diff --git a/src/napatrackmater/Trackvector.py b/src/napatrackmater/Trackvector.py index c10fb829..f56daf1d 100644 --- a/src/napatrackmater/Trackvector.py +++ b/src/napatrackmater/Trackvector.py @@ -1296,6 +1296,11 @@ def train_mitosis_neural_net(features_array, labels_array_class1, labels_array_c loss_class2.backward() optimizer.step() + + outputs_class1, outputs_class2 = model(inputs) + + _, predicted_class1 = torch.max(outputs_class1.data, 1) + _, predicted_class2 = torch.max(outputs_class2.data, 1) running_loss_class1 += loss_class1.item() running_loss_class2 += loss_class2.item() @@ -1359,7 +1364,7 @@ def train_mitosis_neural_net(features_array, labels_array_class1, labels_array_c torch.save(model.state_dict(), save_path + '_mitosis_track_model.pth') - + def plot_metrics_from_npz(npz_file): data = np.load(npz_file)