Skip to content

Commit

Permalink
Replace "DML CPU" Allocator with onnxruntime::CpuAllocator (#21818)
Browse files Browse the repository at this point in the history
### Description
Replace "DML CPU" Allocator with onnxruntime::CpuAllocator

### Motivation and Context
This allocator is being ignored by ORTExtensions and causes CPU memory
to be treated as non-CPU memory and crash in SentencepieceTokenizer.

In general it seems like this allocator is not used and can be handled
just fine by the default allocator.

---------

Co-authored-by: Sheil Kumar <[email protected]>
  • Loading branch information
smk2007 and Sheil Kumar authored Aug 23, 2024
1 parent 5726318 commit 44dcc3a
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -223,28 +223,4 @@ namespace Dml
{
m_defaultRoundingMode = roundingMode;
}

CPUAllocator::CPUAllocator(OrtMemType memType)
: onnxruntime::IAllocator(
OrtMemoryInfo(
"DML CPU",
OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, 0),
0,
memType
)
)
{
}

void* CPUAllocator::Alloc(size_t size)
{
return onnxruntime::AllocatorDefaultAlloc(size);
}

void CPUAllocator::Free(void* p)
{
return onnxruntime::AllocatorDefaultFree(p);
}

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,9 @@ namespace Dml
std::make_unique<DmlCommittedResourceAllocator>(m_d3d12Device.Get()));
m_context->SetAllocator(m_allocator);
// CPU Allocator used to create buffers for the MemcpyFromHost, Shape and Size operators.
m_cpuInputAllocator = std::make_shared<CPUAllocator>(OrtMemType::OrtMemTypeCPUInput);
OrtMemoryInfo memoryInfo(onnxruntime::CPU, OrtAllocatorType::OrtDeviceAllocator);
memoryInfo.mem_type = ::OrtMemType::OrtMemTypeCPUInput;
m_cpuInputAllocator = std::make_shared<onnxruntime::CPUAllocator>(memoryInfo);
}

return std::vector<onnxruntime::AllocatorPtr>{m_allocator, m_cpuInputAllocator,};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ namespace Dml
class ReadbackHeap;
class ExecutionContext;
class BucketizedBufferAllocator;
class CPUAllocator;
class ExecutionProvider;

class ExecutionProviderImpl : public WRL::Base<Dml::IExecutionProvider,
Expand Down Expand Up @@ -213,7 +212,7 @@ namespace Dml
std::unique_ptr<PooledUploadHeap> m_uploadHeap;
std::unique_ptr<ReadbackHeap> m_readbackHeap;
std::shared_ptr<BucketizedBufferAllocator> m_allocator;
std::shared_ptr<CPUAllocator> m_cpuInputAllocator;
std::shared_ptr<onnxruntime::IAllocator> m_cpuInputAllocator;
std::shared_ptr<onnxruntime::KernelRegistry> m_kernelRegistry;
std::shared_ptr<const Windows::AI::MachineLearning::Adapter::InternalRegistrationInfoMap> m_internalRegInfoMap;
mutable uint64_t m_partitionKernelPrefixVal = 0;
Expand Down

0 comments on commit 44dcc3a

Please sign in to comment.