Skip to content

Commit

Permalink
Fix an issue that QNN models shared from other session use the sessio…
Browse files Browse the repository at this point in the history
…n logger from that session (microsoft#22170)

### Description
Fix an issue that QNN models shared from other session use the session logger from that producer session also which cause confusion. Make QNN model compute function use the session logger from current session.
  • Loading branch information
HectorSVC authored Sep 22, 2024
1 parent 171b901 commit b636b27
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ Status CreateNodeArgs(const std::vector<std::string>& names,
Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
const logging::Logger& logger,
QnnModelLookupTable& qnn_models) {
ORT_RETURN_IF_NOT(EPCONTEXT_OP == main_context_node.OpType(), "Should only filter in the EPContext node.");
NodeAttrHelper node_helper(main_context_node);
Expand All @@ -97,7 +96,6 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast<char*>(context_binary.c_str()),
static_cast<uint64_t>(context_binary.length()),
main_context_node.Name(),
logger,
qnn_models);
}

Expand Down Expand Up @@ -147,7 +145,6 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(),
static_cast<uint64_t>(buffer_size),
main_context_node.Name(),
logger,
qnn_models);
}

Expand All @@ -158,7 +155,7 @@ Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
const logging::Logger& logger) {
ORT_RETURN_IF(graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!");
Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager,
logger, qnn_models);
qnn_models);

// This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model
if (!status.IsOK()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ bool ValidateContextCacheFilePath(bool is_qnn_ctx_model,
Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
const logging::Logger& logger,
QnnModelLookupTable& qnn_models);

Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
Expand Down
5 changes: 2 additions & 3 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,6 @@ std::unique_ptr<unsigned char[]> QnnBackendManager::GetContextBinaryBuffer(uint6

Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length,
std::string node_name,
const logging::Logger& logger,
QnnModelLookupTable& qnn_models) {
bool result = nullptr == qnn_sys_interface_.systemContextCreate ||
nullptr == qnn_sys_interface_.systemContextGetBinaryInfo ||
Expand Down Expand Up @@ -665,12 +664,12 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
if (1 == graph_count) {
// in case the EPContext node is generated from script
// the graph name from the context binary may not match the EPContext node name
auto qnn_model = std::make_unique<qnn::QnnModel>(logger, this);
auto qnn_model = std::make_unique<qnn::QnnModel>(this);
ORT_RETURN_IF_ERROR(qnn_model->DeserializeGraphInfoFromBinaryInfo(graphs_info[0], context));
qnn_models.emplace(node_name, std::move(qnn_model));
} else {
for (uint32_t i = 0; i < graph_count; ++i) {
auto qnn_model = std::make_unique<qnn::QnnModel>(logger, this);
auto qnn_model = std::make_unique<qnn::QnnModel>(this);
ORT_RETURN_IF_ERROR(qnn_model->DeserializeGraphInfoFromBinaryInfo(graphs_info[i], context));
qnn_models.emplace(graphs_info[i].graphInfoV1.graphName, std::move(qnn_model));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ class QnnBackendManager {

Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length,
std::string node_name,
const logging::Logger& logger,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models);

Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context);
Expand Down
73 changes: 39 additions & 34 deletions onnxruntime/core/providers/qnn/builder/qnn_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,35 @@
namespace onnxruntime {
namespace qnn {

bool QnnModel::GetGraphInfoFromModel(QnnModelWrapper& model_wrapper) {
bool QnnModel::GetGraphInfoFromModel(QnnModelWrapper& model_wrapper, const logging::Logger& logger) {
bool rt = true;

graph_info_ = std::make_unique<GraphInfo>(model_wrapper.GetQnnGraph(),
model_wrapper.GetQnnGraphName(),
std::move(model_wrapper.GetGraphInputTensorWrappers()),
std::move(model_wrapper.GetGraphOutputTensorWrappers()));
if (graph_info_ == nullptr) {
LOGS(logger_, ERROR) << "GetGraphInfoFromModel() failed to allocate GraphInfo.";
LOGS(logger, ERROR) << "GetGraphInfoFromModel() failed to allocate GraphInfo.";
return false;
}

return rt;
}

Status QnnModel::SetGraphInputOutputInfo(const GraphViewer& graph_viewer,
const onnxruntime::Node& fused_node) {
const onnxruntime::Node& fused_node,
const logging::Logger& logger) {
auto graph_initializers = graph_viewer.GetAllInitializedTensors();
for (auto graph_ini : graph_initializers) {
initializer_inputs_.emplace(graph_ini.first);
}
auto input_defs = fused_node.InputDefs();
ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(input_defs, input_names_, inputs_info_, model_input_index_map_, true));
ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(input_defs, input_names_, inputs_info_,
model_input_index_map_, logger, true));

auto output_defs = fused_node.OutputDefs();
ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(output_defs, output_names_, outputs_info_, model_output_index_map_));
ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(output_defs, output_names_, outputs_info_,
model_output_index_map_, logger));

return Status::OK();
}
Expand All @@ -51,6 +54,7 @@ Status QnnModel::ParseGraphInputOrOutput(ConstPointerContainer<std::vector<NodeA
std::vector<std::string>& input_output_names,
std::unordered_map<std::string, OnnxTensorInfo>& input_output_info_table,
std::unordered_map<std::string, size_t>& input_output_index_map,
const logging::Logger& logger,
bool is_input) {
for (size_t i = 0, end = input_output_defs.size(), index = 0; i < end; ++i) {
const auto& name = input_output_defs[i]->Name();
Expand All @@ -60,7 +64,7 @@ Status QnnModel::ParseGraphInputOrOutput(ConstPointerContainer<std::vector<NodeA
}
}
// Validate input/output shape
LOGS(logger_, VERBOSE) << (is_input ? "input " : "output ") << i << " " << name;
LOGS(logger, VERBOSE) << (is_input ? "input " : "output ") << i << " " << name;
input_output_index_map.emplace(name, index++);
const auto* shape_proto = input_output_defs[i]->Shape(); // consider use qnn_model_wrapper.GetOnnxShape
ORT_RETURN_IF(shape_proto == nullptr, "shape_proto cannot be null for output: ", name);
Expand Down Expand Up @@ -91,8 +95,9 @@ const NodeUnit& QnnModel::GetNodeUnit(const Node* node,

Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer,
const onnxruntime::Node& fused_node,
const logging::Logger& logger,
const QnnGraph_Config_t** graph_configs) {
LOGS(logger_, VERBOSE) << "ComposeGraph Graph name: " << graph_viewer.Name();
LOGS(logger, VERBOSE) << "ComposeGraph Graph name: " << graph_viewer.Name();

// Holder for the NodeUnits in the graph, this will guarantee the NodeUnits is
// valid throughout the lifetime of the ModelBuilder
Expand All @@ -102,9 +107,9 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer,

// This name must be same with the EPContext node name
const auto& graph_name = fused_node.Name();
ORT_RETURN_IF_ERROR(SetGraphInputOutputInfo(graph_viewer, fused_node));
ORT_RETURN_IF_ERROR(SetGraphInputOutputInfo(graph_viewer, fused_node, logger));

QnnModelWrapper qnn_model_wrapper = QnnModelWrapper(graph_viewer, logger_,
QnnModelWrapper qnn_model_wrapper = QnnModelWrapper(graph_viewer, logger,
qnn_backend_manager_->GetQnnInterface(),
qnn_backend_manager_->GetQnnBackendHandle(),
model_input_index_map_,
Expand All @@ -121,65 +126,65 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer,
qnn_node_groups.reserve(node_unit_holder.size());

ORT_RETURN_IF_ERROR(qnn::GetQnnNodeGroups(qnn_node_groups, qnn_model_wrapper, node_unit_map,
node_unit_holder.size(), logger_));
node_unit_holder.size(), logger));

for (const std::unique_ptr<qnn::IQnnNodeGroup>& qnn_node_group : qnn_node_groups) {
Status status = qnn_node_group->AddToModelBuilder(qnn_model_wrapper, logger_);
Status status = qnn_node_group->AddToModelBuilder(qnn_model_wrapper, logger);

if (!status.IsOK()) {
LOGS(logger_, ERROR) << "[QNN EP] Failed to add supported node to QNN graph during EP's compile call: "
<< status.ErrorMessage() << std::endl;
LOGS(logger, ERROR) << "[QNN EP] Failed to add supported node to QNN graph during EP's compile call: "
<< status.ErrorMessage() << std::endl;
return status;
}
}

ORT_RETURN_IF_NOT(qnn_model_wrapper.ComposeQnnGraph(), "Failed to compose Qnn graph.");

rt = GetGraphInfoFromModel(qnn_model_wrapper);
rt = GetGraphInfoFromModel(qnn_model_wrapper, logger);
if (!rt) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GetGraphInfoFromModel failed.");
}
LOGS(logger_, VERBOSE) << "GetGraphInfoFromModel completed.";
LOGS(logger, VERBOSE) << "GetGraphInfoFromModel completed.";
return Status::OK();
}

Status QnnModel::FinalizeGraphs() {
LOGS(logger_, VERBOSE) << "FinalizeGraphs started.";
Status QnnModel::FinalizeGraphs(const logging::Logger& logger) {
LOGS(logger, VERBOSE) << "FinalizeGraphs started.";
Qnn_ErrorHandle_t status = qnn_backend_manager_->GetQnnInterface().graphFinalize(graph_info_->Graph(),
qnn_backend_manager_->GetQnnProfileHandle(),
nullptr);
if (QNN_GRAPH_NO_ERROR != status) {
LOGS(logger_, ERROR) << "Failed to finalize QNN graph. Error code: " << status;
LOGS(logger, ERROR) << "Failed to finalize QNN graph. Error code: " << status;
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to finalize QNN graph.");
}

ORT_RETURN_IF_ERROR(qnn_backend_manager_->ExtractBackendProfilingInfo());

LOGS(logger_, VERBOSE) << "FinalizeGraphs completed.";
LOGS(logger, VERBOSE) << "FinalizeGraphs completed.";
return Status::OK();
}

Status QnnModel::SetupQnnInputOutput() {
LOGS(logger_, VERBOSE) << "Setting up QNN input/output for graph: " << graph_info_->Name();
Status QnnModel::SetupQnnInputOutput(const logging::Logger& logger) {
LOGS(logger, VERBOSE) << "Setting up QNN input/output for graph: " << graph_info_->Name();

auto result = SetupTensors(qnn_input_infos_, graph_info_->InputTensors());

if (Status::OK() != result) {
LOGS(logger_, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name();
LOGS(logger, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name();
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to setup QNN input tensors!");
}

result = SetupTensors(qnn_output_infos_, graph_info_->OutputTensors(), false);
if (Status::OK() != result) {
LOGS(logger_, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name();
LOGS(logger, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name();
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to setup QNN output tensors!");
}

return Status::OK();
}

Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) {
LOGS(logger_, VERBOSE) << "QnnModel::ExecuteGraphs";
Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const logging::Logger& logger) {
LOGS(logger, VERBOSE) << "QnnModel::ExecuteGraphs";
const size_t num_inputs = context.GetInputCount();
const size_t num_outputs = context.GetOutputCount();
ORT_RETURN_IF_NOT(qnn_input_infos_.size() <= num_inputs, "Inconsistent input sizes");
Expand All @@ -198,12 +203,12 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) {
qnn_inputs.reserve(qnn_input_infos_.size());

for (const auto& qnn_input_info : qnn_input_infos_) {
LOGS(logger_, VERBOSE) << "model_input = " << qnn_input_info.tensor_wrapper->GetName()
<< " index = " << qnn_input_info.ort_index;
LOGS(logger, VERBOSE) << "model_input = " << qnn_input_info.tensor_wrapper->GetName()
<< " index = " << qnn_input_info.ort_index;
auto ort_input_tensor = context.GetInput(qnn_input_info.ort_index);
auto ort_tensor_size = TensorDataSize(ort_input_tensor);
LOGS(logger_, VERBOSE) << "Qnn tensor size: " << qnn_input_info.tensor_byte_size
<< "Ort tensor size: " << ort_tensor_size;
LOGS(logger, VERBOSE) << "Qnn tensor size: " << qnn_input_info.tensor_byte_size
<< "Ort tensor size: " << ort_tensor_size;
ORT_RETURN_IF_NOT(qnn_input_info.tensor_byte_size == ort_tensor_size,
"ORT Tensor data size does not match QNN tensor data size.");

Expand All @@ -217,13 +222,13 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) {

for (auto& qnn_output_info : qnn_output_infos_) {
const std::string& model_output_name = qnn_output_info.tensor_wrapper->GetName();
LOGS(logger_, VERBOSE) << "model_output = " << model_output_name << " index = " << qnn_output_info.ort_index;
LOGS(logger, VERBOSE) << "model_output = " << model_output_name << " index = " << qnn_output_info.ort_index;
const auto& ort_output_info = GetOutputInfo(model_output_name);
const std::vector<int64_t>& output_shape = ort_output_info->shape_;
auto ort_output_tensor = context.GetOutput(qnn_output_info.ort_index, output_shape.data(), output_shape.size());
auto ort_tensor_size = TensorDataSize(ort_output_tensor);
LOGS(logger_, VERBOSE) << "Qnn tensor size: " << qnn_output_info.tensor_byte_size
<< "Ort tensor size: " << ort_tensor_size;
LOGS(logger, VERBOSE) << "Qnn tensor size: " << qnn_output_info.tensor_byte_size
<< "Ort tensor size: " << ort_tensor_size;
ORT_RETURN_IF_NOT(qnn_output_info.tensor_byte_size == ort_tensor_size,
"ORT Tensor data size does not match QNN tensor data size");

Expand All @@ -232,7 +237,7 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) {
const_cast<void*>(ort_output_tensor.GetTensorData<void>()), qnn_output_info.tensor_byte_size);
}

LOGS(logger_, VERBOSE) << "Start execute QNN graph:" << graph_info_->Name();
LOGS(logger, VERBOSE) << "Start execute QNN graph:" << graph_info_->Name();
auto qnn_interface = qnn_backend_manager_->GetQnnInterface();
auto profile_backend_handle = qnn_backend_manager_->GetQnnProfileHandle();
Qnn_ErrorHandle_t execute_status = QNN_GRAPH_NO_ERROR;
Expand All @@ -257,7 +262,7 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) {

if (QNN_COMMON_ERROR_SYSTEM_COMMUNICATION == execute_status) {
auto error_message = "NPU crashed. SSR detected. Caused QNN graph execute error. Error code: ";
LOGS(logger_, ERROR) << error_message << execute_status;
LOGS(logger, ERROR) << error_message << execute_status;
return ORT_MAKE_STATUS(ONNXRUNTIME, ENGINE_ERROR, error_message, execute_status);
}

Expand Down
20 changes: 10 additions & 10 deletions onnxruntime/core/providers/qnn/builder/qnn_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@ struct QnnTensorInfo {

class QnnModel {
public:
QnnModel(const logging::Logger& logger,
QnnBackendManager* qnn_backend_manager)
: logger_(logger),
qnn_backend_manager_(qnn_backend_manager) {
QnnModel(QnnBackendManager* qnn_backend_manager)
: qnn_backend_manager_(qnn_backend_manager) {
qnn_backend_type_ = qnn_backend_manager_->GetQnnBackendType();
}

Expand All @@ -37,13 +35,14 @@ class QnnModel {

Status ComposeGraph(const GraphViewer& graph_viewer,
const onnxruntime::Node& fused_node,
const logging::Logger& logger,
const QnnGraph_Config_t** graph_configs = nullptr);

Status FinalizeGraphs();
Status FinalizeGraphs(const logging::Logger& logger);

Status SetupQnnInputOutput();
Status SetupQnnInputOutput(const logging::Logger& logger);

Status ExecuteGraph(const Ort::KernelContext& context);
Status ExecuteGraph(const Ort::KernelContext& context, const logging::Logger& logger);

const OnnxTensorInfo* GetOutputInfo(const std::string& name) const {
auto it = outputs_info_.find(name);
Expand All @@ -55,11 +54,13 @@ class QnnModel {
}

Status SetGraphInputOutputInfo(const GraphViewer& graph_viewer,
const onnxruntime::Node& fused_node);
const onnxruntime::Node& fused_node,
const logging::Logger& logger);
Status ParseGraphInputOrOutput(ConstPointerContainer<std::vector<NodeArg*>>& input_output_defs,
std::vector<std::string>& input_output_names,
std::unordered_map<std::string, OnnxTensorInfo>& input_output_info_table,
std::unordered_map<std::string, size_t>& input_output_index,
const logging::Logger& logger,
bool is_input = false);

const std::unordered_set<std::string>& GetInitializerInputs() const { return initializer_inputs_; }
Expand Down Expand Up @@ -107,7 +108,7 @@ class QnnModel {
private:
const NodeUnit& GetNodeUnit(const Node* node,
const std::unordered_map<const Node*, const NodeUnit*>& node_unit_map) const;
bool GetGraphInfoFromModel(QnnModelWrapper& model_wrapper);
bool GetGraphInfoFromModel(QnnModelWrapper& model_wrapper, const logging::Logger& logger);

Status GetQnnTensorDataLength(const std::vector<uint32_t>& dims,
Qnn_DataType_t data_type,
Expand All @@ -125,7 +126,6 @@ class QnnModel {
}

private:
const logging::Logger& logger_;
std::unique_ptr<GraphInfo> graph_info_;
QnnBackendManager* qnn_backend_manager_ = nullptr;
// <input_name, input_index>, initializer inputs are excluded, keep the input index here
Expand Down
Loading

0 comments on commit b636b27

Please sign in to comment.