Skip to content

Commit

Permalink
handle negative eigenvalues
Browse files Browse the repository at this point in the history
  • Loading branch information
kapoorlab committed Dec 4, 2023
1 parent 7d3d748 commit 1387287
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
6 changes: 4 additions & 2 deletions src/napatrackmater/Trackmate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pathlib import Path
import concurrent
from .clustering import Clustering

from lightning import Trainer

class TrackMate:
def __init__(
Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(
self.center = center
self.compute_with_autoencoder = compute_with_autoencoder
self.latent_features = latent_features

self.pretrainer = Trainer(accelerator=self.accelerator, devices=self.devices)
if image is not None:
self.image = image.astype(np.uint8)
else:
Expand Down Expand Up @@ -1705,6 +1705,7 @@ def _compute_latent_space(self):
self.progress_bar.show()

cluster_eval = Clustering(
self.pretrainer,
self.accelerator,
self.devices,
self.seg_image[int(time_key), :],
Expand Down Expand Up @@ -1767,6 +1768,7 @@ def _assign_cluster_class(self):
self.progress_bar.show()

cluster_eval = Clustering(
self.pretrainer,
self.accelerator,
self.devices,
self.seg_image[int(time_key), :],
Expand Down
5 changes: 3 additions & 2 deletions src/napatrackmater/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __getitem__(self, idx):
class Clustering:
def __init__(
self,
pretrainer: Trainer,
accelerator: str,
devices: List[int],
label_image: np.ndarray,
Expand All @@ -59,7 +60,8 @@ def __init__(
center=True,
compute_with_autoencoder=True,
):


self.pretrainer = pretrainer
self.accelerator = accelerator
self.devices = devices
self.label_image = label_image
Expand Down Expand Up @@ -174,7 +176,6 @@ def _latent_computer(self, i, dim):
def _create_cluster_labels(self):

ndim = len(self.label_image.shape)
self.pretrainer = Trainer(accelerator=self.accelerator, devices=self.devices)
if ndim == 2:

labels, centroids, clouds, marching_cube_points = _label_cluster(
Expand Down

0 comments on commit 1387287

Please sign in to comment.