-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
o0t1ng0o
committed
Jul 28, 2024
0 parents
commit 9fe1db9
Showing
164 changed files
with
15,254 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
pretrained_models/VQGAN/2024_01_07_090227/epoch=284-step=114000.ckpt filter=lfs diff=lfs merge=lfs -text | ||
*.ckpt filter=lfs diff=lfs merge=lfs -text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
Metadata-Version: 2.1 | ||
Name: Medical-Diffusion | ||
Version: 1.0 | ||
Summary: Diffusion model for medical images | ||
Home-page: UNKNOWN | ||
Author: | ||
License: UNKNOWN | ||
Platform: UNKNOWN | ||
Description-Content-Type: text/markdown | ||
License-File: LICENSE | ||
|
||
Medfusion - Medical Denoising Diffusion Probabilistic Model | ||
============= | ||
|
||
Paper | ||
======= | ||
Please see: [**Diffusion Probabilistic Models beat GANs on Medical 2D Images**](https://arxiv.org/abs/2212.07501) | ||
|
||
![](media/Medfusion.png) | ||
*Figure: Medfusion* | ||
|
||
![](media/animation_eye.gif) ![](media/animation_histo.gif) ![](media/animation_chest.gif)\ | ||
*Figure: Eye fundus, chest X-ray and colon histology images generated with Medfusion (Warning color quality limited by .gif)* | ||
|
||
Demo | ||
============= | ||
[Link](https://huggingface.co/spaces/mueller-franzes/medfusion-app) to streamlit app. | ||
|
||
Install | ||
============= | ||
|
||
Create virtual environment and install packages: \ | ||
`python -m venv venv` \ | ||
`source venv/bin/activate`\ | ||
`pip install -e .` | ||
|
||
|
||
Get Started | ||
============= | ||
|
||
1 Prepare Data | ||
------------- | ||
|
||
* Go to [medical_diffusion/data/datasets/dataset_simple_2d.py](medical_diffusion/data/datasets/dataset_simple_2d.py) and create a new `SimpleDataset2D` or write your own Dataset. | ||
|
||
|
||
2 Train Autoencoder | ||
---------------- | ||
* Go to [scripts/train_latent_embedder_2d.py](scripts/train_latent_embedder_2d.py) and import your Dataset. | ||
* Load your dataset with eg. `SimpleDataModule` | ||
* Customize `VAE` to your needs | ||
* (Optional): Train a `VAEGAN` instead or load a pre-trained `VAE` and set `start_gan_train_step=-1` to start training of GAN immediately. | ||
|
||
2.1 Evaluate Autoencoder | ||
---------------- | ||
* Use [scripts/evaluate_latent_embedder.py](scripts/evaluate_latent_embedder.py) to evaluate the performance of the Autoencoder. | ||
|
||
3 Train Diffusion | ||
---------------- | ||
* Go to [scripts/train_diffusion.py](scripts/train_diffusion.py) and import/load your Dataset as before. | ||
* Load your pre-trained VAE or VAEGAN with `latent_embedder_checkpoint=...` | ||
* Use `cond_embedder = LabelEmbedder` for conditional training, otherwise `cond_embedder = None` | ||
|
||
3.1 Evaluate Diffusion | ||
---------------- | ||
* Go to [scripts/sample.py](scripts/sample.py) to sample a test image. | ||
* Go to [scripts/helpers/sample_dataset.py](scripts/helpers/sample_dataset.py) to sample a more reprensative sample size. | ||
* Use [scripts/evaluate_images.py](scripts/evaluate_images.py) to evaluate performance of sample (FID, Precision, Recall) | ||
|
||
Acknowledgment | ||
============= | ||
* Code builds upon https://github.com/lucidrains/denoising-diffusion-pytorch | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
LICENSE | ||
README.md | ||
setup.py | ||
Medical_Diffusion.egg-info/PKG-INFO | ||
Medical_Diffusion.egg-info/SOURCES.txt | ||
Medical_Diffusion.egg-info/dependency_links.txt | ||
Medical_Diffusion.egg-info/requires.txt | ||
Medical_Diffusion.egg-info/top_level.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
torch | ||
pytorch-lightning | ||
pytorch_msssim | ||
monai | ||
torchmetrics | ||
torch-fidelity | ||
torchio | ||
pillow | ||
einops | ||
torchvision | ||
matplotlib | ||
pandas | ||
lpips | ||
streamlit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Mask2PET | ||
This repository is a PyTorch implementation for PET tumor generation from benign PET and tumor mask. We build this code upon [medfusion](https://github.com/mueller-franzes/medfusion). You can completely follow the instruction in medfusion. | ||
|
||
## Dataset | ||
We use the HECKTOR 2021 dataset, please download from this link (https://www.aicrowd.com/challenges/miccai-2021-hecktor). Please check the path of data in ./launch/train.sh and ./launch/test.sh | ||
|
||
## Data preprocessing | ||
In this work, we use the VQ-GAN as encoder to encode the PET. If you want to custom this model on your dataset, please train VQ-GAN on your dataset first. https://github.com/CompVis/taming-transformers. | ||
|
||
We provide the pre-trained VQ-GAN model in ./pretrained_models | ||
|
||
## Training | ||
``` | ||
sh ./launch/train.sh | ||
``` | ||
## inference | ||
``` | ||
sh ./launch/test.sh | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
|
||
|
||
pip3 install torch==1.11.0+cu113 torchvision==0.12.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 | ||
|
||
|
||
pip install torchio | ||
pip install monai | ||
|
||
pip install torch-fidelity | ||
pip install pytorch_msssim | ||
pip install lpips | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
CUDA_VISIBLE_DEVICES=0 python3 ./scripts/train_diffusion.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
python3 ./scripts/sample2.py \ | ||
-i data/Task107_hecktor2021/labelsTrain/ \ | ||
-t data/Task107_hecktor2021/imagesTrain/ \ | ||
-i_val data/Task107_hecktor2021/labelsTest/ \ | ||
-t_val data/Task107_hecktor2021/imagesTest/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
python3 ./scripts/train_diffusion.py --masked_condition True \ | ||
-i data/Task107_hecktor2021/labelsTrain/ \ | ||
-t data/Task107_hecktor2021/imagesTrain/ \ | ||
-i_val data/Task107_hecktor2021/labelsTest/ \ | ||
-t_val data/Task107_hecktor2021/imagesTest/ | ||
|
||
# continue to train | ||
#--resume_from_checkpoint ./runs/LDM_VQGAN/2024_06_07_115628/epoch=199-step=9999.ckpt |
Empty file.
Binary file not shown.
Binary file not shown.
Empty file.
Binary file added
BIN
+204 Bytes
medical_diffusion/data/augmentation/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
Binary file added
BIN
+212 Bytes
medical_diffusion/data/augmentation/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
Binary file added
BIN
+1.56 KB
medical_diffusion/data/augmentation/__pycache__/augmentations_2d.cpython-36.pyc
Binary file not shown.
Binary file added
BIN
+1.52 KB
medical_diffusion/data/augmentation/__pycache__/augmentations_2d.cpython-38.pyc
Binary file not shown.
Binary file added
BIN
+2.68 KB
medical_diffusion/data/augmentation/__pycache__/augmentations_3d.cpython-36.pyc
Binary file not shown.
Binary file added
BIN
+2.73 KB
medical_diffusion/data/augmentation/__pycache__/augmentations_3d.cpython-38.pyc
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
|
||
import torch | ||
import numpy as np | ||
|
||
class ToTensor16bit(object): | ||
"""PyTorch can not handle uint16 only int16. First transform to int32. Note, this function also adds a channel-dim""" | ||
def __call__(self, image): | ||
# return torch.as_tensor(np.array(image, dtype=np.int32)[None]) | ||
# return torch.from_numpy(np.array(image, np.int32, copy=True)[None]) | ||
image = np.array(image, np.int32, copy=True) # [H,W,C] or [H,W] | ||
image = np.expand_dims(image, axis=-1) if image.ndim ==2 else image | ||
return torch.from_numpy(np.moveaxis(image, -1, 0)) #[C, H, W] | ||
|
||
class Normalize(object): | ||
"""Rescale the image to [0,1] and ensure float32 dtype """ | ||
|
||
def __call__(self, image): | ||
image = image.type(torch.FloatTensor) | ||
return (image-image.min())/(image.max()-image.min()) | ||
|
||
|
||
class RandomBackground(object): | ||
"""Fill Background (intensity ==0) with random values""" | ||
|
||
def __call__(self, image): | ||
image[image==0] = torch.rand(*image[image==0].shape) #(image.max()-image.min()) | ||
return image |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import torchio as tio | ||
from typing import Union, Optional, Sequence | ||
from torchio.typing import TypeTripletInt | ||
from torchio import Subject, Image | ||
from torchio.utils import to_tuple | ||
|
||
class CropOrPad_None(tio.CropOrPad): | ||
def __init__( | ||
self, | ||
target_shape: Union[int, TypeTripletInt, None] = None, | ||
padding_mode: Union[str, float] = 0, | ||
mask_name: Optional[str] = None, | ||
labels: Optional[Sequence[int]] = None, | ||
**kwargs | ||
): | ||
|
||
# WARNING: Ugly workaround to allow None values | ||
if target_shape is not None: | ||
self.original_target_shape = to_tuple(target_shape, length=3) | ||
target_shape = [1 if t_s is None else t_s for t_s in target_shape] | ||
super().__init__(target_shape, padding_mode, mask_name, labels, **kwargs) | ||
|
||
def apply_transform(self, subject: Subject): | ||
# WARNING: This makes the transformation subject dependent - reverse transformation must be adapted | ||
if self.target_shape is not None: | ||
self.target_shape = [s_s if t_s is None else t_s for t_s, s_s in zip(self.original_target_shape, subject.spatial_shape)] | ||
return super().apply_transform(subject=subject) | ||
|
||
|
||
class SubjectToTensor(object): | ||
"""Transforms TorchIO Subjects into a Python dict and changes axes order from TorchIO to Torch""" | ||
def __call__(self, subject: Subject): | ||
return {key: val.data.swapaxes(1,-1) if isinstance(val, Image) else val for key,val in subject.items()} | ||
|
||
class ImageToTensor(object): | ||
"""Transforms TorchIO Image into a Numpy/Torch Tensor and changes axes order from TorchIO [B, C, W, H, D] to Torch [B, C, D, H, W]""" | ||
def __call__(self, image: Image): | ||
return image.data.swapaxes(1,-1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .datamodule_simple import SimpleDataModule |
Binary file added
BIN
+261 Bytes
medical_diffusion/data/datamodules/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
Binary file added
BIN
+269 Bytes
medical_diffusion/data/datamodules/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
Binary file added
BIN
+2.29 KB
medical_diffusion/data/datamodules/__pycache__/datamodule_simple.cpython-36.pyc
Binary file not shown.
Binary file added
BIN
+2.32 KB
medical_diffusion/data/datamodules/__pycache__/datamodule_simple.cpython-38.pyc
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
|
||
import pytorch_lightning as pl | ||
import torch | ||
from torch.utils.data.dataloader import DataLoader | ||
import torch.multiprocessing as mp | ||
from torch.utils.data.sampler import WeightedRandomSampler, RandomSampler | ||
|
||
|
||
|
||
class SimpleDataModule(pl.LightningDataModule): | ||
|
||
def __init__(self, | ||
ds_train: object, | ||
ds_val:object =None, | ||
ds_test:object =None, | ||
batch_size: int = 1, | ||
num_workers: int = mp.cpu_count(), | ||
seed: int = 0, | ||
pin_memory: bool = False, | ||
weights: list = None | ||
): | ||
super().__init__() | ||
self.hyperparameters = {**locals()} | ||
self.hyperparameters.pop('__class__') | ||
self.hyperparameters.pop('self') | ||
|
||
self.ds_train = ds_train | ||
self.ds_val = ds_val | ||
self.ds_test = ds_test | ||
|
||
self.batch_size = batch_size | ||
self.num_workers = num_workers | ||
self.seed = seed | ||
self.pin_memory = pin_memory | ||
self.weights = weights | ||
|
||
|
||
|
||
def train_dataloader(self): | ||
generator = torch.Generator() | ||
generator.manual_seed(self.seed) | ||
|
||
if self.weights is not None: | ||
sampler = WeightedRandomSampler(self.weights, len(self.weights), generator=generator) | ||
else: | ||
sampler = RandomSampler(self.ds_train, replacement=False, generator=generator) | ||
return DataLoader(self.ds_train, batch_size=self.batch_size, num_workers=self.num_workers, | ||
sampler=sampler, generator=generator, drop_last=True, pin_memory=self.pin_memory) | ||
|
||
|
||
def val_dataloader(self): | ||
generator = torch.Generator() | ||
generator.manual_seed(self.seed) | ||
if self.ds_val is not None: | ||
return DataLoader(self.ds_val, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, | ||
generator=generator, drop_last=False, pin_memory=self.pin_memory) | ||
else: | ||
raise AssertionError("A validation set was not initialized.") | ||
|
||
|
||
def test_dataloader(self): | ||
generator = torch.Generator() | ||
generator.manual_seed(self.seed) | ||
if self.ds_test is not None: | ||
return DataLoader(self.ds_test, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, | ||
generator = generator, drop_last=False, pin_memory=self.pin_memory) | ||
else: | ||
raise AssertionError("A test test set was not initialized.") | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .dataset_simple_2d import * | ||
from .dataset_simple_3d import * |
Binary file added
BIN
+263 Bytes
medical_diffusion/data/datasets/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
Binary file added
BIN
+271 Bytes
medical_diffusion/data/datasets/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
Binary file added
BIN
+7.79 KB
medical_diffusion/data/datasets/__pycache__/dataset_simple_2d.cpython-36.pyc
Binary file not shown.
Binary file added
BIN
+7.55 KB
medical_diffusion/data/datasets/__pycache__/dataset_simple_2d.cpython-38.pyc
Binary file not shown.
Binary file added
BIN
+6.83 KB
medical_diffusion/data/datasets/__pycache__/dataset_simple_3d.cpython-36.pyc
Binary file not shown.
Binary file added
BIN
+6.89 KB
medical_diffusion/data/datasets/__pycache__/dataset_simple_3d.cpython-38.pyc
Binary file not shown.
Oops, something went wrong.