From 29bb48cbdbff251ad380e7e19b7a60d2e2ce31a5 Mon Sep 17 00:00:00 2001 From: Nico Kemnitz Date: Wed, 8 Nov 2023 10:06:03 +0100 Subject: [PATCH] feat(inference): base coarsener with output channel support --- zetta_utils/alignment/base_coarsener.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/zetta_utils/alignment/base_coarsener.py b/zetta_utils/alignment/base_coarsener.py index 1d114e0f6..4dd495619 100644 --- a/zetta_utils/alignment/base_coarsener.py +++ b/zetta_utils/alignment/base_coarsener.py @@ -17,6 +17,7 @@ class BaseCoarsener: model_path: str abs_val_thr: float = 0.005 ds_factor: int = 1 + output_channels: int = 1 tile_pad_in: int = 128 tile_size: int = 1024 @@ -36,14 +37,15 @@ def __call__(self, src: torch.Tensor) -> torch.Tensor: raise ValueError(f"Unsupported src dtype: {src.dtype}") data_in = einops.rearrange(data_in, "C X Y Z -> Z C X Y").to(device) - result = torch.zeros_like( - data_in[ - ..., - : data_in.shape[-2] // self.ds_factor, - : data_in.shape[-1] // self.ds_factor, - ] - ).float() - + result = torch.zeros( + data_in.shape[0], + self.output_channels, + data_in.shape[-2] // self.ds_factor, + data_in.shape[-1] // self.ds_factor, + dtype=torch.float32, + layout=data_in.layout, + device=data_in.device + ) tile_pad_out = self.tile_pad_in // self.ds_factor for x in range(self.tile_pad_in, data_in.shape[-2] - self.tile_pad_in, self.tile_size):