Skip to content

Commit

Permalink
remove unnecessary comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohit Jena committed Jul 8, 2024
1 parent f73abe3 commit f411298
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from pprint import pprint

from nemo.collections.multimodal.modules.stable_diffusion.attention import SpatialTransformer
from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import (
Expand Down Expand Up @@ -1227,11 +1226,6 @@ def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from
missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys))

# pprint(sorted(missing_keys))
# input("these are missing...")
# pprint(sorted(unexpected_keys))
# input("these are unexpected...")

# SDXL specific mapping
if 'output_blocks.2.2.conv.bias' in missing_keys and 'output_blocks.2.1.conv.bias' in loaded_keys:
state_dict['output_blocks.2.2.conv.bias'] = state_dict['output_blocks.2.1.conv.bias']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,11 @@

import torch
import numpy as np
import json
from pprint import pprint
from safetensors import torch as torch_s
import safetensors

import json
import os
from argparse import ArgumentParser
from collections import OrderedDict

import torch
import torch.nn
from omegaconf import OmegaConf
from pytorch_lightning.core.saving import _load_state as ptl_load_state
from pytorch_lightning.trainer.trainer import Trainer
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.parts.nlp_overrides import (
GradScaler,
MegatronHalfPrecisionPlugin,
NLPDDPStrategy,
NLPSaveRestoreConnector,
PipelineMixedPrecisionPlugin,
)
from nemo.utils import logging

intkey = lambda x: int(x)
Expand Down Expand Up @@ -122,9 +101,6 @@ def model_to_tree(model):
keys = list(model.keys())
tree = SegTree()
for k in keys:
# wk = model.get(wk, torch.tensor([]))
# bk = model.get(bk, torch.tensor([]))
# tree.add(k, (wk, bk))
tree.add(k, "leaf")
return tree

Expand All @@ -151,28 +127,10 @@ def make_tiny_config(config):
return config

def load_hf_ckpt(in_dir, args):
# takes a directory as input
# params_file = os.path.join(in_dir, 'config.json')
# assert os.path.exists(params_file)
# with open(params_file, 'r') as fp:
# model_args = json.load(fp)
# if args.debug:
# model_args = make_tiny_config(model_args)

# # model = AutoModel.from_pretrained(in_dir)
# model = AutoModel.from_config(model_args)
# if args.model == 'unet':
# model = UNet2DConditionModel.from_pretrained(in_dir)
# elif args.model == 'vae':
# model = AutoencoderKL.from_pretrained(in_dir)
# model = torch_s.load(in_dir + "/diffusion")
# print(model)
ckpt = {}
with safetensors.safe_open(in_dir + "/diffusion_pytorch_model.safetensors", framework="pt") as f:
for k in f.keys():
ckpt[k] = f.get_tensor(k)
# input("enter to continue...")
# ckpt = model.state_dict()
return args, ckpt

def dup_convert_name_recursive(tree: SegTree, convert_name=None):
Expand Down Expand Up @@ -246,7 +204,7 @@ def map_attention_block(att_tree: SegTree):
currently assumed AttentionBlock
'''
# TODO: Add check for dual attention block
# TODO (rohit): Add check for dual attention block
def check_att_type(tree):
return "att_block"

Expand All @@ -263,7 +221,7 @@ def check_att_type(tree):
dup_convert_name_recursive(tblock['norm1'], 'attn1.norm')
dup_convert_name_recursive(tblock['norm2'], 'attn2.norm')
dup_convert_name_recursive(tblock['norm3'], 'ff.net.0')
# map ff
# map ff module
tblock['ff'].convert_name = "ff"
tblock['ff.net'].convert_name = 'net'
dup_convert_name_recursive(tblock['ff.net.0'], '1')
Expand Down Expand Up @@ -382,13 +340,6 @@ def convert_encoder(hf_tree: SegTree):

# map the `mid_block` ( NeMo's mid layer is hardcoded in terms of number of modules)
encoder['mid_block'].convert_name = 'mid'
# encoder['mid_block.resnets.0'].convert_name = 'block_1'
# encoder['mid_block.resnets.1'].convert_name = 'block_2'
# map_resnet_block(encoder['mid_block.resnets.0'])
# map_resnet_block(encoder['mid_block.resnets.1'])
# for reskey in {'conv1', 'conv2', 'norm1', 'norm2'}:
# dup_convert_name_recursive(encoder[f'mid_block.resnets.0.{reskey}'], reskey)
# dup_convert_name_recursive(encoder[f'mid_block.resnets.1.{reskey}'], reskey)
dup_convert_name_recursive(encoder[f'mid_block.resnets.0'], 'block_1')
dup_convert_name_recursive(encoder[f'mid_block.resnets.1'], 'block_2')

Expand All @@ -413,13 +364,6 @@ def convert_decoder(hf_tree: SegTree):
decoder['mid_block'].convert_name = 'mid'
dup_convert_name_recursive(decoder[f'mid_block.resnets.0'], 'block_1')
dup_convert_name_recursive(decoder[f'mid_block.resnets.1'], 'block_2')
# decoder['mid_block.resnets.0'].convert_name = 'block_1'
# decoder['mid_block.resnets.1'].convert_name = 'block_2'
# map_resnet_block(encoder['mid_block.resnets.0'])
# map_resnet_block(encoder['mid_block.resnets.1'])
# for reskey in {'conv1', 'conv2', 'norm1', 'norm2'}:
# dup_convert_name_recursive(decoder[f'mid_block.resnets.0.{reskey}'], reskey)
# dup_convert_name_recursive(decoder[f'mid_block.resnets.1.{reskey}'], reskey)
att = decoder['mid_block.attentions.0']
att.convert_name = 'attn_1'
dup_convert_name_recursive(att['group_norm'], 'norm')
Expand Down Expand Up @@ -476,8 +420,6 @@ def convert(args):

for hf_key, nemo_key in mapping.items():
nemo_ckpt[nemo_key] = hf_ckpt[hf_key]
# save this
# torch.save(args.output_path, nemo_ckpt)
torch.save(nemo_ckpt, args.output_path)
logging.info(f"Saved nemo file to {args.output_path}")

Expand Down

0 comments on commit f411298

Please sign in to comment.