diff --git a/nemo/collections/diffusion/encoders/__init__.py b/nemo/collections/diffusion/encoders/__init__.py new file mode 100644 index 000000000000..9e3250071955 --- /dev/null +++ b/nemo/collections/diffusion/encoders/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/diffusion/encoders/conditioner.py b/nemo/collections/diffusion/encoders/conditioner.py new file mode 100644 index 000000000000..2bfb008c5d84 --- /dev/null +++ b/nemo/collections/diffusion/encoders/conditioner.py @@ -0,0 +1,199 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +import torch +import torch.nn as nn +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer + + +class AbstractEmbModel(nn.Module): + def __init__(self, enable_lora_finetune=False, target_block=[], target_module=[]): + super().__init__() + self._is_trainable = None + self._ucg_rate = None + self._input_key = None + + self.TARGET_BLOCK = target_block + self.TARGET_MODULE = target_module + if enable_lora_finetune: + self.lora_layers = [] + + @property + def is_trainable(self) -> bool: + return self._is_trainable + + @property + def ucg_rate(self) -> Union[float, torch.Tensor]: + return self._ucg_rate + + @property + def input_key(self) -> str: + return self._input_key + + @is_trainable.setter + def is_trainable(self, value: bool): + self._is_trainable = value + + @ucg_rate.setter + def ucg_rate(self, value: Union[float, torch.Tensor]): + self._ucg_rate = value + + @input_key.setter + def input_key(self, value: str): + self._input_key = value + + @is_trainable.deleter + def is_trainable(self): + del self._is_trainable + + @ucg_rate.deleter + def ucg_rate(self): + del self._ucg_rate + + @input_key.deleter + def input_key(self): + del self._input_key + + def encode(self, *args, **kwargs): + raise NotImplementedError + + def _enable_lora(self, lora_model): + for module_name, module in lora_model.named_modules(): + if module.__class__.__name__ in self.TARGET_BLOCK: + tmp = {} + for sub_name, sub_module in module.named_modules(): + if sub_module.__class__.__name__ in self.TARGET_MODULE: + if hasattr(sub_module, "input_size") and hasattr( + sub_module, "output_size" + ): # for megatron ParallelLinear + lora = LoraWrapper(sub_module, sub_module.input_size, sub_module.output_size) + else: # for nn.Linear + lora = LoraWrapper(sub_module, sub_module.in_features, sub_module.out_features) + self.lora_layers.append(lora) + if sub_name not in tmp.keys(): + tmp.update({sub_name: lora}) + else: + print(f"Duplicate subnames are found in module {module_name}") + for sub_name, lora_layer in tmp.items(): + lora_name = f'{sub_name}_lora' + module.add_module(lora_name, lora_layer) + + +class FrozenCLIPEmbedder(AbstractEmbModel): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + + LAYERS = ["last", "pooled", "hidden"] + + def __init__( + self, + version="openai/clip-vit-large-patch14", + device="cuda", + max_length=77, + enable_lora_finetune=False, + layer="last", + layer_idx=None, + always_return_pooled=False, + dtype=torch.float, + ): + super().__init__(enable_lora_finetune, target_block=["CLIPAttention", "CLIPMLP"], target_module=["Linear"]) + self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + self.transformer = CLIPTextModel.from_pretrained(version, torch_dtype=dtype).to(device) + self.device = device + self.max_length = max_length + self.freeze() + if enable_lora_finetune: + self._enable_lora(self.transformer) + print(f"CLIP transformer encoder add {len(self.lora_layers)} lora layers.") + + self.layer = layer + self.layer_idx = layer_idx + self.return_pooled = always_return_pooled + if layer == "hidden": + assert layer_idx is not None + assert 0 <= abs(layer_idx) <= 12 + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text, max_sequence_length=None): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=max_sequence_length if max_sequence_length else self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.transformer.device, non_blocking=True) + outputs = self.transformer(input_ids=tokens, output_hidden_states=(self.layer == "hidden")) + + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + + # Pad the seq length to multiple of 8 + seq_len = (z.shape[1] + 8 - 1) // 8 * 8 + z = torch.nn.functional.pad(z, (0, 0, 0, seq_len - z.shape[1]), value=0.0) + if self.return_pooled: + return z, outputs.pooler_output + return z + + def encode(self, text): + return self(text) + + +class FrozenT5Embedder(AbstractEmbModel): + def __init__( + self, + version="google/t5-v1_1-xxl", + max_length=512, + device="cuda", + dtype=torch.float, + ): + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl", max_length=max_length) + self.transformer = T5EncoderModel.from_pretrained(version, torch_dtype=dtype).to(device) + self.max_length = max_length + self.freeze() + self.device = device + self.dtype = dtype + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text, max_sequence_length=None): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=max_sequence_length if max_sequence_length else self.max_length, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + + tokens = batch_encoding["input_ids"].to(self.transformer.device, non_blocking=True) + outputs = self.transformer(input_ids=tokens, output_hidden_states=None) + + return outputs.last_hidden_state diff --git a/nemo/collections/diffusion/flux_infer.py b/nemo/collections/diffusion/flux_infer.py new file mode 100644 index 000000000000..f914dbf50258 --- /dev/null +++ b/nemo/collections/diffusion/flux_infer.py @@ -0,0 +1,113 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch + +from nemo.collections.diffusion.models.flux.pipeline import FluxInferencePipeline +from nemo.collections.diffusion.utils.flux_pipeline_utils import configs +from nemo.collections.diffusion.utils.mcore_parallel_utils import Utils + + +def parse_args(): + parser = argparse.ArgumentParser( + description="The flux inference pipeline is utilizing megatron core transformer.\nPlease prepare the necessary checkpoints for flux model on local disk in order to use this script" + ) + + parser.add_argument("--flux_ckpt", type=str, default="", help="Path to Flux transformer checkpoint(s)") + parser.add_argument("--vae_ckpt", type=str, default="/ckpts/ae.safetensors", help="Path to \'ae.safetensors\'") + parser.add_argument( + "--clip_version", + type=str, + default='/ckpts/text_encoder', + help="Clip version, provide either ckpt dir or clip version like openai/clip-vit-large-patch14", + ) + parser.add_argument( + "--t5_version", + type=str, + default='/ckpts/text_encoder_2', + help="Clip version, provide either ckpt dir or clip version like google/t5-v1_1-xxl", + ) + parser.add_argument( + "--do_convert_from_hf", + action='store_true', + default=False, + help="Must be true if provided checkpoint is not already converted to NeMo version", + ) + parser.add_argument( + "--save_converted_model", + action="store_true", + default=False, + help="Whether to save the converted NeMo transformer checkpoint for Flux", + ) + parser.add_argument( + "--version", + type=str, + default='dev', + choices=['dev', 'schnell'], + help="Must align with the checkpoint provided.", + ) + parser.add_argument("--height", type=int, default=1024, help="Image height.") + parser.add_argument("--width", type=int, default=1024, help="Image width.") + parser.add_argument("--inference_steps", type=int, default=10, help="Number of inference steps to run.") + parser.add_argument( + "--num_images_per_prompt", type=int, default=1, help="Number of images to generate for each prompt." + ) + parser.add_argument("--guidance", type=float, default=0.0, help="Guidance scale.") + parser.add_argument( + "--offload", action='store_true', default=False, help="Offload modules to cpu after being called." + ) + parser.add_argument( + "--prompts", + type=str, + default="A cat holding a sign that says hello world", + help="Inference prompts, use \',\' to separate if multiple prompts are provided.", + ) + parser.add_argument("--bf16", action='store_true', default=False, help="Use bf16 in inference.") + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + print('Initializing model parallel config') + Utils.initialize_distributed(1, 1, 1) + + print('Initializing flux inference pipeline') + params = configs[args.version] + params.vae_params.ckpt = args.vae_ckpt + params.clip_params['version'] = args.clip_version + params.t5_params['version'] = args.t5_version + pipe = FluxInferencePipeline(params) + + print('Loading transformer weights') + pipe.load_from_pretrained( + args.flux_ckpt, + do_convert_from_hf=args.do_convert_from_hf, + save_converted_model=args.save_converted_model, + ) + dtype = torch.bfloat16 if args.bf16 else torch.float32 + text = args.prompts.split(',') + pipe( + text, + max_sequence_length=256, + height=args.height, + width=args.width, + num_inference_steps=args.inference_steps, + num_images_per_prompt=args.num_images_per_prompt, + offload=args.offload, + guidance_scale=args.guidance, + dtype=dtype, + ) diff --git a/nemo/collections/diffusion/models/dit/dit_attention.py b/nemo/collections/diffusion/models/dit/dit_attention.py new file mode 100644 index 000000000000..9e60b11dd1c6 --- /dev/null +++ b/nemo/collections/diffusion/models/dit/dit_attention.py @@ -0,0 +1,428 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +from dataclasses import dataclass +from typing import Union + +import torch +from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb +from megatron.core.transformer.attention import Attention, SelfAttention +from megatron.core.transformer.custom_layers.transformer_engine import SplitAlongDim +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig + + +@dataclass +class JointSelfAttentionSubmodules: + linear_qkv: Union[ModuleSpec, type] = None + added_linear_qkv: Union[ModuleSpec, type] = None + core_attention: Union[ModuleSpec, type] = None + linear_proj: Union[ModuleSpec, type] = None + q_layernorm: Union[ModuleSpec, type] = None + k_layernorm: Union[ModuleSpec, type] = None + added_q_layernorm: Union[ModuleSpec, type] = None + added_k_layernorm: Union[ModuleSpec, type] = None + + +class JointSelfAttention(Attention): + """Joint Self-attention layer class + + Used for MMDIT-like transformer block. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: JointSelfAttentionSubmodules, + layer_number: int, + attn_mask_type=AttnMaskType.padding, + context_pre_only: bool = False, + ): + super().__init__( + config=config, + submodules=submodules, + layer_number=layer_number, + attn_mask_type=attn_mask_type, + attention_type="self", + ) + + self.linear_qkv = build_module( + submodules.linear_qkv, + self.config.hidden_size, + self.query_projection_size + 2 * self.kv_projection_size, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.config.add_bias_linear or self.config.add_qkv_bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='qkv', + ) + + if submodules.added_linear_qkv is not None: + self.added_linear_qkv = build_module( + submodules.added_linear_qkv, + self.config.hidden_size, + self.query_projection_size + 2 * self.kv_projection_size, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.config.add_qkv_bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='qkv', + ) + + if not context_pre_only: + self.added_linear_proj = build_module( + submodules.linear_proj, + self.query_projection_size, + self.config.hidden_size, + config=self.config, + init_method=self.config.output_layer_init_method, + bias=self.config.add_bias_linear, + input_is_parallel=True, + skip_bias_add=True, + is_expert=False, + tp_comm_buffer_name='proj', + ) + + if submodules.q_layernorm is not None: + self.q_layernorm = build_module( + submodules.q_layernorm, + hidden_size=self.hidden_size_per_attention_head, + config=self.config, + eps=self.config.layernorm_epsilon, + ) + else: + self.q_layernorm = None + + if submodules.k_layernorm is not None: + self.k_layernorm = build_module( + submodules.k_layernorm, + hidden_size=self.hidden_size_per_attention_head, + config=self.config, + eps=self.config.layernorm_epsilon, + ) + else: + self.k_layernorm = None + + if submodules.added_q_layernorm is not None: + self.added_q_layernorm = build_module( + submodules.added_q_layernorm, + hidden_size=self.hidden_size_per_attention_head, + config=self.config, + eps=self.config.layernorm_epsilon, + ) + else: + self.added_q_layernorm = None + + if submodules.added_k_layernorm is not None: + self.added_k_layernorm = build_module( + submodules.added_k_layernorm, + hidden_size=self.hidden_size_per_attention_head, + config=self.config, + eps=self.config.layernorm_epsilon, + ) + else: + self.added_k_layernorm = None + + def _split_qkv(self, mixed_qkv): + # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] + new_tensor_shape = mixed_qkv.size()[:-1] + ( + self.num_query_groups_per_partition, + ( + (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) + * self.hidden_size_per_attention_head + ), + ) + mixed_qkv = mixed_qkv.view(*new_tensor_shape) + + split_arg_list = [ + ( + self.num_attention_heads_per_partition + // self.num_query_groups_per_partition + * self.hidden_size_per_attention_head + ), + self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head, + ] + + if SplitAlongDim is not None: + + # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query, key, value) = SplitAlongDim( + mixed_qkv, + 3, + split_arg_list, + ) + else: + + # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query, key, value) = torch.split( + mixed_qkv, + split_arg_list, + dim=3, + ) + + # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] + query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) + return query, key, value + + def get_query_key_value_tensors(self, hidden_states, key_value_states=None): + """ + Derives `query`, `key` and `value` tensors from `hidden_states`. + """ + # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] + mixed_qkv, _ = self.linear_qkv(hidden_states) + + query, key, value = self._split_qkv(mixed_qkv) + + if self.config.test_mode: + self.run_realtime_tests() + + if self.q_layernorm is not None: + query = self.q_layernorm(query) + + if self.k_layernorm is not None: + key = self.k_layernorm(key) + + return query, key, value + + def get_added_query_key_value_tensors(self, added_hidden_states, key_value_states=None): + """ + Derives `query`, `key` and `value` tensors from `hidden_states`. + """ + # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] + mixed_qkv, _ = self.added_linear_qkv(added_hidden_states) + + query, key, value = self._split_qkv(mixed_qkv) + + if self.config.test_mode: + self.run_realtime_tests() + + if self.added_q_layernorm is not None: + query = self.added_q_layernorm(query) + + if self.added_k_layernorm is not None: + key = self.added_k_layernorm(key) + + return query, key, value + + def forward( + self, + hidden_states, + attention_mask, + key_value_states=None, + inference_params=None, + rotary_pos_emb=None, + packed_seq_params=None, + additional_hidden_states=None, + ): + # hidden_states: [sq, b, h] + + # For self attention we just duplicate the rotary_pos_emb if it isn't already + if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = (rotary_pos_emb,) * 2 + + # ===================== + # Query, Key, and Value + # ===================== + # Get the query, key and value tensors based on the type of attention - + # self or cross attn. + + query, key, value = self.get_query_key_value_tensors(hidden_states) + added_query, added_key, added_value = self.get_added_query_key_value_tensors(additional_hidden_states) + + query = torch.cat([added_query, query], dim=0) + key = torch.cat([added_key, key], dim=0) + value = torch.cat([added_value, value], dim=0) + + # =================================================== + # Adjust key, value, and rotary_pos_emb for inference + # =================================================== + key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference( + inference_params, key, value, rotary_pos_emb + ) + + if packed_seq_params is not None: + query = query.squeeze(1) + key = key.squeeze(1) + value = value.squeeze(1) + + # ================================================ + # relative positional embedding (rotary embedding) + # ================================================ + if rotary_pos_emb is not None: + q_pos_emb, k_pos_emb = rotary_pos_emb + + if packed_seq_params is not None: + cu_seqlens_q = packed_seq_params.cu_seqlens_q + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv + else: + cu_seqlens_q = cu_seqlens_kv = None + query = apply_rotary_pos_emb( + query, + q_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_q, + ) + key = apply_rotary_pos_emb( + key, + k_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_kv, + ) + + # TODO, can apply positional embedding to value_layer so it has + # absolute positional embedding. + # otherwise, only relative positional embedding takes effect + # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) + + # ================================== + # core attention computation + # ================================== + if self.checkpoint_core_attention and self.training: + core_attn_out = self._checkpointed_attention_forward( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + packed_seq_params=packed_seq_params, + ) + else: + core_attn_out = self.core_attention( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + packed_seq_params=packed_seq_params, + ) + + if packed_seq_params is not None: + # reshape to same output shape as unpacked case + # (t, np, hn) -> (t, b=1, h=np*hn) + # t is the pack size = sum (sq_i) + # note that batch is a dummy dimension in the packed case + core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) + + # ================= + # Output. [sq, b, h] + # ================= + encoder_attention_output = core_attn_out[: additional_hidden_states.shape[0], :, :] + attention_output = core_attn_out[additional_hidden_states.shape[0] :, :, :] + + output, bias = self.linear_proj(attention_output) + encoder_output, encoder_bias = self.added_linear_proj(encoder_attention_output) + + output = output + bias + encoder_output = encoder_output + encoder_bias + + return output, encoder_output + + +class FluxSingleAttention(SelfAttention): + """Self-attention layer class + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def forward( + self, + hidden_states, + attention_mask, + key_value_states=None, + inference_params=None, + rotary_pos_emb=None, + packed_seq_params=None, + ): + # hidden_states: [sq, b, h] + + # For self attention we just duplicate the rotary_pos_emb if it isn't already + if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = (rotary_pos_emb,) * 2 + + # ===================== + # Query, Key, and Value + # ===================== + # Get the query, key and value tensors based on the type of attention - + # self or cross attn. + query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) + # print(f'megatron q before ln: {query.transpose(0, 1).contiguous()}, {query.transpose(0, 1).contiguous().shape}') + # print(f'megatron k before ln: {key.transpose(0, 1).contiguous()}, {key.transpose(0, 1).contiguous().shape}') + # print(f'megatron v before ln: {value.transpose(0, 1).contiguous()}, {value.transpose(0, 1).contiguous().shape}') + + # =================================================== + # Adjust key, value, and rotary_pos_emb for inference + # =================================================== + key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference( + inference_params, key, value, rotary_pos_emb + ) + + if packed_seq_params is not None: + query = query.squeeze(1) + key = key.squeeze(1) + value = value.squeeze(1) + + # ================================================ + # relative positional embedding (rotary embedding) + # ================================================ + if rotary_pos_emb is not None: + q_pos_emb, k_pos_emb = rotary_pos_emb + + if packed_seq_params is not None: + cu_seqlens_q = packed_seq_params.cu_seqlens_q + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv + else: + cu_seqlens_q = cu_seqlens_kv = None + query = apply_rotary_pos_emb( + query, + q_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_q, + ) + key = apply_rotary_pos_emb( + key, + k_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_kv, + ) + + # TODO, can apply positional embedding to value_layer so it has + # absolute positional embedding. + # otherwise, only relative positional embedding takes effect + # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) + + # ================================== + # core attention computation + # ================================== + + if self.checkpoint_core_attention and self.training: + core_attn_out = self._checkpointed_attention_forward( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + packed_seq_params=packed_seq_params, + ) + else: + core_attn_out = self.core_attention( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + packed_seq_params=packed_seq_params, + ) + + if packed_seq_params is not None: + # reshape to same output shape as unpacked case + # (t, np, hn) -> (t, b=1, h=np*hn) + # t is the pack size = sum (sq_i) + # note that batch is a dummy dimension in the packed case + core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) + + return core_attn_out diff --git a/nemo/collections/diffusion/models/dit/dit_layer_spec.py b/nemo/collections/diffusion/models/dit/dit_layer_spec.py index 672dcff3ba00..cb7c520493f0 100644 --- a/nemo/collections/diffusion/models/dit/dit_layer_spec.py +++ b/nemo/collections/diffusion/models/dit/dit_layer_spec.py @@ -42,6 +42,12 @@ from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules from megatron.core.utils import make_viewless_tensor +from nemo.collections.diffusion.models.dit.dit_attention import ( + FluxSingleAttention, + JointSelfAttention, + JointSelfAttentionSubmodules, +) + @dataclass class DiTWithAdaLNSubmodules(TransformerLayerSubmodules): @@ -75,7 +81,14 @@ class AdaLN(MegatronModule): Adaptive Layer Normalization Module for DiT. """ - def __init__(self, config: TransformerConfig, n_adaln_chunks=9, norm=nn.LayerNorm): + def __init__( + self, + config: TransformerConfig, + n_adaln_chunks=9, + norm=nn.LayerNorm, + modulation_bias=False, + use_second_norm=False, + ): super().__init__(config) if norm == TENorm: self.ln = norm(config, config.hidden_size, config.layernorm_epsilon) @@ -83,8 +96,11 @@ def __init__(self, config: TransformerConfig, n_adaln_chunks=9, norm=nn.LayerNor self.ln = norm(config.hidden_size, elementwise_affine=False, eps=self.config.layernorm_epsilon) self.n_adaln_chunks = n_adaln_chunks self.adaLN_modulation = nn.Sequential( - nn.SiLU(), nn.Linear(config.hidden_size, self.n_adaln_chunks * config.hidden_size, bias=False) + nn.SiLU(), nn.Linear(config.hidden_size, self.n_adaln_chunks * config.hidden_size, bias=modulation_bias) ) + self.use_second_norm = use_second_norm + if self.use_second_norm: + self.ln2 = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6) nn.init.constant_(self.adaLN_modulation[-1].weight, 0) setattr(self.adaLN_modulation[-1].weight, "sequence_parallel", config.sequence_parallel) @@ -92,29 +108,59 @@ def __init__(self, config: TransformerConfig, n_adaln_chunks=9, norm=nn.LayerNor def forward(self, timestep_emb): return self.adaLN_modulation(timestep_emb).chunk(self.n_adaln_chunks, dim=-1) - @jit_fuser + # @jit_fuser def modulate(self, x, shift, scale): return x * (1 + scale) + shift - @jit_fuser + # @jit_fuser def scale_add(self, residual, x, gate): return residual + gate * x - @jit_fuser - def modulated_layernorm(self, x, shift, scale): + # @jit_fuser + def modulated_layernorm(self, x, shift, scale, layernorm_idx=0): + if self.use_second_norm and layernorm_idx == 1: + layernorm = self.ln2 + else: + layernorm = self.ln # Optional Input Layer norm - input_layernorm_output = self.ln(x).type_as(x) + input_layernorm_output = layernorm(x).type_as(x) # DiT block specific return self.modulate(input_layernorm_output, shift, scale) # @jit_fuser - def scaled_modulated_layernorm(self, residual, x, gate, shift, scale): + def scaled_modulated_layernorm(self, residual, x, gate, shift, scale, layernorm_idx=0): hidden_states = self.scale_add(residual, x, gate) - shifted_pre_mlp_layernorm_output = self.modulated_layernorm(hidden_states, shift, scale) + shifted_pre_mlp_layernorm_output = self.modulated_layernorm(hidden_states, shift, scale, layernorm_idx) return hidden_states, shifted_pre_mlp_layernorm_output +class AdaLNContinuous(MegatronModule): + def __init__( + self, + config: TransformerConfig, + conditioning_embedding_dim: int, + modulation_bias: bool = True, + norm_type: str = "layer_norm", + ): + super().__init__(config) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(conditioning_embedding_dim, config.hidden_size * 2, bias=modulation_bias) + ) + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6, bias=modulation_bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(config.hidden_size, eps=1e-6) + else: + raise ValueError("Unknown normalization type {}".format(norm_type)) + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + emb = self.adaLN_modulation(conditioning_embedding) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale) + shift + return x + + class STDiTLayerWithAdaLN(TransformerLayer): """A single transformer layer. @@ -407,6 +453,225 @@ def forward( return output, context +class DiTLayer(TransformerLayer): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + + Original DiT layer implementation from [https://arxiv.org/pdf/2212.09748]. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + mlp_ratio: int = 4, + n_adaln_chunks: int = 6, + modulation_bias: bool = True, + ): + # Modify the mlp layer hidden_size of a dit layer according to mlp_ratio + config.ffn_hidden_size = int(mlp_ratio * config.hidden_size) + super().__init__(config=config, submodules=submodules, layer_number=layer_number) + + self.adaLN = AdaLN( + config=config, n_adaln_chunks=n_adaln_chunks, modulation_bias=modulation_bias, use_second_norm=True + ) + + def forward( + self, + hidden_states, + attention_mask, + context=None, + context_mask=None, + rotary_pos_emb=None, + inference_params=None, + packed_seq_params=None, + ): + # passing in conditioning information via attention mask here + c = attention_mask + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN(c) + + shifted_input_layernorm_output = self.adaLN.modulated_layernorm( + hidden_states, shift=shift_msa, scale=scale_msa, layernorm_idx=0 + ) + + x, bias = self.self_attention(shifted_input_layernorm_output, attention_mask=None) + + hidden_states = self.adaLN.scale_add(hidden_states, x=(x + bias), gate=gate_msa) + + residual = hidden_states + + shited_pre_mlp_layernorm_output = self.adaLN.modulated_layernorm( + hidden_states, shift=shift_mlp, scale=scale_mlp, layernorm_idx=1 + ) + + x, bias = self.mlp(shited_pre_mlp_layernorm_output) + + hidden_states = self.adaLN.scale_add(residual, x=(x + bias), gate=gate_mlp) + + return hidden_states, context + + +class MMDiTLayer(TransformerLayer): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + + MMDiT layer implementation from [https://arxiv.org/pdf/2403.03206]. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + context_pre_only: bool = False, + ): + + hidden_size = config.hidden_size + super().__init__(config=config, submodules=submodules, layer_number=layer_number) + + self.adaln = AdaLN(config, modulation_bias=True, n_adaln_chunks=6, use_second_norm=True) + + self.context_pre_only = context_pre_only + context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero" + + if context_norm_type == "ada_norm_continous": + self.adaln_context = AdaLNContinous(config, hidden_size, modulation_bias=True, norm_type="layer_norm") + elif context_norm_type == "ada_norm_zero": + self.adaln_context = AdaLN(config, modulation_bias=True, n_adaln_chunks=6, use_second_norm=True) + else: + raise ValueError( + f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`" + ) + # Override Cross Attention to disable CP. + # Disable TP Comm overlap as well. Not disabling will attempt re-use of buffer size same as Q and lead to incorrect tensor shapes. + cp_override_config = copy.deepcopy(config) + cp_override_config.context_parallel_size = 1 + cp_override_config.tp_comm_overlap = False + + if not context_pre_only: + self.context_mlp = build_module( + submodules.mlp, + config=cp_override_config, + ) + else: + self.context_mlp = None + + def forward( + self, + hidden_states, + encoder_hidden_states, + attention_mask=None, + context=None, + context_mask=None, + rotary_pos_emb=None, + inference_params=None, + packed_seq_params=None, + emb=None, + ): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaln(emb) + + norm_hidden_states = self.adaln.modulated_layernorm( + hidden_states, shift=shift_msa, scale=scale_msa, layernorm_idx=0 + ) + if self.context_pre_only: + norm_encoder_hidden_states = self.adaln_context(encoder_hidden_states, emb) + else: + c_shift_msa, c_scale_msa, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.adaln_context(emb) + norm_encoder_hidden_states = self.adaln_context.modulated_layernorm( + encoder_hidden_states, shift=c_shift_msa, scale=c_scale_msa, layernorm_idx=0 + ) + + attention_output, encoder_attention_output = self.self_attention( + norm_hidden_states, + attention_mask=attention_mask, + key_value_states=None, + additional_hidden_states=norm_encoder_hidden_states, + rotary_pos_emb=rotary_pos_emb, + ) + hidden_states = self.adaln.scale_add(hidden_states, x=attention_output, gate=gate_msa) + norm_hidden_states = self.adaln.modulated_layernorm( + hidden_states, shift=shift_mlp, scale=scale_mlp, layernorm_idx=1 + ) + + mlp_output, mlp_output_bias = self.mlp(norm_hidden_states) + hidden_states = self.adaln.scale_add(hidden_states, x=(mlp_output + mlp_output_bias), gate=gate_mlp) + + if self.context_pre_only: + encoder_hidden_states = None + else: + encoder_hidden_states = self.adaln_context.scale_add( + encoder_hidden_states, x=encoder_attention_output, gate=c_gate_msa + ) + norm_encoder_hidden_states = self.adaln_context.modulated_layernorm( + encoder_hidden_states, shift=c_shift_mlp, scale=c_scale_mlp, layernorm_idx=1 + ) + + context_mlp_output, context_mlp_output_bias = self.context_mlp(norm_encoder_hidden_states) + encoder_hidden_states = self.adaln.scale_add( + encoder_hidden_states, x=(context_mlp_output + context_mlp_output_bias), gate=c_gate_mlp + ) + + return hidden_states, encoder_hidden_states + + +class FluxSingleTransformerBlock(TransformerLayer): + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + mlp_ratio: int = 4, + n_adaln_chunks: int = 3, + modulation_bias: bool = True, + ): + super().__init__(config=config, submodules=submodules, layer_number=layer_number) + hidden_size = config.hidden_size + self.adaln = AdaLN( + config=config, n_adaln_chunks=n_adaln_chunks, modulation_bias=modulation_bias, use_second_norm=False + ) + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.proj_in = nn.Linear(hidden_size, self.mlp_hidden_dim) + self.activation = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + def forward( + self, + hidden_states, + attention_mask=None, + context=None, + context_mask=None, + rotary_pos_emb=None, + inference_params=None, + packed_seq_params=None, + emb=None, + ): + residual = hidden_states + + shift, scale, gate = self.adaln(emb) + + norm_hidden_states = self.adaln.modulated_layernorm(hidden_states, shift=shift, scale=scale) + + mlp_hidden_states = self.activation(self.proj_in(norm_hidden_states)) + + attention_output = self.self_attention( + norm_hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb + ) + + hidden_states = torch.cat((attention_output, mlp_hidden_states), dim=2) + + hidden_states = self.proj_out(hidden_states) + + hidden_states = self.adaln.scale_add(residual, x=hidden_states, gate=gate) + + return hidden_states + + def get_stdit_adaln_block_with_transformer_engine_spec() -> ModuleSpec: params = {"attn_mask_type": AttnMaskType.padding} return ModuleSpec( @@ -530,3 +795,77 @@ def get_official_dit_adaln_block_with_transformer_engine_spec() -> ModuleSpec: ), ), ) + + +def get_mm_dit_block_with_transformer_engine_spec() -> ModuleSpec: + + return ModuleSpec( + module=MMDiTLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=JointSelfAttention, + params={"attn_mask_type": AttnMaskType.no_mask}, + submodules=JointSelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear, + added_linear_qkv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + ), + ), + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ), + ), + ), + ) + + +def get_flux_single_transformer_engine_spec() -> ModuleSpec: + return ModuleSpec( + module=FluxSingleTransformerBlock, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=FluxSingleAttention, + params={"attn_mask_type": AttnMaskType.no_mask}, + submodules=SelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + q_layernorm=RMSNorm, + k_layernorm=RMSNorm, + linear_proj=IdentityOp, + ), + ), + ), + ) + + +def get_flux_double_transformer_engine_spec() -> ModuleSpec: + return ModuleSpec( + module=MMDiTLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=JointSelfAttention, + params={"attn_mask_type": AttnMaskType.no_mask}, + submodules=JointSelfAttentionSubmodules( + q_layernorm=RMSNorm, + k_layernorm=RMSNorm, + added_q_layernorm=RMSNorm, + added_k_layernorm=RMSNorm, + linear_qkv=TEColumnParallelLinear, + added_linear_qkv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + ), + ), + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ), + ), + ), + ) diff --git a/nemo/collections/diffusion/models/flux/__init__.py b/nemo/collections/diffusion/models/flux/__init__.py new file mode 100644 index 000000000000..9e3250071955 --- /dev/null +++ b/nemo/collections/diffusion/models/flux/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/diffusion/models/flux/layers.py b/nemo/collections/diffusion/models/flux/layers.py new file mode 100644 index 000000000000..222a9a1d67ae --- /dev/null +++ b/nemo/collections/diffusion/models/flux/layers.py @@ -0,0 +1,173 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +from torch import Tensor, nn + + +def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + """ + Different from the original ROPE used for flux. + Megatron attention takes the out product and calculate sin/cos inside, so we only need to get the freqs here + in the shape of [seq, ..., dim] + """ + assert dim % 2 == 0, "The dimension must be even." + + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + + out = torch.einsum("...n,d->...nd", pos, omega) + + return out.float() + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-1, + ) + emb = emb.unsqueeze(1).permute(2, 0, 1, 3) + return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = True, + downscale_freq_shift: float = 0, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class Timesteps(nn.Module): + def __init__( + self, + embedding_dim: int, + flip_sin_to_cos: bool = True, + downscale_freq_shift: float = 0, + scale: float = 1, + max_period: int = 10000, + ): + super().__init__() + self.embedding_dim = embedding_dim + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + self.max_period = max_period + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + t_emb = get_timestep_embedding( + timesteps, + self.embedding_dim, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + max_period=self.max_period, + ) + return t_emb + + +class TimeStepEmbedder(nn.Module): + def __init__( + self, + embedding_dim: int, + hidden_dim: int, + flip_sin_to_cos: bool = True, + downscale_freq_shift: float = 0, + scale: float = 1, + max_period: int = 10000, + ): + + super().__init__() + + self.time_proj = Timesteps( + embedding_dim=embedding_dim, + flip_sin_to_cos=flip_sin_to_cos, + downscale_freq_shift=downscale_freq_shift, + scale=scale, + max_period=max_period, + ) + self.time_embedder = MLPEmbedder(in_dim=embedding_dim, hidden_dim=hidden_dim) + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + timesteps_proj = self.time_proj(timesteps) + timesteps_emb = self.time_embedder(timesteps_proj) + + return timesteps_emb diff --git a/nemo/collections/diffusion/models/flux/model.py b/nemo/collections/diffusion/models/flux/model.py new file mode 100644 index 000000000000..4d42c80a75a1 --- /dev/null +++ b/nemo/collections/diffusion/models/flux/model.py @@ -0,0 +1,156 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Callable + +import torch +from megatron.core.models.common.vision_module.vision_module import VisionModule +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.utils import openai_gelu +from torch import nn + +from nemo.collections.diffusion.models.dit.dit_layer_spec import ( + AdaLNContinuous, + FluxSingleTransformerBlock, + MMDiTLayer, + get_flux_double_transformer_engine_spec, + get_flux_single_transformer_engine_spec, +) +from nemo.collections.diffusion.models.flux.layers import EmbedND, MLPEmbedder, TimeStepEmbedder + + +@dataclass +class FluxParams: + num_joint_layers: int = 19 + num_single_layers: int = 38 + hidden_size: int = 3072 + num_attention_heads: int = 24 + activation_func: Callable = openai_gelu + add_qkv_bias: bool = True + ffn_hidden_size: int = 16384 + in_channels: int = 64 + context_dim: int = 4096 + model_channels: int = 256 + patch_size: int = 1 + guidance_embed: bool = False + vec_in_dim: int = 768 + + +class Flux(VisionModule): + def __init__(self, config: FluxParams): + + self.out_channels = config.in_channels + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.patch_size = config.patch_size + self.in_channels = config.in_channels + self.guidance_embed = config.guidance_embed + transformer_config = TransformerConfig( + num_layers=1, + hidden_size=self.hidden_size, + num_attention_heads=self.num_attention_heads, + use_cpu_initialization=True, + activation_func=config.activation_func, + hidden_dropout=0, + attention_dropout=0, + layernorm_epsilon=1e-6, + add_qkv_bias=config.add_qkv_bias, + rotary_interleaved=True, + ) + super().__init__(transformer_config) + + self.pos_embed = EmbedND(dim=self.hidden_size, theta=10000, axes_dim=[16, 56, 56]) + self.img_embed = nn.Linear(config.in_channels, self.hidden_size) + self.txt_embed = nn.Linear(config.context_dim, self.hidden_size) + self.timestep_embedding = TimeStepEmbedder(config.model_channels, self.hidden_size) + self.vector_embedding = MLPEmbedder(in_dim=config.vec_in_dim, hidden_dim=self.hidden_size) + if config.guidance_embed: + self.guidance_embedding = ( + MLPEmbedder(in_dim=config.model_channels, hidden_dim=self.hidden_size) + if config.guidance_embed + else nn.Identity() + ) + + self.double_blocks = nn.ModuleList( + [ + MMDiTLayer( + config=transformer_config, + submodules=get_flux_double_transformer_engine_spec().submodules, + layer_number=i, + context_pre_only=False, + ) + for i in range(config.num_joint_layers) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + FluxSingleTransformerBlock( + config=transformer_config, + submodules=get_flux_single_transformer_engine_spec().submodules, + layer_number=i, + ) + for i in range(config.num_single_layers) + ] + ) + + self.norm_out = AdaLNContinuous(config=transformer_config, conditioning_embedding_dim=self.hidden_size) + self.proj_out = nn.Linear(self.hidden_size, self.patch_size * self.patch_size * self.out_channels, bias=True) + + def forward( + self, + img: torch.Tensor, + txt: torch.Tensor = None, + y: torch.Tensor = None, + timesteps: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + ): + hidden_states = self.img_embed(img) + encoder_hidden_states = self.txt_embed(txt) + + timesteps = timesteps.to(img.dtype) * 1000 + vec_emb = self.timestep_embedding(timesteps) + + if guidance is not None: + vec_emb = vec_emb + self.guidance_embedding(self.timestep_embedding.time_proj(guidance * 1000)) + vec_emb = vec_emb + self.vector_embedding(y) + + ids = torch.cat((txt_ids, img_ids), dim=1) + rotary_pos_emb = self.pos_embed(ids) + for id_block, block in enumerate(self.double_blocks): + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + rotary_pos_emb=rotary_pos_emb, + emb=vec_emb, + ) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=0) + + for id_block, block in enumerate(self.single_blocks): + hidden_states = block( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + emb=vec_emb, + ) + + hidden_states = hidden_states[encoder_hidden_states.shape[0] :, ...] + + hidden_states = self.norm_out(hidden_states, vec_emb) + output = self.proj_out(hidden_states) + + return output diff --git a/nemo/collections/diffusion/models/flux/pipeline.py b/nemo/collections/diffusion/models/flux/pipeline.py new file mode 100644 index 000000000000..e460f8f115bd --- /dev/null +++ b/nemo/collections/diffusion/models/flux/pipeline.py @@ -0,0 +1,342 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import List, Optional, Union + +import numpy as np +import torch +from PIL import Image +from safetensors.torch import load_file as load_safetensors +from safetensors.torch import save_file as save_safetensors +from torch import nn +from tqdm import tqdm + +from nemo.collections.diffusion.encoders.conditioner import FrozenCLIPEmbedder, FrozenT5Embedder +from nemo.collections.diffusion.models.flux.model import Flux, FluxParams +from nemo.collections.diffusion.sampler.flow_matching.flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler +from nemo.collections.diffusion.utils.flux_ckpt_converter import flux_transformer_converter +from nemo.collections.diffusion.utils.flux_pipeline_utils import FluxModelParams +from nemo.collections.diffusion.vae.autoencoder import AutoEncoder + + +class FluxInferencePipeline(nn.Module): + def __init__(self, params: FluxModelParams): + super().__init__() + self.device = params.device + params.clip_params['device'] = self.device + params.t5_params['device'] = self.device + + self.vae = AutoEncoder(params.vae_params).to(self.device).eval() + self.clip_encoder = FrozenCLIPEmbedder(**params.clip_params) + self.t5_encoder = FrozenT5Embedder(**params.t5_params) + self.transformer = Flux(params.flux_params).to(self.device).eval() + self.vae_scale_factor = 2 ** (len(self.vae.params.ch_mult)) + self.scheduler = FlowMatchEulerDiscreteScheduler(**params.scheduler_params) + self.params = params + + def load_from_pretrained(self, ckpt_path, do_convert_from_hf=True, save_converted_model=None): + if do_convert_from_hf: + ckpt = flux_transformer_converter(ckpt_path, self.transformer.config) + if save_converted_model: + save_path = os.path.join(ckpt_path, 'nemo_flux_transformer.safetensors') + save_safetensors(ckpt, save_path) + print(f'saving converted transformer checkpoint to {save_path}') + else: + ckpt = load_safetensors(ckpt_path) + missing, unexpected = self.transformer.load_state_dict(ckpt, strict=False) + missing = [ + k for k in missing if not k.endswith('_extra_state') + ] # These keys are mcore specific and should not affect the model performance + if len(missing) > 0: + print( + f"The folloing keys are missing during checkpoint loading, please check the ckpt provided or the image quality may be compromised.\n {missing}" + ) + print(f"Found unexepected keys: \n {unexpected}") + + def encoder_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = 'cuda', + dtype: Optional[torch.dtype] = torch.float, + ): + if prompt is not None: + batch_size = len(prompt) + elif prompt_embeds is not None: + batch_size = prompt_embeds.shape[0] + else: + raise ValueError("Either prompt or prompt_embeds must be provided.") + if device == 'cuda' and self.t5_encoder.device != device: + self.t5_encoder.to(device) + if prompt_embeds is None: + prompt_embeds = self.t5_encoder(prompt, max_sequence_length=max_sequence_length) + seq_len = prompt_embeds.shape[1] + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1).to(dtype=dtype) + + if device == 'cuda' and self.clip_encoder.device != device: + self.clip_encoder.to(device) + if pooled_prompt_embeds is None: + _, pooled_prompt_embeds = self.clip_encoder(prompt) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1).to(dtype=dtype) + + dtype = dtype if dtype is not None else self.t5_encoder.dtype + text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) + + return prompt_embeds.transpose(0, 1), pooled_prompt_embeds, text_ids + + @staticmethod + def _prepare_latent_image_ids(batch_size: int, height: int, width: int, device: torch.device, dtype: torch.dtype): + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + + return latents + + @staticmethod + def _calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, + ): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * int(height) // self.vae_scale_factor + width = 2 * int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = FluxInferencePipeline._generate_rand_latents(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + return latents.transpose(0, 1), latent_image_ids + + @staticmethod + def _generate_rand_latents( + shape, + generator, + device, + dtype, + ): + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=device, dtype=dtype, layout=layout) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device=device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + + return latents + + @staticmethod + def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + @staticmethod + def torch_to_numpy(images): + numpy_images = images.float().cpu().permute(0, 2, 3, 1).numpy() + return numpy_images + + @staticmethod + def denormalize(image): + return (image / 2 + 0.5).clamp(0, 1) + + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: int = 28, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 7.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + max_sequence_length: int = 512, + device: torch.device = 'cuda', + dtype: torch.dtype = torch.float32, + save_to_disk: bool = True, + offload: bool = True, + ): + assert device == 'cuda', 'Transformer blocks in Mcore must run on cuda devices' + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + elif prompt_embeds is not None and isinstance(prompt_embeds, torch.FloatTensor): + batch_size = prompt_embeds.shape[0] + else: + raise ValueError("Either prompt or prompt_embeds must be provided.") + + ## get text prompt embeddings + prompt_embeds, pooled_prompt_embeds, text_ids = self.encoder_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + if offload: + self.t5_encoder.to('cpu') + self.clip_encoder.to('cpu') + torch.cuda.empty_cache() + + ## prepare image latents + num_channels_latents = self.transformer.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, num_channels_latents, height, width, dtype, device, generator, latents + ) + # prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[0] + + mu = FluxInferencePipeline._calculate_shift( + image_seq_len, + self.scheduler.base_image_seq_len, + self.scheduler.max_image_seq_len, + self.scheduler.base_shift, + self.scheduler.max_shift, + ) + + self.scheduler.set_timesteps(sigmas=sigmas, device=device, mu=mu) + timesteps = self.scheduler.timesteps + + if device == 'cuda' and device != self.device: + self.transformer.to(device) + with torch.no_grad(): + for i, t in tqdm(enumerate(timesteps)): + timestep = t.expand(latents.shape[1]).to(device=latents.device, dtype=latents.dtype) + if self.transformer.guidance_embed: + guidance = torch.tensor([guidance_scale], device=device).expand(latents.shape[1]) + else: + guidance = None + with torch.autocast(device_type='cuda', dtype=latents.dtype): + pred = self.transformer( + img=latents, + txt=prompt_embeds, + y=pooled_prompt_embeds, + timesteps=timestep / 1000, + img_ids=latent_image_ids, + txt_ids=text_ids, + guidance=guidance, + ) + latents = self.scheduler.step(pred, t, latents)[0] + if offload: + self.transformer.to('cpu') + torch.cuda.empty_cache() + + if output_type == "latent": + return latents.transpose(0, 1) + elif output_type == "pil": + latents = self._unpack_latents(latents.transpose(0, 1), height, width, self.vae_scale_factor) + latents = (latents / self.vae.params.scale_factor) + self.vae.params.shift_factor + if device == 'cuda' and device != self.device: + self.vae.to(device) + with torch.autocast(device_type='cuda', dtype=latents.dtype): + image = self.vae.decode(latents) + if offload: + self.vae.to('cpu') + torch.cuda.empty_cache() + image = FluxInferencePipeline.denormalize(image) + image = FluxInferencePipeline.torch_to_numpy(image) + image = FluxInferencePipeline.numpy_to_pil(image) + if save_to_disk: + print('Saving to disk') + assert len(image) == int(len(prompt) * num_images_per_prompt) + prompt = [p[:40] + f'_{idx}' for p in prompt for idx in range(num_images_per_prompt)] + for file_name, image in zip(prompt, image): + image.save(f'{file_name}.png') + + return image diff --git a/nemo/collections/diffusion/sampler/flow_matching/__init__.py b/nemo/collections/diffusion/sampler/flow_matching/__init__.py new file mode 100644 index 000000000000..d9155f923f18 --- /dev/null +++ b/nemo/collections/diffusion/sampler/flow_matching/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/diffusion/sampler/flow_matching/flow_match_euler_discrete.py b/nemo/collections/diffusion/sampler/flow_matching/flow_match_euler_discrete.py new file mode 100644 index 000000000000..5bde6b0d1dc1 --- /dev/null +++ b/nemo/collections/diffusion/sampler/flow_matching/flow_match_euler_discrete.py @@ -0,0 +1,284 @@ +# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from abc import ABC +from typing import List, Optional, Tuple, Union + + +import numpy as np +import torch + + +class FlowMatchEulerDiscreteScheduler(ABC): + """ + Euler scheduler. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + """ + + _compatibles = [] + order = 1 + + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting=False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + ): + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + self.base_shift = base_shift + self.max_shift = max_shift + self.base_image_seq_len = base_image_seq_len + self.max_image_seq_len = max_image_seq_len + self.use_dynamic_shifting = use_dynamic_shifting + self.num_train_timesteps = num_train_timesteps + self.shift = shift + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Forward process in flow-matching + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) + + if sample.device.type == "mps" and torch.is_floating_point(timestep): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) + timestep = timestep.to(sample.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(sample.device) + timestep = timestep.to(sample.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timestep.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timestep.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) + + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.num_train_timesteps + + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.use_dynamic_shifting and mu is None: + raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + + if sigmas is None: + self.num_inference_steps = num_inference_steps + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) + + sigmas = timesteps / self.num_train_timesteps + + if self.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + timesteps = sigmas * self.num_train_timesteps + + self.timesteps = timesteps.to(device=device) + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self._step_index = None + self._begin_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + ) -> Tuple: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + + Returns: + A tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + prev_sample = sample + (sigma_next - sigma) * model_output + + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + return (prev_sample,) + + def __len__(self): + return self.num_train_timesteps diff --git a/nemo/collections/diffusion/utils/__init__.py b/nemo/collections/diffusion/utils/__init__.py new file mode 100644 index 000000000000..9e3250071955 --- /dev/null +++ b/nemo/collections/diffusion/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/diffusion/utils/flux_ckpt_converter.py b/nemo/collections/diffusion/utils/flux_ckpt_converter.py new file mode 100644 index 000000000000..444a77bfad68 --- /dev/null +++ b/nemo/collections/diffusion/utils/flux_ckpt_converter.py @@ -0,0 +1,206 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +from safetensors.torch import load_file as load_safetensors + + +def _import_qkv_bias(transformer_config, qb, kb, vb): + + head_num = transformer_config.num_attention_heads + num_query_groups = transformer_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = transformer_config.hidden_size + head_num = transformer_config.num_attention_heads + head_size = hidden_size // head_num + + new_q_bias_tensor_shape = (head_num, head_size) + new_kv_bias_tensor_shape = (num_query_groups, head_size) + + qb = qb.view(*new_q_bias_tensor_shape) + kb = kb.view(*new_kv_bias_tensor_shape) + vb = vb.view(*new_kv_bias_tensor_shape) + + qkv_bias_l = [] + for i in range(num_query_groups): + qkv_bias_l.append(qb[i * heads_per_group : (i + 1) * heads_per_group, :]) + qkv_bias_l.append(kb[i : i + 1, :]) + qkv_bias_l.append(vb[i : i + 1, :]) + + qkv_bias = torch.cat(qkv_bias_l) + qkv_bias = qkv_bias.reshape([head_size * (head_num + 2 * num_query_groups)]) + + return qkv_bias + + +def _import_qkv(transformer_config, q, k, v): + + head_num = transformer_config.num_attention_heads + num_query_groups = transformer_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = transformer_config.hidden_size + head_num = transformer_config.num_attention_heads + head_size = hidden_size // head_num + + old_tensor_shape = q.size() + new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] + new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] + + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights_l = [] + for i in range(num_query_groups): + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) + qkv_weights_l.append(k[i : i + 1, :, :]) + qkv_weights_l.append(v[i : i + 1, :, :]) + qkv_weights = torch.cat(qkv_weights_l) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + return qkv_weights + + +key_mapping = { + 'double_blocks': { + 'norm1.linear.weight': 'adaln.adaLN_modulation.1.weight', + 'norm1.linear.bias': 'adaln.adaLN_modulation.1.bias', + 'norm1_context.linear.weight': 'adaln_context.adaLN_modulation.1.weight', + 'norm1_context.linear.bias': 'adaln_context.adaLN_modulation.1.bias', + 'attn.norm_q.weight': 'self_attention.q_layernorm.weight', + 'attn.norm_k.weight': 'self_attention.k_layernorm.weight', + 'attn.norm_added_q.weight': 'self_attention.added_q_layernorm.weight', + 'attn.norm_added_k.weight': 'self_attention.added_k_layernorm.weight', + 'attn.to_out.0.weight': 'self_attention.linear_proj.weight', + 'attn.to_out.0.bias': 'self_attention.linear_proj.bias', + 'attn.to_add_out.weight': 'self_attention.added_linear_proj.weight', + 'attn.to_add_out.bias': 'self_attention.added_linear_proj.bias', + 'ff.net.0.proj.weight': 'mlp.linear_fc1.weight', + 'ff.net.0.proj.bias': 'mlp.linear_fc1.bias', + 'ff.net.2.weight': 'mlp.linear_fc2.weight', + 'ff.net.2.bias': 'mlp.linear_fc2.bias', + 'ff_context.net.0.proj.weight': 'context_mlp.linear_fc1.weight', + 'ff_context.net.0.proj.bias': 'context_mlp.linear_fc1.bias', + 'ff_context.net.2.weight': 'context_mlp.linear_fc2.weight', + 'ff_context.net.2.bias': 'context_mlp.linear_fc2.bias', + }, + 'single_blocks': { + 'norm.linear.weight': 'adaln.adaLN_modulation.1.weight', + 'norm.linear.bias': 'adaln.adaLN_modulation.1.bias', + 'proj_mlp.weight': 'proj_in.weight', + 'proj_mlp.bias': 'proj_in.bias', + 'proj_out.weight': 'proj_out.weight', + 'proj_out.bias': 'proj_out.bias', + 'attn.norm_q.weight': 'self_attention.q_layernorm.weight', + 'attn.norm_k.weight': 'self_attention.k_layernorm.weight', + }, + 'norm_out.linear.bias': 'norm_out.adaLN_modulation.1.bias', + 'norm_out.linear.weight': 'norm_out.adaLN_modulation.1.weight', + 'proj_out.bias': 'proj_out.bias', + 'proj_out.weight': 'proj_out.weight', + 'time_text_embed.guidance_embedder.linear_1.bias': 'guidance_embedding.in_layer.bias', + 'time_text_embed.guidance_embedder.linear_1.weight': 'guidance_embedding.in_layer.weight', + 'time_text_embed.guidance_embedder.linear_2.bias': 'guidance_embedding.out_layer.bias', + 'time_text_embed.guidance_embedder.linear_2.weight': 'guidance_embedding.out_layer.weight', + 'x_embedder.bias': 'img_embed.bias', + 'x_embedder.weight': 'img_embed.weight', + 'time_text_embed.timestep_embedder.linear_1.bias': 'timestep_embedding.time_embedder.in_layer.bias', + 'time_text_embed.timestep_embedder.linear_1.weight': 'timestep_embedding.time_embedder.in_layer.weight', + 'time_text_embed.timestep_embedder.linear_2.bias': 'timestep_embedding.time_embedder.out_layer.bias', + 'time_text_embed.timestep_embedder.linear_2.weight': 'timestep_embedding.time_embedder.out_layer.weight', + 'context_embedder.bias': 'txt_embed.bias', + 'context_embedder.weight': 'txt_embed.weight', + 'time_text_embed.text_embedder.linear_1.bias': 'vector_embedding.in_layer.bias', + 'time_text_embed.text_embedder.linear_1.weight': 'vector_embedding.in_layer.weight', + 'time_text_embed.text_embedder.linear_2.bias': 'vector_embedding.out_layer.bias', + 'time_text_embed.text_embedder.linear_2.weight': 'vector_embedding.out_layer.weight', +} + + +def flux_transformer_converter(ckpt_path=None, transformer_config=None): + diffuser_state_dict = {} + if os.path.isdir(ckpt_path): + files = os.listdir(ckpt_path) + for file in files: + if file.endswith('.safetensors'): + loaded_dict = load_safetensors(os.path.join(ckpt_path, file)) + diffuser_state_dict.update(loaded_dict) + elif os.path.isfile(ckpt_path): + diffuser_state_dict = load_safetensors(ckpt_path) + else: + raise FileNotFoundError("Please provide a valid ckpt path.") + new_state_dict = {} + num_single_blocks = 0 + num_double_blocks = 0 + for key, value in diffuser_state_dict.items(): + if 'attn.to_q' in key or 'attn.to_k' in key or 'attn.to_v' in key: + continue + if 'attn.add_q_proj' in key or 'attn.add_k_proj' in key or 'attn.add_v_proj' in key: + continue + if key.startswith('transformer_blocks'): + temp = key.split('.') + idx, k = temp[1], '.'.join(temp[2:]) + num_double_blocks = max(int(idx), num_double_blocks) + new_key = '.'.join(['double_blocks', idx, key_mapping['double_blocks'][k]]) + elif key.startswith('single_transformer_blocks'): + temp = key.split('.') + idx, k = temp[1], '.'.join(temp[2:]) + num_single_blocks = max(int(idx), num_single_blocks) + new_key = '.'.join(['single_blocks', idx, key_mapping['single_blocks'][k]]) + else: + new_key = key_mapping[key] + new_state_dict[new_key] = value + + for i in range(num_double_blocks + 1): + new_key = f'double_blocks.{str(i)}.self_attention.linear_qkv.weight' + qk, kk, vk = [f'transformer_blocks.{str(i)}.attn.to_{n}.weight' for n in ('q', 'k', 'v')] + new_state_dict[new_key] = _import_qkv( + transformer_config, diffuser_state_dict[qk], diffuser_state_dict[kk], diffuser_state_dict[vk] + ) + new_key = f'double_blocks.{str(i)}.self_attention.linear_qkv.bias' + qk, kk, vk = [f'transformer_blocks.{str(i)}.attn.to_{n}.bias' for n in ('q', 'k', 'v')] + new_state_dict[new_key] = _import_qkv_bias( + transformer_config, diffuser_state_dict[qk], diffuser_state_dict[kk], diffuser_state_dict[vk] + ) + new_key = f'double_blocks.{str(i)}.self_attention.added_linear_qkv.weight' + qk, kk, vk = [f'transformer_blocks.{str(i)}.attn.add_{n}_proj.weight' for n in ('q', 'k', 'v')] + new_state_dict[new_key] = _import_qkv( + transformer_config, diffuser_state_dict[qk], diffuser_state_dict[kk], diffuser_state_dict[vk] + ) + new_key = f'double_blocks.{str(i)}.self_attention.added_linear_qkv.bias' + qk, kk, vk = [f'transformer_blocks.{str(i)}.attn.add_{n}_proj.bias' for n in ('q', 'k', 'v')] + new_state_dict[new_key] = _import_qkv_bias( + transformer_config, diffuser_state_dict[qk], diffuser_state_dict[kk], diffuser_state_dict[vk] + ) + + for i in range(num_single_blocks + 1): + new_key = f'single_blocks.{str(i)}.self_attention.linear_qkv.weight' + qk, kk, vk = [f'single_transformer_blocks.{str(i)}.attn.to_{n}.weight' for n in ('q', 'k', 'v')] + new_state_dict[new_key] = _import_qkv( + transformer_config, diffuser_state_dict[qk], diffuser_state_dict[kk], diffuser_state_dict[vk] + ) + new_key = f'single_blocks.{str(i)}.self_attention.linear_qkv.bias' + qk, kk, vk = [f'single_transformer_blocks.{str(i)}.attn.to_{n}.bias' for n in ('q', 'k', 'v')] + new_state_dict[new_key] = _import_qkv_bias( + transformer_config, diffuser_state_dict[qk], diffuser_state_dict[kk], diffuser_state_dict[vk] + ) + + return new_state_dict diff --git a/nemo/collections/diffusion/utils/flux_pipeline_utils.py b/nemo/collections/diffusion/utils/flux_pipeline_utils.py new file mode 100644 index 000000000000..77dcfa58450f --- /dev/null +++ b/nemo/collections/diffusion/utils/flux_pipeline_utils.py @@ -0,0 +1,76 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import torch +from megatron.core.transformer.utils import openai_gelu + +from nemo.collections.diffusion.models.flux.model import FluxParams +from nemo.collections.diffusion.vae.autoencoder import AutoEncoderParams + + +@dataclass +class FluxModelParams: + flux_params: FluxParams + vae_params: AutoEncoderParams + clip_params: dict | None + t5_params: dict | None + scheduler_params: dict | None + device: str | torch.device + + +configs = { + "dev": FluxModelParams( + flux_params=FluxParams( + num_joint_layers=19, + num_single_layers=38, + hidden_size=3072, + num_attention_heads=24, + activation_func=openai_gelu, + add_qkv_bias=True, + ffn_hidden_size=16384, + in_channels=64, + context_dim=4096, + model_channels=256, + patch_size=1, + guidance_embed=True, + vec_in_dim=768, + ), + vae_params=AutoEncoderParams( + ch_mult=[1, 2, 4, 4], + attn_resolutions=[], + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ckpt=None, + ), + clip_params={ + 'max_length': 77, + 'always_return_pooled': True, + }, + t5_params={ + 'max_length': 512, + }, + scheduler_params={ + 'num_train_timesteps': 1000, + }, + device='cpu', + ) +} diff --git a/nemo/collections/diffusion/utils/mcore_parallel_utils.py b/nemo/collections/diffusion/utils/mcore_parallel_utils.py new file mode 100644 index 000000000000..0b9bdec97464 --- /dev/null +++ b/nemo/collections/diffusion/utils/mcore_parallel_utils.py @@ -0,0 +1,80 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Megatron Model Parallel Initialization +""" + +import os + +import megatron.core.parallel_state as ps +import torch + + +class Utils: + world_size = torch.cuda.device_count() + # rank = int(os.environ["LOCAL_RANK"]) + rank = 0 + + @staticmethod + def initialize_distributed(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, context_parallel_size=1): + ps.destroy_model_parallel() + + # Torch setup for distributed training + rank = int(os.environ['LOCAL_RANK']) + world_size = 1 # torch.cuda.device_count() + torch.cuda.set_device(rank) + torch.distributed.init_process_group(world_size=world_size, rank=rank) + + # Megatron core distributed training initialization + ps.initialize_model_parallel( + tensor_model_parallel_size, pipeline_model_parallel_size, context_parallel_size=context_parallel_size + ) + + @staticmethod + def set_world_size(world_size=None, rank=None): + Utils.world_size = torch.cuda.device_count() if world_size is None else world_size + if torch.distributed.is_initialized() and Utils.world_size != torch.distributed.get_world_size(): + torch.distributed.destroy_process_group() + + if rank is None: + # Utils.rank = int(os.environ["LOCAL_RANK"]) + Utils.rank = 0 + if Utils.rank >= Utils.world_size: + Utils.rank = -1 + else: + Utils.rank = rank + + @staticmethod + def destroy_model_parallel(): + ps.destroy_model_parallel() + torch.distributed.barrier() + + @staticmethod + def initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + virtual_pipeline_model_parallel_size=None, + pipeline_model_parallel_split_rank=None, + **kwargs, + ): + ps.destroy_model_parallel() + Utils.initialize_distributed() + ps.initialize_model_parallel( + tensor_model_parallel_size, + pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size, + pipeline_model_parallel_split_rank, + **kwargs, + ) diff --git a/nemo/collections/diffusion/vae/autoencoder.py b/nemo/collections/diffusion/vae/autoencoder.py new file mode 100644 index 000000000000..b356d74baac1 --- /dev/null +++ b/nemo/collections/diffusion/vae/autoencoder.py @@ -0,0 +1,334 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import numpy as np +import torch +from torch import Tensor, nn + +from nemo.collections.diffusion.vae.blocks import AttnBlock, Downsample, Normalize, ResnetBlock, Upsample, make_attn + + +@dataclass +class AutoEncoderParams: + ch_mult: list[int] + attn_resolutions: list[int] + resolution: int = 256 + in_channels: int = 3 + ch: int = 128 + out_ch: int = 3 + num_res_blocks: int = 2 + z_channels: int = 16 + scale_factor: float = 0.3611 + shift_factor: float = 0.1159 + attn_type: str = 'vanilla' + double_z: bool = True + dropout: float = 0.0 + ckpt: str = None + + +def nonlinearity(x): + # swish + return torch.nn.functional.silu(x) + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + in_channels: int, + resolution: int, + z_channels: int, + dropout=0.0, + resamp_with_conv=True, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + in_channels: int, + resolution: int, + z_channels: int, + dropout=0.0, + resamp_with_conv=True, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + double_z=params.double_z, + attn_type=params.attn_type, + dropout=params.dropout, + out_ch=params.out_ch, + attn_resolutions=params.attn_resolutions, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + double_z=params.double_z, + attn_type=params.attn_type, + dropout=params.dropout, + attn_resolutions=params.attn_resolutions, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + self.params = params + + if params.ckpt is not None: + self.load_from_checkpoint(params.ckpt) + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) + + def load_from_checkpoint(self, ckpt_path): + from safetensors.torch import load_file as load_sft + + state_dict = load_sft(ckpt_path) + missing, unexpected = self.load_state_dict(state_dict) + if len(missing) > 0: + logger.warning(f"Following keys are missing from checkpoint loaded: {missing}") diff --git a/nemo/collections/diffusion/vae/blocks.py b/nemo/collections/diffusion/vae/blocks.py new file mode 100644 index 000000000000..ad38a7a463cf --- /dev/null +++ b/nemo/collections/diffusion/vae/blocks.py @@ -0,0 +1,180 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from einops import rearrange +from torch import Tensor, nn + +try: + from apex.contrib.group_norm import GroupNorm + + OPT_GROUP_NORM = True +except Exception: + print('Fused optimized group norm has not been installed.') + OPT_GROUP_NORM = False + + +def Normalize(in_channels, num_groups=32, act=""): + return GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, act=act) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, temb_channels=0): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, act="silu") + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels, act="silu") + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(yuya): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 + dtype = x.dtype + if dtype == torch.bfloat16: + x = x.to(torch.float32) + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if dtype == torch.bfloat16: + x = x.to(dtype) + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels, act="silu") + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class LinAttnBlock(LinearAttention): + """ + to match AttnBlock usage + """ + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels)