diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py index 638116352674..35eab9df25b3 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py @@ -25,7 +25,6 @@ import torch as th import torch.nn as nn import torch.nn.functional as F -from pprint import pprint from nemo.collections.multimodal.modules.stable_diffusion.attention import SpatialTransformer from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import ( @@ -1227,11 +1226,6 @@ def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from missing_keys = list(set(expected_keys) - set(loaded_keys)) unexpected_keys = list(set(loaded_keys) - set(expected_keys)) - # pprint(sorted(missing_keys)) - # input("these are missing...") - # pprint(sorted(unexpected_keys)) - # input("these are unexpected...") - # SDXL specific mapping if 'output_blocks.2.2.conv.bias' in missing_keys and 'output_blocks.2.1.conv.bias' in loaded_keys: state_dict['output_blocks.2.2.conv.bias'] = state_dict['output_blocks.2.1.conv.bias'] diff --git a/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py b/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py index bfab58e6750f..0c49215f1ebb 100644 --- a/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py @@ -22,32 +22,11 @@ import torch import numpy as np -import json -from pprint import pprint -from safetensors import torch as torch_s import safetensors - -import json -import os from argparse import ArgumentParser -from collections import OrderedDict import torch import torch.nn -from omegaconf import OmegaConf -from pytorch_lightning.core.saving import _load_state as ptl_load_state -from pytorch_lightning.trainer.trainer import Trainer -from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel -from diffusers import StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL - -from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel -from nemo.collections.nlp.parts.nlp_overrides import ( - GradScaler, - MegatronHalfPrecisionPlugin, - NLPDDPStrategy, - NLPSaveRestoreConnector, - PipelineMixedPrecisionPlugin, -) from nemo.utils import logging intkey = lambda x: int(x) @@ -122,9 +101,6 @@ def model_to_tree(model): keys = list(model.keys()) tree = SegTree() for k in keys: - # wk = model.get(wk, torch.tensor([])) - # bk = model.get(bk, torch.tensor([])) - # tree.add(k, (wk, bk)) tree.add(k, "leaf") return tree @@ -151,28 +127,10 @@ def make_tiny_config(config): return config def load_hf_ckpt(in_dir, args): - # takes a directory as input - # params_file = os.path.join(in_dir, 'config.json') - # assert os.path.exists(params_file) - # with open(params_file, 'r') as fp: - # model_args = json.load(fp) - # if args.debug: - # model_args = make_tiny_config(model_args) - - # # model = AutoModel.from_pretrained(in_dir) - # model = AutoModel.from_config(model_args) - # if args.model == 'unet': - # model = UNet2DConditionModel.from_pretrained(in_dir) - # elif args.model == 'vae': - # model = AutoencoderKL.from_pretrained(in_dir) - # model = torch_s.load(in_dir + "/diffusion") - # print(model) ckpt = {} with safetensors.safe_open(in_dir + "/diffusion_pytorch_model.safetensors", framework="pt") as f: for k in f.keys(): ckpt[k] = f.get_tensor(k) - # input("enter to continue...") - # ckpt = model.state_dict() return args, ckpt def dup_convert_name_recursive(tree: SegTree, convert_name=None): @@ -246,7 +204,7 @@ def map_attention_block(att_tree: SegTree): currently assumed AttentionBlock ''' - # TODO: Add check for dual attention block + # TODO (rohit): Add check for dual attention block def check_att_type(tree): return "att_block" @@ -263,7 +221,7 @@ def check_att_type(tree): dup_convert_name_recursive(tblock['norm1'], 'attn1.norm') dup_convert_name_recursive(tblock['norm2'], 'attn2.norm') dup_convert_name_recursive(tblock['norm3'], 'ff.net.0') - # map ff + # map ff module tblock['ff'].convert_name = "ff" tblock['ff.net'].convert_name = 'net' dup_convert_name_recursive(tblock['ff.net.0'], '1') @@ -382,13 +340,6 @@ def convert_encoder(hf_tree: SegTree): # map the `mid_block` ( NeMo's mid layer is hardcoded in terms of number of modules) encoder['mid_block'].convert_name = 'mid' - # encoder['mid_block.resnets.0'].convert_name = 'block_1' - # encoder['mid_block.resnets.1'].convert_name = 'block_2' - # map_resnet_block(encoder['mid_block.resnets.0']) - # map_resnet_block(encoder['mid_block.resnets.1']) - # for reskey in {'conv1', 'conv2', 'norm1', 'norm2'}: - # dup_convert_name_recursive(encoder[f'mid_block.resnets.0.{reskey}'], reskey) - # dup_convert_name_recursive(encoder[f'mid_block.resnets.1.{reskey}'], reskey) dup_convert_name_recursive(encoder[f'mid_block.resnets.0'], 'block_1') dup_convert_name_recursive(encoder[f'mid_block.resnets.1'], 'block_2') @@ -413,13 +364,6 @@ def convert_decoder(hf_tree: SegTree): decoder['mid_block'].convert_name = 'mid' dup_convert_name_recursive(decoder[f'mid_block.resnets.0'], 'block_1') dup_convert_name_recursive(decoder[f'mid_block.resnets.1'], 'block_2') - # decoder['mid_block.resnets.0'].convert_name = 'block_1' - # decoder['mid_block.resnets.1'].convert_name = 'block_2' - # map_resnet_block(encoder['mid_block.resnets.0']) - # map_resnet_block(encoder['mid_block.resnets.1']) - # for reskey in {'conv1', 'conv2', 'norm1', 'norm2'}: - # dup_convert_name_recursive(decoder[f'mid_block.resnets.0.{reskey}'], reskey) - # dup_convert_name_recursive(decoder[f'mid_block.resnets.1.{reskey}'], reskey) att = decoder['mid_block.attentions.0'] att.convert_name = 'attn_1' dup_convert_name_recursive(att['group_norm'], 'norm') @@ -476,8 +420,6 @@ def convert(args): for hf_key, nemo_key in mapping.items(): nemo_ckpt[nemo_key] = hf_ckpt[hf_key] - # save this - # torch.save(args.output_path, nemo_ckpt) torch.save(nemo_ckpt, args.output_path) logging.info(f"Saved nemo file to {args.output_path}")