From 2254bf9fee0ef2a6f407413c7ea3d6f78484ba2e Mon Sep 17 00:00:00 2001 From: rohitrango Date: Wed, 10 Jul 2024 21:36:46 +0000 Subject: [PATCH] Apply isort and black reformatting Signed-off-by: rohitrango --- .../text_to_image/controlnet/controlnet.py | 100 ++++++++++-------- .../convert_stablediffusion_hf_to_nemo.py | 72 +++++++++---- 2 files changed, 105 insertions(+), 67 deletions(-) diff --git a/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py b/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py index 50f20377b4da7..7e41c3cca2235 100644 --- a/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py +++ b/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py @@ -380,7 +380,9 @@ def __init__( time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), ) self.input_blocks = nn.ModuleList( @@ -505,24 +507,26 @@ def __init__( use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) - if not use_spatial_transformer - else SpatialTransformer( # always uses a self-attn - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, - use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, - use_flash_attention=use_flash_attention, + ( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( # always uses a self-attn + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + use_flash_attention=use_flash_attention, + ) ), ResBlock( ch, @@ -693,7 +697,10 @@ def fwd_bwd_step(self, dataloader_iter, forward_only): # handle asynchronous grad reduction no_sync_func = None if not forward_only and self.with_distributed_adam: - no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,) + no_sync_func = partial( + self._optimizer.no_sync, + greedy_grad_copy=self.megatron_amp_O2, + ) # pipeline schedules will get these from self.model.config for module in self.get_module_list(): @@ -737,12 +744,12 @@ def fwd_bwd_step(self, dataloader_iter, forward_only): def training_step(self, dataloader_iter): """ - Our dataloaders produce a micro-batch and then we fetch - a number of microbatches depending on the global batch size and model parallel size - from the dataloader to produce a list of microbatches. - Batch should be a list of microbatches and those microbatches should on CPU. - Microbatches are then moved to GPU during the pipeline. - The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + Batch should be a list of microbatches and those microbatches should on CPU. + Microbatches are then moved to GPU during the pipeline. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. """ # we zero grads here because we also call backward in the apex fwd/bwd functions self._optimizer.zero_grad() @@ -786,20 +793,20 @@ def training_step(self, dataloader_iter): return loss_mean def backward(self, *args, **kwargs): - """ LightningModule hook to do backward. - We want this to do nothing since we run backward in the fwd/bwd functions from apex. - No need to call it here. + """LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from apex. + No need to call it here. """ pass def optimizer_zero_grad(self, *args, **kwargs): - """ LightningModule hook to zero grad. - We want this to do nothing as we are zeroing grads during the training_step. + """LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. """ pass def _append_sequence_parallel_module_grads(self, module, grads): - """ Helper method for allreduce_sequence_parallel_gradients""" + """Helper method for allreduce_sequence_parallel_gradients""" for param in module.parameters(): sequence_parallel_param = getattr(param, 'sequence_parallel', False) @@ -812,8 +819,8 @@ def _append_sequence_parallel_module_grads(self, module, grads): def get_forward_output_and_loss_func(self): def process_batch(batch): - """ Prepares the global batch for apex fwd/bwd functions. - Global batch is a list of micro batches. + """Prepares the global batch for apex fwd/bwd functions. + Global batch is a list of micro batches. """ # noise_map, condition batch[self.cfg.first_stage_key] = batch[self.cfg.first_stage_key].cuda(non_blocking=True) @@ -823,7 +830,8 @@ def process_batch(batch): # SD has more dedicated structure for encoding, so we enable autocasting here as well with torch.cuda.amp.autocast( - self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype, + self.autocast_dtype in (torch.half, torch.bfloat16), + dtype=self.autocast_dtype, ): x, c = self.model.get_input(batch, self.cfg.first_stage_key) @@ -890,7 +898,7 @@ def validation_step(self, batch, batch_idx): self.log_dict(val_loss_dict, prog_bar=False, logger=True, on_step=False, on_epoch=True) def setup(self, stage=None): - """ PTL hook that is executed after DDP spawns. + """PTL hook that is executed after DDP spawns. We setup datasets here as megatron datasets require DDP to instantiate. See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. Args: @@ -944,7 +952,8 @@ def build_train_valid_test_datasets(self): if self.cfg.first_stage_key.endswith("encoded"): self._train_ds, self._validation_ds = build_train_valid_precached_datasets( - model_cfg=self.cfg, consumed_samples=self.compute_consumed_samples(0), + model_cfg=self.cfg, + consumed_samples=self.compute_consumed_samples(0), ) else: self._train_ds, self._validation_ds = build_train_valid_datasets( @@ -998,20 +1007,23 @@ def setup_test_data(self, cfg): f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}' ) self._test_dl = torch.utils.data.DataLoader( - self._test_ds, batch_size=self._micro_batch_size, num_workers=cfg.num_workers, pin_memory=True, + self._test_ds, + batch_size=self._micro_batch_size, + num_workers=cfg.num_workers, + pin_memory=True, ) def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: - """ PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device - When using pipeline parallelism, we need the global batch to remain on the CPU, - since the memory overhead will be too high when using a large number of microbatches. - Microbatches are transferred from CPU to GPU inside the pipeline. + """PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device + When using pipeline parallelism, we need the global batch to remain on the CPU, + since the memory overhead will be too high when using a large number of microbatches. + Microbatches are transferred from CPU to GPU inside the pipeline. """ return batch def _validate_trainer(self): - """ Certain trainer configurations can break training. - Here we try to catch them and raise an error. + """Certain trainer configurations can break training. + Here we try to catch them and raise an error. """ if self.trainer.accumulate_grad_batches > 1: raise ValueError( diff --git a/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py b/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py index e56298f4e2d1c..8baa47c59cfd1 100644 --- a/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py @@ -14,36 +14,42 @@ r""" Conversion script to convert HuggingFace StableDiffusion checkpoints (unet and vae) into checkpoints with nemo naming convention. """ -import torch -import numpy as np -import safetensors import os - from argparse import ArgumentParser + +import numpy as np +import safetensors import torch import torch.nn + from nemo.utils import logging + def filter_keys(rule, dict): keys = list(dict.keys()) nd = {k: dict[k] for k in keys if rule(k)} return nd + def map_keys(rule, dict): new = {rule(k): v for k, v in dict.items()} return new + def split_name(name, dots=0): l = name.split(".") - return ".".join(l[:dots+1]), ".".join(l[dots+1:]) + return ".".join(l[: dots + 1]), ".".join(l[dots + 1 :]) + def is_prefix(shortstr, longstr): # is the first string a prefix of the second one return longstr == shortstr or longstr.startswith(shortstr + ".") + def numdots(str): return str.count(".") + class SegTree: def __init__(self): self.nodes = dict() @@ -53,10 +59,10 @@ def __init__(self): def __len__(self): return len(self.nodes) - + def is_leaf(self): return len(self.nodes) == 0 - + def add(self, name, val=0): prefix, subname = split_name(name) if subname == '': @@ -66,10 +72,10 @@ def add(self, name, val=0): if self.nodes.get(prefix) is None: self.nodes[prefix] = SegTree() self.nodes[prefix].add(subname, val) - + def change(self, name, val): self.add(name, val) - + def __getitem__(self, name: str): if hasattr(self, name): return getattr(self, name) @@ -90,6 +96,7 @@ def __getitem__(self, name: str): return self.nodes[prefix][substr] return val + def model_to_tree(model): keys = list(model.keys()) tree = SegTree() @@ -97,6 +104,7 @@ def model_to_tree(model): tree.add(k, "leaf") return tree + def get_args(): parser = ArgumentParser() parser.add_argument( @@ -110,14 +118,16 @@ def get_args(): parser.add_argument("--precision", type=str, default="32", help="Model precision") parser.add_argument("--model", type=str, default="unet", required=True, choices=['unet', 'vae']) parser.add_argument("--debug", action='store_true', help="Useful for debugging purposes.") - + args = parser.parse_args() return args + def make_tiny_config(config): - ''' dial down the config file to make things tractable ''' + '''dial down the config file to make things tractable''' return config + def load_hf_ckpt(in_dir, args): # takes a directory as input, loads the checkpoint into a dict ckpt = {} @@ -125,10 +135,11 @@ def load_hf_ckpt(in_dir, args): with safetensors.safe_open(in_dir + "/diffusion_pytorch_model.safetensors", framework="pt") as f: for k in f.keys(): ckpt[k] = f.get_tensor(k) - return args, ckpt + return args, ckpt + def dup_convert_name_recursive(tree: SegTree, convert_name=None): - ''' inside this tree, convert all nodes recursively + '''inside this tree, convert all nodes recursively optionally, convert the name of the root as given by name (if not None) ''' if tree is None: @@ -139,6 +150,7 @@ def dup_convert_name_recursive(tree: SegTree, convert_name=None): for k, v in tree.nodes.items(): dup_convert_name_recursive(v, k) + def sanity_check(hf_tree, hf_unet, nemo_unet): # check if i'm introducing new keys for hfk, nk in hf_to_nemo_mapping(hf_tree).items(): @@ -147,8 +159,9 @@ def sanity_check(hf_tree, hf_unet, nemo_unet): if hfk not in hf_unet.keys(): print(hfk) + def convert_input_keys(hf_tree: SegTree): - ''' map the input blocks of huggingface model ''' + '''map the input blocks of huggingface model''' # map `conv_in` to first input block dup_convert_name_recursive(hf_tree['conv_in'], 'input_blocks.0.0') @@ -163,7 +176,7 @@ def convert_input_keys(hf_tree: SegTree): attentions = block.nodes.get('attentions', SegTree()) downsamplers = block.nodes.get('downsamplers', SegTree()) - if len(attentions) == 0: # no attentions, this is a DownBlock2d + if len(attentions) == 0: # no attentions, this is a DownBlock2d for resid in sorted(list(resnets.nodes.keys()), key=int): resid = str(resid) resnets[resid].convert_name = f"input_blocks.{nemo_inp_blk}.0" @@ -188,15 +201,18 @@ def convert_input_keys(hf_tree: SegTree): dup_convert_name_recursive(downsamplers[k]['conv'], 'op') nemo_inp_blk += 1 + def clean_convert_names(tree): tree.convert_name = None for k, v in tree.nodes.items(): clean_convert_names(v) + def map_attention_block(att_tree: SegTree): - ''' this HF tree can either be an AttentionBlock or a DualAttention block + '''this HF tree can either be an AttentionBlock or a DualAttention block currently assumed AttentionBlock ''' + # TODO(@rohitrango): Add check for dual attention block, but right now this works with SD and SDXL def check_att_type(tree): return "att_block" @@ -222,8 +238,9 @@ def check_att_type(tree): else: logging.warning("failed to identify type of attention block here.") + def map_resnet_block(resnet_tree: SegTree): - ''' this HF tree is supposed to have all the keys for a resnet ''' + '''this HF tree is supposed to have all the keys for a resnet''' dup_convert_name_recursive(resnet_tree.nodes.get('time_emb_proj'), 'emb_layers.1') dup_convert_name_recursive(resnet_tree['norm1'], 'in_layers.0') dup_convert_name_recursive(resnet_tree['conv1'], 'in_layers.1') @@ -231,6 +248,7 @@ def map_resnet_block(resnet_tree: SegTree): dup_convert_name_recursive(resnet_tree['conv2'], 'out_layers.2') dup_convert_name_recursive(resnet_tree.nodes.get('conv_shortcut'), 'skip_connection') + def hf_to_nemo_mapping(tree: SegTree): mapping = {} for nodename, subtree in tree.nodes.items(): @@ -244,6 +262,7 @@ def hf_to_nemo_mapping(tree: SegTree): mapping[nodename + "." + k] = convert_name + v return mapping + def convert_cond_keys(tree: SegTree): # map all conditioning keys if tree.nodes.get("add_embedding"): @@ -257,8 +276,9 @@ def convert_cond_keys(tree: SegTree): dup_convert_name_recursive(tree['time_embedding.linear_1'], '0') dup_convert_name_recursive(tree['time_embedding.linear_2'], '2') + def convert_middle_keys(tree: SegTree): - ''' middle block is fixed (resnet -> attention -> resnet) ''' + '''middle block is fixed (resnet -> attention -> resnet)''' mid = tree['mid_block'] resnets = mid['resnets'] attns = mid['attentions'] @@ -270,8 +290,9 @@ def convert_middle_keys(tree: SegTree): map_resnet_block(resnets['1']) map_attention_block(attns['0']) + def convert_output_keys(hf_tree: SegTree): - ''' output keys is similar to input keys ''' + '''output keys is similar to input keys''' nemo_inp_blk = 0 up_blocks = hf_tree['up_blocks'] up_blocks_keys = sorted(list(up_blocks.nodes.keys()), key=int) @@ -283,7 +304,7 @@ def convert_output_keys(hf_tree: SegTree): attentions = block.nodes.get('attentions', SegTree()) upsamplers = block.nodes.get('upsamplers', SegTree()) - if len(attentions) == 0: # no attentions, this is a UpBlock2D + if len(attentions) == 0: # no attentions, this is a UpBlock2D for resid in sorted(list(resnets.nodes.keys()), key=int): resid = str(resid) resnets[resid].convert_name = f"output_blocks.{nemo_inp_blk}.0" @@ -305,15 +326,19 @@ def convert_output_keys(hf_tree: SegTree): # if there is an upsampler, then also append it if len(upsamplers) > 0: nemo_inp_blk -= 1 - upsamplenum = 1 if len(attentions) == 0 else 2 # if there are attention modules, upsample is module2, else it is module 1 (to stay consistent with SD) + upsamplenum = ( + 1 if len(attentions) == 0 else 2 + ) # if there are attention modules, upsample is module2, else it is module 1 (to stay consistent with SD) upsamplers['0'].convert_name = f"output_blocks.{nemo_inp_blk}.{upsamplenum}" dup_convert_name_recursive(upsamplers['0.conv'], 'conv') nemo_inp_blk += 1 + def convert_finalout_keys(hf_tree: SegTree): dup_convert_name_recursive(hf_tree['conv_norm_out'], "out.0") dup_convert_name_recursive(hf_tree['conv_out'], "out.1") + def convert_encoder(hf_tree: SegTree): encoder = hf_tree['encoder'] encoder.convert_name = 'encoder' @@ -370,7 +395,7 @@ def convert_decoder(hf_tree: SegTree): dup_convert_name_recursive(att['to_v'], 'v') dup_convert_name_recursive(att['to_out.0'], 'proj_out') - # up blocks contain resnets and upsamplers + # up blocks contain resnets and upsamplers decoder['up_blocks'].convert_name = 'up' num_up_blocks = len(decoder['up_blocks']) for upid, upblock in decoder['up_blocks'].nodes.items(): @@ -409,7 +434,7 @@ def convert(args): else: logging.error("incorrect model specification.") return - + # check mapping mapping = hf_to_nemo_mapping(hf_tree) if len(mapping) != len(hf_ckpt.keys()): @@ -422,6 +447,7 @@ def convert(args): torch.save(nemo_ckpt, args.output_path) logging.info(f"Saved nemo file to {args.output_path}") + if __name__ == '__main__': args = get_args() convert(args)