diff --git a/tools/fastllm_pytools/torch2flm.py b/tools/fastllm_pytools/torch2flm.py index b81387f9..9a9da3f6 100644 --- a/tools/fastllm_pytools/torch2flm.py +++ b/tools/fastllm_pytools/torch2flm.py @@ -39,10 +39,18 @@ def write_int8(fo, v): fo.write(v.data) def write_int4(fo, v): - c_min = np.expand_dims(-np.abs(v).max(axis = -1), -1) - c_max = np.expand_dims(np.abs(v).max(axis = -1), -1) - c_scale = c_max / 7.0 - c_min = c_scale * -8.0 + # c_min = np.expand_dims(-np.abs(v).max(axis = -1), -1) + # c_max = np.expand_dims(np.abs(v).max(axis = -1), -1) + # c_scale = c_max / 7.0 + # c_min = c_scale * -8.0 + + c_min = np.expand_dims(v.min(axis = -1), -1) + c_max = np.expand_dims(v.max(axis = -1), -1) + c_scale = (c_max - c_min) / 15.0 + c_zero = np.round(0.0 - c_min / c_scale) + c_zero = c_zero.clip(0, 15) + c_min = -c_scale * c_zero + v = (v - c_min) / c_scale v = (v + 0.5).astype(np.int8).clip(0, 15).astype(np.uint8) v = v[:, 0::2] * 16 + v[:, 1::2]