Skip to content

Commit

Permalink
mitosis training
Browse files Browse the repository at this point in the history
  • Loading branch information
kapoorlab committed Nov 25, 2023
1 parent 45c5299 commit 6062572
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions src/napatrackmater/Trackvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import json
from tqdm import tqdm
from torch.optim.lr_scheduler import MultiStepLR
import matplotlib.pyplot as plt

class TrackVector(TrackMate):
def __init__(
Expand Down Expand Up @@ -1357,6 +1358,55 @@ def train_mitosis_neural_net(features_array, labels_array_class1, labels_array_c
val_acc_class2=val_acc_class2_values)
torch.save(model.state_dict(), save_path + '_mitosis_track_model.pth')



def plot_metrics_from_npz(npz_file):
data = np.load(npz_file)

train_loss_class1 = data['train_loss_class1']
train_loss_class2 = data['train_loss_class2']
val_loss_class1 = data['val_loss_class1']
val_loss_class2 = data['val_loss_class2']
train_acc_class1 = data['train_acc_class1']
train_acc_class2 = data['train_acc_class2']
val_acc_class1 = data['val_acc_class1']
val_acc_class2 = data['val_acc_class2']

epochs = len(train_loss_class1)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(range(epochs), train_loss_class1, label='Train Loss Class 1')
plt.plot(range(epochs), val_loss_class1, label='Validation Loss Class 1')
plt.legend()
plt.title('Loss for Class 1')

plt.subplot(1, 2, 2)
plt.plot(range(epochs), train_loss_class2, label='Train Loss Class 2')
plt.plot(range(epochs), val_loss_class2, label='Validation Loss Class 2')
plt.legend()
plt.title('Loss for Class 2')

plt.tight_layout()
plt.show()

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(range(epochs), train_acc_class1, label='Train Acc Class 1')
plt.plot(range(epochs), val_acc_class1, label='Validation Acc Class 1')
plt.legend()
plt.title('Accuracy for Class 1')

plt.subplot(1, 2, 2)
plt.plot(range(epochs), train_acc_class2, label='Train Acc Class 2')
plt.plot(range(epochs), val_acc_class2, label='Validation Acc Class 2')
plt.legend()
plt.title('Accuracy for Class 2')

plt.tight_layout()
plt.show()


def predict_with_model(saved_model_path, features_array):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand Down

0 comments on commit 6062572

Please sign in to comment.