diff --git a/models/control-lora-canny-rank128.yaml b/models/control-lora-canny-rank128.yaml index 68d2c79dc..3dbfc7b7e 100644 --- a/models/control-lora-canny-rank128.yaml +++ b/models/control-lora-canny-rank128.yaml @@ -13,10 +13,11 @@ model: model_channels: 320 num_res_blocks: 2 attention_resolutions: [2, 4] - transformer_depth: 10 + transformer_depth: [0, 2, 10] + transformer_depth_middle: 10 channel_mult: [1, 2, 4] use_linear_in_transformer: True - context_dim: [2048, 2048,2048,2048,2048,2048,2048,2048,2048,2048] + context_dim: [2048,2048,2048,2048,2048,2048,2048,2048,2048,2048] num_heads: -1 num_head_channels: 64 hint_channels: 3 diff --git a/scripts/cldm.py b/scripts/cldm.py index 134a7bd20..7920908d7 100644 --- a/scripts/cldm.py +++ b/scripts/cldm.py @@ -77,8 +77,6 @@ def use_controlnet_lora_operations(): def set_attr(obj, attr, value): - print(f"setting {attr}") - attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) @@ -182,6 +180,7 @@ def __init__( num_attention_blocks=None, disable_middle_self_attn=False, use_linear_in_transformer=False, + transformer_depth_middle=None, ): use_fp16 = getattr(devices, 'dtype_unet', devices.dtype) == th.float16 and not getattr(shared.cmd_opts, "no_half_controlnet", False) @@ -204,6 +203,13 @@ def __init__( if num_head_channels == -1: assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + if isinstance(transformer_depth, int): + transformer_depth = len(channel_mult) * [transformer_depth] + if transformer_depth_middle is None: + transformer_depth_middle = transformer_depth[-1] + + self.max_transformer_depth = max([*transformer_depth, transformer_depth_middle]) + self.dims = dims self.image_size = image_size self.in_channels = in_channels @@ -313,7 +319,7 @@ def __init__( num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint ) @@ -373,7 +379,7 @@ def __init__( use_new_attention_order=use_new_attention_order, # always uses a self-attn ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint ), @@ -409,6 +415,12 @@ def forward(self, x, hint, timesteps, context, **kwargs): guided_hint = self.align(guided_hint, h1, w1) h = x.type(self.dtype) + + # `context` is only used in SpatialTransformer. + if not isinstance(context, list): + context = [context] * self.max_transformer_depth + assert len(context) >= self.max_transformer_depth + for module, zero_conv in zip(self.input_blocks, self.zero_convs): if guided_hint is not None: h = module(h, emb, context)