From 20f6cea7fc883fd6f62f6ad51fd625002f60d084 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Tue, 30 Jul 2024 12:31:58 -0700 Subject: [PATCH] resolve comments --- onnxruntime/core/framework/tensorprotoutils.cc | 2 ++ .../selectors_actions/qdq_actions.cc | 17 ++++++++++------- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index dd6e6e5376cdc..cbd53298ab2ad 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -1268,6 +1268,8 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, *mutable_string_data->Add() = *f; } } else if (use_tensor_buffer && tensor.SizeInBytes() > 127) { + // The logic aligns with + // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph_flatbuffers_utils.cc#L302 const auto* raw_data = tensor.DataRaw(); ORT_ENFORCE(raw_data, "Missing raw data for tensor proto. Invalid tensor."); static_assert(sizeof(void*) <= sizeof(ExternalDataInfo::OFFSET_TYPE)); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 9870b9c14575a..8f99b7409d4fe 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -348,28 +348,31 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, auto uint8_type = DataTypeImpl::TensorTypeFromONNXEnum(ONNX_NAMESPACE::TensorProto_DataType_UINT8)->GetElementType(); auto scale_type = DataTypeImpl::TensorTypeFromONNXEnum(scale_src.data_type())->GetElementType(); std::optional zp_src_ptr; + auto cpu_allocator = std::make_shared(); auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_T"); auto weight_dst_ptr = std::make_unique(uint8_type, TensorShape{N, quant_num, blob_bytes}, - std::make_shared()); + cpu_allocator); auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_T"); + auto scale_size = (TensorShape{N, quant_num}).Size(); auto scale_dst_ptr = std::make_unique(scale_type, - TensorShape{N * quant_num}, - std::make_shared()); + TensorShape{scale_size}, + cpu_allocator); std::string zp_dst_name; std::unique_ptr zp_dst_ptr; + auto zp_size = (TensorShape{N, (quant_num + 1) / 2}).Size(); if (zp_tensor_proto) { zp_src_ptr.emplace(*zp_tensor_proto, graph.ModelPath()); zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_T"); zp_dst_ptr = std::make_unique(uint8_type, - TensorShape{N * ((quant_num + 1) / 2)}, - std::make_shared()); + TensorShape{zp_size}, + cpu_allocator); } else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { zp_dst_name = graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"); zp_dst_ptr = std::make_unique(uint8_type, - TensorShape{N * ((quant_num + 1) / 2)}, - std::make_shared()); + TensorShape{zp_size}, + cpu_allocator); memset(zp_dst_ptr->MutableDataRaw(), 0, zp_dst_ptr->SizeInBytes()); }