Skip to content

Commit

Permalink
Merge branch 'yufeng/matmul_int4' into kvaishnavi/llama_int4_gqa
Browse files Browse the repository at this point in the history
  • Loading branch information
yufenglee committed Oct 2, 2023
2 parents 6c0cdea + a1a7e8d commit 5925510
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto])
packed, scales, zero_points = self.int4_block_quant(B_array)
B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806
B_quant.name = B.name + "_Q4"
Bs_graph.initializer.remove(B)
for input in Bs_graph.input:
if input.name == inputB:
Bs_graph.input.remove(input)
Expand Down Expand Up @@ -179,6 +178,7 @@ def process(self):
opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)])

self._process_subgraph(graph_stack)
self.model.clean_initializers()


def parse_args():
Expand Down Expand Up @@ -222,7 +222,7 @@ def parse_args():
logger.error(f"file {output_model_path} already exists")
raise Exception(f"file {output_model_path} already exists")

model = load_model_with_shape_infer(Path(input_model_path))
model = onnx.load(input_model_path)
quant = MatMul4BitsQuantizer(model, args.block_size, args.symmetric, nodes_to_exclude = args.nodes_to_exclude)
quant.process()
quant.model.save_model_to_file(output_model_path, True)

0 comments on commit 5925510

Please sign in to comment.