-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding ortvalue features support for MGX EP #81
base: rocm6.3_internal_testing
Are you sure you want to change the base?
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,10 +7,36 @@ | |
#include <string> | ||
|
||
#include "core/framework/ortdevice.h" | ||
#include "core/common/hash_combine.h" | ||
#include "core/framework/arena_extend_strategy.h" | ||
#include "core/framework/provider_options.h" | ||
#include "core/session/onnxruntime_c_api.h" | ||
|
||
namespace onnxruntime { | ||
|
||
// Information needed to construct MIGraphX execution providers. | ||
struct MIGraphXExecutionProviderExternalAllocatorInfo { | ||
void* alloc{nullptr}; | ||
void* free{nullptr}; | ||
void* empty_cache{nullptr}; | ||
|
||
MIGraphXExecutionProviderExternalAllocatorInfo() { | ||
alloc = nullptr; | ||
free = nullptr; | ||
empty_cache = nullptr; | ||
} | ||
|
||
MIGraphXExecutionProviderExternalAllocatorInfo(void* a, void* f, void* e) { | ||
alloc = a; | ||
free = f; | ||
empty_cache = e; | ||
} | ||
|
||
bool UseExternalAllocator() const { | ||
return (alloc != nullptr) && (free != nullptr); | ||
} | ||
}; | ||
|
||
// Information needed to construct trt execution providers. | ||
struct MIGraphXExecutionProviderInfo { | ||
std::string target_device; | ||
|
@@ -25,8 +51,43 @@ | |
std::string load_model_file{"./compiled_model.mxr"}; | ||
bool exhaustive_tune{false}; | ||
|
||
size_t mem_limit{std::numeric_limits<size_t>::max()}; // Will be over-ridden by contents of `default_memory_arena_cfg` (if specified) | ||
Check warning on line 54 in onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h GitHub Actions / cpplint[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h#L54
Raw output
|
||
ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; // Will be over-ridden by contents of `default_memory_arena_cfg` (if specified) | ||
Check warning on line 55 in onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h GitHub Actions / cpplint[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h#L55
Raw output
|
||
|
||
OrtArenaCfg* default_memory_arena_cfg{nullptr}; | ||
MIGraphXExecutionProviderExternalAllocatorInfo external_allocator_info{}; | ||
|
||
static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); | ||
static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info); | ||
static ProviderOptions ToProviderOptions(const OrtMIGraphXProviderOptions& info); | ||
}; | ||
} // namespace onnxruntime | ||
|
||
template <> | ||
struct std::hash<::onnxruntime::MIGraphXExecutionProviderInfo> { | ||
size_t operator()(const ::onnxruntime::MIGraphXExecutionProviderInfo& info) const { | ||
size_t value{0xbc9f1d34}; // seed | ||
|
||
// Bits: device_id (16), arena_extend_strategy (reserved 2), boolean options (1 each) | ||
size_t data = static_cast<size_t>(info.device_id) ^ | ||
(static_cast<size_t>(info.arena_extend_strategy) << 16) ^ | ||
(static_cast<size_t>(info.fp16_enable) << 18) ^ | ||
(static_cast<size_t>(info.int8_enable) << 19) ^ | ||
(static_cast<size_t>(info.int8_use_native_calibration_table) << 20) ^ | ||
(static_cast<size_t>(info.model_cache_enable) << 21) ^ | ||
(static_cast<size_t>(info.save_compiled_model) << 22) ^ | ||
(static_cast<size_t>(info.load_compiled_model) << 23) ^ | ||
(static_cast<size_t>(info.exhaustive_tune) << 24); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Going forward is the intent to add the other flags (fp16/int8) and other quantize modes in here as well? |
||
onnxruntime::HashCombine(data, value); | ||
|
||
onnxruntime::HashCombine(info.mem_limit, value); | ||
|
||
// Memory pointers | ||
onnxruntime::HashCombine(reinterpret_cast<size_t>(info.external_allocator_info.alloc), value); | ||
onnxruntime::HashCombine(reinterpret_cast<size_t>(info.external_allocator_info.free), value); | ||
onnxruntime::HashCombine(reinterpret_cast<size_t>(info.external_allocator_info.empty_cache), value); | ||
|
||
// The default memory arena cfg is not used in hashing right now. | ||
return value; | ||
} | ||
}; |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
#include "core/providers/shared_library/provider_api.h" | ||
#include "core/providers/migraphx/migraphx_provider_factory.h" | ||
#include "migraphx_execution_provider.h" | ||
#include "migraphx_execution_provider_info.h" | ||
Check warning on line 8 in onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc GitHub Actions / cpplint[cpplint] onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc#L8
Raw output
|
||
#include "migraphx_provider_factory_creator.h" | ||
#include "migraphx_allocator.h" | ||
#include "gpu_data_transfer.h" | ||
|
@@ -42,6 +43,27 @@ | |
return std::make_unique<HIPPinnedAllocator>(device_id, name); | ||
} | ||
|
||
void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, size_t count) override { | ||
// hipMemcpy() operates on the default stream | ||
HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyHostToDevice)); | ||
|
||
// To ensure that the copy has completed, invoke a stream sync for the default stream. | ||
// For transfers from pageable host memory to device memory, a stream sync is performed before the copy is initiated. | ||
Check warning on line 51 in onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc GitHub Actions / cpplint[cpplint] onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc#L51
Raw output
|
||
// The function will return once the pageable buffer has been copied to the staging memory for DMA transfer | ||
// to device memory, but the DMA to final destination may not have completed. | ||
|
||
HIP_CALL_THROW(hipStreamSynchronize(0)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we always want to be using hipstream 0 for this? |
||
} | ||
|
||
// Used by onnxruntime_pybind_state.cc | ||
void MIGraphXMemcpy_DeviceToHost(void* dst, const void* src, size_t count) override { | ||
// For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed. | ||
Check warning on line 60 in onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc GitHub Actions / cpplint[cpplint] onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc#L60
Raw output
|
||
HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyDeviceToHost)); | ||
} | ||
|
||
std::shared_ptr<IAllocator> CreateMIGraphXAllocator(int16_t device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override { | ||
Check warning on line 64 in onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc GitHub Actions / cpplint[cpplint] onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc#L64
Raw output
Check warning on line 64 in onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc GitHub Actions / cpplint[cpplint] onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc#L64
Raw output
|
||
return MIGraphXExecutionProvider::CreateMIGraphXAllocator(device_id, migx_mem_limit, arena_extend_strategy, external_allocator_info, default_memory_arena_cfg); | ||
Check warning on line 65 in onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc GitHub Actions / cpplint[cpplint] onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc#L65
Raw output
|
||
} | ||
} g_info; | ||
|
||
struct MIGraphX_Provider : Provider { | ||
|
@@ -77,6 +99,8 @@ | |
if (options.migraphx_load_model_path != nullptr) { | ||
info.load_model_file = options.migraphx_load_model_path; | ||
} | ||
info.arena_extend_strategy = static_cast<onnxruntime::ArenaExtendStrategy>(options.migraphx_arena_extend_strategy); | ||
info.mem_limit = options.migraphx_mem_limit; | ||
return std::make_shared<MIGraphXProviderFactory>(info); | ||
} | ||
|
||
|
@@ -109,6 +133,8 @@ | |
migx_options.migraphx_save_model_path = internal_options.save_model_file.c_str(); | ||
migx_options.migraphx_load_compiled_model = internal_options.load_compiled_model; | ||
migx_options.migraphx_load_model_path = internal_options.load_model_file.c_str(); | ||
migx_options.migraphx_arena_extend_strategy = static_cast<int>(internal_options.arena_extend_strategy); | ||
migx_options.migraphx_mem_limit = internal_options.mem_limit; | ||
} | ||
|
||
ProviderOptions GetProviderOptions(const void* provider_options) override { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this something we want to make controllable from he API later?