Skip to content

Commit

Permalink
change target name generation
Browse files Browse the repository at this point in the history
  • Loading branch information
fajin-corp committed Jun 21, 2024
1 parent cc10612 commit 9857148
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -341,17 +341,17 @@ void DQMatMulReplaceWithMatMulNBits::AddTransposedInitializers(Graph& graph,
Initializer scale_src(*scale_tensor_proto, graph.ModelPath());
std::unique_ptr<Initializer> zp_src_ptr = nullptr;
Initializer weight_dst(ONNX_NAMESPACE::TensorProto_DataType_UINT8,
weight_arg->Name() + "_T",
graph.GenerateNodeArgName(weight_arg->Name() + "_T"),
std::vector<int64_t>{N, quant_num, blob_bytes});
Initializer scale_dst(static_cast<ONNX_NAMESPACE::TensorProto_DataType>(scale_src.data_type()),
scale_arg->Name() + "_T",
graph.GenerateNodeArgName(scale_arg->Name() + "_T"),
std::vector<int64_t>{N * quant_num});
std::unique_ptr<Initializer> zp_dst_ptr = nullptr;

if (zp_tensor_proto) {
zp_src_ptr = std::make_unique<Initializer>(*zp_tensor_proto, graph.ModelPath());
zp_dst_ptr = std::make_unique<Initializer>(ONNX_NAMESPACE::TensorProto_DataType_UINT8,
zp_arg->Name() + "_T",
graph.GenerateNodeArgName(zp_arg->Name() + "_T"),
std::vector<int64_t>{N * ((quant_num + 1) / 2)});
}

Expand Down

0 comments on commit 9857148

Please sign in to comment.