diff --git a/src/napatrackmater/Trackmate.py b/src/napatrackmater/Trackmate.py index fb57f879..98ffcd94 100644 --- a/src/napatrackmater/Trackmate.py +++ b/src/napatrackmater/Trackmate.py @@ -14,7 +14,7 @@ from pathlib import Path import concurrent from .clustering import Clustering - +from lightning import Trainer class TrackMate: def __init__( @@ -59,7 +59,7 @@ def __init__( self.center = center self.compute_with_autoencoder = compute_with_autoencoder self.latent_features = latent_features - + self.pretrainer = Trainer(accelerator=self.accelerator, devices=self.devices) if image is not None: self.image = image.astype(np.uint8) else: @@ -1705,6 +1705,7 @@ def _compute_latent_space(self): self.progress_bar.show() cluster_eval = Clustering( + self.pretrainer, self.accelerator, self.devices, self.seg_image[int(time_key), :], @@ -1767,6 +1768,7 @@ def _assign_cluster_class(self): self.progress_bar.show() cluster_eval = Clustering( + self.pretrainer, self.accelerator, self.devices, self.seg_image[int(time_key), :], diff --git a/src/napatrackmater/clustering.py b/src/napatrackmater/clustering.py index 23d8a5f8..8324aa97 100644 --- a/src/napatrackmater/clustering.py +++ b/src/napatrackmater/clustering.py @@ -44,6 +44,7 @@ def __getitem__(self, idx): class Clustering: def __init__( self, + pretrainer: Trainer, accelerator: str, devices: List[int], label_image: np.ndarray, @@ -59,7 +60,8 @@ def __init__( center=True, compute_with_autoencoder=True, ): - + + self.pretrainer = pretrainer self.accelerator = accelerator self.devices = devices self.label_image = label_image @@ -174,7 +176,6 @@ def _latent_computer(self, i, dim): def _create_cluster_labels(self): ndim = len(self.label_image.shape) - self.pretrainer = Trainer(accelerator=self.accelerator, devices=self.devices) if ndim == 2: labels, centroids, clouds, marching_cube_points = _label_cluster(