diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 1f2085df..ddee198b 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -63,6 +63,25 @@ def get_lora_parameters(proj): pass +def get_lora_parameters_bias(proj): + # For DPO or disabled adapters + base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj) + W = base_layer.weight + bias = base_layer.bias + + if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: + return W, QUANT_STATE(W), None, None, None, bias + pass + + active_adapter = proj.active_adapters[0] if \ + hasattr(proj, "active_adapters") else proj.active_adapter + A = proj.lora_A [active_adapter].weight + B = proj.lora_B [active_adapter].weight + s = proj.scaling[active_adapter] + return W, QUANT_STATE(W), A, B, s, bias +pass + + def fast_dequantize(W, quant_state = None, out = None): if quant_state is None: return W if type(quant_state) is not list: @@ -181,7 +200,7 @@ def fast_gemv(X, W, quant_state, out = None): def fast_linear_forward(proj, X, temp_lora = None, out = None): - W, W_quant, lora_A, lora_B, lora_S = get_lora_parameters(proj) + W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj) bsz, q_len, in_dim = X.shape if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S) @@ -216,6 +235,8 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): out = out.view(bsz, 1, out_dim) pass + if bias is not None: out += bias + return out pass