Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support flux example #1073

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions benchmarks/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default=MODEL)
parser.add_argument("--dtype", type=str, default="half")
parser.add_argument("--variant", type=str, default=VARIANT)
parser.add_argument("--custom-pipeline", type=str, default=CUSTOM_PIPELINE)
parser.add_argument("--scheduler", type=str, default=SCHEDULER)
Expand Down Expand Up @@ -92,6 +93,8 @@ def parse_args():
default=QUANTIZE_CONFIG,
)
parser.add_argument("--quant-submodules-config-path", type=str, default=None)
parser.add_argument("--revision", type=str, default=None)
parser.add_argument("--local-files-only", action="store_true")
return parser.parse_args()


Expand All @@ -108,13 +111,17 @@ def load_pipe(
scheduler=None,
lora=None,
controlnet=None,
revision=None,
local_files_only=False,
):
extra_kwargs = {}
if custom_pipeline is not None:
extra_kwargs["custom_pipeline"] = custom_pipeline
if variant is not None:
extra_kwargs["variant"] = variant
if dtype is not None:
dtype = getattr(torch, dtype)
assert isinstance(dtype, torch.dtype)
extra_kwargs["torch_dtype"] = dtype
if controlnet is not None:
from diffusers import ControlNetModel
Expand All @@ -124,6 +131,11 @@ def load_pipe(
torch_dtype=dtype,
)
extra_kwargs["controlnet"] = controlnet
if revision is not None:
extra_kwargs["revision"] = revision
if local_files_only:
extra_kwargs["local_files_only"] = True

if os.path.exists(os.path.join(model_name, "calibrate_info.txt")):
from onediff.quantization import QuantPipeline

Expand Down Expand Up @@ -231,11 +243,14 @@ def main():
pipe = load_pipe(
pipeline_cls,
args.model,
dtype=args.dtype,
variant=args.variant,
custom_pipeline=args.custom_pipeline,
scheduler=args.scheduler,
lora=args.lora,
controlnet=args.controlnet,
revision=args.revision,
local_files_only=args.local_files_only,
)

core_net = None
Expand Down Expand Up @@ -349,6 +364,13 @@ def get_kwarg_inputs():
kwarg_inputs["cache_block_id"] = args.cache_block_id
return kwarg_inputs

kwarg_inputs = get_kwarg_inputs()

# patch for flux pipeline, rename negative_prompt to prompt2
if pipe.__class__.__name__ == "FluxPipeline":
kwarg_inputs["prompt_2"] = kwarg_inputs["negative_prompt"]
kwarg_inputs.pop("negative_prompt")

# NOTE: Warm it up.
# The initial calls will trigger compilation and might be very slow.
# After that, it should be very fast.
Expand All @@ -357,15 +379,15 @@ def get_kwarg_inputs():
print("=======================================")
print("Begin warmup")
for _ in range(args.warmups):
pipe(**get_kwarg_inputs())
pipe(**kwarg_inputs)
end = time.time()
print("End warmup")
print(f"Warmup time: {end - begin:.3f}s")
print("=======================================")

# Let"s see it!
# Note: Progress bar might work incorrectly due to the async nature of CUDA.
kwarg_inputs = get_kwarg_inputs()

iter_profiler = IterationProfiler()
if "callback_on_step_end" in inspect.signature(pipe).parameters:
kwarg_inputs["callback_on_step_end"] = iter_profiler.callback_on_step_end
Expand All @@ -387,6 +409,9 @@ def get_kwarg_inputs():
else:
cuda_mem_after_used = torch.cuda.max_memory_allocated() / (1024**3)
print(f"Max used CUDA memory : {cuda_mem_after_used:.3f}GiB")
if args.compiler != "oneflow":
cuda_mem_max_reserved = torch.cuda.max_memory_reserved() / (1024**3)
print(f"Peak CUDA memory : {cuda_mem_max_reserved:.3f}GiB")
print("=======================================")

if args.print_output:
Expand Down
101 changes: 101 additions & 0 deletions onediff_diffusers_extensions/examples/flux/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Run FLUX with nexfort backend (Beta Release)

1. [Environment Setup](#environment-setup)
- [Set Up OneDiff](#set-up-onediff)
- [Set Up NexFort Backend](#set-up-nexfort-backend)
- [Set Up Diffusers Library](#set-up-diffusers)
- [Set Up FLUX](#set-up-flux)
2. [Execution Instructions](#run)
- [Run Without Compilation (Baseline)](#run-without-compilation-baseline)
- [Run With Compilation](#run-with-compilation)
3. [Performance Comparison](#performance-comparison)
4. [Dynamic Shape for FLUX](#dynamic-shape-for-flux)

## Environment setup
### Set up onediff
https://github.com/siliconflow/onediff?tab=readme-ov-file#installation

### Set up nexfort backend
https://github.com/siliconflow/onediff/tree/main/src/onediff/infer_compiler/backends/nexfort

### Set up diffusers

```
pip3 install --upgrade diffusers[torch]
```
### Set up FLUX
Model version for diffusers: https://huggingface.co/black-forest-labs/FLUX.1-schnell

HF pipeline: https://github.com/huggingface/diffusers/blob/main/docs/source/en/api/pipelines/flux.md

## Run

### Run without compilation (Baseline)
```shell
python3 benchmarks/text_to_image.py \
--model black-forest-labs/FLUX.1-schnell \
--height 1024 --width 1024 \
--scheduler none \
--steps 4 \
--output-image ./flux-schnell.png \
--prompt "beautiful scenery nature glass bottle landscape, , purple galaxy bottle," \
--compiler none \
--dtype bfloat16 \
--seed 1 \
--print-output
```

### Run with compilation

```shell
python3 benchmarks/text_to_image.py \
--model black-forest-labs/FLUX.1-schnell \
--height 1024 --width 1024 \
--scheduler none \
--steps 4 \
--output-image ./flux-schnell-compile.png \
--prompt "beautiful scenery nature glass bottle landscape, , purple galaxy bottle," \
--compiler nexfort \
--compiler-config '{"mode": "benchmark:cudagraphs:max-autotune:low-precision:cache-all", "memory_format": "channels_last", "options": {"cuda.fuse_timestep_embedding": false, "inductor.force_triton_sdpa": true}}' \
--dtype bfloat16 \
--seed 1 \
--print-output
```

## Performance comparison

Testing on NVIDIA A800-SXM4-80GB, with image size of 1024*1024, iterating 4 steps:
| Metric | A800-SXM4-80GB 1024*1024 |
| ------------------------------------ | ------------------------ |
| Data update date (yyyy-mm-dd) | 2024-08-07 |
| PyTorch iteration speed | 2.18 it/s |
| OneDiff iteration speed | 2.80 it/s (+28.4%) |
| PyTorch E2E time | 2.06 s |
| OneDiff E2E time | 1.53 s (-25.7%) |
| PyTorch Max Mem Used | 35.79 GiB |
| OneDiff Max Mem Used | 40.44 GiB |
| PyTorch Warmup with Run time | 2.81 s |
| OneDiff Warmup with Compilation time | 253.01 s |
| OneDiff Warmup with Cache time | 73.63 s |

<sup>1</sup> OneDiff Warmup with Compilation time is tested on Intel(R) Xeon(R) Platinum 8358P CPU @ 2.60GHz. Note this is just for reference, and it varies a lot on different CPU.


## Dynamic shape for FLUX

Run:

```shell
python3 benchmarks/text_to_image.py \
--model black-forest-labs/FLUX.1-schnell \
--height 1024 --width 1024 \
--scheduler none \
--steps 4 \
--output-image ./flux-schnell-compile.png \
--prompt "beautiful scenery nature glass bottle landscape, , purple galaxy bottle," \
--compiler nexfort \
--compiler-config '{"mode": "benchmark:cudagraphs:max-autotune:low-precision:cache-all", "memory_format": "channels_last", "options": {"cuda.fuse_timestep_embedding": false, "inductor.force_triton_sdpa": true}, "dynamic", true}' \
--run_multiple_resolutions 1 \
--dtype bfloat16 \
--seed 1 \
```
101 changes: 101 additions & 0 deletions onediff_diffusers_extensions/examples/text_to_image_flux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import argparse
import time

import cv2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused imports.

The imports cv2, numpy, and PIL.Image are not used in the script and should be removed to clean up the code.

-import cv2
-import numpy as np
-import PIL.Image

Also applies to: 5-5, 9-9

Tools
Ruff

4-4: cv2 imported but unused

Remove unused import: cv2

(F401)

import numpy as np
import torch

from diffusers import FluxPipeline
from PIL import Image

parser = argparse.ArgumentParser()
parser.add_argument("--base", type=str, default="black-forest-labs/FLUX.1-schnell")
parser.add_argument(
"--prompt",
type=str,
default="chinese painting style women",
)
parser.add_argument("--height", type=int, default=512)
parser.add_argument("--width", type=int, default=512)
parser.add_argument("--n_steps", type=int, default=4)
parser.add_argument("--saved_image", type=str, required=False, default="flux-out.png")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--warmup", type=int, default=1)
parser.add_argument("--run", type=int, default=3)
parser.add_argument(
"--compile", type=(lambda x: str(x).lower() in ["true", "1", "yes"]), default=True
)
parser.add_argument("--run-multiple-resolutions", action="store_true")
args = parser.parse_args()


# load stable diffusion
pipe = FluxPipeline.from_pretrained(args.base, torch_dtype=torch.bfloat16)
# pipe = FluxPipeline.from_pretrained(args.base, torch_dtype=torch.bfloat16, local_files_only=True, revision="93424e3a1530639fefdf08d2a7a954312e5cb254")
pipe.to("cuda")

if args.compile:
from onediffx import compile_pipe

pipe = compile_pipe(
pipe,
backend="nexfort",
options={
"options": {
"cuda.fuse_timestep_embedding": False,
"inductor.force_triton_sdpa": True,
}
},
)


# generate image
generator = torch.manual_seed(args.seed)

print("Warmup")
for i in range(args.warmup):
image = pipe(
args.prompt,
height=args.height,
width=args.width,
output_type="pil",
num_inference_steps=args.n_steps, # use a larger number if you are using [dev]
generator=torch.Generator("cpu").manual_seed(args.seed),
).images[0]


print("Run")
for i in range(args.run):
begin = time.time()
image = pipe(
args.prompt,
height=args.height,
width=args.width,
output_type="pil",
num_inference_steps=args.n_steps, # use a larger number if you are using [dev]
generator=torch.Generator("cpu").manual_seed(args.seed),
).images[0]
end = time.time()
print(f"Inference time: {end - begin:.3f}s")

image.save(f"{i=}th_{args.saved_image}.png")
Copy link
Contributor

@coderabbitai coderabbitai bot Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix incorrect string interpolation in image.save.

The use of f"{i=}th_{args.saved_image}.png" is incorrect for filename generation. Correct the string formatting to avoid syntax errors.

-    image.save(f"{i=}th_{args.saved_image}.png")
+    image.save(f"{i}th_{args.saved_image}.png")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
image.save(f"{i=}th_{args.saved_image}.png")
image.save(f"{i}th_{args.saved_image}.png")

Copy link
Collaborator

@jackalcooper jackalcooper Aug 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个等号去掉吧,在文件名里面比较麻烦

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jackalcooper,好的,去掉等号是正确的选择。这样文件名会更简洁。感谢您的确认!



if args.run_multiple_resolutions:
print("Test run with multiple resolutions...")
sizes = [1024, 512, 768, 256]
for h in sizes:
for w in sizes:
print(f"Running at resolution: {h}x{w}")
start_time = time.time()
image = pipe(
args.prompt,
height=h,
width=w,
output_type="pil",
num_inference_steps=args.n_steps, # use a larger number if you are using [dev]
generator=torch.Generator("cpu").manual_seed(args.seed),
).images[0]
end_time = time.time()
print(f"Inference time: {end_time - start_time:.2f} seconds")
image.save(f"{i=}th_{args.saved_image}_{h}x{w}.png")
Copy link
Contributor

@coderabbitai coderabbitai bot Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix incorrect string interpolation in image.save.

The use of f"{i=}th_{args.saved_image}_{h}x{w}.png" is incorrect for filename generation. Correct the string formatting to avoid syntax errors.

-    image.save(f"{i=}th_{args.saved_image}_{h}x{w}.png")
+    image.save(f"{i}th_{args.saved_image}_{h}x{w}.png")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
image.save(f"{i=}th_{args.saved_image}_{h}x{w}.png")
image.save(f"{i}th_{args.saved_image}_{h}x{w}.png")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个等号去掉吧,在文件名里面比较麻烦

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jackalcooper, 正如建议中所示,我已经去掉了文件名中的等号。请查看是否符合您的要求!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

5 changes: 4 additions & 1 deletion src/onediff/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def check_module_availability(module_name):
return True


_oneflow_available = check_module_availability("oneflow")
_oneflow_available = None
_onediff_quant_available = check_module_availability("onediff_quant")
_nexfort_available = check_module_availability("nexfort")

Expand All @@ -33,6 +33,9 @@ def check_module_availability(module_name):


def is_oneflow_available():
global _oneflow_available
if _oneflow_available is None:
_oneflow_available = check_module_availability("oneflow")
return _oneflow_available


Expand Down
Loading