From 7b040e908a4ff2ecf16c51683cfc3bbdb8dbb5c4 Mon Sep 17 00:00:00 2001 From: JoeStrout Date: Fri, 11 Oct 2024 16:55:08 +0000 Subject: [PATCH] Fix incorrect argument type to torch.maximum. --- .../mazepa_layer_processing/common/volumetric_apply_flow.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py b/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py index 579d7e20e..fb88ba876 100644 --- a/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py +++ b/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py @@ -95,7 +95,9 @@ def __call__( intscn, subidx = src_idx.get_intersection_and_subindex(red_idx) subidx_channels = [slice(0, res.shape[0])] + list(subidx) with semaphore("read"): - res[subidx_channels] = torch.maximum(res[subidx_channels], layer[intscn]) + res[subidx_channels] = torch.maximum( + res[subidx_channels], convert.to_torch(layer[intscn], res.device) + ) else: for src_idx, layer in zip(src_idxs, src_layers): intscn, subidx = src_idx.get_intersection_and_subindex(red_idx)