Skip to content

Commit

Permalink
gravitational net for mitosis net
Browse files Browse the repository at this point in the history
  • Loading branch information
kapoorlab committed Dec 9, 2023
1 parent 936f902 commit fe15851
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 54 deletions.
108 changes: 56 additions & 52 deletions src/napatrackmater/Trackvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from torch.optim.lr_scheduler import MultiStepLR
import matplotlib.pyplot as plt
from typing import List, Union
import torch.nn.functional as F
from torchsummary import summary


class TrackVector(TrackMate):
def __init__(
self,
Expand Down Expand Up @@ -1195,7 +1195,7 @@ def unsupervised_clustering(

track_arrays_array_names = ["shape_dynamic", "shape", "dynamic"]
clusterable_track_arrays = [
shape_dynamic_covariance_2d,
shape_dynamic_covariance_2d,
shape_covariance_2d,
dynamic_covariance_2d,
]
Expand All @@ -1218,8 +1218,8 @@ def unsupervised_clustering(
shape_dynamic_cosine_distance, method=method
)
shape_dynamic_cluster_labels = fcluster(
shape_dynamic_linkage_matrix, num_clusters, criterion=criterion
)
shape_dynamic_linkage_matrix, num_clusters, criterion=criterion
)

cluster_centroids = calculate_cluster_centroids(
clusterable_track_array, shape_dynamic_cluster_labels
Expand Down Expand Up @@ -1364,8 +1364,8 @@ def compute_covariance_matrix(track_arrays, shape_features=5, mask_features=None


class DenseLayer(nn.Module):
'''
'''
""" """

def __init__(self, in_channels, growth_rate, bottleneck_size, kernel_size):
super().__init__()
self.use_bottleneck = bottleneck_size > 0
Expand All @@ -1374,19 +1374,18 @@ def __init__(self, in_channels, growth_rate, bottleneck_size, kernel_size):
self.bn2 = nn.BatchNorm1d(in_channels)
self.act2 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv1d(
in_channels,
self.num_bottleneck_output_filters,
kernel_size=1,
stride=1)
in_channels, self.num_bottleneck_output_filters, kernel_size=1, stride=1
)
self.bn1 = nn.BatchNorm1d(self.num_bottleneck_output_filters)
self.act1 = nn.ReLU(inplace=True)
self.conv1 = nn.Conv1d(
self.num_bottleneck_output_filters,
growth_rate,
kernel_size=kernel_size,
stride=1,
dilation=1,
padding=kernel_size // 2)
stride=1,
dilation=1,
padding=kernel_size // 2,
)

def forward(self, x):
if self.use_bottleneck:
Expand All @@ -1400,17 +1399,23 @@ def forward(self, x):


class DenseBlock(nn.ModuleDict):
'''
'''
def __init__(self, num_layers, in_channels, growth_rate, kernel_size, bottleneck_size):
""" """

def __init__(
self, num_layers, in_channels, growth_rate, kernel_size, bottleneck_size
):
super().__init__()
self.num_layers = num_layers
for i in range(self.num_layers):
self.add_module(f'denselayer{i}',
DenseLayer(in_channels + i * growth_rate,
growth_rate,
bottleneck_size,
kernel_size))
self.add_module(
f"denselayer{i}",
DenseLayer(
in_channels + i * growth_rate,
growth_rate,
bottleneck_size,
kernel_size,
),
)

def forward(self, x):
layer_outputs = [x]
Expand All @@ -1422,43 +1427,41 @@ def forward(self, x):


class TransitionBlock(nn.Module):
'''
'''
""" """

def __init__(self, in_channels, out_channels):
super().__init__()
self.bn = nn.BatchNorm1d(in_channels)
self.act = nn.ReLU(inplace=True)
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, dilation=1)
self.conv = nn.Conv1d(
in_channels, out_channels, kernel_size=1, stride=1, dilation=1
)
self.pool = nn.AvgPool1d(kernel_size=2, stride=2)

def forward(self, x):
x = self.bn(x)
x = self.act(x)
x = self.conv(x)
x = self.pool(x)
return x



class DenseNet1d(nn.Module):

class DenseNet1d(nn.Module):
def __init__(
self,
self,
growth_rate: int = 32,
block_config: tuple = (6, 12, 24, 16),
num_init_features: int = 64,
bottleneck_size: int = 4,
kernel_size: int = 3,
kernel_size: int = 3,
in_channels: int = 1,
num_classes: int = 1,
reinit: bool = True,
):
super().__init__()

self.features = nn.Sequential(
nn.Conv1d(
in_channels, num_init_features,
kernel_size=3),
nn.Conv1d(in_channels, num_init_features, kernel_size=3),
nn.BatchNorm1d(num_init_features),
nn.ReLU(inplace=True),
nn.MaxPool1d(kernel_size=3, stride=2, padding=1),
Expand All @@ -1473,21 +1476,20 @@ def __init__(
kernel_size=kernel_size,
bottleneck_size=bottleneck_size,
)
self.features.add_module(f'denseblock{i}', block)
self.features.add_module(f"denseblock{i}", block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
trans = TransitionBlock(
in_channels=num_features,
out_channels=num_features // 2)
self.features.add_module(f'transition{i}', trans)
in_channels=num_features, out_channels=num_features // 2
)
self.features.add_module(f"transition{i}", trans)
num_features = num_features // 2

self.final_bn = nn.BatchNorm1d(num_features)
self.final_act = nn.ReLU(inplace=True)
self.final_pool = nn.AdaptiveAvgPool1d(1)
self.classifier = nn.Linear(num_features, num_classes)

# init

if reinit:
for m in self.modules():
if isinstance(m, nn.Conv1d):
Expand All @@ -1513,25 +1515,27 @@ def forward(self, x):

def reset_classifier(self):
self.classifier = nn.Identity()

def get_classifier(self):
return self.classifier


class MitosisNet(nn.Module):
def __init__(self, num_classes_class1, num_classes_class2):
super().__init__()
self.densenet = DenseNet1d(in_channels=1, num_classes=num_classes_class1 + num_classes_class2)
self.densenet = DenseNet1d(
in_channels=1, num_classes=num_classes_class1 + num_classes_class2
)
self.num_classes_class1 = num_classes_class1
self.num_classes_class2 = num_classes_class2

def forward(self, x):
logits = self.densenet(x)

class_output1 = logits[:, :self.num_classes_class1]
class_output2 = logits[:, self.num_classes_class1:]
class_output1 = logits[:, : self.num_classes_class1]
class_output2 = logits[:, self.num_classes_class1 :]

return class_output1, class_output2



def train_mitosis_neural_net(
Expand Down Expand Up @@ -1579,18 +1583,18 @@ def train_mitosis_neural_net(
}
with open(save_path + "_model_info.json", "w") as json_file:
json.dump(model_info, json_file)

model = MitosisNet(
num_classes_class1=num_classes1,
num_classes_class2=num_classes2,
)
num_classes_class1=num_classes1,
num_classes_class2=num_classes2,
)

model.to(device)
summary(model, (1,input_size))
summary(model, (1, input_size))
criterion_class1 = nn.CrossEntropyLoss()
criterion_class2 = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
milestones = [int(epochs * 0.25),int(epochs * 0.5), int(epochs * 0.75)]
milestones = [int(epochs * 0.25), int(epochs * 0.5), int(epochs * 0.75)]
scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

train_dataset = TensorDataset(
Expand Down
4 changes: 2 additions & 2 deletions src/napatrackmater/_version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = version = "4.5.6"
__version_tuple__ = version_tuple = (4, 5, 6)
__version__ = version = "4.5.7"
__version_tuple__ = version_tuple = (4, 5, 7)

0 comments on commit fe15851

Please sign in to comment.