Skip to content

Commit

Permalink
commit message
Browse files Browse the repository at this point in the history
  • Loading branch information
o0t1ng0o committed Jul 28, 2024
0 parents commit 9fe1db9
Show file tree
Hide file tree
Showing 164 changed files with 15,254 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pretrained_models/VQGAN/2024_01_07_090227/epoch=284-step=114000.ckpt filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
73 changes: 73 additions & 0 deletions Medical_Diffusion.egg-info/PKG-INFO
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
Metadata-Version: 2.1
Name: Medical-Diffusion
Version: 1.0
Summary: Diffusion model for medical images
Home-page: UNKNOWN
Author:
License: UNKNOWN
Platform: UNKNOWN
Description-Content-Type: text/markdown
License-File: LICENSE

Medfusion - Medical Denoising Diffusion Probabilistic Model
=============

Paper
=======
Please see: [**Diffusion Probabilistic Models beat GANs on Medical 2D Images**](https://arxiv.org/abs/2212.07501)

![](media/Medfusion.png)
*Figure: Medfusion*

![](media/animation_eye.gif) ![](media/animation_histo.gif) ![](media/animation_chest.gif)\
*Figure: Eye fundus, chest X-ray and colon histology images generated with Medfusion (Warning color quality limited by .gif)*

Demo
=============
[Link](https://huggingface.co/spaces/mueller-franzes/medfusion-app) to streamlit app.

Install
=============

Create virtual environment and install packages: \
`python -m venv venv` \
`source venv/bin/activate`\
`pip install -e .`


Get Started
=============

1 Prepare Data
-------------

* Go to [medical_diffusion/data/datasets/dataset_simple_2d.py](medical_diffusion/data/datasets/dataset_simple_2d.py) and create a new `SimpleDataset2D` or write your own Dataset.


2 Train Autoencoder
----------------
* Go to [scripts/train_latent_embedder_2d.py](scripts/train_latent_embedder_2d.py) and import your Dataset.
* Load your dataset with eg. `SimpleDataModule`
* Customize `VAE` to your needs
* (Optional): Train a `VAEGAN` instead or load a pre-trained `VAE` and set `start_gan_train_step=-1` to start training of GAN immediately.

2.1 Evaluate Autoencoder
----------------
* Use [scripts/evaluate_latent_embedder.py](scripts/evaluate_latent_embedder.py) to evaluate the performance of the Autoencoder.

3 Train Diffusion
----------------
* Go to [scripts/train_diffusion.py](scripts/train_diffusion.py) and import/load your Dataset as before.
* Load your pre-trained VAE or VAEGAN with `latent_embedder_checkpoint=...`
* Use `cond_embedder = LabelEmbedder` for conditional training, otherwise `cond_embedder = None`

3.1 Evaluate Diffusion
----------------
* Go to [scripts/sample.py](scripts/sample.py) to sample a test image.
* Go to [scripts/helpers/sample_dataset.py](scripts/helpers/sample_dataset.py) to sample a more reprensative sample size.
* Use [scripts/evaluate_images.py](scripts/evaluate_images.py) to evaluate performance of sample (FID, Precision, Recall)

Acknowledgment
=============
* Code builds upon https://github.com/lucidrains/denoising-diffusion-pytorch

8 changes: 8 additions & 0 deletions Medical_Diffusion.egg-info/SOURCES.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
LICENSE
README.md
setup.py
Medical_Diffusion.egg-info/PKG-INFO
Medical_Diffusion.egg-info/SOURCES.txt
Medical_Diffusion.egg-info/dependency_links.txt
Medical_Diffusion.egg-info/requires.txt
Medical_Diffusion.egg-info/top_level.txt
1 change: 1 addition & 0 deletions Medical_Diffusion.egg-info/dependency_links.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

14 changes: 14 additions & 0 deletions Medical_Diffusion.egg-info/requires.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
torch
pytorch-lightning
pytorch_msssim
monai
torchmetrics
torch-fidelity
torchio
pillow
einops
torchvision
matplotlib
pandas
lpips
streamlit
1 change: 1 addition & 0 deletions Medical_Diffusion.egg-info/top_level.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Mask2PET
This repository is a PyTorch implementation for PET tumor generation from benign PET and tumor mask. We build this code upon [medfusion](https://github.com/mueller-franzes/medfusion). You can completely follow the instruction in medfusion.

## Dataset
We use the HECKTOR 2021 dataset, please download from this link (https://www.aicrowd.com/challenges/miccai-2021-hecktor). Please check the path of data in ./launch/train.sh and ./launch/test.sh

## Data preprocessing
In this work, we use the VQ-GAN as encoder to encode the PET. If you want to custom this model on your dataset, please train VQ-GAN on your dataset first. https://github.com/CompVis/taming-transformers.

We provide the pre-trained VQ-GAN model in ./pretrained_models

## Training
```
sh ./launch/train.sh
```
## inference
```
sh ./launch/test.sh
```

12 changes: 12 additions & 0 deletions lanuch/env.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@


pip3 install torch==1.11.0+cu113 torchvision==0.12.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113


pip install torchio
pip install monai

pip install torch-fidelity
pip install pytorch_msssim
pip install lpips

1 change: 1 addition & 0 deletions lanuch/mask2pet_train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CUDA_VISIBLE_DEVICES=0 python3 ./scripts/train_diffusion.py
5 changes: 5 additions & 0 deletions lanuch/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
python3 ./scripts/sample2.py \
-i data/Task107_hecktor2021/labelsTrain/ \
-t data/Task107_hecktor2021/imagesTrain/ \
-i_val data/Task107_hecktor2021/labelsTest/ \
-t_val data/Task107_hecktor2021/imagesTest/
8 changes: 8 additions & 0 deletions lanuch/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
python3 ./scripts/train_diffusion.py --masked_condition True \
-i data/Task107_hecktor2021/labelsTrain/ \
-t data/Task107_hecktor2021/imagesTrain/ \
-i_val data/Task107_hecktor2021/labelsTest/ \
-t_val data/Task107_hecktor2021/imagesTest/

# continue to train
#--resume_from_checkpoint ./runs/LDM_VQGAN/2024_06_07_115628/epoch=199-step=9999.ckpt
Empty file added medical_diffusion/__init__.py
Empty file.
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
27 changes: 27 additions & 0 deletions medical_diffusion/data/augmentation/augmentations_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

import torch
import numpy as np

class ToTensor16bit(object):
"""PyTorch can not handle uint16 only int16. First transform to int32. Note, this function also adds a channel-dim"""
def __call__(self, image):
# return torch.as_tensor(np.array(image, dtype=np.int32)[None])
# return torch.from_numpy(np.array(image, np.int32, copy=True)[None])
image = np.array(image, np.int32, copy=True) # [H,W,C] or [H,W]
image = np.expand_dims(image, axis=-1) if image.ndim ==2 else image
return torch.from_numpy(np.moveaxis(image, -1, 0)) #[C, H, W]

class Normalize(object):
"""Rescale the image to [0,1] and ensure float32 dtype """

def __call__(self, image):
image = image.type(torch.FloatTensor)
return (image-image.min())/(image.max()-image.min())


class RandomBackground(object):
"""Fill Background (intensity ==0) with random values"""

def __call__(self, image):
image[image==0] = torch.rand(*image[image==0].shape) #(image.max()-image.min())
return image
38 changes: 38 additions & 0 deletions medical_diffusion/data/augmentation/augmentations_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torchio as tio
from typing import Union, Optional, Sequence
from torchio.typing import TypeTripletInt
from torchio import Subject, Image
from torchio.utils import to_tuple

class CropOrPad_None(tio.CropOrPad):
def __init__(
self,
target_shape: Union[int, TypeTripletInt, None] = None,
padding_mode: Union[str, float] = 0,
mask_name: Optional[str] = None,
labels: Optional[Sequence[int]] = None,
**kwargs
):

# WARNING: Ugly workaround to allow None values
if target_shape is not None:
self.original_target_shape = to_tuple(target_shape, length=3)
target_shape = [1 if t_s is None else t_s for t_s in target_shape]
super().__init__(target_shape, padding_mode, mask_name, labels, **kwargs)

def apply_transform(self, subject: Subject):
# WARNING: This makes the transformation subject dependent - reverse transformation must be adapted
if self.target_shape is not None:
self.target_shape = [s_s if t_s is None else t_s for t_s, s_s in zip(self.original_target_shape, subject.spatial_shape)]
return super().apply_transform(subject=subject)


class SubjectToTensor(object):
"""Transforms TorchIO Subjects into a Python dict and changes axes order from TorchIO to Torch"""
def __call__(self, subject: Subject):
return {key: val.data.swapaxes(1,-1) if isinstance(val, Image) else val for key,val in subject.items()}

class ImageToTensor(object):
"""Transforms TorchIO Image into a Numpy/Torch Tensor and changes axes order from TorchIO [B, C, W, H, D] to Torch [B, C, D, H, W]"""
def __call__(self, image: Image):
return image.data.swapaxes(1,-1)
1 change: 1 addition & 0 deletions medical_diffusion/data/datamodules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .datamodule_simple import SimpleDataModule
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
79 changes: 79 additions & 0 deletions medical_diffusion/data/datamodules/datamodule_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@

import pytorch_lightning as pl
import torch
from torch.utils.data.dataloader import DataLoader
import torch.multiprocessing as mp
from torch.utils.data.sampler import WeightedRandomSampler, RandomSampler



class SimpleDataModule(pl.LightningDataModule):

def __init__(self,
ds_train: object,
ds_val:object =None,
ds_test:object =None,
batch_size: int = 1,
num_workers: int = mp.cpu_count(),
seed: int = 0,
pin_memory: bool = False,
weights: list = None
):
super().__init__()
self.hyperparameters = {**locals()}
self.hyperparameters.pop('__class__')
self.hyperparameters.pop('self')

self.ds_train = ds_train
self.ds_val = ds_val
self.ds_test = ds_test

self.batch_size = batch_size
self.num_workers = num_workers
self.seed = seed
self.pin_memory = pin_memory
self.weights = weights



def train_dataloader(self):
generator = torch.Generator()
generator.manual_seed(self.seed)

if self.weights is not None:
sampler = WeightedRandomSampler(self.weights, len(self.weights), generator=generator)
else:
sampler = RandomSampler(self.ds_train, replacement=False, generator=generator)
return DataLoader(self.ds_train, batch_size=self.batch_size, num_workers=self.num_workers,
sampler=sampler, generator=generator, drop_last=True, pin_memory=self.pin_memory)


def val_dataloader(self):
generator = torch.Generator()
generator.manual_seed(self.seed)
if self.ds_val is not None:
return DataLoader(self.ds_val, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False,
generator=generator, drop_last=False, pin_memory=self.pin_memory)
else:
raise AssertionError("A validation set was not initialized.")


def test_dataloader(self):
generator = torch.Generator()
generator.manual_seed(self.seed)
if self.ds_test is not None:
return DataLoader(self.ds_test, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False,
generator = generator, drop_last=False, pin_memory=self.pin_memory)
else:
raise AssertionError("A test test set was not initialized.")











2 changes: 2 additions & 0 deletions medical_diffusion/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .dataset_simple_2d import *
from .dataset_simple_3d import *
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit 9fe1db9

Please sign in to comment.