Skip to content

Commit

Permalink
Test pass
Browse files Browse the repository at this point in the history
  • Loading branch information
yuslepukhin committed Dec 5, 2024
1 parent 0584f09 commit 4c15b4c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 16 deletions.
19 changes: 12 additions & 7 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -489,16 +489,16 @@ Status SessionState::PrepackConstantInitializedTensors(
// In the saving mode we choose to overwrite the pre-packed weight in the container so we
// write out the most recent version of the pre-packed data
if (prepacked_weights_for_serialization_.IsSaveModeOn()) {
// Here we take references to the shared container owned data, so we unmap any entries
// that we are mapping from disk
// Here we take references to the shared container owned data, so we unmap this entry
// if it came from disk in thise session.
WritePrepackedForSaving(input_name, prepacked_weights_container_key, prepacked_shared,
prepacked_subgraph);
}

} else { // container doesn't contain the pre-packed weight - so write into it for sharing across kernel instances

if (!prepacked_weights_for_serialization_.IsSaveModeOn()) {
// Check if we loaded it from disk, then shared it in the container
// Check if we loaded it from disk, then put it into the shared container
// the shared container takes ownership of the memory mapped entries
auto prepacked_from_disk =
prepacked_weights_for_serialization_.TakePrepackedWeights(prepacked_weights_container_key);
Expand All @@ -520,8 +520,6 @@ Status SessionState::PrepackConstantInitializedTensors(
// In the saving mode we choose to overwrite the pre-packed weight in the container so we
// write out the most recent version of the pre-packed data
if (prepacked_weights_for_serialization_.IsSaveModeOn()) {
// Here we take references to the shared container owned data, so we unmap any entries
// that we are mapping from disk, so we write the most fresh data possible
WritePrepackedForSaving(input_name, prepacked_weights_container_key, shared_prepacked,
prepacked_subgraph);
}
Expand All @@ -541,7 +539,10 @@ Status SessionState::PrepackConstantInitializedTensors(
is_packed,
&weights_to_be_filled_in));

if (is_packed) {
// Some kernels (matmul_nbits) do not share their pre-packed results even though
// they set is_packed = true
// so we leave it up to them.
if (is_packed && !weights_to_be_filled_in.buffers_.empty()) {
const auto& op_type = node.OpType();
const std::string prepacked_weights_container_key = GenerateKeyForPrepackedWeightsMap(
op_type,
Expand All @@ -551,13 +552,16 @@ Status SessionState::PrepackConstantInitializedTensors(
const auto* weights_to_use = prepacked_subgraph.GetPrepackedWeights(
prepacked_weights_container_key);

// In both saving mode and none-saving, we use serialization container to own the data
// and share it.
if (prepacked_subgraph.IsSaveModeOn() || weights_to_use == nullptr) {
// In this case pre-packed container owns the data
prepacked_subgraph.WritePackedForSaving(input_name, prepacked_weights_container_key,
std::move(weights_to_be_filled_in));
weights_to_use = prepacked_subgraph.GetPrepackedWeights(prepacked_weights_container_key);
assert(weights_to_use != nullptr);
}

ORT_RETURN_IF_ERROR(KernelUseSharedPrePackedBuffers(*kernel, input_idx,
*weights_to_use,
node.Name()));
Expand Down Expand Up @@ -603,7 +607,8 @@ Status SessionState::PrepackConstantInitializedTensors(
}
}

static int64_t CalculateMemoryPatternsKey(const gsl::span<const OrtValue>& tensor_inputs) {
static int64_t
CalculateMemoryPatternsKey(const gsl::span<const OrtValue>& tensor_inputs) {
int64_t key = 0;
for (const auto& input : tensor_inputs) {
for (auto dim : input.Get<Tensor>().Shape().GetDims()) key ^= dim;
Expand Down
21 changes: 12 additions & 9 deletions onnxruntime/test/framework/session_state_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,8 @@ class SessionStateTestSharedInitalizersWithPrePacking : public ::testing::Test {
}
};

// Pre-packing enabled + no shared initializers = no pre-packed weights caching
// Pre-packing enabled + no shared initializers, however, we put all the pre-packs
// in a session_state container for onwership.
TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test1) {
SessionOptions sess_options;
sess_options.enable_mem_pattern = true;
Expand Down Expand Up @@ -679,10 +680,11 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test1) {

const auto* kernel = reinterpret_cast<const PrePackingTestOpKernel*>(session_state_1.GetKernel(0));

// Assert that a pre-pack call was made and that no mechanism to store weight from shared container was invoked
// Assert that a pre-pack call was made. However, they sharing call is still made from a serialized container.
ASSERT_EQ(session_state_1.GetNumberOfPrepacksCounter(), static_cast<size_t>(1));
ASSERT_EQ(kernel->prepack_calls_count, 1);
ASSERT_EQ(kernel->store_pre_packed_weight_calls_count, 0);
// In this case the sharing comes from the serialized container
ASSERT_EQ(kernel->store_pre_packed_weight_calls_count, 1);

// Second session/model
Model model_2("graph_main", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
Expand All @@ -706,10 +708,11 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test1) {

kernel = reinterpret_cast<const PrePackingTestOpKernel*>(session_state_2.GetKernel(0));

// Assert that a pre-pack call was made and that no mechanism to store weight from shared container was invoked
// Assert that a pre-pack call was made. The weights are still shared from the serialized container
// either because they are loaded from disk or because we share it from there.
ASSERT_EQ(session_state_2.GetNumberOfPrepacksCounter(), static_cast<size_t>(1));
ASSERT_EQ(kernel->prepack_calls_count, 1);
ASSERT_EQ(kernel->store_pre_packed_weight_calls_count, 0);
ASSERT_EQ(kernel->store_pre_packed_weight_calls_count, 1);
}

// Pre-packing enabled + shared initializers + no pre-packed weights container = no pre-packed weights caching
Expand Down Expand Up @@ -754,10 +757,10 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test2) {

const auto* kernel = reinterpret_cast<const PrePackingTestOpKernel*>(session_state_1.GetKernel(0));

// Assert that a pre-pack call was made and that no mechanism to store weight from shared container was invoked
// Assert that a pre-pack call was made, but sharing still takes place from the serialized container
ASSERT_EQ(session_state_1.GetNumberOfPrepacksCounter(), static_cast<size_t>(1));
ASSERT_EQ(kernel->prepack_calls_count, 1);
ASSERT_EQ(kernel->store_pre_packed_weight_calls_count, 0);
ASSERT_EQ(kernel->store_pre_packed_weight_calls_count, 1);

// Second session/model
Model model_2("graph_main", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
Expand All @@ -781,10 +784,10 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test2) {

kernel = reinterpret_cast<const PrePackingTestOpKernel*>(session_state_2.GetKernel(0));

// Assert that a pre-pack call was made and that no mechanism to store weight from shared container was invoked
// Assert that a pre-pack call was made, but sharing still takes place from the serialized container
ASSERT_EQ(session_state_2.GetNumberOfPrepacksCounter(), static_cast<size_t>(1));
ASSERT_EQ(kernel->prepack_calls_count, 1);
ASSERT_EQ(kernel->store_pre_packed_weight_calls_count, 0);
ASSERT_EQ(kernel->store_pre_packed_weight_calls_count, 1);
}

// Pre-packing enabled + shared initializers + pre-packed weights container = pre-packed weights caching enabled
Expand Down

0 comments on commit 4c15b4c

Please sign in to comment.