Skip to content

Commit

Permalink
fix hard coded dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Nov 19, 2024
1 parent 29cbd6d commit 02ee936
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions dacapo/predict_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ def predict(
else:
input_roi = output_roi.grow(context, context)

read_roi = Roi((0, 0, 0), input_size)
read_roi = Roi((0,) * input_size.dims, input_size)
write_roi = read_roi.grow(-context, -context)

axes = ["c^", "z", "y", "x"]
axes = raw_array.axis_names
if "c^" not in axes:
axes = ["c^"] + axes

num_channels = model.num_out_channels

Expand All @@ -73,8 +75,8 @@ def predict(

model_device = str(next(model.parameters()).device).split(":")[0]

assert model_device == str(
device
assert (
model_device == str(device)
), f"Model is not on the right device, Model: {model_device}, Compute device: {device}"

def predict_fn(block):
Expand Down Expand Up @@ -103,7 +105,7 @@ def predict_fn(block):
predictions = Array(
predictions,
block.write_roi.offset,
raw_array.voxel_size,
output_voxel_size,
axis_names,
raw_array.units,
)
Expand All @@ -120,7 +122,7 @@ def predict_fn(block):
task = daisy.Task(
f"predict_{out_container}_{out_dataset}",
total_roi=input_roi,
read_roi=Roi((0, 0, 0), input_size),
read_roi=Roi((0,)*input_size.dims, input_size),
write_roi=Roi(context, output_size),
process_function=predict_fn,
check_function=None,
Expand Down

0 comments on commit 02ee936

Please sign in to comment.