Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
snnn committed Jun 3, 2024
1 parent 2b01660 commit 2bba511
Show file tree
Hide file tree
Showing 37 changed files with 264 additions and 270 deletions.
20 changes: 6 additions & 14 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,7 @@
#include <type_traits>
#include <unordered_map>
#include <unordered_set>

#ifdef _WIN32
#pragma warning(push)
// disable some warnings from protobuf to pass Windows build
#pragma warning(disable : 4244)
#endif

#ifdef _WIN32
#pragma warning(pop)
#endif
#include <filesystem>

#include "core/common/flatbuffers.h"

Expand Down Expand Up @@ -147,7 +138,7 @@ class Node {
const std::string& Domain() const noexcept { return domain_; }

/** Gets the path of the owning model if any. */
const Path& ModelPath() const noexcept;
const std::filesystem::path& ModelPath() const noexcept;

/** Gets the Node's execution priority.
@remarks Lower value means higher priority */
Expand Down Expand Up @@ -693,7 +684,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
const std::string& Description() const noexcept;

/** Gets the path of the owning model, if any. */
const Path& ModelPath() const;
const std::filesystem::path& ModelPath() const;

/** Returns true if this is a subgraph or false if it is a high-level graph. */
bool IsSubgraph() const { return parent_graph_ != nullptr; }
Expand Down Expand Up @@ -1149,13 +1140,14 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
ONNX_NAMESPACE::GraphProto ToGraphProto() const;

/** Gets the GraphProto representation of this Graph
@params external_file_name name of the binary file to use for initializers
@params external_file_name name of the binary file to use for initializers. Must be a UTF-8 string.
@params destination_file_path path of the model file.
@param initializer_size_threshold initializers larger or equal to this threshold (in bytes) are saved
in the external file. Initializer smaller than this threshold are included in the onnx file.
@returns GraphProto serialization of the graph.
*/
ONNX_NAMESPACE::GraphProto ToGraphProtoWithExternalInitializers(const std::string& external_file_name,
const PathString& file_path,
const std::filesystem::path& file_path,
size_t initializer_size_threshold) const;

/** Gets the ISchemaRegistry instances being used with this Graph. */
Expand Down
2 changes: 1 addition & 1 deletion include/onnxruntime/core/graph/graph_viewer.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class GraphViewer {
const std::string& Description() const noexcept;

/** Gets the path of the owning model if any **/
const Path& ModelPath() const noexcept { return graph_->ModelPath(); }
const std::filesystem::path& ModelPath() const noexcept { return graph_->ModelPath(); }

/**
Gets a tensor created from an initializer.
Expand Down
9 changes: 5 additions & 4 deletions onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide

static Status CreateEpContextModel(const ExecutionProviders& execution_providers,
const Graph& graph,
const std::string& ep_context_path,
const std::filesystem::path& ep_context_path,
const logging::Logger& logger) {
InlinedVector<const Node*> all_ep_context_nodes;
for (const auto& ep : execution_providers) {
Expand All @@ -658,13 +658,14 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
return std::make_pair(false, static_cast<const Node*>(nullptr));
};

onnxruntime::PathString context_cache_path;
PathString model_pathstring = graph.ModelPath().ToPathString();
std::filesystem::path context_cache_path;
const std::filesystem::path& model_pathstring = graph.ModelPath();

if (!ep_context_path.empty()) {
// On Windows here we explicitly cast the ep_context_path string to UTF-16 because we assume ep_context_path is in UTF-8
context_cache_path = ToPathString(ep_context_path);
} else if (!model_pathstring.empty()) {
context_cache_path = model_pathstring + ToPathString("_ctx.onnx");
context_cache_path = model_pathstring / ORT_TSTR("_ctx.onnx");
}

{
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/model_metadef_id_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ int ModelMetadefIdGenerator::GenerateId(const onnxruntime::GraphViewer& graph_vi

// prefer path the model was loaded from
// this may not be available if the model was loaded from a stream or in-memory bytes
const auto& model_path_str = main_graph.ModelPath().ToPathString();
const auto& model_path_str = main_graph.ModelPath().string();
if (!model_path_str.empty()) {
MurmurHash3::x86_128(model_path_str.data(), gsl::narrow_cast<int32_t>(model_path_str.size()), hash[0], &hash);
} else {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/node_unit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ const std::string& NodeUnit::OpType() const noexcept { return target_node_.OpTyp
const std::string& NodeUnit::Name() const noexcept { return target_node_.Name(); }
int NodeUnit::SinceVersion() const noexcept { return target_node_.SinceVersion(); }
NodeIndex NodeUnit::Index() const noexcept { return target_node_.Index(); }
const Path& NodeUnit::ModelPath() const noexcept { return target_node_.ModelPath(); }
const std::filesystem::path& NodeUnit::ModelPath() const noexcept { return target_node_.ModelPath(); }
ProviderType NodeUnit::GetExecutionProviderType() const noexcept { return target_node_.GetExecutionProviderType(); }

void NodeUnit::InitForSingleNode() {
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/framework/node_unit.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <string>
#include <optional>
#include <vector>
#include <filesystem>

#include "core/graph/basic_types.h"
#include "core/graph/graph.h"
Expand Down Expand Up @@ -78,7 +79,7 @@ class NodeUnit {
const std::string& Name() const noexcept;
int SinceVersion() const noexcept;
NodeIndex Index() const noexcept;
const Path& ModelPath() const noexcept;
const std::filesystem::path& ModelPath() const noexcept;
ProviderType GetExecutionProviderType() const noexcept;

const Node& GetNode() const noexcept { return target_node_; }
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/session_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ struct SessionOptions {
//
// If session config value is not set, it will be assumed to be ONNX
// unless the filepath ends in '.ort' (case insensitive).
std::basic_string<ORTCHAR_T> optimized_model_filepath;
std::filesystem::path optimized_model_filepath;

// enable the memory pattern optimization.
// The idea is if the input shapes are the same, we could trace the internal memory allocation
Expand Down
Loading

0 comments on commit 2bba511

Please sign in to comment.