Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

s4879083 3d-unet #195

Open
wants to merge 26 commits into
base: topic-recognition
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions recognition/3D-UNT 48790835/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# 3D UNet for Prostate Segmentation

## Introduction

This project utilizes the 3D UNet architecture to train on the Prostate 3D dataset, aiming to achieve precise medical volumetric image segmentation. We evaluate the performance of the segmentation using the Dice similarity coefficient, targeting a minimum score of 0.7 for all labels on the test set. Image segmentation transforms a volumetric image into segmented areas represented by masks, which facilitates medical condition analysis, symptom prediction, and treatment planning.

## Background

### UNet-3D

The 3D UNet is an extension of the original UNet architecture, which is widely used for segmenting 2D medical images. While the standard UNet processes 2D images, UNet-3D extends this functionality to volumetric (3D) images, allowing for more accurate segmentation of complex medical structures found in modalities like MRI or CT scans.

UNet architecture leverages a combination of convolutional neural networks (CNNs) and skip connections, improving performance by combining high-resolution features from the contracting path with low-resolution context from the expansive path. This design maintains spatial information throughout the segmentation process, which is critical in the medical imaging field.


![3D U-Net Architecture](https://raw.githubusercontent.com/Han1zen/PatternAnalysis-2024/refs/heads/topic-recognition/recognition/3D-UNT%2048790835/picture/3D%20U-Net.webp)

### Dataset

For this project, we will segment the downsampled Prostate 3D dataset. A sample code for loading and processing Nifti file formats is provided in Appendix B. Furthermore, we encourage the use of data augmentation libraries for TensorFlow (TF) or the appropriate transformations in PyTorch to enhance the robustness of the model.

### Evaluation Metric

We will employ the Dice similarity coefficient as our primary evaluation metric. The Dice coefficient measures the overlap between the predicted segmentation and the ground truth, mathematically expressed as:

\[ \text{Dice} = \frac{2 |A \cap B|}{|A| + |B|} \]

where \( A \) and \( B \) are the sets of predicted and ground truth regions respectively. A Dice coefficient of 0.7 or greater indicates a significant degree of accuracy in segmentation.

## Objectives

- Implement the 3D Improved UNet architecture for the Prostate dataset.
- Achieve a minimum Dice similarity coefficient of 0.7 for all labels on the test set.
- Utilize data augmentation techniques to improve model generalization.
- Load and preprocess Nifti file formats for volumetric data analysis.

## Quick Start

To get started with the 3D UNet model for prostate segmentation, follow these steps:

1. **Clone the Repository**: Clone the repository to your local machine.
2. **Install Dependencies**: Ensure you have the required libraries installed.
3. **Prepare the Dataset**: Download the Prostate 3D dataset and place it in the `data/` directory.
4. **Run Training**: Execute the training script to begin training the model on the Prostate 3D dataset.

## Results

### Training and Validation Loss

![Training and Validation Loss](https://github.com/Han1zen/PatternAnalysis-2024/blob/topic-recognition/recognition/3D-UNT%2048790835/picture/train_loss_and_valid_loss.png#:~:text=loss.jpg-,train_loss_and_valid_loss,-.png)

- The **training loss** curve demonstrates a rapid decline in the early stages of training, indicating that the model is effectively learning and adapting to the training data.
- As training progresses, the loss stabilizes, ultimately reaching around **0.6**. This suggests that the model performs well on the training set and is capable of effective feature learning.

- The **validation loss** curve also exhibits a downward trend, remaining relatively close to the training loss in the later stages of training.
- This indicates that the model has good generalization capabilities on the validation set, with no significant signs of overfitting. The validation loss stabilizes at approximately **0.62**, further supporting the model's effectiveness.

### Dice Similarity Coefficient

![Dice](https://github.com/Han1zen/PatternAnalysis-2024/blob/topic-recognition/recognition/3D-UNT%2048790835/picture/dice.png#:~:text=dice.-,png,-loss.jpg)
- The model achieves a **Dice similarity coefficient** of over **0.7** for all labels, meeting our established target.
- This indicates that the model performs excellently in the segmentation task, accurately identifying and segmenting different regions of the prostate.


## References

1. Sik-Ho Tsang. "Review: 3D U-Net — Volumetric Segmentation (Medical Image Segmentation)." [Towards Data Science](https://towardsdatascience.com/review-3d-u-net-volumetric-segmentation-medical-image-segmentation-8b592560fac1).

122 changes: 122 additions & 0 deletions recognition/3D-UNT 48790835/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import os
import torch
from torch.utils.data import Dataset, DataLoader
from monai.transforms import (
Compose,
LoadImaged,
EnsureTyped,
RandFlipd,
Lambdad,
Resized,
EnsureChannelFirstd,
ScaleIntensityd,
RandRotate90d,
)

# Transforms for training data: load, resize, and apply random flips and rotations
train_transforms = Compose([
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
ScaleIntensityd(keys="image"), # Normalize intensity
Lambdad(keys="image", func=lambda x: (x - x.min()) / (x.max() - x.min())), # Further normalization
RandRotate90d(keys=("image", "label"), prob=0.5), # Random 90-degree rotations
RandFlipd(keys=("image", "label"), prob=0.5, spatial_axis=[0]),
RandFlipd(keys=("image", "label"), prob=0.5, spatial_axis=[1]),
RandFlipd(keys=("image", "label"), prob=0.5, spatial_axis=[2]),
Resized(keys=["image", "label"], spatial_size=(256, 256, 128)),
EnsureTyped(keys=("image", "label"), dtype=torch.float32),
])

# Transforms for testing data: only load and normalize
val_transforms = Compose([
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
ScaleIntensityd(keys="image"),
Lambdad(keys="image", func=lambda x: (x - x.min()) / (x.max() - x.min())),
EnsureTyped(keys=("image", "label"), dtype=torch.float32),
])

class CustomDataset(Dataset):
"""
Dataset class for reading pelvic MRI data.
"""

def __init__(self, mode, dataset_path):
"""
Args:
mode (str): One of 'train', 'val', 'test'.
dataset_path (str): Root directory of the dataset.
"""
self.mode = mode
self.train_transform = train_transforms
self.test_transform = val_transforms

# Load image and label file paths based on mode
if self.mode == 'train':
with open('train_list.txt', 'r') as f:
select_list = [_.strip() for _ in f.readlines()]
self.img_list = [os.path.join(dataset_path, 'semantic_MRs_anon', _) for _ in select_list]
self.label_list = [os.path.join(dataset_path, 'semantic_labels_anon', _.replace('_LFOV', '_SEMANTIC_LFOV'))
for _ in select_list]

elif self.mode == 'test':
with open('test_list.txt', 'r') as f:
select_list = [_.strip() for _ in f.readlines()]
self.img_list = [os.path.join(dataset_path, 'semantic_MRs_anon', _) for _ in select_list]
self.label_list = [os.path.join(dataset_path, 'semantic_labels_anon', _.replace('_LFOV', '_SEMANTIC_LFOV'))
for _ in select_list]

def __len__(self):
return len(self.label_list)

def __getitem__(self, index):
img_path = self.img_list[index]
label_path = self.label_list[index]

if self.mode == 'train':
augmented = self.train_transform({'image': img_path, 'label': label_path})
image = augmented['image']
label = augmented['label']

# 确保图像和标签是4D张量
if image.dim() == 5: # 如果是5D张量
image = image.squeeze(1) # 去掉通道维度,变为4D张量 (x, y, z)

if label.dim() == 5: # 如果是5D张量
label = label.squeeze(1) # 去掉通道维度,变为4D张量 (x, y, z)

return image, label

if self.mode == 'test':
augmented = self.test_transform({'image': img_path, 'label': label_path})
image = augmented['image']
label = augmented['label']

# 确保图像和标签是4D张量
if image.dim() == 5: # 如果是5D张量
image = image.squeeze(1) # 去掉通道维度,变为4D张量 (x, y, z)

if label.dim() == 5: # 如果是5D张量
label = label.squeeze(1) # 去掉通道维度,变为4D张量 (x, y, z)

return image, label


if __name__ == '__main__':
# Test the dataset
test_dataset = CustomDataset(mode='test', dataset_path=r"path_to_your_dataset")
test_dataloader = DataLoader(dataset=test_dataset, batch_size=2, shuffle=False)
print(len(test_dataset))
for batch_ndx, sample in enumerate(test_dataloader):
print('test')
print(sample[0].shape) # 应该打印 (batch_size, channels, x, y, z)
print(sample[1].shape) # 应该打印 (batch_size, channels, x, y, z)
break

train_dataset = CustomDataset(mode='train', dataset_path=r"path_to_your_dataset")
train_dataloader = DataLoader(dataset=train_dataset, batch_size=2, shuffle=False)
for batch_ndx, sample in enumerate(train_dataloader):
print('train')
print(sample[0].shape) # 应该打印 (batch_size, channels, x, y, z)
print(sample[1].shape) # 应该打印 (batch_size, channels, x, y, z)
break
92 changes: 92 additions & 0 deletions recognition/3D-UNT 48790835/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""

def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm3d(out_channels),
nn.ReLU(inplace=True),
nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm3d(out_channels),
nn.ReLU(inplace=True)
)

def forward(self, x):
return self.double_conv(x)

class Down(nn.Module):
"""Downscaling with maxpool then double conv"""

def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool3d(2),
DoubleConv(in_channels, out_channels)
)

def forward(self, x):
return self.maxpool_conv(x)

class Up(nn.Module):
"""Upscaling then double conv"""

def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()

# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels)
else:
self.up = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)

def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffZ = x2.size()[2] - x1.size()[2]
diffY = x2.size()[3] - x1.size()[3]
diffX = x2.size()[4] - x1.size()[4]

x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2,
diffZ // 2, diffZ - diffZ // 2])

x = torch.cat([x2, x1], dim=1)
return self.conv(x)

class UNet3D(nn.Module):
def __init__(self, in_channels=1, out_channels=6):
super(UNet3D, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels

self.inc = DoubleConv(in_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 512)
self.up1 = Up(512, 256)
self.up2 = Up(256, 128)
self.up3 = Up(128, 64)
self.up4 = Up(64, 64)
self.outc = nn.Conv3d(64, out_channels, kernel_size=1)

def forward(self, x):

x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
Binary file not shown.
Binary file added recognition/3D-UNT 48790835/picture/dice.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added recognition/3D-UNT 48790835/picture/loss.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
69 changes: 69 additions & 0 deletions recognition/3D-UNT 48790835/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch
import numpy as np
import random
import argparse
from modules import UNet3D
from dataset import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn

# Set random seed for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

# Load the model
model = UNet3D(in_channel=1, out_channel=6).cuda()
model.load_state_dict(torch.load(r'epoch_2_lossdice1.pth'))
model.eval()

# Define the test dataloader
test_dataset = Dataset(mode='test', dataset_path=r'C:\Users\111\Desktop\3710\新建文件夹\数据集\Labelled_weekly_MR_images_of_the_male_pelvis-Xken7gkM-\data\HipMRI_study_complete_release_v1')
test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

# Define weighted Dice loss function
class WeightedDiceLoss(nn.Module):
def __init__(self, weights=None, smooth=1):
super(WeightedDiceLoss, self).__init__()
self.weights = weights
self.smooth = smooth

def forward(self, inputs, targets):
# Flatten the input and target tensors
inputs = inputs.view(-1)
targets = targets.view(-1)

intersection = (inputs * targets).sum()
total = inputs.sum() + targets.sum()

# Calculate Dice coefficient
dice = (2. * intersection + self.smooth) / (total + self.smooth)

if self.weights is not None:
return (1 - dice) * self.weights
return 1 - dice

valid_loss = []
for idx, (data_x, data_y) in enumerate(test_dataloader):
data_x = data_x.to(torch.float32).cuda()
data_y = data_y.to(torch.float32).cuda().squeeze()

# Get model outputs
outputs = model(data_x)

# Get the predicted class with the maximum value
outputs_class = torch.argmax(outputs, dim=1).squeeze()

# Calculate the intersection with the ground truth
intersection = torch.sum(outputs_class == data_y)
assert outputs_class.size() == data_y.size()

# Calculate the Dice coefficient
dice_coeff = intersection.item() / outputs_class.numel()
print('Dice Coefficient:', dice_coeff)
valid_loss.append(dice_coeff)

# Print the average Dice coefficient for the test set
average_loss = np.average(valid_loss)
print('Average Dice Coefficient:', average_loss)
Loading