diff --git a/src/webnn/native/BUILD.gn b/src/webnn/native/BUILD.gn index 5e682b374..bb5f4792f 100644 --- a/src/webnn/native/BUILD.gn +++ b/src/webnn/native/BUILD.gn @@ -210,6 +210,19 @@ source_set("sources") { } } + if (webnn_enable_dml) { + sources += [ + "dml/BackendDML.cpp", + "dml/BackendDML.h", + "dml/ContextDML.cpp", + "dml/ContextDML.h", + "dml/ExecutionContextDML.cpp", + "dml/ExecutionContextDML.h", + "dml/GraphDML.cpp", + "dml/GraphDML.h", + ] + } + if (webnn_enable_dmlx) { if (webnn_enable_gpu_buffer == false) { sources += [ diff --git a/src/webnn/native/dml/BackendDML.cpp b/src/webnn/native/dml/BackendDML.cpp new file mode 100644 index 000000000..4b9af9195 --- /dev/null +++ b/src/webnn/native/dml/BackendDML.cpp @@ -0,0 +1,46 @@ +// Copyright 2019 The Dawn Authors +// Copyright 2022 The WebNN-native Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "webnn/native/dml/BackendDML.h" + +#include "webnn/native/Instance.h" +#include "webnn/native/dml/ContextDML.h" + +namespace webnn::native::dml { + + Backend::Backend(InstanceBase* instance) + : BackendConnection(instance, wnn::BackendType::DirectML) { + } + + MaybeError Backend::Initialize() { + return {}; + } + + ContextBase* Backend::CreateContext(ContextOptions const* options) { + return new Context(options); + } + + BackendConnection* Connect(InstanceBase* instance) { + Backend* backend = new Backend(instance); + + if (instance->ConsumedError(backend->Initialize())) { + delete backend; + return nullptr; + } + + return backend; + } + +} // namespace webnn::native::dml diff --git a/src/webnn/native/dml/BackendDML.h b/src/webnn/native/dml/BackendDML.h new file mode 100644 index 000000000..0279bf0b8 --- /dev/null +++ b/src/webnn/native/dml/BackendDML.h @@ -0,0 +1,38 @@ +// Copyright 2019 The Dawn Authors +// Copyright 2022 The WebNN-native Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef WEBNN_NATIVE_DML_BACKEND_DML_H_ +#define WEBNN_NATIVE_DML_BACKEND_DML_H_ + +#include +#include "webnn/native/BackendConnection.h" +#include "webnn/native/Context.h" +#include "webnn/native/Error.h" + +namespace webnn::native::dml { + + class Backend : public BackendConnection { + public: + Backend(InstanceBase* instance); + + MaybeError Initialize(); + ContextBase* CreateContext(ContextOptions const* options = nullptr) override; + + private: + }; + +} // namespace webnn::native::dml + +#endif // WEBNN_NATIVE_DML_BACKEND_DML_H_ diff --git a/src/webnn/native/dml/ContextDML.cpp b/src/webnn/native/dml/ContextDML.cpp new file mode 100644 index 000000000..871c6f2be --- /dev/null +++ b/src/webnn/native/dml/ContextDML.cpp @@ -0,0 +1,29 @@ +// Copyright 2022 The WebNN-native Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "webnn/native/dml/ContextDML.h" + +#include "common/RefCounted.h" +#include "webnn/native/dml/GraphDML.h" + +namespace webnn::native::dml { + + Context::Context(ContextOptions const* options) : ContextBase(options) { + } + + GraphBase* Context::CreateGraphImpl() { + return new Graph(this); + } + +} // namespace webnn::native::dml diff --git a/src/webnn/native/dml/ContextDML.h b/src/webnn/native/dml/ContextDML.h new file mode 100644 index 000000000..8f6a98293 --- /dev/null +++ b/src/webnn/native/dml/ContextDML.h @@ -0,0 +1,34 @@ +// Copyright 2022 The WebNN-native Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef WEBNN_NATIVE_DML_CONTEXT_DML_H_ +#define WEBNN_NATIVE_DML_CONTEXT_DML_H_ + +#include "webnn/native/Context.h" +#include "webnn/native/Graph.h" + +namespace webnn::native::dml { + + class Context : public ContextBase { + public: + explicit Context(ContextOptions const* options); + ~Context() override = default; + + private: + GraphBase* CreateGraphImpl() override; + }; + +} // namespace webnn::native::dml + +#endif // WEBNN_NATIVE_DML_CONTEXT_DML_H_ diff --git a/src/webnn/native/dml/ExecutionContextDML.cpp b/src/webnn/native/dml/ExecutionContextDML.cpp new file mode 100644 index 000000000..a584af294 --- /dev/null +++ b/src/webnn/native/dml/ExecutionContextDML.cpp @@ -0,0 +1,106 @@ +// Copyright 2022 The WebNN-native Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ExecutionContextDML.h" + +namespace webnn::native::dml { + + // An adapter called the "Microsoft Basic Render Driver" is always present. This adapter is a + // render-only device that has no display outputs. + inline bool IsSoftwareAdapter(IDXGIAdapter1* pAdapter) { + DXGI_ADAPTER_DESC1 pDesc; + pAdapter->GetDesc1(&pDesc); + // See here for documentation on filtering WARP adapter: + // https://docs.microsoft.com/en-us/windows/desktop/direct3ddxgi/d3d10-graphics-programming-guide-dxgi#new-info-about-enumerating-adapters-for-windows-8 + return pDesc.Flags == DXGI_ADAPTER_FLAG_SOFTWARE || + (pDesc.VendorId == 0x1414 && pDesc.DeviceId == 0x8c); + } + + HRESULT EnumAdapter(DXGI_GPU_PREFERENCE gpuPreference, + bool useGpu, + ComPtr adapter) { + ComPtr dxgiFactory; + RETURN_IF_FAILED(CreateDXGIFactory1(IID_PPV_ARGS(&dxgiFactory))); + if (useGpu) { + UINT adapterIndex = 0; + while (dxgiFactory->EnumAdapterByGpuPreference(adapterIndex++, gpuPreference, + IID_PPV_ARGS(&adapter)) != + DXGI_ERROR_NOT_FOUND) { + if (!IsSoftwareAdapter(adapter.Get())) { + break; + } + } + } else { + RETURN_IF_FAILED(dxgiFactory->EnumWarpAdapter(IID_PPV_ARGS(&adapter))); + } + return S_OK; + } + + ExecutionContext::ExecutionContext(ComPtr adapter, bool useDebugLayer) + : mAdapter(std::move(adapter)), mUseDebugLayer(useDebugLayer) { + } + + // static + std::unique_ptr ExecutionContext::Create(ComPtr adapter, + bool useDebugLayer) { + std::unique_ptr executionContext( + new ExecutionContext(adapter, useDebugLayer)); + if (FAILED(executionContext->Initialize())) { + dawn::ErrorLog() << "Failed to initialize Device."; + return nullptr; + } + return executionContext; + } + + HRESULT ExecutionContext::Initialize() { + if (mUseDebugLayer) { + ComPtr debug; + if (SUCCEEDED(D3D12GetDebugInterface(IID_PPV_ARGS(&debug)))) { + debug->EnableDebugLayer(); + } + } + RETURN_IF_FAILED( + D3D12CreateDevice(mAdapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&mD3D12Device))); + D3D12_COMMAND_QUEUE_DESC commandQueueDesc{}; + commandQueueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; + commandQueueDesc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE; + RETURN_IF_FAILED( + mD3D12Device->CreateCommandQueue(&commandQueueDesc, IID_PPV_ARGS(&mCommandQueue))); + RETURN_IF_FAILED(mD3D12Device->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_DIRECT, + IID_PPV_ARGS(&mCommandAllocator))); + RETURN_IF_FAILED(mD3D12Device->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT, + mCommandAllocator.Get(), nullptr, + IID_PPV_ARGS(&mCommandList))); + + // Create the DirectML device. + DML_CREATE_DEVICE_FLAGS dmlCreateDeviceFlags = DML_CREATE_DEVICE_FLAG_NONE; +#if defined(_DEBUG) + dmlCreateDeviceFlags = DML_CREATE_DEVICE_FLAG_DEBUG; +#endif + if (dmlCreateDeviceFlags == DML_CREATE_DEVICE_FLAG_DEBUG) { + if (FAILED(DMLCreateDevice(mD3D12Device.Get(), dmlCreateDeviceFlags, + IID_PPV_ARGS(&mDevice)))) { + dawn::WarningLog() << "Failed to create a DirectML device with debug flag, " + "will fall back to use none flag."; + RETURN_IF_FAILED(DMLCreateDevice(mD3D12Device.Get(), DML_CREATE_DEVICE_FLAG_NONE, + IID_PPV_ARGS(&mDevice))); + } + } else { + RETURN_IF_FAILED( + DMLCreateDevice(mD3D12Device.Get(), dmlCreateDeviceFlags, IID_PPV_ARGS(&mDevice))); + } + return S_OK; + }; + +} // namespace webnn::native::dml diff --git a/src/webnn/native/dml/ExecutionContextDML.h b/src/webnn/native/dml/ExecutionContextDML.h new file mode 100644 index 000000000..4df3b5722 --- /dev/null +++ b/src/webnn/native/dml/ExecutionContextDML.h @@ -0,0 +1,63 @@ +// Copyright 2022 The WebNN-native Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef WEBNN_NATIVE_DML_EXECUTIONCONTEXTEDML_H_ +#define WEBNN_NATIVE_DML_EXECUTIONCONTEXTEDML_H_ + +#include +#include + +#include "common/Log.h" +#include "dml_platform.h" +#include "webnn/native/NamedOutputs.h" +#include "webnn/native/webnn_platform.h" + +#define RETURN_IF_FAILED(EXPR) \ + do { \ + auto HR = EXPR; \ + if (FAILED(HR)) { \ + dawn::ErrorLog() << "Failed to do " << #EXPR << " Return HRESULT " << std::hex << HR; \ + return HR; \ + } \ + } while (0) + +namespace webnn::native::dml { + + HRESULT EnumAdapter(DXGI_GPU_PREFERENCE gpuPreference, + bool useGpu, + ComPtr adapter); + + class ExecutionContext { + public: + static std::unique_ptr Create(ComPtr adapter, + bool useDebugLayer); + + private: + ExecutionContext(ComPtr adapter, bool useDebugLayer); + HRESULT Initialize(); + + ComPtr mDevice; + ComPtr mD3D12Device; + ComPtr mCommandRecorder; + ComPtr mCommandQueue; + ComPtr mCommandAllocator; + ComPtr mCommandList; + + ComPtr mAdapter; + bool mUseDebugLayer = false; + }; + +} // namespace webnn::native::dml + +#endif // WEBNN_NATIVE_DML_EXECUTIONCONTEXTEDML_H_ diff --git a/src/webnn/native/dml/GraphDML.cpp b/src/webnn/native/dml/GraphDML.cpp new file mode 100644 index 000000000..21e8a1da6 --- /dev/null +++ b/src/webnn/native/dml/GraphDML.cpp @@ -0,0 +1,60 @@ +// Copyright 2022 The WebNN-native Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "webnn/native/dml/GraphDML.h" + +#include "webnn/native/NamedInputs.h" +#include "webnn/native/NamedOutputs.h" + +namespace webnn::native ::dml { + + Graph::Graph(Context* context) : GraphBase(context) { + wnn::DevicePreference devicePreference = GetContext()->GetContextOptions().devicePreference; + bool useGpu = devicePreference == wnn::DevicePreference::Cpu ? false : true; + DXGI_GPU_PREFERENCE gpuPreference = DXGI_GPU_PREFERENCE_UNSPECIFIED; + wnn::PowerPreference powerPreference = GetContext()->GetContextOptions().powerPreference; + switch (powerPreference) { + case wnn::PowerPreference::High_performance: + gpuPreference = DXGI_GPU_PREFERENCE::DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE; + break; + case wnn::PowerPreference::Low_power: + gpuPreference = DXGI_GPU_PREFERENCE::DXGI_GPU_PREFERENCE_MINIMUM_POWER; + break; + default: + break; + } + + bool useDebugLayer = false; +#ifdef _DEBUG + useDebugLayer = true; +#endif + ComPtr adapter; + if (FAILED(EnumAdapter(gpuPreference, useGpu, adapter))) { + dawn::ErrorLog() << "Failed to enumerate adapters."; + DAWN_ASSERT(0); + } + + mExecutionContext = ExecutionContext::Create(adapter, useDebugLayer); + DAWN_ASSERT(mExecutionContext != nullptr); + } + + MaybeError Graph::CompileImpl() { + return {}; + } + + MaybeError Graph::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) { + return {}; + } + +} // namespace webnn::native::dml diff --git a/src/webnn/native/dml/GraphDML.h b/src/webnn/native/dml/GraphDML.h new file mode 100644 index 000000000..0e240bf57 --- /dev/null +++ b/src/webnn/native/dml/GraphDML.h @@ -0,0 +1,61 @@ +// Copyright 2022 The WebNN-native Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef WEBNN_NATIVE_DML_GRAPH_DML_H_ +#define WEBNN_NATIVE_DML_GRAPH_DML_H_ + +#include "ExecutionContextDML.h" +#include "webnn/native/Graph.h" +#include "webnn/native/Operand.h" +#include "webnn/native/Operator.h" +#include "webnn/native/dml/ContextDML.h" +#include "webnn/native/ops/BatchNorm.h" +#include "webnn/native/ops/Binary.h" +#include "webnn/native/ops/Clamp.h" +#include "webnn/native/ops/Concat.h" +#include "webnn/native/ops/Constant.h" +#include "webnn/native/ops/Conv2d.h" +#include "webnn/native/ops/Gemm.h" +#include "webnn/native/ops/Gru.h" +#include "webnn/native/ops/Input.h" +#include "webnn/native/ops/InstanceNorm.h" +#include "webnn/native/ops/LeakyRelu.h" +#include "webnn/native/ops/Pad.h" +#include "webnn/native/ops/Pool2d.h" +#include "webnn/native/ops/Reduce.h" +#include "webnn/native/ops/Resample2d.h" +#include "webnn/native/ops/Reshape.h" +#include "webnn/native/ops/Slice.h" +#include "webnn/native/ops/Split.h" +#include "webnn/native/ops/Squeeze.h" +#include "webnn/native/ops/Transpose.h" +#include "webnn/native/ops/Unary.h" + +namespace webnn::native::dml { + + class Graph : public GraphBase { + public: + explicit Graph(Context* context); + ~Graph() override = default; + + private: + MaybeError CompileImpl() override; + MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override; + + std::unique_ptr mExecutionContext; + }; + +} // namespace webnn::native::dml + +#endif // WEBNN_NATIVE_DML_GRAPH_DML_H_ diff --git a/src/webnn/native/dml/dml_platform.h b/src/webnn/native/dml/dml_platform.h new file mode 100644 index 000000000..49aa73475 --- /dev/null +++ b/src/webnn/native/dml/dml_platform.h @@ -0,0 +1,28 @@ +// Copyright 2022 The WebNN-native Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef WEBNN_NATIVE_DML_DMLPLATFORM_H_ +#define WEBNN_NATIVE_DML_DMLPLATFORM_H_ + +// This micro definition must be added before including "DirectML.h". +#define DML_TARGET_VERSION_USE_LATEST 1 + +#include +#include + +#include "DirectML.h" + +using Microsoft::WRL::ComPtr; + +#endif // WEBNN_NATIVE_DML_DMLPLATFORM_H_