diff --git a/onnxruntime/core/framework/prepacked_weights_container.h b/onnxruntime/core/framework/prepacked_weights_container.h index a0880239bba5f..394dfdc5a93eb 100644 --- a/onnxruntime/core/framework/prepacked_weights_container.h +++ b/onnxruntime/core/framework/prepacked_weights_container.h @@ -82,7 +82,8 @@ class PrepackedWeightsContainer final { /// /// If saving is OFF, it is used to contain the weights memory mapped from disk. /// Those weights are then moved to the shared container if weight sharing is enabled. -/// And also the interested kernels. +/// If x-session weight sharing is not enabled, the weights are stored in this container, +/// and shared with the interested kernels. /// class PrepackedForSerialization final { public: @@ -91,11 +92,11 @@ class PrepackedForSerialization final { ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PrepackedForSerialization); + // Maps a pre-packed weight blob key to PrepackedWeights instance using KeyToBlobMap = std::unordered_map; - // Maps weight name to iterators in key_to_blobs_. It associates a weight name with its pre-packs. - // Normally, a single weight produces a single PrePackedWeights. But it is possible that a weight - // is pre-packed by different kernels. + // WeightToPrePacksMap maps weight name to a set of pre-packed + // keys contained in the KeyToBlobMap using KeysPerWeight = std::unordered_set; // blob keys using WeightToPrePacksMap = std::unordered_map; @@ -138,7 +139,7 @@ class PrepackedForSerialization final { // The function would add or replace existing entry with references to it. // If the entry is present, it would replace it with references to the existing entry. // If the entry is not present, it would add reference to refer_if_absent - // If present it would return the existing entry otherwise std::nullopt + // If the entry is present it would return the existing entry otherwise std::nullopt std::optional ReplaceWithReferenceIfSaving(const std::string& weight_name, const std::string& key, const PrePackedWeights& refer_if_absent); diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 5edea481ba7f0..16d2e53732dde 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -495,7 +495,7 @@ Status SessionState::PrepackConstantInitializedTensors( // everybody can share the same memory mapped entry // the shared container takes ownership of the memory mapped entries - // The next like replaces the existing entry with references to it + // The next line replaces the existing entry with references to it // and returns the container that holds the memory mapped entries // so we can transfer it to shared container. // if there is not an entry, we replace it with references to weights_to_be_filled_in diff --git a/onnxruntime/core/framework/tensor_external_data_info.h b/onnxruntime/core/framework/tensor_external_data_info.h index badd2137f3472..110c1122868da 100644 --- a/onnxruntime/core/framework/tensor_external_data_info.h +++ b/onnxruntime/core/framework/tensor_external_data_info.h @@ -2,14 +2,15 @@ // Licensed under the MIT License. #pragma once +#include #include #include #include #include #include -#include "core/common/status.h" #include "core/common/path_string.h" +#include "core/common/status.h" #include "core/framework/prepacked_weights_container.h" #include "core/graph/onnx_protobuf.h" diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 758c717045537..e2e9987fc8940 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -4209,7 +4209,7 @@ Status Graph::ToGraphProtoWithExternalInitiallizersImpl( } if (!blob_keys_to_external_data.empty()) { - ORT_RETURN_IF_NOT(ExternalDataInfo::WritePrepackedToFileAndAddToProto( + ORT_RETURN_IF_NOT(!!ExternalDataInfo::WritePrepackedToFileAndAddToProto( *model_saving_options.prepacked_for_save, blob_keys_to_external_data, model_saving_options.align_offset, model_saving_options.allocation_granularity, diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index a7dcddfe61c00..aa01772d6fdd3 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -75,10 +75,10 @@ void PrepackedForSerialization::Subgraph::TestHarness kData = {1.2345f, 2.4690f}; + const size_t buffer_size = kData.size() * sizeof(float); prepacked_weights.buffers_.push_back(BufferUniquePtr(kData.data(), BufferDeleter(nullptr))); prepacked_weights.buffer_sizes_.push_back(buffer_size); @@ -57,7 +57,7 @@ TEST(TensorProtoUtilsTest, SetExternalDataInformation) { prepacked_for_serialization.MainGraph().WritePacked(init_name, blob_key, std::move(prepacked_weights)); - const int64_t starting_offset = 300; + constexpr const int64_t starting_offset = 300; int64_t external_offset = starting_offset; std::stringstream ss; const auto* blobs_for_weight = prepacked_for_serialization.MainGraph().GetBlobsForWeight(init_name); diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index 8f2d0f6531500..60708b05626c5 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -11,6 +11,7 @@ #include "core/session/inference_session.h" #include "core/session/environment.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/graph/model_saving_options.h" #include "core/graph/graph_utils.h" #include "orttraining/training_api/checkpoint.h"