diff --git a/examples/diffusion/python_flux/README.md b/examples/diffusion/python_flux/README.md new file mode 100644 index 00000000000..98dc86ca815 --- /dev/null +++ b/examples/diffusion/python_flux/README.md @@ -0,0 +1,27 @@ +## Setup + +Make sure python interpreter can find migraphx. Default location: +``` +export PYTHONPATH=/opt/rocm/lib:$PYTHONPATH +``` + +Install dependencies +``` +pip install -r torch_requirements.txt -r requirements.txt +``` + +Login to Huggingface: +``` +huggingface-cli login +``` + +## Generate Image +``` +python3 txt2img.py -p "A cat holding a sign that says hello world" +``` + +## Benchmark +Ex. 10 full executions: +``` +python3 txt2img.py -b 10 --fp16 +``` \ No newline at end of file diff --git a/examples/diffusion/python_flux/flux_0.png b/examples/diffusion/python_flux/flux_0.png new file mode 100644 index 00000000000..029d0b27eb8 Binary files /dev/null and b/examples/diffusion/python_flux/flux_0.png differ diff --git a/examples/diffusion/python_flux/flux_pipeline.py b/examples/diffusion/python_flux/flux_pipeline.py new file mode 100644 index 00000000000..da1af37506d --- /dev/null +++ b/examples/diffusion/python_flux/flux_pipeline.py @@ -0,0 +1,380 @@ +# pip install transformers, diffusers, sentencepiece, accelerate, onnx + +import os +import warnings +import time +from tabulate import tabulate +import numpy as np +import torch +from models import (get_tokenizer, get_clip_model, get_t5_model, + get_flux_transformer_model, get_vae_model, get_scheduler, AutoencoderKL) +from PIL import Image +# import migraphx as mgx + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +class FluxPipeline: + + def __init__(self, + hf_model_path="black-forest-labs/FLUX.1-dev", + local_dir=None, + compile_dir=None, + pipeline_type="txt2img", + img_height=1024, + img_width=1024, + guidance_scale=3.5, + max_sequence_length=512, + batch_size=1, + denoising_steps=50, + fp16=True, + exhaustive_tune=False, + manual_seed=None): + + self.hf_model_path = hf_model_path + self.height = img_height + self.width = img_width + self.guidance_scale = guidance_scale + self.max_sequence_length = max_sequence_length + self.pipeline_type = pipeline_type + self.bs = batch_size + self.steps = denoising_steps + self.fp16 = fp16 + self.exhaustive_tune = exhaustive_tune + + if not local_dir: + self.local_dir = self.hf_model_path.split("/")[-1] + if not compile_dir: + self.compile_dir = self.hf_model_path.split("/")[-1] + "_compiled" + + self.models = {} + + # self.stages = ["clip", "t5", "transformer", "vae"] + + self.generator = torch.Generator(device="cuda") + if manual_seed: + self.generator.manual_seed(manual_seed) + self.device = torch.cuda.current_device() + + self.times = [] + + def load_models(self): + self.scheduler = get_scheduler(self.local_dir, self.hf_model_path) + self.tokenizer = get_tokenizer(self.local_dir, self.hf_model_path) + self.tokenizer2 = get_tokenizer(self.local_dir, self.hf_model_path, + "t5", "tokenizer_2") + + self.clip = get_clip_model(self.local_dir, + self.hf_model_path, + self.compile_dir, + fp16=self.fp16, + bs=self.bs, + exhaustive_tune=self.exhaustive_tune) + + self.t5 = get_t5_model(self.local_dir, + self.hf_model_path, + self.compile_dir, + self.max_sequence_length, + bs=self.bs, + exhaustive_tune=self.exhaustive_tune) + + self.flux_transformer = get_flux_transformer_model( + self.local_dir, + self.hf_model_path, + self.compile_dir, + img_height=self.height, + img_width=self.width, + max_len=self.max_sequence_length, + fp16=self.fp16, + bs=self.bs, + exhaustive_tune=self.exhaustive_tune) + + self.vae = get_vae_model(self.local_dir, + self.hf_model_path, + self.compile_dir, + img_height=self.height, + img_width=self.width, + bs=self.bs, + exhaustive_tune=self.exhaustive_tune) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, + width): + """ + Reshapes latents from (B, C, H, W) to (B, H/2, W/2, C*4) as expected by the denoiser + """ + latents = latents.view(batch_size, num_channels_latents, height // 2, + 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), + num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + """ + Reshapes denoised latents to the format (B, C, H, W) + """ + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, + width * 2) + + return latents + + @staticmethod + def _prepare_latent_image_ids(height, width, dtype, device): + """ + Prepares latent image indices + """ + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = (latent_image_ids[..., 1] + + torch.arange(height // 2)[:, None]) + latent_image_ids[..., 2] = (latent_image_ids[..., 2] + + torch.arange(width // 2)[None, :]) + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = ( + latent_image_ids.shape) + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, + latent_image_id_channels) + + return latent_image_ids.to(device=device, dtype=dtype) + + def initialize_latents( + self, + batch_size, + num_channels_latents, + latent_height, + latent_width, + latents_dtype=torch.float32, + ): + latents_dtype = latents_dtype # text_embeddings.dtype + latents_shape = (batch_size, num_channels_latents, latent_height, + latent_width) + latents = torch.randn( + latents_shape, + device=torch.cuda.current_device(), + dtype=latents_dtype, + generator=self.generator, + ) + + latents = self._pack_latents(latents, batch_size, num_channels_latents, + latent_height, latent_width) + + latent_image_ids = self._prepare_latent_image_ids( + latent_height, latent_width, latents_dtype, self.device) + + return latents, latent_image_ids + + def encode_prompt(self, + prompt, + encoder="clip", + max_sequence_length=None, + pooled_output=False): + tokenizer = self.tokenizer2 if encoder == "t5" else self.tokenizer + encoder = self.t5 if encoder == "t5" else self.clip + max_sequence_length = (tokenizer.model_max_length + if max_sequence_length is None else + max_sequence_length) + + def tokenize(prompt, max_sequence_length): + text_input_ids = (tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ).input_ids.type(torch.int32).to(self.device)) + + untruncated_ids = tokenizer(prompt, + padding="longest", + return_tensors="pt").input_ids.type( + torch.int32).to(self.device) + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode( + untruncated_ids[:, max_sequence_length - 1:-1]) + warnings.warn( + "The following part of your input was truncated because `max_sequence_length` is set to " + f"{max_sequence_length} tokens: {removed_text}") + + # NOTE: output tensor for the encoder must be cloned because it will be overwritten when called again for prompt2 + outputs = encoder.run_async(input_ids=text_input_ids) + output_name = ("main:#output_0" + if not pooled_output else "main:#output_1") + text_encoder_output = outputs[output_name].clone() + + return text_encoder_output + + # Tokenize prompt + text_encoder_output = tokenize(prompt, max_sequence_length) + + # return (text_encoder_output.to(torch.float16) + # if self.fp16 else text_encoder_output) + return text_encoder_output + + def denoise_latent( + self, + latents, + timesteps, + text_embeddings, + pooled_embeddings, + text_ids, + latent_image_ids, + denoiser="transformer", + guidance=None, + ): + + # handle guidance + if self.flux_transformer.config["guidance_embeds"] and guidance is None: + guidance = torch.full([latents.shape[0]], + self.guidance_scale, + device=self.device, + dtype=torch.float32) + + for step_index, timestep in enumerate(timesteps): + # prepare inputs + timestep_inp = timestep.expand(latents.shape[0]).to(latents.dtype) + params = { + "hidden_states": latents, + "timestep": timestep_inp / 1000, + "pooled_projections": pooled_embeddings, + "encoder_hidden_states": text_embeddings, + "txt_ids": text_ids, + "img_ids": latent_image_ids, + } + if guidance is not None: + params.update({"guidance": guidance}) + + noise_pred = self.flux_transformer.run_async( + **params)["main:#output_0"] + + latents = self.scheduler.step(noise_pred, + timestep, + latents, + return_dict=False)[0] + + return latents.to(dtype=torch.float32) + + def decode_latent(self, latents): + images = self.vae.run_async(latent=latents)["main:#output_0"] + return images + + def infer(self, prompt, prompt2, warmup=False): + assert len(prompt) == len(prompt2) + batch_size = len(prompt) + + self.vae_scale_factor = 2**(len(self.vae.config["block_out_channels"])) + latent_height = 2 * (int(self.height) // self.vae_scale_factor) + latent_width = 2 * (int(self.width) // self.vae_scale_factor) + + num_inference_steps = self.steps + + with torch.inference_mode(): + torch.cuda.synchronize() + + self.e2e_tic = time.perf_counter() + + latents, latent_image_ids = self.initialize_latents( + batch_size=batch_size, + num_channels_latents=self.flux_transformer.config["in_channels"] // 4, + # num_channels_latents=16, + latent_height=latent_height, + latent_width=latent_width, + # latents_dtype=torch.float16 if self.fp16 else torch.float32, + latents_dtype=torch.float32 + ) + + pooled_embeddings = self.encode_prompt(prompt, pooled_output=True) + text_embeddings = self.encode_prompt( + prompt2, + encoder="t5", + max_sequence_length=self.max_sequence_length) + text_ids = torch.zeros(text_embeddings.shape[1], + 3).to(device=self.device, + dtype=text_embeddings.dtype) + + # Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, + num_inference_steps) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + self.scheduler.set_timesteps(sigmas=sigmas, mu=mu, device=self.device) + timesteps = self.scheduler.timesteps.to(self.device) + num_inference_steps = len(timesteps) + + latents = self.denoise_latent( + latents, + timesteps, + text_embeddings, + pooled_embeddings, + text_ids, + latent_image_ids, + ) + + latents = self._unpack_latents(latents, self.height, self.width, + self.vae_scale_factor) + latents = (latents / self.vae.config["scaling_factor"] + ) + self.vae.config["shift_factor"] + + + images = self.decode_latent(latents) + torch.cuda.synchronize() + self.e2e_toc = time.perf_counter() + if not warmup: + self.record_times() + + return images + + def record_times(self): + self.times.append(self.e2e_toc - self.e2e_tic) + + def save_image(self, images, prefix, output_dir="./"): + images = ((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy() + for i in range(images.shape[0]): + path = os.path.join(output_dir, f"{prefix}_{i}.png") + Image.fromarray(images[i]).save(path) + + def print_summary(self): + headers = ["Model", "Latency(ms)"] + rows = [] + for mod in ("clip", "t5", "flux_transformer", "vae"): + name = f"{mod} (x{self.steps})" if mod == "flux_transformer" else mod + rows.append([name, np.average(getattr(self, mod).get_run_times())]) + + rows.append(["e2e", np.average(self.times)*1000]) + print(tabulate(rows, headers=headers)) + + def clear_run_data(self): + self.times = [] + for mod in (self.clip, self.t5, self.flux_transformer, self.vae): + mod.clear_events() + diff --git a/examples/diffusion/python_flux/models.py b/examples/diffusion/python_flux/models.py new file mode 100644 index 00000000000..fb36505d7a1 --- /dev/null +++ b/examples/diffusion/python_flux/models.py @@ -0,0 +1,504 @@ +import os +import torch +from transformers import (CLIPTokenizer, T5TokenizerFast, CLIPTextModel, + T5EncoderModel) +from diffusers import FluxTransformer2DModel, AutoencoderKL, FlowMatchEulerDiscreteScheduler +import migraphx as mgx + + +class MGXModel: + + def __init__(self, + model, + input_shapes=None, + fp16=False, + exhaustive_tune=False, + config=None): + + if isinstance(model, mgx.program): + self.model = model + + elif isinstance(model, str) and os.path.isfile(model): + + if model.endswith(".mxr"): + self.model = mgx.load(model, format="msgpack") + + elif model.endswith(".onnx"): + if not input_shapes: + raise ValueError( + f"input_shapes need to be specified for loading a .onnx file" + ) + self.model = mgx.parse_onnx(model, map_input_dims=input_shapes) + if fp16: + mgx.quantize_fp16(self.model) + self.model.compile(mgx.get_target("gpu"), + exhaustive_tune=exhaustive_tune, + offload_copy=False) + else: + raise ValueError( + f"File type not recognized (should eend with .mxr or .onnx): {model}" + ) + else: + raise ValueError( + f"model should be a migraphx.program object or path to .mxr/.onnx file" + ) + self.config = config + + self.mgx_to_torch_dtype_dict = { + "bool_type": torch.bool, + "uint8_type": torch.uint8, + "int8_type": torch.int8, + "int16_type": torch.int16, + "int32_type": torch.int32, + "int64_type": torch.int64, + "float_type": torch.float32, + "double_type": torch.float64, + "half_type": torch.float16, + } + self.torch_to_mgx_dtype_dict = { + v: k + for k, v in self.mgx_to_torch_dtype_dict.items() + } + + self.input_names = [] + self.output_names = [] + + for n in self.model.get_parameter_names(): + if "main:#output_" in n: + self.output_names.append(n) + else: + self.input_names.append(n) + + self.torch_buffers = {} + self.mgx_args = {} + + self.start_events = [] + self.end_events = [] + + self.prealloc_buffers(self.output_names) + + def run_async(self, stream=None, **inputs): + if stream is None: + stream = torch.cuda.current_stream() + + for name, tensor in inputs.items(): + self.mgx_args[name] = self.tensor_to_arg(tensor) + + self.start_events.append(torch.cuda.Event(enable_timing=True)) + self.end_events.append(torch.cuda.Event(enable_timing=True)) + + self.start_events[-1].record() + self.model.run_async(self.mgx_args, stream.cuda_stream, "ihipStream_t") + self.end_events[-1].record() + + return {p: self.torch_buffers[p] for p in self.output_names} + + def save_model(self, path): + os.makedirs(os.path.dirname(path), exist_ok=True) + mgx.save(self.model, path, format="msgpack") + + def tensor_to_arg(self, tensor): + mgx_shape = mgx.shape(type=self.torch_to_mgx_dtype_dict[tensor.dtype], + lens=list(tensor.size()), + strides=list(tensor.stride())) + return mgx.argument_from_pointer(mgx_shape, tensor.data_ptr()) + + def prealloc_buffers(self, param_names): + for param_name in param_names: + param_shape = self.model.get_parameter_shapes()[param_name] + + type_str, lens = param_shape.type_string(), param_shape.lens() + strides = param_shape.strides() + torch_dtype = self.mgx_to_torch_dtype_dict[type_str] + tensor = torch.empty_strided(lens, + strides, + dtype=torch_dtype, + device=torch.cuda.current_device()) + self.torch_buffers[param_name] = tensor + self.mgx_args[param_name] = self.tensor_to_arg(tensor) + + def get_run_times(self): + return [s.elapsed_time(e) for s, e in zip(self.start_events, self.end_events)] + + def clear_events(self): + self.start_events = [] + self.end_events = [] + + +def get_scheduler(local_dir, hf_model_path, scheduler_dir="scheduler"): + scheduler_local_dir = os.path.join(local_dir, scheduler_dir) + scheduler_cls = FlowMatchEulerDiscreteScheduler + + if not os.path.exists(scheduler_local_dir): + model = scheduler_cls.from_pretrained(hf_model_path, + subfolder=scheduler_dir) + model.save_pretrained(scheduler_local_dir) + else: + print(f"Loading {scheduler_cls} scheduler from {scheduler_local_dir}") + model = scheduler_cls.from_pretrained(scheduler_local_dir) + + return model + + +def get_tokenizer(local_dir, + hf_model_path, + tokenizer_type="clip", + tokenizer_dir="tokenizer"): + tokenizer_local_dir = os.path.join(local_dir, tokenizer_dir) + if tokenizer_type == "clip": + tokenizer_class = CLIPTokenizer + elif tokenizer_type == "t5": + tokenizer_class = T5TokenizerFast + else: + raise ValueError(f"Unsupported tokenizer: {tokenizer_type}") + + if not os.path.exists(tokenizer_local_dir): + model = tokenizer_class.from_pretrained(hf_model_path, + subfolder=tokenizer_dir) + model.save_pretrained(tokenizer_local_dir) + else: + print(f"Loading {tokenizer_type} tokenizer from {tokenizer_local_dir}") + model = tokenizer_class.from_pretrained(tokenizer_local_dir) + + return model + + +def get_local_path(local_dir, model_dir): + model_local_dir = os.path.join(local_dir, model_dir) + if not os.path.exists(model_local_dir): + os.makedirs(model_local_dir) + return model_local_dir + + +def get_clip_model(local_dir, + hf_model_path, + compiled_dir, + model_dir="text_encoder", + torch_dtype=torch.float32, + bs=1, + exhaustive_tune=False, + fp16=True): + clip_local_dir = get_local_path(local_dir, model_dir) + onnx_file = "model.onnx" + onnx_path = os.path.join(clip_local_dir, onnx_file) + + def get_compiled_file_name(): + name = f"model_b{bs}" + if fp16: name += "_fp16" + if exhaustive_tune: name += f"_exh" + return name + ".mxr" + + clip_compiled_dir = get_local_path(compiled_dir, model_dir) + mxr_file = get_compiled_file_name() + mxr_path = os.path.join(clip_compiled_dir, mxr_file) + + if os.path.isfile(mxr_path): + print(f"found compiled model.. loading CLIP encoder from {mxr_path}") + model = MGXModel(mxr_path) + return model + + sample_inputs = (torch.zeros(bs, 77, dtype=torch.int32), ) + input_names = ["input_ids"] + if not os.path.isfile(onnx_path): + print(f"ONNX file not found.. exporting CLIP encoder to ONNX") + model = CLIPTextModel.from_pretrained(hf_model_path, + subfolder=model_dir, + torch_dtype=torch_dtype) + + output_names = ["text_embeddings"] + dynamic_axes = {"input_ids": {0: 'B'}, "text_embeddings": {0: 'B'}} + + # CLIP export requires nightly pytorch due to bug in onnx parser + with torch.inference_mode(): + torch.onnx.export(model, + sample_inputs, + onnx_path, + export_params=True, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes) + + assert os.path.isfile(onnx_path) + print(f"Generating MXR from ONNX file: {onnx_path}") + input_shapes = {n: list(t.size()) for n, t in zip(input_names, sample_inputs)} + model = MGXModel(onnx_path, + input_shapes=input_shapes, + exhaustive_tune=exhaustive_tune, + fp16=fp16) + model.save_model(os.path.join(clip_compiled_dir, get_compiled_file_name())) + + return model + + # migraphx-driver perf FLUX.1-schnell/text_encoder/model.onnx --input-dim @input_ids 1 77 --fill1 input_ids --fp16 + + +def get_t5_model(local_dir, + hf_model_path, + compiled_dir, + max_len=512, + model_dir="text_encoder_2", + torch_dtype=torch.float32, + bs=1, + exhaustive_tune=False, + fp16=False): + t5_local_dir = get_local_path(local_dir, model_dir) + onnx_file = "model.onnx" + onnx_path = os.path.join(t5_local_dir, onnx_file) + + def get_compiled_file_name(): + name = f"model_b{bs}" + name += f"_l{max_len}" + if fp16: name += "_fp16" + if exhaustive_tune: name += f"_exh" + return name + ".mxr" + + t5_compiled_dir = get_local_path(compiled_dir, model_dir) + mxr_file = get_compiled_file_name() + mxr_path = os.path.join(t5_compiled_dir, mxr_file) + + if os.path.isfile(mxr_path): + print(f"found compiled model.. loading T5 encoder from {mxr_path}") + model = MGXModel(mxr_path) + return model + + sample_inputs = (torch.zeros(bs, max_len, dtype=torch.int32), ) + input_names = ["input_ids"] + if not os.path.isfile(onnx_path): + model = T5EncoderModel.from_pretrained(hf_model_path, + subfolder=model_dir, + torch_dtype=torch_dtype) + output_names = ["text_embeddings"] + dynamic_axes = {"input_ids": {0: 'B'}, "text_embeddings": {0: 'B'}} + + with torch.inference_mode(): + torch.onnx.export(model, + sample_inputs, + onnx_path, + export_params=True, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes) + + assert os.path.isfile(onnx_path) + print(f"Generating MXR from ONNX file: {onnx_path}") + input_shapes = {n: list(t.size()) for n, t in zip(input_names, sample_inputs)} + model = MGXModel(onnx_path, + input_shapes=input_shapes, + exhaustive_tune=exhaustive_tune, + fp16=fp16) + model.save_model(os.path.join(t5_compiled_dir, get_compiled_file_name())) + + return model + + # migraphx-driver perf FLUX.1-schnell/text_encoder_2/model.onnx --input-dim @input_ids 1 512 --fill1 input_ids --fp16 + + +## Following decorators required to apply fp16 inference patch to the transformer blocks +## Note that we do not export fp16 weights directly to ONNX to allow migraphx to +## perform optimizations before quantizing down to fp16. This results in better +## accuracy compared to exporting fp16 directly to onnx +def transformer_block_clip_wrapper(fn): + + def new_forward(*args, **kwargs): + encoder_hidden_states, hidden_states = fn(*args, **kwargs) + return encoder_hidden_states.clip(-65504, 65504), hidden_states + + return new_forward + +def single_transformer_block_clip_wrapper(fn): + + def new_forward(*args, **kwargs): + hidden_states = fn(*args, **kwargs) + return hidden_states.clip(-65504, 65504) + + return new_forward + +def add_output_clippings_for_fp16(model): + for b in model.transformer_blocks: + b.forward = transformer_block_clip_wrapper(b.forward) + + for b in model.single_transformer_blocks: + b.forward = single_transformer_block_clip_wrapper(b.forward) + +def get_flux_transformer_model(local_dir, + hf_model_path, + compiled_dir, + img_height=1024, + img_width=1024, + compression_factor=8, + max_len=512, + model_dir="transformer", + torch_dtype=torch.float32, + bs=1, + exhaustive_tune=False, + fp16=True): + + transformer_local_dir = get_local_path(local_dir, model_dir) + onnx_file = "model.onnx" + onnx_path = os.path.join(transformer_local_dir, onnx_file) + latent_h, latent_w = img_height // compression_factor, img_width // compression_factor + + def get_compiled_file_name(): + name = f"model_b{bs}" + name += f"_h{latent_h}_w{latent_w}_l{max_len}" + if fp16: name += "_fp16" + if exhaustive_tune: name += f"_exh" + return name + ".mxr" + + transformer_compiled_dir = get_local_path(compiled_dir, model_dir) + mxr_file = get_compiled_file_name() + mxr_path = os.path.join(transformer_compiled_dir, mxr_file) + + config = FluxTransformer2DModel.load_config(hf_model_path, + subfolder=model_dir) + + if os.path.isfile(mxr_path): + print(f"found compiled model.. loading flux transformer from {mxr_path}") + model = MGXModel(mxr_path, config=config) + return model + + sample_inputs = (torch.randn(bs, (latent_h // 2) * (latent_w // 2), + config["in_channels"], + dtype=torch_dtype), + torch.randn(bs, + max_len, + config['joint_attention_dim'], + dtype=torch_dtype), + torch.randn(bs, + config['pooled_projection_dim'], + dtype=torch_dtype), + torch.tensor([1.]*bs, dtype=torch_dtype), + torch.randn((latent_h // 2) * (latent_w // 2), + 3, + dtype=torch_dtype), + torch.randn(max_len, 3, dtype=torch_dtype), + torch.tensor([1.]*bs, dtype=torch_dtype),) + + input_names = [ + 'hidden_states', 'encoder_hidden_states', 'pooled_projections', + 'timestep', 'img_ids', 'txt_ids', 'guidance' + ] + if not os.path.isfile(onnx_path): + model = FluxTransformer2DModel.from_pretrained(hf_model_path, + subfolder=model_dir, + torch_dtype=torch_dtype) + + add_output_clippings_for_fp16(model) + + output_names = ["latent"] + dynamic_axes = { + 'hidden_states': {0: 'B', 1: 'latent_dim'}, + 'encoder_hidden_states': {0: 'B',1: 'L'}, + 'pooled_projections': {0: 'B'}, + 'timestep': {0: 'B'}, + 'img_ids': {0: 'latent_dim'}, + 'txt_ids': {0: 'L'}, + 'guidance': {0: 'B'}, + } + + with torch.inference_mode(): + torch.onnx.export(model, + sample_inputs, + onnx_path, + export_params=True, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes) + + assert os.path.isfile(onnx_path) + print(f"Generating MXR from ONNX file: {onnx_path}") + input_shapes = {n: list(t.size()) for n, t in zip(input_names, sample_inputs)} + model = MGXModel(onnx_path, + input_shapes=input_shapes, + exhaustive_tune=exhaustive_tune, + fp16=fp16, + config=config) + model.save_model(os.path.join(transformer_compiled_dir, get_compiled_file_name())) + + return model + # migraphx-driver perf FLUX.1-schnell/transformer/model.onnx --input-dim @hidden_states 1 4096 64 @encoder_hidden_states 1 512 4096 @pooled_projections 1 768 @timestep 1 @img_ids 4096 3 @txt_ids 512 3 --fp16 + # migraphx-driver perf FLUX.1-dev/transformer/model.onnx --input-dim @hidden_states 1 4096 64 @encoder_hidden_states 1 512 4096 @pooled_projections 1 768 @timestep 1 @img_ids 4096 3 @txt_ids 512 3 @guidance 1 --fp16 + + +def get_vae_model(local_dir, + hf_model_path, + compiled_dir, + img_height=1024, + img_width=1024, + compression_factor=8, + model_dir="vae", + torch_dtype=torch.float32, + bs=1, + exhaustive_tune=False, + fp16=False): + + vae_local_dir = get_local_path(local_dir, model_dir) + onnx_file = "model.onnx" + onnx_path = os.path.join(vae_local_dir, onnx_file) + latent_h, latent_w = img_height // compression_factor, img_width // compression_factor + + def get_compiled_file_name(): + name = f"model_b{bs}" + name += f"_h{latent_h}_w{latent_w}" + if fp16: name += "_fp16" + if exhaustive_tune: name += f"_exh" + return name + ".mxr" + + vae_compiled_dir = get_local_path(compiled_dir, model_dir) + mxr_file = get_compiled_file_name() + mxr_path = os.path.join(vae_compiled_dir, mxr_file) + + config = AutoencoderKL.load_config(hf_model_path, subfolder=model_dir) + + if os.path.isfile(mxr_path): + print(f"found compiled model.. loading VAE decoder from {mxr_path}") + model = MGXModel(mxr_path, config=config) + return model + + sample_inputs = (torch.randn(bs, + config['latent_channels'], + latent_h, + latent_w, + dtype=torch_dtype), ) + input_names = ["latent"] + if not os.path.isfile(onnx_path): + model = AutoencoderKL.from_pretrained(hf_model_path, + subfolder=model_dir, + torch_dtype=torch_dtype) + model.forward = model.decode + + output_names = ["images"] + dynamic_axes = { + 'latent': { + 0: 'B', + 2: 'H', + 3: 'W' + }, + 'images': { + 0: 'B', + 2: '8H', + 3: '8W' + } + } + + with torch.inference_mode(): + torch.onnx.export(model, + sample_inputs, + onnx_path, + export_params=True, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes) + + assert os.path.isfile(onnx_path) + print(f"Generating MXR from ONNX file: {onnx_path}") + input_shapes = {n: list(t.size()) for n, t in zip(input_names, sample_inputs)} + model = MGXModel(onnx_path, + input_shapes=input_shapes, + exhaustive_tune=exhaustive_tune, + fp16=fp16, + config=config) + model.save_model(os.path.join(vae_compiled_dir, get_compiled_file_name())) + + return model + # migraphx-driver perf FLUX.1-schnell/vae/model.onnx --input-dim @latent 1 16 128 128 --fp16 diff --git a/examples/diffusion/python_flux/requirements.txt b/examples/diffusion/python_flux/requirements.txt new file mode 100644 index 00000000000..0b25ac70956 --- /dev/null +++ b/examples/diffusion/python_flux/requirements.txt @@ -0,0 +1,6 @@ +transformers +diffusers +sentencepiece +accelerate +onnx +tabulate \ No newline at end of file diff --git a/examples/diffusion/python_flux/torch_requirements.txt b/examples/diffusion/python_flux/torch_requirements.txt new file mode 100644 index 00000000000..af228b3adfe --- /dev/null +++ b/examples/diffusion/python_flux/torch_requirements.txt @@ -0,0 +1,2 @@ +--index-url https://download.pytorch.org/whl/rocm6.2 +torch \ No newline at end of file diff --git a/examples/diffusion/python_flux/txt2img.py b/examples/diffusion/python_flux/txt2img.py new file mode 100644 index 00000000000..f4b3875522a --- /dev/null +++ b/examples/diffusion/python_flux/txt2img.py @@ -0,0 +1,170 @@ +from argparse import ArgumentParser +from flux_pipeline import FluxPipeline + + +def get_args(): + parser = ArgumentParser() + + parser.add_argument( + "--hf-model", + type=str, + choices=["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell",], + default="black-forest-labs/FLUX.1-dev", + help="Specify HF model card. Options: 'black-forest-labs/FLUX.1-dev', 'black-forest-labs/FLUX.1-schnell'", + ) + + parser.add_argument( + "--local-dir", + type=str, + default=None, + help="Specify directory with local onnx files (or where to export)", + ) + + parser.add_argument( + "--compile-dir", + type=str, + default=None, + help="Specify directory with compile mxr files (or where to export)", + ) + + parser.add_argument( + "-d", + "--image-height", + type=int, + default=1024, + help="Output Image height, default 1024", + ) + + parser.add_argument( + "-w", + "--image-width", + type=int, + default=1024, + help="Output Image width, default 1024", + ) + + parser.add_argument( + "-g", + "--guidance-scale", + type=float, + default=3.5, + help="Guidance scale, default 3.5", + ) + + parser.add_argument( + "-l", + "--max-sequence-length", + type=int, + default=512, + help="Max sequence length for T5, default 512", + ) + + parser.add_argument( + "-p", + "--prompt", + default=["A cat holding a sign that says hello world"], + nargs="*", + help="Text prompt(s) to be sent to the CLIP tokenizer and text encoder", + ) + + parser.add_argument( + "--prompt2", + default=None, + nargs="*", + help="Text prompt(s) to be sent to the T5 tokenizer and text encoder. If not defined, prompt will be used instead", + ) + + parser.add_argument( + "-s", + "--denoising-steps", + type=int, + default=50, + help="Number of denoising steps", + ) + + parser.add_argument( + "--fp16", + action='store_true', + help="Apply fp16 quantization." + ) + + parser.add_argument( + "--output-dir", + type=str, + default="./", + help="Specify directory where images should be saved", + ) + + parser.add_argument( + "-o", + "--output-prefix", + type=str, + default="flux", + help="Specify image name prefix for saving result images", + ) + + parser.add_argument( + "-b", + "--benchmark-runs", + type=int, + default=None, + help="Number of runs to do for benchmarking. Default: no benchmarking", + ) + + parser.add_argument( + "--exhaustive-tune", + action='store_true', + help="Perform exhaustive tuning when compiling" + ) + + parser.add_argument( + "--batch-size", + type=int, + default=None, + help="Set custom batch size (expects len 1 prompt, useful for benchmarking)" + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + + prompt = args.prompt + prompt2 = args.prompt2 if args.prompt2 else prompt + + if args.batch_size: + assert len(prompt) == 1 and len(prompt2) == 1 + prompt = prompt * args.batch_size + prompt2 = prompt2 * args.batch_size + + pipe = FluxPipeline( + hf_model_path=args.hf_model, + local_dir=args.local_dir, + compile_dir=args.compile_dir, + img_height=args.image_height, + img_width=args.image_width, + guidance_scale=args.guidance_scale, + max_sequence_length=args.max_sequence_length, + batch_size=len(prompt), + denoising_steps=args.denoising_steps, + fp16=args.fp16, + exhaustive_tune=args.exhaustive_tune + ) + + pipe.load_models() + + images = pipe.infer(prompt, prompt2, warmup=True) + + if args.output_dir: + print(f"Saving images to {args.output_dir}") + pipe.save_image(images, args.output_prefix, args.output_dir) + + if args.benchmark_runs: + pipe.clear_run_data() + print("Begin benchmarking...") + for _ in range(args.benchmark_runs): + pipe.infer(prompt, prompt2) + print(f"Run time: {pipe.times[-1]}s") + + pipe.print_summary() \ No newline at end of file