Skip to content

Commit

Permalink
loader v2
Browse files Browse the repository at this point in the history
  • Loading branch information
lllyasviel committed Aug 25, 2023
1 parent 17915db commit b34c50f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 14 deletions.
4 changes: 2 additions & 2 deletions scripts/cldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def __init__(
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(
SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=[context_dim] * transformer_depth[level],
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
)
Expand Down Expand Up @@ -226,7 +226,7 @@ def __init__(
use_scale_shift_norm=use_scale_shift_norm
),
SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=[context_dim] * transformer_depth_middle,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
),
Expand Down
43 changes: 31 additions & 12 deletions scripts/controlnet_model_guess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from modules import devices

from scripts.adapter import PlugableAdapter, Adapter, StyleAdapter, Adapter_light
from scripts.cldm import PlugableControlModel, ControlNet
from scripts.cldm import PlugableControlModel
from scripts.logging import logger
from scripts.diffuser import convert_from_diffuser_state_dict

controlnet_default_config = {'adm_in_channels': None,
Expand Down Expand Up @@ -97,10 +98,32 @@ def build_model_by_guess(state_dict, unet, model_path):
network = None

if 'input_hint_block.0.weight' in state_dict:
config = copy.deepcopy(controlnet_default_config)
config['global_average_pooling'] = model_has_shuffle_in_filename
config['hint_channels'] = int(state_dict['input_hint_block.0.weight'].shape[1])
config['context_dim'] = int(state_dict['input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight'].shape[1])
if 'label_emb.0.0.bias' not in state_dict:
config = copy.deepcopy(controlnet_default_config)
logger.info('controlnet_default_config')
config['global_average_pooling'] = model_has_shuffle_in_filename
config['hint_channels'] = int(state_dict['input_hint_block.0.weight'].shape[1])
config['context_dim'] = int(state_dict['input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight'].shape[1])
for key in state_dict.keys():
p = state_dict[key]
if 'proj_in.weight' in key or 'proj_out.weight' in key:
if len(p.shape) == 2:
p = p[..., None, None]
state_dict[key] = p
else:
has_full_layers = 'input_blocks.8.1.transformer_blocks.9.norm3.weight' in state_dict
has_mid_layers = 'input_blocks.8.1.transformer_blocks.0.norm3.weight' in state_dict
if has_full_layers:
config = copy.deepcopy(controlnet_sdxl_config)
logger.info('controlnet_sdxl_config')
elif has_mid_layers:
config = copy.deepcopy(controlnet_sdxl_mid_config)
logger.info('controlnet_sdxl_mid_config')
else:
config = copy.deepcopy(controlnet_sdxl_small_config)
logger.info('controlnet_sdxl_small_config')
config['global_average_pooling'] = False
config['hint_channels'] = int(state_dict['input_hint_block.0.weight'].shape[1])

if 'difference' in state_dict and unet is not None:
unet_state_dict = unet.state_dict()
Expand All @@ -115,32 +138,28 @@ def build_model_by_guess(state_dict, unet, model_path):
final_state_dict[key] = p_new
state_dict = final_state_dict

for key in state_dict.keys():
p = state_dict[key]
if 'proj_in.weight' in key or 'proj_out.weight' in key:
if len(p.shape) == 2:
p = p[..., None, None]
state_dict[key] = p

config['use_fp16'] = devices.dtype_unet == torch.float16

network = PlugableControlModel(config, state_dict)

if 'conv_in.weight' in state_dict:
config = copy.deepcopy(t2i_adapter_config)
logger.info('t2i_adapter_config')
config['cin'] = int(state_dict['conv_in.weight'].shape[1])
adapter = Adapter(**config).cpu()
adapter.load_state_dict(state_dict, strict=False)
network = PlugableAdapter(adapter)

if 'style_embedding' in state_dict:
config = copy.deepcopy(t2i_adapter_style_config)
logger.info('t2i_adapter_style_config')
adapter = StyleAdapter(**config).cpu()
adapter.load_state_dict(state_dict, strict=False)
network = PlugableAdapter(adapter)

if 'body.0.in_conv.weight' in state_dict:
config = copy.deepcopy(t2i_adapter_light_config)
logger.info('t2i_adapter_light_config')
config['cin'] = int(state_dict['body.0.in_conv.weight'].shape[1])
adapter = Adapter_light(**config).cpu()
adapter.load_state_dict(state_dict, strict=False)
Expand Down

0 comments on commit b34c50f

Please sign in to comment.