Skip to content

Commit

Permalink
Update utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhanchen committed Jun 6, 2024
1 parent 471565f commit c1e1646
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion unsloth/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,25 @@ def get_lora_parameters(proj):
pass


def get_lora_parameters_bias(proj):
# For DPO or disabled adapters
base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
W = base_layer.weight
bias = base_layer.bias

if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
return W, QUANT_STATE(W), None, None, None, bias
pass

active_adapter = proj.active_adapters[0] if \
hasattr(proj, "active_adapters") else proj.active_adapter
A = proj.lora_A [active_adapter].weight
B = proj.lora_B [active_adapter].weight
s = proj.scaling[active_adapter]
return W, QUANT_STATE(W), A, B, s, bias
pass


def fast_dequantize(W, quant_state = None, out = None):
if quant_state is None: return W
if type(quant_state) is not list:
Expand Down Expand Up @@ -181,7 +200,7 @@ def fast_gemv(X, W, quant_state, out = None):

def fast_linear_forward(proj, X, temp_lora = None, out = None):

W, W_quant, lora_A, lora_B, lora_S = get_lora_parameters(proj)
W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)
bsz, q_len, in_dim = X.shape
if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)

Expand Down Expand Up @@ -216,6 +235,8 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None):
out = out.view(bsz, 1, out_dim)
pass

if bias is not None: out += bias

return out
pass

Expand Down

0 comments on commit c1e1646

Please sign in to comment.