Skip to content

Commit

Permalink
Add DML on-Device copy
Browse files Browse the repository at this point in the history
  • Loading branch information
yuslepukhin committed Oct 2, 2024
1 parent 224f065 commit 1a60d8a
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 4 deletions.
33 changes: 33 additions & 0 deletions onnxruntime/core/session/lora_adapters.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
#include "core/providers/cuda/cuda_provider_factory.h"
#endif

#ifdef USE_DML
#include "core/framework/execution_provider.h"
#include "core/session/abi_session_options_impl.h"
#include "core/providers/dml/dml_provider_factory_creator.h"
#include "core/providers/dml/dml_provider_factory.h"
#endif

namespace onnxruntime {

#ifdef USE_CUDA
Expand Down Expand Up @@ -63,6 +70,32 @@ static std::unique_ptr<IDataTransfer> GetDataTransfer(const OrtMemoryInfo& mem_i
if (cuda_provider_info != nullptr) {
data_transfer = cuda_provider_info->CreateGPUDataTransfer();
}
#endif
} else if (strcmp(mem_info.name, onnxruntime::DML) == 0) {
#ifdef USE_DML
auto ep_factory = onnxruntime::DMLProviderFactoryCreator::Create(ConfigOptions{}, 0, false, false, false);
auto dml_ep = ep_factory->CreateProvider();
data_transfer = dml_ep->GetDataTransfer();

//constexpr uint32_t dml_api_version = 0; // This is ignored
//const void* dml_api = nullptr;
//auto* ort_status = OrtApis::GetExecutionProviderApi("DML", dml_api_version, &dml_api);
//if (ort_status == nullptr) {
// const auto* dml_provider_api = reinterpret_cast<const OrtDmlApi*>(dml_api);
// OrtSessionOptions sess_options;
// OrtDmlDeviceOptions dml_dev_options{OrtDmlPerformancePreference::Default, OrtDmlDeviceFilter::Gpu};
// ort_status = dml_provider_api->SessionOptionsAppendExecutionProvider_DML2(&sess_options, &dml_dev_options);
// if (ort_status) {
// Ort::Status status(ort_status);
// ORT_THROW(status.GetErrorMessage());
// }
// ORT_ENFORCE(sess_options.provider_factories.size() == 1, "Expecting a single factory");
// auto dml_ep = sess_options.provider_factories[0]->CreateProvider();
// data_transfer = dml_ep->GetDataTransfer();
//} else {
// Ort::Status status(ort_status);
// ORT_THROW(status.GetErrorMessage());
//}
#endif
}

Expand Down
45 changes: 41 additions & 4 deletions onnxruntime/test/lora/lora_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,13 @@ TEST(LoraAdapterTest, Load) {
}

#ifdef USE_CUDA
TEST(LoraAdapterTest, VerifyDeviceCopy) {
TEST(LoraAdapterTest, VerifyCudaDeviceCopy) {
auto cpu_ep = DefaultCpuExecutionProvider();
auto cpu_allocator = cpu_ep->CreatePreferredAllocators()[0];
auto cuda_ep = DefaultCudaExecutionProvider();
auto cuda_allocator = cuda_ep->CreatePreferredAllocators()[0];

auto gpu_transfer = cuda_ep->GetDataTransfer();
auto dml_transfer = cuda_ep->GetDataTransfer();

auto test_params = GenerateTestParameters<float>()();
lora::LoraAdapter adapter(std::move(cuda_allocator));
Expand All @@ -222,9 +222,9 @@ TEST(LoraAdapterTest, VerifyDeviceCopy) {
ASSERT_EQ(tensor_cpu.Shape().Size(), tensor_device.Shape().Size());

Tensor copy(tensor_cpu.DataType(), tensor_cpu.Shape(), cpu_allocator);
ASSERT_TRUE(gpu_transfer->CanCopy(tensor_device.Location().device,
ASSERT_TRUE(dml_transfer->CanCopy(tensor_device.Location().device,
copy.Location().device));
ASSERT_STATUS_OK(gpu_transfer->CopyTensor(tensor_device, copy));
ASSERT_STATUS_OK(dml_transfer->CopyTensor(tensor_device, copy));

auto expected_span = tensor_cpu.DataAsSpan<float>();
auto copy_span = copy.DataAsSpan<float>();
Expand All @@ -233,5 +233,42 @@ TEST(LoraAdapterTest, VerifyDeviceCopy) {
}
}
#endif

#ifdef USE_DML
TEST(LoraAdapterTest, VerifyDmlDeviceCopy) {
auto cpu_ep = DefaultCpuExecutionProvider();
auto cpu_allocator = cpu_ep->CreatePreferredAllocators()[0];

auto dml_ep = DefaultDmlExecutionProvider();
auto dml_allocator = dml_ep->CreatePreferredAllocators()[0];

auto dml_transfer = dml_ep->GetDataTransfer();

auto test_params = GenerateTestParameters<float>()();
lora::LoraAdapter adapter(std::move(dml_allocator));
adapter.Load(std::move(test_params));

auto [begin, end] = adapter.GetParamIterators();
for (; begin != end; ++begin) {
const auto& [_, param] = *begin;
const auto& tensor_device = param.GetDeviceOrMapped().Get<Tensor>();
ASSERT_EQ(0, strcmp(tensor_device.Location().name, onnxruntime::CUDA));

const auto& tensor_cpu = param.GetMapped().Get<Tensor>();
ASSERT_EQ(tensor_cpu.Shape().Size(), tensor_device.Shape().Size());

Tensor copy(tensor_cpu.DataType(), tensor_cpu.Shape(), cpu_allocator);
ASSERT_TRUE(dml_transfer->CanCopy(tensor_device.Location().device,
copy.Location().device));
ASSERT_STATUS_OK(dml_transfer->CopyTensor(tensor_device, copy));

auto expected_span = tensor_cpu.DataAsSpan<float>();
auto copy_span = copy.DataAsSpan<float>();

ASSERT_EQ(expected_span, copy_span);
}
}
#endif

} // namespace test
} // namespace onnxruntime

0 comments on commit 1a60d8a

Please sign in to comment.