Skip to content

Commit

Permalink
feat(inference): base coarsener with output channel support
Browse files Browse the repository at this point in the history
  • Loading branch information
nkemnitz committed Nov 8, 2023
1 parent 2f9533c commit 29bb48c
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions zetta_utils/alignment/base_coarsener.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class BaseCoarsener:
model_path: str
abs_val_thr: float = 0.005
ds_factor: int = 1
output_channels: int = 1
tile_pad_in: int = 128
tile_size: int = 1024

Expand All @@ -36,14 +37,15 @@ def __call__(self, src: torch.Tensor) -> torch.Tensor:
raise ValueError(f"Unsupported src dtype: {src.dtype}")

data_in = einops.rearrange(data_in, "C X Y Z -> Z C X Y").to(device)
result = torch.zeros_like(
data_in[
...,
: data_in.shape[-2] // self.ds_factor,
: data_in.shape[-1] // self.ds_factor,
]
).float()

result = torch.zeros(
data_in.shape[0],
self.output_channels,
data_in.shape[-2] // self.ds_factor,
data_in.shape[-1] // self.ds_factor,
dtype=torch.float32,
layout=data_in.layout,
device=data_in.device
)
tile_pad_out = self.tile_pad_in // self.ds_factor

for x in range(self.tile_pad_in, data_in.shape[-2] - self.tile_pad_in, self.tile_size):
Expand Down

0 comments on commit 29bb48c

Please sign in to comment.