Skip to content

Commit

Permalink
[LoRA] feat: support unload_lora_weights() for Flux Control. (huggi…
Browse files Browse the repository at this point in the history
…ngface#10206)

* feat: support unload_lora_weights() for Flux Control.

* tighten test

* minor

* updates

* meta device fixes.
  • Loading branch information
sayakpaul authored Dec 25, 2024
1 parent cd991d1 commit 1b202c5
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 0 deletions.
56 changes: 56 additions & 0 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2286,6 +2286,50 @@ def unload_lora_weights(self):
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
transformer._transformer_norm_layers = None

if getattr(transformer, "_overwritten_params", None) is not None:
overwritten_params = transformer._overwritten_params
module_names = set()

for param_name in overwritten_params:
if param_name.endswith(".weight"):
module_names.add(param_name.replace(".weight", ""))

for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear) and name in module_names:
module_weight = module.weight.data
module_bias = module.bias.data if module.bias is not None else None
bias = module_bias is not None

parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)

current_param_weight = overwritten_params[f"{name}.weight"]
in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0]
with torch.device("meta"):
original_module = torch.nn.Linear(
in_features,
out_features,
bias=bias,
dtype=module_weight.dtype,
)

tmp_state_dict = {"weight": current_param_weight}
if module_bias is not None:
tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]})
original_module.load_state_dict(tmp_state_dict, assign=True, strict=True)
setattr(parent_module, current_module_name, original_module)

del tmp_state_dict

if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
new_value = int(current_param_weight.shape[1])
old_value = getattr(transformer.config, attribute_name)
setattr(transformer.config, attribute_name, new_value)
logger.info(
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
)

@classmethod
def _maybe_expand_transformer_param_shape_or_error_(
cls,
Expand All @@ -2312,6 +2356,8 @@ def _maybe_expand_transformer_param_shape_or_error_(

# Expand transformer parameter shapes if they don't match lora
has_param_with_shape_update = False
overwritten_params = {}

is_peft_loaded = getattr(transformer, "peft_config", None) is not None
for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear):
Expand Down Expand Up @@ -2386,6 +2432,16 @@ def _maybe_expand_transformer_param_shape_or_error_(
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
)

# For `unload_lora_weights()`.
# TODO: this could lead to more memory overhead if the number of overwritten params
# are large. Should be revisited later and tackled through a `discard_original_layers` arg.
overwritten_params[f"{current_module_name}.weight"] = module_weight
if module_bias is not None:
overwritten_params[f"{current_module_name}.bias"] = module_bias

if len(overwritten_params) > 0:
transformer._overwritten_params = overwritten_params

return has_param_with_shape_update

@classmethod
Expand Down
66 changes: 66 additions & 0 deletions tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,72 @@ def test_load_regular_lora(self):
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3))

def test_lora_unload_with_parameter_expanded_shapes(self):
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)

logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.DEBUG)

# Change the transformer config to mimic a real use case.
num_channels_without_control = 4
transformer = FluxTransformer2DModel.from_config(
components["transformer"].config, in_channels=num_channels_without_control
).to(torch_device)
self.assertTrue(
transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
)

# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
components["transformer"] = transformer
pipe = FluxPipeline(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

_, _, inputs = self.get_dummy_inputs(with_generator=False)
control_image = inputs.pop("control_image")
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]

control_pipe = self.pipeline_class(**components)
out_features, in_features = control_pipe.transformer.x_embedder.weight.shape
rank = 4

dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
}
with CaptureLogger(logger) as cap_logger:
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")

inputs["control_image"] = control_image
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))

control_pipe.unload_lora_weights()
self.assertTrue(
control_pipe.transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
)
loaded_pipe = FluxPipeline.from_pipe(control_pipe)
self.assertTrue(
loaded_pipe.transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}",
)
inputs.pop("control_image")
unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertFalse(np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
self.assertTrue(pipe.transformer.config.in_channels == in_features)

@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
Expand Down

0 comments on commit 1b202c5

Please sign in to comment.