From ee519cfef5d274f7d0c67270674523833083640d Mon Sep 17 00:00:00 2001 From: co63oc Date: Tue, 21 Nov 2023 18:56:13 +0800 Subject: [PATCH 1/8] Update README.md (#5855) --- examples/community/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/community/README.md b/examples/community/README.md index 87b0ed9151a7..36b34e8b0be5 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -2052,7 +2052,7 @@ import torch from PIL import Image from io import BytesIO -from diffusers import Diffusionpipeline +from diffusers import DiffusionPipeline # load the pipeline # make sure you're logged in with `huggingface-cli login` From ebc7bedeb7d07269f4e2b2ce38afb77ed9b4c91e Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 21 Nov 2023 18:01:44 +0530 Subject: [PATCH 2/8] Add tests fetcher (#5848) * add tests fetcher to utils * add test fetcher * update * update * remove unused dependency version check script * update * fix mistake * update * update * update * update * update * update * update * remove concurrency params * update * update * update * update * update * update * move test fetcher to dedicated workflow --- .github/workflows/pr_test_fetcher.yml | 171 ++++ utils/tests_fetcher.py | 1107 +++++++++++++++++++++++++ 2 files changed, 1278 insertions(+) create mode 100644 .github/workflows/pr_test_fetcher.yml create mode 100644 utils/tests_fetcher.py diff --git a/.github/workflows/pr_test_fetcher.yml b/.github/workflows/pr_test_fetcher.yml new file mode 100644 index 000000000000..d33bca1903f4 --- /dev/null +++ b/.github/workflows/pr_test_fetcher.yml @@ -0,0 +1,171 @@ +name: Fast tests for PRs + +on: + pull_request: + branches: + - main + push: + branches: + - ci-* + +env: + DIFFUSERS_IS_CI: yes + OMP_NUM_THREADS: 4 + MKL_NUM_THREADS: 4 + PYTEST_TIMEOUT: 60 + +jobs: + setup_pr_tests: + name: Setup PR Tests + runs-on: docker-cpu + container: + image: diffusers/diffusers-pytorch-cpu + options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ + defaults: + run: + shell: bash + outputs: + matrix: ${{ steps.set_matrix.outputs.matrix }} + test_map: ${{ steps.set_matrix.outputs.test_map }} + steps: + - name: Checkout diffusers + uses: actions/checkout@v3 + with: + fetch-depth: 2 + - name: Install dependencies + run: | + apt-get update && apt-get install libsndfile1-dev libgl1 -y + python -m pip install -e . + - name: Environment + run: | + python utils/print_env.py + - name: Fetch Tests + run: | + python utils/tests_fetcher.py | tee test_preparation.txt + - name: Report fetched tests + uses: actions/upload-artifact@v3 + with: + name: test_fetched + path: test_preparation.txt + - id: set_matrix + name: Create Test Matrix + # The `keys` is used as GitHub actions matrix for jobs, i.e. `models`, `pipelines`, etc. + # The `test_map` is used to get the actual identified test files under each key. + # If no test to run (so no `test_map.json` file), create a dummy map (empty matrix will fail) + run: | + if [ -f test_map.json ]; then + keys=$(python3 -c 'import json; fp = open("test_map.json"); test_map = json.load(fp); fp.close(); d = list(test_map.keys()); print(json.dumps(d))') + test_map=$(python3 -c 'import json; fp = open("test_map.json"); test_map = json.load(fp); fp.close(); print(json.dumps(test_map))') + else + keys=$(python3 -c 'keys = ["dummy"]; print(keys)') + test_map=$(python3 -c 'test_map = {"dummy": []}; print(test_map)') + fi + echo $keys + echo $test_map + echo "matrix=$keys" >> $GITHUB_OUTPUT + echo "test_map=$test_map" >> $GITHUB_OUTPUT + + run_pr_tests: + name: Run PR Tests + needs: setup_pr_tests + if: contains(fromJson(needs.setup_pr_tests.outputs.matrix), 'dummy') != true + strategy: + fail-fast: false + max-parallel: 2 + matrix: + modules: ${{ fromJson(needs.setup_pr_tests.outputs.matrix) }} + runs-on: docker-cpu + container: + image: diffusers/diffusers-pytorch-cpu + options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ + defaults: + run: + shell: bash + steps: + - name: Checkout diffusers + uses: actions/checkout@v3 + with: + fetch-depth: 2 + + - name: Install dependencies + run: | + apt-get update && apt-get install libsndfile1-dev libgl1 -y + python -m pip install -e .[quality,test] + python -m pip install accelerate + + - name: Environment + run: | + python utils/print_env.py + + - name: Run all selected tests on CPU + run: | + python -m pytest -n 2 --dist=loadfile -v --make-reports=${{ matrix.modules }}_tests_cpu ${{ fromJson(needs.setup_pr_tests.outputs.test_map)[matrix.modules] }} + + - name: Failure short reports + if: ${{ failure() }} + continue-on-error: true + run: | + cat reports/${{ matrix.modules }}_tests_cpu_stats.txt + cat reports/${{ matrix.modules }}_tests_cpu/failures_short.txt + + - name: Test suite reports artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v3 + with: + name: ${{ matrix.modules }}_test_reports + path: reports + + run_staging_tests: + strategy: + fail-fast: false + matrix: + config: + - name: Hub tests for models, schedulers, and pipelines + framework: hub_tests_pytorch + runner: docker-cpu + image: diffusers/diffusers-pytorch-cpu + report: torch_hub + + name: ${{ matrix.config.name }} + runs-on: ${{ matrix.config.runner }} + container: + image: ${{ matrix.config.image }} + options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ + + defaults: + run: + shell: bash + + steps: + - name: Checkout diffusers + uses: actions/checkout@v3 + with: + fetch-depth: 2 + + - name: Install dependencies + run: | + apt-get update && apt-get install libsndfile1-dev libgl1 -y + python -m pip install -e .[quality,test] + + - name: Environment + run: | + python utils/print_env.py + + - name: Run Hub tests for models, schedulers, and pipelines on a staging env + if: ${{ matrix.config.framework == 'hub_tests_pytorch' }} + run: | + HUGGINGFACE_CO_STAGING=true python -m pytest \ + -m "is_staging_test" \ + --make-reports=tests_${{ matrix.config.report }} \ + tests + + - name: Failure short reports + if: ${{ failure() }} + run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt + + - name: Test suite reports artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v2 + with: + name: pr_${{ matrix.config.report }}_test_reports + path: reports diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py new file mode 100644 index 000000000000..365310f415a2 --- /dev/null +++ b/utils/tests_fetcher.py @@ -0,0 +1,1107 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Welcome to tests_fetcher V2. + +This util is designed to fetch tests to run on a PR so that only the tests impacted by the modifications are run, and +when too many models are being impacted, only run the tests of a subset of core models. It works like this. + +Stage 1: Identify the modified files. For jobs that run on the main branch, it's just the diff with the last commit. +On a PR, this takes all the files from the branching point to the current commit (so all modifications in a PR, not +just the last commit) but excludes modifications that are on docstrings or comments only. + +Stage 2: Extract the tests to run. This is done by looking at the imports in each module and test file: if module A +imports module B, then changing module B impacts module A, so the tests using module A should be run. We thus get the +dependencies of each model and then recursively builds the 'reverse' map of dependencies to get all modules and tests +impacted by a given file. We then only keep the tests (and only the core models tests if there are too many modules). + +Caveats: + - This module only filters tests by files (not individual tests) so it's better to have tests for different things + in different files. + - This module assumes inits are just importing things, not really building objects, so it's better to structure + them this way and move objects building in separate submodules. + +Usage: + +Base use to fetch the tests in a pull request + +```bash +python utils/tests_fetcher.py +``` + +Base use to fetch the tests on a the main branch (with diff from the last commit): + +```bash +python utils/tests_fetcher.py --diff_with_last_commit +``` +""" + +import argparse +import collections +import json +import os +import re +from contextlib import contextmanager +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +from git import Repo + + +PATH_TO_REPO = Path(__file__).parent.parent.resolve() +PATH_TO_EXAMPLES = PATH_TO_REPO / "examples" +PATH_TO_DIFFUSERS = PATH_TO_REPO / "src/diffusers" +PATH_TO_TESTS = PATH_TO_REPO / "tests" + +# List here the pipelines to always test. +IMPORTANT_PIPELINES = [ + "controlnet", + "stable_diffusion", + "stable_diffusion_2", + "stable_diffusion_xl", + "deepfloyd_if", + "kandinsky", + "kandinsky2_2", + "text_to_video_synthesis", + "wuerstchen", +] + +# Ignore fixtures in tests folder +# Ignore lora since they are always tested +MODULES_TO_IGNORE = ["fixtures", "lora"] + + +@contextmanager +def checkout_commit(repo: Repo, commit_id: str): + """ + Context manager that checks out a given commit when entered, but gets back to the reference it was at on exit. + + Args: + repo (`git.Repo`): A git repository (for instance the Transformers repo). + commit_id (`str`): The commit reference to checkout inside the context manager. + """ + current_head = repo.head.commit if repo.head.is_detached else repo.head.ref + + try: + repo.git.checkout(commit_id) + yield + + finally: + repo.git.checkout(current_head) + + +def clean_code(content: str) -> str: + """ + Remove docstrings, empty line or comments from some code (used to detect if a diff is real or only concern + comments or docstings). + + Args: + content (`str`): The code to clean + + Returns: + `str`: The cleaned code. + """ + # We need to deactivate autoformatting here to write escaped triple quotes (we cannot use real triple quotes or + # this would mess up the result if this function applied to this particular file). + # fmt: off + # Remove docstrings by splitting on triple " then triple ': + splits = content.split('\"\"\"') + content = "".join(splits[::2]) + splits = content.split("\'\'\'") + # fmt: on + content = "".join(splits[::2]) + + # Remove empty lines and comments + lines_to_keep = [] + for line in content.split("\n"): + # remove anything that is after a # sign. + line = re.sub("#.*$", "", line) + # remove white lines + if len(line) != 0 and not line.isspace(): + lines_to_keep.append(line) + return "\n".join(lines_to_keep) + + +def keep_doc_examples_only(content: str) -> str: + """ + Remove everything from the code content except the doc examples (used to determined if a diff should trigger doc + tests or not). + + Args: + content (`str`): The code to clean + + Returns: + `str`: The cleaned code. + """ + # Keep doc examples only by splitting on triple "`" + splits = content.split("```") + # Add leading and trailing "```" so the navigation is easier when compared to the original input `content` + content = "```" + "```".join(splits[1::2]) + "```" + + # Remove empty lines and comments + lines_to_keep = [] + for line in content.split("\n"): + # remove anything that is after a # sign. + line = re.sub("#.*$", "", line) + # remove white lines + if len(line) != 0 and not line.isspace(): + lines_to_keep.append(line) + return "\n".join(lines_to_keep) + + +def get_all_tests() -> List[str]: + """ + Walks the `tests` folder to return a list of files/subfolders. This is used to split the tests to run when using + paralellism. The split is: + + - folders under `tests`: (`tokenization`, `pipelines`, etc) except the subfolder `models` is excluded. + - folders under `tests/models`: `bert`, `gpt2`, etc. + - test files under `tests`: `test_modeling_common.py`, `test_tokenization_common.py`, etc. + """ + + # test folders/files directly under `tests` folder + tests = os.listdir(PATH_TO_TESTS) + tests = [f"tests/{f}" for f in tests if "__pycache__" not in f] + tests = sorted([f for f in tests if (PATH_TO_REPO / f).is_dir() or f.startswith("tests/test_")]) + + return tests + + +def diff_is_docstring_only(repo: Repo, branching_point: str, filename: str) -> bool: + """ + Check if the diff is only in docstrings (or comments and whitespace) in a filename. + + Args: + repo (`git.Repo`): A git repository (for instance the Transformers repo). + branching_point (`str`): The commit reference of where to compare for the diff. + filename (`str`): The filename where we want to know if the diff isonly in docstrings/comments. + + Returns: + `bool`: Whether the diff is docstring/comments only or not. + """ + folder = Path(repo.working_dir) + with checkout_commit(repo, branching_point): + with open(folder / filename, "r", encoding="utf-8") as f: + old_content = f.read() + + with open(folder / filename, "r", encoding="utf-8") as f: + new_content = f.read() + + old_content_clean = clean_code(old_content) + new_content_clean = clean_code(new_content) + + return old_content_clean == new_content_clean + + +def diff_contains_doc_examples(repo: Repo, branching_point: str, filename: str) -> bool: + """ + Check if the diff is only in code examples of the doc in a filename. + + Args: + repo (`git.Repo`): A git repository (for instance the Transformers repo). + branching_point (`str`): The commit reference of where to compare for the diff. + filename (`str`): The filename where we want to know if the diff is only in codes examples. + + Returns: + `bool`: Whether the diff is only in code examples of the doc or not. + """ + folder = Path(repo.working_dir) + with checkout_commit(repo, branching_point): + with open(folder / filename, "r", encoding="utf-8") as f: + old_content = f.read() + + with open(folder / filename, "r", encoding="utf-8") as f: + new_content = f.read() + + old_content_clean = keep_doc_examples_only(old_content) + new_content_clean = keep_doc_examples_only(new_content) + + return old_content_clean != new_content_clean + + +def get_diff(repo: Repo, base_commit: str, commits: List[str]) -> List[str]: + """ + Get the diff between a base commit and one or several commits. + + Args: + repo (`git.Repo`): + A git repository (for instance the Transformers repo). + base_commit (`str`): + The commit reference of where to compare for the diff. This is the current commit, not the branching point! + commits (`List[str]`): + The list of commits with which to compare the repo at `base_commit` (so the branching point). + + Returns: + `List[str]`: The list of Python files with a diff (files added, renamed or deleted are always returned, files + modified are returned if the diff in the file is not only in docstrings or comments, see + `diff_is_docstring_only`). + """ + print("\n### DIFF ###\n") + code_diff = [] + for commit in commits: + for diff_obj in commit.diff(base_commit): + # We always add new python files + if diff_obj.change_type == "A" and diff_obj.b_path.endswith(".py"): + code_diff.append(diff_obj.b_path) + # We check that deleted python files won't break corresponding tests. + elif diff_obj.change_type == "D" and diff_obj.a_path.endswith(".py"): + code_diff.append(diff_obj.a_path) + # Now for modified files + elif diff_obj.change_type in ["M", "R"] and diff_obj.b_path.endswith(".py"): + # In case of renames, we'll look at the tests using both the old and new name. + if diff_obj.a_path != diff_obj.b_path: + code_diff.extend([diff_obj.a_path, diff_obj.b_path]) + else: + # Otherwise, we check modifications are in code and not docstrings. + if diff_is_docstring_only(repo, commit, diff_obj.b_path): + print(f"Ignoring diff in {diff_obj.b_path} as it only concerns docstrings or comments.") + else: + code_diff.append(diff_obj.a_path) + + return code_diff + + +def get_modified_python_files(diff_with_last_commit: bool = False) -> List[str]: + """ + Return a list of python files that have been modified between: + + - the current head and the main branch if `diff_with_last_commit=False` (default) + - the current head and its parent commit otherwise. + + Returns: + `List[str]`: The list of Python files with a diff (files added, renamed or deleted are always returned, files + modified are returned if the diff in the file is not only in docstrings or comments, see + `diff_is_docstring_only`). + """ + repo = Repo(PATH_TO_REPO) + + if not diff_with_last_commit: + print(f"main is at {repo.refs.main.commit}") + print(f"Current head is at {repo.head.commit}") + + branching_commits = repo.merge_base(repo.refs.main, repo.head) + for commit in branching_commits: + print(f"Branching commit: {commit}") + return get_diff(repo, repo.head.commit, branching_commits) + else: + print(f"main is at {repo.head.commit}") + parent_commits = repo.head.commit.parents + for commit in parent_commits: + print(f"Parent commit: {commit}") + return get_diff(repo, repo.head.commit, parent_commits) + + +def get_diff_for_doctesting(repo: Repo, base_commit: str, commits: List[str]) -> List[str]: + """ + Get the diff in doc examples between a base commit and one or several commits. + + Args: + repo (`git.Repo`): + A git repository (for instance the Transformers repo). + base_commit (`str`): + The commit reference of where to compare for the diff. This is the current commit, not the branching point! + commits (`List[str]`): + The list of commits with which to compare the repo at `base_commit` (so the branching point). + + Returns: + `List[str]`: The list of Python and Markdown files with a diff (files added or renamed are always returned, files + modified are returned if the diff in the file is only in doctest examples). + """ + print("\n### DIFF ###\n") + code_diff = [] + for commit in commits: + for diff_obj in commit.diff(base_commit): + # We only consider Python files and doc files. + if not diff_obj.b_path.endswith(".py") and not diff_obj.b_path.endswith(".md"): + continue + # We always add new python/md files + if diff_obj.change_type in ["A"]: + code_diff.append(diff_obj.b_path) + # Now for modified files + elif diff_obj.change_type in ["M", "R"]: + # In case of renames, we'll look at the tests using both the old and new name. + if diff_obj.a_path != diff_obj.b_path: + code_diff.extend([diff_obj.a_path, diff_obj.b_path]) + else: + # Otherwise, we check modifications contain some doc example(s). + if diff_contains_doc_examples(repo, commit, diff_obj.b_path): + code_diff.append(diff_obj.a_path) + else: + print(f"Ignoring diff in {diff_obj.b_path} as it doesn't contain any doc example.") + + return code_diff + + +def get_all_doctest_files() -> List[str]: + """ + Return the complete list of python and Markdown files on which we run doctest. + + At this moment, we restrict this to only take files from `src/` or `docs/source/en/` that are not in `utils/not_doctested.txt`. + + Returns: + `List[str]`: The complete list of Python and Markdown files on which we run doctest. + """ + py_files = [str(x.relative_to(PATH_TO_REPO)) for x in PATH_TO_REPO.glob("**/*.py")] + md_files = [str(x.relative_to(PATH_TO_REPO)) for x in PATH_TO_REPO.glob("**/*.md")] + test_files_to_run = py_files + md_files + + # only include files in `src` or `docs/source/en/` + test_files_to_run = [x for x in test_files_to_run if x.startswith(("src/", "docs/source/en/"))] + # not include init files + test_files_to_run = [x for x in test_files_to_run if not x.endswith(("__init__.py",))] + + # These are files not doctested yet. + with open("utils/not_doctested.txt") as fp: + not_doctested = {x.split(" ")[0] for x in fp.read().strip().split("\n")} + + # So far we don't have 100% coverage for doctest. This line will be removed once we achieve 100%. + test_files_to_run = [x for x in test_files_to_run if x not in not_doctested] + + return sorted(test_files_to_run) + + +def get_new_doctest_files(repo, base_commit, branching_commit) -> List[str]: + """ + Get the list of files that were removed from "utils/not_doctested.txt", between `base_commit` and + `branching_commit`. + + Returns: + `List[str]`: List of files that were removed from "utils/not_doctested.txt". + """ + for diff_obj in branching_commit.diff(base_commit): + # Ignores all but the "utils/not_doctested.txt" file. + if diff_obj.a_path != "utils/not_doctested.txt": + continue + # Loads the two versions + folder = Path(repo.working_dir) + with checkout_commit(repo, branching_commit): + with open(folder / "utils/not_doctested.txt", "r", encoding="utf-8") as f: + old_content = f.read() + with open(folder / "utils/not_doctested.txt", "r", encoding="utf-8") as f: + new_content = f.read() + # Compute the removed lines and return them + removed_content = {x.split(" ")[0] for x in old_content.split("\n")} - { + x.split(" ")[0] for x in new_content.split("\n") + } + return sorted(removed_content) + return [] + + +def get_doctest_files(diff_with_last_commit: bool = False) -> List[str]: + """ + Return a list of python and Markdown files where doc example have been modified between: + + - the current head and the main branch if `diff_with_last_commit=False` (default) + - the current head and its parent commit otherwise. + + Returns: + `List[str]`: The list of Python and Markdown files with a diff (files added or renamed are always returned, files + modified are returned if the diff in the file is only in doctest examples). + """ + repo = Repo(PATH_TO_REPO) + + test_files_to_run = [] # noqa + if not diff_with_last_commit: + print(f"main is at {repo.refs.main.commit}") + print(f"Current head is at {repo.head.commit}") + + branching_commits = repo.merge_base(repo.refs.main, repo.head) + for commit in branching_commits: + print(f"Branching commit: {commit}") + test_files_to_run = get_diff_for_doctesting(repo, repo.head.commit, branching_commits) + else: + print(f"main is at {repo.head.commit}") + parent_commits = repo.head.commit.parents + for commit in parent_commits: + print(f"Parent commit: {commit}") + test_files_to_run = get_diff_for_doctesting(repo, repo.head.commit, parent_commits) + + all_test_files_to_run = get_all_doctest_files() + + # Add to the test files to run any removed entry from "utils/not_doctested.txt". + new_test_files = get_new_doctest_files(repo, repo.head.commit, repo.refs.main.commit) + test_files_to_run = list(set(test_files_to_run + new_test_files)) + + # Do not run slow doctest tests on CircleCI + with open("utils/slow_documentation_tests.txt") as fp: + slow_documentation_tests = set(fp.read().strip().split("\n")) + test_files_to_run = [ + x for x in test_files_to_run if x in all_test_files_to_run and x not in slow_documentation_tests + ] + + # Make sure we did not end up with a test file that was removed + test_files_to_run = [f for f in test_files_to_run if (PATH_TO_REPO / f).exists()] + + return sorted(test_files_to_run) + + +# (:?^|\n) -> Non-catching group for the beginning of the doc or a new line. +# \s*from\s+(\.+\S+)\s+import\s+([^\n]+) -> Line only contains from .xxx import yyy and we catch .xxx and yyy +# (?=\n) -> Look-ahead to a new line. We can't just put \n here or using find_all on this re will only catch every +# other import. +_re_single_line_relative_imports = re.compile(r"(?:^|\n)\s*from\s+(\.+\S+)\s+import\s+([^\n]+)(?=\n)") +# (:?^|\n) -> Non-catching group for the beginning of the doc or a new line. +# \s*from\s+(\.+\S+)\s+import\s+\(([^\)]+)\) -> Line continues with from .xxx import (yyy) and we catch .xxx and yyy +# yyy will take multiple lines otherwise there wouldn't be parenthesis. +_re_multi_line_relative_imports = re.compile(r"(?:^|\n)\s*from\s+(\.+\S+)\s+import\s+\(([^\)]+)\)") +# (:?^|\n) -> Non-catching group for the beginning of the doc or a new line. +# \s*from\s+transformers(\S*)\s+import\s+([^\n]+) -> Line only contains from transformers.xxx import yyy and we catch +# .xxx and yyy +# (?=\n) -> Look-ahead to a new line. We can't just put \n here or using find_all on this re will only catch every +# other import. +_re_single_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+diffusers(\S*)\s+import\s+([^\n]+)(?=\n)") +# (:?^|\n) -> Non-catching group for the beginning of the doc or a new line. +# \s*from\s+transformers(\S*)\s+import\s+\(([^\)]+)\) -> Line continues with from transformers.xxx import (yyy) and we +# catch .xxx and yyy. yyy will take multiple lines otherwise there wouldn't be parenthesis. +_re_multi_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+diffusers(\S*)\s+import\s+\(([^\)]+)\)") + + +def extract_imports(module_fname: str, cache: Dict[str, List[str]] = None) -> List[str]: + """ + Get the imports a given module makes. + + Args: + module_fname (`str`): + The name of the file of the module where we want to look at the imports (given relative to the root of + the repo). + cache (Dictionary `str` to `List[str]`, *optional*): + To speed up this function if it was previously called on `module_fname`, the cache of all previously + computed results. + + Returns: + `List[str]`: The list of module filenames imported in the input `module_fname` (a submodule we import from that + is a subfolder will give its init file). + """ + if cache is not None and module_fname in cache: + return cache[module_fname] + + with open(PATH_TO_REPO / module_fname, "r", encoding="utf-8") as f: + content = f.read() + + # Filter out all docstrings to not get imports in code examples. As before we need to deactivate formatting to + # keep this as escaped quotes and avoid this function failing on this file. + # fmt: off + splits = content.split('\"\"\"') + # fmt: on + content = "".join(splits[::2]) + + module_parts = str(module_fname).split(os.path.sep) + imported_modules = [] + + # Let's start with relative imports + relative_imports = _re_single_line_relative_imports.findall(content) + relative_imports = [ + (mod, imp) for mod, imp in relative_imports if "# tests_ignore" not in imp and imp.strip() != "(" + ] + multiline_relative_imports = _re_multi_line_relative_imports.findall(content) + relative_imports += [(mod, imp) for mod, imp in multiline_relative_imports if "# tests_ignore" not in imp] + + # We need to remove parts of the module name depending on the depth of the relative imports. + for module, imports in relative_imports: + level = 0 + while module.startswith("."): + module = module[1:] + level += 1 + + if len(module) > 0: + dep_parts = module_parts[: len(module_parts) - level] + module.split(".") + else: + dep_parts = module_parts[: len(module_parts) - level] + imported_module = os.path.sep.join(dep_parts) + imported_modules.append((imported_module, [imp.strip() for imp in imports.split(",")])) + + # Let's continue with direct imports + direct_imports = _re_single_line_direct_imports.findall(content) + direct_imports = [(mod, imp) for mod, imp in direct_imports if "# tests_ignore" not in imp and imp.strip() != "("] + multiline_direct_imports = _re_multi_line_direct_imports.findall(content) + direct_imports += [(mod, imp) for mod, imp in multiline_direct_imports if "# tests_ignore" not in imp] + + # We need to find the relative path of those imports. + for module, imports in direct_imports: + import_parts = module.split(".")[1:] # ignore the name of the repo since we add it below. + dep_parts = ["src", "diffusers"] + import_parts + imported_module = os.path.sep.join(dep_parts) + imported_modules.append((imported_module, [imp.strip() for imp in imports.split(",")])) + + result = [] + # Double check we get proper modules (either a python file or a folder with an init). + for module_file, imports in imported_modules: + if (PATH_TO_REPO / f"{module_file}.py").is_file(): + module_file = f"{module_file}.py" + elif (PATH_TO_REPO / module_file).is_dir() and (PATH_TO_REPO / module_file / "__init__.py").is_file(): + module_file = os.path.sep.join([module_file, "__init__.py"]) + imports = [imp for imp in imports if len(imp) > 0 and re.match("^[A-Za-z0-9_]*$", imp)] + if len(imports) > 0: + result.append((module_file, imports)) + + if cache is not None: + cache[module_fname] = result + + return result + + +def get_module_dependencies(module_fname: str, cache: Dict[str, List[str]] = None) -> List[str]: + """ + Refines the result of `extract_imports` to remove subfolders and get a proper list of module filenames: if a file + as an import `from utils import Foo, Bar`, with `utils` being a subfolder containing many files, this will traverse + the `utils` init file to check where those dependencies come from: for instance the files utils/foo.py and utils/bar.py. + + Warning: This presupposes that all intermediate inits are properly built (with imports from the respective + submodules) and work better if objects are defined in submodules and not the intermediate init (otherwise the + intermediate init is added, and inits usually have a lot of dependencies). + + Args: + module_fname (`str`): + The name of the file of the module where we want to look at the imports (given relative to the root of + the repo). + cache (Dictionary `str` to `List[str]`, *optional*): + To speed up this function if it was previously called on `module_fname`, the cache of all previously + computed results. + + Returns: + `List[str]`: The list of module filenames imported in the input `module_fname` (with submodule imports refined). + """ + dependencies = [] + imported_modules = extract_imports(module_fname, cache=cache) + # The while loop is to recursively traverse all inits we may encounter: we will add things as we go. + while len(imported_modules) > 0: + new_modules = [] + for module, imports in imported_modules: + # If we end up in an __init__ we are often not actually importing from this init (except in the case where + # the object is fully defined in the __init__) + if module.endswith("__init__.py"): + # So we get the imports from that init then try to find where our objects come from. + new_imported_modules = extract_imports(module, cache=cache) + for new_module, new_imports in new_imported_modules: + if any(i in new_imports for i in imports): + if new_module not in dependencies: + new_modules.append((new_module, [i for i in new_imports if i in imports])) + imports = [i for i in imports if i not in new_imports] + if len(imports) > 0: + # If there are any objects lefts, they may be a submodule + path_to_module = PATH_TO_REPO / module.replace("__init__.py", "") + dependencies.extend( + [ + os.path.join(module.replace("__init__.py", ""), f"{i}.py") + for i in imports + if (path_to_module / f"{i}.py").is_file() + ] + ) + imports = [i for i in imports if not (path_to_module / f"{i}.py").is_file()] + if len(imports) > 0: + # Then if there are still objects left, they are fully defined in the init, so we keep it as a + # dependency. + dependencies.append(module) + else: + dependencies.append(module) + + imported_modules = new_modules + + return dependencies + + +def create_reverse_dependency_tree() -> List[Tuple[str, str]]: + """ + Create a list of all edges (a, b) which mean that modifying a impacts b with a going over all module and test files. + """ + cache = {} + all_modules = list(PATH_TO_DIFFUSERS.glob("**/*.py")) + list(PATH_TO_TESTS.glob("**/*.py")) + all_modules = [str(mod.relative_to(PATH_TO_REPO)) for mod in all_modules] + edges = [(dep, mod) for mod in all_modules for dep in get_module_dependencies(mod, cache=cache)] + + return list(set(edges)) + + +def get_tree_starting_at(module: str, edges: List[Tuple[str, str]]) -> List[Union[str, List[str]]]: + """ + Returns the tree starting at a given module following all edges. + + Args: + module (`str`): The module that will be the root of the subtree we want. + eges (`List[Tuple[str, str]]`): The list of all edges of the tree. + + Returns: + `List[Union[str, List[str]]]`: The tree to print in the following format: [module, [list of edges + starting at module], [list of edges starting at the preceding level], ...] + """ + vertices_seen = [module] + new_edges = [edge for edge in edges if edge[0] == module and edge[1] != module and "__init__.py" not in edge[1]] + tree = [module] + while len(new_edges) > 0: + tree.append(new_edges) + final_vertices = list({edge[1] for edge in new_edges}) + vertices_seen.extend(final_vertices) + new_edges = [ + edge + for edge in edges + if edge[0] in final_vertices and edge[1] not in vertices_seen and "__init__.py" not in edge[1] + ] + + return tree + + +def print_tree_deps_of(module, all_edges=None): + """ + Prints the tree of modules depending on a given module. + + Args: + module (`str`): The module that will be the root of the subtree we want. + all_eges (`List[Tuple[str, str]]`, *optional*): + The list of all edges of the tree. Will be set to `create_reverse_dependency_tree()` if not passed. + """ + if all_edges is None: + all_edges = create_reverse_dependency_tree() + tree = get_tree_starting_at(module, all_edges) + + # The list of lines is a list of tuples (line_to_be_printed, module) + # Keeping the modules lets us know where to insert each new lines in the list. + lines = [(tree[0], tree[0])] + for index in range(1, len(tree)): + edges = tree[index] + start_edges = {edge[0] for edge in edges} + + for start in start_edges: + end_edges = {edge[1] for edge in edges if edge[0] == start} + # We will insert all those edges just after the line showing start. + pos = 0 + while lines[pos][1] != start: + pos += 1 + lines = lines[: pos + 1] + [(" " * (2 * index) + end, end) for end in end_edges] + lines[pos + 1 :] + + for line in lines: + # We don't print the refs that where just here to help build lines. + print(line[0]) + + +def init_test_examples_dependencies() -> Tuple[Dict[str, List[str]], List[str]]: + """ + The test examples do not import from the examples (which are just scripts, not modules) so we need som extra + care initializing the dependency map, which is the goal of this function. It initializes the dependency map for + example files by linking each example to the example test file for the example framework. + + Returns: + `Tuple[Dict[str, List[str]], List[str]]`: A tuple with two elements: the initialized dependency map which is a + dict test example file to list of example files potentially tested by that test file, and the list of all + example files (to avoid recomputing it later). + """ + test_example_deps = {} + all_examples = [] + for framework in ["flax", "pytorch", "tensorflow"]: + test_files = list((PATH_TO_EXAMPLES / framework).glob("test_*.py")) + all_examples.extend(test_files) + # Remove the files at the root of examples/framework since they are not proper examples (they are eith utils + # or example test files). + examples = [ + f for f in (PATH_TO_EXAMPLES / framework).glob("**/*.py") if f.parent != PATH_TO_EXAMPLES / framework + ] + all_examples.extend(examples) + for test_file in test_files: + with open(test_file, "r", encoding="utf-8") as f: + content = f.read() + # Map all examples to the test files found in examples/framework. + test_example_deps[str(test_file.relative_to(PATH_TO_REPO))] = [ + str(e.relative_to(PATH_TO_REPO)) for e in examples if e.name in content + ] + # Also map the test files to themselves. + test_example_deps[str(test_file.relative_to(PATH_TO_REPO))].append( + str(test_file.relative_to(PATH_TO_REPO)) + ) + return test_example_deps, all_examples + + +def create_reverse_dependency_map() -> Dict[str, List[str]]: + """ + Create the dependency map from module/test filename to the list of modules/tests that depend on it recursively. + + Returns: + `Dict[str, List[str]]`: The reverse dependency map as a dictionary mapping filenames to all the filenames + depending on it recursively. This way the tests impacted by a change in file A are the test files in the list + corresponding to key A in this result. + """ + cache = {} + # Start from the example deps init. + example_deps, examples = init_test_examples_dependencies() + # Add all modules and all tests to all examples + all_modules = list(PATH_TO_DIFFUSERS.glob("**/*.py")) + list(PATH_TO_TESTS.glob("**/*.py")) + examples + all_modules = [str(mod.relative_to(PATH_TO_REPO)) for mod in all_modules] + # Compute the direct dependencies of all modules. + direct_deps = {m: get_module_dependencies(m, cache=cache) for m in all_modules} + direct_deps.update(example_deps) + + # This recurses the dependencies + something_changed = True + while something_changed: + something_changed = False + for m in all_modules: + for d in direct_deps[m]: + # We stop recursing at an init (cause we always end up in the main init and we don't want to add all + # files which the main init imports) + if d.endswith("__init__.py"): + continue + if d not in direct_deps: + raise ValueError(f"KeyError:{d}. From {m}") + new_deps = set(direct_deps[d]) - set(direct_deps[m]) + if len(new_deps) > 0: + direct_deps[m].extend(list(new_deps)) + something_changed = True + + # Finally we can build the reverse map. + reverse_map = collections.defaultdict(list) + for m in all_modules: + for d in direct_deps[m]: + reverse_map[d].append(m) + + # For inits, we don't do the reverse deps but the direct deps: if modifying an init, we want to make sure we test + # all the modules impacted by that init. + for m in [f for f in all_modules if f.endswith("__init__.py")]: + direct_deps = get_module_dependencies(m, cache=cache) + deps = sum([reverse_map[d] for d in direct_deps if not d.endswith("__init__.py")], direct_deps) + reverse_map[m] = list(set(deps) - {m}) + + return reverse_map + + +def create_module_to_test_map( + reverse_map: Dict[str, List[str]] = None, filter_models: bool = False +) -> Dict[str, List[str]]: + """ + Extract the tests from the reverse_dependency_map and potentially filters the model tests. + + Args: + reverse_map (`Dict[str, List[str]]`, *optional*): + The reverse dependency map as created by `create_reverse_dependency_map`. Will default to the result of + that function if not provided. + filter_models (`bool`, *optional*, defaults to `False`): + Whether or not to filter model tests to only include core models if a file impacts a lot of models. + + Returns: + `Dict[str, List[str]]`: A dictionary that maps each file to the tests to execute if that file was modified. + """ + if reverse_map is None: + reverse_map = create_reverse_dependency_map() + + # Utility that tells us if a given file is a test (taking test examples into account) + def is_test(fname): + if fname.startswith("tests"): + return True + if fname.startswith("examples") and fname.split(os.path.sep)[-1].startswith("test"): + return True + return False + + # Build the test map + test_map = {module: [f for f in deps if is_test(f)] for module, deps in reverse_map.items()} + + if not filter_models: + return test_map + + # Now we deal with the filtering if `filter_models` is True. + num_model_tests = len(list(PATH_TO_TESTS.glob("models/*"))) + + def has_many_models(tests): + # We filter to core models when a given file impacts more than half the model tests. + model_tests = {Path(t).parts[2] for t in tests if t.startswith("tests/models/")} + return len(model_tests) > num_model_tests // 2 + + def filter_tests(tests): + return [t for t in tests if not t.startswith("tests/models/") or Path(t).parts[2] in IMPORTANT_PIPELINES] + + return {module: (filter_tests(tests) if has_many_models(tests) else tests) for module, tests in test_map.items()} + + +def check_imports_all_exist(): + """ + Isn't used per se by the test fetcher but might be used later as a quality check. Putting this here for now so the + code is not lost. This checks all imports in a given file do exist. + """ + cache = {} + all_modules = list(PATH_TO_DIFFUSERS.glob("**/*.py")) + list(PATH_TO_TESTS.glob("**/*.py")) + all_modules = [str(mod.relative_to(PATH_TO_REPO)) for mod in all_modules] + direct_deps = {m: get_module_dependencies(m, cache=cache) for m in all_modules} + + for module, deps in direct_deps.items(): + for dep in deps: + if not (PATH_TO_REPO / dep).is_file(): + print(f"{module} has dependency on {dep} which does not exist.") + + +def _print_list(l) -> str: + """ + Pretty print a list of elements with one line per element and a - starting each line. + """ + return "\n".join([f"- {f}" for f in l]) + + +def create_json_map(test_files_to_run: List[str], json_output_file: str): + """ + Creates a map from a list of tests to run to easily split them by category, when running parallelism of slow tests. + + Args: + test_files_to_run (`List[str]`): The list of tests to run. + json_output_file (`str`): The path where to store the built json map. + """ + if json_output_file is None: + return + + test_map = {} + for test_file in test_files_to_run: + # `test_file` is a path to a test folder/file, starting with `tests/`. For example, + # - `tests/models/bert/test_modeling_bert.py` or `tests/models/bert` + # - `tests/trainer/test_trainer.py` or `tests/trainer` + # - `tests/test_modeling_common.py` + names = test_file.split(os.path.sep) + module = names[1] + if module in MODULES_TO_IGNORE: + continue + + if len(names) > 2 or not test_file.endswith(".py"): + # test folders under `tests` or python files under them + # take the part like tokenization, `pipeline`, etc. for other test categories + key = os.path.sep.join(names[1:2]) + else: + # common test files directly under `tests/` + key = "common" + + if key not in test_map: + test_map[key] = [] + test_map[key].append(test_file) + + # sort the keys & values + keys = sorted(test_map.keys()) + test_map = {k: " ".join(sorted(test_map[k])) for k in keys} + with open(json_output_file, "w", encoding="UTF-8") as fp: + json.dump(test_map, fp, ensure_ascii=False) + + +def infer_tests_to_run( + output_file: str, + diff_with_last_commit: bool = False, + filter_models: bool = True, + json_output_file: Optional[str] = None, +): + """ + The main function called by the test fetcher. Determines the tests to run from the diff. + + Args: + output_file (`str`): + The path where to store the summary of the test fetcher analysis. Other files will be stored in the same + folder: + + - examples_test_list.txt: The list of examples tests to run. + - test_repo_utils.txt: Will indicate if the repo utils tests should be run or not. + - doctest_list.txt: The list of doctests to run. + + diff_with_last_commit (`bool`, *optional*, defaults to `False`): + Whether to analyze the diff with the last commit (for use on the main branch after a PR is merged) or with + the branching point from main (for use on each PR). + filter_models (`bool`, *optional*, defaults to `True`): + Whether or not to filter the tests to core models only, when a file modified results in a lot of model + tests. + json_output_file (`str`, *optional*): + The path where to store the json file mapping categories of tests to tests to run (used for parallelism or + the slow tests). + """ + modified_files = get_modified_python_files(diff_with_last_commit=diff_with_last_commit) + print(f"\n### MODIFIED FILES ###\n{_print_list(modified_files)}") + # Create the map that will give us all impacted modules. + reverse_map = create_reverse_dependency_map() + impacted_files = modified_files.copy() + for f in modified_files: + if f in reverse_map: + impacted_files.extend(reverse_map[f]) + + # Remove duplicates + impacted_files = sorted(set(impacted_files)) + print(f"\n### IMPACTED FILES ###\n{_print_list(impacted_files)}") + + # Grab the corresponding test files: + if any(x in modified_files for x in ["setup.py"]): + test_files_to_run = ["tests", "examples"] + # in order to trigger pipeline tests even if no code change at all + elif "tests/utils/tiny_model_summary.json" in modified_files: + test_files_to_run = ["tests"] + any(f.split(os.path.sep)[0] == "utils" for f in modified_files) + else: + # All modified tests need to be run. + test_files_to_run = [ + f for f in modified_files if f.startswith("tests") and f.split(os.path.sep)[-1].startswith("test") + ] + # Then we grab the corresponding test files. + test_map = create_module_to_test_map(reverse_map=reverse_map, filter_models=filter_models) + for f in modified_files: + if f in test_map: + test_files_to_run.extend(test_map[f]) + test_files_to_run = sorted(set(test_files_to_run)) + # Make sure we did not end up with a test file that was removed + test_files_to_run = [f for f in test_files_to_run if (PATH_TO_REPO / f).exists()] + + any(f.split(os.path.sep)[0] == "utils" for f in modified_files) + + examples_tests_to_run = [f for f in test_files_to_run if f.startswith("examples")] + test_files_to_run = [f for f in test_files_to_run if not f.startswith("examples")] + print(f"\n### TEST TO RUN ###\n{_print_list(test_files_to_run)}") + if len(test_files_to_run) > 0: + with open(output_file, "w", encoding="utf-8") as f: + f.write(" ".join(test_files_to_run)) + + # Create a map that maps test categories to test files, i.e. `models/bert` -> [...test_modeling_bert.py, ...] + + # Get all test directories (and some common test files) under `tests` and `tests/models` if `test_files_to_run` + # contains `tests` (i.e. when `setup.py` is changed). + if "tests" in test_files_to_run: + test_files_to_run = get_all_tests() + + create_json_map(test_files_to_run, json_output_file) + + print(f"\n### EXAMPLES TEST TO RUN ###\n{_print_list(examples_tests_to_run)}") + if len(examples_tests_to_run) > 0: + # We use `all` in the case `commit_flags["test_all"]` as well as in `create_circleci_config.py` for processing + if examples_tests_to_run == ["examples"]: + examples_tests_to_run = ["all"] + example_file = Path(output_file).parent / "examples_test_list.txt" + with open(example_file, "w", encoding="utf-8") as f: + f.write(" ".join(examples_tests_to_run)) + + +def filter_tests(output_file: str, filters: List[str]): + """ + Reads the content of the output file and filters out all the tests in a list of given folders. + + Args: + output_file (`str` or `os.PathLike`): The path to the output file of the tests fetcher. + filters (`List[str]`): A list of folders to filter. + """ + if not os.path.isfile(output_file): + print("No test file found.") + return + with open(output_file, "r", encoding="utf-8") as f: + test_files = f.read().split(" ") + + if len(test_files) == 0 or test_files == [""]: + print("No tests to filter.") + return + + if test_files == ["tests"]: + test_files = [os.path.join("tests", f) for f in os.listdir("tests") if f not in ["__init__.py"] + filters] + else: + test_files = [f for f in test_files if f.split(os.path.sep)[1] not in filters] + + with open(output_file, "w", encoding="utf-8") as f: + f.write(" ".join(test_files)) + + +def parse_commit_message(commit_message: str) -> Dict[str, bool]: + """ + Parses the commit message to detect if a command is there to skip, force all or part of the CI. + + Args: + commit_message (`str`): The commit message of the current commit. + + Returns: + `Dict[str, bool]`: A dictionary of strings to bools with keys the following keys: `"skip"`, + `"test_all_models"` and `"test_all"`. + """ + if commit_message is None: + return {"skip": False, "no_filter": False, "test_all": False} + + command_search = re.search(r"\[([^\]]*)\]", commit_message) + if command_search is not None: + command = command_search.groups()[0] + command = command.lower().replace("-", " ").replace("_", " ") + skip = command in ["ci skip", "skip ci", "circleci skip", "skip circleci"] + no_filter = set(command.split(" ")) == {"no", "filter"} + test_all = set(command.split(" ")) == {"test", "all"} + return {"skip": skip, "no_filter": no_filter, "test_all": test_all} + else: + return {"skip": False, "no_filter": False, "test_all": False} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--output_file", type=str, default="test_list.txt", help="Where to store the list of tests to run" + ) + parser.add_argument( + "--json_output_file", + type=str, + default="test_map.json", + help="Where to store the tests to run in a dictionary format mapping test categories to test files", + ) + parser.add_argument( + "--diff_with_last_commit", + action="store_true", + help="To fetch the tests between the current commit and the last commit", + ) + parser.add_argument( + "--filter_tests", + action="store_true", + help="Will filter the pipeline/repo utils tests outside of the generated list of tests.", + ) + parser.add_argument( + "--print_dependencies_of", + type=str, + help="Will only print the tree of modules depending on the file passed.", + default=None, + ) + parser.add_argument( + "--commit_message", + type=str, + help="The commit message (which could contain a command to force all tests or skip the CI).", + default=None, + ) + args = parser.parse_args() + if args.print_dependencies_of is not None: + print_tree_deps_of(args.print_dependencies_of) + elif args.filter_tests: + filter_tests(args.output_file, ["pipelines", "repo_utils"]) + else: + repo = Repo(PATH_TO_REPO) + commit_message = repo.head.commit.message + commit_flags = parse_commit_message(commit_message) + if commit_flags["skip"]: + print("Force-skipping the CI") + quit() + if commit_flags["no_filter"]: + print("Running all tests fetched without filtering.") + if commit_flags["test_all"]: + print("Force-launching all tests") + + diff_with_last_commit = args.diff_with_last_commit + if not diff_with_last_commit and not repo.head.is_detached and repo.head.ref == repo.refs.main: + print("main branch detected, fetching tests against last commit.") + diff_with_last_commit = True + + if not commit_flags["test_all"]: + try: + infer_tests_to_run( + args.output_file, + diff_with_last_commit=diff_with_last_commit, + json_output_file=args.json_output_file, + filter_models=not commit_flags["no_filter"], + ) + filter_tests(args.output_file, ["repo_utils"]) + except Exception as e: + print(f"\nError when trying to grab the relevant tests: {e}\n\nRunning all tests.") + commit_flags["test_all"] = True + + if commit_flags["test_all"]: + with open(args.output_file, "w", encoding="utf-8") as f: + f.write("tests") + example_file = Path(args.output_file).parent / "examples_test_list.txt" + with open(example_file, "w", encoding="utf-8") as f: + f.write("all") + + test_files_to_run = get_all_tests() + create_json_map(test_files_to_run, args.json_output_file) From 81780882b8cfb7628a2e09dbaae566ead5d760e8 Mon Sep 17 00:00:00 2001 From: Aryan V S Date: Tue, 21 Nov 2023 19:52:20 +0530 Subject: [PATCH 3/8] Addition of new callbacks to controlnets (#5812) * add new callbacks to src/diffusers/pipelines/controlnet/pipeline_controlnet.py * update callbacks * fix repeated kwarg * update --------- Co-authored-by: Patrick von Platen --- .../controlnet/pipeline_controlnet.py | 78 +++++++++-- .../controlnet/pipeline_controlnet_img2img.py | 115 ++++++++++++---- .../controlnet/pipeline_controlnet_inpaint.py | 119 +++++++++++++---- .../pipeline_controlnet_inpaint_sd_xl.py | 123 +++++++++++++----- .../controlnet/pipeline_controlnet_sd_xl.py | 91 ++++++++++--- .../pipeline_controlnet_sd_xl_img2img.py | 117 +++++++++++++---- 6 files changed, 510 insertions(+), 133 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index db17e2b7c181..4f625304fdf9 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -130,6 +130,7 @@ class StableDiffusionControlNetPipeline( model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, @@ -485,15 +486,21 @@ def check_inputs( controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -760,6 +767,10 @@ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32 def guidance_scale(self): return self._guidance_scale + @property + def clip_skip(self): + return self._clip_skip + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @@ -767,6 +778,14 @@ def guidance_scale(self): def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -786,14 +805,15 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, ): r""" The call function to the pipeline for generation. @@ -868,6 +888,15 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. Examples: @@ -878,6 +907,23 @@ def __call__( second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance @@ -903,9 +949,12 @@ def __call__( controlnet_conditioning_scale, control_guidance_start, control_guidance_end, + callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -929,7 +978,7 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, @@ -940,7 +989,7 @@ def __call__( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, + clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch @@ -988,6 +1037,7 @@ def __call__( # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels @@ -1078,7 +1128,7 @@ def __call__( t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, - cross_attention_kwargs=cross_attention_kwargs, + cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, return_dict=False, @@ -1087,11 +1137,21 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 24e4050939c8..8945bd3d9c81 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -164,6 +164,7 @@ class StableDiffusionControlNetImg2ImgPipeline( model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, @@ -519,15 +520,21 @@ def check_inputs( controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -808,6 +815,29 @@ def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -829,14 +859,15 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 0.8, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, ): r""" The call function to the pipeline for generation. @@ -892,12 +923,6 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). @@ -915,6 +940,15 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. Examples: @@ -925,6 +959,23 @@ def __call__( second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance @@ -950,8 +1001,13 @@ def __call__( controlnet_conditioning_scale, control_guidance_start, control_guidance_end, + callback_on_step_end_tensor_inputs, ) + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -961,10 +1017,6 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) @@ -978,23 +1030,23 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, - do_classifier_free_guidance, + self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, + clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare image @@ -1010,7 +1062,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) elif isinstance(controlnet, MultiControlNetModel): @@ -1025,7 +1077,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) @@ -1039,6 +1091,7 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + self._num_timesteps = len(timesteps) # 6. Prepare latent variables latents = self.prepare_latents( @@ -1068,11 +1121,11 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) @@ -1099,7 +1152,7 @@ def __call__( return_dict=False, ) - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. @@ -1111,20 +1164,30 @@ def __call__( latent_model_input, t, encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, + cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, return_dict=False, )[0] # perform guidance - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 99c72d21e22e..9e2e428eaf91 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -286,6 +286,7 @@ class StableDiffusionControlNetInpaintPipeline( model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, @@ -656,18 +657,24 @@ def check_inputs( controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, ): if height is not None and height % 8 != 0 or width is not None and width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -999,6 +1006,29 @@ def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1021,14 +1051,15 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 0.5, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, ): r""" The call function to the pipeline for generation. @@ -1101,12 +1132,6 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). @@ -1124,6 +1149,15 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. Examples: @@ -1134,6 +1168,23 @@ def __call__( second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance @@ -1161,8 +1212,13 @@ def __call__( controlnet_conditioning_scale, control_guidance_start, control_guidance_end, + callback_on_step_end_tensor_inputs, ) + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -1172,10 +1228,6 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) @@ -1189,23 +1241,23 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, - do_classifier_free_guidance, + self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, + clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare image @@ -1218,7 +1270,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) elif isinstance(controlnet, MultiControlNetModel): @@ -1233,7 +1285,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) @@ -1261,6 +1313,7 @@ def __call__( latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise is_strength_max = strength == 1.0 + self._num_timesteps = len(timesteps) # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels @@ -1297,7 +1350,7 @@ def __call__( prompt_embeds.dtype, device, generator, - do_classifier_free_guidance, + self.do_classifier_free_guidance, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline @@ -1317,11 +1370,11 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) @@ -1348,7 +1401,7 @@ def __call__( return_dict=False, ) - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. @@ -1363,14 +1416,14 @@ def __call__( latent_model_input, t, encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, + cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, return_dict=False, )[0] # perform guidance - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) @@ -1379,7 +1432,7 @@ def __call__( if num_channels_unet == 4: init_latents_proper = image_latents - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: init_mask, _ = mask.chunk(2) else: init_mask = mask @@ -1392,6 +1445,16 @@ def __call__( latents = (1 - init_mask) * init_latents_proper + init_mask * latents + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 856ebbe6bbb5..3e5cba79f50b 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -34,6 +34,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, + deprecate, is_invisible_watermark_available, logging, replace_example_docstring, @@ -167,6 +168,7 @@ class StableDiffusionXLControlNetInpaintPipeline( model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, @@ -555,6 +557,7 @@ def check_inputs( controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") @@ -565,14 +568,20 @@ def check_inputs( f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" f" {type(num_inference_steps)}." ) - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -1008,6 +1017,29 @@ def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1039,8 +1071,6 @@ def __call__( negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, @@ -1053,6 +1083,9 @@ def __call__( aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, ): r""" Function invoked when calling the pipeline for generation. @@ -1147,12 +1180,6 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in @@ -1182,6 +1209,15 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. Examples: @@ -1190,6 +1226,23 @@ def __call__( [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. `tuple. When returning a tuple, the first element is a list with the generated images. """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance @@ -1237,8 +1290,13 @@ def __call__( controlnet_conditioning_scale, control_guidance_start, control_guidance_end, + callback_on_step_end_tensor_inputs, ) + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -1248,17 +1306,13 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( @@ -1271,7 +1325,7 @@ def __call__( prompt_2=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, @@ -1279,7 +1333,7 @@ def __call__( pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, + clip_skip=self.clip_skip, ) # 4. set timesteps @@ -1300,6 +1354,7 @@ def denoising_value_valid(dnv): latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise is_strength_max = strength == 1.0 + self._num_timesteps = len(timesteps) # 5. Preprocess mask and image - resizes image and mask w.r.t height and width # 5.1 Prepare init image @@ -1316,7 +1371,7 @@ def denoising_value_valid(dnv): num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) elif isinstance(controlnet, MultiControlNetModel): @@ -1331,7 +1386,7 @@ def denoising_value_valid(dnv): num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) @@ -1385,7 +1440,7 @@ def denoising_value_valid(dnv): prompt_embeds.dtype, device, generator, - do_classifier_free_guidance, + self.do_classifier_free_guidance, ) # 8. Check that sizes of mask, masked image and latents match @@ -1446,7 +1501,7 @@ def denoising_value_valid(dnv): ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) @@ -1483,7 +1538,7 @@ def denoising_value_valid(dnv): with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1491,7 +1546,7 @@ def denoising_value_valid(dnv): added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # controlnet(s) inference - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) @@ -1528,7 +1583,7 @@ def denoising_value_valid(dnv): return_dict=False, ) - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. @@ -1543,7 +1598,7 @@ def denoising_value_valid(dnv): latent_model_input, t, encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, + cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, @@ -1551,11 +1606,11 @@ def denoising_value_valid(dnv): )[0] # perform guidance - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - if do_classifier_free_guidance and guidance_rescale > 0.0: + if self.do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) @@ -1564,7 +1619,7 @@ def denoising_value_valid(dnv): if num_channels_unet == 4: init_latents_proper = image_latents - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: init_mask, _ = mask.chunk(2) else: init_mask = mask @@ -1577,6 +1632,16 @@ def denoising_value_valid(dnv): latents = (1 - init_mask) * init_latents_proper + init_mask * latents + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index e248a48f8ed7..c1efd8aaa397 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -35,7 +35,14 @@ ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput @@ -143,6 +150,7 @@ class StableDiffusionXLControlNetPipeline( # leave controlnet out on purpose because it iterates with unet model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, @@ -487,15 +495,21 @@ def check_inputs( controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -825,6 +839,10 @@ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32 def guidance_scale(self): return self._guidance_scale + @property + def clip_skip(self): + return self._clip_skip + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @@ -832,6 +850,14 @@ def guidance_scale(self): def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -855,8 +881,6 @@ def __call__( negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, @@ -869,6 +893,9 @@ def __call__( negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, ): r""" The call function to the pipeline for generation. @@ -937,12 +964,6 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). @@ -989,6 +1010,15 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. Examples: @@ -997,6 +1027,23 @@ def __call__( If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned containing the output images. """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance @@ -1026,9 +1073,12 @@ def __call__( controlnet_conditioning_scale, control_guidance_start, control_guidance_end, + callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1052,7 +1102,7 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( prompt_embeds, @@ -1072,7 +1122,7 @@ def __call__( pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, + clip_skip=self.clip_skip, ) # 4. Prepare image @@ -1115,6 +1165,7 @@ def __call__( # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels @@ -1254,7 +1305,7 @@ def __call__( t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, - cross_attention_kwargs=cross_attention_kwargs, + cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, @@ -1269,6 +1320,16 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 3926eba33024..4fccd6a91b0f 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -37,6 +37,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, + deprecate, logging, replace_example_docstring, scale_lora_layers, @@ -195,6 +196,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, @@ -543,6 +545,7 @@ def check_inputs( controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") @@ -553,14 +556,20 @@ def check_inputs( f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" f" {type(num_inference_steps)}." ) - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -951,6 +960,29 @@ def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -976,8 +1008,6 @@ def __call__( negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 0.8, guess_mode: bool = False, @@ -992,6 +1022,9 @@ def __call__( aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, ): r""" Function invoked when calling the pipeline for generation. @@ -1077,12 +1110,6 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in @@ -1138,6 +1165,15 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. Examples: @@ -1146,6 +1182,23 @@ def __call__( [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple` containing the output images. """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance @@ -1177,8 +1230,13 @@ def __call__( controlnet_conditioning_scale, control_guidance_start, control_guidance_end, + callback_on_step_end_tensor_inputs, ) + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -1188,10 +1246,6 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) @@ -1205,7 +1259,7 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( prompt_embeds, @@ -1217,7 +1271,7 @@ def __call__( prompt_2, device, num_images_per_prompt, - do_classifier_free_guidance, + self.do_classifier_free_guidance, negative_prompt, negative_prompt_2, prompt_embeds=prompt_embeds, @@ -1225,7 +1279,7 @@ def __call__( pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, + clip_skip=self.clip_skip, ) # 4. Prepare image and controlnet_conditioning_image @@ -1240,7 +1294,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) height, width = control_image.shape[-2:] @@ -1256,7 +1310,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) @@ -1271,6 +1325,7 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + self._num_timesteps = len(timesteps) # 6. Prepare latent variables latents = self.prepare_latents( @@ -1328,7 +1383,7 @@ def __call__( ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) @@ -1343,13 +1398,13 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # controlnet(s) inference - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) @@ -1382,7 +1437,7 @@ def __call__( return_dict=False, ) - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. @@ -1394,7 +1449,7 @@ def __call__( latent_model_input, t, encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, + cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, @@ -1402,13 +1457,23 @@ def __call__( )[0] # perform guidance - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() From 1093f9d615abaad27347279a889e1dba5f0df39c Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Tue, 21 Nov 2023 06:27:41 -0800 Subject: [PATCH 4/8] [docs] MusicLDM (#5854) * fix * feedback --- src/diffusers/pipelines/musicldm/pipeline_musicldm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/musicldm/pipeline_musicldm.py b/src/diffusers/pipelines/musicldm/pipeline_musicldm.py index 9e6b6fea13e5..68af3925fa02 100644 --- a/src/diffusers/pipelines/musicldm/pipeline_musicldm.py +++ b/src/diffusers/pipelines/musicldm/pipeline_musicldm.py @@ -51,7 +51,7 @@ >>> import torch >>> import scipy - >>> repo_id = "cvssp/audioldm-s-full-v2" + >>> repo_id = "ucsd-reach/musicldm" >>> pipe = MusicLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") From 6fac1369d0140ebeafdaeeef3f558f21fa3d5108 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Tue, 21 Nov 2023 18:38:43 +0200 Subject: [PATCH 5/8] Add features to the Dreambooth LoRA SDXL training script (#5508) * Additions: - support for different lr for text encoder - support for Prodigy optimizer - support for min snr gamma - support for custom captions and dataset loading from the hub * adjusted --caption_column behaviour (to -not- use the second column of the dataset by default if --caption_column is not provided) * fixed --output_dir / --model_dir_name confusion * added --repeats, --adam_weight_decay_text_encoder + some fixes * Update examples/dreambooth/train_dreambooth_lora_sdxl.py Co-authored-by: Patrick von Platen * Update examples/dreambooth/train_dreambooth_lora_sdxl.py Co-authored-by: Patrick von Platen * Update examples/dreambooth/train_dreambooth_lora_sdxl.py Co-authored-by: Patrick von Platen * - import compute_snr from diffusers/training_utils.py - cluster adamw together - when using 'prodigy', if --train_text_encoder == True and --text_encoder_lr != --learning rate, changes the lr of the text encoders optimization params to be --learning_rate (otherwise errors) * shape fixes when custom captions are used * formatting and a little cleanup * code styling * --repeats default value fixed, changed to 1 * bug fix - removed redundant lines of embedding concatenation when using prior_preservation (that duplicated class_prompt embeddings) * changed dataset loading logic according to the following usecases (to avoid unnecessary dependency on datasets)- 1. user provides --dataset_name 2. user provides local dir --instance_data_dir that contains a metadata .jsonl file 3. user provides local dir --instance_data_dir that contains only images in cases [1,2] we import datasets and use load_dataset method, in case [3] we process the data same as in the original script setting * styling fix * arg name fix * adjusted the --repeats logic * -removed redundant arg and 'if' when loading local folder with prompts -updated readme template -some default val fixes -custom caption tests * image path fix for readme * code style * bug fix * --caption_column arg * readme fix --------- Co-authored-by: Patrick von Platen Co-authored-by: Linoy Tsaban --- .../dreambooth/train_dreambooth_lora_sdxl.py | 524 ++++++++++++++---- examples/test_examples.py | 43 ++ 2 files changed, 468 insertions(+), 99 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 9baa137656f0..97b60c8f527d 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -52,7 +52,7 @@ from diffusers.loaders import LoraLoaderMixin from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict from diffusers.optimization import get_scheduler -from diffusers.training_utils import unet_lora_state_dict +from diffusers.training_utils import compute_snr, unet_lora_state_dict from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -64,36 +64,65 @@ def save_model_card( - repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None + repo_id: str, + images=None, + base_model=str, + train_text_encoder=False, + instance_prompt=str, + validation_prompt=str, + repo_folder=None, + vae_path=None, ): - img_str = "" + img_str = "widget:\n" if images else "" for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) - img_str += f"![img_{i}](./image_{i}.png)\n" + img_str += f""" + - text: '{validation_prompt if validation_prompt else ' ' }' + output: + url: >- + "image_{i}.png" + """ yaml = f""" --- -license: openrail++ -base_model: {base_model} -instance_prompt: {prompt} tags: - stable-diffusion-xl - stable-diffusion-xl-diffusers - text-to-image - diffusers - lora -inference: true +- template:sd-lora +widget: +{img_str} +--- +base_model: {base_model} +instance_prompt: {instance_prompt} +license: openrail++ --- """ + model_card = f""" -# LoRA DreamBooth - {repo_id} +# SDXL LoRA DreamBooth - {repo_id} -These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n -{img_str} + -LoRA for the text encoder was enabled: {train_text_encoder}. +## Model description +These are {repo_id} LoRA adaption weights for {base_model}. +The weights were trained using [DreamBooth](https://dreambooth.github.io/). +LoRA for the text encoder was enabled: {train_text_encoder}. Special VAE used for training: {vae_path}. + +## Trigger words + +You should use {instance_prompt} to trigger the image generation. + +## Download model + +Weights for this model are available in Safetensors format. + +[Download]({repo_id}/tree/main) them in the Files & versions tab. + """ with open(os.path.join(repo_folder, "README.md"), "w") as f: f.write(yaml + model_card) @@ -141,13 +170,53 @@ def parse_args(input_args=None): required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) parser.add_argument( "--instance_data_dir", type=str, default=None, - required=True, - help="A folder containing the training data of instance images.", + help=("A folder containing the training data. "), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + parser.add_argument( "--class_data_dir", type=str, @@ -160,7 +229,7 @@ def parse_args(input_args=None): type=str, default=None, required=True, - help="The prompt with identifier specifying the instance", + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", ) parser.add_argument( "--class_prompt", @@ -299,9 +368,16 @@ def parse_args(input_args=None): parser.add_argument( "--learning_rate", type=float, - default=5e-4, + default=1e-4, help="Initial learning rate (after the potential warmup period) to use.", ) + + parser.add_argument( + "--text_encoder_lr", + type=float, + default=5e-6, + help="Text encoder learning rate to use.", + ) parser.add_argument( "--scale_lr", action="store_true", @@ -317,6 +393,14 @@ def parse_args(input_args=None): ' "constant", "constant_with_warmup"]' ), ) + + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) @@ -335,13 +419,59 @@ def parse_args(input_args=None): "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ), ) + + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") parser.add_argument( - "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + ) + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", ) - parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") - parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") - parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") - parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") @@ -414,6 +544,12 @@ def parse_args(input_args=None): else: args = parser.parse_args() + if args.dataset_name is None and args.instance_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.instance_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank @@ -442,20 +578,84 @@ class DreamBoothDataset(Dataset): def __init__( self, instance_data_root, + instance_prompt, + class_prompt, class_data_root=None, class_num=None, size=1024, + repeats=1, center_crop=False, ): self.size = size self.center_crop = center_crop - self.instance_data_root = Path(instance_data_root) - if not self.instance_data_root.exists(): - raise ValueError("Instance images root doesn't exists.") + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + self.class_prompt = class_prompt + + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if args.dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + + if args.caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") - self.instance_images_path = list(Path(instance_data_root).iterdir()) - self.num_instance_images = len(self.instance_images_path) + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None + + self.instance_images = [] + for img in instance_images: + self.instance_images.extend(itertools.repeat(img, repeats)) + self.num_instance_images = len(self.instance_images) self._length = self.num_instance_images if class_data_root is not None: @@ -484,13 +684,23 @@ def __len__(self): def __getitem__(self, index): example = {} - instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + instance_image = self.instance_images[index % self.num_instance_images] instance_image = exif_transpose(instance_image) if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") example["instance_images"] = self.image_transforms(instance_image) + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + example["instance_prompt"] = caption + else: + example["instance_prompt"] = self.instance_prompt + + else: # costum prompts were provided, but length does not match size of image dataset + example["instance_prompt"] = self.instance_prompt + if self.class_data_root: class_image = Image.open(self.class_images_path[index % self.num_class_images]) class_image = exif_transpose(class_image) @@ -498,22 +708,25 @@ def __getitem__(self, index): if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image) + example["class_prompt"] = self.class_prompt return example def collate_fn(examples, with_prior_preservation=False): pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] # Concat class and instance examples for prior preservation. # We do this to avoid doing two forward passes. if with_prior_preservation: pixel_values += [example["class_images"] for example in examples] + prompts += [example["class_prompt"] for example in examples] pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - batch = {"pixel_values": pixel_values} + batch = {"pixel_values": pixel_values, "prompts": prompts} return batch @@ -732,7 +945,8 @@ def main(args): xformers_version = version.parse(xformers.__version__) if xformers_version == version.parse("0.0.16"): logger.warn( - "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, " + "please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." ) unet.enable_xformers_memory_efficient_attention() else: @@ -866,35 +1080,119 @@ def load_model_hook(models, input_dir): args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) - # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs - if args.use_8bit_adam: + # Optimization parameters + unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate} + if args.train_text_encoder: + # different learning rate for text encoder and unet + text_lora_parameters_one_with_lr = { + "params": text_lora_parameters_one, + "weight_decay": args.adam_weight_decay_text_encoder, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + } + text_lora_parameters_two_with_lr = { + "params": text_lora_parameters_two, + "weight_decay": args.adam_weight_decay_text_encoder, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + } + params_to_optimize = [ + unet_lora_parameters_with_lr, + text_lora_parameters_one_with_lr, + text_lora_parameters_two_with_lr, + ] + else: + params_to_optimize = [unet_lora_parameters_with_lr] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warn( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warn( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": try: - import bitsandbytes as bnb + import prodigyopt except ImportError: - raise ImportError( - "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." - ) + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") - optimizer_class = bnb.optim.AdamW8bit - else: - optimizer_class = torch.optim.AdamW + optimizer_class = prodigyopt.Prodigy - # Optimizer creation - params_to_optimize = ( - itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two) - if args.train_text_encoder - else unet_lora_parameters + if args.learning_rate <= 0.1: + logger.warn( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + if args.train_text_encoder and args.text_encoder_lr: + logger.warn( + f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:" + f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " + f"When using prodigy only learning_rate is used as the initial learning rate." + ) + # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be + # --learning_rate + params_to_optimize[1]["lr"] = args.learning_rate + params_to_optimize[2]["lr"] = args.learning_rate + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_prompt=args.class_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_num=args.num_class_images, + size=args.resolution, + repeats=args.repeats, + center_crop=args.center_crop, ) - optimizer = optimizer_class( - params_to_optimize, - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, ) # Computes additional embeddings/ids required by the SDXL UNet. - # regular text emebddings (when `train_text_encoder` is not True) + # regular text embeddings (when `train_text_encoder` is not True) # pooled text embeddings # time ids @@ -921,7 +1219,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Handle instance prompt. instance_time_ids = compute_time_ids() - if not args.train_text_encoder: + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. + if not args.train_text_encoder and not train_dataset.custom_instance_prompts: instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( args.instance_prompt, text_encoders, tokenizers ) @@ -934,49 +1236,36 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): args.class_prompt, text_encoders, tokenizers ) - # Clear the memory here. - if not args.train_text_encoder: + # Clear the memory here + if not args.train_text_encoder and not train_dataset.custom_instance_prompts: del tokenizers, text_encoders gc.collect() torch.cuda.empty_cache() - # Pack the statically computed variables appropriately. This is so that we don't + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), + # pack the statically computed variables appropriately here. This is so that we don't # have to pass them to the dataloader. add_time_ids = instance_time_ids if args.with_prior_preservation: add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0) - if not args.train_text_encoder: - prompt_embeds = instance_prompt_hidden_states - unet_add_text_embeds = instance_pooled_prompt_embeds - if args.with_prior_preservation: - prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) - unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0) - else: - tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt) - tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt) - if args.with_prior_preservation: - class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt) - class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt) - tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) - tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) - - # Dataset and DataLoaders creation: - train_dataset = DreamBoothDataset( - instance_data_root=args.instance_data_dir, - class_data_root=args.class_data_dir if args.with_prior_preservation else None, - class_num=args.num_class_images, - size=args.resolution, - center_crop=args.center_crop, - ) - - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - batch_size=args.train_batch_size, - shuffle=True, - collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), - num_workers=args.dataloader_num_workers, - ) + if not train_dataset.custom_instance_prompts: + if not args.train_text_encoder: + prompt_embeds = instance_prompt_hidden_states + unet_add_text_embeds = instance_pooled_prompt_embeds + if args.with_prior_preservation: + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0) + # if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the + # batch prompts on all training steps + else: + tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt) + tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt) + if args.with_prior_preservation: + class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt) + class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt) + tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) + tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -1079,6 +1368,17 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + prompts = batch["prompts"] + + # encode batch prompts when custom prompts are provided for each image - + if train_dataset.custom_instance_prompts: + if not args.train_text_encoder: + prompt_embeds, unet_add_text_embeds = compute_text_embeddings( + prompts, text_encoders, tokenizers + ) + else: + tokens_one = tokenize_prompt(tokenizer_one, prompts) + tokens_two = tokenize_prompt(tokenizer_two, prompts) # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() @@ -1099,16 +1399,21 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # (this is the forward diffusion process) noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) - # Calculate the elements to repeat depending on the use of prior-preservation. - elems_to_repeat = bsz // 2 if args.with_prior_preservation else bsz + # Calculate the elements to repeat depending on the use of prior-preservation and custom captions. + if not train_dataset.custom_instance_prompts: + elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz + elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz + else: + elems_to_repeat_text_embeds = 1 + elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz # Predict the noise residual if not args.train_text_encoder: unet_added_conditions = { - "time_ids": add_time_ids.repeat(elems_to_repeat, 1), - "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat, 1), + "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1), + "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1), } - prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1) + prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) model_pred = unet( noisy_model_input, timesteps, @@ -1116,15 +1421,17 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): added_cond_kwargs=unet_added_conditions, ).sample else: - unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat, 1)} + unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)} prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two], tokenizers=None, prompt=None, text_input_ids_list=[tokens_one, tokens_two], ) - unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)}) - prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1) + unet_added_conditions.update( + {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)} + ) + prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) model_pred = unet( noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions ).sample @@ -1142,16 +1449,34 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) - # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - # Compute prior loss prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(noise_scheduler, timesteps) + base_weight = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) + + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + if args.with_prior_preservation: # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss - else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients: @@ -1353,7 +1678,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): images=images, base_model=args.pretrained_model_name_or_path, train_text_encoder=args.train_text_encoder, - prompt=args.instance_prompt, + instance_prompt=args.instance_prompt, + validation_prompt=args.validation_prompt, repo_folder=args.output_dir, vae_path=args.pretrained_vae_model_name_or_path, ) diff --git a/examples/test_examples.py b/examples/test_examples.py index 89e866231e89..292c433a3395 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -421,6 +421,49 @@ def test_dreambooth_lora_sdxl_with_text_encoder(self): ) self.assertTrue(starts_with_unet) + def test_dreambooth_lora_sdxl_custom_captions(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora_sdxl.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe + --dataset_name hf-internal-testing/dummy_image_text_data + --caption_column text + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + + def test_dreambooth_lora_sdxl_text_encoder_custom_captions(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora_sdxl.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe + --dataset_name hf-internal-testing/dummy_image_text_data + --caption_column text + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --train_text_encoder + """.split() + + run_command(self._launch_args + test_args) + def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self): pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" From ba352aea29df9a7f086bf0815fe9fe479218f801 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 21 Nov 2023 07:34:30 -1000 Subject: [PATCH 6/8] [feat] IP Adapters (author @okotaku ) (#5713) * add ip-adapter --------- Co-authored-by: okotaku Co-authored-by: sayakpaul Co-authored-by: yiyixuxu Co-authored-by: Patrick von Platen Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- .../en/using-diffusers/loading_adapters.md | 328 ++++++++++++++++++ src/diffusers/loaders/__init__.py | 2 + src/diffusers/loaders/ip_adapter.py | 157 +++++++++ src/diffusers/loaders/unet.py | 70 ++++ src/diffusers/models/attention_processor.py | 246 +++++++++++++ src/diffusers/models/unet_2d_condition.py | 9 + src/diffusers/models/unet_motion_model.py | 16 + .../alt_diffusion/pipeline_alt_diffusion.py | 42 ++- .../pipeline_alt_diffusion_img2img.py | 37 +- .../animatediff/pipeline_animatediff.py | 37 +- .../controlnet/pipeline_controlnet.py | 38 +- .../controlnet/pipeline_controlnet_sd_xl.py | 61 +++- .../pipelines/pipeline_flax_utils.py | 5 +- src/diffusers/pipelines/pipeline_utils.py | 13 +- .../pipeline_stable_diffusion.py | 42 ++- .../pipeline_stable_diffusion_img2img.py | 38 +- .../pipeline_stable_diffusion_inpaint.py | 38 +- .../pipeline_stable_diffusion_xl.py | 54 ++- .../pipeline_stable_diffusion_xl_img2img.py | 58 +++- .../pipeline_stable_diffusion_xl_inpaint.py | 58 +++- .../versatile_diffusion/modeling_text_unet.py | 9 + tests/lora/test_lora_layers_old_backend.py | 4 + tests/lora/test_lora_layers_peft.py | 3 + tests/models/test_models_unet_2d_condition.py | 104 +++++- .../altdiffusion/test_alt_diffusion.py | 1 + .../test_alt_diffusion_img2img.py | 2 + .../pipelines/animatediff/test_animatediff.py | 2 + tests/pipelines/controlnet/test_controlnet.py | 3 + .../controlnet/test_controlnet_sdxl.py | 6 + .../test_ip_adapter_stable_diffusion.py | 221 ++++++++++++ .../stable_diffusion/test_stable_diffusion.py | 1 + .../test_stable_diffusion_img2img.py | 1 + .../test_stable_diffusion_inpaint.py | 2 + .../test_stable_diffusion.py | 1 + .../test_stable_diffusion_inpaint.py | 1 + .../test_stable_diffusion_v_pred.py | 3 + .../test_stable_diffusion_xl.py | 2 + .../test_stable_diffusion_xl_img2img.py | 39 ++- .../test_stable_diffusion_xl_inpaint.py | 37 +- tests/pipelines/test_pipelines.py | 27 +- 40 files changed, 1755 insertions(+), 63 deletions(-) create mode 100644 src/diffusers/loaders/ip_adapter.py create mode 100644 tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py diff --git a/docs/source/en/using-diffusers/loading_adapters.md b/docs/source/en/using-diffusers/loading_adapters.md index e73e042bd4d5..c14b38a9dd89 100644 --- a/docs/source/en/using-diffusers/loading_adapters.md +++ b/docs/source/en/using-diffusers/loading_adapters.md @@ -307,3 +307,331 @@ prompt = "a house by william eggleston, sunrays, beautiful, sunlight, sunrays, b image = pipeline(prompt=prompt).images[0] image ``` + +## IP-Adapter + +[IP-Adapter](https://ip-adapter.github.io/) is an effective and lightweight adapter that adds image prompting capabilities to a diffusion model. This adapter works by decoupling the cross-attention layers of the image and text features. All the other model components are frozen and only the embedded image features in the UNet are trained. As a result, IP-Adapter files are typically only ~100MBs. + +IP-Adapter works with most of our pipelines, including Stable Diffusion, Stable Diffusion XL (SDXL), ControlNet, T2I-Adapter, AnimateDiff. And you can use any custom models finetuned from the same base models. It also works with LCM-Lora out of box. + + + + +You can find official IP-Adapter checkpoints in [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter). + +IP-Adapter was contributed by [okotaku](https://github.com/okotaku). + + + +Let's first create a Stable Diffusion Pipeline. + +```py +from diffusers import AutoPipelineForText2Image +import torch +from diffusers.utils import load_image + + +pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda") +``` + +Now load the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) weights with the [`~loaders.IPAdapterMixin.load_ip_adapter`] method. + +```py +pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") +``` + + +IP-Adapter relies on an image encoder to generate the image features, if your IP-Adapter weights folder contains a "image_encoder" subfolder, the image encoder will be automatically loaded and registered to the pipeline. Otherwise you can so load a [`~transformers.CLIPVisionModelWithProjection`] model and pass it to a Stable Diffusion pipeline when you create it. + +```py +from diffusers import AutoPipelineForText2Image, CLIPVisionModelWithProjection +import torch + +image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "h94/IP-Adapter", + subfolder="models/image_encoder", + torch_dtype=torch.float16, +).to("cuda") + +pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, torch_dtype=torch.float16).to("cuda") +``` + + +IP-Adapter allows you to use both image and text to condition the image generation process. For example, let's use the bear image from the [Textual Inversion](#textual-inversion) section as the image prompt (`ip_adapter_image`) along with a text prompt to add "sunglasses". 😎 + +```py +pipeline.set_ip_adapter_scale(0.6) +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png") +generator = torch.Generator(device="cpu").manual_seed(33) +images = pipeline( +    prompt='best quality, high quality, wearing sunglasses', +    ip_adapter_image=image, +    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", +    num_inference_steps=50, +    generator=generator, +).images +images[0] +``` + +
+    +
+ + + +You can use the [`~loaders.IPAdapterMixin.set_ip_adapter_scale`] method to adjust the text prompt and image prompt condition ratio.  If you're only using the image prompt, you should set the scale to `1.0`. You can lower the scale to get more generation diversity, but it'll be less aligned with the prompt. +`scale=0.5` can achieve good results in most cases when you use both text and image prompts. + + +IP-Adapter also works great with Image-to-Image and Inpainting pipelines. See below examples of how you can use it with Image-to-Image and Inpaint. + + + + +```py +from diffusers import AutoPipelineForImage2Image +import torch +from diffusers.utils import load_image + +pipeline = AutoPipelineForImage2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda") + +image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/vermeer.jpg") +ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/river.png") + +pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") +generator = torch.Generator(device="cpu").manual_seed(33) +images = pipeline( +    prompt='best quality, high quality', +    image = image, +    ip_adapter_image=ip_image, +    num_inference_steps=50, +    generator=generator, +    strength=0.6, +).images +images[0] +``` + + + + +```py +from diffusers import AutoPipelineForInpaint +import torch +from diffusers.utils import load_image + +pipeline = AutoPipelineForInpaint.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float).to("cuda") + +image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/inpaint_image.png") +mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/mask.png") +ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/girl.png") + +image = image.resize((512, 768)) +mask = mask.resize((512, 768)) + +pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") + +generator = torch.Generator(device="cpu").manual_seed(33) +images = pipeline( + prompt='best quality, high quality', + image = image, + mask_image = mask, + ip_adapter_image=ip_image, + negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", + num_inference_steps=50, + generator=generator, + strength=0.5, +).images +images[0] +``` + + + + +IP-Adapters can also be used with [SDXL](../api/pipelines/stable_diffusion/stable_diffusion_xl.md) + +```python +from diffusers import AutoPipelineForText2Image +from diffusers.utils import load_image +import torch + +pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.float16 +).to("cuda") + +image = load_image("https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/watercolor_painting.jpeg") + +pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") + +generator = torch.Generator(device="cpu").manual_seed(33) +image = pipeline( + prompt="best quality, high quality", + ip_adapter_image=image, + negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", + num_inference_steps=25, + generator=generator, +).images[0] +image.save("sdxl_t2i.png") +``` + +
+
+ +
input image
+
+
+ +
adapted image
+
+
+ + +### LCM-Lora + +You can use IP-Adapter with LCM-Lora to achieve "instant fine-tune" with custom images. Note that you need to load IP-Adapter weights before loading the LCM-Lora weights. + +```py +from diffusers import DiffusionPipeline, LCMScheduler +import torch +from diffusers.utils import load_image + +model_id = "sd-dreambooth-library/herge-style" +lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5" + +pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) + +pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") +pipe.load_lora_weights(lcm_lora_id) +pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) +pipe.enable_model_cpu_offload() + +prompt = "best quality, high quality" +image = load_image("https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png") +images = pipe( + prompt=prompt, + ip_adapter_image=image, + num_inference_steps=4, + guidance_scale=1, +).images[0] +``` + +### Other pipelines + +IP-Adapter is compatible with any pipeline that (1) uses a text prompt and (2) uses Stable Diffusion or Stable Diffusion XL checkpoint. To use IP-Adapter with a different pipeline, all you need to do is to run `load_ip_adapter()` method after you create the pipeline, and then pass your image to the pipeline as `ip_adapter_image` + + + +🤗 Diffusers currently only supports using IP-Adapter with some of the most popular pipelines, feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require integrating IP-adapters with a pipeline that does not support it yet! + + + +You can find below examples on how to use IP-Adapter with ControlNet and AnimateDiff. + + + + +``` +from diffusers import StableDiffusionControlNetPipeline, ControlNetModel +import torch +from diffusers.utils import load_image + +controlnet_model_path = "lllyasviel/control_v11f1p_sd15_depth" +controlnet = ControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.float16) + +pipeline = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16) +pipeline.to("cuda") + +image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/statue.png") +depth_map = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/depth.png") + +pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") + +generator = torch.Generator(device="cpu").manual_seed(33) +images = pipeline( + prompt='best quality, high quality', + image=depth_map, + ip_adapter_image=image, + negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", + num_inference_steps=50, + generator=generator, +).images +images[0] +``` +
+
+ +
input image
+
+
+ +
adapted image
+
+
+ +
+ + +```py +# animate diff + ip adapter +import torch +from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler +from diffusers.utils import export_to_gif, load_image + +# Load the motion adapter +adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16) +# load SD 1.5 based finetuned model +model_id = "Lykon/DreamShaper" +pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16) + +# scheduler +scheduler = DDIMScheduler( + clip_sample=False, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="linear", + timestep_spacing="trailing", + steps_offset=1 +) +pipe.scheduler = scheduler + +# enable memory savings +pipe.enable_vae_slicing() +pipe.enable_model_cpu_offload() + +# load ip_adapter +pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") + +# load motion adapters +pipe.load_lora_weights("guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out") +pipe.load_lora_weights("guoyww/animatediff-motion-lora-tilt-up", adapter_name="tilt-up") +pipe.load_lora_weights("guoyww/animatediff-motion-lora-pan-left", adapter_name="pan-left") + +seed = 42 +image = load_image("https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png") +images = [image] * 3 +prompts = ["best quality, high quality"] * 3 +negative_prompt = "bad quality, worst quality" +adapter_weights = [[0.75, 0.0, 0.0], [0.0, 0.0, 0.75], [0.0, 0.75, 0.75]] + +# generate +output_frames = [] +for prompt, image, adapter_weight in zip(prompts, images, adapter_weights): + pipe.set_adapters(["zoom-out", "tilt-up", "pan-left"], adapter_weights=adapter_weight) + output = pipe( + prompt= prompt, + num_frames=16, + guidance_scale=7.5, + num_inference_steps=30, + ip_adapter_image = image, + generator=torch.Generator("cpu").manual_seed(seed), + ) + frames = output.frames[0] + output_frames.extend(frames) + +export_to_gif(output_frames, "test_out_animation.gif") +``` + + +
+ diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 14fd985f69e4..684736856029 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -62,6 +62,7 @@ def text_encoder_attn_modules(text_encoder): _import_structure["single_file"].extend(["FromSingleFileMixin"]) _import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] + _import_structure["ip_adapter"] = ["IPAdapterMixin"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -72,6 +73,7 @@ def text_encoder_attn_modules(text_encoder): from .utils import AttnProcsLayers if is_transformers_available(): + from .ip_adapter import IPAdapterMixin from .lora import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin from .single_file import FromSingleFileMixin from .textual_inversion import TextualInversionLoaderMixin diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py new file mode 100644 index 000000000000..32c558554be2 --- /dev/null +++ b/src/diffusers/loaders/ip_adapter.py @@ -0,0 +1,157 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Dict, Union + +import torch +from safetensors import safe_open + +from ..utils import ( + DIFFUSERS_CACHE, + HF_HUB_OFFLINE, + _get_model_file, + is_transformers_available, + logging, +) + + +if is_transformers_available(): + from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, + ) + + from ..models.attention_processor import ( + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, + ) + +logger = logging.get_logger(__name__) + + +class IPAdapterMixin: + """Mixin for handling IP Adapters.""" + + def load_ip_adapter( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + subfolder: str, + weight_name: str, + **kwargs, + ): + """ + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + """ + + # Load the main state dict first. + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(model_file, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict + + keys = list(state_dict.keys()) + if keys != ["image_proj", "ip_adapter"]: + raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") + + # load CLIP image encoer here if it has not been registered to the pipeline yet + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + pretrained_model_name_or_path_or_dict, + subfolder=os.path.join(subfolder, "image_encoder"), + ).to(self.device, dtype=self.dtype) + self.image_encoder = image_encoder + else: + raise ValueError("`image_encoder` cannot be None when using IP Adapters.") + + # create feature extractor if it has not been registered to the pipeline yet + if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None: + self.feature_extractor = CLIPImageProcessor() + + # load ip-adapter into unet + self.unet._load_ip_adapter_weights(state_dict) + + def set_ip_adapter_scale(self, scale): + for attn_processor in self.unet.attn_processors.values(): + if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): + attn_processor.scale = scale diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 9555ac9e7d8b..6c805672c9cd 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -18,8 +18,10 @@ import safetensors import torch +import torch.nn.functional as F from torch import nn +from ..models.embeddings import ImageProjection from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..utils import ( DIFFUSERS_CACHE, @@ -662,4 +664,72 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): if hasattr(self, "peft_config"): self.peft_config.pop(adapter_name, None) + def _load_ip_adapter_weights(self, state_dict): + from ..models.attention_processor import ( + AttnProcessor, + AttnProcessor2_0, + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, + ) + + # set ip-adapter cross-attention processors & load state_dict + attn_procs = {} + key_id = 1 + for name in self.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = self.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(self.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = self.config.block_out_channels[block_id] + if cross_attention_dim is None or "motion_modules" in name: + attn_processor_class = ( + AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor + ) + attn_procs[name] = attn_processor_class() + else: + attn_processor_class = ( + IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor + ) + attn_procs[name] = attn_processor_class( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 + ).to(dtype=self.dtype, device=self.device) + + value_dict = {} + for k, w in attn_procs[name].state_dict().items(): + value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]}) + + attn_procs[name].load_state_dict(value_dict) + key_id += 2 + + self.set_attn_processor(attn_procs) + + # create image projection layers. + clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1] + cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4 + + image_projection = ImageProjection( + cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4 + ) + image_projection.to(dtype=self.dtype, device=self.device) + + # load image projection layer weights + image_proj_state_dict = {} + image_proj_state_dict.update( + { + "image_embeds.weight": state_dict["image_proj"]["proj.weight"], + "image_embeds.bias": state_dict["image_proj"]["proj.bias"], + "norm.weight": state_dict["image_proj"]["norm.weight"], + "norm.bias": state_dict["image_proj"]["norm.bias"], + } + ) + + image_projection.load_state_dict(image_proj_state_dict) + + self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype) + self.config.encoder_hid_dim_type = "ip_image_proj" + delete_adapter_layers diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 1234dbd2d5ce..6b86ba66db37 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1975,6 +1975,250 @@ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **k return attn.processor(attn, hidden_states, *args, **kwargs) +class IPAdapterAttnProcessor(nn.Module): + r""" + Attention processor for IP-Adapater. + + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + num_tokens (`int`, defaults to 4): + The context length of the image features. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + self.scale = scale + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + scale=1.0, + ): + if scale != 1.0: + logger.warning("`scale` of IPAttnProcessor should be set with `set_ip_adapter_scale`.") + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + # split hidden states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAdapterAttnProcessor2_0(torch.nn.Module): + r""" + Attention processor for IP-Adapater for PyTorch 2.0. + + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + num_tokens (`int`, defaults to 4): + The context length of the image features. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + self.scale = scale + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + scale=1.0, + ): + if scale != 1.0: + logger.warning("`scale` of IPAttnProcessor should be set by `set_ip_adapter_scale`.") + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + # split hidden states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + LORA_ATTENTION_PROCESSORS = ( LoRAAttnProcessor, LoRAAttnProcessor2_0, @@ -1998,6 +2242,8 @@ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **k LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, ) AttentionProcessor = Union[ diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index f248b243f376..dd91d8007229 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -1022,6 +1022,15 @@ def forward( ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) + encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) + # 2. pre-process sample = self.conv_in(sample) diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index 7be1a59114ef..0bbc573e7df1 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -208,6 +208,8 @@ def __init__( motion_max_seq_length: int = 32, motion_num_attention_heads: int = 8, use_motion_mid_block: int = True, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, ): super().__init__() @@ -248,6 +250,9 @@ def __init__( act_fn=act_fn, ) + if encoder_hid_dim_type is None: + self.encoder_hid_proj = None + # class embedding self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) @@ -684,6 +689,7 @@ def forward( timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, return_dict: bool = True, @@ -767,6 +773,16 @@ def forward( emb = self.time_embedding(t_emb, timestep_cond) emb = emb.repeat_interleave(repeats=num_frames, dim=0) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) + encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) # 2. pre-process diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 9f51c084d5f8..843e3b8b9410 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -17,11 +17,11 @@ import torch from packaging import version -from transformers import CLIPImageProcessor, XLMRobertaTokenizer +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMRobertaTokenizer from ...configuration_utils import FrozenDict -from ...image_processor import VaeImageProcessor -from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers @@ -74,7 +74,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker -class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): +class AltDiffusionPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin +): r""" Pipeline for text-to-image generation using Alt Diffusion. @@ -86,6 +88,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): @@ -108,7 +111,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL """ model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] @@ -121,6 +124,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() @@ -197,6 +201,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -444,6 +449,19 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None @@ -652,6 +670,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -698,6 +717,7 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -797,12 +817,18 @@ def __call__( lora_scale=lora_scale, clip_skip=self.clip_skip, ) + # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps @@ -823,7 +849,10 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 6.5 Optionally get Guidance Scale Embedding + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 6.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) @@ -847,6 +876,7 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 129794f7fbbd..b196ac4d3f69 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -19,11 +19,11 @@ import PIL.Image import torch from packaging import version -from transformers import CLIPImageProcessor, XLMRobertaTokenizer +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMRobertaTokenizer from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers @@ -111,7 +111,7 @@ def preprocess(image): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker class AltDiffusionImg2ImgPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-guided image-to-image generation using Alt Diffusion. @@ -124,6 +124,7 @@ class AltDiffusionImg2ImgPipeline( - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): @@ -146,7 +147,7 @@ class AltDiffusionImg2ImgPipeline( """ model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] @@ -159,6 +160,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() @@ -235,6 +237,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -453,6 +456,19 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None @@ -705,6 +721,7 @@ def __call__( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -754,6 +771,7 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -846,6 +864,11 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. Preprocess image image = self.image_processor.preprocess(image) @@ -868,7 +891,10 @@ def __call__( # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7.5 Optionally get Guidance Scale Embedding + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 7.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) @@ -892,6 +918,7 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 71adb8408c88..28dc220545dc 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -18,10 +18,10 @@ import numpy as np import torch -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection -from ...image_processor import VaeImageProcessor -from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder from ...models.unet_motion_model import MotionAdapter @@ -77,7 +77,7 @@ class AnimateDiffPipelineOutput(BaseOutput): frames: Union[torch.Tensor, np.ndarray] -class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): +class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin): r""" Pipeline for text-to-video generation. @@ -101,6 +101,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLo """ model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder"] def __init__( self, @@ -117,6 +118,8 @@ def __init__( EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ], + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, ): super().__init__() unet = UNetMotionModel.from_unet2d(unet, motion_adapter) @@ -128,6 +131,8 @@ def __init__( unet=unet, motion_adapter=motion_adapter, scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -314,6 +319,20 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents @@ -512,6 +531,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -558,6 +578,7 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or `np.array`. @@ -629,6 +650,11 @@ def __call__( if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_videos_per_prompt) + if do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps @@ -649,6 +675,8 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 7 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None # Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -664,6 +692,7 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, ).sample # perform guidance diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 4f625304fdf9..41e5e75f68e5 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -20,10 +20,10 @@ import PIL.Image import torch import torch.nn.functional as F -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers @@ -92,7 +92,7 @@ class StableDiffusionControlNetPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin ): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. @@ -102,6 +102,7 @@ class StableDiffusionControlNetPipeline( The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): @@ -128,7 +129,7 @@ class StableDiffusionControlNetPipeline( """ model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] @@ -142,6 +143,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() @@ -174,6 +176,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) @@ -430,6 +433,20 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -803,6 +820,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -860,6 +878,7 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -997,6 +1016,11 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. Prepare image if isinstance(controlnet, ControlNetModel): image = self.prepare_image( @@ -1063,7 +1087,10 @@ def __call__( # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7.1 Create tensor stating which controlnets to keep + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 7.2 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ @@ -1131,6 +1158,7 @@ def __call__( cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index c1efd8aaa397..4696781dce0c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -20,12 +20,23 @@ import PIL.Image import torch import torch.nn.functional as F -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) from diffusers.utils.import_utils import is_invisible_watermark_available from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, @@ -104,7 +115,11 @@ class StableDiffusionXLControlNetPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. @@ -149,7 +164,14 @@ class StableDiffusionXLControlNetPipeline( # leave controlnet out on purpose because it iterates with unet model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( @@ -164,6 +186,8 @@ def __init__( scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, ): super().__init__() @@ -179,6 +203,8 @@ def __init__( unet=unet, controlnet=controlnet, scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) @@ -462,6 +488,20 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -879,6 +919,7 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -959,6 +1000,7 @@ def __call__( Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -1100,7 +1142,7 @@ def __call__( ) guess_mode = guess_mode or global_pool_conditions - # 3. Encode input prompt + # 3.1 Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) @@ -1125,6 +1167,12 @@ def __call__( clip_skip=self.clip_skip, ) + # 3.2 Encode ip_adapter_image + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. Prepare image if isinstance(controlnet, ControlNetModel): image = self.prepare_image( @@ -1299,6 +1347,9 @@ def __call__( down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + if ip_adapter_image is not None: + added_cond_kwargs["image_embeds"] = image_embeds + # predict the noise residual noise_pred = self.unet( latent_model_input, diff --git a/src/diffusers/pipelines/pipeline_flax_utils.py b/src/diffusers/pipelines/pipeline_flax_utils.py index cbb55b504c54..2e25a40295b4 100644 --- a/src/diffusers/pipelines/pipeline_flax_utils.py +++ b/src/diffusers/pipelines/pipeline_flax_utils.py @@ -538,12 +538,13 @@ def load_module(name, value): model = pipeline_class(**init_kwargs, dtype=dtype) return model, params - @staticmethod - def _get_signature_keys(obj): + @classmethod + def _get_signature_keys(cls, obj): parameters = inspect.signature(obj.__init__).parameters required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) expected_modules = set(required_parameters.keys()) - {"self"} + return expected_modules, optional_parameters @property diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 5fa1938983d5..0208ade020bd 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -557,7 +557,7 @@ def register_modules(self, **kwargs): for name, module in kwargs.items(): # retrieve library - if module is None: + if module is None or isinstance(module, (tuple, list)) and module[0] is None: register_dict = {name: (None, None)} else: # register the config from the original module, not the dynamo compiled one @@ -1906,12 +1906,19 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: " above." ) from model_info_call_error - @staticmethod - def _get_signature_keys(obj): + @classmethod + def _get_signature_keys(cls, obj): parameters = inspect.signature(obj.__init__).parameters required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) expected_modules = set(required_parameters.keys()) - {"self"} + + optional_names = list(optional_parameters) + for name in optional_names: + if name in cls._optional_components: + expected_modules.add(name) + optional_parameters.remove(name) + return expected_modules, optional_parameters @property diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 5af5a42256f3..a05abe00f2b1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -17,11 +17,11 @@ import torch from packaging import version -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...configuration_utils import FrozenDict -from ...image_processor import VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers @@ -70,7 +70,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg -class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin): +class StableDiffusionPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin +): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -82,6 +84,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): @@ -104,7 +107,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo """ model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] @@ -117,6 +120,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() @@ -193,6 +197,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -440,6 +445,19 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None @@ -649,6 +667,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -695,6 +714,7 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -794,12 +814,18 @@ def __call__( lora_scale=lora_scale, clip_skip=self.clip_skip, ) + # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps @@ -820,7 +846,10 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 6.5 Optionally get Guidance Scale Embedding + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 6.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) @@ -844,6 +873,7 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index c75afb0789a4..029cd2b04839 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -19,11 +19,11 @@ import PIL.Image import torch from packaging import version -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers @@ -106,7 +106,7 @@ def preprocess(image): class StableDiffusionImg2ImgPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-guided image-to-image generation using Stable Diffusion. @@ -119,6 +119,7 @@ class StableDiffusionImg2ImgPipeline( - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): @@ -141,7 +142,7 @@ class StableDiffusionImg2ImgPipeline( """ model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] @@ -154,6 +155,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() @@ -230,6 +232,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -450,6 +453,20 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -708,6 +725,7 @@ def __call__( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -757,6 +775,7 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -849,6 +868,11 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. Preprocess image image = self.image_processor.preprocess(image) @@ -871,7 +895,10 @@ def __call__( # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7.5 Optionally get Guidance Scale Embedding + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 7.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) @@ -895,6 +922,7 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index e4a25e181e42..09e50c60a807 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -19,11 +19,11 @@ import PIL.Image import torch from packaging import version -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers @@ -170,7 +170,7 @@ def retrieve_latents(encoder_output, generator): class StableDiffusionInpaintPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-guided image inpainting using Stable Diffusion. @@ -182,6 +182,7 @@ class StableDiffusionInpaintPipeline( - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`, `AsymmetricAutoencoderKL`]): @@ -204,7 +205,7 @@ class StableDiffusionInpaintPipeline( """ model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "mask", "masked_image_latents"] @@ -217,6 +218,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() @@ -298,6 +300,7 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -521,6 +524,20 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -837,6 +854,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -902,6 +920,7 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -1029,6 +1048,11 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps( @@ -1117,7 +1141,10 @@ def __call__( # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 9.5 Optionally get Guidance Scale Embedding + # 9.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 9.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) @@ -1146,6 +1173,7 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index c50a036a88f8..e32791693012 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -16,11 +16,18 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) -from ...image_processor import VaeImageProcessor +from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, + IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) @@ -94,7 +101,11 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): class StableDiffusionXLPipeline( - DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin + DiffusionPipeline, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. @@ -142,7 +153,14 @@ class StableDiffusionXLPipeline( """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] _callback_tensor_inputs = [ "latents", "prompt_embeds", @@ -162,6 +180,8 @@ def __init__( tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, ): @@ -175,6 +195,8 @@ def __init__( tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) @@ -456,6 +478,20 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -718,6 +754,7 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -802,6 +839,7 @@ def __call__( Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -1000,6 +1038,12 @@ def __call__( add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + image_embeds = image_embeds.to(device) + # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -1037,6 +1081,8 @@ def __call__( # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None: + added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.unet( latent_model_input, t, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 56f1a5196cf0..d40a037e67fe 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -17,10 +17,21 @@ import PIL.Image import torch -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, @@ -104,7 +115,11 @@ def retrieve_latents(encoder_output, generator): class StableDiffusionXLImg2ImgPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin + DiffusionPipeline, + TextualInversionLoaderMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + IPAdapterMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. @@ -155,7 +170,14 @@ class StableDiffusionXLImg2ImgPipeline( """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] _callback_tensor_inputs = [ "latents", "prompt_embeds", @@ -175,6 +197,8 @@ def __init__( tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, @@ -188,6 +212,8 @@ def __init__( tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, + image_encoder=image_encoder, + feature_extractor=feature_extractor, scheduler=scheduler, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) @@ -665,6 +691,20 @@ def prepare_latents( return latents + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + def _get_add_time_ids( self, original_size, @@ -850,6 +890,7 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -943,6 +984,7 @@ def __call__( Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -1162,6 +1204,12 @@ def denoising_value_valid(dnv): add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + image_embeds = image_embeds.to(device) + # 9. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -1205,6 +1253,8 @@ def denoising_value_valid(dnv): # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None: + added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.unet( latent_model_input, t, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index d618ea4c2a71..3a9d068d60f3 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -18,10 +18,21 @@ import numpy as np import PIL.Image import torch -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, @@ -249,7 +260,11 @@ def retrieve_latents(encoder_output, generator): class StableDiffusionXLInpaintPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, + IPAdapterMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. @@ -301,7 +316,14 @@ class StableDiffusionXLInpaintPipeline( model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] _callback_tensor_inputs = [ "latents", "prompt_embeds", @@ -323,6 +345,8 @@ def __init__( tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, @@ -336,6 +360,8 @@ def __init__( tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, + image_encoder=image_encoder, + feature_extractor=feature_extractor, scheduler=scheduler, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) @@ -386,6 +412,20 @@ def disable_vae_tiling(self): """ self.vae.disable_tiling() + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, @@ -1074,6 +1114,7 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -1172,6 +1213,7 @@ def __call__( Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -1471,6 +1513,12 @@ def denoising_value_valid(dnv): add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device) + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + image_embeds = image_embeds.to(device) + # 11. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -1517,6 +1565,8 @@ def denoising_value_valid(dnv): # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None: + added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.unet( latent_model_input, t, diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 91f8c2c3dc03..a940cec5e46a 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1221,6 +1221,15 @@ def forward( ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) + encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) + # 2. pre-process sample = self.conv_in(sample) diff --git a/tests/lora/test_lora_layers_old_backend.py b/tests/lora/test_lora_layers_old_backend.py index 285c5e864a04..19505a1d906d 100644 --- a/tests/lora/test_lora_layers_old_backend.py +++ b/tests/lora/test_lora_layers_old_backend.py @@ -246,6 +246,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } lora_components = { "unet_lora_layers": unet_lora_layers, @@ -757,6 +758,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components @@ -866,6 +868,8 @@ def get_dummy_components(self): "text_encoder_2": text_encoder_2, "tokenizer": tokenizer, "tokenizer_2": tokenizer_2, + "image_encoder": None, + "feature_extractor": None, } lora_components = { "unet_lora_layers": unet_lora_layers, diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py index c290850a10b6..48ae5d197273 100644 --- a/tests/lora/test_lora_layers_peft.py +++ b/tests/lora/test_lora_layers_peft.py @@ -140,6 +140,8 @@ def get_dummy_components(self, scheduler_cls=None): "tokenizer": tokenizer, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, + "image_encoder": None, + "feature_extractor": None, } else: pipeline_components = { @@ -150,6 +152,7 @@ def get_dummy_components(self, scheduler_cls=None): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } lora_components = { "unet_lora_layers": unet_lora_layers, diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 0db336a88029..06bf2685560d 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -24,7 +24,8 @@ from pytest import mark from diffusers import UNet2DConditionModel -from diffusers.models.attention_processor import CustomDiffusionAttnProcessor +from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor +from diffusers.models.embeddings import ImageProjection from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( @@ -45,6 +46,57 @@ enable_full_determinism() +def create_ip_adapter_state_dict(model): + # "ip_adapter" (cross-attention weights) + ip_cross_attn_state_dict = {} + key_id = 1 + + for name in model.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = model.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.config.block_out_channels[block_id] + if cross_attention_dim is not None: + sd = IPAdapterAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 + ).state_dict() + ip_cross_attn_state_dict.update( + { + f"{key_id}.to_k_ip.weight": sd["to_k_ip.weight"], + f"{key_id}.to_v_ip.weight": sd["to_v_ip.weight"], + } + ) + + key_id += 2 + + # "image_proj" (ImageProjection layer weights) + cross_attention_dim = model.config["cross_attention_dim"] + image_projection = ImageProjection( + cross_attention_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, num_image_text_embeds=4 + ) + + ip_image_projection_state_dict = {} + sd = image_projection.state_dict() + ip_image_projection_state_dict.update( + { + "proj.weight": sd["image_embeds.weight"], + "proj.bias": sd["image_embeds.bias"], + "norm.weight": sd["norm.weight"], + "norm.bias": sd["norm.bias"], + } + ) + + del sd + ip_state_dict = {} + ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) + return ip_state_dict + + def create_custom_diffusion_layers(model, mock_weights: bool = True): train_kv = True train_q_out = True @@ -622,6 +674,56 @@ def test_asymmetrical_unet(self): # Check if input and output shapes are the same self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + def test_ip_adapter(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + + # forward pass without ip-adapter + with torch.no_grad(): + sample1 = model(**inputs_dict).sample + + # update inputs_dict for ip-adapter + batch_size = inputs_dict["encoder_hidden_states"].shape[0] + image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to(torch_device) + inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds} + + # make ip_adapter_1 and ip_adapter_2 + ip_adapter_1 = create_ip_adapter_state_dict(model) + + image_proj_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["image_proj"].items()} + cross_attn_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["ip_adapter"].items()} + ip_adapter_2 = {} + ip_adapter_2.update({"image_proj": image_proj_state_dict_2, "ip_adapter": cross_attn_state_dict_2}) + + # forward pass ip_adapter_1 + model._load_ip_adapter_weights(ip_adapter_1) + assert model.config.encoder_hid_dim_type == "ip_image_proj" + assert model.encoder_hid_proj is not None + assert model.down_blocks[0].attentions[0].transformer_blocks[0].attn2.processor.__class__.__name__ in ( + "IPAdapterAttnProcessor", + "IPAdapterAttnProcessor2_0", + ) + with torch.no_grad(): + sample2 = model(**inputs_dict).sample + + # forward pass with ip_adapter_2 + model._load_ip_adapter_weights(ip_adapter_2) + with torch.no_grad(): + sample3 = model(**inputs_dict).sample + + # forward pass with ip_adapter_1 again + model._load_ip_adapter_weights(ip_adapter_1) + with torch.no_grad(): + sample4 = model(**inputs_dict).sample + + assert not sample1.allclose(sample2, atol=1e-4, rtol=1e-4) + assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4) + assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4) + @slow class UNet2DConditionModelIntegrationTests(unittest.TestCase): diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion.py b/tests/pipelines/altdiffusion/test_alt_diffusion.py index 5befe60cf6d9..b4a2847bb84d 100644 --- a/tests/pipelines/altdiffusion/test_alt_diffusion.py +++ b/tests/pipelines/altdiffusion/test_alt_diffusion.py @@ -117,6 +117,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py b/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py index 57001f7bea52..3fd1a90172ca 100644 --- a/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py +++ b/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py @@ -141,6 +141,7 @@ def test_stable_diffusion_img2img_default_case(self): tokenizer=tokenizer, safety_checker=None, feature_extractor=self.dummy_extractor, + image_encoder=None, ) alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=True) alt_pipe = alt_pipe.to(device) @@ -205,6 +206,7 @@ def test_stable_diffusion_img2img_fp16(self): tokenizer=tokenizer, safety_checker=None, feature_extractor=self.dummy_extractor, + image_encoder=None, ) alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=False) alt_pipe = alt_pipe.to(torch_device) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 3c9390f2d1b6..5cd0a45c7406 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -99,6 +99,8 @@ def get_dummy_components(self): "motion_adapter": motion_adapter, "text_encoder": text_encoder, "tokenizer": tokenizer, + "feature_extractor": None, + "image_encoder": None, } return components diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index 2d8c8869c23c..1cf52bfeebe2 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -183,6 +183,7 @@ def get_dummy_components(self, time_cond_proj_dim=None): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components @@ -341,6 +342,7 @@ def init_weights(m): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components @@ -518,6 +520,7 @@ def init_weights(m): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 36ddee36eb52..88d2df1ec0f8 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -146,6 +146,8 @@ def get_dummy_components(self, time_cond_proj_dim=None): "tokenizer": tokenizer, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, + "feature_extractor": None, + "image_encoder": None, } return components @@ -471,6 +473,8 @@ def init_weights(m): "tokenizer": tokenizer, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, + "feature_extractor": None, + "image_encoder": None, } return components @@ -656,6 +660,8 @@ def init_weights(m): "tokenizer": tokenizer, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, + "feature_extractor": None, + "image_encoder": None, } return components diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py new file mode 100644 index 000000000000..57eb49013c1f --- /dev/null +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -0,0 +1,221 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch +from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, +) + +from diffusers import ( + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLPipeline, +) +from diffusers.utils import load_image +from diffusers.utils.testing_utils import ( + enable_full_determinism, + require_torch_gpu, + slow, + torch_device, +) + + +enable_full_determinism() + + +class IPAdapterNightlyTestsMixin(unittest.TestCase): + dtype = torch.float16 + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_image_encoder(self, repo_id, subfolder): + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + repo_id, subfolder=subfolder, torch_dtype=self.dtype + ).to(torch_device) + return image_encoder + + def get_image_processor(self, repo_id): + image_processor = CLIPImageProcessor.from_pretrained(repo_id) + return image_processor + + def get_dummy_inputs(self, for_image_to_image=False, for_inpainting=False, for_sdxl=False): + image = load_image( + "https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png" + ) + if for_sdxl: + image = image.resize((1024, 1024)) + + input_kwargs = { + "prompt": "best quality, high quality", + "negative_prompt": "monochrome, lowres, bad anatomy, worst quality, low quality", + "num_inference_steps": 5, + "generator": torch.Generator(device="cpu").manual_seed(33), + "ip_adapter_image": image, + "output_type": "np", + } + if for_image_to_image: + image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/vermeer.jpg") + ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/river.png") + + if for_sdxl: + image = image.resize((1024, 1024)) + ip_image = ip_image.resize((1024, 1024)) + + input_kwargs.update({"image": image, "ip_adapter_image": ip_image}) + + elif for_inpainting: + image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/inpaint_image.png") + mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/mask.png") + ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/girl.png") + + if for_sdxl: + image = image.resize((1024, 1024)) + mask = mask.resize((1024, 1024)) + ip_image = ip_image.resize((1024, 1024)) + + input_kwargs.update({"image": image, "mask_image": mask, "ip_adapter_image": ip_image}) + + return input_kwargs + + +@slow +@require_torch_gpu +class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin): + def test_text_to_image(self): + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + pipeline = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") + + inputs = self.get_dummy_inputs() + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.8047, 0.8774, 0.9248, 0.9155, 0.9814, 1.0, 0.9678, 1.0, 1.0]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + def test_image_to_image(self): + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + pipeline = StableDiffusionImg2ImgPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") + + inputs = self.get_dummy_inputs(for_image_to_image=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.2307, 0.2341, 0.2305, 0.24, 0.2268, 0.25, 0.2322, 0.2588, 0.2935]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + def test_inpainting(self): + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + pipeline = StableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") + + inputs = self.get_dummy_inputs(for_inpainting=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.2705, 0.2395, 0.2209, 0.2312, 0.2102, 0.2104, 0.2178, 0.2065, 0.1997]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + +@slow +@require_torch_gpu +class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin): + def test_text_to_image_sdxl(self): + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="sdxl_models/image_encoder") + feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + + pipeline = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + feature_extractor=feature_extractor, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") + + inputs = self.get_dummy_inputs() + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.0968, 0.0959, 0.0852, 0.0912, 0.0948, 0.093, 0.0893, 0.0932, 0.0923]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + def test_image_to_image_sdxl(self): + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="sdxl_models/image_encoder") + feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + + pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + feature_extractor=feature_extractor, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") + + inputs = self.get_dummy_inputs(for_image_to_image=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.0653, 0.0704, 0.0725, 0.0741, 0.0702, 0.0647, 0.0782, 0.0799, 0.0752]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + def test_inpainting_sdxl(self): + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="sdxl_models/image_encoder") + feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + + pipeline = StableDiffusionXLInpaintPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + feature_extractor=feature_extractor, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") + + inputs = self.get_dummy_inputs(for_inpainting=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + image_slice.tolist() + + expected_slice = np.array([0.1418, 0.1493, 0.1428, 0.146, 0.1491, 0.1501, 0.1473, 0.1501, 0.1516]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 53284b80952c..15c1c4fe6671 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -163,6 +163,7 @@ def get_dummy_components(self, time_cond_proj_dim=None): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index 12c6d8cf63d3..1a482b38e2ee 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -150,6 +150,7 @@ def get_dummy_components(self, time_cond_proj_dim=None): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 287b2eac4d75..cbe4fb2a0ddf 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -153,6 +153,7 @@ def get_dummy_components(self, time_cond_proj_dim=None): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components @@ -353,6 +354,7 @@ def get_dummy_components(self, time_cond_proj_dim=None): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index 4414d1ec5075..ed295f792f99 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -123,6 +123,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py index 92e8857610ea..41b9f83914a6 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py @@ -108,6 +108,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "image_encoder": None, } return components diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py index e2d476dec502..09034789c61c 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py @@ -127,6 +127,7 @@ def test_stable_diffusion_v_pred_ddim(self): tokenizer=tokenizer, safety_checker=None, feature_extractor=None, + image_encoder=None, requires_safety_checker=False, ) sd_pipe = sd_pipe.to(device) @@ -176,6 +177,7 @@ def test_stable_diffusion_v_pred_k_euler(self): tokenizer=tokenizer, safety_checker=None, feature_extractor=None, + image_encoder=None, requires_safety_checker=False, ) sd_pipe = sd_pipe.to(device) @@ -236,6 +238,7 @@ def test_stable_diffusion_v_pred_fp16(self): tokenizer=tokenizer, safety_checker=None, feature_extractor=None, + image_encoder=None, requires_safety_checker=False, ) sd_pipe = sd_pipe.to(torch_device) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 95fbb658fe5e..8957ebbef5ab 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -131,6 +131,8 @@ def get_dummy_components(self, time_cond_proj_dim=None): "tokenizer": tokenizer, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, + "image_encoder": None, + "feature_extractor": None, } return components diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index 55779e5f060d..444f12ecfa9d 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -18,7 +18,15 @@ import numpy as np import torch -from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import ( + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) from diffusers import ( AutoencoderKL, @@ -95,6 +103,31 @@ def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim latent_channels=4, sample_size=128, ) + torch.manual_seed(0) + image_encoder_config = CLIPVisionConfig( + hidden_size=32, + image_size=224, + projection_dim=32, + intermediate_size=37, + num_attention_heads=4, + num_channels=3, + num_hidden_layers=5, + patch_size=14, + ) + + image_encoder = CLIPVisionModelWithProjection(image_encoder_config) + + feature_extractor = CLIPImageProcessor( + crop_size=224, + do_center_crop=True, + do_normalize=True, + do_resize=True, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + resample=3, + size=224, + ) + torch.manual_seed(0) text_encoder_config = CLIPTextConfig( bos_token_id=0, @@ -125,6 +158,8 @@ def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, "requires_aesthetics_score": True, + "image_encoder": image_encoder, + "feature_extractor": feature_extractor, } return components @@ -458,6 +493,8 @@ def get_dummy_components(self): "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, "requires_aesthetics_score": True, + "image_encoder": None, + "feature_extractor": None, } return components diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py index 54c750f997b6..7f7a0d81e5a2 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py @@ -20,7 +20,15 @@ import numpy as np import torch from PIL import Image -from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import ( + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) from diffusers import ( AutoencoderKL, @@ -120,6 +128,31 @@ def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + torch.manual_seed(0) + image_encoder_config = CLIPVisionConfig( + hidden_size=32, + image_size=224, + projection_dim=32, + intermediate_size=37, + num_attention_heads=4, + num_channels=3, + num_hidden_layers=5, + patch_size=14, + ) + + image_encoder = CLIPVisionModelWithProjection(image_encoder_config) + + feature_extractor = CLIPImageProcessor( + crop_size=224, + do_center_crop=True, + do_normalize=True, + do_resize=True, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + resample=3, + size=224, + ) + components = { "unet": unet, "scheduler": scheduler, @@ -128,6 +161,8 @@ def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim "tokenizer": tokenizer if not skip_first_text_encoder else None, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, + "image_encoder": image_encoder, + "feature_extractor": feature_extractor, "requires_aesthetics_score": True, } return components diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 42c90e47af80..d812ce0ccb95 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -1136,8 +1136,8 @@ def test_stable_diffusion_components(self): safety_checker=None, feature_extractor=self.dummy_extractor, ).to(torch_device) - img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device) - text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device) + img2img = StableDiffusionImg2ImgPipeline(**inpaint.components, image_encoder=None).to(torch_device) + text2img = StableDiffusionPipeline(**inpaint.components, image_encoder=None).to(torch_device) prompt = "A painting of a squirrel eating a burger" @@ -1276,6 +1276,29 @@ def test_set_component_to_none(self): assert out_image.shape == (1, 64, 64, 3) assert np.abs(out_image - out_image_2).max() < 1e-3 + def test_optional_components_is_none(self): + unet = self.dummy_cond_unet() + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + items = { + "feature_extractor": self.dummy_extractor, + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": bert, + "tokenizer": tokenizer, + "safety_checker": None, + # we don't add an image encoder + } + + pipeline = StableDiffusionPipeline(**items) + + assert sorted(pipeline.components.keys()) == sorted(["image_encoder"] + list(items.keys())) + assert pipeline.image_encoder is None + def test_set_scheduler_consistency(self): unet = self.dummy_cond_unet() pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") From 13d73d9303583f430763357975fcb2398c009a50 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 21 Nov 2023 18:58:37 +0100 Subject: [PATCH 7/8] [Lora] Seperate logic (#5809) * [Lora] Seperate logic * [Lora] Seperate logic * [Lora] Seperate logic * add comments to explain the code better * add comments to explain the code better --- examples/dreambooth/train_dreambooth_lora.py | 35 ++++++- .../dreambooth/train_dreambooth_lora_sdxl.py | 35 ++++++- .../text_to_image/train_text_to_image_lora.py | 98 ++++++++++++++----- .../train_text_to_image_lora_sdxl.py | 35 ++++++- src/diffusers/loaders/__init__.py | 5 +- src/diffusers/loaders/lora.py | 37 ++++++- src/diffusers/models/lora.py | 28 ++---- 7 files changed, 219 insertions(+), 54 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 9250865a3ad1..b82dfa38c172 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -57,7 +57,7 @@ AttnAddedKVProcessor2_0, SlicedAttnAddedKVProcessor, ) -from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict +from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler from diffusers.training_utils import unet_lora_state_dict from diffusers.utils import check_min_version, is_wandb_available @@ -70,6 +70,39 @@ logger = get_logger(__name__) +# TODO: This function should be removed once training scripts are rewritten in PEFT +def text_encoder_lora_state_dict(text_encoder): + state_dict = {} + + def text_encoder_attn_modules(text_encoder): + from transformers import CLIPTextModel, CLIPTextModelWithProjection + + attn_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + + return attn_modules + + for name, module in text_encoder_attn_modules(text_encoder): + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v + + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v + + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v + + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v + + return state_dict + + def save_model_card( repo_id: str, images=None, diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 97b60c8f527d..dd7b29ca8842 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -50,7 +50,7 @@ UNet2DConditionModel, ) from diffusers.loaders import LoraLoaderMixin -from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict +from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_snr, unet_lora_state_dict from diffusers.utils import check_min_version, is_wandb_available @@ -63,6 +63,39 @@ logger = get_logger(__name__) +# TODO: This function should be removed once training scripts are rewritten in PEFT +def text_encoder_lora_state_dict(text_encoder): + state_dict = {} + + def text_encoder_attn_modules(text_encoder): + from transformers import CLIPTextModel, CLIPTextModelWithProjection + + attn_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + + return attn_modules + + for name, module in text_encoder_attn_modules(text_encoder): + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v + + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v + + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v + + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v + + return state_dict + + def save_model_card( repo_id: str, images=None, diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 78b443d149e8..b7309196dec8 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -40,8 +40,7 @@ import diffusers from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel -from diffusers.loaders import AttnProcsLayers -from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_snr from diffusers.utils import check_min_version, is_wandb_available @@ -54,6 +53,39 @@ logger = get_logger(__name__, log_level="INFO") +# TODO: This function should be removed once training scripts are rewritten in PEFT +def text_encoder_lora_state_dict(text_encoder): + state_dict = {} + + def text_encoder_attn_modules(text_encoder): + from transformers import CLIPTextModel, CLIPTextModelWithProjection + + attn_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + + return attn_modules + + for name, module in text_encoder_attn_modules(text_encoder): + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v + + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v + + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v + + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v + + return state_dict + + def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None): img_str = "" for i, image in enumerate(images): @@ -458,25 +490,43 @@ def main(): # => 32 layers # Set correct lora layers - lora_attn_procs = {} - for name in unet.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = unet.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(unet.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = unet.config.block_out_channels[block_id] - - lora_attn_procs[name] = LoRAAttnProcessor( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - rank=args.rank, + unet_lora_parameters = [] + for attn_processor_name, attn_processor in unet.attn_processors.items(): + # Parse the attention module. + attn_module = unet + for n in attn_processor_name.split(".")[:-1]: + attn_module = getattr(attn_module, n) + + # Set the `lora_layer` attribute of the attention-related matrices. + attn_module.to_q.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank + ) + ) + attn_module.to_k.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank + ) + ) + + attn_module.to_v.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank + ) + ) + attn_module.to_out[0].set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_out[0].in_features, + out_features=attn_module.to_out[0].out_features, + rank=args.rank, + ) ) - unet.set_attn_processor(lora_attn_procs) + # Accumulate the LoRA params to optimize. + unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -491,8 +541,6 @@ def main(): else: raise ValueError("xformers is not available. Make sure it is installed correctly") - lora_layers = AttnProcsLayers(unet.attn_processors) - # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: @@ -517,7 +565,7 @@ def main(): optimizer_cls = torch.optim.AdamW optimizer = optimizer_cls( - lora_layers.parameters(), + unet_lora_parameters, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -644,8 +692,8 @@ def collate_fn(examples): ) # Prepare everything with our `accelerator`. - lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - lora_layers, optimizer, train_dataloader, lr_scheduler + unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet_lora_parameters, optimizer, train_dataloader, lr_scheduler ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. @@ -777,7 +825,7 @@ def collate_fn(examples): # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = lora_layers.parameters() + params_to_clip = unet_lora_parameters accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index bff928541f57..96bfe9e16783 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -50,7 +50,7 @@ UNet2DConditionModel, ) from diffusers.loaders import LoraLoaderMixin -from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict +from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_snr from diffusers.utils import check_min_version, is_wandb_available @@ -63,6 +63,39 @@ logger = get_logger(__name__) +# TODO: This function should be removed once training scripts are rewritten in PEFT +def text_encoder_lora_state_dict(text_encoder): + state_dict = {} + + def text_encoder_attn_modules(text_encoder): + from transformers import CLIPTextModel, CLIPTextModelWithProjection + + attn_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + + return attn_modules + + for name, module in text_encoder_attn_modules(text_encoder): + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v + + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v + + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v + + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v + + return state_dict + + def save_model_card( repo_id: str, images=None, diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 684736856029..45c8c97c76eb 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -8,7 +8,7 @@ def text_encoder_lora_state_dict(text_encoder): deprecate( "text_encoder_load_state_dict in `models`", "0.27.0", - "`text_encoder_lora_state_dict` has been moved to `diffusers.models.lora`. Please make sure to import it via `from diffusers.models.lora import text_encoder_lora_state_dict`.", + "`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.", ) state_dict = {} @@ -34,7 +34,7 @@ def text_encoder_attn_modules(text_encoder): deprecate( "text_encoder_attn_modules in `models`", "0.27.0", - "`text_encoder_lora_state_dict` has been moved to `diffusers.models.lora`. Please make sure to import it via `from diffusers.models.lora import text_encoder_lora_state_dict`.", + "`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.", ) from transformers import CLIPTextModel, CLIPTextModelWithProjection @@ -67,7 +67,6 @@ def text_encoder_attn_modules(text_encoder): if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): - from ..models.lora import text_encoder_lora_state_dict from .single_file import FromOriginalControlnetMixin, FromOriginalVAEMixin from .unet import UNet2DConditionLoadersMixin from .utils import AttnProcsLayers diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index ab5d0ffd0157..06eb3af05ee2 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -47,9 +47,10 @@ if is_transformers_available(): - from transformers import PreTrainedModel + from transformers import CLIPTextModel, CLIPTextModelWithProjection - from ..models.lora import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules + # To be deprecated soon + from ..models.lora import PatchedLoraProjection if is_accelerate_available(): from accelerate import init_empty_weights @@ -66,6 +67,34 @@ LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future." +def text_encoder_attn_modules(text_encoder): + attn_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + else: + raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}") + + return attn_modules + + +def text_encoder_mlp_modules(text_encoder): + mlp_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + mlp_mod = layer.mlp + name = f"text_model.encoder.layers.{i}.mlp" + mlp_modules.append((name, mlp_mod)) + else: + raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}") + + return mlp_modules + + class LoraLoaderMixin: r""" Load LoRA layers into [`UNet2DConditionModel`] and [`~transformers.CLIPTextModel`]. @@ -1415,7 +1444,7 @@ def process_weights(adapter_names, weights): ) set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights) - def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): + def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): # noqa: F821 """ Disable the text encoder's LoRA layers. @@ -1445,7 +1474,7 @@ def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel" raise ValueError("Text Encoder not found.") set_adapter_layers(text_encoder, enabled=False) - def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): + def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): # noqa: F821 """ Enables the text encoder's LoRA layers. diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 9edec19a3a34..daac8f902cd6 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -12,6 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. + +# IMPORTANT: # +################################################################### +# ----------------------------------------------------------------# +# This file is deprecated and will be removed soon # +# (as soon as PEFT will become a required dependency for LoRA) # +# ----------------------------------------------------------------# +################################################################### + from typing import Optional, Tuple, Union import torch @@ -57,25 +66,6 @@ def text_encoder_mlp_modules(text_encoder): return mlp_modules -def text_encoder_lora_state_dict(text_encoder): - state_dict = {} - - for name, module in text_encoder_attn_modules(text_encoder): - for k, v in module.q_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v - - for k, v in module.k_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v - - for k, v in module.v_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v - - for k, v in module.out_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v - - return state_dict - - def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0): for _, attn_module in text_encoder_attn_modules(text_encoder): if isinstance(attn_module.q_proj, PatchedLoraProjection): From 93f1a14cab56364494ba8aba916f9cdf56f8e3b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Romero?= Date: Tue, 21 Nov 2023 19:59:29 +0100 Subject: [PATCH 8/8] ControlNet+Adapter pipeline, and ControlNet+Adapter+Inpaint pipeline (#5869) * ControlNet+Adapter pipeline, and +Inpaint pipeline --------- Co-authored-by: andres --- examples/community/README.md | 138 ++ ..._stable_diffusion_xl_controlnet_adapter.py | 1463 +++++++++++++ ...diffusion_xl_controlnet_adapter_inpaint.py | 1896 +++++++++++++++++ 3 files changed, 3497 insertions(+) create mode 100644 examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py create mode 100644 examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py diff --git a/examples/community/README.md b/examples/community/README.md index 36b34e8b0be5..96d530412979 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -2343,3 +2343,141 @@ images = pipe( assert len(images) == (len(prompts) - 1) * num_interpolation_steps ``` + +### ControlNet + T2I Adapter Pipeline +This pipelines combines both ControlNet and T2IAdapter into a single pipeline, where the forward pass is executed once. +It receives `control_image` and `adapter_image`, as well as `controlnet_conditioning_scale` and `adapter_conditioning_scale`, for the ControlNet and Adapter modules, respectively. Whenever `adapter_conditioning_scale = 0` or `controlnet_conditioning_scale = 0`, it will act as a full ControlNet module or as a full T2IAdapter module, respectively. + +```py +import cv2 +import numpy as np +import torch +from controlnet_aux.midas import MidasDetector +from PIL import Image + +from diffusers import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers.utils import load_image +from examples.community.pipeline_stable_diffusion_xl_controlnet_adapter import ( + StableDiffusionXLControlNetAdapterPipeline, +) + +controlnet_depth = ControlNetModel.from_pretrained( + "diffusers/controlnet-depth-sdxl-1.0", + torch_dtype=torch.float16, + variant="fp16", + use_safetensors=True +) +adapter_depth = T2IAdapter.from_pretrained( + "TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True +) +vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True) + +pipe = StableDiffusionXLControlNetAdapterPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + controlnet=controlnet_depth, + adapter=adapter_depth, + vae=vae, + variant="fp16", + use_safetensors=True, + torch_dtype=torch.float16, +) +pipe = pipe.to("cuda") +pipe.enable_xformers_memory_efficient_attention() +# pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2) +midas_depth = MidasDetector.from_pretrained( + "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large" +).to("cuda") + +prompt = "a tiger sitting on a park bench" +img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + +image = load_image(img_url).resize((1024, 1024)) + +depth_image = midas_depth( + image, detect_resolution=512, image_resolution=1024 +) + +strength = 0.5 + +images = pipe( + prompt, + control_image=depth_image, + adapter_image=depth_image, + num_inference_steps=30, + controlnet_conditioning_scale=strength, + adapter_conditioning_scale=strength, +).images +images[0].save("controlnet_and_adapter.png") + +``` + +### ControlNet + T2I Adapter + Inpainting Pipeline +```py +import cv2 +import numpy as np +import torch +from controlnet_aux.midas import MidasDetector +from PIL import Image + +from diffusers import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers.utils import load_image +from examples.community.pipeline_stable_diffusion_xl_controlnet_adapter_inpaint import ( + StableDiffusionXLControlNetAdapterInpaintPipeline, +) + +controlnet_depth = ControlNetModel.from_pretrained( + "diffusers/controlnet-depth-sdxl-1.0", + torch_dtype=torch.float16, + variant="fp16", + use_safetensors=True +) +adapter_depth = T2IAdapter.from_pretrained( + "TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True +) +vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True) + +pipe = StableDiffusionXLControlNetAdapterInpaintPipeline.from_pretrained( + "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", + controlnet=controlnet_depth, + adapter=adapter_depth, + vae=vae, + variant="fp16", + use_safetensors=True, + torch_dtype=torch.float16, +) +pipe = pipe.to("cuda") +pipe.enable_xformers_memory_efficient_attention() +# pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2) +midas_depth = MidasDetector.from_pretrained( + "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large" +).to("cuda") + +prompt = "a tiger sitting on a park bench" +img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" +mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + +image = load_image(img_url).resize((1024, 1024)) +mask_image = load_image(mask_url).resize((1024, 1024)) + +depth_image = midas_depth( + image, detect_resolution=512, image_resolution=1024 +) + +strength = 0.4 + +images = pipe( + prompt, + image=image, + mask_image=mask_image, + control_image=depth_image, + adapter_image=depth_image, + num_inference_steps=30, + controlnet_conditioning_scale=strength, + adapter_conditioning_scale=strength, + strength=0.7, +).images +images[0].save("controlnet_and_adapter_inpaint.png") + +``` \ No newline at end of file diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py new file mode 100644 index 000000000000..d801de86cc70 --- /dev/null +++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py @@ -0,0 +1,1463 @@ +# Copyright 2023 TencentARC and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter, UNet2DConditionModel +from diffusers.models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + PIL_INTERPOLATION, + USE_PEFT_BACKEND, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import is_compiled_module, randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import T2IAdapter, StableDiffusionXLAdapterPipeline, DDPMScheduler + >>> from diffusers.utils import load_image + >>> from controlnet_aux.midas import MidasDetector + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> image = load_image(img_url).resize((1024, 1024)) + >>> mask_image = load_image(mask_url).resize((1024, 1024)) + + >>> midas_depth = MidasDetector.from_pretrained( + ... "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large" + ... ).to("cuda") + + >>> depth_image = midas_depth( + ... image, detect_resolution=512, image_resolution=1024 + ... ) + + >>> model_id = "stabilityai/stable-diffusion-xl-base-1.0" + + >>> adapter = T2IAdapter.from_pretrained( + ... "Adapter/t2iadapter", + ... subfolder="sketch_sdxl_1.0", + ... torch_dtype=torch.float16, + ... adapter_type="full_adapter_xl", + ... ) + + >>> controlnet = ControlNetModel.from_pretrained( + ... "diffusers/controlnet-depth-sdxl-1.0", + ... torch_dtype=torch.float16, + ... variant="fp16", + ... use_safetensors=True + ... ).to("cuda") + + >>> scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler") + + >>> pipe = StableDiffusionXLAdapterPipeline.from_pretrained( + ... model_id, + ... adapter=adapter, + ... controlnet=controlnet, + ... torch_dtype=torch.float16, + ... variant="fp16", + ... scheduler=scheduler + ... ).to("cuda") + + >>> strength = 0.5 + + >>> generator = torch.manual_seed(42) + >>> sketch_image_out = pipe( + ... prompt="a photo of a tiger sitting on a park bench", + ... negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality", + ... adapter_image=depth_image, + ... control_image=mask_image, + ... adapter_conditioning_scale=strength, + ... controlnet_conditioning_scale=strength, + ... generator=generator, + ... guidance_scale=7.5, + ... ).images[0] + ``` +""" + + +def _preprocess_adapter_image(image, height, width): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image] + image = [ + i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image + ] # expand [h, w] or [h, w, c] to [b, h, w, c] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + if image[0].ndim == 3: + image = torch.stack(image, dim=0) + elif image[0].ndim == 4: + image = torch.cat(image, dim=0) + else: + raise ValueError( + f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}" + ) + return image + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionXLControlNetAdapterPipeline( + DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter + https://arxiv.org/abs/2302.08453 + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a + list, the outputs from each Adapter are added together to create one combined additional conditioning. + adapter_weights (`List[float]`, *optional*, defaults to None): + List of floats representing the weight which will be multiply to each adapter's output before adding them + together. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]], + controlnet: Union[ControlNetModel, MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + ): + super().__init__() + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + adapter=adapter, + controlnet=controlnet, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.default_sample_size = self.unet.config.sample_size + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + def check_conditions( + self, + prompt, + prompt_embeds, + adapter_image, + control_image, + adapter_conditioning_scale, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + ): + # controlnet checks + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Check controlnet `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(control_image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(control_image, list): + raise TypeError("For multiple controlnets: `control_image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in control_image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(control_image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(control_image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in control_image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + # adapter checks + if isinstance(self.adapter, T2IAdapter) or is_compiled and isinstance(self.adapter._orig_mod, T2IAdapter): + self.check_image(adapter_image, prompt, prompt_embeds) + elif ( + isinstance(self.adapter, MultiAdapter) or is_compiled and isinstance(self.adapter._orig_mod, MultiAdapter) + ): + if not isinstance(adapter_image, list): + raise TypeError("For multiple adapters: `adapter_image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in adapter_image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(adapter_image) != len(self.adapter.adapters): + raise ValueError( + f"For multiple adapters: `image` must have the same length as the number of adapters, but got {len(adapter_image)} images and {len(self.adapters.nets)} Adapters." + ) + + for image_ in adapter_image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `adapter_conditioning_scale` + if isinstance(self.adapter, T2IAdapter) or is_compiled and isinstance(self.adapter._orig_mod, T2IAdapter): + if not isinstance(adapter_conditioning_scale, float): + raise TypeError("For single adapter: `adapter_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.adapter, MultiAdapter) or is_compiled and isinstance(self.adapter._orig_mod, MultiAdapter) + ): + if isinstance(adapter_conditioning_scale, list): + if any(isinstance(i, list) for i in adapter_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(adapter_conditioning_scale, list) and len(adapter_conditioning_scale) != len( + self.adapter.adapters + ): + raise ValueError( + "For multiple adapters: When `adapter_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of adapters" + ) + else: + assert False + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.t2i_adapter.pipeline_stable_diffusion_adapter.StableDiffusionAdapterPipeline._default_height_width + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[-2] + + # round down to nearest multiple of `self.adapter.downscale_factor` + height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[-1] + + # round down to nearest multiple of `self.adapter.downscale_factor` + width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor + + return height, width + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + adapter_image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + adapter_conditioning_scale: Union[float, List[float]] = 1.0, + adapter_conditioning_factor: float = 1.0, + clip_skip: Optional[int] = None, + controlnet_conditioning_scale=1.0, + guess_mode: bool = False, + control_guidance_start: float = 0.0, + control_guidance_end: float = 1.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + adapter_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[List[PIL.Image.Image]]`): + The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the + type is specified as `Torch.FloatTensor`, it is passed to Adapter as is. PIL.Image.Image` can also be + accepted as an image. The control image is automatically resized to fit the output image. + control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be + accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in + `init`, images must be passed as a list such that each element of the list can be correctly batched for + input to a single ControlNet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionAdapterPipelineOutput`] + instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the + residual in the original unet. If multiple adapters are specified in init, you can set the + corresponding scale as a list. + adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the + residual in the original unet. If multiple adapters are specified in init, you can set the + corresponding scale as a list. + adapter_conditioning_factor (`float`, *optional*, defaults to 1.0): + The fraction of timesteps for which adapter should be applied. If `adapter_conditioning_factor` is + `0.0`, adapter is not applied at all. If `adapter_conditioning_factor` is `1.0`, adapter is applied for + all timesteps. If `adapter_conditioning_factor` is `0.5`, adapter is applied for half of the timesteps. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + adapter = self.adapter._orig_mod if is_compiled_module(self.adapter) else self.adapter + + # 0. Default height and width to unet + + height, width = self._default_height_width(height, width, adapter_image) + device = self._execution_device + + if isinstance(adapter, MultiAdapter): + adapter_input = [] + + for one_image in adapter_image: + one_image = _preprocess_adapter_image(one_image, height, width) + one_image = one_image.to(device=device, dtype=adapter.dtype) + adapter_input.append(one_image) + else: + adapter_input = _preprocess_adapter_image(adapter_image, height, width) + adapter_input = adapter_input.to(device=device, dtype=adapter.dtype) + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 0.1 align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + if isinstance(adapter, MultiAdapter) and isinstance(adapter_conditioning_scale, float): + adapter_conditioning_scale = [adapter_conditioning_scale] * len(adapter.adapters) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + + self.check_conditions( + prompt, + prompt_embeds, + adapter_image, + control_image, + adapter_conditioning_scale, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + clip_skip=clip_skip, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings & adapter features + if isinstance(adapter, MultiAdapter): + adapter_state = adapter(adapter_input, adapter_conditioning_scale) + for k, v in enumerate(adapter_state): + adapter_state[k] = v + else: + adapter_state = adapter(adapter_input) + for k, v in enumerate(adapter_state): + adapter_state[k] = v * adapter_conditioning_scale + if num_images_per_prompt > 1: + for k, v in enumerate(adapter_state): + adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1) + if do_classifier_free_guidance: + for k, v in enumerate(adapter_state): + adapter_state[k] = torch.cat([v] * 2, dim=0) + + # 7.2 Prepare control images + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(control_image_) + + control_image = control_images + else: + raise ValueError(f"{controlnet.__class__} is not supported.") + + # 8.2 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + if isinstance(self.controlnet, MultiControlNetModel): + controlnet_keep.append(keeps) + else: + controlnet_keep.append(keeps[0]) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 7.1 Apply denoising_end + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + if i < int(num_inference_steps * adapter_conditioning_factor): + down_intrablock_additional_residuals = [state.clone() for state in adapter_state] + else: + down_intrablock_additional_residuals = None + + # ----------- ControlNet + + # expand the latents if we are doing classifier free guidance + latent_model_input_controlnet = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input_controlnet = self.scheduler.scale_model_input(latent_model_input_controlnet, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input_controlnet + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + down_intrablock_additional_residuals=down_intrablock_additional_residuals, # t2iadapter + down_block_additional_residuals=down_block_res_samples, # controlnet + mid_block_additional_residual=mid_block_res_sample, # controlnet + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py new file mode 100644 index 000000000000..bc612edbc20e --- /dev/null +++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py @@ -0,0 +1,1896 @@ +# Copyright 2023 Jake Babbidge, TencentARC and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ignore the entire file for precommit +# type: ignore + +import inspect +from collections.abc import Callable +from typing import Any, List, Optional, Union + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from transformers import ( + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, +) + +from diffusers import DiffusionPipeline +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import ( + FromSingleFileMixin, + LoraLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from diffusers.models import ( + AutoencoderKL, + ControlNetModel, + MultiAdapter, + T2IAdapter, + UNet2DConditionModel, +) +from diffusers.models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + PIL_INTERPOLATION, + USE_PEFT_BACKEND, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import is_compiled_module, randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import DiffusionPipeline, T2IAdapter + >>> from diffusers.utils import load_image + >>> from PIL import Image + >>> from controlnet_aux.midas import MidasDetector + + >>> adapter = T2IAdapter.from_pretrained( + ... "TencentARC/t2i-adapter-sketch-sdxl-1.0", torch_dtype=torch.float16, variant="fp16" + ... ).to("cuda") + + >>> controlnet = ControlNetModel.from_pretrained( + ... "diffusers/controlnet-depth-sdxl-1.0", + ... torch_dtype=torch.float16, + ... variant="fp16", + ... use_safetensors=True + ... ).to("cuda") + + >>> pipe = DiffusionPipeline.from_pretrained( + ... "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", + ... torch_dtype=torch.float16, + ... variant="fp16", + ... use_safetensors=True, + ... custom_pipeline="stable_diffusion_xl_adapter_controlnet_inpaint", + ... adapter=adapter, + ... controlnet=controlnet, + ... ).to("cuda") + + >>> prompt = "a tiger sitting on a park bench" + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> image = load_image(img_url).resize((1024, 1024)) + >>> mask_image = load_image(mask_url).resize((1024, 1024)) + + >>> midas_depth = MidasDetector.from_pretrained( + ... "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large" + ... ).to("cuda") + + >>> depth_image = midas_depth( + ... image, detect_resolution=512, image_resolution=1024 + ... ) + + >>> strength = 0.4 + + >>> generator = torch.manual_seed(42) + + >>> result_image = pipe( + ... image=image, + ... mask_image=mask, + ... adapter_image=depth_image, + ... control_image=depth_image, + ... controlnet_conditioning_scale=strength, + ... adapter_conditioning_scale=strength, + ... strength=0.7, + ... generator=generator, + ... prompt=prompt, + ... negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality", + ... num_inference_steps=50 + ... ).images[0] + ``` +""" + + +def _preprocess_adapter_image(image, height, width): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image] + image = [ + i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image + ] # expand [h, w] or [h, w, c] to [b, h, w, c] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + if image[0].ndim == 3: + image = torch.stack(image, dim=0) + elif image[0].ndim == 4: + image = torch.cat(image, dim=0) + else: + raise ValueError( + f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}" + ) + return image + + +def mask_pil_to_torch(mask, height, width): + # preprocess mask + if isinstance(mask, Union[PIL.Image.Image, np.ndarray]): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask = torch.from_numpy(mask) + return mask + + +def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False): + """ + Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + + # checkpoint. TOD(Yiyi) - need to clean this up later + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask is None: + raise ValueError("`mask_image` input cannot be undefined.") + + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + mask = mask_pil_to_torch(mask, height, width) + + if image.ndim == 3: + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + # if image.min() < -1 or image.max() > 1: + # raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + # preprocess image + if isinstance(image, Union[PIL.Image.Image, np.ndarray]): + image = [image] + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + # resize all images w.r.t passed height an width + image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + mask = mask_pil_to_torch(mask, height, width) + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + if image.shape[1] == 4: + # images are in latent space and thus can't + # be masked set masked_image to None + # we assume that the checkpoint is not an inpainting + # checkpoint. TOD(Yiyi) - need to clean this up later + masked_image = None + else: + masked_image = image * (mask < 0.5) + + # n.b. ensure backwards compatibility as old function does not return image + if return_image: + return mask, masked_image, image + + return mask, masked_image + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionXLControlNetAdapterInpaintPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter + https://arxiv.org/abs/2302.08453 + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a + list, the outputs from each Adapter are added together to create one combined additional conditioning. + adapter_weights (`List[float]`, *optional*, defaults to None): + List of floats representing the weight which will be multiply to each adapter's output before adding them + together. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): + Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config + of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + adapter: Union[T2IAdapter, MultiAdapter], + controlnet: Union[ControlNetModel, MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + ): + super().__init__() + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + adapter=adapter, + controlnet=controlnet, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.default_sample_size = self.unet.config.sample_size + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + def check_conditions( + self, + prompt, + prompt_embeds, + adapter_image, + control_image, + adapter_conditioning_scale, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + ): + # controlnet checks + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Check controlnet `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(control_image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(control_image, list): + raise TypeError("For multiple controlnets: `control_image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in control_image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(control_image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(control_image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in control_image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + # adapter checks + if isinstance(self.adapter, T2IAdapter) or is_compiled and isinstance(self.adapter._orig_mod, T2IAdapter): + self.check_image(adapter_image, prompt, prompt_embeds) + elif ( + isinstance(self.adapter, MultiAdapter) or is_compiled and isinstance(self.adapter._orig_mod, MultiAdapter) + ): + if not isinstance(adapter_image, list): + raise TypeError("For multiple adapters: `adapter_image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in adapter_image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(adapter_image) != len(self.adapter.adapters): + raise ValueError( + f"For multiple adapters: `image` must have the same length as the number of adapters, but got {len(adapter_image)} images and {len(self.adapters.nets)} Adapters." + ) + + for image_ in adapter_image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `adapter_conditioning_scale` + if isinstance(self.adapter, T2IAdapter) or is_compiled and isinstance(self.adapter._orig_mod, T2IAdapter): + if not isinstance(adapter_conditioning_scale, float): + raise TypeError("For single adapter: `adapter_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.adapter, MultiAdapter) or is_compiled and isinstance(self.adapter._orig_mod, MultiAdapter) + ): + if isinstance(adapter_conditioning_scale, list): + if any(isinstance(i, list) for i in adapter_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(adapter_conditioning_scale, list) and len(adapter_conditioning_scale) != len( + self.adapter.adapters + ): + raise ValueError( + "For multiple adapters: When `adapter_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of adapters" + ) + else: + assert False + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if image.shape[1] == 4: + image_latents = image.to(device=device, dtype=dtype) + elif return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + dtype = image.dtype + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + image_latents = image_latents.to(dtype) + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + height, + width, + dtype, + device, + generator, + do_classifier_free_guidance, + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, + size=( + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ), + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + + masked_image_latents = None + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + else: + t_start = 0 + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + if denoising_start is not None: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + timesteps = timesteps[-num_inference_steps:] + return timesteps, num_inference_steps + + return timesteps, num_inference_steps - t_start + + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.t2i_adapter.pipeline_stable_diffusion_adapter.StableDiffusionAdapterPipeline._default_height_width + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[-2] + + # round down to nearest multiple of `self.adapter.downscale_factor` + height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[-1] + + # round down to nearest multiple of `self.adapter.downscale_factor` + width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor + + return height, width + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, list[str]]] = None, + prompt_2: Optional[Union[str, list[str]]] = None, + image: Optional[Union[torch.Tensor, PIL.Image.Image]] = None, + mask_image: Optional[Union[torch.Tensor, PIL.Image.Image]] = None, + adapter_image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.9999, + num_inference_steps: int = 50, + denoising_start: Optional[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, list[str]]] = None, + negative_prompt_2: Optional[Union[str, list[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None, + latents: Optional[Union[torch.FloatTensor]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[tuple[int, int]] = None, + crops_coords_top_left: Optional[tuple[int, int]] = (0, 0), + target_size: Optional[tuple[int, int]] = None, + adapter_conditioning_scale: Optional[Union[float, list[float]]] = 1.0, + cond_tau: float = 1.0, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + controlnet_conditioning_scale=1.0, + guess_mode: bool = False, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + adapter_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[List[PIL.Image.Image]]`): + The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the + type is specified as `Torch.FloatTensor`, it is passed to Adapter as is. PIL.Image.Image` can also be + accepted as an image. The control image is automatically resized to fit the output image. + control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be + accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in + `init`, images must be passed as a list such that each element of the list can be correctly batched for + input to a single ControlNet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and + it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, + strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline + is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionAdapterPipelineOutput`] + instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the + residual in the original unet. If multiple adapters are specified in init, you can set the + corresponding scale as a list. + adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the + residual in the original unet. If multiple adapters are specified in init, you can set the + corresponding scale as a list. + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + adapter = self.adapter._orig_mod if is_compiled_module(self.adapter) else self.adapter + height, width = self._default_height_width(height, width, adapter_image) + device = self._execution_device + + adapter_input = _preprocess_adapter_image(adapter_image, height, width).to(device) + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 0.1 align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + if isinstance(adapter, MultiAdapter) and isinstance(adapter_conditioning_scale, float): + adapter_conditioning_scale = [adapter_conditioning_scale] * len(adapter.nets) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + + self.check_conditions( + prompt, + prompt_embeds, + adapter_image, + control_image, + adapter_conditioning_scale, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + + # 4. set timesteps + def denoising_value_valid(dnv): + return isinstance(denoising_end, float) and 0 < dnv < 1 + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=denoising_start if denoising_value_valid else None, + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Preprocess mask and image - resizes image and mask w.r.t height and width + mask, masked_image, init_image = prepare_mask_and_masked_image( + image, mask_image, height, width, return_image=True + ) + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + add_noise = denoising_start is None + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + add_noise=add_noise, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Prepare added time ids & embeddings & adapter features + adapter_input = adapter_input.type(latents.dtype) + adapter_state = adapter(adapter_input) + for k, v in enumerate(adapter_state): + adapter_state[k] = v * adapter_conditioning_scale + if num_images_per_prompt > 1: + for k, v in enumerate(adapter_state): + adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1) + if do_classifier_free_guidance: + for k, v in enumerate(adapter_state): + adapter_state[k] = torch.cat([v] * 2, dim=0) + + # 10.2 Prepare control images + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(control_image_) + + control_image = control_images + else: + raise ValueError(f"{controlnet.__class__} is not supported.") + + # 8.2 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + if isinstance(self.controlnet, MultiControlNetModel): + controlnet_keep.append(keeps) + else: + controlnet_keep.append(keeps[0]) + # ---------------------------------------------------------------- + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + # 11. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 11.1 Apply denoising_end + if ( + denoising_end is not None + and denoising_start is not None + and denoising_value_valid(denoising_end) + and denoising_value_valid(denoising_start) + and denoising_start >= denoising_end + ): + raise ValueError( + f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: " + + f" {denoising_end} when using type float." + ) + elif denoising_end is not None and denoising_value_valid(denoising_end): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # predict the noise residual + added_cond_kwargs = { + "text_embeds": add_text_embeds, + "time_ids": add_time_ids, + } + + if i < int(num_inference_steps * cond_tau): + down_block_additional_residuals = [state.clone() for state in adapter_state] + else: + down_block_additional_residuals = None + + # ----------- ControlNet + + # expand the latents if we are doing classifier free guidance + latent_model_input_controlnet = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input_controlnet = self.scheduler.scale_model_input(latent_model_input_controlnet, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input_controlnet + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + down_intrablock_additional_residuals=down_block_additional_residuals, # t2iadapter + down_block_additional_residuals=down_block_res_samples, # controlnet + mid_block_additional_residual=mid_block_res_sample, # controlnet + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_text, + guidance_rescale=guidance_rescale, + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + **extra_step_kwargs, + return_dict=False, + )[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents + if do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, + noise, + torch.tensor([noise_timestep]), + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + if output_type != "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image)