Skip to content

Commit

Permalink
Address code review comment
Browse files Browse the repository at this point in the history
  • Loading branch information
pranavsharma committed Oct 28, 2024
1 parent 7690e26 commit 19d7f35
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 6 deletions.
5 changes: 4 additions & 1 deletion onnxruntime/core/framework/allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,16 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA
*out = new OrtMemoryInfo(onnxruntime::CPU, type, OrtDevice(), id1, mem_type1);
} else if (strcmp(name1, onnxruntime::CUDA) == 0 ||
strcmp(name1, onnxruntime::OpenVINO_GPU) == 0 ||
strcmp(name1, onnxruntime::DML) == 0 ||
strcmp(name1, onnxruntime::HIP) == 0 ||
strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0 ||
strcmp(name1, onnxruntime::WEBNN_TENSOR) == 0) {
*out = new OrtMemoryInfo(
name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,
mem_type1);
} else if (strcmp(name1, onnxruntime::DML) == 0) {
*out = new OrtMemoryInfo(
name1, type, OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,
mem_type1);
} else if (strcmp(name1, onnxruntime::OpenVINO_RT_NPU) == 0) {
*out = new OrtMemoryInfo(
name1, type, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ namespace Dml
OrtMemoryInfo(
"DML",
OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)
OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0)
)
),
m_device(device),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace Dml
OrtMemoryInfo(
"DML",
OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)
OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0)
))
{
m_device = onnxruntime::DMLProviderFactoryCreator::CreateD3D12Device(device_id, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ namespace Dml
bool enableGraphCapture,
bool enableSyncSpinning,
bool disableMemoryArena) :
IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0))
IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0))
{
D3D12_COMMAND_LIST_TYPE queueType = executionContext->GetCommandListTypeForQueue();
if (queueType != D3D12_COMMAND_LIST_TYPE_DIRECT && queueType != D3D12_COMMAND_LIST_TYPE_COMPUTE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ namespace Dml

bool CanCopy(const OrtDevice& srcDevice, const OrtDevice& dstDevice) const final
{
return (srcDevice.Type() == OrtDevice::GPU) ||
(dstDevice.Type() == OrtDevice::GPU);
return (srcDevice.Type() == OrtDevice::DML) ||
(dstDevice.Type() == OrtDevice::DML);
}

private:
Expand Down

0 comments on commit 19d7f35

Please sign in to comment.