Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose Reserve() in OrtAllocator to allow custom allocators to work when session.use_device_allocator_for_initializers is specified. #19904

Merged
merged 1 commit into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion include/onnxruntime/core/framework/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ class IAllocator {

virtual void Free(void* p) = 0;

// TODO: Find a better name than Reserve() and update in all places.
// Reserve() is an interface exposed for an implementation of IAllocator
// to optionally implement some allocation logic that by-passes any arena-based
// logic that may be housed in the Alloc() implementation.
Expand Down
6 changes: 6 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,12 @@
void*(ORT_API_CALL* Alloc)(struct OrtAllocator* this_, size_t size); ///< Returns a pointer to an allocated block of `size` bytes
void(ORT_API_CALL* Free)(struct OrtAllocator* this_, void* p); ///< Free a block of memory previously allocated with OrtAllocator::Alloc
const struct OrtMemoryInfo*(ORT_API_CALL* Info)(const struct OrtAllocator* this_); ///< Return a pointer to an ::OrtMemoryInfo that describes this allocator
/**
* @brief Optional allocation function to use for memory allocations made during session initialization.
* Use this function if you want to separate allocations made by ORT during Run() calls from
* those made during session initialization. This allows for separate memory management strategies for these allocations.
*/
void*(ORT_API_CALL* Reserve)(struct OrtAllocator* this_, size_t size); ///< Returns a pointer to an allocated block of `size` bytes

Check warning on line 327 in include/onnxruntime/core/session/onnxruntime_c_api.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: include/onnxruntime/core/session/onnxruntime_c_api.h:327: Lines should be <= 120 characters long [whitespace/line_length] [2]
pranavsharma marked this conversation as resolved.
Show resolved Hide resolved
} OrtAllocator;

typedef void(ORT_API_CALL* OrtLoggingFunction)(
Expand Down
16 changes: 16 additions & 0 deletions onnxruntime/core/session/allocator_adapters.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,20 @@
[](OrtAllocator* this_, void* p) { static_cast<OrtAllocatorImplWrappingIAllocator*>(this_)->Free(p); };
OrtAllocator::Info =
[](const OrtAllocator* this_) { return static_cast<const OrtAllocatorImplWrappingIAllocator*>(this_)->Info(); };
if (OrtAllocator::version >= 18) {
OrtAllocator::Reserve =
[](OrtAllocator* this_, size_t size) { return static_cast<OrtAllocatorImplWrappingIAllocator*>(this_)->Reserve(size); };

Check warning on line 22 in onnxruntime/core/session/allocator_adapters.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/session/allocator_adapters.cc:22: Lines should be <= 120 characters long [whitespace/line_length] [2]
}
}

void* OrtAllocatorImplWrappingIAllocator::Alloc(size_t size) {
return i_allocator_->Alloc(size);
}

void* OrtAllocatorImplWrappingIAllocator::Reserve(size_t size) {
return i_allocator_->Reserve(size);
}

void OrtAllocatorImplWrappingIAllocator::Free(void* p) {
i_allocator_->Free(p);
}
Expand All @@ -42,6 +50,14 @@
return ort_allocator_->Alloc(ort_allocator_, size);
}

void* IAllocatorImplWrappingOrtAllocator::Reserve(size_t size) {
if (ort_allocator_->version >= 18 && ort_allocator_->Reserve) {
return ort_allocator_->Reserve(ort_allocator_, size);
}

return ort_allocator_->Alloc(ort_allocator_, size);
}

void IAllocatorImplWrappingOrtAllocator::Free(void* p) {
return ort_allocator_->Free(ort_allocator_, p);
}
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/session/allocator_adapters.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ struct OrtAllocatorImplWrappingIAllocator final : public OrtAllocatorImpl {
void Free(void* p);

const OrtMemoryInfo* Info() const;
void* Reserve(size_t size);

ORT_DISALLOW_COPY_AND_ASSIGNMENT(OrtAllocatorImplWrappingIAllocator);

Expand All @@ -43,6 +44,7 @@ class IAllocatorImplWrappingOrtAllocator final : public IAllocator {
~IAllocatorImplWrappingOrtAllocator() override = default;

void* Alloc(size_t size) override;
void* Reserve(size_t size) override;

void Free(void* p) override;

Expand Down
16 changes: 15 additions & 1 deletion onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2861,6 +2861,17 @@
expected_dims_y,
expected_values_y,
nullptr);

// create session 3 to test separate allocation for initializers
session_options.AddConfigEntry("session.use_device_allocator_for_initializers", "1");
Ort::Session session3(*ort_env, MODEL_URI, session_options);
RunSession<float>(allocator_for_input_memory_allocation.get(),
session3,
inputs,
"Y",
expected_dims_y,
expected_values_y,
nullptr);
}

// Remove the registered shared allocator from the global environment
Expand All @@ -2871,7 +2882,10 @@
// We should have seen 2 allocations per session (one for the sole initializer
// and one for the output). So, for two sessions, we should have seen 4 allocations.
size_t num_allocations = custom_allocator.NumAllocations();
ASSERT_TRUE(num_allocations == 4);
pranavsharma marked this conversation as resolved.
Show resolved Hide resolved
ASSERT_TRUE(num_allocations == 6);

Check warning on line 2886 in onnxruntime/test/shared_lib/test_inference.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Consider using ASSERT_EQ instead of ASSERT_TRUE(a == b) [readability/check] [2] Raw Output: onnxruntime/test/shared_lib/test_inference.cc:2886: Consider using ASSERT_EQ instead of ASSERT_TRUE(a == b) [readability/check] [2]
size_t num_reserve_allocations = custom_allocator.NumReserveAllocations();
ASSERT_TRUE(num_reserve_allocations == 1);

// Ensure that there was no leak
custom_allocator.LeakCheck();
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/test/util/include/test_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ struct MockedOrtAllocator : OrtAllocator {
void* Alloc(size_t size);
void Free(void* p);
const OrtMemoryInfo* Info() const;
void* Reserve(size_t size);
size_t NumAllocations() const;
size_t NumReserveAllocations() const;

void LeakCheck();

Expand All @@ -24,5 +26,6 @@ struct MockedOrtAllocator : OrtAllocator {

std::atomic<size_t> memory_inuse{0};
std::atomic<size_t> num_allocations{0};
std::atomic<size_t> num_reserve_allocations{0};
OrtMemoryInfo* cpu_memory_info;
};
17 changes: 17 additions & 0 deletions onnxruntime/test/util/test_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast<MockedOrtAllocator*>(this_)->Alloc(size); };
OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast<MockedOrtAllocator*>(this_)->Free(p); };
OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast<const MockedOrtAllocator*>(this_)->Info(); };
OrtAllocator::Reserve = [](OrtAllocator* this_, size_t size) { return static_cast<MockedOrtAllocator*>(this_)->Reserve(size); };

Check warning on line 12 in onnxruntime/test/util/test_allocator.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/util/test_allocator.cc:12: Lines should be <= 120 characters long [whitespace/line_length] [2]
Ort::ThrowOnError(Ort::GetApi().CreateCpuMemoryInfo(OrtDeviceAllocator, OrtMemTypeDefault, &cpu_memory_info));
}

Expand All @@ -30,6 +31,18 @@
return (char*)p + extra_len;
}

void* MockedOrtAllocator::Reserve(size_t size) {
constexpr size_t extra_len = sizeof(size_t);
memory_inuse.fetch_add(size += extra_len);
void* p = new (std::nothrow) uint8_t[size];
if (p == nullptr)
return p;
num_allocations.fetch_add(1);
num_reserve_allocations.fetch_add(1);
*(size_t*)p = size;

Check warning on line 42 in onnxruntime/test/util/test_allocator.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use reinterpret_cast<size_t*>(...) instead [readability/casting] [4] Raw Output: onnxruntime/test/util/test_allocator.cc:42: Using C-style cast. Use reinterpret_cast<size_t*>(...) instead [readability/casting] [4]
return (char*)p + extra_len;

Check warning on line 43 in onnxruntime/test/util/test_allocator.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use reinterpret_cast<char*>(...) instead [readability/casting] [4] Raw Output: onnxruntime/test/util/test_allocator.cc:43: Using C-style cast. Use reinterpret_cast<char*>(...) instead [readability/casting] [4]
}

void MockedOrtAllocator::Free(void* p) {
constexpr size_t extra_len = sizeof(size_t);
if (!p) return;
Expand All @@ -47,6 +60,10 @@
return num_allocations.load();
}

size_t MockedOrtAllocator::NumReserveAllocations() const {
return num_reserve_allocations.load();
}

void MockedOrtAllocator::LeakCheck() {
if (memory_inuse.load())
ORT_THROW("memory leak!!!");
Expand Down
Loading