Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ControlNet LoRA #1936

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions models/control-lora-sdxl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
model:
target: cldm.cldm.ControlLDM
params:
control_stage_config:
target: cldm.cldm.ControlNet
params:
use_checkpoint: False
image_size: 32
use_spatial_transformer: True
legacy: False
use_fp16: True
in_channels: 4
model_channels: 320
num_res_blocks: 2
attention_resolutions: [2, 4]
transformer_depth: [0, 2, 10]
transformer_depth_middle: 10
channel_mult: [1, 2, 4]
use_linear_in_transformer: True
context_dim: [2048,2048,2048,2048,2048,2048,2048,2048,2048,2048]
num_heads: -1
num_head_channels: 64
hint_channels: 3
84 changes: 74 additions & 10 deletions scripts/cldm.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from contextlib import contextmanager
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from modules import devices, shared

from scripts.controlnet_lora import ControlLoraOps

cond_cast_unet = getattr(devices, 'cond_cast_unet', lambda x: x)

from ldm.util import exists
from ldm.modules.attention import SpatialTransformer
from ldm.modules.diffusionmodules.util import conv_nd, linear, zero_module, timestep_embedding
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
from ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock


class TorchHijackForUnet:
Expand Down Expand Up @@ -57,12 +60,42 @@ def get_node_name(name, parent_name):
return True, name[len(parent_name):]


@contextmanager
def use_controlnet_lora_operations():
original_torch_nn_linear = torch.nn.Linear
original_torch_nn_conv_2d = torch.nn.Conv2d
ops = ControlLoraOps()

torch.nn.Linear = ops.Linear
torch.nn.Conv2d = ops.Conv2d

try:
yield
finally:
torch.nn.Linear = original_torch_nn_linear
torch.nn.Conv2d = original_torch_nn_conv_2d


def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value))
del prev


class PlugableControlModel(nn.Module):
def __init__(self, state_dict, config_path, lowvram=False, base_model=None) -> None:
super().__init__()
self.config = OmegaConf.load(config_path)
self.control_model = ControlNet(**self.config.model.params.control_stage_config.params)

is_controlnet_lora = "lora_controlnet" in state_dict
self.config = OmegaConf.load(config_path)
if is_controlnet_lora:
with use_controlnet_lora_operations():
self.control_model = ControlNet(**self.config.model.params.control_stage_config.params)
else:
self.control_model = ControlNet(**self.config.model.params.control_stage_config.params)

if any([k.startswith("control_model.") for k, v in state_dict.items()]):
if 'difference' in state_dict and base_model is not None:
print('We will stop supporting diff models soon because of its lack of robustness.')
Expand All @@ -87,17 +120,34 @@ def __init__(self, state_dict, config_path, lowvram=False, base_model=None) -> N
print(f'Diff model cloned: {counter} values')
state_dict = final_state_dict
state_dict = {k.replace("control_model.", ""): v for k, v in state_dict.items() if k.startswith("control_model.")}

self.control_model.load_state_dict(state_dict)

if is_controlnet_lora:
# For ControlNet LoRA, state_dict contains up/down tensors to
# dynamically compute full weight from base_model's state_dict.
control_weights = state_dict
sd = base_model.state_dict()
for k in sd:
try:
set_attr(self.control_model, k, sd[k])
except:
pass

for k in control_weights:
if k != "lora_controlnet" and 'label_emb' not in k:
set_attr(self.control_model, k, control_weights[k].to(devices.get_device_for("controlnet")))
else:
self.control_model.load_state_dict(state_dict)

if not lowvram:
self.control_model.to(devices.get_device_for("controlnet"))

def reset(self):
pass

def forward(self, *args, **kwargs):
return self.control_model(*args, **kwargs)




class ControlNet(nn.Module):
def __init__(
Expand Down Expand Up @@ -130,6 +180,7 @@ def __init__(
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
transformer_depth_middle=None,
):
use_fp16 = getattr(devices, 'dtype_unet', devices.dtype) == th.float16 and not getattr(shared.cmd_opts, "no_half_controlnet", False)

Expand All @@ -152,6 +203,13 @@ def __init__(
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'

if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
if transformer_depth_middle is None:
transformer_depth_middle = transformer_depth[-1]

self.max_transformer_depth = max([*transformer_depth, transformer_depth_middle])

self.dims = dims
self.image_size = image_size
self.in_channels = in_channels
Expand Down Expand Up @@ -261,7 +319,7 @@ def __init__(
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
)
Expand Down Expand Up @@ -321,7 +379,7 @@ def __init__(
use_new_attention_order=use_new_attention_order,
# always uses a self-attn
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
),
Expand Down Expand Up @@ -357,6 +415,12 @@ def forward(self, x, hint, timesteps, context, **kwargs):
guided_hint = self.align(guided_hint, h1, w1)

h = x.type(self.dtype)

# `context` is only used in SpatialTransformer.
if not isinstance(context, list):
context = [context] * self.max_transformer_depth
assert len(context) >= self.max_transformer_depth

for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if guided_hint is not None:
h = module(h, emb, context)
Expand Down
10 changes: 5 additions & 5 deletions scripts/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ def load_control_model(p, unet, model, lowvram):

@staticmethod
def build_control_model(p, unet, model, lowvram):
is_sdxl = getattr(p.sd_model, 'is_sdxl', False)

if model is None or model == 'None':
raise RuntimeError(f"You have not selected any ControlNet Model.")

Expand Down Expand Up @@ -361,6 +363,9 @@ def build_control_model(p, unet, model, lowvram):
os.path.join(global_state.script_dir, 'models', model_stem.replace('-diff', '') + ".yaml")
]

if is_sdxl:
possible_config_filenames.append(os.path.join(model_dir_name, 'control-lora-sdxl.yaml'))

override_config = possible_config_filenames[0]

for possible_config_filename in possible_config_filenames:
Expand Down Expand Up @@ -720,11 +725,6 @@ def process(self, p, *args):
self.latest_network = None
return

is_sdxl = getattr(p.sd_model, 'is_sdxl', False)
if is_sdxl:
logger.warning('ControlNet does not support SDXL -- disabling')
return

detected_maps = []
forward_params = []
post_processors = []
Expand Down
121 changes: 121 additions & 0 deletions scripts/controlnet_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# ComfyUI (https://github.com/comfyanonymous/ComfyUI)
# Copyright (C) 2023

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import torch


class ControlLoraOps:
class Linear(torch.nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = None
self.up = None
self.down = None
self.bias = None

def forward(self, input):
if self.up is not None:
return torch.nn.functional.linear(
input,
self.weight
+ (
torch.mm(
self.up.flatten(start_dim=1), self.down.flatten(start_dim=1)
)
)
.reshape(self.weight.shape)
.type(self.weight.dtype),
self.bias,
)
else:
return torch.nn.functional.linear(input, self.weight, self.bias)

class Conv2d(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
device=None,
dtype=None,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.transposed = False
self.output_padding = 0
self.groups = groups
self.padding_mode = padding_mode

self.weight = None
self.bias = None
self.up = None
self.down = None

def forward(self, input):
if self.up is not None:
return torch.nn.functional.conv2d(
input,
self.weight
+ (
torch.mm(
self.up.flatten(start_dim=1), self.down.flatten(start_dim=1)
)
)
.reshape(self.weight.shape)
.type(self.weight.dtype),
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
else:
return torch.nn.functional.conv2d(
input,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)

def conv_nd(self, dims, *args, **kwargs):
if dims == 2:
return self.Conv2d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
Loading