Skip to content

Commit

Permalink
feat(inference): mixed precision support for base encoder/coarsener
Browse files Browse the repository at this point in the history
  • Loading branch information
nkemnitz committed Nov 8, 2023
1 parent ede7425 commit 7f0a270
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion zetta_utils/alignment/base_coarsener.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def __call__(self, src: torch.Tensor) -> torch.Tensor:
y_end = y + self.tile_size + self.tile_pad_in
tile = data_in[:, :, x_start:x_end, y_start:y_end]
if (tile != 0).sum() > 0.0:
tile_result = model(tile)
with torch.autocast(device_type=device):
tile_result = model(tile)
if tile_pad_out > 0:
tile_result = tile_result[
:, :, tile_pad_out:-tile_pad_out, tile_pad_out:-tile_pad_out
Expand Down
3 changes: 2 additions & 1 deletion zetta_utils/alignment/base_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __call__(self, src: torch.Tensor) -> torch.Tensor:
raise ValueError(f"Unsupported src dtype: {src.dtype}")

data_in = einops.rearrange(data_in, "C X Y Z -> Z C X Y")
result = model(data_in.to(device))
with torch.autocast(device_type=device):
result = model(data_in.to(device))
result = einops.rearrange(result, "Z C X Y -> C X Y Z")

# Final layer assumed to be tanh
Expand Down

0 comments on commit 7f0a270

Please sign in to comment.