diff --git a/spotiflow/model/spotiflow.py b/spotiflow/model/spotiflow.py index 502925d..8fd4500 100644 --- a/spotiflow/model/spotiflow.py +++ b/spotiflow/model/spotiflow.py @@ -415,11 +415,9 @@ def fit( transforms.Crop if not self.config.is_3d else transforms3d.Crop3D ) assert any( - isinstance(p, _crop_cls) for p in augment_train.transforms + isinstance(p, _crop_cls) for p in augment_train.augmentations ), "Custom augmenter must contain a cropping transform!" - tr_augmenter = self.build_image_augmenter( - crop_size, point_priority=point_priority - ) + tr_augmenter = augment_train elif augment_train: tr_augmenter = self.build_image_augmenter( crop_size, point_priority=point_priority