Skip to content

Commit

Permalink
Merge pull request #196 from huggingface/main
Browse files Browse the repository at this point in the history
Merge changes
  • Loading branch information
Skquark authored Dec 22, 2024
2 parents d18e6ae + 233dffd commit 0ed3049
Show file tree
Hide file tree
Showing 20 changed files with 1,174 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/push_tests_mps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
shell: arch -arch arm64 bash {0}
run: |
${CONDA_RUN} python -m pip install --upgrade pip uv
${CONDA_RUN} python -m uv pip install -e [quality,test]
${CONDA_RUN} python -m uv pip install -e ".[quality,test]"
${CONDA_RUN} python -m uv pip install torch torchvision torchaudio
${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
${CONDA_RUN} python -m uv pip install transformers --upgrade
Expand Down
3 changes: 1 addition & 2 deletions docs/source/en/quantization/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,11 @@ transformer = FluxTransformer2DModel.from_single_file(
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
generator=torch.manual_seed(0),
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt).images[0]
image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
image.save("flux-gguf.png")
```

Expand Down
6 changes: 3 additions & 3 deletions docs/source/en/quantization/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ If you are new to the quantization field, we recommend you to check out these be
## When to use what?

Diffusers currently supports the following quantization methods.
- [BitsandBytes](./bitsandbytes.md)
- [TorchAO](./torchao.md)
- [GGUF](./gguf.md)
- [BitsandBytes](./bitsandbytes)
- [TorchAO](./torchao)
- [GGUF](./gguf)

[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
97 changes: 97 additions & 0 deletions scripts/convert_flux_xlabs_ipadapter_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import argparse
from contextlib import nullcontext

import safetensors.torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download

from diffusers.utils.import_utils import is_accelerate_available, is_transformers_available


if is_transformers_available():
from transformers import CLIPVisionModelWithProjection

vision = True
else:
vision = False

"""
python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \
--original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \
--filename "flux-ip-adapter.safetensors"
--output_path "flux-ip-adapter-hf/"
"""


CTX = init_empty_weights if is_accelerate_available else nullcontext

parser = argparse.ArgumentParser()
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--filename", default="flux.safetensors", type=str)
parser.add_argument("--checkpoint_path", default=None, type=str)
parser.add_argument("--output_path", type=str)
parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str)

args = parser.parse_args()


def load_original_checkpoint(args):
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
elif args.checkpoint_path is not None:
ckpt_path = args.checkpoint_path
else:
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")

original_state_dict = safetensors.torch.load_file(ckpt_path)
return original_state_dict


def convert_flux_ipadapter_checkpoint_to_diffusers(original_state_dict, num_layers):
converted_state_dict = {}

# image_proj
## norm
converted_state_dict["image_proj.norm.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
converted_state_dict["image_proj.norm.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
## proj
converted_state_dict["image_proj.proj.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
converted_state_dict["image_proj.proj.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")

# double transformer blocks
for i in range(num_layers):
block_prefix = f"ip_adapter.{i}."
# to_k_ip
converted_state_dict[f"{block_prefix}to_k_ip.bias"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias"
)
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight"
)
# to_v_ip
converted_state_dict[f"{block_prefix}to_v_ip.bias"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias"
)
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight"
)

return converted_state_dict


def main(args):
original_ckpt = load_original_checkpoint(args)

num_layers = 19
converted_ip_adapter_state_dict = convert_flux_ipadapter_checkpoint_to_diffusers(original_ckpt, num_layers)

print("Saving Flux IP-Adapter in Diffusers format.")
safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors")

if vision:
model = CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path)
model.save_pretrained(f"{args.output_path}/image_encoder")


if __name__ == "__main__":
main(args)
5 changes: 3 additions & 2 deletions scripts/convert_sana_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
CTX = init_empty_weights if is_accelerate_available else nullcontext

ckpt_ids = [
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
"Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth",
"Efficient-Large-Model/Sana_1600M_512px_MultiLing/checkpoints/Sana_1600M_512px_MultiLing.pth",
Expand Down Expand Up @@ -265,9 +266,9 @@ def main(args):
"--image_size",
default=1024,
type=int,
choices=[512, 1024],
choices=[512, 1024, 2048],
required=False,
help="Image size of pretrained model, 512 or 1024.",
help="Image size of pretrained model, 512, 1024 or 2048.",
)
parser.add_argument(
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]
Expand Down
5 changes: 4 additions & 1 deletion src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def text_encoder_attn_modules(text_encoder):

if is_torch_available():
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]

_import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
_import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"]
Expand All @@ -77,6 +77,7 @@ def text_encoder_attn_modules(text_encoder):
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = [
"IPAdapterMixin",
"FluxIPAdapterMixin",
"SD3IPAdapterMixin",
]

Expand All @@ -86,12 +87,14 @@ def text_encoder_attn_modules(text_encoder):
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .single_file_model import FromOriginalModelMixin
from .transformer_flux import FluxTransformer2DLoadersMixin
from .transformer_sd3 import SD3Transformer2DLoadersMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers

if is_transformers_available():
from .ip_adapter import (
FluxIPAdapterMixin,
IPAdapterMixin,
SD3IPAdapterMixin,
)
Expand Down
Loading

0 comments on commit 0ed3049

Please sign in to comment.