diff --git a/dacapo/experiments/tasks/affinities_task.py b/dacapo/experiments/tasks/affinities_task.py index a355288d2..1a5b4d7ce 100644 --- a/dacapo/experiments/tasks/affinities_task.py +++ b/dacapo/experiments/tasks/affinities_task.py @@ -25,14 +25,10 @@ class AffinitiesTask(Task): def __init__(self, task_config): """ - Create a `DummyTask` from a `DummyTaskConfig`. + Create an AffinitiesTask object from a given AffinitiesTaskConfig. Args: - task_config: The configuration for the task. - Returns: - A `DummyTask` object. - Raises: - NotImplementedError: This method is not implemented. + task_config (AffinitiesTaskConfig): The configuration for the affinities task Examples: >>> task = AffinitiesTask(task_config) """ @@ -40,6 +36,8 @@ def __init__(self, task_config): self.predictor = AffinitiesPredictor( neighborhood=task_config.neighborhood, lsds=task_config.lsds, + num_voxels=task_config.num_lsd_voxels, + downsample_lsds=task_config.downsample_lsds, affs_weight_clipmin=task_config.affs_weight_clipmin, affs_weight_clipmax=task_config.affs_weight_clipmax, lsd_weight_clipmin=task_config.lsd_weight_clipmin, diff --git a/dacapo/experiments/tasks/affinities_task_config.py b/dacapo/experiments/tasks/affinities_task_config.py index 5e22f2a0d..0aeb72763 100644 --- a/dacapo/experiments/tasks/affinities_task_config.py +++ b/dacapo/experiments/tasks/affinities_task_config.py @@ -17,6 +17,8 @@ class AffinitiesTaskConfig(TaskConfig): Attributes: neighborhood: A list of Coordinate objects. lsds: Whether or not to train lsds along with your affinities. + num_lsd_voxels: The number of voxels to use for the lsd center of mass calculation. + downsample_lsds: The factor by which to downsample the lsds. lsds_to_affs_weight_ratio: If training with lsds, set how much they should be weighted compared to affs. affs_weight_clipmin: The minimum value for affinities weights. affs_weight_clipmax: The maximum value for affinities weights. @@ -45,6 +47,19 @@ class AffinitiesTaskConfig(TaskConfig): "It has been shown that lsds as an auxiliary task can help affinity predictions." }, ) + num_lsd_voxels: int = attr.ib( + default=10, + metadata={ + "help_text": "The number of voxels to use for the lsd center of mass calculation." + }, + ) + downsample_lsds: int = attr.ib( + default=1, + metadata={ + "help_text": "The factor by which to downsample the lsds. " + "This can be useful to reduce the computational cost of training." + }, + ) lsds_to_affs_weight_ratio: float = attr.ib( default=1, metadata={ diff --git a/dacapo/experiments/tasks/predictors/affinities_predictor.py b/dacapo/experiments/tasks/predictors/affinities_predictor.py index 59f0cfa60..e4084270a 100644 --- a/dacapo/experiments/tasks/predictors/affinities_predictor.py +++ b/dacapo/experiments/tasks/predictors/affinities_predictor.py @@ -4,22 +4,6 @@ from dacapo.experiments.datasplits.datasets.arrays import NumpyArray from dacapo.utils.affinities import seg_to_affgraph, padding as aff_padding from dacapo.utils.balance_weights import balance_weights - -from funlib.geometry import Coordinate -from lsd.train import LsdExtractor - -from scipy import ndimage -import numpy as np -import torch -import itertools - -from typing import List -from .predictor import Predictor -from dacapo.experiments import Model -from dacapo.experiments.arraytypes import EmbeddingArray -from dacapo.experiments.datasplits.datasets.arrays import NumpyArray -from dacapo.utils.affinities import seg_to_affgraph, padding as aff_padding -from dacapo.utils.balance_weights import balance_weights from funlib.geometry import Coordinate from lsd.train import LsdExtractor from scipy import ndimage