From 6062572d923509cc9bc3aa2b511e630da1b4fe47 Mon Sep 17 00:00:00 2001 From: kapoorlab Date: Sat, 25 Nov 2023 17:43:46 +0100 Subject: [PATCH] mitosis training --- src/napatrackmater/Trackvector.py | 50 +++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/src/napatrackmater/Trackvector.py b/src/napatrackmater/Trackvector.py index 189b9352..c10fb829 100644 --- a/src/napatrackmater/Trackvector.py +++ b/src/napatrackmater/Trackvector.py @@ -29,6 +29,7 @@ import json from tqdm import tqdm from torch.optim.lr_scheduler import MultiStepLR +import matplotlib.pyplot as plt class TrackVector(TrackMate): def __init__( @@ -1357,6 +1358,55 @@ def train_mitosis_neural_net(features_array, labels_array_class1, labels_array_c val_acc_class2=val_acc_class2_values) torch.save(model.state_dict(), save_path + '_mitosis_track_model.pth') + + +def plot_metrics_from_npz(npz_file): + data = np.load(npz_file) + + train_loss_class1 = data['train_loss_class1'] + train_loss_class2 = data['train_loss_class2'] + val_loss_class1 = data['val_loss_class1'] + val_loss_class2 = data['val_loss_class2'] + train_acc_class1 = data['train_acc_class1'] + train_acc_class2 = data['train_acc_class2'] + val_acc_class1 = data['val_acc_class1'] + val_acc_class2 = data['val_acc_class2'] + + epochs = len(train_loss_class1) + + plt.figure(figsize=(12, 4)) + plt.subplot(1, 2, 1) + plt.plot(range(epochs), train_loss_class1, label='Train Loss Class 1') + plt.plot(range(epochs), val_loss_class1, label='Validation Loss Class 1') + plt.legend() + plt.title('Loss for Class 1') + + plt.subplot(1, 2, 2) + plt.plot(range(epochs), train_loss_class2, label='Train Loss Class 2') + plt.plot(range(epochs), val_loss_class2, label='Validation Loss Class 2') + plt.legend() + plt.title('Loss for Class 2') + + plt.tight_layout() + plt.show() + + plt.figure(figsize=(12, 4)) + plt.subplot(1, 2, 1) + plt.plot(range(epochs), train_acc_class1, label='Train Acc Class 1') + plt.plot(range(epochs), val_acc_class1, label='Validation Acc Class 1') + plt.legend() + plt.title('Accuracy for Class 1') + + plt.subplot(1, 2, 2) + plt.plot(range(epochs), train_acc_class2, label='Train Acc Class 2') + plt.plot(range(epochs), val_acc_class2, label='Validation Acc Class 2') + plt.legend() + plt.title('Accuracy for Class 2') + + plt.tight_layout() + plt.show() + + def predict_with_model(saved_model_path, features_array): device = torch.device("cuda" if torch.cuda.is_available() else "cpu")