Skip to content

Commit

Permalink
chore:using flux port code
Browse files Browse the repository at this point in the history
  • Loading branch information
ariG23498 committed Oct 8, 2024
1 parent cbdb372 commit bce2003
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions jflux/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from jflux.modules.autoencoder import AutoEncoder, AutoEncoderParams
from jflux.modules.conditioner import HFEmbedder

from port import port_autoencoder
from port import port_autoencoder, port_flux


@dataclass
Expand Down Expand Up @@ -128,13 +128,13 @@ def load_flow_model(name: str, hf_download: bool = True) -> Flux:

model = Flux(params=configs[name].params)

# TODO (ariG23498): Port the flux model
if ckpt_path is not None:
print("Loading checkpoint")
# load_sft doesn't support torch.device
sd = load_sft(ckpt_path)
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
print_load_warning(missing, unexpected)
tensors = {}
with safe_open(ckpt_path, framework="flax") as f:
for k in f.keys():
tensors[k] = f.get_tensor(k)

model = port_flux(flux=model, tensors=tensors)
return model


Expand Down Expand Up @@ -166,7 +166,6 @@ def load_ae(name: str, hf_download: bool = True) -> AutoEncoder:
print("Init AE")
ae = AutoEncoder(params=configs[name].ae_params)

# TODO (ariG23498): Port the flux model
if ckpt_path is not None:
tensors = {}
with safe_open(ckpt_path, framework="flax") as f:
Expand Down

0 comments on commit bce2003

Please sign in to comment.