diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index 8753f15..7208f20 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -724,11 +724,14 @@ def monkeypatch_or_replace_lora_extended( target_replace_module, search_class=[nn.Linear, LoraInjectedLinear, nn.Conv2d, LoraInjectedConv2d], ): + temp_proj = [] if (_child_module.__class__ == nn.Linear) or ( _child_module.__class__ == LoraInjectedLinear ): - if len(loras[0].shape) != 2: + if len(loras) == 0 and name == 'proj': + pass + elif len(loras[0].shape) != 2: continue _source = ( @@ -783,8 +786,17 @@ def monkeypatch_or_replace_lora_extended( # switch the module _module._modules[name] = _tmp - up_weight = loras.pop(0) - down_weight = loras.pop(0) + if name == 'proj': + up_weight = temp_proj.pop(0) + down_weight = temp_proj.pop(0) + else: + up_weight = loras.pop(0) + down_weight = loras.pop(0) + if up_weight.shape[0] not in weight.shape or down_weight.shape[1] not in weight.shape: + temp_proj.append(up_weight) + temp_proj.append(down_weight) + up_weight = loras.pop(0) + down_weight = loras.pop(0) _module._modules[name].lora_up.weight = nn.Parameter( up_weight.type(weight.dtype)