Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
[Neural Speed] Add compute_dtype For Quant Model Convert (#267)
Browse files Browse the repository at this point in the history
* add compute_dtype for quant model convert

Signed-off-by: Yu, Zhentao <[email protected]>

* fix pylint

Signed-off-by: Yu, Zhentao <[email protected]>

* set cmp_dtype=int8 by default

Signed-off-by: Yu, Zhentao <[email protected]>

* add llama2-gptq in test

Signed-off-by: Yu, Zhentao <[email protected]>

---------

Signed-off-by: Yu, Zhentao <[email protected]>
zhentaoyu authored May 28, 2024

Verified

This commit was signed with the committer’s verified signature.
caixw caixw
1 parent bf251bd commit cadb01e
Showing 10 changed files with 154 additions and 76 deletions.
9 changes: 5 additions & 4 deletions neural_speed/convert/common.py
Original file line number Diff line number Diff line change
@@ -646,7 +646,8 @@ def convert_q4_f32_tensor(src_name, dst_name, model, fout, q_config, n_head, n_h
print(f"converting {dst_name} quantized tensor to fp32 tensor")


def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head_kv=0, permute_func=None):
def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head_kv=0, permute_func=None,
compute_dtype="int8"):
# unpack weight and repack into jblas format
import neural_speed.llama_cpp as cpp_model
qzeros = model[f"{src_name}.qzeros"]
@@ -708,12 +709,12 @@ def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head,
weight_dtype="int4" if q_config['bits'] == 4 else "int8",
group_size=q_config['group_size'],
alg="sym" if q_config['sym'] else "asym",
compute_dtype="int8")
compute_dtype=compute_dtype)
dst.flatten()[:byte_size].tofile(fout)
print(f"converting {dst_name} qauntized tensor to bestla q4 block")


def convert_to_qx_bestla_tensor(src_name, dst_name, model, fout, q_config):
def convert_to_qx_bestla_tensor(src_name, dst_name, model, fout, q_config, compute_dtype="int8"):
# unpack weight and repack into 3bits / 4bits BestLA format
import neural_speed.llama_cpp as cpp_model
if ".weight" in src_name:
@@ -791,6 +792,6 @@ def convert_to_qx_bestla_tensor(src_name, dst_name, model, fout, q_config):
weight_dtype=weight_dtype,
group_size=q_config['group_size'],
alg="sym" if q_config['sym'] else "asym",
compute_dtype="int8")
compute_dtype=compute_dtype)
dst.flatten()[:byte_size].tofile(fout)
print(f"convert_to_qx_bestla_tensor: {src_name:>40} -> {dst_name:<40} shape: {shape}, byte_size: {byte_size:<10}")
16 changes: 11 additions & 5 deletions neural_speed/convert/convert_quantized_baichuan.py
Original file line number Diff line number Diff line change
@@ -54,6 +54,10 @@ def main(args_in: Optional[List[str]] = None) -> None:
choices=["huggingface", "modelscope"],
default="huggingface",
help="hub to load model")
parser.add_argument("--compute_dtype",
choices=["fp32", "bf16", "int8"],
default="int8",
help="compute_dtype for model inference")
parser.add_argument("model", type=Path, help="directory containing model file")
args = parser.parse_args(args_in)

@@ -168,6 +172,8 @@ def convert_qwen_to_fp32_tensor(src_name, dst_name, model, fout):
convert_qwen_to_fp32_tensor("model.norm.weight", "model.norm.weight", list_vars, fout)
convert_qwen_to_fp32_tensor("lm_head.weight", "lm_head.weight", list_vars, fout)

cmp_dtype = args.compute_dtype
print("model compute_dtype is {}".format(cmp_dtype))
for i in range(hparams["num_hidden_layers"]):
prefix = "model.layers." + str(i)

@@ -177,17 +183,17 @@ def convert_qwen_to_fp32_tensor(src_name, dst_name, model, fout):
f"{prefix}.post_attention_layernorm.weight", list_vars, fout)
# qkv GEMM
convert_to_qx_bestla_tensor(f"{prefix}.self_attn.W_pack.weight", f"{prefix}.self_attn.W_pack.weight", list_vars,
fout, quantize_config)
fout, quantize_config, compute_dtype=cmp_dtype)
convert_to_qx_bestla_tensor(f"{prefix}.self_attn.o_proj.weight", f"{prefix}.self_attn.o_proj.weight", list_vars,
fout, quantize_config)
fout, quantize_config, compute_dtype=cmp_dtype)

# ffn GEMM
convert_to_qx_bestla_tensor(f"{prefix}.mlp.gate_proj", f"{prefix}.mlp.gate_proj.weight", list_vars, fout,
quantize_config)
quantize_config, compute_dtype=cmp_dtype)
convert_to_qx_bestla_tensor(f"{prefix}.mlp.down_proj", f"{prefix}.mlp.down_proj.weight", list_vars, fout,
quantize_config)
quantize_config, compute_dtype=cmp_dtype)
convert_to_qx_bestla_tensor(f"{prefix}.mlp.up_proj", f"{prefix}.mlp.up_proj.weight", list_vars, fout,
quantize_config)
quantize_config, compute_dtype=cmp_dtype)

fout.close()
print(f"Success! saved as {out_path}")
15 changes: 11 additions & 4 deletions neural_speed/convert/convert_quantized_falcon.py
Original file line number Diff line number Diff line change
@@ -33,6 +33,10 @@ def main(args_in: Optional[List[str]] = None) -> None:
choices=["huggingface", "modelscope"],
default="huggingface",
help="hub to load model")
parser.add_argument("--compute_dtype",
choices=["fp32", "bf16", "int8"],
default="int8",
help="compute_dtype for model inference")
parser.add_argument("model", type=Path, help="directory containing model file")
args = parser.parse_args(args_in)

@@ -142,6 +146,8 @@ def convert_to_fp32_tensor(src_name, dst_name, model, fout):
convert_to_fp32_tensor("transformer.ln_f.bias", "transformer.ln_f.bias", list_vars, fout)
convert_to_fp32_tensor("lm_head.weight", "lm_head.weight", list_vars, fout)

cmp_dtype = args.compute_dtype
print("model compute_dtype is {}".format(cmp_dtype))
for i in range(hparams["n_layer"]):
prefix = "transformer.h." + str(i)

@@ -157,15 +163,16 @@ def convert_to_fp32_tensor(src_name, dst_name, model, fout):

# qkv GEMM
convert_to_qx_bestla_tensor(f"{prefix}.self_attention.query_key_value.weight",
f"{prefix}.self_attention.query_key_value.weight", list_vars, fout, quantize_config)
f"{prefix}.self_attention.query_key_value.weight", list_vars, fout, quantize_config,
compute_dtype=cmp_dtype)
convert_to_qx_bestla_tensor(f"{prefix}.self_attention.dense.weight", f"{prefix}.self_attention.dense.weight",
list_vars, fout, quantize_config)
list_vars, fout, quantize_config, compute_dtype=cmp_dtype)

# ffn GEMM
convert_to_qx_bestla_tensor(f"{prefix}.mlp.dense_h_to_4h", f"{prefix}.mlp.dense_h_to_4h.weight", list_vars,
fout, quantize_config)
fout, quantize_config, compute_dtype=cmp_dtype)
convert_to_qx_bestla_tensor(f"{prefix}.mlp.dense_4h_to_h", f"{prefix}.mlp.dense_4h_to_h.weight", list_vars,
fout, quantize_config)
fout, quantize_config, compute_dtype=cmp_dtype)

fout.close()
print(f"Success! saved as {out_path}")
23 changes: 15 additions & 8 deletions neural_speed/convert/convert_quantized_gptj.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@
from transformers import AutoTokenizer


def convert_to_qx_bestla_tensor(src_name, dst_name, model, fout, q_config):
def convert_to_qx_bestla_tensor(src_name, dst_name, model, fout, q_config, compute_dtype="int8"):
# unpack weight and repack into 3bits / 4bits BestLA format
import neural_speed.llama_cpp as cpp_model
if ".weight" in src_name:
@@ -89,7 +89,7 @@ def convert_to_qx_bestla_tensor(src_name, dst_name, model, fout, q_config):
weight_dtype=weight_dtype,
group_size=q_config['group_size'],
alg="sym" if q_config['sym'] else "asym",
compute_dtype="int8")
compute_dtype=compute_dtype)
dst.flatten()[:byte_size].tofile(fout)
print(f"converting {dst_name} quantized tensor to bestla q{q_config['bits']} block")

@@ -100,6 +100,10 @@ def main(args_in: Optional[List[str]] = None) -> None:
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("--model_hub", choices=["huggingface","modelscope"],
default="huggingface", help="hub to load model")
parser.add_argument("--compute_dtype",
choices=["fp32", "bf16", "int8"],
default="int8",
help="compute_dtype for model inference")
parser.add_argument("model", type=Path, help="directory containing model file")
args = parser.parse_args(args_in)

@@ -180,20 +184,23 @@ def main(args_in: Optional[List[str]] = None) -> None:
convert_to_fp32_tensor("lm_head.bias", "lm_head.bias", list_vars, fout)
convert_to_fp32_tensor("lm_head.weight", "lm_head.weight", list_vars, fout)

cmp_dtype = args.compute_dtype
print("model compute_dtype is {}".format(cmp_dtype))
for i in tqdm(range(n_layer), desc="Processing layers"):
convert_to_qx_bestla_tensor(f"transformer.h.{i}.attn.q_proj.weight",
f"transformer.h.{i}.attn.q_proj.weight", list_vars, fout, quantize_config)
f"transformer.h.{i}.attn.q_proj.weight", list_vars, fout, quantize_config, compute_dtype=cmp_dtype)
convert_to_qx_bestla_tensor(f"transformer.h.{i}.attn.k_proj.weight",
f"transformer.h.{i}.attn.k_proj.weight", list_vars, fout, quantize_config)
f"transformer.h.{i}.attn.k_proj.weight", list_vars, fout, quantize_config, compute_dtype=cmp_dtype)
convert_to_qx_bestla_tensor(f"transformer.h.{i}.attn.v_proj.weight",
f"transformer.h.{i}.attn.v_proj.weight", list_vars, fout, quantize_config)
f"transformer.h.{i}.attn.v_proj.weight", list_vars, fout, quantize_config, compute_dtype=cmp_dtype)

convert_to_qx_bestla_tensor(f"transformer.h.{i}.attn.out_proj.weight",
f"transformer.h.{i}.attn.out_proj.weight", list_vars, fout, quantize_config)
f"transformer.h.{i}.attn.out_proj.weight", list_vars, fout, quantize_config,
compute_dtype=cmp_dtype)
convert_to_qx_bestla_tensor(f"transformer.h.{i}.mlp.fc_in.weight",
f"transformer.h.{i}.mlp.fc_in.weight", list_vars, fout, quantize_config)
f"transformer.h.{i}.mlp.fc_in.weight", list_vars, fout, quantize_config, compute_dtype=cmp_dtype)
convert_to_qx_bestla_tensor(f"transformer.h.{i}.mlp.fc_out.weight",
f"transformer.h.{i}.mlp.fc_out.weight", list_vars, fout, quantize_config)
f"transformer.h.{i}.mlp.fc_out.weight", list_vars, fout, quantize_config, compute_dtype=cmp_dtype)

convert_to_fp32_tensor(f"transformer.h.{i}.mlp.fc_in.bias",
f"transformer.h.{i}.mlp.fc_in.bias", list_vars, fout)
30 changes: 20 additions & 10 deletions neural_speed/convert/convert_quantized_llama.py
Original file line number Diff line number Diff line change
@@ -28,7 +28,8 @@ def permute_func(weights, n_head: int, n_head_kv: int):
*weights.shape[1:]).swapaxes(1, 2).reshape(weights.shape))


def convert_to_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head_kv=0, permute_func=None):
def convert_to_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head_kv=0, permute_func=None,
compute_dtype="int8"):
# unpack weight and repack into jblas format
import neural_speed.llama_cpp as cpp_model
if ".weight" in src_name:
@@ -95,7 +96,7 @@ def convert_to_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_hea
weight_dtype=weight_dtype,
group_size=q_config['group_size'],
alg="sym" if q_config['sym'] else "asym",
compute_dtype="int8")
compute_dtype=compute_dtype)
dst.flatten()[:byte_size].tofile(fout)
print(f"converting {dst_name} quantized tensor to bestla q4 block")

@@ -108,6 +109,10 @@ def main(args_in: Optional[List[str]] = None) -> None:
choices=["huggingface", "modelscope"],
default="huggingface",
help="hub to load model")
parser.add_argument("--compute_dtype",
choices=["fp32", "bf16", "int8"],
default="int8",
help="compute_dtype for model inference")
parser.add_argument("model", type=Path, help="directory containing model file")
args = parser.parse_args(args_in)

@@ -193,13 +198,16 @@ def llama(model, config, quantize_config, f, model_path, out_path):
f.write(struct.pack("f", score))

# 3. write tensors
cmp_dtype = args.compute_dtype
print("model compute_dtype is {}".format(cmp_dtype))
list_vars = model
convert_to_fp32_tensor("model.embed_tokens.weight", "tok_embeddings.weight", list_vars, f)
convert_to_fp32_tensor("model.norm.weight", "norm.weight", list_vars, f)
if list_vars.get("lm_head.qweight") is None:
convert_to_fp32_tensor("lm_head.weight", "output.weight", list_vars, f)
else:
convert_to_q4_bestla_tensor(f"lm_head.weight", f"output.weight", list_vars, f, quantize_config, n_head)
convert_to_q4_bestla_tensor(f"lm_head.weight", f"output.weight", list_vars, f, quantize_config, n_head,
compute_dtype=cmp_dtype)

for i in range(n_layer):
convert_to_q4_bestla_tensor(f"model.layers.{i}.self_attn.q_proj",
@@ -209,25 +217,27 @@ def llama(model, config, quantize_config, f, model_path, out_path):
quantize_config,
n_head,
n_head,
permute_func=permute_func)
permute_func=permute_func,
compute_dtype=cmp_dtype)
convert_to_q4_bestla_tensor(f"model.layers.{i}.self_attn.k_proj",
f"layers.{i}.attention.wk.weight",
list_vars,
f,
quantize_config,
n_head,
n_head_kv,
permute_func=permute_func)
permute_func=permute_func,
compute_dtype=cmp_dtype)
convert_to_q4_bestla_tensor(f"model.layers.{i}.self_attn.v_proj", f"layers.{i}.attention.wv.weight",
list_vars, f, quantize_config, n_head)
list_vars, f, quantize_config, n_head, compute_dtype=cmp_dtype)
convert_to_q4_bestla_tensor(f"model.layers.{i}.self_attn.o_proj", f"layers.{i}.attention.wo.weight",
list_vars, f, quantize_config, n_head)
list_vars, f, quantize_config, n_head, compute_dtype=cmp_dtype)
convert_to_q4_bestla_tensor(f"model.layers.{i}.mlp.gate_proj", f"layers.{i}.feed_forward.w1.weight",
list_vars, f, quantize_config, n_head)
list_vars, f, quantize_config, n_head, compute_dtype=cmp_dtype)
convert_to_q4_bestla_tensor(f"model.layers.{i}.mlp.down_proj", f"layers.{i}.feed_forward.w2.weight",
list_vars, f, quantize_config, n_head)
list_vars, f, quantize_config, n_head, compute_dtype=cmp_dtype)
convert_to_q4_bestla_tensor(f"model.layers.{i}.mlp.up_proj", f"layers.{i}.feed_forward.w3.weight",
list_vars, f, quantize_config, n_head)
list_vars, f, quantize_config, n_head, compute_dtype=cmp_dtype)

convert_to_fp32_tensor(f"model.layers.{i}.input_layernorm.weight", f"layers.{i}.attention_norm.weight",
list_vars, f)
27 changes: 18 additions & 9 deletions neural_speed/convert/convert_quantized_mistral.py
Original file line number Diff line number Diff line change
@@ -30,7 +30,8 @@ def permute_func(weights, n_head: int, n_head_kv: int):
*weights.shape[1:]).swapaxes(1, 2).reshape(weights.shape))


def convert_to_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head_kv=0, permute_func=None):
def convert_to_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head_kv=0, permute_func=None,
compute_dtype="int8"):
# unpack weight and repack into jblas format
import neural_speed.llama_cpp as cpp_model
if ".weight" in src_name:
@@ -97,7 +98,7 @@ def convert_to_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_hea
weight_dtype=weight_dtype,
group_size=q_config['group_size'],
alg="sym" if q_config['sym'] else "asym",
compute_dtype="int8")
compute_dtype=compute_dtype)
dst.flatten()[:byte_size].tofile(fout)
print(f"converting {dst_name} quantized tensor to bestla q4 block")

@@ -108,6 +109,10 @@ def main(args_in: Optional[List[str]] = None) -> None:
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("--model_hub", choices=["huggingface","modelscope"],
default="huggingface", help="hub to load model")
parser.add_argument("--compute_dtype",
choices=["fp32", "bf16", "int8"],
default="int8",
help="compute_dtype for model inference")
parser.add_argument("model", type=Path, help="directory containing model file")
args = parser.parse_args(args_in)

@@ -186,6 +191,8 @@ def main(args_in: Optional[List[str]] = None) -> None:
f.write(struct.pack("f", score))

# 3. write tensors
cmp_dtype = args.compute_dtype
print("model compute_dtype is {}".format(cmp_dtype))
list_vars = model
convert_to_fp32_tensor("model.embed_tokens.weight", "tok_embeddings.weight", list_vars, f)
convert_to_fp32_tensor("model.norm.weight", "norm.weight", list_vars, f)
@@ -199,25 +206,27 @@ def main(args_in: Optional[List[str]] = None) -> None:
quantize_config,
n_head,
n_head,
permute_func=permute_func)
permute_func=permute_func,
compute_dtype=cmp_dtype)
convert_to_q4_bestla_tensor(f"model.layers.{i}.self_attn.k_proj",
f"layers.{i}.attention.wk.weight",
list_vars,
f,
quantize_config,
n_head,
n_head_kv,
permute_func=permute_func)
permute_func=permute_func,
compute_dtype=cmp_dtype)
convert_to_q4_bestla_tensor(f"model.layers.{i}.self_attn.v_proj", f"layers.{i}.attention.wv.weight", list_vars,
f, quantize_config, n_head)
f, quantize_config, n_head, compute_dtype=cmp_dtype)
convert_to_q4_bestla_tensor(f"model.layers.{i}.self_attn.o_proj", f"layers.{i}.attention.wo.weight", list_vars,
f, quantize_config, n_head)
f, quantize_config, n_head, compute_dtype=cmp_dtype)
convert_to_q4_bestla_tensor(f"model.layers.{i}.mlp.gate_proj", f"layers.{i}.feed_forward.w1.weight", list_vars,
f, quantize_config, n_head)
f, quantize_config, n_head, compute_dtype=cmp_dtype)
convert_to_q4_bestla_tensor(f"model.layers.{i}.mlp.down_proj", f"layers.{i}.feed_forward.w2.weight", list_vars,
f, quantize_config, n_head)
f, quantize_config, n_head, compute_dtype=cmp_dtype)
convert_to_q4_bestla_tensor(f"model.layers.{i}.mlp.up_proj", f"layers.{i}.feed_forward.w3.weight", list_vars, f,
quantize_config, n_head)
quantize_config, n_head, compute_dtype=cmp_dtype)

convert_to_fp32_tensor(f"model.layers.{i}.input_layernorm.weight", f"layers.{i}.attention_norm.weight",
list_vars, f)
33 changes: 23 additions & 10 deletions neural_speed/convert/convert_quantized_mixtral.py
Original file line number Diff line number Diff line change
@@ -30,7 +30,8 @@ def permute_func(weights, n_head: int, n_head_kv: int):
*weights.shape[1:]).swapaxes(1, 2).reshape(weights.shape))


def convert_to_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head_kv=0, permute_func=None):
def convert_to_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head_kv=0, permute_func=None,
compute_dtype="int8"):
# unpack weight and repack into jblas format
import neural_speed.llama_cpp as cpp_model
if ".weight" in src_name:
@@ -103,7 +104,7 @@ def convert_to_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_hea
weight_dtype=weight_dtype,
group_size=q_config['group_size'],
alg="sym" if q_config['sym'] else "asym",
compute_dtype="int8")
compute_dtype=compute_dtype)
dst.flatten()[:byte_size].tofile(fout)
print(f"convert_to_q4_bestla_tensor: {src_name:>40} -> {dst_name:<40} shape: {shape}, byte_size: {byte_size:<10}")

@@ -114,6 +115,10 @@ def main(args_in: Optional[List[str]] = None) -> None:
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("--model_hub", choices=["huggingface","modelscope"],
default="huggingface", help="hub to load model")
parser.add_argument("--compute_dtype",
choices=["fp32", "bf16", "int8"],
default="int8",
help="compute_dtype for model inference")
parser.add_argument("model", type=Path, help="directory containing model file")
args = parser.parse_args(args_in)

@@ -233,6 +238,8 @@ def convert_mixtral_to_fp32_tensor(src_name, dst_name, model, fout):
data.tofile(f)

# 3. write tensors
cmp_dtype = args.compute_dtype
print("model compute_dtype is {}".format(cmp_dtype))
list_vars = model
quant_moe_gate = False
for name in list_vars.keys():
@@ -251,34 +258,40 @@ def convert_mixtral_to_fp32_tensor(src_name, dst_name, model, fout):
quantize_config,
n_head,
n_head,
permute_func=permute_func)
permute_func=permute_func,
compute_dtype=cmp_dtype)
convert_to_q4_bestla_tensor(f"model.layers.{i}.self_attn.k_proj",
f"layers.{i}.attention.wk.weight",
list_vars,
f,
quantize_config,
n_head,
n_head_kv,
permute_func=permute_func)
permute_func=permute_func,
compute_dtype=cmp_dtype)
convert_to_q4_bestla_tensor(f"model.layers.{i}.self_attn.v_proj", f"layers.{i}.attention.wv.weight", list_vars,
f, quantize_config, n_head)
f, quantize_config, n_head, compute_dtype=cmp_dtype)
convert_to_q4_bestla_tensor(f"model.layers.{i}.self_attn.o_proj", f"layers.{i}.attention.wo.weight", list_vars,
f, quantize_config, n_head)
f, quantize_config, n_head, compute_dtype=cmp_dtype)

if quant_moe_gate:
convert_to_q4_bestla_tensor(f"model.layers.{i}.block_sparse_moe.gate.weight",
f"layers.{i}.ffn_gate_inp.weight", list_vars, f, quantize_config, n_head)
f"layers.{i}.ffn_gate_inp.weight", list_vars, f, quantize_config, n_head,
compute_dtype=cmp_dtype)
else:
convert_mixtral_to_fp32_tensor(f"model.layers.{i}.block_sparse_moe.gate.weight",
f"layers.{i}.ffn_gate_inp.weight", list_vars, f)

for j in range(8):
convert_to_q4_bestla_tensor(f"model.layers.{i}.block_sparse_moe.experts.{j}.w1",
f"layers.{i}.ffn_gate.{j}.weight", list_vars, f, quantize_config, n_head)
f"layers.{i}.ffn_gate.{j}.weight", list_vars, f, quantize_config, n_head,
compute_dtype=cmp_dtype)
convert_to_q4_bestla_tensor(f"model.layers.{i}.block_sparse_moe.experts.{j}.w2",
f"layers.{i}.ffn_down.{j}.weight", list_vars, f, quantize_config, n_head)
f"layers.{i}.ffn_down.{j}.weight", list_vars, f, quantize_config, n_head,
compute_dtype=cmp_dtype)
convert_to_q4_bestla_tensor(f"model.layers.{i}.block_sparse_moe.experts.{j}.w3",
f"layers.{i}.ffn_up.{j}.weight", list_vars, f, quantize_config, n_head)
f"layers.{i}.ffn_up.{j}.weight", list_vars, f, quantize_config, n_head,
compute_dtype=cmp_dtype)

convert_mixtral_to_fp32_tensor(f"model.layers.{i}.input_layernorm.weight", f"layers.{i}.attention_norm.weight",
list_vars, f)
38 changes: 24 additions & 14 deletions neural_speed/convert/convert_quantized_phi.py
Original file line number Diff line number Diff line change
@@ -25,7 +25,8 @@
from transformers import AutoModelForCausalLM, AutoTokenizer


def convert_phi1_5_gptq_to_bestTLA(model_path, out_path, outtype, model, hparams, quantize_config):
def convert_phi1_5_gptq_to_bestTLA(model_path, out_path, outtype, model, hparams, quantize_config,
compute_dtype="int8"):
list_vars = model
for name in list_vars.keys():
print(name)
@@ -130,6 +131,7 @@ def convert_qwen_to_fp32_tensor(src_name, dst_name, model, fout):
convert_qwen_to_fp32_tensor("lm_head.weight", "lm_head.weight", list_vars, fout)
convert_qwen_to_fp32_tensor("lm_head.bias", "lm_head.bias", list_vars, fout)

print("model compute_dtype is {}".format(compute_dtype))
for i in range(hparams["num_hidden_layers"]):
prefix = "model.layers." + str(i)
renamed_prefix = "model.layers." + str(i)
@@ -141,13 +143,13 @@ def convert_qwen_to_fp32_tensor(src_name, dst_name, model, fout):

# qkv GEMM
convert_to_qx_bestla_tensor(f"{prefix}.self_attn.q_proj.weight", f"{prefix}.self_attn.q_proj.weight", list_vars,
fout, quantize_config)
fout, quantize_config, compute_dtype=compute_dtype)
convert_to_qx_bestla_tensor(f"{prefix}.self_attn.k_proj.weight", f"{prefix}.self_attn.k_proj.weight", list_vars,
fout, quantize_config)
fout, quantize_config, compute_dtype=compute_dtype)
convert_to_qx_bestla_tensor(f"{prefix}.self_attn.v_proj.weight", f"{prefix}.self_attn.v_proj.weight", list_vars,
fout, quantize_config)
fout, quantize_config, compute_dtype=compute_dtype)
convert_to_qx_bestla_tensor(f"{prefix}.self_attn.dense.weight", f"{prefix}.self_attn.dense.weight", list_vars,
fout, quantize_config)
fout, quantize_config, compute_dtype=compute_dtype)

convert_qwen_to_fp32_tensor(f"{prefix}.self_attn.q_proj.bias", f"{prefix}.self_attn.q_proj.bias", list_vars,
fout)
@@ -159,9 +161,9 @@ def convert_qwen_to_fp32_tensor(src_name, dst_name, model, fout):

# ffn GEMM
convert_to_qx_bestla_tensor(f"{prefix}.mlp.fc1.weight", f"{renamed_prefix}.mlp.fc1.weight", list_vars, fout,
quantize_config)
quantize_config, compute_dtype=compute_dtype)
convert_to_qx_bestla_tensor(f"{prefix}.mlp.fc2.weight", f"{renamed_prefix}.mlp.fc2.weight", list_vars, fout,
quantize_config)
quantize_config, compute_dtype=compute_dtype)

convert_qwen_to_fp32_tensor(f"{prefix}.mlp.fc1.bias", f"{renamed_prefix}.mlp.fc1.bias", list_vars, fout)
convert_qwen_to_fp32_tensor(f"{prefix}.mlp.fc2.bias", f"{renamed_prefix}.mlp.fc2.bias", list_vars, fout)
@@ -170,7 +172,8 @@ def convert_qwen_to_fp32_tensor(src_name, dst_name, model, fout):
print(f"Success! saved as {out_path}")


def convert_phi2_gptq_to_bestTLA(model_path, model, out_path, hparams, quantize_config):
def convert_phi2_gptq_to_bestTLA(model_path, out_path, outtype, model, hparams, quantize_config,
compute_dtype="int8"):
list_vars = model
for name in list_vars.keys():
print(name)
@@ -274,6 +277,7 @@ def convert_qwen_to_fp32_tensor(src_name, dst_name, model, fout):
convert_qwen_to_fp32_tensor("lm_head.linear.weight", "lm_head.weight", list_vars, fout)
convert_qwen_to_fp32_tensor("lm_head.linear.bias", "lm_head.bias", list_vars, fout)

print("model compute_dtype is {}".format(compute_dtype))
for i in range(hparams["n_layer"]):
prefix = "transformer.h." + str(i)
renamed_prefix = "model.layers." + str(i)
@@ -283,19 +287,19 @@ def convert_qwen_to_fp32_tensor(src_name, dst_name, model, fout):

# qkv GEMM
convert_to_qx_bestla_tensor(f"{prefix}.mixer.Wqkv.weight", f"{renamed_prefix}.mixer.Wqkv.weight", list_vars,
fout, quantize_config)
fout, quantize_config, compute_dtype=compute_dtype)
convert_qwen_to_fp32_tensor(f"{prefix}.mixer.Wqkv.bias", f"{renamed_prefix}.mixer.Wqkv.bias", list_vars, fout)

convert_to_qx_bestla_tensor(f"{prefix}.mixer.out_proj.weight", f"{renamed_prefix}.mixer.out_proj.weight",
list_vars, fout, quantize_config)
list_vars, fout, quantize_config, compute_dtype=compute_dtype)
convert_qwen_to_fp32_tensor(f"{prefix}.mixer.out_proj.bias", f"{renamed_prefix}.mixer.out_proj.bias", list_vars,
fout)

# ffn GEMM
convert_to_qx_bestla_tensor(f"{prefix}.mlp.fc1.weight", f"{renamed_prefix}.mlp.fc1.weight", list_vars, fout,
quantize_config)
quantize_config, compute_dtype=compute_dtype)
convert_to_qx_bestla_tensor(f"{prefix}.mlp.fc2.weight", f"{renamed_prefix}.mlp.fc2.weight", list_vars, fout,
quantize_config)
quantize_config, compute_dtype=compute_dtype)

convert_qwen_to_fp32_tensor(f"{prefix}.mlp.fc1.bias", f"{renamed_prefix}.mlp.fc1.bias", list_vars, fout)
convert_qwen_to_fp32_tensor(f"{prefix}.mlp.fc2.bias", f"{renamed_prefix}.mlp.fc2.bias", list_vars, fout)
@@ -317,6 +321,10 @@ def main(args_in: Optional[List[str]] = None) -> None:
default="NE",
choices=["NE", "GGUF"],
help="convert to the GGUF or NE format")
parser.add_argument("--compute_dtype",
choices=["fp32", "bf16", "int8"],
default="int8",
help="compute_dtype for model inference")
parser.add_argument("model", type=Path, help="directory containing model file")
args = parser.parse_args(args_in)

@@ -326,9 +334,11 @@ def main(args_in: Optional[List[str]] = None) -> None:
model, hparams, quantize_config = load_quantized_safetensors(model_path)

if hparams['model_type'] == "phi":
convert_phi1_5_gptq_to_bestTLA(model_path, out_path, args.outtype, model, hparams, quantize_config)
convert_phi1_5_gptq_to_bestTLA(model_path, out_path, args.outtype, model, hparams, quantize_config,
compute_dtype=args.compute_dtype)
elif hparams['model_type'] == "phi-msft":
convert_phi2_gptq_to_bestTLA(model_path, out_path, args.outtype, model, hparams, quantize_config)
convert_phi2_gptq_to_bestTLA(model_path, out_path, args.outtype, model, hparams, quantize_config,
compute_dtype=args.compute_dtype)


if __name__ == '__main__':
38 changes: 26 additions & 12 deletions neural_speed/convert/convert_quantized_qwen.py
Original file line number Diff line number Diff line change
@@ -32,6 +32,10 @@ def main(args_in: Optional[List[str]] = None) -> None:
choices=["huggingface", "modelscope"],
default="huggingface",
help="hub to load model")
parser.add_argument("--compute_dtype",
choices=["fp32", "bf16", "int8"],
default="int8",
help="compute_dtype for model inference")
parser.add_argument("model", type=Path, help="directory containing model file")
args = parser.parse_args(args_in)

@@ -147,6 +151,8 @@ def convert_qwen_to_fp32_tensor(src_name, dst_name, model, fout):
data.tofile(f)

#3. write tensors
cmp_dtype = args.compute_dtype
print("model compute_dtype is {}".format(cmp_dtype))
if hparams['model_type'] == 'qwen':
convert_qwen_to_fp32_tensor("transformer.wte.weight", "transformer.wte.weight", list_vars, f)
convert_qwen_to_fp32_tensor("transformer.ln_f.weight", "transformer.ln_f.weight", list_vars, f)
@@ -160,19 +166,21 @@ def convert_qwen_to_fp32_tensor(src_name, dst_name, model, fout):

# qkv GEMM
convert_to_qx_bestla_tensor(f"transformer.h.{i}.attn.c_attn.weight",
f"transformer.h.{i}.attn.c_attn.weight", list_vars, f, quantize_config)
f"transformer.h.{i}.attn.c_attn.weight", list_vars, f, quantize_config,
compute_dtype=cmp_dtype)
convert_qwen_to_fp32_tensor(f"transformer.h.{i}.attn.c_attn.bias", f"transformer.h.{i}.attn.c_attn.bias",
list_vars, f)
convert_to_qx_bestla_tensor(f"transformer.h.{i}.attn.c_proj.weight",
f"transformer.h.{i}.attn.c_proj.weight", list_vars, f, quantize_config)
f"transformer.h.{i}.attn.c_proj.weight", list_vars, f, quantize_config,
compute_dtype=cmp_dtype)

# ffn GEMM
convert_to_qx_bestla_tensor(f"transformer.h.{i}.mlp.w1.weight", f"transformer.h.{i}.mlp.w1.weight",
list_vars, f, quantize_config)
list_vars, f, quantize_config, compute_dtype=cmp_dtype)
convert_to_qx_bestla_tensor(f"transformer.h.{i}.mlp.w2.weight", f"transformer.h.{i}.mlp.w2.weight",
list_vars, f, quantize_config)
list_vars, f, quantize_config, compute_dtype=cmp_dtype)
convert_to_qx_bestla_tensor(f"transformer.h.{i}.mlp.c_proj.weight", f"transformer.h.{i}.mlp.c_proj.weight",
list_vars, f, quantize_config)
list_vars, f, quantize_config, compute_dtype=cmp_dtype)

f.close()
print(f"Success! saved as {out_path}")
@@ -190,13 +198,17 @@ def convert_qwen_to_fp32_tensor(src_name, dst_name, model, fout):

# qkv GEMM
convert_to_qx_bestla_tensor(f"model.layers.{i}.self_attn.q_proj.weight",
f"model.layers.{i}.self_attn.q_proj.weight", list_vars, f, quantize_config)
f"model.layers.{i}.self_attn.q_proj.weight", list_vars, f, quantize_config,
compute_dtype=cmp_dtype)
convert_to_qx_bestla_tensor(f"model.layers.{i}.self_attn.k_proj.weight",
f"model.layers.{i}.self_attn.k_proj.weight", list_vars, f, quantize_config)
f"model.layers.{i}.self_attn.k_proj.weight", list_vars, f, quantize_config,
compute_dtype=cmp_dtype)
convert_to_qx_bestla_tensor(f"model.layers.{i}.self_attn.v_proj.weight",
f"model.layers.{i}.self_attn.v_proj.weight", list_vars, f, quantize_config)
f"model.layers.{i}.self_attn.v_proj.weight", list_vars, f, quantize_config,
compute_dtype=cmp_dtype)
convert_to_qx_bestla_tensor(f"model.layers.{i}.self_attn.o_proj.weight",
f"model.layers.{i}.self_attn.o_proj.weight", list_vars, f, quantize_config)
f"model.layers.{i}.self_attn.o_proj.weight", list_vars, f, quantize_config,
compute_dtype=cmp_dtype)

convert_qwen_to_fp32_tensor(f"model.layers.{i}.self_attn.q_proj.bias",
f"model.layers.{i}.self_attn.q_proj.bias", list_vars, f)
@@ -207,11 +219,13 @@ def convert_qwen_to_fp32_tensor(src_name, dst_name, model, fout):

# ffn GEMM
convert_to_qx_bestla_tensor(f"model.layers.{i}.mlp.down_proj.weight",
f"model.layers.{i}.mlp.down_proj.weight", list_vars, f, quantize_config)
f"model.layers.{i}.mlp.down_proj.weight", list_vars, f, quantize_config,
compute_dtype=cmp_dtype)
convert_to_qx_bestla_tensor(f"model.layers.{i}.mlp.gate_proj.weight",
f"model.layers.{i}.mlp.gate_proj.weight", list_vars, f, quantize_config)
f"model.layers.{i}.mlp.gate_proj.weight", list_vars, f, quantize_config,
compute_dtype=cmp_dtype)
convert_to_qx_bestla_tensor(f"model.layers.{i}.mlp.up_proj.weight", f"model.layers.{i}.mlp.up_proj.weight",
list_vars, f, quantize_config)
list_vars, f, quantize_config, compute_dtype=cmp_dtype)

f.close()
print(f"Success! saved as {out_path}")
1 change: 1 addition & 0 deletions tests/model-test/cpp_graph_inference.sh
Original file line number Diff line number Diff line change
@@ -160,6 +160,7 @@ model_name_map["stablelm"]="stabilityai/stablelm-2-1_6b"
model_name_map["qwen-1_5"]="Qwen/Qwen1.5-7B-Chat"
model_name_map["mixtral"]="mistralai/Mixtral-8x7B-Instruct-v0.1"
model_name_map["gemma-2b"]="google/gemma-2b-it"
model_name_map["llama2-gptq"]="TheBloke/Llama-2-7B-Chat-GPTQ"
model_name_map["mixtral-gptq"]="Mixtral-8x7B-Instruct-v0.1-GPTQ"
model_name_map["qwen1.5-gptq"]="Qwen/Qwen1.5-7B-Chat-GPTQ"
model_name_map["qwen-gptq"]="TheBloke/Qwen-7B-Chat-GPTQ"

0 comments on commit cadb01e

Please sign in to comment.