diff --git a/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py b/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py index 579d7e20..fb88ba87 100644 --- a/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py +++ b/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py @@ -95,7 +95,9 @@ def __call__( intscn, subidx = src_idx.get_intersection_and_subindex(red_idx) subidx_channels = [slice(0, res.shape[0])] + list(subidx) with semaphore("read"): - res[subidx_channels] = torch.maximum(res[subidx_channels], layer[intscn]) + res[subidx_channels] = torch.maximum( + res[subidx_channels], convert.to_torch(layer[intscn], res.device) + ) else: for src_idx, layer in zip(src_idxs, src_layers): intscn, subidx = src_idx.get_intersection_and_subindex(red_idx)