Skip to content

Commit

Permalink
FIX: not use assert and Exception, add license
Browse files Browse the repository at this point in the history
  • Loading branch information
5eqn committed Nov 22, 2024
1 parent bfc8fb3 commit aed135d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 28 deletions.
26 changes: 11 additions & 15 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,26 +268,21 @@ def corda_init(self, adapter_name, init_lora_weights):
in_dim = weight.data.size(1)

# Calculate WC from covariance matrix
assert hasattr(linear, "eigens")
if not hasattr(linear, "eigens"):
raise ValueError("eigens attribute not found, it's expected to be set by the pre-injection hook.")
eigens = linear.eigens
U = eigens.U_WC
S = eigens.S_WC
V = eigens.V_WC
r = self.r[adapter_name]

# nan or inf check
# if (S!=S).any():
if torch.isnan(S).any() or torch.isinf(S).any():
# print("nan in S")
raise Exception("nan or inf in S")
# if (U!=U).any():
raise ValueError("nan or inf in S")
if torch.isnan(U).any() or torch.isinf(U).any():
# print("nan in U")
raise Exception("nan or inf in U")
# if (V!=V).any():
raise ValueError("nan or inf in U")
if torch.isnan(V).any() or torch.isinf(V).any():
# print("nan in V")
raise Exception("nan or inf in V")
raise ValueError("nan or inf in V")

# Sanity check
logging.info(f"U.device = {U.device}, S.device = {S.device}, V.device = {V.device}")
Expand All @@ -296,11 +291,12 @@ def corda_init(self, adapter_name, init_lora_weights):
scale_u = torch.linalg.norm(U) / math.sqrt(r)
scale_v = torch.linalg.norm(V) / math.sqrt(r)
logging.info(f"scale_u: {scale_u:.2f}, scale_v: {scale_v:.2f}, svd_error: {svd_error:.2f}")
assert U.size(0) == out_dim
assert U.size(1) == r
assert S.size(0) == r
assert V.size(0) == in_dim
assert V.size(1) == r
if U.size(0) != out_dim or U.size(1) != r:
raise ValueError(f"U size mismatch: {U.size()} vs. ({out_dim}, {r})")
if S.size(0) != r:
raise ValueError(f"S size mismatch: {S.size()} vs. ({r},)")
if V.size(0) != in_dim or V.size(1) != r:
raise ValueError(f"V size mismatch: {V.size()} vs. ({in_dim}, {r})")

# Apply alpha
S /= self.scaling[adapter_name]
Expand Down
52 changes: 39 additions & 13 deletions src/peft/utils/corda_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Reference code: https://github.com/iboing/CorDA/blob/main/cordalib/decomposition.py
# Reference paper: https://arxiv.org/abs/2406.05223

import logging
import os
from typing import Any, Callable, Iterable, Optional
Expand Down Expand Up @@ -91,7 +108,8 @@ def preprocess_corda(
logging.info("CorDA cache file not found, building...")

# Specify CorDA method for each layer
assert corda_method is not None, "corda_method is required when cache_file is not provided"
if corda_method is None:
raise ValueError("corda_method is required when cache_file is not provided")
for name, module in target_modules(model, config):
module.corda_method = corda_method

Expand Down Expand Up @@ -144,7 +162,8 @@ def calib_cov_distribution(
module.covariance_matrix = all_covariance_matrix[name]
return

assert run_model is not None, "run_model must be specified when covariance file and cache file aren't built"
if run_model is None:
raise ValueError("run_model must be specified when covariance file and cache file aren't built")
if hooked_model is None:
hooked_model = model
hooked_model.eval()
Expand Down Expand Up @@ -241,7 +260,8 @@ def collect_eigens_for_layer(
in_dim = w.size(1)
min_dim = min(in_dim, out_dim)

assert hasattr(linear, "covariance_matrix")
if not hasattr(linear, "covariance_matrix"):
raise ValueError("build covariance matrix with calib_cov_distribution first")
covariance_matrix = linear.covariance_matrix.float()

damp = 0.01
Expand Down Expand Up @@ -272,11 +292,12 @@ def collect_eigens_for_layer(
logging.info(f"S: {S[:16]} ... {S[-16:]}")

# Sanity check, temporarily U and V are large, they will be crop after rank search
assert U.size(0) == out_dim
assert U.size(1) == min_dim
assert S.size(0) == min_dim
assert V.size(0) == in_dim
assert V.size(1) == min_dim
if U.size(0) != out_dim or U.size(1) != min_dim:
raise ValueError(f"U size mismatch: {U.size()} vs. ({out_dim}, {min_dim})")
if S.size(0) != min_dim:
raise ValueError(f"S size mismatch: {S.size()} vs. ({min_dim},)")
if V.size(0) != in_dim or V.size(1) != min_dim:
raise ValueError(f"V size mismatch: {V.size()} vs. ({in_dim}, {min_dim})")

# Offload U and V to CPU, they consume too much memory
U = U.cpu()
Expand Down Expand Up @@ -306,9 +327,14 @@ def crop_corda_eigens(model: nn.Module, config: LoraConfig):
raise ValueError("Invalid corda_method")

# Sanity check
assert module.eigens.S_WC.size(0) == module.rank
assert module.eigens.U_WC.size(0) == module.weight.size(0)
assert module.eigens.U_WC.size(1) == module.rank
assert module.eigens.V_WC.size(0) == module.weight.size(1)
assert module.eigens.V_WC.size(1) == module.rank
if module.eigens.S_WC.size(0) != module.rank:
raise ValueError(f"rank mismatch: {module.eigens.S_WC.size(0)} vs. {module.rank}")
if module.eigens.U_WC.size(0) != module.weight.size(0):
raise ValueError(f"U size mismatch: {module.eigens.U_WC.size(0)} vs. {module.weight.size(0)}")
if module.eigens.U_WC.size(1) != module.rank:
raise ValueError(f"U size mismatch: {module.eigens.U_WC.size(1)} vs. {module.rank}")
if module.eigens.V_WC.size(0) != module.weight.size(1):
raise ValueError(f"V size mismatch: {module.eigens.V_WC.size(0)} vs. {module.weight.size(1)}")
if module.eigens.V_WC.size(1) != module.rank:
raise ValueError(f"V size mismatch: {module.eigens.V_WC.size(1)} vs. {module.rank}")
logging.info("CorDA eigens cropped")

0 comments on commit aed135d

Please sign in to comment.