diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..728f512 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023, Replicate, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 6da2ab3..53e1bf8 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,53 @@ -# FLUX (in cog!) +# cog-flux -This is a repository for running flux-dev and flux-schnell within a cog container. +This is a [Cog](https://cog.run) inference model for FLUX.1 [schnell] and FLUX.1 [dev] by [Black Forest Labs](https://blackforestlabs.ai/). It powers the following Replicate models: -## How to use this repo +* https://replicate.com/black-forest-labs/flux-schnell +* https://replicate.com/black-forest-labs/flux-dev -### Selecting a model +## Features -run `script/select.sh (dev,schnell)` and that'll create a cog.yaml configured for the appropriate model. +* Compilation with `torch.compile` +* Optional fp8 quantization based on [aredden/flux-fp8-api](https://github.com/aredden/flux-fp8-api), using fast CuDNN attention from Pytorch nightlies +* NSFW checking with [CompVis](https://huggingface.co/CompVis/stable-diffusion-safety-checker) and [Falcons.ai](https://huggingface.co/Falconsai/nsfw_image_detection) safety checkers +* img2img support -### Pushing a model +## Getting started -run `script/push.sh (dev,schnell) (test, prod)` to push the model to Replicate. +If you just want to use the models, you can run [FLUX.1 [schnell]](https://replicate.com/black-forest-labs/flux-schnell) and [FLUX.1 [dev]](https://replicate.com/black-forest-labs/flux-dev) on Replicate with an API or in the browser. -To push all models, run `script/prod-deploy-all.sh`. Note that after doing this you'll still need to manually go in and update deployments. +The code in this repo can be used as a template for customizations on FLUX.1, or to run the models on your own hardware. + +First you need to select which model to run: + +```shell +$ script/select.sh {dev,schnell} +``` + +Then you can run a single prediction on the model using: + +```shell +$ cog predict -i prompt="a cat in a hat" +``` + +For more documentation about how to interact with Cog models and push customized FLUX.1 models to Replicate: +* The [Cog getting started guide](https://cog.run/getting-started/) explains what Cog is and how it works +* [This guide](https://replicate.com/docs/guides/push-a-model) describes how to push a model to Replicate + +## Contributing + +Pull requests and issues are welcome! If you see a novel technique or feature you think will make FLUX.1 inference better or faster, let us know and we'll do our best to integrate it. + +## Rough, partial roadmap + +* Serialize quantized model instead of quantizing on the fly +* Use row-wise quantization +* Port quantization and compilation code over to https://github.com/replicate/flux-fine-tuner + +## License + +The code in this repository is licensed under the [Apache-2.0 License](LICENSE). + +FLUX.1 [dev] falls under the [`FLUX.1 [dev]` Non-Commercial License](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md). + +FLUX.1 [schnell] falls under the [Apache-2.0 License](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md). diff --git a/demo_st.py b/demo_st.py deleted file mode 100644 index 30764a5..0000000 --- a/demo_st.py +++ /dev/null @@ -1,185 +0,0 @@ -import os -import time -from glob import glob -from io import BytesIO - -import streamlit as st -import torch -from einops import rearrange -from fire import Fire -from PIL import Image -from st_keyup import st_keyup - -from flux.cli import SamplingOptions -from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack -from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5 - - -@st.cache_resource() -def get_models(name: str, device: torch.device, offload: bool, quantize_flow: bool): - t5 = load_t5(device) - clip = load_clip(device) - model = load_flow_model(name, device="cpu" if offload else device, quantize=quantize_flow) - ae = load_ae(name, device="cpu" if offload else device) - return model, ae, t5, clip - - -@torch.inference_mode() -def main( - quantize_flow: bool = False, - device: str = "cuda" if torch.cuda.is_available() else "cpu", - offload: bool = False, - output_dir: str = "output", -): - torch_device = torch.device(device) - names = list(configs.keys()) - name = st.selectbox("Which model to load?", names) - if name is None or not st.checkbox("Load model", False): - return - - model, ae, t5, clip = get_models( - name, - device=torch_device, - offload=offload, - quantize_flow=quantize_flow, - ) - is_schnell = name == "flux-schnell" - - # allow for packing and conversion to latent space - width = int(16 * (st.number_input("Width", min_value=128, max_value=8192, value=1024) // 16)) - height = int(16 * (st.number_input("Height", min_value=128, max_value=8192, value=1024) // 16)) - num_steps = int(st.number_input("Number of steps?", min_value=1, value=(4 if is_schnell else 50))) - guidance = float(st.number_input("Guidance", min_value=1.0, value=3.5, disabled=is_schnell)) - seed = int(st.number_input("Seed (-1 to disable)", min_value=-1, value=-1, disabled=is_schnell)) - if seed == -1: - seed = None - save_samples = st.checkbox("Save samples?", not is_schnell) - - default_prompt = ( - "a photo of a forest with mist swirling around the tree trunks. The word " - '"FLUX" is painted over it in big, red brush strokes with visible texture' - ) - prompt = st_keyup("Enter a prompt", value=default_prompt, debounce=300, key="interactive_text") - - output_name = os.path.join(output_dir, "img_{idx}.png") - if not os.path.exists(output_dir): - os.makedirs(output_dir) - idx = 0 - elif len(os.listdir(output_dir)) > 0: - fns = glob(output_name.format(idx="*")) - idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 - else: - idx = 0 - - rng = torch.Generator(device="cpu") - - if "seed" not in st.session_state: - st.session_state.seed = rng.seed() - - def increment_counter(): - st.session_state.seed += 1 - - def decrement_counter(): - if st.session_state.seed > 0: - st.session_state.seed -= 1 - - opts = SamplingOptions( - prompt=prompt, - width=width, - height=height, - num_steps=num_steps, - guidance=guidance, - seed=seed, - ) - - if name == "flux-schnell": - cols = st.columns([5, 1, 1, 5]) - with cols[1]: - st.button("↩", on_click=increment_counter) - with cols[2]: - st.button("↪", on_click=decrement_counter) - if is_schnell or st.button("Sample"): - if is_schnell: - opts.seed = st.session_state.seed - elif opts.seed is None: - opts.seed = rng.seed() - print(f"Generating '{opts.prompt}' with seed {opts.seed}") - t0 = time.perf_counter() - - # prepare input - x = get_noise( - 1, - opts.height, - opts.width, - device=torch_device, - dtype=torch.bfloat16, - seed=opts.seed, - ) - if offload: - ae = ae.cpu() - torch.cuda.empty_cache() - t5, clip = t5.to(torch_device), clip.to(torch_device) - inp = prepare(t5=t5, clip=clip, img=x, prompt=opts.prompt) - timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(not is_schnell)) - - # offload TEs to CPU, load model to gpu - if offload: - t5, clip = t5.cpu(), clip.cpu() - torch.cuda.empty_cache() - model = model.to(torch_device) - - # denoise initial noise - x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) - - # offload model, load autoencoder to gpu - if offload: - model.cpu() - torch.cuda.empty_cache() - ae.decoder.to(x.device) - - # decode latents to pixel space - x = unpack(x.float(), opts.height, opts.width) - with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): - x = ae.decode(x) - t1 = time.perf_counter() - - fn = output_name.format(idx=idx) - print(f"Done in {t1 - t0:.1f}s.") - # bring into PIL format and save - x = rearrange(x[0], "c h w -> h w c").clamp(-1, 1) - img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) - if save_samples: - print(f"Saving {fn}") - img.save(fn) - idx += 1 - - st.session_state["samples"] = { - "prompt": opts.prompt, - "img": img, - "seed": opts.seed, - } - opts.seed = None - - samples = st.session_state.get("samples", None) - if samples is not None: - st.image(samples["img"], caption=samples["prompt"]) - if "bytes" not in samples: - img: Image.Image = samples["img"] - buffer = BytesIO() - img.save(buffer, format="png") - samples["bytes"] = buffer.getvalue() - st.download_button( - "Download full-resolution", - samples["bytes"], - file_name="generated.png", - mime="image/png", - ) - st.write(f"Seed: {samples['seed']}") - - -def app(): - Fire(main) - - -if __name__ == "__main__": - app() diff --git a/flux_emphasis.py b/flux_emphasis.py deleted file mode 100644 index efb74ac..0000000 --- a/flux_emphasis.py +++ /dev/null @@ -1,444 +0,0 @@ -from typing import TYPE_CHECKING, Optional -from pydash import flatten - -import torch -from transformers.models.clip.tokenization_clip import CLIPTokenizer -from einops import repeat - -if TYPE_CHECKING: - from flux_pipeline import FluxPipeline - - -def parse_prompt_attention(text): - """ - Parses a string with attention tokens and returns a list of pairs: text and its associated weight. - Accepted tokens are: - (abc) - increases attention to abc by a multiplier of 1.1 - (abc:3.12) - increases attention to abc by a multiplier of 3.12 - [abc] - decreases attention to abc by a multiplier of 1.1 - \\( - literal character '(' - \\[ - literal character '[' - \\) - literal character ')' - \\] - literal character ']' - \\ - literal character '\' - anything else - just text - - >>> parse_prompt_attention('normal text') - [['normal text', 1.0]] - >>> parse_prompt_attention('an (important) word') - [['an ', 1.0], ['important', 1.1], [' word', 1.0]] - >>> parse_prompt_attention('(unbalanced') - [['unbalanced', 1.1]] - >>> parse_prompt_attention('\\(literal\\]') - [['(literal]', 1.0]] - >>> parse_prompt_attention('(unnecessary)(parens)') - [['unnecessaryparens', 1.1]] - >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') - [['a ', 1.0], - ['house', 1.5730000000000004], - [' ', 1.1], - ['on', 1.0], - [' a ', 1.1], - ['hill', 0.55], - [', sun, ', 1.1], - ['sky', 1.4641000000000006], - ['.', 1.1]] - """ - import re - - re_attention = re.compile( - r""" - \\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)| - \)|]|[^\\()\[\]:]+|: - """, - re.X, - ) - - re_break = re.compile(r"\s*\bBREAK\b\s*", re.S) - - res = [] - round_brackets = [] - square_brackets = [] - - round_bracket_multiplier = 1.1 - square_bracket_multiplier = 1 / 1.1 - - def multiply_range(start_position, multiplier): - for p in range(start_position, len(res)): - res[p][1] *= multiplier - - for m in re_attention.finditer(text): - text = m.group(0) - weight = m.group(1) - - if text.startswith("\\"): - res.append([text[1:], 1.0]) - elif text == "(": - round_brackets.append(len(res)) - elif text == "[": - square_brackets.append(len(res)) - elif weight is not None and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), float(weight)) - elif text == ")" and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), round_bracket_multiplier) - elif text == "]" and len(square_brackets) > 0: - multiply_range(square_brackets.pop(), square_bracket_multiplier) - else: - parts = re.split(re_break, text) - for i, part in enumerate(parts): - if i > 0: - res.append(["BREAK", -1]) - res.append([part, 1.0]) - - for pos in round_brackets: - multiply_range(pos, round_bracket_multiplier) - - for pos in square_brackets: - multiply_range(pos, square_bracket_multiplier) - - if len(res) == 0: - res = [["", 1.0]] - - # merge runs of identical weights - i = 0 - while i + 1 < len(res): - if res[i][1] == res[i + 1][1]: - res[i][0] += res[i + 1][0] - res.pop(i + 1) - else: - i += 1 - - return res - - -def get_prompts_tokens_with_weights( - clip_tokenizer: CLIPTokenizer, prompt: str, debug: bool = False -): - """ - Get prompt token ids and weights, this function works for both prompt and negative prompt - - Args: - pipe (CLIPTokenizer) - A CLIPTokenizer - prompt (str) - A prompt string with weights - - Returns: - text_tokens (list) - A list contains token ids - text_weight (list) - A list contains the correspodent weight of token ids - - Example: - import torch - from transformers import CLIPTokenizer - - clip_tokenizer = CLIPTokenizer.from_pretrained( - "stablediffusionapi/deliberate-v2" - , subfolder = "tokenizer" - , dtype = torch.float16 - ) - - token_id_list, token_weight_list = get_prompts_tokens_with_weights( - clip_tokenizer = clip_tokenizer - ,prompt = "a (red:1.5) cat"*70 - ) - """ - texts_and_weights = parse_prompt_attention(prompt) - text_tokens, text_weights = [], [] - maxlen = clip_tokenizer.model_max_length - for word, weight in texts_and_weights: - # tokenize and discard the starting and the ending token - token = clip_tokenizer( - word, truncation=False, padding=False, add_special_tokens=False - ).input_ids - # so that tokenize whatever length prompt - # the returned token is a 1d list: [320, 1125, 539, 320] - if debug: - print( - token, - "|FOR MODEL LEN{}|".format(maxlen), - clip_tokenizer.decode( - token, skip_special_tokens=True, clean_up_tokenization_spaces=True - ), - ) - # merge the new tokens to the all tokens holder: text_tokens - text_tokens = [*text_tokens, *token] - - # each token chunk will come with one weight, like ['red cat', 2.0] - # need to expand weight for each token. - chunk_weights = [weight] * len(token) - - # append the weight back to the weight holder: text_weights - text_weights = [*text_weights, *chunk_weights] - return text_tokens, text_weights - - -def group_tokens_and_weights( - token_ids: list, - weights: list, - pad_last_block=False, - bos=49406, - eos=49407, - max_length=77, - pad_tokens=True, -): - """ - Produce tokens and weights in groups and pad the missing tokens - - Args: - token_ids (list) - The token ids from tokenizer - weights (list) - The weights list from function get_prompts_tokens_with_weights - pad_last_block (bool) - Control if fill the last token list to 75 tokens with eos - Returns: - new_token_ids (2d list) - new_weights (2d list) - - Example: - token_groups,weight_groups = group_tokens_and_weights( - token_ids = token_id_list - , weights = token_weight_list - ) - """ - max_len = max_length - 2 if max_length < 77 else max_length - # this will be a 2d list - new_token_ids = [] - new_weights = [] - while len(token_ids) >= max_len: - # get the first 75 tokens - head_75_tokens = [token_ids.pop(0) for _ in range(max_len)] - head_75_weights = [weights.pop(0) for _ in range(max_len)] - - # extract token ids and weights - - if pad_tokens: - if bos is not None: - temp_77_token_ids = [bos] + head_75_tokens + [eos] - temp_77_weights = [1.0] + head_75_weights + [1.0] - else: - temp_77_token_ids = head_75_tokens + [eos] - temp_77_weights = head_75_weights + [1.0] - - # add 77 token and weights chunk to the holder list - new_token_ids.append(temp_77_token_ids) - new_weights.append(temp_77_weights) - - # padding the left - if len(token_ids) > 0: - if pad_tokens: - padding_len = max_len - len(token_ids) if pad_last_block else 0 - - temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos] - new_token_ids.append(temp_77_token_ids) - - temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0] - new_weights.append(temp_77_weights) - else: - new_token_ids.append(token_ids) - new_weights.append(weights) - return new_token_ids, new_weights - - -def standardize_tensor( - input_tensor: torch.Tensor, target_mean: float, target_std: float -) -> torch.Tensor: - """ - This function standardizes an input tensor so that it has a specific mean and standard deviation. - - Parameters: - input_tensor (torch.Tensor): The tensor to standardize. - target_mean (float): The target mean for the tensor. - target_std (float): The target standard deviation for the tensor. - - Returns: - torch.Tensor: The standardized tensor. - """ - - # First, compute the mean and std of the input tensor - mean = input_tensor.mean() - std = input_tensor.std() - - # Then, standardize the tensor to have a mean of 0 and std of 1 - standardized_tensor = (input_tensor - mean) / std - - # Finally, scale the tensor to the target mean and std - output_tensor = standardized_tensor * target_std + target_mean - - return output_tensor - - -def apply_weights( - prompt_tokens: torch.Tensor, - weight_tensor: torch.Tensor, - token_embedding: torch.Tensor, - eos_token_id: int, - pad_last_block: bool = True, -) -> torch.FloatTensor: - mean = token_embedding.mean() - std = token_embedding.std() - if pad_last_block: - pooled_tensor = token_embedding[ - torch.arange(token_embedding.shape[0], device=token_embedding.device), - ( - prompt_tokens.to(dtype=torch.int, device=token_embedding.device) - == eos_token_id - ) - .int() - .argmax(dim=-1), - ] - else: - pooled_tensor = token_embedding[:, -1] - - for j in range(len(weight_tensor)): - if weight_tensor[j] != 1.0: - token_embedding[:, j] = ( - pooled_tensor - + (token_embedding[:, j] - pooled_tensor) * weight_tensor[j] - ) - return standardize_tensor(token_embedding, mean, std) - - -@torch.inference_mode() -def get_weighted_text_embeddings_flux( - pipe: "FluxPipeline", - prompt: str = "", - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - target_device: Optional[torch.device] = torch.device("cuda:0"), - target_dtype: Optional[torch.dtype] = torch.bfloat16, - debug: bool = False, -): - """ - This function can process long prompt with weights, no length limitation - for Stable Diffusion XL - - Args: - pipe (StableDiffusionPipeline) - prompt (str) - prompt_2 (str) - neg_prompt (str) - neg_prompt_2 (str) - num_images_per_prompt (int) - device (torch.device) - Returns: - prompt_embeds (torch.Tensor) - neg_prompt_embeds (torch.Tensor) - """ - device = device or pipe._execution_device - - eos = pipe.clip.tokenizer.eos_token_id - eos_2 = pipe.t5.tokenizer.eos_token_id - bos = pipe.clip.tokenizer.bos_token_id - bos_2 = pipe.t5.tokenizer.bos_token_id - - clip = pipe.clip.hf_module - t5 = pipe.t5.hf_module - - tokenizer_clip = pipe.clip.tokenizer - tokenizer_t5 = pipe.t5.tokenizer - - t5_length = 512 if pipe.name == "flux-dev" else 256 - clip_length = 77 - - # tokenizer 1 - prompt_tokens_clip, prompt_weights_clip = get_prompts_tokens_with_weights( - tokenizer_clip, prompt, debug=debug - ) - - # tokenizer 2 - prompt_tokens_t5, prompt_weights_t5 = get_prompts_tokens_with_weights( - tokenizer_t5, prompt, debug=debug - ) - - prompt_tokens_clip_grouped, prompt_weights_clip_grouped = group_tokens_and_weights( - prompt_tokens_clip, - prompt_weights_clip, - pad_last_block=True, - bos=bos, - eos=eos, - max_length=clip_length, - ) - prompt_tokens_t5_grouped, prompt_weights_t5_grouped = group_tokens_and_weights( - prompt_tokens_t5, - prompt_weights_t5, - pad_last_block=True, - bos=bos_2, - eos=eos_2, - max_length=t5_length, - pad_tokens=False, - ) - prompt_tokens_t5 = flatten(prompt_tokens_t5_grouped) - prompt_weights_t5 = flatten(prompt_weights_t5_grouped) - prompt_tokens_clip = flatten(prompt_tokens_clip_grouped) - prompt_weights_clip = flatten(prompt_weights_clip_grouped) - - prompt_tokens_clip = tokenizer_clip.decode( - prompt_tokens_clip, skip_special_tokens=True, clean_up_tokenization_spaces=True - ) - prompt_tokens_clip = tokenizer_clip( - prompt_tokens_clip, - add_special_tokens=True, - padding="max_length", - truncation=True, - max_length=clip_length, - return_tensors="pt", - ).input_ids.to(device) - prompt_tokens_t5 = tokenizer_t5.decode( - prompt_tokens_t5, skip_special_tokens=True, clean_up_tokenization_spaces=True - ) - prompt_tokens_t5 = tokenizer_t5( - prompt_tokens_t5, - add_special_tokens=True, - padding="max_length", - truncation=True, - max_length=t5_length, - return_tensors="pt", - ).input_ids.to(device) - - prompt_weights_t5 = torch.cat( - [ - torch.tensor(prompt_weights_t5, dtype=torch.float32), - torch.full( - (t5_length - torch.tensor(prompt_weights_t5).numel(),), - 1.0, - dtype=torch.float32, - ), - ], - dim=0, - ).to(device) - - clip_embeds = clip( - prompt_tokens_clip, output_hidden_states=True, attention_mask=None - )["pooler_output"] - if clip_embeds.shape[0] == 1 and num_images_per_prompt > 1: - clip_embeds = repeat(clip_embeds, "1 ... -> bs ...", bs=num_images_per_prompt) - - weight_tensor_t5 = torch.tensor( - flatten(prompt_weights_t5), dtype=torch.float32, device=device - ) - t5_embeds = t5(prompt_tokens_t5, output_hidden_states=True, attention_mask=None)[ - "last_hidden_state" - ] - t5_embeds = apply_weights(prompt_tokens_t5, weight_tensor_t5, t5_embeds, eos_2) - if debug: - print(t5_embeds.shape) - if t5_embeds.shape[0] == 1 and num_images_per_prompt > 1: - t5_embeds = repeat(t5_embeds, "1 ... -> bs ...", bs=num_images_per_prompt) - txt_ids = torch.zeros( - num_images_per_prompt, - t5_embeds.shape[1], - 3, - device=target_device, - dtype=target_dtype, - ) - t5_embeds = t5_embeds.to(target_device, dtype=target_dtype) - clip_embeds = clip_embeds.to(target_device, dtype=target_dtype) - - return ( - clip_embeds, - t5_embeds, - txt_ids, - ) diff --git a/ruff.toml b/ruff.toml index aba6cd8..9499594 100644 --- a/ruff.toml +++ b/ruff.toml @@ -31,8 +31,6 @@ exclude = [ "util.py", "lora_loading.py", "flux_pipeline.py", - "flux_emphasis.py", - "demo_st.py", ] # Same as Black.