forked from xuanzic/NeMo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add flux inference pipeline (NVIDIA#10752)
* Vae added and matched flux checkpoint Signed-off-by: mingyuanm <[email protected]> * Flux model added. Signed-off-by: mingyuanm <[email protected]> * Copying FlowMatchEulerScheduler over Signed-off-by: mingyuanm <[email protected]> * WIP: Start to test the pipeline forward pass Signed-off-by: mingyuanm <[email protected]> * Vae added and matched flux checkpoint Signed-off-by: mingyuanm <[email protected]> * Inference pipeline runs with offloading function Signed-off-by: mingyuanm <[email protected]> * Start to test image generation Signed-off-by: mingyuanm <[email protected]> * Decoding with VAE part has been verified. Still need to check the denoising loop. Signed-off-by: mingyuanm <[email protected]> * The inference pipeline is verified. Signed-off-by: mingyuanm <[email protected]> * Add arg parsers and refactoring Signed-off-by: mingyuanm <[email protected]> * Tested on multi batch sizes and prompts. Signed-off-by: mingyuanm <[email protected]> * Add headers Signed-off-by: mingyuanm <[email protected]> * Apply isort and black reformatting Signed-off-by: Victor49152 <[email protected]> * Renaming Signed-off-by: mingyuanm <[email protected]> * Move shceduler to sampler folder Signed-off-by: mingyuanm <[email protected]> * Merging folders. Signed-off-by: mingyuanm <[email protected]> * Apply isort and black reformatting Signed-off-by: Victor49152 <[email protected]> * Tested after path changing. Signed-off-by: mingyuanm <[email protected]> * Apply isort and black reformatting Signed-off-by: Victor49152 <[email protected]> * Move MMDIT block to NeMo Signed-off-by: mingyuanm <[email protected]> * Apply isort and black reformatting Signed-off-by: Victor49152 <[email protected]> * Add joint attention and single attention to NeMo Signed-off-by: mingyuanm <[email protected]> * Apply isort and black reformatting Signed-off-by: Victor49152 <[email protected]> * Joint attention updated Signed-off-by: mingyuanm <[email protected]> * Apply isort and black reformatting Signed-off-by: Victor49152 <[email protected]> * Remove redundant importing Signed-off-by: mingyuanm <[email protected]> * Refactor to inherit megatron module Signed-off-by: mingyuanm <[email protected]> * Apply isort and black reformatting Signed-off-by: Victor49152 <[email protected]> --------- Signed-off-by: mingyuanm <[email protected]> Signed-off-by: Victor49152 <[email protected]> Co-authored-by: Victor49152 <[email protected]>
- Loading branch information
1 parent
c7a539a
commit 47f2446
Showing
17 changed files
with
2,971 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Union | ||
|
||
import torch | ||
import torch.nn as nn | ||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer | ||
|
||
|
||
class AbstractEmbModel(nn.Module): | ||
def __init__(self, enable_lora_finetune=False, target_block=[], target_module=[]): | ||
super().__init__() | ||
self._is_trainable = None | ||
self._ucg_rate = None | ||
self._input_key = None | ||
|
||
self.TARGET_BLOCK = target_block | ||
self.TARGET_MODULE = target_module | ||
if enable_lora_finetune: | ||
self.lora_layers = [] | ||
|
||
@property | ||
def is_trainable(self) -> bool: | ||
return self._is_trainable | ||
|
||
@property | ||
def ucg_rate(self) -> Union[float, torch.Tensor]: | ||
return self._ucg_rate | ||
|
||
@property | ||
def input_key(self) -> str: | ||
return self._input_key | ||
|
||
@is_trainable.setter | ||
def is_trainable(self, value: bool): | ||
self._is_trainable = value | ||
|
||
@ucg_rate.setter | ||
def ucg_rate(self, value: Union[float, torch.Tensor]): | ||
self._ucg_rate = value | ||
|
||
@input_key.setter | ||
def input_key(self, value: str): | ||
self._input_key = value | ||
|
||
@is_trainable.deleter | ||
def is_trainable(self): | ||
del self._is_trainable | ||
|
||
@ucg_rate.deleter | ||
def ucg_rate(self): | ||
del self._ucg_rate | ||
|
||
@input_key.deleter | ||
def input_key(self): | ||
del self._input_key | ||
|
||
def encode(self, *args, **kwargs): | ||
raise NotImplementedError | ||
|
||
def _enable_lora(self, lora_model): | ||
for module_name, module in lora_model.named_modules(): | ||
if module.__class__.__name__ in self.TARGET_BLOCK: | ||
tmp = {} | ||
for sub_name, sub_module in module.named_modules(): | ||
if sub_module.__class__.__name__ in self.TARGET_MODULE: | ||
if hasattr(sub_module, "input_size") and hasattr( | ||
sub_module, "output_size" | ||
): # for megatron ParallelLinear | ||
lora = LoraWrapper(sub_module, sub_module.input_size, sub_module.output_size) | ||
else: # for nn.Linear | ||
lora = LoraWrapper(sub_module, sub_module.in_features, sub_module.out_features) | ||
self.lora_layers.append(lora) | ||
if sub_name not in tmp.keys(): | ||
tmp.update({sub_name: lora}) | ||
else: | ||
print(f"Duplicate subnames are found in module {module_name}") | ||
for sub_name, lora_layer in tmp.items(): | ||
lora_name = f'{sub_name}_lora' | ||
module.add_module(lora_name, lora_layer) | ||
|
||
|
||
class FrozenCLIPEmbedder(AbstractEmbModel): | ||
"""Uses the CLIP transformer encoder for text (from Hugging Face)""" | ||
|
||
LAYERS = ["last", "pooled", "hidden"] | ||
|
||
def __init__( | ||
self, | ||
version="openai/clip-vit-large-patch14", | ||
device="cuda", | ||
max_length=77, | ||
enable_lora_finetune=False, | ||
layer="last", | ||
layer_idx=None, | ||
always_return_pooled=False, | ||
dtype=torch.float, | ||
): | ||
super().__init__(enable_lora_finetune, target_block=["CLIPAttention", "CLIPMLP"], target_module=["Linear"]) | ||
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | ||
self.transformer = CLIPTextModel.from_pretrained(version, torch_dtype=dtype).to(device) | ||
self.device = device | ||
self.max_length = max_length | ||
self.freeze() | ||
if enable_lora_finetune: | ||
self._enable_lora(self.transformer) | ||
print(f"CLIP transformer encoder add {len(self.lora_layers)} lora layers.") | ||
|
||
self.layer = layer | ||
self.layer_idx = layer_idx | ||
self.return_pooled = always_return_pooled | ||
if layer == "hidden": | ||
assert layer_idx is not None | ||
assert 0 <= abs(layer_idx) <= 12 | ||
|
||
def freeze(self): | ||
self.transformer = self.transformer.eval() | ||
for param in self.parameters(): | ||
param.requires_grad = False | ||
|
||
def forward(self, text, max_sequence_length=None): | ||
batch_encoding = self.tokenizer( | ||
text, | ||
truncation=True, | ||
max_length=max_sequence_length if max_sequence_length else self.max_length, | ||
return_length=True, | ||
return_overflowing_tokens=False, | ||
padding="max_length", | ||
return_tensors="pt", | ||
) | ||
tokens = batch_encoding["input_ids"].to(self.transformer.device, non_blocking=True) | ||
outputs = self.transformer(input_ids=tokens, output_hidden_states=(self.layer == "hidden")) | ||
|
||
if self.layer == "last": | ||
z = outputs.last_hidden_state | ||
elif self.layer == "pooled": | ||
z = outputs.pooler_output[:, None, :] | ||
else: | ||
z = outputs.hidden_states[self.layer_idx] | ||
|
||
# Pad the seq length to multiple of 8 | ||
seq_len = (z.shape[1] + 8 - 1) // 8 * 8 | ||
z = torch.nn.functional.pad(z, (0, 0, 0, seq_len - z.shape[1]), value=0.0) | ||
if self.return_pooled: | ||
return z, outputs.pooler_output | ||
return z | ||
|
||
def encode(self, text): | ||
return self(text) | ||
|
||
|
||
class FrozenT5Embedder(AbstractEmbModel): | ||
def __init__( | ||
self, | ||
version="google/t5-v1_1-xxl", | ||
max_length=512, | ||
device="cuda", | ||
dtype=torch.float, | ||
): | ||
super().__init__() | ||
self.tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl", max_length=max_length) | ||
self.transformer = T5EncoderModel.from_pretrained(version, torch_dtype=dtype).to(device) | ||
self.max_length = max_length | ||
self.freeze() | ||
self.device = device | ||
self.dtype = dtype | ||
|
||
def freeze(self): | ||
self.transformer = self.transformer.eval() | ||
for param in self.parameters(): | ||
param.requires_grad = False | ||
|
||
def forward(self, text, max_sequence_length=None): | ||
batch_encoding = self.tokenizer( | ||
text, | ||
truncation=True, | ||
max_length=max_sequence_length if max_sequence_length else self.max_length, | ||
return_length=False, | ||
return_overflowing_tokens=False, | ||
padding="max_length", | ||
return_tensors="pt", | ||
) | ||
|
||
tokens = batch_encoding["input_ids"].to(self.transformer.device, non_blocking=True) | ||
outputs = self.transformer(input_ids=tokens, output_hidden_states=None) | ||
|
||
return outputs.last_hidden_state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
|
||
import torch | ||
|
||
from nemo.collections.diffusion.models.flux.pipeline import FluxInferencePipeline | ||
from nemo.collections.diffusion.utils.flux_pipeline_utils import configs | ||
from nemo.collections.diffusion.utils.mcore_parallel_utils import Utils | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description="The flux inference pipeline is utilizing megatron core transformer.\nPlease prepare the necessary checkpoints for flux model on local disk in order to use this script" | ||
) | ||
|
||
parser.add_argument("--flux_ckpt", type=str, default="", help="Path to Flux transformer checkpoint(s)") | ||
parser.add_argument("--vae_ckpt", type=str, default="/ckpts/ae.safetensors", help="Path to \'ae.safetensors\'") | ||
parser.add_argument( | ||
"--clip_version", | ||
type=str, | ||
default='/ckpts/text_encoder', | ||
help="Clip version, provide either ckpt dir or clip version like openai/clip-vit-large-patch14", | ||
) | ||
parser.add_argument( | ||
"--t5_version", | ||
type=str, | ||
default='/ckpts/text_encoder_2', | ||
help="Clip version, provide either ckpt dir or clip version like google/t5-v1_1-xxl", | ||
) | ||
parser.add_argument( | ||
"--do_convert_from_hf", | ||
action='store_true', | ||
default=False, | ||
help="Must be true if provided checkpoint is not already converted to NeMo version", | ||
) | ||
parser.add_argument( | ||
"--save_converted_model", | ||
action="store_true", | ||
default=False, | ||
help="Whether to save the converted NeMo transformer checkpoint for Flux", | ||
) | ||
parser.add_argument( | ||
"--version", | ||
type=str, | ||
default='dev', | ||
choices=['dev', 'schnell'], | ||
help="Must align with the checkpoint provided.", | ||
) | ||
parser.add_argument("--height", type=int, default=1024, help="Image height.") | ||
parser.add_argument("--width", type=int, default=1024, help="Image width.") | ||
parser.add_argument("--inference_steps", type=int, default=10, help="Number of inference steps to run.") | ||
parser.add_argument( | ||
"--num_images_per_prompt", type=int, default=1, help="Number of images to generate for each prompt." | ||
) | ||
parser.add_argument("--guidance", type=float, default=0.0, help="Guidance scale.") | ||
parser.add_argument( | ||
"--offload", action='store_true', default=False, help="Offload modules to cpu after being called." | ||
) | ||
parser.add_argument( | ||
"--prompts", | ||
type=str, | ||
default="A cat holding a sign that says hello world", | ||
help="Inference prompts, use \',\' to separate if multiple prompts are provided.", | ||
) | ||
parser.add_argument("--bf16", action='store_true', default=False, help="Use bf16 in inference.") | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
print('Initializing model parallel config') | ||
Utils.initialize_distributed(1, 1, 1) | ||
|
||
print('Initializing flux inference pipeline') | ||
params = configs[args.version] | ||
params.vae_params.ckpt = args.vae_ckpt | ||
params.clip_params['version'] = args.clip_version | ||
params.t5_params['version'] = args.t5_version | ||
pipe = FluxInferencePipeline(params) | ||
|
||
print('Loading transformer weights') | ||
pipe.load_from_pretrained( | ||
args.flux_ckpt, | ||
do_convert_from_hf=args.do_convert_from_hf, | ||
save_converted_model=args.save_converted_model, | ||
) | ||
dtype = torch.bfloat16 if args.bf16 else torch.float32 | ||
text = args.prompts.split(',') | ||
pipe( | ||
text, | ||
max_sequence_length=256, | ||
height=args.height, | ||
width=args.width, | ||
num_inference_steps=args.inference_steps, | ||
num_images_per_prompt=args.num_images_per_prompt, | ||
offload=args.offload, | ||
guidance_scale=args.guidance, | ||
dtype=dtype, | ||
) |
Oops, something went wrong.