Skip to content

Commit

Permalink
Dev/main (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar authored Feb 28, 2024
2 parents 44ff6a0 + 9df4e0b commit 24f9d4f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions dacapo/blockwise/predict_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def start_worker(
weights_store.retrieve_weights(run_name, iteration)

# get arrays
raw_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset)
raw_array = ZarrArray.open_from_array_identifier(raw_array_identifier)
input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset)
raw_array = ZarrArray.open_from_array_identifier(input_array_identifier)

output_array_identifier = LocalArrayIdentifier(
Path(output_container), output_dataset
Expand Down
18 changes: 9 additions & 9 deletions dacapo/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ def predict(
)

# get arrays
raw_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset)
raw_array = ZarrArray.open_from_array_identifier(raw_array_identifier)
input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset)
raw_array = ZarrArray.open_from_array_identifier(input_array_identifier)
if isinstance(output_path, LocalArrayIdentifier):
prediction_array_identifier = output_path
output_array_identifier = output_path
else:
if ".zarr" in str(output_path) or ".n5" in str(output_path):
output_container = Path(output_path)
Expand All @@ -68,7 +68,7 @@ def predict(
output_path,
Path(input_container).stem + ".zarr",
) # TODO: zarr hardcoded
prediction_array_identifier = LocalArrayIdentifier(
output_array_identifier = LocalArrayIdentifier(
output_container, f"prediction_{run_name}_{iteration}"
)

Expand Down Expand Up @@ -115,7 +115,7 @@ def predict(
# prepare prediction dataset
axes = ["c"] + [axis for axis in raw_array.axes if axis != "c"]
ZarrArray.create_from_array_identifier(
prediction_array_identifier,
output_array_identifier,
axes,
output_roi,
model.num_out_channels,
Expand All @@ -138,12 +138,12 @@ def predict(
######
run_name=run_name,
iteration=iteration,
raw_array_identifier=raw_array_identifier,
prediction_array_identifier=prediction_array_identifier,
input_array_identifier=input_array_identifier,
output_array_identifier=output_array_identifier,
)

container = zarr.open(str(prediction_array_identifier.container))
dataset = container[prediction_array_identifier.dataset]
container = zarr.open(str(output_array_identifier.container))
dataset = container[output_array_identifier.dataset]
dataset.attrs["axes"] = ( # type: ignore
raw_array.axes if "c" in raw_array.axes else ["c"] + raw_array.axes
)

0 comments on commit 24f9d4f

Please sign in to comment.