Skip to content

Commit

Permalink
Make summarize_tensor robust to non-float dtypes (#171)
Browse files Browse the repository at this point in the history
  • Loading branch information
piercus authored Jan 11, 2024
1 parent ce0f988 commit c141091
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 12 deletions.
28 changes: 17 additions & 11 deletions src/refiners/fluxion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,20 +187,26 @@ def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata:


def summarize_tensor(tensor: torch.Tensor, /) -> str:
return (
"Tensor("
+ ", ".join(
info_list = [
f"shape=({', '.join(map(str, tensor.shape))})",
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
f"device={tensor.device}",
]
if not tensor.is_complex():
info_list.extend(
[
f"shape=({', '.join(map(str, tensor.shape))})",
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
f"device={tensor.device}",
f"min={tensor.min():.2f}", # type: ignore
f"max={tensor.max():.2f}", # type: ignore
f"mean={tensor.mean():.2f}",
f"std={tensor.std():.2f}",
f"norm={norm(x=tensor):.2f}",
f"grad={tensor.requires_grad}",
]
)
+ ")"

info_list.extend(
[
f"mean={tensor.float().mean():.2f}",
f"std={tensor.float().std():.2f}",
f"norm={norm(x=tensor.float()):.2f}",
f"grad={tensor.requires_grad}",
]
)

return "Tensor(" + ", ".join(info_list) + ")"
18 changes: 17 additions & 1 deletion tests/fluxion/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
from torch import device as Device, dtype as DType
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore

from refiners.fluxion.utils import gaussian_blur, image_to_tensor, manual_seed, no_grad, tensor_to_image
from refiners.fluxion.utils import (
gaussian_blur,
image_to_tensor,
manual_seed,
no_grad,
summarize_tensor,
tensor_to_image,
)


@dataclass
Expand Down Expand Up @@ -64,6 +71,15 @@ def test_tensor_to_image() -> None:
assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA"


def test_summarize_tensor() -> None:
assert summarize_tensor(torch.zeros(1, 3, 512, 512).int())
assert summarize_tensor(torch.zeros(1, 3, 512, 512).float())
assert summarize_tensor(torch.zeros(1, 3, 512, 512).double())
assert summarize_tensor(torch.complex(torch.zeros(1, 3, 512, 512), torch.zeros(1, 3, 512, 512)))
assert summarize_tensor(torch.zeros(1, 3, 512, 512).bfloat16())
assert summarize_tensor(torch.zeros(1, 3, 512, 512).bool())


def test_no_grad() -> None:
x = torch.randn(1, 1, requires_grad=True)

Expand Down

0 comments on commit c141091

Please sign in to comment.