Skip to content

Commit

Permalink
Add load_tensors utils in fluxion
Browse files Browse the repository at this point in the history
  • Loading branch information
limiteinductive committed Jan 21, 2024
1 parent 91aea9b commit ed36213
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 16 deletions.
4 changes: 2 additions & 2 deletions scripts/conversion/convert_diffusers_ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from refiners.fluxion.utils import save_to_safetensors
from refiners.fluxion.utils import load_tensors, save_to_safetensors
from refiners.foundationals.latent_diffusion import SD1IPAdapter, SD1UNet, SDXLIPAdapter, SDXLUNet

# Running:
Expand Down Expand Up @@ -66,7 +66,7 @@ def main() -> None:
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}.safetensors"

weights: dict[str, Any] = torch.load(f=args.source_path, map_location="cpu") # type: ignore
weights: dict[str, Any] = load_tensors(args.source_path, device="cpu")
assert isinstance(weights, dict)
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]

Expand Down
4 changes: 2 additions & 2 deletions scripts/conversion/convert_dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from refiners.fluxion.utils import save_to_safetensors
from refiners.fluxion.utils import load_tensors, save_to_safetensors


def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:
Expand Down Expand Up @@ -148,7 +148,7 @@ def main() -> None:
parser.add_argument("--half", action="store_true", dest="half")
args = parser.parse_args()

weights = torch.load(args.source_path) # type: ignore
weights = load_tensors(args.source_path)
convert_dinov2_facebook(weights)
if args.half:
weights = {key: value.half() for key, value in weights.items()}
Expand Down
3 changes: 2 additions & 1 deletion scripts/conversion/convert_informative_drawings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch import nn

from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import load_tensors
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings

try:
Expand All @@ -27,7 +28,7 @@ class Args(argparse.Namespace):

def setup_converter(args: Args) -> ModelConverter:
source = Generator(3, 1, 3)
source.load_state_dict(state_dict=torch.load(f=args.source_path, map_location="cpu")) # type: ignore
source.load_state_dict(state_dict=load_tensors(args.source_path))
source.eval()
target = InformativeDrawings()
x = torch.randn(1, 3, 512, 512)
Expand Down
4 changes: 2 additions & 2 deletions scripts/conversion/convert_segment_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import refiners.fluxion.layers as fl
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import manual_seed, save_to_safetensors
from refiners.fluxion.utils import load_tensors, manual_seed, save_to_safetensors
from refiners.foundationals.segment_anything.image_encoder import SAMViTH
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
Expand Down Expand Up @@ -245,7 +245,7 @@ def main() -> None:
args = parser.parse_args(namespace=Args())

sam_h = build_sam_vit_h() # type: ignore
sam_h.load_state_dict(state_dict=torch.load(f=args.source_path)) # type: ignore
sam_h.load_state_dict(state_dict=load_tensors(args.source_path))

vit_state_dict = convert_vit(vit=sam_h.image_encoder)
mask_decoder_state_dict = convert_mask_decoder(mask_decoder=sam_h.mask_decoder)
Expand Down
22 changes: 21 additions & 1 deletion src/refiners/fluxion/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings
from pathlib import Path
from typing import Any, Iterable, Literal, TypeVar
from typing import Any, Iterable, Literal, TypeVar, cast

import torch
from jaxtyping import Float
Expand Down Expand Up @@ -173,6 +174,25 @@ def safe_open(
return _safe_open(str(path), framework=framework_mapping[framework], device=str(device)) # type: ignore


def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str, Tensor]:
"""
Load tensors from a file saved with `torch.save` from disk using the `weights_only` mode
for additional safety (see `torch.load` for more details). Still, *only load data you trust* and
favor using `load_from_safetensors`.
"""
# see https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
tensors = torch.load(path, map_location=device, weights_only=True) # type: ignore

assert isinstance(tensors, dict) and all(
isinstance(key, str) and isinstance(value, Tensor)
for key, value in tensors.items() # type: ignore
), "Invalid tensor file, expected a dict[str, Tensor]"

return cast(dict[str, Tensor], tensors)


def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dict[str, Tensor]:
with safe_open(path=path, framework="pytorch", device=device) as tensors: # type: ignore
return {key: tensors.get_tensor(key) for key in tensors.keys()} # type: ignore
Expand Down
6 changes: 3 additions & 3 deletions tests/e2e/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from PIL import Image

from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed, no_grad
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad
from refiners.foundationals.clip.concepts import ConceptExtender
from refiners.foundationals.latent_diffusion import (
SD1ControlnetAdapter,
Expand Down Expand Up @@ -199,7 +199,7 @@ def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Im
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)

tensors = torch.load(weights_path) # type: ignore
tensors = load_tensors(weights_path)
return expected_image, tensors


Expand Down Expand Up @@ -282,7 +282,7 @@ def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image.

@pytest.fixture
def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
return load_tensors(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"]


@pytest.fixture(scope="module")
Expand Down
29 changes: 29 additions & 0 deletions tests/fluxion/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pickle
from dataclasses import dataclass
from pathlib import Path
from warnings import warn

import pytest
Expand All @@ -7,9 +9,11 @@
from torch import device as Device, dtype as DType
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore

from refiners.fluxion import layers as fl
from refiners.fluxion.utils import (
gaussian_blur,
image_to_tensor,
load_tensors,
manual_seed,
no_grad,
summarize_tensor,
Expand Down Expand Up @@ -95,3 +99,28 @@ def test_no_grad() -> None:

w = x + 1
assert w.requires_grad


def test_load_tensors_valid_pickle(tmp_path: Path) -> None:
pickle_path = tmp_path / "valid.pickle"

tensors = {"easy-as.weight": torch.tensor([1.0, 2.0, 3.0])}
torch.save(tensors, pickle_path) # type: ignore
loaded_tensor = load_tensors(pickle_path)
assert torch.equal(loaded_tensor["easy-as.weight"], tensors["easy-as.weight"])

tensors = {"easy-as.weight": torch.tensor([1, 2, 3]), "hello": "world"}
torch.save(tensors, pickle_path) # type: ignore

with pytest.raises(AssertionError):
loaded_tensor = load_tensors(pickle_path)


def test_load_tensors_invalid_pickle(tmp_path: Path) -> None:
invalid_pickle_path = tmp_path / "invalid.pickle"
model = fl.Chain(fl.Linear(1, 1))
torch.save(model, invalid_pickle_path) # type: ignore
with pytest.raises(
pickle.UnpicklingError,
):
load_tensors(invalid_pickle_path)
6 changes: 3 additions & 3 deletions tests/foundationals/clip/test_concepts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from diffusers import StableDiffusionPipeline # type: ignore

import refiners.fluxion.layers as fl
from refiners.fluxion.utils import load_from_safetensors, no_grad
from refiners.fluxion.utils import load_from_safetensors, load_tensors, no_grad
from refiners.foundationals.clip.concepts import ConceptExtender, TokenExtender
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
Expand Down Expand Up @@ -76,12 +76,12 @@ def prompt(request: pytest.FixtureRequest):

@pytest.fixture(scope="module")
def gta5_artwork_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
return load_tensors(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"]


@pytest.fixture(scope="module")
def cat_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
return torch.load(test_textual_inversion_path / "cat-toy" / "learned_embeds.bin")["<cat-toy>"] # type: ignore
return load_tensors(test_textual_inversion_path / "cat-toy" / "learned_embeds.bin")["<cat-toy>"]


def test_tokenizer_with_special_character():
Expand Down
4 changes: 2 additions & 2 deletions tests/foundationals/segment_anything/test_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from refiners.fluxion import manual_seed
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import image_to_tensor, no_grad
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention
from refiners.foundationals.segment_anything.model import SegmentAnythingH
from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer
Expand Down Expand Up @@ -69,7 +69,7 @@ def facebook_sam_h(facebook_sam_h_weights: Path, test_device: torch.device) -> F
from segment_anything import build_sam_vit_h # type: ignore

sam_h = cast(FacebookSAM, build_sam_vit_h())
sam_h.load_state_dict(state_dict=torch.load(f=facebook_sam_h_weights)) # type: ignore
sam_h.load_state_dict(state_dict=load_tensors(facebook_sam_h_weights))
return sam_h.to(device=test_device)


Expand Down

0 comments on commit ed36213

Please sign in to comment.