Skip to content

Commit

Permalink
add hot_distance test, and fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Nov 15, 2024
1 parent 31c2911 commit c4a234b
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 14 deletions.
6 changes: 2 additions & 4 deletions dacapo/experiments/tasks/predictors/dummy_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,8 @@ def create_target(self, gt):
# zeros
return np_to_funlib_array(
np.zeros((self.embedding_dims,) + gt.data.shape[-gt.dims :]),
gt.roi,
gt.roi.offset,
gt.voxel_size,
["c^"] + gt.axis_names,
)

def create_weight(self, gt, target, mask, moving_class_counts=None):
Expand All @@ -96,9 +95,8 @@ def create_weight(self, gt, target, mask, moving_class_counts=None):
return (
np_to_funlib_array(
np.ones(target.data.shape),
target.roi,
target.roi.offset,
target.voxel_size,
target.axis_names,
),
None,
)
Expand Down
8 changes: 3 additions & 5 deletions dacapo/experiments/tasks/predictors/hot_distance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,11 @@ def create_target(self, gt):
Examples:
>>> target = predictor.create_target(gt)
"""
target = self.process(gt.data, gt.voxel_size, self.norm, self.dt_scale_factor)
target = self.process(gt[:], gt.voxel_size, self.norm, self.dt_scale_factor)
return np_to_funlib_array(
target,
gt.roi,
gt.roi.offset,
gt.voxel_size,
gt.axis_names,
)

def create_weight(self, gt, target, mask, moving_class_counts=None):
Expand Down Expand Up @@ -209,9 +208,8 @@ def create_weight(self, gt, target, mask, moving_class_counts=None):
return (
np_to_funlib_array(
weights,
gt.roi,
gt.roi.offset,
gt.voxel_size,
gt.axis_names,
),
moving_class_counts,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,8 @@ def create_target(self, gt):
)
return np_to_funlib_array(
distances,
gt.roi,
gt.roi.offset,
gt.voxel_size,
gt.axis_names,
)

def create_weight(self, gt, target, mask, moving_class_counts=None):
Expand Down Expand Up @@ -155,9 +154,8 @@ def create_weight(self, gt, target, mask, moving_class_counts=None):
return (
np_to_funlib_array(
weights,
gt.roi,
gt.roi.offset,
gt.voxel_size,
gt.axis_names,
),
moving_class_counts,
)
Expand Down
3 changes: 2 additions & 1 deletion tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
onehot_run,
unet_2d_distance_run,
unet_3d_distance_run,
hot_distance_run,
)
from .tasks import dummy_task, distance_task, onehot_task, six_onehot_task
from .tasks import dummy_task, distance_task, onehot_task, six_onehot_task, hot_distance_task
from .trainers import dummy_trainer, gunpowder_trainer
18 changes: 18 additions & 0 deletions tests/fixtures/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,24 @@ def distance_run(
num_iterations=10,
)

@pytest.fixture()
def hot_distance_run(
six_class_datasplit,
dummy_architecture,
hot_distance_task,
gunpowder_trainer,
):
yield RunConfig(
name="hot_distance_run",
task_config=hot_distance_task,
architecture_config=dummy_architecture,
trainer_config=gunpowder_trainer,
datasplit_config=six_class_datasplit,
repetition=0,
num_iterations=10,
)



@pytest.fixture()
def dummy_run(
Expand Down
16 changes: 16 additions & 0 deletions tests/fixtures/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
DistanceTaskConfig,
DummyTaskConfig,
OneHotTaskConfig,
HotDistanceTaskConfig,
)
import pytest

Expand All @@ -27,6 +28,21 @@ def distance_task():
tol_distance=10,
)

@pytest.fixture()
def hot_distance_task():
yield HotDistanceTaskConfig(
name="hot_distance_task",
channels=[
"a",
"b",
"c",
"d",
"e",
"f",
],
clip_distance=5,
tol_distance=10,
)

@pytest.fixture()
def onehot_task():
Expand Down
1 change: 1 addition & 0 deletions tests/operations/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
lf("distance_run"),
lf("dummy_run"),
lf("onehot_run"),
lf("hot_distance_run"),
],
)
def test_train(
Expand Down

0 comments on commit c4a234b

Please sign in to comment.