diff --git a/DEPS b/DEPS index af7e24d8e..8c2aaec80 100644 --- a/DEPS +++ b/DEPS @@ -55,7 +55,7 @@ deps = { }, # GPGMM support for fast DML allocation and residency management. 'third_party/gpgmm': { - 'url': '{github_git}/intel/gpgmm.git@72f910e8dcb0048d84c11f705b96b768b9eee156', + 'url': '{github_git}/intel/gpgmm.git@61fcfcbd872b7643423ac8c9555f3a6a366904ee', 'condition': 'checkout_win', }, 'third_party/oneDNN': { diff --git a/src/webnn/native/dmlx/deps/src/device.cpp b/src/webnn/native/dmlx/deps/src/device.cpp index 595930ead..f80101985 100644 --- a/src/webnn/native/dmlx/deps/src/device.cpp +++ b/src/webnn/native/dmlx/deps/src/device.cpp @@ -13,10 +13,14 @@ using namespace pydml; using Microsoft::WRL::ComPtr; SVDescriptorHeap::SVDescriptorHeap( - ComPtr heap, - uint64_t size) - : gpgmm::d3d12::Heap(heap, DXGI_MEMORY_SEGMENT_GROUP_LOCAL, size), - m_Heap(std::move(heap)) { + ComPtr heap) + : m_Heap(std::move(heap)) { +} + +ID3D12DescriptorHeap* SVDescriptorHeap::GetDescriptorHeap() const { + ComPtr descriptorHeap; + m_Heap->As(&descriptorHeap); + return descriptorHeap.Get(); } Device::Device(bool useGpu, bool useDebugLayer, DXGI_GPU_PREFERENCE gpuPreference) : m_useGpu(useGpu), m_useDebugLayer(useDebugLayer), m_gpuPreference(gpuPreference) {} @@ -101,20 +105,16 @@ HRESULT Device::Init() nullptr, // initial pipeline state IID_GRAPHICS_PPV_ARGS(m_commandList.GetAddressOf()))); - D3D12_FEATURE_DATA_ARCHITECTURE arch = {}; - ReturnIfFailed(m_d3d12Device->CheckFeatureSupport(D3D12_FEATURE_ARCHITECTURE, &arch, sizeof(arch))); - D3D12_FEATURE_DATA_D3D12_OPTIONS options = {}; ReturnIfFailed(m_d3d12Device->CheckFeatureSupport(D3D12_FEATURE_D3D12_OPTIONS, &options, sizeof(options))); gpgmm::d3d12::ALLOCATOR_DESC allocatorDesc = {}; allocatorDesc.Adapter = dxgiAdapter; allocatorDesc.Device = m_d3d12Device; - allocatorDesc.IsUMA = arch.UMA; allocatorDesc.ResourceHeapTier = options.ResourceHeapTier; #ifdef WEBNN_ENABLE_RESOURCE_DUMP - allocatorDesc.RecordOptions.Flags |= gpgmm::d3d12::ALLOCATOR_RECORD_FLAG_ALL_EVENTS; + allocatorDesc.RecordOptions.Flags |= gpgmm::d3d12::EVENT_RECORD_FLAG_ALL_EVENTS; allocatorDesc.RecordOptions.MinMessageLevel = D3D12_MESSAGE_SEVERITY_MESSAGE; allocatorDesc.RecordOptions.UseDetailedTimingEvents = true; #endif @@ -299,8 +299,8 @@ HRESULT Device::DispatchOperator( DML_BINDING_TABLE_DESC bindingTableDesc = {}; bindingTableDesc.Dispatchable = op; - bindingTableDesc.CPUDescriptorHandle = m_descriptorHeap->m_Heap->GetCPUDescriptorHandleForHeapStart(); - bindingTableDesc.GPUDescriptorHandle = m_descriptorHeap->m_Heap->GetGPUDescriptorHandleForHeapStart(); + bindingTableDesc.CPUDescriptorHandle = m_descriptorHeap->GetDescriptorHeap()->GetCPUDescriptorHandleForHeapStart(); + bindingTableDesc.GPUDescriptorHandle = m_descriptorHeap->GetDescriptorHeap()->GetGPUDescriptorHandleForHeapStart(); bindingTableDesc.SizeInDescriptors = bindingProps.RequiredDescriptorCount; ReturnIfFailed(m_bindingTable->Reset(&bindingTableDesc)); @@ -339,7 +339,8 @@ HRESULT Device::DispatchOperator( } // Record and execute commands, and wait for completion - m_commandList->SetDescriptorHeaps(1, m_descriptorHeap->m_Heap.GetAddressOf()); + ID3D12DescriptorHeap* descriptorHeap = m_descriptorHeap->GetDescriptorHeap(); + m_commandList->SetDescriptorHeaps(1, &descriptorHeap); m_commandRecorder->RecordDispatch(m_commandList.Get(), op, m_bindingTable.Get()); RecordOutputReadBack(outputsResourceSize); ReturnIfFailed(ExecuteCommandListAndWait()); @@ -536,8 +537,8 @@ HRESULT Device::InitializeOperator( DML_BINDING_TABLE_DESC bindingTableDesc = {}; bindingTableDesc.Dispatchable = m_operatorInitializer.Get(); - bindingTableDesc.CPUDescriptorHandle = m_descriptorHeap->m_Heap->GetCPUDescriptorHandleForHeapStart(); - bindingTableDesc.GPUDescriptorHandle = m_descriptorHeap->m_Heap->GetGPUDescriptorHandleForHeapStart(); + bindingTableDesc.CPUDescriptorHandle = m_descriptorHeap->GetDescriptorHeap()->GetCPUDescriptorHandleForHeapStart(); + bindingTableDesc.GPUDescriptorHandle = m_descriptorHeap->GetDescriptorHeap()->GetGPUDescriptorHandleForHeapStart(); bindingTableDesc.SizeInDescriptors = descriptorHeapSize; ReturnIfFailed(m_bindingTable->Reset(&bindingTableDesc)); @@ -560,7 +561,8 @@ HRESULT Device::InitializeOperator( } // Record and execute commands, and wait for completion - m_commandList->SetDescriptorHeaps(1, m_descriptorHeap->m_Heap.GetAddressOf()); + ID3D12DescriptorHeap* descriptorHeap = m_descriptorHeap->GetDescriptorHeap(); + m_commandList->SetDescriptorHeaps(1, &descriptorHeap); m_commandRecorder->RecordDispatch(m_commandList.Get(), m_operatorInitializer.Get(), m_bindingTable.Get()); ReturnIfFailed(ExecuteCommandListAndWait()); return S_OK; @@ -655,35 +657,42 @@ HRESULT Device::EnsureDefaultBufferSize(uint64_t requestedSizeInBytes, _Inout_ C HRESULT Device::EnsureDescriptorHeapSize(uint32_t requestedSizeInDescriptors) { - uint32_t existingSize = m_descriptorHeap ? m_descriptorHeap->m_Heap->GetDesc().NumDescriptors : 0; + uint32_t existingSize = m_descriptorHeap ? m_descriptorHeap->GetDescriptorHeap()->GetDesc().NumDescriptors : 0; uint32_t newSize = RoundUpToPow2(requestedSizeInDescriptors); // ensures geometric growth if (newSize != existingSize) { if (m_descriptorHeap != nullptr && m_residencyManager != nullptr){ - m_residencyManager->UnlockHeap(m_descriptorHeap.get()); + m_residencyManager->UnlockHeap(m_descriptorHeap->m_Heap.Get()); } m_descriptorHeap = nullptr; - if (m_residencyManager != nullptr){ - ReturnIfFailed(m_residencyManager->Evict(newSize, DXGI_MEMORY_SEGMENT_GROUP_LOCAL)); - } - - D3D12_DESCRIPTOR_HEAP_DESC desc = {}; - desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; - desc.NumDescriptors = newSize; - desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; - - ComPtr d3d12DescriptorHeap; - ReturnIfFailed(m_d3d12Device->CreateDescriptorHeap(&desc, IID_GRAPHICS_PPV_ARGS(d3d12DescriptorHeap.GetAddressOf()))); - - m_descriptorHeap = std::make_unique(std::move(d3d12DescriptorHeap), newSize); + auto createHeapFn = [&](ID3D12Pageable** ppPageableOut) -> HRESULT { + ComPtr d3d12DescriptorHeap; + D3D12_DESCRIPTOR_HEAP_DESC desc = {}; + desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; + desc.NumDescriptors = newSize; + desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; + ReturnIfFailed(m_d3d12Device->CreateDescriptorHeap( + &desc, IID_PPV_ARGS(&d3d12DescriptorHeap))); + *ppPageableOut = d3d12DescriptorHeap.Detach(); + return S_OK; + }; + + gpgmm::d3d12::HEAP_DESC heapDesc = {}; + heapDesc.SizeInBytes = newSize * m_d3d12Device->GetDescriptorHandleIncrementSize(D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV); + heapDesc.MemorySegment = gpgmm::d3d12::RESIDENCY_SEGMENT_LOCAL; + + ComPtr descriptorHeap; + ReturnIfFailed(gpgmm::d3d12::Heap::CreateHeap(heapDesc, m_residencyManager.Get(), createHeapFn, + &descriptorHeap)); if (m_residencyManager != nullptr){ - ReturnIfFailed(m_residencyManager->InsertHeap(m_descriptorHeap.get())); - ReturnIfFailed(m_residencyManager->LockHeap(m_descriptorHeap.get())); + ReturnIfFailed(m_residencyManager->LockHeap(descriptorHeap.Get())); } + + m_descriptorHeap = std::make_unique(std::move(descriptorHeap)); } return S_OK; } diff --git a/src/webnn/native/dmlx/deps/src/device.h b/src/webnn/native/dmlx/deps/src/device.h index 7400ee655..55e3f7b7e 100644 --- a/src/webnn/native/dmlx/deps/src/device.h +++ b/src/webnn/native/dmlx/deps/src/device.h @@ -10,10 +10,12 @@ namespace pydml { - class SVDescriptorHeap : public gpgmm::d3d12::Heap { + class SVDescriptorHeap { public: - SVDescriptorHeap(ComPtr heap, uint64_t size); - ComPtr m_Heap; + SVDescriptorHeap(ComPtr heap); + ID3D12DescriptorHeap* GetDescriptorHeap() const; + + ComPtr m_Heap; }; class Device