diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index 4ca2791e26ab9..d6b7e0f3df0ba 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -622,6 +622,21 @@ typedef struct OrtMIGraphXProviderOptions {
int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true
const char* migraphx_load_model_path; // migraphx model path name
bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false
+
+ /** \brief MIGraphX memory limit (To use all possible memory pass in maximum size_t)
+ * Defaults to SIZE_MAX.
+ * \note If a ::OrtArenaCfg has been applied, it will override this field
+ */
+ size_t migraphx_mem_limit;
+
+ /** \brief Strategy used to grow the memory arena
+ * 0 = kNextPowerOfTwo
+ * 1 = kSameAsRequested
+ * Defaults to 0.
+ * \note If a ::OrtArenaCfg has been applied, it will override this field
+ */
+ int migraphx_arena_extend_strategy;
+
} OrtMIGraphXProviderOptions;
/** \brief OpenVINO Provider Options
diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
index fd36b8ae5f678..1a7a3a27b755b 100644
--- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
+++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
@@ -13,10 +13,11 @@
#include "core/common/safeint.h"
#include "core/common/logging/severity.h"
#include "migraphx_execution_provider.h"
+#include "migraphx_execution_provider_info.h"
#include "migraphx_execution_provider_utils.h"
#include "migraphx_allocator.h"
#include "gpu_data_transfer.h"
-#include "migraphx_inc.h"
+#include "migraphx_call.h"
#include "migraphx_stream_handle.h"
@@ -208,12 +209,50 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
MIGraphXExecutionProvider::~MIGraphXExecutionProvider() {
}
+AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(OrtDevice::DeviceId device_id,
+ size_t migx_mem_limit,
+ ArenaExtendStrategy arena_extend_strategy,
+ MIGraphXExecutionProviderExternalAllocatorInfo
+ external_allocator_info,
+ const OrtArenaCfg* default_memory_arena_cfg) {
+ if (external_allocator_info.UseExternalAllocator()) {
+ AllocatorCreationInfo default_memory_info(
+ [external_allocator_info](OrtDevice::DeviceId id) {
+ return std::make_unique(id, HIP,
+ external_allocator_info.alloc,
+ external_allocator_info.free,
+ external_allocator_info.empty_cache);
+ },
+ device_id,
+ false);
+
+ return CreateAllocator(default_memory_info);
+ } else {
+ AllocatorCreationInfo default_memory_info(
+ [](OrtDevice::DeviceId id) {
+ return std::make_unique(id, HIP);
+ },
+ device_id,
+ true,
+ {default_memory_arena_cfg ? *default_memory_arena_cfg
+ : OrtArenaCfg(migx_mem_limit, static_cast(arena_extend_strategy),
+ -1, -1, -1, -1L)},
+ // make it stream aware
+ true,
+ // enable cross stream sharing?
+ false);
+
+ // ROCM malloc/free is expensive so always use an arena
+ return CreateAllocator(default_memory_info);
+ }
+}
+
std::vector MIGraphXExecutionProvider::CreatePreferredAllocators() {
AllocatorCreationInfo default_memory_info(
- [](OrtDevice::DeviceId device_id) { return CreateMIGraphXAllocator(device_id, onnxruntime::CUDA); }, info_.device_id);
+ [](OrtDevice::DeviceId device_id) { return std::make_unique(device_id, onnxruntime::CUDA); }, info_.device_id);
AllocatorCreationInfo pinned_allocator_info(
[](OrtDevice::DeviceId device_id) {
- return CreateMIGraphXPinnedAllocator(device_id, onnxruntime::CUDA_PINNED);
+ return std::make_unique(device_id, onnxruntime::CUDA_PINNED);
},
0);
return std::vector{CreateAllocator(default_memory_info), CreateAllocator(pinned_allocator_info)};
diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h
index 2be6c09551a71..f79b720e758e3 100644
--- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h
+++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h
@@ -7,7 +7,7 @@
#include "core/framework/execution_provider.h"
#include "core/platform/ort_mutex.h"
#include "core/providers/migraphx/migraphx_execution_provider_info.h"
-#include "core/providers/migraphx/migraphx_inc.h"
+#include "core/providers/migraphx/migraphx_call.h"
#include