From 87e8a5dfa843f37598b344a28660e3bca86d03f9 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 14 Oct 2024 12:26:50 -0700 Subject: [PATCH] Implement DML copy for Lora Adapters (#22396) ### Description Request and create DML EP and its data transfer. Use to copy on device. The PR includes changes to fix issues in DML provider. ### Motivation and Context This enables Lora users to run it with DML which is important for GenAI. Co-authored-by: @PatriceVignola --------- Co-authored-by: Patrice Vignola --- .../src/BucketizedBufferAllocator.cpp | 2 +- .../src/BucketizedBufferAllocator.h | 6 ++ .../DmlExecutionProvider/src/CommandQueue.cpp | 10 +-- .../DmlExecutionProvider/src/CommandQueue.h | 4 +- .../src/ExecutionContext.cpp | 37 +--------- .../src/ExecutionContext.h | 12 +--- .../src/ExecutionProvider.cpp | 21 +++++- .../providers/dml/dml_provider_factory.cc | 4 +- onnxruntime/core/session/lora_adapters.cc | 69 ++++++++++++++----- .../python/onnxruntime_pybind_mlvalue.cc | 2 +- onnxruntime/test/lora/lora_test.cc | 47 +++++++++++-- 11 files changed, 136 insertions(+), 78 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp index b1714a8220cd1..801cceb3bd99f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp @@ -186,7 +186,7 @@ namespace Dml } else { - if (!m_context->IsClosed()) + if (!m_closed) { // Free the underlying allocation once queued work has completed. #ifdef _GAMING_XBOX diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h index 16283d5b19c9c..65bc9b7f69316 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h @@ -46,6 +46,11 @@ namespace Dml void SetDefaultRoundingMode(AllocatorRoundingMode roundingMode); + void Close() + { + m_closed = true; + } + public: // onnxruntime::IAllocator void* Alloc(size_t size, AllocatorRoundingMode roundingMode); void* Alloc(size_t size) final; @@ -83,6 +88,7 @@ namespace Dml std::vector m_pool; size_t m_currentAllocationId = 0; uint64_t m_currentResourceId = 0; + bool m_closed = false; // Unless specifically requested, allocation sizes are not rounded to enable pooling // until SetDefaultRoundingMode is called. This should be done at completion of session diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.cpp index 988324bab1174..67faf333d21e1 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.cpp @@ -55,7 +55,7 @@ namespace Dml // for example, an allocation from BucketizedBufferAllocator attempts to queue a reference // to its underlying D3D resource when freed. Furthermore, these references are unnecessary // since Close() already blocks for scheduled GPU work before clearing m_queuedReferences. - if (!m_closing) + if (!m_clearingQueue) { QueuedReference queuedReference = {GetLastFenceValue(), object}; @@ -70,15 +70,15 @@ namespace Dml } } - void CommandQueue::Close() + void CommandQueue::WaitForSignalAndClearQueue() { // Wait for flushed work: - assert(!m_closing); - m_closing = true; + assert(!m_clearingQueue); + m_clearingQueue = true; GpuEvent event = GetCurrentCompletionEvent(); event.WaitForSignal(m_cpuSyncSpinningEnabled); m_queuedReferences.clear(); - m_closing = false; + m_clearingQueue = false; } void CommandQueue::ReleaseCompletedReferences() diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.h index 71d5eb173cfec..9a4728d5845d4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.h @@ -44,7 +44,7 @@ namespace Dml } #endif - void Close(); + void WaitForSignalAndClearQueue(); void ReleaseCompletedReferences(); private: @@ -61,7 +61,7 @@ namespace Dml ComPtr m_fence; uint64_t m_lastFenceValue = 0; - bool m_closing = false; + bool m_clearingQueue = false; bool m_cpuSyncSpinningEnabled = false; }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp index 5dc1213bd76f0..ececf13fc8cdf 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp @@ -11,13 +11,10 @@ namespace Dml ID3D12Device* d3d12Device, IDMLDevice* dmlDevice, ID3D12CommandQueue* queue, - bool cpuSyncSpinningEnabled, - bool keepOpen - ) + bool cpuSyncSpinningEnabled) : m_queue(std::make_shared(queue, cpuSyncSpinningEnabled)) , m_dmlRecorder(d3d12Device, dmlDevice, m_queue) , m_cpuSyncSpinningEnabled(cpuSyncSpinningEnabled) - , m_keepOpen(keepOpen) { ORT_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(m_d3dDevice.GetAddressOf()))); } @@ -36,8 +33,6 @@ namespace Dml D3D12_RESOURCE_STATES srcState, uint64_t byteCount) { - assert(!m_closed); - SetCommandRecorder(&m_dmlRecorder); std::vector barriers; @@ -84,8 +79,6 @@ namespace Dml _Out_ uint64_t* completionValue ) { - assert(!m_closed); - SetCommandRecorder(&m_dmlRecorder); m_dmlRecorder.ExecuteCommandList(commandList, fence, completionValue); } @@ -95,7 +88,6 @@ namespace Dml const DML_BINDING_DESC& persistentResourceBinding, const DML_BINDING_DESC& inputArrayBinding) { - assert(!m_closed); SetCommandRecorder(&m_dmlRecorder); m_dmlRecorder.InitializeOperator(op, persistentResourceBinding, inputArrayBinding); @@ -107,7 +99,6 @@ namespace Dml gsl::span inputBindings, gsl::span outputBindings) { - assert(!m_closed); SetCommandRecorder(&m_dmlRecorder); m_dmlRecorder.ExecuteOperator(op, persistentResourceBinding, inputBindings, outputBindings); @@ -115,7 +106,6 @@ namespace Dml void ExecutionContext::AddUAVBarrier() { - assert(!m_closed); SetCommandRecorder(&m_dmlRecorder); m_dmlRecorder.AddUAVBarrier(); @@ -123,7 +113,6 @@ namespace Dml void ExecutionContext::ResourceBarrier(gsl::span barriers) { - assert(!m_closed); SetCommandRecorder(&m_dmlRecorder); m_dmlRecorder.ResourceBarrier(barriers); @@ -131,7 +120,6 @@ namespace Dml void ExecutionContext::GetCommandListForRecordingAndInvalidateState(ID3D12GraphicsCommandList** commandList) { - assert(!m_closed); SetCommandRecorder(&m_dmlRecorder); // Ensure the descriptor heap is reset to D3D as something external may change it before recording @@ -142,8 +130,6 @@ namespace Dml void ExecutionContext::SetCommandRecorder(ICommandRecorder* newRecorder) { - assert(!m_closed); - // If changing which recorder is the current one, we need to flush the old one first. This is to ensure correct // ordering of operations on the command queue. if (m_currentRecorder != newRecorder) @@ -160,8 +146,6 @@ namespace Dml void ExecutionContext::Flush() { - assert(!m_closed); - if (!m_currentRecorder || !m_currentRecorder->HasUnsubmittedWork()) { // Nothing to flush @@ -180,34 +164,21 @@ namespace Dml void ExecutionContext::QueueReference(IUnknown* object) { - assert(!m_closed); // If something has been recorded into a command list but not submitted yet, it means that the *next* fence // value is the one to signal completion. bool waitForUnsubmittedWork = (m_currentRecorder != nullptr); m_queue->QueueReference(object, waitForUnsubmittedWork); } - void ExecutionContext::Close() + void ExecutionContext::WaitForSignalAndClearQueue() { - assert(!m_closed); - // Discard unflushed work and clear queued references. This prevents the circular reference: // Kernel --> ProviderImpl --> Context --> QueuedRefs --> Kernel - m_queue->Close(); - - // Keep the execution context open when requested, e.g. when used through the python API where there's a single context - // and single command queue - if (!m_keepOpen) - { - m_currentRecorder = nullptr; - m_closed = true; - } + m_queue->WaitForSignalAndClearQueue(); } GpuEvent ExecutionContext::GetCurrentCompletionEvent() { - assert(!m_closed); - GpuEvent event = m_queue->GetCurrentCompletionEvent(); // If something has been recorded into a command list but not submitted yet, it means that the *next* fence @@ -223,13 +194,11 @@ namespace Dml void ExecutionContext::ReleaseCompletedReferences() { - assert(!m_closed); m_queue->ReleaseCompletedReferences(); } D3D12_COMMAND_LIST_TYPE ExecutionContext::GetCommandListTypeForQueue() const { - assert(!m_closed); return m_queue->GetType(); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h index e7a6fa3d07296..71aa26f4a0148 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h @@ -23,14 +23,13 @@ namespace Dml ID3D12Device* d3d12Device, IDMLDevice* dmlDevice, ID3D12CommandQueue* queue, - bool cpuSyncSpinningEnabled, - bool keepOpen); + bool cpuSyncSpinningEnabled); void SetAllocator(std::weak_ptr allocator); // Waits for flushed work, discards unflushed work, and discards associated references to - // prevent circular references. Must be the last call on the object before destruction. - void Close(); + // prevent circular references. + void WaitForSignalAndClearQueue(); // Queues a CopyBufferRegion (see ID3D12GraphicsCommandList::CopyBufferRegion) for execution. Transition // barriers are automatically inserted to transition the source and destination resources to COPY_SOURCE and @@ -87,7 +86,6 @@ namespace Dml D3D12_COMMAND_LIST_TYPE GetCommandListTypeForQueue() const; bool CpuSyncSpinningEnabled() const { return m_cpuSyncSpinningEnabled; } - bool IsClosed() const { return m_closed; } private: Microsoft::WRL::ComPtr m_d3dDevice; @@ -103,10 +101,6 @@ namespace Dml bool m_closed = false; bool m_cpuSyncSpinningEnabled = false; - - // The python API has a global state used for I/O binding where the execution context is shared between session, - // so we don't want to close the context when one of the sessions is destroyed - bool m_keepOpen = false; }; } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 9c01df13741e1..6b0faaad43175 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -106,7 +106,26 @@ namespace Dml // Release the cached command list references before closing the context m_capturedGraphs.clear(); - m_context->Close(); + // Close the allocator before clearing the command queue to stop it from + // appending resources to it in an attempt to keep them alive. + if (m_allocator) + { + m_allocator->Close(); + } + + // Destroy the allocators. We are closing the execution provider, so from now on the + // only thing it will be used for is doing copies via the DataTransfer, which doesn't + // require allocating any memory. + // TODO: Move the copy functions over to ExecutionContext so that we are able to cleanly + // destroy ExecutionProviderImpl, and instead have the DataTransfer keep the context alive. + m_allocator = nullptr; + m_cpuInputAllocator = nullptr; + + // Wait for all pending commands to be done executing and empty the command queue. This will + // Force all kernels and resources in flight to get destroyed and, from this point forward, + // ExecutionProviderImpl will only be used to execute transfer between resources that are + // already existing via the DataTransfer; + m_context->WaitForSignalAndClearQueue(); } void ExecutionProviderImpl::WaitForOutstandingWork() diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index e8fe235fc1d46..89decfef6fef6 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -86,11 +86,11 @@ std::unique_ptr DMLProviderFactory::CreateProvider() { // First, check if an I/O binding API that was used before this session or another session has already created a queue if (FAILED(d3d12_device->GetPrivateData(dml_execution_context_guid, &execution_context_ptr_size, execution_context.GetAddressOf()))) { - execution_context = wil::MakeOrThrow(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), true, true); + execution_context = wil::MakeOrThrow(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), true); ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_execution_context_guid, execution_context.Get())); } } else { - execution_context = wil::MakeOrThrow(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), cpu_sync_spinning_enabled_, false); + execution_context = wil::MakeOrThrow(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), cpu_sync_spinning_enabled_); } auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), execution_context.Get(), metacommands_enabled_, graph_capture_enabled_, cpu_sync_spinning_enabled_, disable_memory_arena_); diff --git a/onnxruntime/core/session/lora_adapters.cc b/onnxruntime/core/session/lora_adapters.cc index 466edce187a56..a095027a1d417 100644 --- a/onnxruntime/core/session/lora_adapters.cc +++ b/onnxruntime/core/session/lora_adapters.cc @@ -4,10 +4,9 @@ #include "core/session/lora_adapters.h" #include "lora/adapter_format_utils.h" -#include - #include "core/framework/data_transfer.h" #include "core/framework/error_code_helper.h" +#include "core/framework/execution_provider.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/allocator_adapters.h" #include "core/session/ort_apis.h" @@ -16,6 +15,15 @@ #include "core/providers/cuda/cuda_provider_factory.h" #endif +#ifdef USE_DML +#include "core/session/abi_session_options_impl.h" +#include "core/providers/dml/dml_provider_factory_creator.h" +#include "core/providers/dml/dml_provider_factory.h" +#endif + +#include +#include + namespace onnxruntime { #ifdef USE_CUDA @@ -50,28 +58,56 @@ void LoraAdapter::MemoryMap(const std::filesystem::path& file_path) { InitializeParamsValues(); } -static std::unique_ptr GetDataTransfer(const OrtMemoryInfo& mem_info) { +namespace { +struct DataTransfer { + std::unique_ptr ep; std::unique_ptr data_transfer; - - if (strcmp(mem_info.name, onnxruntime::CPU) == 0) { - return data_transfer; + Status CopyTensor(const Tensor& src, Tensor& dst) const { + return data_transfer->CopyTensor(src, dst); + } + Status Sync() const { +#if USE_DML + return ep->Sync(); +#else + return Status::OK(); +#endif } +}; +} // namespace +static Status GetDataTransfer(const OrtMemoryInfo& mem_info, [[maybe_unused]] DataTransfer& dt) { + ORT_RETURN_IF(strcmp(mem_info.name, onnxruntime::CPU) == 0, "Expecting on device allocator for LoraAdapter"); + + Status status; if (strcmp(mem_info.name, onnxruntime::CUDA) == 0) { #ifdef USE_CUDA auto* cuda_provider_info = TryGetProviderInfo_CUDA(); if (cuda_provider_info != nullptr) { - data_transfer = cuda_provider_info->CreateGPUDataTransfer(); + dt.data_transfer = cuda_provider_info->CreateGPUDataTransfer(); + } else { + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA provider could not be loaded"); } +#else + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA provider is not enabled in this build"); +#endif + } else if (strcmp(mem_info.name, onnxruntime::DML) == 0) { +#ifdef USE_DML + auto ep_factory = onnxruntime::DMLProviderFactoryCreator::Create(ConfigOptions{}, 0, false, false, false); + dt.ep = ep_factory->CreateProvider(); + dt.data_transfer = dt.ep->GetDataTransfer(); +#else + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "DML provider is not enabled in this build"); #endif + } else { + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported device allocator"); } - return data_transfer; + return status; } static Status CreateOrtValueOnDevice(const OrtValue& ort_value_mapped, const AllocatorPtr& device_allocator, - const IDataTransfer& data_transfer, + const DataTransfer& data_transfer, OrtValue& out) { OrtValue result; const auto& src = ort_value_mapped.Get(); @@ -87,12 +123,9 @@ void LoraAdapter::InitializeParamsValues() { ORT_THROW("Adapter is not loaded yet."); } - std::unique_ptr data_transfer; + DataTransfer data_transfer; if (device_allocator_) { - data_transfer = GetDataTransfer(device_allocator_->Info()); - if (data_transfer == nullptr) { - ORT_THROW("Data transfer is not available for the specified device allocator, it also must not be a CPU allocator"); - } + ORT_THROW_IF_ERROR(GetDataTransfer(device_allocator_->Info(), data_transfer)); } const auto* params = adapter_->parameters(); @@ -100,12 +133,12 @@ void LoraAdapter::InitializeParamsValues() { std::unordered_map params_values; params_values.reserve(params->size()); // Re-work in two separate loops due to compiler issues - if (data_transfer) { + if (device_allocator_) { for (const auto* param : *params) { auto [name, ort_value] = adapters::utils::CreateOrtValueOverLoraParameter(*param); OrtValue ort_value_ondevice; ORT_THROW_IF_ERROR(CreateOrtValueOnDevice(ort_value, device_allocator_, - *data_transfer, ort_value_ondevice)); + data_transfer, ort_value_ondevice)); Param lora_param(std::move(ort_value), std::move(ort_value_ondevice)); params_values.emplace(std::move(name), std::move(lora_param)); } @@ -117,6 +150,10 @@ void LoraAdapter::InitializeParamsValues() { } } + if (device_allocator_) { + ORT_THROW_IF_ERROR(data_transfer.Sync()); + } + params_values_.swap(params_values); } diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 084ee6bc50698..ebb1a54facbeb 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -226,7 +226,7 @@ AllocatorPtr GetDmlAllocator(OrtDevice::DeviceId id) { auto dml_device = onnxruntime::DMLProviderFactoryCreator::CreateDMLDevice(d3d12_device.Get()); ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_device_guid, dml_device.Get())); - context = wil::MakeOrThrow(d3d12_device.Get(), dml_device.Get(), cmd_queue.Get(), true, true); + context = wil::MakeOrThrow(d3d12_device.Get(), dml_device.Get(), cmd_queue.Get(), true); ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_execution_context_guid, context.Get())); } diff --git a/onnxruntime/test/lora/lora_test.cc b/onnxruntime/test/lora/lora_test.cc index e8291a36447ca..fde603858f9a9 100644 --- a/onnxruntime/test/lora/lora_test.cc +++ b/onnxruntime/test/lora/lora_test.cc @@ -200,13 +200,11 @@ TEST(LoraAdapterTest, Load) { } #ifdef USE_CUDA -TEST(LoraAdapterTest, VerifyDeviceCopy) { +TEST(LoraAdapterTest, VerifyCudaDeviceCopy) { auto cpu_ep = DefaultCpuExecutionProvider(); auto cpu_allocator = cpu_ep->CreatePreferredAllocators()[0]; - auto cuda_ep = DefaultCudaExecutionProvider(); - auto cuda_allocator = cuda_ep->CreatePreferredAllocators()[0]; - - auto gpu_transfer = cuda_ep->GetDataTransfer(); + auto cuda_allocator = DefaultCudaExecutionProvider()->CreatePreferredAllocators()[0]; + auto cuda_transfer = DefaultCudaExecutionProvider()->GetDataTransfer(); auto test_params = GenerateTestParameters()(); lora::LoraAdapter adapter(std::move(cuda_allocator)); @@ -222,9 +220,43 @@ TEST(LoraAdapterTest, VerifyDeviceCopy) { ASSERT_EQ(tensor_cpu.Shape().Size(), tensor_device.Shape().Size()); Tensor copy(tensor_cpu.DataType(), tensor_cpu.Shape(), cpu_allocator); - ASSERT_TRUE(gpu_transfer->CanCopy(tensor_device.Location().device, + ASSERT_TRUE(cuda_transfer->CanCopy(tensor_device.Location().device, + copy.Location().device)); + ASSERT_STATUS_OK(cuda_transfer->CopyTensor(tensor_device, copy)); + + auto expected_span = tensor_cpu.DataAsSpan(); + auto copy_span = copy.DataAsSpan(); + + ASSERT_EQ(expected_span, copy_span); + } +} +#endif + +#ifdef USE_DML +TEST(LoraAdapterTest, VerifyDmlDeviceCopy) { + auto cpu_ep = DefaultCpuExecutionProvider(); + auto cpu_allocator = cpu_ep->CreatePreferredAllocators()[0]; + + auto dml_allocator = DefaultDmlExecutionProvider()->CreatePreferredAllocators()[0]; + auto dml_transfer = DefaultDmlExecutionProvider()->GetDataTransfer(); + + auto test_params = GenerateTestParameters()(); + lora::LoraAdapter adapter(std::move(dml_allocator)); + adapter.Load(std::move(test_params)); + + auto [begin, end] = adapter.GetParamIterators(); + for (; begin != end; ++begin) { + const auto& [_, param] = *begin; + const auto& tensor_device = param.GetDeviceOrMapped().Get(); + ASSERT_EQ(0, strcmp(tensor_device.Location().name, onnxruntime::DML)); + + const auto& tensor_cpu = param.GetMapped().Get(); + ASSERT_EQ(tensor_cpu.Shape().Size(), tensor_device.Shape().Size()); + + Tensor copy(tensor_cpu.DataType(), tensor_cpu.Shape(), cpu_allocator); + ASSERT_TRUE(dml_transfer->CanCopy(tensor_device.Location().device, copy.Location().device)); - ASSERT_STATUS_OK(gpu_transfer->CopyTensor(tensor_device, copy)); + ASSERT_STATUS_OK(dml_transfer->CopyTensor(tensor_device, copy)); auto expected_span = tensor_cpu.DataAsSpan(); auto copy_span = copy.DataAsSpan(); @@ -233,5 +265,6 @@ TEST(LoraAdapterTest, VerifyDeviceCopy) { } } #endif + } // namespace test } // namespace onnxruntime