medseg-diffusion

MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Models πŸš€

PyTorch Badge Python Badge Jupyter Notebook Badge License Badge GitHub Stars Badge

MedSegDiff is a comprehensive PyTorch implementation of the MedSegDiff paper, presenting the first Diffusion Probabilistic Model (DPM) designed specifically for general medical image segmentation tasks. This repository aims to provide researchers and practitioners with a clear, step-by-step codebase and documentation to facilitate understanding and application of MedSegDiff across various medical imaging modalities.


Source Code Website
github.com/deepmancer/medseg-diffusion deepmancer.github.io/medseg-diffusion

🌟 Key Features


πŸ“– Table of Contents


πŸ” Overview

MedSegDiff addresses a fundamental challenge in medical imaging: achieving accurate and robust segmentation across various imaging modalities. Building upon the principles of Diffusion Probabilistic Models (DPMs), MedSegDiff introduces innovative techniques like dynamic conditional encoding and the Feature Frequency Parser (FF-Parser) to enhance the model’s ability to focus on critical regions, reduce high-frequency noise, and achieve state-of-the-art segmentation results.

MedSegDiff Overview
An overview of the MedSegDiff architecture. The time step encoding component is omitted for clarity.

Formally, at each diffusion step, the model estimates:

Equation 1

Here:

The training objective:

Equation 2

This loss encourages the model to accurately predict the noise added at each step, ultimately guiding the segmentation toward a clean, high-quality mask.


πŸ› οΈ Methodology

MedSegDiff employs a U-Net-based architecture enriched with diffusion steps, dynamic conditional encoding, and Fourier-based noise reduction. The key idea is to iteratively refine a noisy segmentation map into a clean, accurate mask using reverse diffusion steps guided by learned conditioning from the original image.

πŸ”§ Dynamic Conditional Encoding

  1. Feature Frequency Parser (FF-Parser): The segmentation map first passes through the FF-Parser, which utilizes Fourier transforms to filter out high-frequency noise components, thereby refining the feature representation.

    FF-Parser Illustration
    The FF-Parser integrates FFT-based denoising before feature fusion.

  2. Attentive Fusion: The denoised feature map is then fused with the image embeddings through an attentive mechanism, enhancing regional attention and improving segmentation precision.

  3. Iterative Refinement: This combined feature undergoes further refinement, culminating in a bottleneck phase that integrates with encoder features.

  4. Bottleneck Integration: The refined features merge with the encoder outputs, resulting in the final segmentation mask.

⏳ Time Encoding Block

Time Embedding Illustration

πŸ—οΈ Encoder & Decoder Blocks

πŸ”„ Diffusion Forward & Reverse Processes (Review)

🟒 Forward Diffusion

In the forward diffusion process, Gaussian noise is progressively added to the segmentation mask over a series of timesteps, degrading it into pure noise.

  1. Noise Addition: Starting from the original segmentation mask $\text{mask}_0$, Gaussian noise is added iteratively at each timestep $t$, controlled by a variance schedule $\beta_t$.

  2. Progressive Degradation: This process produces a sequence of increasingly noisy masks $\text{mask}_0, \text{mask}_1, \dots, \text{mask}_T$.

  3. Convergence to Noise: As $T \to \infty$, the mask becomes indistinguishable from pure Gaussian noise.

Forward Diffusion Process

πŸ”΄ Reverse Diffusion

The reverse diffusion process aims to reconstruct the original segmentation mask from the noisy data by iteratively denoising.

  1. Noise Prediction: A U-Net is trained to predict the noise added at each timestep, learning a mapping $\epsilon_\theta(\text{mask}_t, t)$.

  2. Stepwise Denoising: Starting from $\text{mask}_T$, the model refines the mask by subtracting the predicted noise at each timestep, moving backward from $t = T$ to $t = 0$.

  3. Final Reconstruction: After $T$ steps, the output $\text{mask}_0$ approximates the original segmentation mask.

Reverse Diffusion Process


🎯 Results

MedSegDiff demonstrates superior performance across various medical image segmentation tasks, outperforming state-of-the-art methods by a significant margin.

Evaluation Results
Visual comparisons with other segmentation methods.

Quantitative Results
Quantitative results comparing MedSegDiff with state-of-the-art methods. Best results are highlighted in bold.


πŸš€ Installation & Usage

Requirements

Installation

Clone the repository and install the required packages:

git clone https://github.com/deepmancer/medseg-diffusion.git
cd medseg-diffusion
pip install -r requirements.txt

Quick Start


πŸ“ License

This project is licensed under the MIT License. See the LICENSE file for details.


πŸ™ Acknowledgments

We extend our gratitude to the authors of the MedSegDiff paper and other referenced works for their valuable research and insights that inspired this implementation.


🌟 Support the Project

If you find MedSegDiff valuable for your research or projects, please consider starring ⭐ this repository on GitHub. Your support helps others discover this work!


πŸ“š Citations

If you utilize this repository, please consider citing the following works:

@article{Wu2022MedSegDiffMI,
    title   = {MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Model},
    author  = {Junde Wu and Huihui Fang and Yu Zhang and Yehui Yang and Yanwu Xu},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2211.00611}
}
@inproceedings{Hoogeboom2023simpleDE,
    title   = {simple diffusion: End-to-end diffusion for high resolution images},
    author  = {Emiel Hoogeboom and Jonathan Heek and Tim Salimans},
    year    = {2023}
}