Skip to content

Commit

Permalink
Add flux inference pipeline (NVIDIA#10752)
Browse files Browse the repository at this point in the history
* Vae added and matched flux checkpoint

Signed-off-by: mingyuanm <[email protected]>

* Flux model added.

Signed-off-by: mingyuanm <[email protected]>

* Copying FlowMatchEulerScheduler over

Signed-off-by: mingyuanm <[email protected]>

* WIP: Start to test the pipeline forward pass

Signed-off-by: mingyuanm <[email protected]>

* Vae added and matched flux checkpoint

Signed-off-by: mingyuanm <[email protected]>

* Inference pipeline runs with offloading function

Signed-off-by: mingyuanm <[email protected]>

* Start to test image generation

Signed-off-by: mingyuanm <[email protected]>

* Decoding with VAE part has been verified. Still need to check the denoising loop.

Signed-off-by: mingyuanm <[email protected]>

* The inference pipeline is verified.

Signed-off-by: mingyuanm <[email protected]>

* Add arg parsers and refactoring

Signed-off-by: mingyuanm <[email protected]>

* Tested on multi batch sizes and prompts.

Signed-off-by: mingyuanm <[email protected]>

* Add headers

Signed-off-by: mingyuanm <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Victor49152 <[email protected]>

* Renaming

Signed-off-by: mingyuanm <[email protected]>

* Move shceduler to sampler folder

Signed-off-by: mingyuanm <[email protected]>

* Merging folders.

Signed-off-by: mingyuanm <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Victor49152 <[email protected]>

* Tested after path changing.

Signed-off-by: mingyuanm <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Victor49152 <[email protected]>

* Move MMDIT block to NeMo

Signed-off-by: mingyuanm <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Victor49152 <[email protected]>

* Add joint attention and single attention to NeMo

Signed-off-by: mingyuanm <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Victor49152 <[email protected]>

* Joint attention updated

Signed-off-by: mingyuanm <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Victor49152 <[email protected]>

* Remove redundant importing

Signed-off-by: mingyuanm <[email protected]>

* Refactor to inherit megatron module

Signed-off-by: mingyuanm <[email protected]>

* Apply isort and black reformatting

Signed-off-by: Victor49152 <[email protected]>

---------

Signed-off-by: mingyuanm <[email protected]>
Signed-off-by: Victor49152 <[email protected]>
Co-authored-by: Victor49152 <[email protected]>
  • Loading branch information
Victor49152 and Victor49152 authored Oct 22, 2024
1 parent c7a539a commit 47f2446
Show file tree
Hide file tree
Showing 17 changed files with 2,971 additions and 9 deletions.
13 changes: 13 additions & 0 deletions nemo/collections/diffusion/encoders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
199 changes: 199 additions & 0 deletions nemo/collections/diffusion/encoders/conditioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import torch
import torch.nn as nn
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer


class AbstractEmbModel(nn.Module):
def __init__(self, enable_lora_finetune=False, target_block=[], target_module=[]):
super().__init__()
self._is_trainable = None
self._ucg_rate = None
self._input_key = None

self.TARGET_BLOCK = target_block
self.TARGET_MODULE = target_module
if enable_lora_finetune:
self.lora_layers = []

@property
def is_trainable(self) -> bool:
return self._is_trainable

@property
def ucg_rate(self) -> Union[float, torch.Tensor]:
return self._ucg_rate

@property
def input_key(self) -> str:
return self._input_key

@is_trainable.setter
def is_trainable(self, value: bool):
self._is_trainable = value

@ucg_rate.setter
def ucg_rate(self, value: Union[float, torch.Tensor]):
self._ucg_rate = value

@input_key.setter
def input_key(self, value: str):
self._input_key = value

@is_trainable.deleter
def is_trainable(self):
del self._is_trainable

@ucg_rate.deleter
def ucg_rate(self):
del self._ucg_rate

@input_key.deleter
def input_key(self):
del self._input_key

def encode(self, *args, **kwargs):
raise NotImplementedError

def _enable_lora(self, lora_model):
for module_name, module in lora_model.named_modules():
if module.__class__.__name__ in self.TARGET_BLOCK:
tmp = {}
for sub_name, sub_module in module.named_modules():
if sub_module.__class__.__name__ in self.TARGET_MODULE:
if hasattr(sub_module, "input_size") and hasattr(
sub_module, "output_size"
): # for megatron ParallelLinear
lora = LoraWrapper(sub_module, sub_module.input_size, sub_module.output_size)
else: # for nn.Linear
lora = LoraWrapper(sub_module, sub_module.in_features, sub_module.out_features)
self.lora_layers.append(lora)
if sub_name not in tmp.keys():
tmp.update({sub_name: lora})
else:
print(f"Duplicate subnames are found in module {module_name}")
for sub_name, lora_layer in tmp.items():
lora_name = f'{sub_name}_lora'
module.add_module(lora_name, lora_layer)


class FrozenCLIPEmbedder(AbstractEmbModel):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""

LAYERS = ["last", "pooled", "hidden"]

def __init__(
self,
version="openai/clip-vit-large-patch14",
device="cuda",
max_length=77,
enable_lora_finetune=False,
layer="last",
layer_idx=None,
always_return_pooled=False,
dtype=torch.float,
):
super().__init__(enable_lora_finetune, target_block=["CLIPAttention", "CLIPMLP"], target_module=["Linear"])
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
self.transformer = CLIPTextModel.from_pretrained(version, torch_dtype=dtype).to(device)
self.device = device
self.max_length = max_length
self.freeze()
if enable_lora_finetune:
self._enable_lora(self.transformer)
print(f"CLIP transformer encoder add {len(self.lora_layers)} lora layers.")

self.layer = layer
self.layer_idx = layer_idx
self.return_pooled = always_return_pooled
if layer == "hidden":
assert layer_idx is not None
assert 0 <= abs(layer_idx) <= 12

def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False

def forward(self, text, max_sequence_length=None):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=max_sequence_length if max_sequence_length else self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.transformer.device, non_blocking=True)
outputs = self.transformer(input_ids=tokens, output_hidden_states=(self.layer == "hidden"))

if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]

# Pad the seq length to multiple of 8
seq_len = (z.shape[1] + 8 - 1) // 8 * 8
z = torch.nn.functional.pad(z, (0, 0, 0, seq_len - z.shape[1]), value=0.0)
if self.return_pooled:
return z, outputs.pooler_output
return z

def encode(self, text):
return self(text)


class FrozenT5Embedder(AbstractEmbModel):
def __init__(
self,
version="google/t5-v1_1-xxl",
max_length=512,
device="cuda",
dtype=torch.float,
):
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl", max_length=max_length)
self.transformer = T5EncoderModel.from_pretrained(version, torch_dtype=dtype).to(device)
self.max_length = max_length
self.freeze()
self.device = device
self.dtype = dtype

def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False

def forward(self, text, max_sequence_length=None):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=max_sequence_length if max_sequence_length else self.max_length,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)

tokens = batch_encoding["input_ids"].to(self.transformer.device, non_blocking=True)
outputs = self.transformer(input_ids=tokens, output_hidden_states=None)

return outputs.last_hidden_state
113 changes: 113 additions & 0 deletions nemo/collections/diffusion/flux_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

import torch

from nemo.collections.diffusion.models.flux.pipeline import FluxInferencePipeline
from nemo.collections.diffusion.utils.flux_pipeline_utils import configs
from nemo.collections.diffusion.utils.mcore_parallel_utils import Utils


def parse_args():
parser = argparse.ArgumentParser(
description="The flux inference pipeline is utilizing megatron core transformer.\nPlease prepare the necessary checkpoints for flux model on local disk in order to use this script"
)

parser.add_argument("--flux_ckpt", type=str, default="", help="Path to Flux transformer checkpoint(s)")
parser.add_argument("--vae_ckpt", type=str, default="/ckpts/ae.safetensors", help="Path to \'ae.safetensors\'")
parser.add_argument(
"--clip_version",
type=str,
default='/ckpts/text_encoder',
help="Clip version, provide either ckpt dir or clip version like openai/clip-vit-large-patch14",
)
parser.add_argument(
"--t5_version",
type=str,
default='/ckpts/text_encoder_2',
help="Clip version, provide either ckpt dir or clip version like google/t5-v1_1-xxl",
)
parser.add_argument(
"--do_convert_from_hf",
action='store_true',
default=False,
help="Must be true if provided checkpoint is not already converted to NeMo version",
)
parser.add_argument(
"--save_converted_model",
action="store_true",
default=False,
help="Whether to save the converted NeMo transformer checkpoint for Flux",
)
parser.add_argument(
"--version",
type=str,
default='dev',
choices=['dev', 'schnell'],
help="Must align with the checkpoint provided.",
)
parser.add_argument("--height", type=int, default=1024, help="Image height.")
parser.add_argument("--width", type=int, default=1024, help="Image width.")
parser.add_argument("--inference_steps", type=int, default=10, help="Number of inference steps to run.")
parser.add_argument(
"--num_images_per_prompt", type=int, default=1, help="Number of images to generate for each prompt."
)
parser.add_argument("--guidance", type=float, default=0.0, help="Guidance scale.")
parser.add_argument(
"--offload", action='store_true', default=False, help="Offload modules to cpu after being called."
)
parser.add_argument(
"--prompts",
type=str,
default="A cat holding a sign that says hello world",
help="Inference prompts, use \',\' to separate if multiple prompts are provided.",
)
parser.add_argument("--bf16", action='store_true', default=False, help="Use bf16 in inference.")
args = parser.parse_args()
return args


if __name__ == '__main__':
args = parse_args()
print('Initializing model parallel config')
Utils.initialize_distributed(1, 1, 1)

print('Initializing flux inference pipeline')
params = configs[args.version]
params.vae_params.ckpt = args.vae_ckpt
params.clip_params['version'] = args.clip_version
params.t5_params['version'] = args.t5_version
pipe = FluxInferencePipeline(params)

print('Loading transformer weights')
pipe.load_from_pretrained(
args.flux_ckpt,
do_convert_from_hf=args.do_convert_from_hf,
save_converted_model=args.save_converted_model,
)
dtype = torch.bfloat16 if args.bf16 else torch.float32
text = args.prompts.split(',')
pipe(
text,
max_sequence_length=256,
height=args.height,
width=args.width,
num_inference_steps=args.inference_steps,
num_images_per_prompt=args.num_images_per_prompt,
offload=args.offload,
guidance_scale=args.guidance,
dtype=dtype,
)
Loading

0 comments on commit 47f2446

Please sign in to comment.