diff --git a/src/napatrackmater/clustering.py b/src/napatrackmater/clustering.py index 2636a518..489d54fa 100644 --- a/src/napatrackmater/clustering.py +++ b/src/napatrackmater/clustering.py @@ -16,6 +16,7 @@ from lightning import Trainer from typing import List from tqdm import tqdm +from .Trackmate import TrackMate class PointCloudDataset(Dataset): def __init__(self, clouds: List[PyntCloud], center=True, scale_z=1.0, scale_xy=1.0): @@ -41,7 +42,7 @@ def __getitem__(self, idx): return point_cloud -class Clustering: +class Clustering(TrackMate): def __init__( self, accelerator: str, @@ -59,6 +60,7 @@ def __init__( center=True, compute_with_autoencoder=True, ): + super().__init__(None, None, None, None, None, None, None, None, seg_image=label_image, autoencoder_model=model) self.accelerator = accelerator self.devices = devices