diff --git a/src/napatrackmater/Trackmate.py b/src/napatrackmater/Trackmate.py index 74a104e9..fb57f879 100644 --- a/src/napatrackmater/Trackmate.py +++ b/src/napatrackmater/Trackmate.py @@ -14,7 +14,7 @@ from pathlib import Path import concurrent from .clustering import Clustering -from lightning import Trainer + class TrackMate: def __init__( @@ -59,7 +59,7 @@ def __init__( self.center = center self.compute_with_autoencoder = compute_with_autoencoder self.latent_features = latent_features - self.pretrainer = Trainer(accelerator=self.accelerator, devices=self.devices) + if image is not None: self.image = image.astype(np.uint8) else: diff --git a/src/napatrackmater/clustering.py b/src/napatrackmater/clustering.py index b6115ff7..23d8a5f8 100644 --- a/src/napatrackmater/clustering.py +++ b/src/napatrackmater/clustering.py @@ -16,7 +16,6 @@ from lightning import Trainer from typing import List from tqdm import tqdm -from .Trackmate import TrackMate class PointCloudDataset(Dataset): def __init__(self, clouds: List[PyntCloud], center=True, scale_z=1.0, scale_xy=1.0): @@ -42,7 +41,7 @@ def __getitem__(self, idx): return point_cloud -class Clustering(TrackMate): +class Clustering: def __init__( self, accelerator: str, @@ -60,10 +59,6 @@ def __init__( center=True, compute_with_autoencoder=True, ): - super().__init__(xml_path=None, spot_csv_path=None, track_csv_path=None, - AttributeBoxname = None, TrackAttributeBoxname = None, - TrackidBox = None, seg_image=label_image, - autoencoder_model=model, accelerator=accelerator, devices=devices, axes=axes, key=key,) self.accelerator = accelerator self.devices = devices @@ -179,6 +174,7 @@ def _latent_computer(self, i, dim): def _create_cluster_labels(self): ndim = len(self.label_image.shape) + self.pretrainer = Trainer(accelerator=self.accelerator, devices=self.devices) if ndim == 2: labels, centroids, clouds, marching_cube_points = _label_cluster(