diff --git a/jflux/util.py b/jflux/util.py index bdb5c4a..2243718 100644 --- a/jflux/util.py +++ b/jflux/util.py @@ -13,7 +13,7 @@ from jflux.modules.autoencoder import AutoEncoder, AutoEncoderParams from jflux.modules.conditioner import HFEmbedder -from port import port_autoencoder +from port import port_autoencoder, port_flux @dataclass @@ -128,13 +128,13 @@ def load_flow_model(name: str, hf_download: bool = True) -> Flux: model = Flux(params=configs[name].params) - # TODO (ariG23498): Port the flux model if ckpt_path is not None: - print("Loading checkpoint") - # load_sft doesn't support torch.device - sd = load_sft(ckpt_path) - missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) - print_load_warning(missing, unexpected) + tensors = {} + with safe_open(ckpt_path, framework="flax") as f: + for k in f.keys(): + tensors[k] = f.get_tensor(k) + + model = port_flux(flux=model, tensors=tensors) return model @@ -166,7 +166,6 @@ def load_ae(name: str, hf_download: bool = True) -> AutoEncoder: print("Init AE") ae = AutoEncoder(params=configs[name].ae_params) - # TODO (ariG23498): Port the flux model if ckpt_path is not None: tensors = {} with safe_open(ckpt_path, framework="flax") as f: