diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index f758a32f6..4a31fceaf 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -80,8 +80,8 @@ def start_worker( weights_store.retrieve_weights(run_name, iteration) # get arrays - raw_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) - raw_array = ZarrArray.open_from_array_identifier(raw_array_identifier) + input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) + raw_array = ZarrArray.open_from_array_identifier(input_array_identifier) output_array_identifier = LocalArrayIdentifier( Path(output_container), output_dataset diff --git a/dacapo/predict.py b/dacapo/predict.py index d0db9149f..1ea363ea0 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -56,10 +56,10 @@ def predict( ) # get arrays - raw_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) - raw_array = ZarrArray.open_from_array_identifier(raw_array_identifier) + input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) + raw_array = ZarrArray.open_from_array_identifier(input_array_identifier) if isinstance(output_path, LocalArrayIdentifier): - prediction_array_identifier = output_path + output_array_identifier = output_path else: if ".zarr" in str(output_path) or ".n5" in str(output_path): output_container = Path(output_path) @@ -68,7 +68,7 @@ def predict( output_path, Path(input_container).stem + ".zarr", ) # TODO: zarr hardcoded - prediction_array_identifier = LocalArrayIdentifier( + output_array_identifier = LocalArrayIdentifier( output_container, f"prediction_{run_name}_{iteration}" ) @@ -115,7 +115,7 @@ def predict( # prepare prediction dataset axes = ["c"] + [axis for axis in raw_array.axes if axis != "c"] ZarrArray.create_from_array_identifier( - prediction_array_identifier, + output_array_identifier, axes, output_roi, model.num_out_channels, @@ -138,12 +138,12 @@ def predict( ###### run_name=run_name, iteration=iteration, - raw_array_identifier=raw_array_identifier, - prediction_array_identifier=prediction_array_identifier, + input_array_identifier=input_array_identifier, + output_array_identifier=output_array_identifier, ) - container = zarr.open(str(prediction_array_identifier.container)) - dataset = container[prediction_array_identifier.dataset] + container = zarr.open(str(output_array_identifier.container)) + dataset = container[output_array_identifier.dataset] dataset.attrs["axes"] = ( # type: ignore raw_array.axes if "c" in raw_array.axes else ["c"] + raw_array.axes )