Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
fajin-corp committed Jul 30, 2024
1 parent 5a9d15c commit 20f6cea
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
2 changes: 2 additions & 0 deletions onnxruntime/core/framework/tensorprotoutils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,8 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor,
*mutable_string_data->Add() = *f;
}
} else if (use_tensor_buffer && tensor.SizeInBytes() > 127) {
// The logic aligns with
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph_flatbuffers_utils.cc#L302
const auto* raw_data = tensor.DataRaw();
ORT_ENFORCE(raw_data, "Missing raw data for tensor proto. Invalid tensor.");
static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,28 +348,31 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph,
auto uint8_type = DataTypeImpl::TensorTypeFromONNXEnum(ONNX_NAMESPACE::TensorProto_DataType_UINT8)->GetElementType();
auto scale_type = DataTypeImpl::TensorTypeFromONNXEnum(scale_src.data_type())->GetElementType();
std::optional<Initializer> zp_src_ptr;
auto cpu_allocator = std::make_shared<CPUAllocator>();
auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_T");
auto weight_dst_ptr = std::make_unique<Tensor>(uint8_type,
TensorShape{N, quant_num, blob_bytes},
std::make_shared<CPUAllocator>());
cpu_allocator);
auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_T");
auto scale_size = (TensorShape{N, quant_num}).Size();
auto scale_dst_ptr = std::make_unique<Tensor>(scale_type,
TensorShape{N * quant_num},
std::make_shared<CPUAllocator>());
TensorShape{scale_size},
cpu_allocator);
std::string zp_dst_name;
std::unique_ptr<Tensor> zp_dst_ptr;
auto zp_size = (TensorShape{N, (quant_num + 1) / 2}).Size();

if (zp_tensor_proto) {
zp_src_ptr.emplace(*zp_tensor_proto, graph.ModelPath());
zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_T");
zp_dst_ptr = std::make_unique<Tensor>(uint8_type,
TensorShape{N * ((quant_num + 1) / 2)},
std::make_shared<CPUAllocator>());
TensorShape{zp_size},
cpu_allocator);
} else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) {
zp_dst_name = graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T");
zp_dst_ptr = std::make_unique<Tensor>(uint8_type,
TensorShape{N * ((quant_num + 1) / 2)},
std::make_shared<CPUAllocator>());
TensorShape{zp_size},
cpu_allocator);
memset(zp_dst_ptr->MutableDataRaw(), 0, zp_dst_ptr->SizeInBytes());
}

Expand Down

0 comments on commit 20f6cea

Please sign in to comment.