Skip to content

Commit

Permalink
mitosis training
Browse files Browse the repository at this point in the history
  • Loading branch information
kapoorlab committed Nov 25, 2023
1 parent 77606c3 commit 45c5299
Showing 1 changed file with 51 additions and 17 deletions.
68 changes: 51 additions & 17 deletions src/napatrackmater/Trackvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,11 +1265,22 @@ def train_mitosis_neural_net(features_array, labels_array_class1, labels_array_c

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

train_loss_class1_values = []
train_loss_class2_values = []
val_loss_class1_values = []
val_loss_class2_values = []
train_acc_class1_values = []
train_acc_class2_values = []
val_acc_class1_values = []
val_acc_class2_values = []
for epoch in range(epochs):
model.train()
running_loss_class1 = 0.0
running_loss_class2 = 0.0
correct_train_class1 = 0
total_train_class1 = 0
correct_train_class2 = 0
total_train_class2 = 0

with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{epochs}") as pbar:
for i, data in enumerate(train_loader):
Expand All @@ -1287,17 +1298,27 @@ def train_mitosis_neural_net(features_array, labels_array_class1, labels_array_c

running_loss_class1 += loss_class1.item()
running_loss_class2 += loss_class2.item()

# Update the progress bar
correct_train_class1 += (predicted_class1 == labels_class1).sum().item()
total_train_class1 += labels_class1.size(0)
correct_train_class2 += (predicted_class2 == labels_class2).sum().item()
total_train_class2 += labels_class2.size(0)
pbar.update(1)
pbar.set_postfix({'Class1 Loss': running_loss_class1 / (i + 1), 'Class2 Loss': running_loss_class2 / (i + 1)})
pbar.set_postfix({'Acc Class1': correct_train_class1 / total_train_class1 if total_train_class1 > 0 else 0,
'Acc Class2': correct_train_class2 / total_train_class2 if total_train_class2 > 0 else 0,'Class1 Loss': running_loss_class1 / (i + 1), 'Class2 Loss': running_loss_class2 / (i + 1)})
scheduler.step()
train_loss_class1_values.append(running_loss_class1 / len(train_loader))
train_loss_class2_values.append(running_loss_class2 / len(train_loader))
train_acc_class1_values.append(correct_train_class1 / total_train_class1 if total_train_class1 > 0 else 0)
train_acc_class2_values.append(correct_train_class2 / total_train_class2 if total_train_class2 > 0 else 0)


model.eval()
correct_class1 = 0
total_class1 = 0
correct_class2 = 0
total_class2 = 0
running_val_loss_class1 = 0.0
running_val_loss_class2 = 0.0
correct_val_class1 = 0
total_val_class1 = 0
correct_val_class2 = 0
total_val_class2 = 0

with tqdm(total=len(val_loader), desc=f"Validation Epoch {epoch + 1}/{epochs}") as pbar_val:
with torch.no_grad():
Expand All @@ -1308,19 +1329,32 @@ def train_mitosis_neural_net(features_array, labels_array_class1, labels_array_c
_, predicted_class1 = torch.max(outputs_class1.data, 1)
_, predicted_class2 = torch.max(outputs_class2.data, 1)

total_class1 += labels_class1.size(0)
correct_class1 += (predicted_class1 == labels_class1).sum().item()
total_val_class1 += labels_class1.size(0)
correct_val_class1 += (predicted_class1 == labels_class1).sum().item()

total_class2 += labels_class2.size(0)
correct_class2 += (predicted_class2 == labels_class2).sum().item()
total_val_class2 += labels_class2.size(0)
correct_val_class2 += (predicted_class2 == labels_class2).sum().item()

# Update the validation progress bar
pbar_val.update(1)
accuracy_class1 = correct_class1 / total_class1 if total_class1 > 0 else 0
accuracy_class2 = correct_class2 / total_class2 if total_class2 > 0 else 0
accuracy_class1 = correct_val_class1 / total_val_class1 if total_val_class1 > 0 else 0
accuracy_class2 = correct_val_class2 / total_val_class2 if total_val_class2 > 0 else 0
pbar_val.set_postfix({'Acc Class1': accuracy_class1, 'Acc Class2': accuracy_class2})



val_loss_class1_values.append(running_val_loss_class1 / len(val_loader))
val_loss_class2_values.append(running_val_loss_class2 / len(val_loader))
val_acc_class1_values.append(correct_val_class1 / total_val_class1 if total_val_class1 > 0 else 0)
val_acc_class2_values.append(correct_val_class2 / total_val_class2 if total_val_class2 > 0 else 0)


np.savez(save_path + '_metrics.npz',
train_loss_class1=train_loss_class1_values,
train_loss_class2=train_loss_class2_values,
val_loss_class1=val_loss_class1_values,
val_loss_class2=val_loss_class2_values,
train_acc_class1=train_acc_class1_values,
train_acc_class2=train_acc_class2_values,
val_acc_class1=val_acc_class1_values,
val_acc_class2=val_acc_class2_values)
torch.save(model.state_dict(), save_path + '_mitosis_track_model.pth')

def predict_with_model(saved_model_path, features_array):
Expand Down

0 comments on commit 45c5299

Please sign in to comment.