Skip to content

Commit

Permalink
fix: summarize_tensor(tensor) when tensor.numel() == 0
Browse files Browse the repository at this point in the history
  • Loading branch information
piercus authored and deltheil committed Jan 20, 2024
1 parent 2b4bc77 commit 86c11ab
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/refiners/fluxion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,13 @@ def summarize_tensor(tensor: torch.Tensor, /) -> str:
if tensor.is_complex():
tensor_f = tensor.real.float()
else:
info_list.extend(
[
f"min={tensor.min():.2f}", # type: ignore
f"max={tensor.max():.2f}", # type: ignore
]
)
if tensor.numel() > 0:
info_list.extend(
[
f"min={tensor.min():.2f}", # type: ignore
f"max={tensor.max():.2f}", # type: ignore
]
)
tensor_f = tensor.float()

info_list.extend(
Expand Down
1 change: 1 addition & 0 deletions tests/fluxion/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def test_summarize_tensor() -> None:
assert summarize_tensor(torch.complex(torch.zeros(1, 3, 512, 512), torch.zeros(1, 3, 512, 512)))
assert summarize_tensor(torch.zeros(1, 3, 512, 512).bfloat16())
assert summarize_tensor(torch.zeros(1, 3, 512, 512).bool())
assert summarize_tensor(torch.zeros(1, 0, 512, 512).int())


def test_no_grad() -> None:
Expand Down

0 comments on commit 86c11ab

Please sign in to comment.