Skip to content

Commit

Permalink
Made changes according to PR microsoft#21211
Browse files Browse the repository at this point in the history
  • Loading branch information
glen-amd committed Jul 12, 2024
1 parent a2e5de9 commit e3033e4
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
10 changes: 5 additions & 5 deletions onnxruntime/core/providers/vitisai/imp/ep_context_utils.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
// Standard headers/libs.
#include <fstream>
#include <filesystem>
#include <sstream>
#include <cctype>
#include <cstring>
Expand Down Expand Up @@ -190,7 +189,8 @@ std::string SerializeOrigialGraph(const GraphViewer& graph_viewer) {
}
}
j_obj["orig_graph_name"] = graph_viewer.Name();
j_obj["orig_model_path"] = PathToUTF8String(graph_viewer.ModelPath().ToPathString());
// TODO: platform dependency (Linux vs Windows).
j_obj["orig_model_path"] = graph_viewer.ModelPath().string();

// XXX: `ModelProto::SerializeToString` will lose some info,
// e.g., ModelProto.opset_import.
Expand Down Expand Up @@ -263,7 +263,7 @@ ONNX_NAMESPACE::ModelProto* CreateEPContexModel(
p_attr_3->set_name(kONNXModelFileNameAttr);
// p_attr_3->set_type(onnx::AttributeProto_AttributeType_STRING);
p_attr_3->set_type(ONNX_NAMESPACE::AttributeProto::STRING);
p_attr_3->set_s(fs::path(graph_viewer.ModelPath().ToPathString()).filename().string());
p_attr_3->set_s(graph_viewer.ModelPath().filename().string());
// Attr "notes".
auto p_attr_4 = ONNX_NAMESPACE::AttributeProto::Create();
p_attr_4->set_name(kNotesAttr);
Expand Down Expand Up @@ -435,7 +435,7 @@ void CreateEPContexNodes(
auto p_attr_3 = ONNX_NAMESPACE::AttributeProto::Create();
p_attr_3->set_name(kONNXModelFileNameAttr);
p_attr_3->set_type(ONNX_NAMESPACE::AttributeProto::STRING);
p_attr_3->set_s(fs::path(graph_viewer.ModelPath().ToPathString()).filename().string());
p_attr_3->set_s(graph_viewer.ModelPath().filename().string());
p_node_attrs->emplace(kONNXModelFileNameAttr, *p_attr_3);
// Attr "partition_name".
auto p_attr_6 = ONNX_NAMESPACE::AttributeProto::Create();
Expand Down Expand Up @@ -595,7 +595,7 @@ bool FusedGraphHasEPContextNode(
return false;
}

const Path& GetTopLevelModelPath(const GraphViewer& graph_viewer) {
const fs::path& GetTopLevelModelPath(const GraphViewer& graph_viewer) {
const auto& graph = graph_viewer.GetGraph();
const Graph* p_graph = &graph;
while (p_graph->IsSubgraph()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

// Standard headers/libs.
#include <filesystem>
#include <vector>
#include <string>
#include <memory>
Expand Down Expand Up @@ -67,7 +68,7 @@ bool GraphHasEPContextNode(const Graph&);
bool FusedGraphHasEPContextNode(
const std::vector<IExecutionProvider::FusedNodeAndGraph>&);

const Path& GetTopLevelModelPath(const GraphViewer&);
const fs::path& GetTopLevelModelPath(const GraphViewer&);

bool GetEPContextModelFileLocation(
const std::string&, const PathString&, bool, PathString&);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ void VitisAIExecutionProvider::LoadEPContexModelFromFile() const {
void VitisAIExecutionProvider::PrepareEPContextEnablement(
const onnxruntime::GraphViewer& graph_viewer) const {
if (model_path_str_.empty()) {
model_path_str_ = GetTopLevelModelPath(graph_viewer).ToPathString();
// TODO: platform dependency (Linux vs Windows).
model_path_str_ = ToPathString(GetTopLevelModelPath(graph_viewer).string());
}
std::string backend_cache_dir, backend_cache_key;
get_backend_compilation_cache(model_path_str_, graph_viewer, info_, kXCCode, backend_cache_dir, backend_cache_key, backend_cache_data_);
Expand Down Expand Up @@ -123,7 +124,8 @@ void VitisAIExecutionProvider::FulfillEPContextEnablement(
std::vector<std::unique_ptr<ComputeCapability>> VitisAIExecutionProvider::GetCapability(
const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/) const {
bool is_ep_ctx_model = GraphHasEPContextNode(graph_viewer.GetGraph());
model_path_str_ = GetTopLevelModelPath(graph_viewer).ToPathString();
// TODO: platform dependency (Linux vs Windows).
model_path_str_ = ToPathString(GetTopLevelModelPath(graph_viewer).string());
if (GetEPContextModelFileLocation(
ep_ctx_model_path_cfg_, model_path_str_, is_ep_ctx_model, ep_ctx_model_file_loc_)) {
if (is_ep_ctx_model) {
Expand Down

0 comments on commit e3033e4

Please sign in to comment.