From 6d05ff506cb90189b6e9d56481798fa644fb2885 Mon Sep 17 00:00:00 2001 From: junwei Date: Wed, 4 May 2022 11:06:32 +0800 Subject: [PATCH] Sync infranstructure with Dawn --- DEPS | 4 +- build_overrides/webnn.gni | 3 +- examples/LeNet/LeNet.cpp | 2 +- examples/MobileNetV2/Main.cpp | 6 +- examples/ResNet/Main.cpp | 2 +- examples/SampleUtils.cpp | 43 ++-- examples/SampleUtils.h | 5 +- examples/SqueezeNet/Main.cpp | 2 +- generator/webnn_generator.gni | 20 +- include/webnn/BUILD.gn | 19 +- include/webnn/EnumClassBitmasks.h | 145 ------------- include/webnn/native/WebnnNative.h | 1 + include/webnn/webnn.h | 1 + include/webnn/webnn_cpp.h | 1 + include/webnn/webnn_proc.h | 4 +- include/webnn/webnn_proc_table.h | 1 + include/webnn/webnn_thread_dispatch_proc.h | 33 +++ node/src/GraphBuilder.cpp | 2 +- scripts/webnn_overrides_with_defaults.gni | 8 +- src/webnn/BUILD.gn | 18 +- src/webnn/common/BUILD.gn | 67 +++--- src/webnn/native/BUILD.gn | 33 ++- src/webnn/native/Context.cpp | 8 +- src/webnn/native/Context.h | 11 +- src/webnn/native/Error.h | 12 +- src/webnn/native/Forward.h | 3 +- src/webnn/native/Graph.cpp | 10 +- src/webnn/native/Graph.h | 10 +- src/webnn/native/GraphBuilder.cpp | 190 +++++++++--------- src/webnn/native/GraphBuilder.h | 158 +++++++-------- src/webnn/native/Instance.cpp | 14 +- src/webnn/native/Instance.h | 14 +- src/webnn/native/NamedInputs.h | 2 +- src/webnn/native/NamedOperands.h | 2 +- src/webnn/native/NamedOutputs.h | 4 +- src/webnn/native/OperandArray.h | 4 +- src/webnn/native/OperatorArray.h | 6 +- src/webnn/native/Utils.h | 2 +- src/webnn/native/WebnnNative.cpp | 7 +- src/webnn/native/dmlx/GraphDMLX.cpp | 4 +- src/webnn/native/ops/Gru.cpp | 7 +- src/webnn/native/webnn_platform.h | 2 +- src/webnn/tests/BUILD.gn | 11 +- src/webnn/tests/end2end/AddTests.cpp | 6 +- src/webnn/tests/end2end/BatchNormTests.cpp | 2 +- src/webnn/tests/end2end/ClampTests.cpp | 2 +- src/webnn/tests/end2end/ConcatTests.cpp | 2 +- src/webnn/tests/end2end/Conv2dTests.cpp | 2 +- .../tests/end2end/ConvTranspose2dTests.cpp | 2 +- src/webnn/tests/end2end/DivTests.cpp | 4 +- .../tests/end2end/ElementWiseUnaryTests.cpp | 2 +- src/webnn/tests/end2end/GemmTests.cpp | 2 +- src/webnn/tests/end2end/GruTests.cpp | 8 +- src/webnn/tests/end2end/HardSwishTests.cpp | 2 +- src/webnn/tests/end2end/InstanceNormTests.cpp | 2 +- src/webnn/tests/end2end/LeakyReluTests.cpp | 2 +- src/webnn/tests/end2end/MatMulTests.cpp | 18 +- src/webnn/tests/end2end/MaxTests.cpp | 6 +- src/webnn/tests/end2end/MinTests.cpp | 6 +- src/webnn/tests/end2end/MulTests.cpp | 6 +- src/webnn/tests/end2end/PadTests.cpp | 2 +- src/webnn/tests/end2end/Pool2dTests.cpp | 64 +++--- src/webnn/tests/end2end/PowTests.cpp | 10 +- src/webnn/tests/end2end/ReduceTests.cpp | 2 +- src/webnn/tests/end2end/ReluTests.cpp | 2 +- src/webnn/tests/end2end/Resample2dTests.cpp | 2 +- src/webnn/tests/end2end/ReshapeTests.cpp | 2 +- src/webnn/tests/end2end/SigmoidTests.cpp | 4 +- src/webnn/tests/end2end/SliceTests.cpp | 2 +- src/webnn/tests/end2end/SoftmaxTests.cpp | 2 +- src/webnn/tests/end2end/SplitTests.cpp | 4 +- src/webnn/tests/end2end/SqueezeTests.cpp | 2 +- src/webnn/tests/end2end/SubTests.cpp | 4 +- src/webnn/tests/end2end/TanhTests.cpp | 2 +- src/webnn/tests/end2end/TransposeTests.cpp | 2 +- .../models/MobileNetV2BatchNormNchw.cpp | 4 +- .../tests/end2end/models/MobileNetV2Nchw.cpp | 4 +- .../tests/end2end/models/MobileNetV2Nhwc.cpp | 4 +- src/webnn/tests/end2end/models/ResNetNchw.cpp | 4 +- src/webnn/tests/end2end/models/ResNetNhwc.cpp | 4 +- .../tests/end2end/models/SqueezeNetNchw.cpp | 4 +- .../tests/end2end/models/SqueezeNetNhwc.cpp | 4 +- src/webnn/tests/unittests/ObjectBaseTests.cpp | 32 +-- .../validation/GraphValidationTests.cpp | 14 +- .../unittests/validation/ValidationTest.cpp | 2 +- src/webnn/utils/BUILD.gn | 6 +- src/webnn/wire/BUILD.gn | 6 +- webnn.json | 37 +++- 88 files changed, 568 insertions(+), 638 deletions(-) delete mode 100644 include/webnn/EnumClassBitmasks.h create mode 100644 include/webnn/webnn.h create mode 100644 include/webnn/webnn_cpp.h create mode 100644 include/webnn/webnn_proc_table.h create mode 100644 include/webnn/webnn_thread_dispatch_proc.h diff --git a/DEPS b/DEPS index 3535c9b0f..4470b983a 100644 --- a/DEPS +++ b/DEPS @@ -8,7 +8,7 @@ gclient_gn_args = [ vars = { 'chromium_git': 'https://chromium.googlesource.com', - 'dawn_git': 'https://dawn.googlesource.com', + 'dawn_git': 'https://github.com/fujunwei', 'github_git': 'https://github.com', 'dawn_standalone': True, @@ -45,7 +45,7 @@ deps = { # Dependencies required for code generator and infrastructure code. 'third_party/dawn': { - 'url': '{dawn_git}/dawn.git@bf1c0cf52377b4db2bf3a433dc5056620aad7cdd' + 'url': '{dawn_git}/dawn.git@f4c84e239bf8b5b2c4733d68ca38e1e9049fd895' }, # Dependencies required for backends. diff --git a/build_overrides/webnn.gni b/build_overrides/webnn.gni index 7155c1cfd..9ca5b973e 100644 --- a/build_overrides/webnn.gni +++ b/build_overrides/webnn.gni @@ -15,7 +15,8 @@ webnn_standalone = true # The paths to WebNN's dependencies -webnn_dawn_root = "//third_party/dawn" +webnn_abseil_dir = "//third_party/abseil-cpp" +dawn_root = "//third_party/dawn" webnn_googletest_dir = "//third_party/googletest" webnn_jinja2_dir = "//third_party/jinja2" webnn_gpgmm_dir = "//third_party/gpgmm" diff --git a/examples/LeNet/LeNet.cpp b/examples/LeNet/LeNet.cpp index a9ae2a6b4..74bb17568 100644 --- a/examples/LeNet/LeNet.cpp +++ b/examples/LeNet/LeNet.cpp @@ -48,7 +48,7 @@ wnn::Graph LeNet::Build(const std::string& weigthsPath) { return nullptr; } - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(mContext); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(mContext); uint32_t byteOffset = 0; const wnn::Operand input = utils::BuildInput(builder, "input", {1, 1, 28, 28}); diff --git a/examples/MobileNetV2/Main.cpp b/examples/MobileNetV2/Main.cpp index 340d865e3..bec9759e1 100644 --- a/examples/MobileNetV2/Main.cpp +++ b/examples/MobileNetV2/Main.cpp @@ -39,9 +39,9 @@ int main(int argc, const char* argv[]) { } }, &mobilevetv2); - wnn::GraphBuilder builder = wnn::CreateGraphBuilder(context); - wnn::Operand output = mobilevetv2.mLayout == "nchw" ? mobilevetv2.LoadNchw(builder) - : mobilevetv2.LoadNhwc(builder); + wnn::GraphBuilder builder = utils::CreateGraphBuilder(context); + wnn::Operand output = mobilevetv2.mLayout == "nchw" ? mobilevetv2.LoadNCHW(builder) + : mobilevetv2.LoadNHWC(builder); // Build the graph. const std::chrono::time_point compilationStartTime = diff --git a/examples/ResNet/Main.cpp b/examples/ResNet/Main.cpp index 64773cfd3..462e494c0 100644 --- a/examples/ResNet/Main.cpp +++ b/examples/ResNet/Main.cpp @@ -39,7 +39,7 @@ int main(int argc, const char* argv[]) { } }, &resnet); - wnn::GraphBuilder builder = wnn::CreateGraphBuilder(context); + wnn::GraphBuilder builder = utils::CreateGraphBuilder(context); wnn::Operand output = resnet.mLayout == "nchw" ? resnet.LoadNchw(builder) : resnet.LoadNhwc(builder); diff --git a/examples/SampleUtils.cpp b/examples/SampleUtils.cpp index ff9a6c10e..2da6009b6 100644 --- a/examples/SampleUtils.cpp +++ b/examples/SampleUtils.cpp @@ -50,7 +50,7 @@ static webnn::wire::WireClient* wireClient = nullptr; static utils::TerribleCommandBuffer* c2sBuf = nullptr; static utils::TerribleCommandBuffer* s2cBuf = nullptr; -static wnn::Instance clientInstance; +static wnn::Instance instance; static std::unique_ptr nativeInstance; wnn::Context CreateCppContext(wnn::ContextOptions const* options) { nativeInstance = std::make_unique(); @@ -67,6 +67,7 @@ wnn::Context CreateCppContext(wnn::ContextOptions const* options) { case CmdBufType::None: procs = backendProcs; context = backendContext; + instance = wnn::Instance(nativeInstance->Get()); break; case CmdBufType::Terrible: { @@ -94,14 +95,12 @@ wnn::Context CreateCppContext(wnn::ContextOptions const* options) { context = contextReservation.context; #else - webnnProcSetProcs(&procs); auto instanceReservation = wireClient->ReserveInstance(); wireServer->InjectInstance(nativeInstance->Get(), instanceReservation.id, instanceReservation.generation); // Keep the reference instread of using Acquire. // TODO:: make the instance in the client as singleton object. - clientInstance = wnn::Instance(instanceReservation.instance); - return clientInstance.CreateContext(options); + instance = wnn::Instance(instanceReservation.instance); #endif } default: @@ -109,8 +108,8 @@ wnn::Context CreateCppContext(wnn::ContextOptions const* options) { DAWN_ASSERT(0); } webnnProcSetProcs(&procs); - - return wnn::Context::Acquire(context); + return instance.CreateContext(options); + ; } void DoFlush() { @@ -123,35 +122,15 @@ void DoFlush() { } wnn::NamedInputs CreateCppNamedInputs() { -#if defined(WEBNN_ENABLE_WIRE) - return clientInstance.CreateNamedInputs(); -#else - return wnn::CreateNamedInputs(); -#endif // defined(WEBNN_ENABLE_WIRE) -} - -wnn::NamedOperands CreateCppNamedOperands() { -#if defined(WEBNN_ENABLE_WIRE) - return clientInstance.CreateNamedOperands(); -#else - return wnn::CreateNamedOperands(); -#endif // defined(WEBNN_ENABLE_WIRE) + return instance.CreateNamedInputs(); } wnn::NamedOutputs CreateCppNamedOutputs() { -#if defined(WEBNN_ENABLE_WIRE) - return clientInstance.CreateNamedOutputs(); -#else - return wnn::CreateNamedOutputs(); -#endif // defined(WEBNN_ENABLE_WIRE) + return instance.CreateNamedOutputs(); } wnn::OperatorArray CreateCppOperatorArray() { -#if defined(WEBNN_ENABLE_WIRE) - return clientInstance.CreateOperatorArray(); -#else - return wnn::CreateOperatorArray(); -#endif // defined(WEBNN_ENABLE_WIRE) + return instance.CreateOperatorArray(); } bool ExampleBase::ParseAndCheckExampleOptions(int argc, const char* argv[]) { @@ -264,6 +243,10 @@ namespace utils { return activationOperand; } + wnn::GraphBuilder CreateGraphBuilder(const wnn::Context& context) { + return instance.CreateGraphBuilder(context); + } + wnn::Operand BuildInput(const wnn::GraphBuilder& builder, std::string name, const std::vector& dimensions, @@ -283,7 +266,7 @@ namespace utils { } wnn::Graph Build(const wnn::GraphBuilder& builder, const std::vector& outputs) { - wnn::NamedOperands namedOperands = CreateCppNamedOperands(); + wnn::NamedOperands namedOperands = instance.CreateNamedOperands(); for (auto& output : outputs) { namedOperands.Set(output.name.c_str(), output.operand); } diff --git a/examples/SampleUtils.h b/examples/SampleUtils.h index 7c424fd18..fa21b9e0a 100644 --- a/examples/SampleUtils.h +++ b/examples/SampleUtils.h @@ -83,6 +83,7 @@ namespace utils { FusedActivation activation, const void* options = nullptr); + wnn::GraphBuilder CreateGraphBuilder(const wnn::Context& context); wnn::Operand BuildInput(const wnn::GraphBuilder& builder, std::string name, const std::vector& dimensions, @@ -247,7 +248,7 @@ namespace utils { void Compute(const wnn::Graph& graph, const std::vector>& inputs, const std::vector>& outputs) { - if (graph.GetHandle() == nullptr) { + if (graph.Get() == nullptr) { dawn::ErrorLog() << "The graph is invaild."; } @@ -272,7 +273,7 @@ namespace utils { resource.arrayBufferView.buffer = output.resource.data(); resource.arrayBufferView.byteLength = output.resource.size() * sizeof(float); mlOutputs.push_back(resource); - namedOutputs.Set(output.name.c_str(), &mlOutputs.back()); + namedOutputs.SetOutput(output.name.c_str(), &mlOutputs.back()); } graph.Compute(namedInputs, namedOutputs); DoFlush(); diff --git a/examples/SqueezeNet/Main.cpp b/examples/SqueezeNet/Main.cpp index feba69e18..d250e0e7a 100644 --- a/examples/SqueezeNet/Main.cpp +++ b/examples/SqueezeNet/Main.cpp @@ -39,7 +39,7 @@ int main(int argc, const char* argv[]) { } }, &squeezenet); - wnn::GraphBuilder builder = wnn::CreateGraphBuilder(context); + wnn::GraphBuilder builder = utils::CreateGraphBuilder(context); wnn::Operand output = squeezenet.mLayout == "nchw" ? squeezenet.LoadNchw(builder) : squeezenet.LoadNhwc(builder); diff --git a/generator/webnn_generator.gni b/generator/webnn_generator.gni index 7c76e1fc3..764387136 100644 --- a/generator/webnn_generator.gni +++ b/generator/webnn_generator.gni @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import("//third_party/dawn/generator/generator_lib.gni") +import("//third_party/dawn/generator/dawn_generator.gni") import("../scripts/webnn_overrides_with_defaults.gni") # Dawn used to put autogenerated files in a lot of different places. When we @@ -39,12 +39,16 @@ import("../scripts/webnn_overrides_with_defaults.gni") # disallowed gen directories. webnn_allowed_gen_output_dirs = [ + "src/dawn/", + "src/dawn/native/", "src/webnn/", "src/webnn/native/", "src/webnn/wire/client/", "src/webnn/wire/server/", "src/webnn/wire/", "include/webnn/", + "emscripten-bits/", + "include/dawn/", ] # Template to help invoking Dawn code generators based on generator_lib @@ -77,7 +81,7 @@ template("webnn_generator") { forward_variables_from(invoker, "*") # Set arguments required to find the python libraries for the generator - generator_lib_dir = "${webnn_dawn_root}/generator" + generator_lib_dir = "${dawn_root}/generator" jinja2_path = webnn_jinja2_dir # Force Dawn's autogenerated file structure to mirror exactly the source @@ -87,15 +91,15 @@ template("webnn_generator") { # Make sure that we delete stale autogenerated file in directories that are # no longer used by code generation to avoid include conflicts. - deps = [ "${webnn_root}/generator:remove_stale_autogen_files" ] + deps = [ "${dawn_root}/generator:remove_stale_autogen_files" ] - template_dir = "${webnn_root}/generator/templates" + template_dir = "${dawn_root}/generator/templates" } } # Helper generator for calling the generator from webnn.json # -# dawn_json_generator("my_target_gen") { +# webnn_json_generator("my_target_gen") { # # Which generator target to output # target = "my_target" # @@ -103,19 +107,17 @@ template("webnn_generator") { # } template("webnn_json_generator") { webnn_generator(target_name) { - script = "${webnn_root}/generator/webnn_json_generator.py" + script = "${dawn_root}/generator/dawn_json_generator.py" # The base arguments for the generator: from this webnn.json, generate this # target using templates in this directory. args = [ - "--webnn-json", + "--dawn-json", rebase_path("${webnn_root}/webnn.json", root_build_dir), "--wire-json", rebase_path("${webnn_root}/webnn_wire.json", root_build_dir), "--targets", invoker.target, - "--dawn-generator-path", - rebase_path("${webnn_root}/third_party/dawn/generator"), ] forward_variables_from(invoker, "*", [ "target" ]) diff --git a/include/webnn/BUILD.gn b/include/webnn/BUILD.gn index 906c76eaf..38cd18e30 100644 --- a/include/webnn/BUILD.gn +++ b/include/webnn/BUILD.gn @@ -15,7 +15,7 @@ import("../../scripts/webnn_overrides_with_defaults.gni") -import("${webnn_dawn_root}/scripts/dawn_component.gni") +import("${dawn_root}/scripts/dawn_component.gni") import("${webnn_root}/generator/webnn_generator.gni") ############################################################################### @@ -25,8 +25,8 @@ import("${webnn_root}/generator/webnn_generator.gni") webnn_json_generator("headers_gen") { target = "headers" outputs = [ - "include/webnn/webnn_proc_table.h", - "include/webnn/webnn.h", + "include/dawn/webnn_proc_table.h", + "include/dawn/webnn.h", ] } @@ -43,7 +43,10 @@ source_set("headers") { webnn_json_generator("cpp_headers_gen") { target = "cpp_headers" - outputs = [ "include/webnn/webnn_cpp.h" ] + outputs = [ + "include/dawn/webnn_cpp.h", + "include/dawn/webnn_cpp_print.h", + ] } source_set("cpp_headers") { @@ -67,14 +70,8 @@ config("public") { if (build_with_chromium) { include_dirs += [ - "${webnn_dawn_root}/include", + "${dawn_root}/include", "${dawn_gen_root}/include", ] - } else { - # TODO: Remove after upgrading webnn infranstructure align with dawn. - include_dirs += [ - "${webnn_dawn_root}/src/include", - "${dawn_gen_root}/src/include", - ] } } diff --git a/include/webnn/EnumClassBitmasks.h b/include/webnn/EnumClassBitmasks.h deleted file mode 100644 index bdef1df3d..000000000 --- a/include/webnn/EnumClassBitmasks.h +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2017 The Dawn Authors -// Copyright 2021 The WebNN-native Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef WEBNN_ENUM_CLASS_BITMASKS_H_ -#define WEBNN_ENUM_CLASS_BITMASKS_H_ - -#include - -namespace webnn { - - template - struct IsDawnBitmask { - static constexpr bool enable = false; - }; - - template - struct LowerBitmask { - static constexpr bool enable = false; - }; - - template - struct LowerBitmask::enable>::type> { - static constexpr bool enable = true; - using type = T; - constexpr static T Lower(T t) { - return t; - } - }; - - template - struct BoolConvertible { - using Integral = typename std::underlying_type::type; - - constexpr BoolConvertible(Integral value) : value(value) { - } - constexpr operator bool() const { - return value != 0; - } - constexpr operator T() const { - return static_cast(value); - } - - Integral value; - }; - - template - struct LowerBitmask> { - static constexpr bool enable = true; - using type = T; - static constexpr type Lower(BoolConvertible t) { - return t; - } - }; - - template ::enable && - LowerBitmask::enable>::type> - constexpr BoolConvertible::type> operator|(T1 left, T2 right) { - using T = typename LowerBitmask::type; - using Integral = typename std::underlying_type::type; - return static_cast(LowerBitmask::Lower(left)) | - static_cast(LowerBitmask::Lower(right)); - } - - template ::enable && - LowerBitmask::enable>::type> - constexpr BoolConvertible::type> operator&(T1 left, T2 right) { - using T = typename LowerBitmask::type; - using Integral = typename std::underlying_type::type; - return static_cast(LowerBitmask::Lower(left)) & - static_cast(LowerBitmask::Lower(right)); - } - - template ::enable && - LowerBitmask::enable>::type> - constexpr BoolConvertible::type> operator^(T1 left, T2 right) { - using T = typename LowerBitmask::type; - using Integral = typename std::underlying_type::type; - return static_cast(LowerBitmask::Lower(left)) ^ - static_cast(LowerBitmask::Lower(right)); - } - - template - constexpr BoolConvertible::type> operator~(T1 t) { - using T = typename LowerBitmask::type; - using Integral = typename std::underlying_type::type; - return ~static_cast(LowerBitmask::Lower(t)); - } - - template ::enable && - LowerBitmask::enable>::type> - constexpr T& operator&=(T& l, T2 right) { - T r = LowerBitmask::Lower(right); - l = l & r; - return l; - } - - template ::enable && - LowerBitmask::enable>::type> - constexpr T& operator|=(T& l, T2 right) { - T r = LowerBitmask::Lower(right); - l = l | r; - return l; - } - - template ::enable && - LowerBitmask::enable>::type> - constexpr T& operator^=(T& l, T2 right) { - T r = LowerBitmask::Lower(right); - l = l ^ r; - return l; - } - - template - constexpr bool HasZeroOrOneBits(T value) { - using Integral = typename std::underlying_type::type; - return (static_cast(value) & (static_cast(value) - 1)) == 0; - } - -} // namespace webnn - -#endif // WEBNN_ENUM_CLASS_BITMASKS_H_ diff --git a/include/webnn/native/WebnnNative.h b/include/webnn/native/WebnnNative.h index bcf76ddd4..bacd3956b 100644 --- a/include/webnn/native/WebnnNative.h +++ b/include/webnn/native/WebnnNative.h @@ -44,6 +44,7 @@ namespace webnn::native { WNNContext CreateTestContext(const wnn::ContextOptions* options = nullptr); WNNContext CreateContext(const wnn::ContextOptions* options = nullptr); + WNNGraphBuilder CreateGraphBuilder(const WNNContext context); // Returns the underlying WNNInstance object. WNNInstance Get() const; diff --git a/include/webnn/webnn.h b/include/webnn/webnn.h new file mode 100644 index 000000000..f45f3b601 --- /dev/null +++ b/include/webnn/webnn.h @@ -0,0 +1 @@ +#include "dawn/webnn.h" diff --git a/include/webnn/webnn_cpp.h b/include/webnn/webnn_cpp.h new file mode 100644 index 000000000..666ea57bf --- /dev/null +++ b/include/webnn/webnn_cpp.h @@ -0,0 +1 @@ +#include diff --git a/include/webnn/webnn_proc.h b/include/webnn/webnn_proc.h index 959453125..2bf0228b3 100644 --- a/include/webnn/webnn_proc.h +++ b/include/webnn/webnn_proc.h @@ -16,8 +16,8 @@ #ifndef WEBNN_WEBNN_PROC_H_ #define WEBNN_WEBNN_PROC_H_ +#include "dawn/webnn_proc_table.h" #include "webnn/webnn.h" -#include "webnn/webnn_proc_table.h" #ifdef __cplusplus extern "C" { @@ -28,7 +28,7 @@ extern "C" { // default value of the proctable. Setting the proctable back to null is good practice when you // are done using libdawn_proc since further usage will cause a segfault instead of calling an // unexpected function. -WEBNN_EXPORT void webnnProcSetProcs(const WebnnProcTable* procs); +WNN_EXPORT void webnnProcSetProcs(const WebnnProcTable* procs); #ifdef __cplusplus } // extern "C" diff --git a/include/webnn/webnn_proc_table.h b/include/webnn/webnn_proc_table.h new file mode 100644 index 000000000..390201d09 --- /dev/null +++ b/include/webnn/webnn_proc_table.h @@ -0,0 +1 @@ +#include "dawn/webnn_proc_table.h" diff --git a/include/webnn/webnn_thread_dispatch_proc.h b/include/webnn/webnn_thread_dispatch_proc.h new file mode 100644 index 000000000..bf0fb012f --- /dev/null +++ b/include/webnn/webnn_thread_dispatch_proc.h @@ -0,0 +1,33 @@ +// Copyright 2020 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DAWN_WEBNN_THREAD_DISPATCH_PROC_H_ +#define DAWN_WEBNN_THREAD_DISPATCH_PROC_H_ + +#include "webnn_proc.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// Call webnnProcSetProcs(&webnnThreadDispatchProcTable) and then use webnnProcSetPerThreadProcs +// to set per-thread procs. +WNN_EXPORT extern WebnnProcTable webnnThreadDispatchProcTable; +WNN_EXPORT void webnnProcSetPerThreadProcs(const WebnnProcTable* procs); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // DAWN_WEBNN_THREAD_DISPATCH_PROC_H_ diff --git a/node/src/GraphBuilder.cpp b/node/src/GraphBuilder.cpp index d04c657bb..392096d50 100644 --- a/node/src/GraphBuilder.cpp +++ b/node/src/GraphBuilder.cpp @@ -79,7 +79,7 @@ namespace node { : Napi::ObjectWrap(info) { Napi::Object object = info[0].As(); node::Context* context = Napi::ObjectWrap::Unwrap(object); - mImpl = wnn::CreateGraphBuilder(context->GetImpl()); + mImpl = utils::CreateGraphBuilder(context->GetImpl()); } Napi::Value GraphBuilder::Constant(const Napi::CallbackInfo& info) { diff --git a/scripts/webnn_overrides_with_defaults.gni b/scripts/webnn_overrides_with_defaults.gni index 2f90d69a9..ab2d22d0a 100644 --- a/scripts/webnn_overrides_with_defaults.gni +++ b/scripts/webnn_overrides_with_defaults.gni @@ -17,8 +17,8 @@ if (!defined(webnn_standalone)) { webnn_standalone = false } -if (!defined(webnn_dawn_root)) { - webnn_dawn_root = "//third_party/dawn" +if (!defined(dawn_root)) { + dawn_root = "//third_party/dawn" } if (!defined(webnn_jinja2_dir)) { @@ -33,3 +33,7 @@ webnn_gen_root = get_path_info("${webnn_root}", "gen_dir") if (!defined(webnn_googletest_dir)) { webnn_googletest_dir = "//third_party/googletest" } + +if (!defined(webnn_abseil_dir)) { + webnn_abseil_dir = "//third_party/abseil-cpp" +} diff --git a/src/webnn/BUILD.gn b/src/webnn/BUILD.gn index 0c025e22b..576dc8144 100644 --- a/src/webnn/BUILD.gn +++ b/src/webnn/BUILD.gn @@ -15,7 +15,7 @@ import("../../scripts/webnn_overrides_with_defaults.gni") -import("${webnn_dawn_root}/scripts/dawn_component.gni") +import("${dawn_root}/scripts/dawn_component.gni") import("${webnn_root}/generator/webnn_generator.gni") ############################################################################### @@ -24,7 +24,7 @@ import("${webnn_root}/generator/webnn_generator.gni") webnn_json_generator("cpp_gen") { target = "cpp" - outputs = [ "src/webnn/webnn_cpp.cpp" ] + outputs = [ "src/dawn/webnn_cpp.cpp" ] } source_set("cpp") { @@ -41,11 +41,14 @@ source_set("cpp") { webnn_json_generator("proc_gen") { target = "proc" - outputs = [ "src/webnn/webnn_proc.c" ] + outputs = [ + "src/dawn/webnn_proc.c", + "src/dawn/webnn_thread_dispatch_proc.cpp", + ] } dawn_component("webnn_proc") { - DEFINE_PREFIX = "WEBNN" + DEFINE_PREFIX = "WNN" public_deps = [ "${webnn_root}/include/webnn:headers" ] deps = [ ":proc_gen" ] @@ -59,7 +62,10 @@ dawn_component("webnn_proc") { webnn_json_generator("emscripten_bits_gen") { target = "emscripten_bits" outputs = [ - "src/webnn/webnn_struct_info.json", - "src/webnn/library_webnn_enum_tables.js", + "emscripten-bits/library_webnn_enum_tables.js", + "emscripten-bits/webnn.h", + "emscripten-bits/webnn_cpp.cpp", + "emscripten-bits/webnn_cpp.h", + "emscripten-bits/webnn_struct_info.json", ] } diff --git a/src/webnn/common/BUILD.gn b/src/webnn/common/BUILD.gn index 21b0226c6..872983b6d 100644 --- a/src/webnn/common/BUILD.gn +++ b/src/webnn/common/BUILD.gn @@ -16,8 +16,8 @@ import("../../../scripts/webnn_overrides_with_defaults.gni") import("//build_overrides/build.gni") -import("${webnn_dawn_root}/scripts/dawn_features.gni") -import("${webnn_dawn_root}/scripts/dawn_overrides_with_defaults.gni") +import("${dawn_root}/scripts/dawn_features.gni") +import("${dawn_root}/scripts/dawn_overrides_with_defaults.gni") import("${webnn_root}/build_overrides/webnn_features.gni") # Use Chromium's dcheck_always_on when available so that we respect it when @@ -41,11 +41,16 @@ config("internal_config") { include_dirs = [ "${target_gen_dir}/../../../src", "${webnn_root}/src", - "${webnn_dawn_root}/src", - "${webnn_dawn_root}/src/dawn", - "${webnn_root}", + "${dawn_root}/src", + "${dawn_root}/src/dawn", + "${webnn_root}/include", + "${dawn_root}/include", ] + if (build_with_chromium) { + include_dirs += [ "${dawn_gen_root}/include" ] + } + defines = [] if (dawn_always_assert || dcheck_always_on || is_debug || use_fuzzing_engine) { @@ -190,42 +195,22 @@ config("internal_config") { # systems we know Dawn is able to compile on. if (is_win || is_linux || is_chromeos || is_mac || is_fuchsia || is_android) { static_library("common") { - if (build_with_chromium) { - sources = [ - "//third_party/dawn/src/dawn/common/Assert.cpp", - "//third_party/dawn/src/dawn/common/Assert.h", - "//third_party/dawn/src/dawn/common/Compiler.h", - "//third_party/dawn/src/dawn/common/Log.cpp", - "//third_party/dawn/src/dawn/common/Log.h", - "//third_party/dawn/src/dawn/common/Math.cpp", - "//third_party/dawn/src/dawn/common/Math.h", - "//third_party/dawn/src/dawn/common/Platform.h", - "//third_party/dawn/src/dawn/common/RefCounted.cpp", - "//third_party/dawn/src/dawn/common/RefCounted.h", - "//third_party/dawn/src/dawn/common/Result.cpp", - "//third_party/dawn/src/dawn/common/Result.h", - "//third_party/dawn/src/dawn/common/SystemUtils.cpp", - "//third_party/dawn/src/dawn/common/SystemUtils.h", - ] - } else { - # TODO: Remove after upgrading webnn infranstructure align with dawn. - sources = [ - "//third_party/dawn/src/common/Assert.cpp", - "//third_party/dawn/src/common/Assert.h", - "//third_party/dawn/src/common/Compiler.h", - "//third_party/dawn/src/common/Log.cpp", - "//third_party/dawn/src/common/Log.h", - "//third_party/dawn/src/common/Math.cpp", - "//third_party/dawn/src/common/Math.h", - "//third_party/dawn/src/common/Platform.h", - "//third_party/dawn/src/common/RefCounted.cpp", - "//third_party/dawn/src/common/RefCounted.h", - "//third_party/dawn/src/common/Result.cpp", - "//third_party/dawn/src/common/Result.h", - "//third_party/dawn/src/common/SystemUtils.cpp", - "//third_party/dawn/src/common/SystemUtils.h", - ] - } + sources = [ + "${dawn_root}/src/dawn/common/Assert.cpp", + "${dawn_root}/src/dawn/common/Assert.h", + "${dawn_root}/src/dawn/common/Compiler.h", + "${dawn_root}/src/dawn/common/Log.cpp", + "${dawn_root}/src/dawn/common/Log.h", + "${dawn_root}/src/dawn/common/Math.cpp", + "${dawn_root}/src/dawn/common/Math.h", + "${dawn_root}/src/dawn/common/Platform.h", + "${dawn_root}/src/dawn/common/RefCounted.cpp", + "${dawn_root}/src/dawn/common/RefCounted.h", + "${dawn_root}/src/dawn/common/Result.cpp", + "${dawn_root}/src/dawn/common/Result.h", + "${dawn_root}/src/dawn/common/SystemUtils.cpp", + "${dawn_root}/src/dawn/common/SystemUtils.h", + ] public_configs = [ ":internal_config" ] deps = [ diff --git a/src/webnn/native/BUILD.gn b/src/webnn/native/BUILD.gn index 0815b96be..76f9e4340 100644 --- a/src/webnn/native/BUILD.gn +++ b/src/webnn/native/BUILD.gn @@ -16,8 +16,8 @@ import("../../../scripts/webnn_overrides_with_defaults.gni") import("//build_overrides/build.gni") -import("${webnn_dawn_root}/scripts/dawn_component.gni") -import("${webnn_dawn_root}/scripts/dawn_features.gni") +import("${dawn_root}/scripts/dawn_component.gni") +import("${dawn_root}/scripts/dawn_features.gni") import("${webnn_root}/build_overrides/webnn_features.gni") import("${webnn_root}/generator/webnn_generator.gni") @@ -30,6 +30,17 @@ if (is_mac) { } } +group("abseil") { + # When build_with_chromium=true we need to include "//third_party/abseil-cpp:absl" while + # it's beneficial to be more specific with standalone Dawn, especially when it comes to + # including it as a dependency in other projects (such as Skia). + if (build_with_chromium) { + public_deps = [ "$webnn_abseil_dir:absl" ] + } else { + public_deps = [ "${webnn_root}/third_party/gn/abseil-cpp:str_format" ] + } +} + config("internal") { configs = [ "${webnn_root}/src/webnn/common:internal_config" ] @@ -51,12 +62,18 @@ config("internal") { webnn_json_generator("utils_gen") { target = "native_utils" outputs = [ + "src/webnn/native/ChainUtils_autogen.cpp", + "src/webnn/native/ChainUtils_autogen.h", + "src/webnn/native/ObjectType_autogen.cpp", + "src/webnn/native/ObjectType_autogen.h", "src/webnn/native/ProcTable.cpp", - "src/webnn/native/webnn_structs_autogen.h", - "src/webnn/native/webnn_structs_autogen.cpp", - "src/webnn/native/ValidationUtils_autogen.h", "src/webnn/native/ValidationUtils_autogen.cpp", + "src/webnn/native/ValidationUtils_autogen.h", + "src/webnn/native/webnn_absl_format_autogen.cpp", + "src/webnn/native/webnn_absl_format_autogen.h", "src/webnn/native/webnn_platform_autogen.h", + "src/webnn/native/wnn_structs_autogen.cpp", + "src/webnn/native/wnn_structs_autogen.h", ] } @@ -89,6 +106,8 @@ source_set("sources") { configs += [ ":internal" ] + public_deps = [ ":abseil" ] + sources = get_target_outputs(":utils_gen") sources += [ @@ -234,7 +253,7 @@ source_set("sources") { include_dirs += [ "${webnn_root}/third_party/microsoft.ai.directml.1.8.2/include" ] if (build_with_chromium) { - include_dirs += [ "${webnn_dawn_root}/src/dawn/native/dml/deps/src" ] + include_dirs += [ "${dawn_root}/src/dawn/native/dml/deps/src" ] } else { include_dirs += [ "${webnn_root}/third_party/DirectML/Libraries" ] } @@ -255,7 +274,7 @@ source_set("sources") { "dawn_wire.dll.lib", ] } - deps += [ "${webnn_dawn_root}/src/dawn/native" ] + deps += [ "${dawn_root}/src/dawn/native" ] } cflags = [ diff --git a/src/webnn/native/Context.cpp b/src/webnn/native/Context.cpp index dc7152502..e490c824c 100644 --- a/src/webnn/native/Context.cpp +++ b/src/webnn/native/Context.cpp @@ -65,7 +65,7 @@ namespace webnn::native { } #endif - void ContextBase::InjectError(wnn::ErrorType type, const char* message) { + void ContextBase::APIInjectError(wnn::ErrorType type, const char* message) { if (ConsumedError(ValidateErrorType(type))) { return; } @@ -80,14 +80,14 @@ namespace webnn::native { HandleError(DAWN_MAKE_ERROR(FromWNNErrorType(type), message)); } - void ContextBase::PushErrorScope(wnn::ErrorFilter filter) { + void ContextBase::APIPushErrorScope(wnn::ErrorFilter filter) { if (ConsumedError(ValidateErrorFilter(filter))) { return; } mCurrentErrorScope = AcquireRef(new ErrorScope(filter, mCurrentErrorScope.Get())); } - bool ContextBase::PopErrorScope(wnn::ErrorCallback callback, void* userdata) { + bool ContextBase::APIPopErrorScope(wnn::ErrorCallback callback, void* userdata) { if (DAWN_UNLIKELY(mCurrentErrorScope.Get() == mRootErrorScope.Get())) { return false; } @@ -97,7 +97,7 @@ namespace webnn::native { return true; } - void ContextBase::SetUncapturedErrorCallback(wnn::ErrorCallback callback, void* userdata) { + void ContextBase::APISetUncapturedErrorCallback(wnn::ErrorCallback callback, void* userdata) { mRootErrorScope->SetCallback(callback, userdata); } diff --git a/src/webnn/native/Context.h b/src/webnn/native/Context.h index d01cc16c8..4a5b4e845 100644 --- a/src/webnn/native/Context.h +++ b/src/webnn/native/Context.h @@ -58,11 +58,12 @@ namespace webnn::native { WGPUDevice GetWGPUDevice(); #endif - // Dawn API - void InjectError(wnn::ErrorType type, const char* message); - void PushErrorScope(wnn::ErrorFilter filter); - bool PopErrorScope(wnn::ErrorCallback callback, void* userdata); - void SetUncapturedErrorCallback(wnn::ErrorCallback callback, void* userdata); + // Webnn API + void APIInjectError(wnn::ErrorType type, const char* message); + void APIPushErrorScope(wnn::ErrorFilter filter); + bool APIPopErrorScope(wnn::ErrorCallback callback, void* userdata); + void APISetUncapturedErrorCallback(wnn::ErrorCallback callback, void* userdata); + ContextOptions GetContextOptions() { return mContextOptions; } diff --git a/src/webnn/native/Error.h b/src/webnn/native/Error.h index 9bcb918c5..7084fa2d3 100644 --- a/src/webnn/native/Error.h +++ b/src/webnn/native/Error.h @@ -16,8 +16,10 @@ #ifndef WEBNN_NATIVE_ERROR_H_ #define WEBNN_NATIVE_ERROR_H_ +#include "absl/strings/str_format.h" #include "common/Result.h" #include "webnn/native/ErrorData.h" +#include "webnn/native/webnn_absl_format_autogen.h" #include @@ -80,11 +82,11 @@ namespace webnn::native { DAWN_MAKE_ERROR(InternalErrorType::Internal, std::string("Unimplemented: ") + MESSAGE) #define DAWN_OUT_OF_MEMORY_ERROR(MESSAGE) DAWN_MAKE_ERROR(InternalErrorType::OutOfMemory, MESSAGE) -#define DAWN_INVALID_IF(EXPR, ...) \ - if (DAWN_UNLIKELY(EXPR)) { \ - return DAWN_MAKE_ERROR(InternalErrorType::Validation, __VA_ARGS__); \ - } \ - for (;;) \ +#define DAWN_INVALID_IF(EXPR, ...) \ + if (DAWN_UNLIKELY(EXPR)) { \ + return DAWN_MAKE_ERROR(InternalErrorType::Validation, absl::StrFormat(__VA_ARGS__)); \ + } \ + for (;;) \ break #define DAWN_CONCAT1(x, y) x##y diff --git a/src/webnn/native/Forward.h b/src/webnn/native/Forward.h index c979553cd..36d959568 100644 --- a/src/webnn/native/Forward.h +++ b/src/webnn/native/Forward.h @@ -23,9 +23,10 @@ class Ref; namespace webnn::native { - class CompilationBase; + class ContextBase; class GraphBase; class GraphBuilderBase; + class InstanceBase; class NamedInputsBase; class NamedOperandsBase; class NamedOutputsBase; diff --git a/src/webnn/native/Graph.cpp b/src/webnn/native/Graph.cpp index d8359d042..ef5d8bd5d 100644 --- a/src/webnn/native/Graph.cpp +++ b/src/webnn/native/Graph.cpp @@ -137,14 +137,14 @@ namespace webnn::native { return CompileImpl(); } - void GraphBase::Compute(NamedInputsBase* inputs, NamedOutputsBase* outputs) { + void GraphBase::APICompute(NamedInputsBase* inputs, NamedOutputsBase* outputs) { GetContext()->ConsumedError(ComputeImpl(inputs, outputs)); } - void GraphBase::ComputeAsync(NamedInputsBase* inputs, - NamedOutputsBase* outputs, - WNNComputeAsyncCallback callback, - void* userdata) { + void GraphBase::APIComputeAsync(NamedInputsBase* inputs, + NamedOutputsBase* outputs, + WNNComputeAsyncCallback callback, + void* userdata) { if (inputs == nullptr || outputs == nullptr) { callback(WNNErrorType_Validation, "named inputs or outputs is empty.", userdata); } diff --git a/src/webnn/native/Graph.h b/src/webnn/native/Graph.h index 7709a3855..a53f6d7c8 100644 --- a/src/webnn/native/Graph.h +++ b/src/webnn/native/Graph.h @@ -82,11 +82,11 @@ namespace webnn::native { virtual MaybeError Compile(); // Webnn API - void Compute(NamedInputsBase* inputs, NamedOutputsBase* outputs); - void ComputeAsync(NamedInputsBase* inputs, - NamedOutputsBase* outputs, - WNNComputeAsyncCallback callback, - void* userdata); + void APICompute(NamedInputsBase* inputs, NamedOutputsBase* outputs); + void APIComputeAsync(NamedInputsBase* inputs, + NamedOutputsBase* outputs, + WNNComputeAsyncCallback callback, + void* userdata); GraphBase(ContextBase* context, ObjectBase::ErrorTag tag); static GraphBase* MakeError(ContextBase* context); diff --git a/src/webnn/native/GraphBuilder.cpp b/src/webnn/native/GraphBuilder.cpp index 26a4ac08e..a8511f2ac 100644 --- a/src/webnn/native/GraphBuilder.cpp +++ b/src/webnn/native/GraphBuilder.cpp @@ -70,40 +70,41 @@ namespace webnn::native { GraphBuilderBase::GraphBuilderBase(ContextBase* context) : ObjectBase(context) { } - OperandBase* GraphBuilderBase::Abs(OperandBase* input) { + OperandBase* GraphBuilderBase::APIAbs(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kAbs, input)); } - OperandBase* GraphBuilderBase::Add(OperandBase* a, OperandBase* b) { + OperandBase* GraphBuilderBase::APIAdd(OperandBase* a, OperandBase* b) { VALIDATE_FOR_OPERAND(new op::Binary(this, op::BinaryOpType::kAdd, a, b)); } - OperandBase* GraphBuilderBase::AveragePool2d(OperandBase* input, Pool2dOptions const* options) { + OperandBase* GraphBuilderBase::APIAveragePool2d(OperandBase* input, + Pool2dOptions const* options) { VALIDATE_FOR_OPERAND(new op::Pool2d(this, op::Pool2dType::kAveragePool2d, input, options)); } - OperandBase* GraphBuilderBase::BatchNorm(OperandBase* input, - OperandBase* mean, - OperandBase* variance, - BatchNormOptions const* options) { + OperandBase* GraphBuilderBase::APIBatchNorm(OperandBase* input, + OperandBase* mean, + OperandBase* variance, + BatchNormOptions const* options) { VALIDATE_FOR_OPERAND(new op::BatchNorm(this, input, mean, variance, options)); } - OperandBase* GraphBuilderBase::Clamp(OperandBase* input, ClampOptions const* options) { + OperandBase* GraphBuilderBase::APIClamp(OperandBase* input, ClampOptions const* options) { VALIDATE_FOR_OPERAND(new op::Clamp(this, input, options)); } - FusionOperatorBase* GraphBuilderBase::ClampOperator(ClampOptions const* options) { + FusionOperatorBase* GraphBuilderBase::APIClampOperator(ClampOptions const* options) { return new op::FusionClamp(this, options); } - OperandBase* GraphBuilderBase::Ceil(OperandBase* input) { + OperandBase* GraphBuilderBase::APICeil(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kCeil, input)); } - OperandBase* GraphBuilderBase::Concat(uint32_t inputsCount, - OperandBase* const* inputs, - uint32_t axis) { + OperandBase* GraphBuilderBase::APIConcat(uint32_t inputsCount, + OperandBase* const* inputs, + uint32_t axis) { std::vector> operandInputs; operandInputs.reserve(inputsCount); for (uint32_t i = 0; i < inputsCount; ++i) { @@ -112,13 +113,13 @@ namespace webnn::native { VALIDATE_FOR_OPERAND(new op::Concat(this, std::move(operandInputs), axis)); } - OperandBase* GraphBuilderBase::Constant(OperandDescriptor const* desc, - ArrayBufferView const* arrayBuffer) { + OperandBase* GraphBuilderBase::APIConstant(OperandDescriptor const* desc, + ArrayBufferView const* arrayBuffer) { VALIDATE_FOR_OPERAND(new op::Constant(this, desc, arrayBuffer)); } - OperandBase* GraphBuilderBase::ConstantWithGpuBuffer(OperandDescriptor const* desc, - GpuBufferView const* gpuBuffer) { + OperandBase* GraphBuilderBase::APIConstantWithGpuBuffer(OperandDescriptor const* desc, + GpuBufferView const* gpuBuffer) { #if defined(WEBNN_ENABLE_GPU_BUFFER) VALIDATE_FOR_OPERAND(new op::Constant(this, desc, gpuBuffer)); #endif @@ -126,104 +127,105 @@ namespace webnn::native { return nullptr; } - OperandBase* GraphBuilderBase::Conv2d(OperandBase* input, - OperandBase* filter, - Conv2dOptions const* options) { + OperandBase* GraphBuilderBase::APIConv2d(OperandBase* input, + OperandBase* filter, + Conv2dOptions const* options) { VALIDATE_FOR_OPERAND(new op::Conv2d(this, input, filter, options)); } - OperandBase* GraphBuilderBase::ConvTranspose2d(OperandBase* input, - OperandBase* filter, - ConvTranspose2dOptions const* options) { + OperandBase* GraphBuilderBase::APIConvTranspose2d(OperandBase* input, + OperandBase* filter, + ConvTranspose2dOptions const* options) { VALIDATE_FOR_OPERAND(new op::ConvTranspose2d(this, input, filter, options)); } - OperandBase* GraphBuilderBase::Cos(OperandBase* input) { + OperandBase* GraphBuilderBase::APICos(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kCos, input)); } - OperandBase* GraphBuilderBase::Div(OperandBase* a, OperandBase* b) { + OperandBase* GraphBuilderBase::APIDiv(OperandBase* a, OperandBase* b) { VALIDATE_FOR_OPERAND(new op::Binary(this, op::BinaryOpType::kDiv, a, b)); } - OperandBase* GraphBuilderBase::Exp(OperandBase* input) { + OperandBase* GraphBuilderBase::APIExp(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kExp, input)); } - OperandBase* GraphBuilderBase::Floor(OperandBase* input) { + OperandBase* GraphBuilderBase::APIFloor(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kFloor, input)); } - OperandBase* GraphBuilderBase::Gemm(OperandBase* a, - OperandBase* b, - GemmOptions const* options) { + OperandBase* GraphBuilderBase::APIGemm(OperandBase* a, + OperandBase* b, + GemmOptions const* options) { VALIDATE_FOR_OPERAND(new op::Gemm(this, a, b, options)); } - OperandArrayBase* GraphBuilderBase::Gru(OperandBase* input, - OperandBase* weight, - OperandBase* recurrentWeight, - int32_t steps, - int32_t hiddenSize, - GruOptions const* options) { + OperandArrayBase* GraphBuilderBase::APIGru(OperandBase* input, + OperandBase* weight, + OperandBase* recurrentWeight, + int32_t steps, + int32_t hiddenSize, + GruOptions const* options) { VALIDATE_ARRAY_OPERAND( new op::Gru(this, input, weight, recurrentWeight, steps, hiddenSize, options)); } - OperandBase* GraphBuilderBase::HardSwish(OperandBase* input) { + OperandBase* GraphBuilderBase::APIHardSwish(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kHardSwish, input)); } - FusionOperatorBase* GraphBuilderBase::HardSwishOperator() { + FusionOperatorBase* GraphBuilderBase::APIHardSwishOperator() { return new op::FusionUnary(this, FusionType::HardSwish); } - OperandBase* GraphBuilderBase::Input(char const* name, OperandDescriptor const* desc) { + OperandBase* GraphBuilderBase::APIInput(char const* name, OperandDescriptor const* desc) { VALIDATE_FOR_OPERAND(new op::Input(this, std::string(name), desc)); } - OperandBase* GraphBuilderBase::InstanceNorm(OperandBase* input, - InstanceNormOptions const* options) { + OperandBase* GraphBuilderBase::APIInstanceNorm(OperandBase* input, + InstanceNormOptions const* options) { VALIDATE_FOR_OPERAND(new op::InstanceNorm(this, input, options)); } - OperandBase* GraphBuilderBase::LeakyRelu(OperandBase* input, LeakyReluOptions const* options) { + OperandBase* GraphBuilderBase::APILeakyRelu(OperandBase* input, + LeakyReluOptions const* options) { VALIDATE_FOR_OPERAND(new op::LeakyRelu(this, input, options)); } - FusionOperatorBase* GraphBuilderBase::LeakyReluOperator(LeakyReluOptions const* options) { + FusionOperatorBase* GraphBuilderBase::APILeakyReluOperator(LeakyReluOptions const* options) { return new op::FusionLeakyRelu(this, options); } - OperandBase* GraphBuilderBase::Log(OperandBase* input) { + OperandBase* GraphBuilderBase::APILog(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kLog, input)); } - OperandBase* GraphBuilderBase::L2Pool2d(OperandBase* input, Pool2dOptions const* options) { + OperandBase* GraphBuilderBase::APIL2Pool2d(OperandBase* input, Pool2dOptions const* options) { VALIDATE_FOR_OPERAND(new op::Pool2d(this, op::Pool2dType::kL2Pool2d, input, options)); } - OperandBase* GraphBuilderBase::Matmul(OperandBase* a, OperandBase* b) { + OperandBase* GraphBuilderBase::APIMatmul(OperandBase* a, OperandBase* b) { VALIDATE_FOR_OPERAND(new op::Binary(this, op::BinaryOpType::kMatMul, a, b)); } - OperandBase* GraphBuilderBase::Max(OperandBase* a, OperandBase* b) { + OperandBase* GraphBuilderBase::APIMax(OperandBase* a, OperandBase* b) { VALIDATE_FOR_OPERAND(new op::Binary(this, op::BinaryOpType::kMax, a, b)); } - OperandBase* GraphBuilderBase::MaxPool2d(OperandBase* input, Pool2dOptions const* options) { + OperandBase* GraphBuilderBase::APIMaxPool2d(OperandBase* input, Pool2dOptions const* options) { VALIDATE_FOR_OPERAND(new op::Pool2d(this, op::Pool2dType::kMaxPool2d, input, options)); } - OperandBase* GraphBuilderBase::Min(OperandBase* a, OperandBase* b) { + OperandBase* GraphBuilderBase::APIMin(OperandBase* a, OperandBase* b) { VALIDATE_FOR_OPERAND(new op::Binary(this, op::BinaryOpType::kMin, a, b)); } - OperandBase* GraphBuilderBase::Mul(OperandBase* a, OperandBase* b) { + OperandBase* GraphBuilderBase::APIMul(OperandBase* a, OperandBase* b) { VALIDATE_FOR_OPERAND(new op::Binary(this, op::BinaryOpType::kMul, a, b)); } - OperandBase* GraphBuilderBase::Neg(OperandBase* input) { + OperandBase* GraphBuilderBase::APINeg(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kNeg, input)); } @@ -233,125 +235,129 @@ namespace webnn::native { // PadOptions const* options) { // VALIDATE_FOR_OPERAND(new op::Pad(this, input, padding, padding_count, options)); // } - OperandBase* GraphBuilderBase::Pad(OperandBase* input, - OperandBase* padding, - PadOptions const* options) { + OperandBase* GraphBuilderBase::APIPad(OperandBase* input, + OperandBase* padding, + PadOptions const* options) { VALIDATE_FOR_OPERAND(new op::Pad(this, input, padding, options)); } - OperandBase* GraphBuilderBase::Pow(OperandBase* a, OperandBase* b) { + OperandBase* GraphBuilderBase::APIPow(OperandBase* a, OperandBase* b) { VALIDATE_FOR_OPERAND(new op::Binary(this, op::BinaryOpType::kPower, a, b)); } - OperandBase* GraphBuilderBase::ReduceArgMax(OperandBase* input, ReduceOptions const* options) { + OperandBase* GraphBuilderBase::APIReduceArgMax(OperandBase* input, + ReduceOptions const* options) { VALIDATE_FOR_OPERAND(new op::Reduce(this, op::ReduceType::kReduceArgMax, input, options)); } - OperandBase* GraphBuilderBase::ReduceArgMin(OperandBase* input, ReduceOptions const* options) { + OperandBase* GraphBuilderBase::APIReduceArgMin(OperandBase* input, + ReduceOptions const* options) { VALIDATE_FOR_OPERAND(new op::Reduce(this, op::ReduceType::kReduceArgMin, input, options)); } - OperandBase* GraphBuilderBase::ReduceL2(OperandBase* input, ReduceOptions const* options) { + OperandBase* GraphBuilderBase::APIReduceL2(OperandBase* input, ReduceOptions const* options) { VALIDATE_FOR_OPERAND(new op::Reduce(this, op::ReduceType::kReduceL2, input, options)); } - OperandBase* GraphBuilderBase::ReduceL1(OperandBase* input, ReduceOptions const* options) { + OperandBase* GraphBuilderBase::APIReduceL1(OperandBase* input, ReduceOptions const* options) { VALIDATE_FOR_OPERAND(new op::Reduce(this, op::ReduceType::kReduceL1, input, options)); } - OperandBase* GraphBuilderBase::ReduceMax(OperandBase* input, ReduceOptions const* options) { + OperandBase* GraphBuilderBase::APIReduceMax(OperandBase* input, ReduceOptions const* options) { VALIDATE_FOR_OPERAND(new op::Reduce(this, op::ReduceType::kReduceMax, input, options)); } - OperandBase* GraphBuilderBase::ReduceMean(OperandBase* input, ReduceOptions const* options) { + OperandBase* GraphBuilderBase::APIReduceMean(OperandBase* input, ReduceOptions const* options) { VALIDATE_FOR_OPERAND(new op::Reduce(this, op::ReduceType::kReduceMean, input, options)); } - OperandBase* GraphBuilderBase::ReduceMin(OperandBase* input, ReduceOptions const* options) { + OperandBase* GraphBuilderBase::APIReduceMin(OperandBase* input, ReduceOptions const* options) { VALIDATE_FOR_OPERAND(new op::Reduce(this, op::ReduceType::kReduceMin, input, options)); } - OperandBase* GraphBuilderBase::ReduceProduct(OperandBase* input, ReduceOptions const* options) { + OperandBase* GraphBuilderBase::APIReduceProduct(OperandBase* input, + ReduceOptions const* options) { VALIDATE_FOR_OPERAND(new op::Reduce(this, op::ReduceType::kReduceProduct, input, options)); } - OperandBase* GraphBuilderBase::ReduceSum(OperandBase* input, ReduceOptions const* options) { + OperandBase* GraphBuilderBase::APIReduceSum(OperandBase* input, ReduceOptions const* options) { VALIDATE_FOR_OPERAND(new op::Reduce(this, op::ReduceType::kReduceSum, input, options)); } - OperandBase* GraphBuilderBase::Relu(OperandBase* input) { + OperandBase* GraphBuilderBase::APIRelu(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kRelu, input)); } - FusionOperatorBase* GraphBuilderBase::ReluOperator() { + FusionOperatorBase* GraphBuilderBase::APIReluOperator() { return new op::FusionUnary(this, FusionType::Relu); } - OperandBase* GraphBuilderBase::Resample2d(OperandBase* input, - Resample2dOptions const* options) { + OperandBase* GraphBuilderBase::APIResample2d(OperandBase* input, + Resample2dOptions const* options) { VALIDATE_FOR_OPERAND(new op::Resample2d(this, input, options)); } - OperandBase* GraphBuilderBase::Reshape(OperandBase* input, - int32_t const* new_shape, - size_t new_shape_count) { + OperandBase* GraphBuilderBase::APIReshape(OperandBase* input, + int32_t const* new_shape, + size_t new_shape_count) { VALIDATE_FOR_OPERAND(new op::Reshape(this, input, new_shape, new_shape_count)); } - OperandBase* GraphBuilderBase::Sigmoid(OperandBase* input) { + OperandBase* GraphBuilderBase::APISigmoid(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kSigmoid, input)); } - FusionOperatorBase* GraphBuilderBase::SigmoidOperator() { + FusionOperatorBase* GraphBuilderBase::APISigmoidOperator() { return new op::FusionUnary(this, FusionType::Sigmoid); } - OperandBase* GraphBuilderBase::Sin(OperandBase* input) { + OperandBase* GraphBuilderBase::APISin(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kSin, input)); } - OperandBase* GraphBuilderBase::Slice(OperandBase* input, - int32_t const* starts, - uint32_t startsCount, - int32_t const* sizes, - uint32_t sizesCount, - SliceOptions const* options) { + OperandBase* GraphBuilderBase::APISlice(OperandBase* input, + int32_t const* starts, + uint32_t startsCount, + int32_t const* sizes, + uint32_t sizesCount, + SliceOptions const* options) { VALIDATE_FOR_OPERAND( new op::Slice(this, input, starts, startsCount, sizes, sizesCount, options)); } - OperandBase* GraphBuilderBase::Softmax(OperandBase* input) { + OperandBase* GraphBuilderBase::APISoftmax(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kSoftmax, input)); } - OperandArrayBase* GraphBuilderBase::Split(OperandBase* input, - uint32_t const* splits, - uint32_t splitsCount, - SplitOptions const* options) { + OperandArrayBase* GraphBuilderBase::APISplit(OperandBase* input, + uint32_t const* splits, + uint32_t splitsCount, + SplitOptions const* options) { VALIDATE_ARRAY_OPERAND(new op::Split(this, input, splits, splitsCount, options)); } - OperandBase* GraphBuilderBase::Squeeze(OperandBase* input, SqueezeOptions const* options) { + OperandBase* GraphBuilderBase::APISqueeze(OperandBase* input, SqueezeOptions const* options) { VALIDATE_FOR_OPERAND(new op::Squeeze(this, input, options)); } - OperandBase* GraphBuilderBase::Sub(OperandBase* a, OperandBase* b) { + OperandBase* GraphBuilderBase::APISub(OperandBase* a, OperandBase* b) { VALIDATE_FOR_OPERAND(new op::Binary(this, op::BinaryOpType::kSub, a, b)); } - OperandBase* GraphBuilderBase::Tan(OperandBase* input) { + OperandBase* GraphBuilderBase::APITan(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kTan, input)); } - OperandBase* GraphBuilderBase::Tanh(OperandBase* input) { + OperandBase* GraphBuilderBase::APITanh(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kTanh, input)); } - FusionOperatorBase* GraphBuilderBase::TanhOperator() { + FusionOperatorBase* GraphBuilderBase::APITanhOperator() { return new op::FusionUnary(this, FusionType::Tanh); } - OperandBase* GraphBuilderBase::Transpose(OperandBase* input, TransposeOptions const* options) { + OperandBase* GraphBuilderBase::APITranspose(OperandBase* input, + TransposeOptions const* options) { VALIDATE_FOR_OPERAND(new op::Transpose(this, input, options)); } @@ -380,7 +386,7 @@ namespace webnn::native { return std::move(graph); } - GraphBase* GraphBuilderBase::Build(NamedOperandsBase const* namedOperands) { + GraphBase* GraphBuilderBase::APIBuild(NamedOperandsBase const* namedOperands) { Ref result = nullptr; if (GetContext()->ConsumedError(BuildImpl(namedOperands), &result)) { ASSERT(result == nullptr); diff --git a/src/webnn/native/GraphBuilder.h b/src/webnn/native/GraphBuilder.h index 58318a122..cd8ddd326 100644 --- a/src/webnn/native/GraphBuilder.h +++ b/src/webnn/native/GraphBuilder.h @@ -34,86 +34,86 @@ namespace webnn::native { virtual ~GraphBuilderBase() = default; // WebNN API - OperandBase* Abs(OperandBase*); - OperandBase* Add(OperandBase*, OperandBase*); - OperandBase* AveragePool2d(OperandBase*, Pool2dOptions const* options); - OperandBase* BatchNorm(OperandBase*, - OperandBase*, - OperandBase*, - BatchNormOptions const* options); - OperandBase* Clamp(OperandBase*, ClampOptions const* options); - FusionOperatorBase* ClampOperator(ClampOptions const* options); - OperandBase* Ceil(OperandBase*); - OperandBase* Concat(uint32_t inputsCount, OperandBase* const* inputs, uint32_t axis); - OperandBase* Constant(OperandDescriptor const* desc, ArrayBufferView const* arrayBuffer); - OperandBase* ConstantWithGpuBuffer(OperandDescriptor const* desc, - GpuBufferView const* arrayBuffer); - OperandBase* Conv2d(OperandBase*, OperandBase*, Conv2dOptions const* options); - OperandBase* ConvTranspose2d(OperandBase*, - OperandBase*, - ConvTranspose2dOptions const* options); - OperandBase* Cos(OperandBase*); - OperandBase* Div(OperandBase*, OperandBase*); - OperandBase* Exp(OperandBase*); - OperandBase* Floor(OperandBase*); - OperandBase* Gemm(OperandBase*, OperandBase*, GemmOptions const* options); - OperandArrayBase* Gru(OperandBase*, - OperandBase*, - OperandBase*, - int32_t steps, - int32_t hiddenSize, - GruOptions const* options); - OperandBase* HardSwish(OperandBase*); - FusionOperatorBase* HardSwishOperator(); - OperandBase* Input(char const* name, OperandDescriptor const* desc); - OperandBase* InstanceNorm(OperandBase*, InstanceNormOptions const* options); - OperandBase* LeakyRelu(OperandBase*, LeakyReluOptions const* options); - FusionOperatorBase* LeakyReluOperator(LeakyReluOptions const* options); - OperandBase* Log(OperandBase*); - OperandBase* L2Pool2d(OperandBase*, Pool2dOptions const* options); - OperandBase* Matmul(OperandBase* a, OperandBase* b); - OperandBase* Max(OperandBase*, OperandBase*); - OperandBase* MaxPool2d(OperandBase*, Pool2dOptions const* options); - OperandBase* Min(OperandBase*, OperandBase*); - OperandBase* Mul(OperandBase*, OperandBase*); - OperandBase* Neg(OperandBase*); - OperandBase* Pad(OperandBase*, OperandBase*, PadOptions const* options); - OperandBase* Pow(OperandBase*, OperandBase*); - OperandBase* ReduceArgMax(OperandBase*, ReduceOptions const* options); - OperandBase* ReduceArgMin(OperandBase*, ReduceOptions const* options); - OperandBase* ReduceL1(OperandBase*, ReduceOptions const* options); - OperandBase* ReduceL2(OperandBase*, ReduceOptions const* options); - OperandBase* ReduceMax(OperandBase*, ReduceOptions const* options); - OperandBase* ReduceMean(OperandBase*, ReduceOptions const* options); - OperandBase* ReduceMin(OperandBase*, ReduceOptions const* options); - OperandBase* ReduceProduct(OperandBase*, ReduceOptions const* options); - OperandBase* ReduceSum(OperandBase*, ReduceOptions const* options); - OperandBase* Relu(OperandBase*); - FusionOperatorBase* ReluOperator(); - OperandBase* Resample2d(OperandBase*, Resample2dOptions const* options); - OperandBase* Reshape(OperandBase*, int32_t const*, size_t); - OperandBase* Sigmoid(OperandBase*); - FusionOperatorBase* SigmoidOperator(); - OperandBase* Sin(OperandBase*); - OperandBase* Slice(OperandBase*, - int32_t const* starts, - uint32_t startsCount, - int32_t const* sizes, - uint32_t sizesCount, - SliceOptions const* options); - OperandBase* Softmax(OperandBase*); - OperandArrayBase* Split(OperandBase*, - uint32_t const*, - uint32_t, - SplitOptions const* options); - OperandBase* Squeeze(OperandBase*, SqueezeOptions const* options); - OperandBase* Sub(OperandBase*, OperandBase*); - OperandBase* Tan(OperandBase*); - OperandBase* Tanh(OperandBase*); - FusionOperatorBase* TanhOperator(); - OperandBase* Transpose(OperandBase*, TransposeOptions const* options); + OperandBase* APIAbs(OperandBase*); + OperandBase* APIAdd(OperandBase*, OperandBase*); + OperandBase* APIAveragePool2d(OperandBase*, Pool2dOptions const* options); + OperandBase* APIBatchNorm(OperandBase*, + OperandBase*, + OperandBase*, + BatchNormOptions const* options); + OperandBase* APIClamp(OperandBase*, ClampOptions const* options); + FusionOperatorBase* APIClampOperator(ClampOptions const* options); + OperandBase* APICeil(OperandBase*); + OperandBase* APIConcat(uint32_t inputsCount, OperandBase* const* inputs, uint32_t axis); + OperandBase* APIConstant(OperandDescriptor const* desc, ArrayBufferView const* arrayBuffer); + OperandBase* APIConstantWithGpuBuffer(OperandDescriptor const* desc, + GpuBufferView const* arrayBuffer); + OperandBase* APIConv2d(OperandBase*, OperandBase*, Conv2dOptions const* options); + OperandBase* APIConvTranspose2d(OperandBase*, + OperandBase*, + ConvTranspose2dOptions const* options); + OperandBase* APICos(OperandBase*); + OperandBase* APIDiv(OperandBase*, OperandBase*); + OperandBase* APIExp(OperandBase*); + OperandBase* APIFloor(OperandBase*); + OperandBase* APIGemm(OperandBase*, OperandBase*, GemmOptions const* options); + OperandArrayBase* APIGru(OperandBase*, + OperandBase*, + OperandBase*, + int32_t steps, + int32_t hiddenSize, + GruOptions const* options); + OperandBase* APIHardSwish(OperandBase*); + FusionOperatorBase* APIHardSwishOperator(); + OperandBase* APIInput(char const* name, OperandDescriptor const* desc); + OperandBase* APIInstanceNorm(OperandBase*, InstanceNormOptions const* options); + OperandBase* APILeakyRelu(OperandBase*, LeakyReluOptions const* options); + FusionOperatorBase* APILeakyReluOperator(LeakyReluOptions const* options); + OperandBase* APILog(OperandBase*); + OperandBase* APIL2Pool2d(OperandBase*, Pool2dOptions const* options); + OperandBase* APIMatmul(OperandBase* a, OperandBase* b); + OperandBase* APIMax(OperandBase*, OperandBase*); + OperandBase* APIMaxPool2d(OperandBase*, Pool2dOptions const* options); + OperandBase* APIMin(OperandBase*, OperandBase*); + OperandBase* APIMul(OperandBase*, OperandBase*); + OperandBase* APINeg(OperandBase*); + OperandBase* APIPad(OperandBase*, OperandBase*, PadOptions const* options); + OperandBase* APIPow(OperandBase*, OperandBase*); + OperandBase* APIReduceArgMax(OperandBase*, ReduceOptions const* options); + OperandBase* APIReduceArgMin(OperandBase*, ReduceOptions const* options); + OperandBase* APIReduceL1(OperandBase*, ReduceOptions const* options); + OperandBase* APIReduceL2(OperandBase*, ReduceOptions const* options); + OperandBase* APIReduceMax(OperandBase*, ReduceOptions const* options); + OperandBase* APIReduceMean(OperandBase*, ReduceOptions const* options); + OperandBase* APIReduceMin(OperandBase*, ReduceOptions const* options); + OperandBase* APIReduceProduct(OperandBase*, ReduceOptions const* options); + OperandBase* APIReduceSum(OperandBase*, ReduceOptions const* options); + OperandBase* APIRelu(OperandBase*); + FusionOperatorBase* APIReluOperator(); + OperandBase* APIResample2d(OperandBase*, Resample2dOptions const* options); + OperandBase* APIReshape(OperandBase*, int32_t const*, size_t); + OperandBase* APISigmoid(OperandBase*); + FusionOperatorBase* APISigmoidOperator(); + OperandBase* APISin(OperandBase*); + OperandBase* APISlice(OperandBase*, + int32_t const* starts, + uint32_t startsCount, + int32_t const* sizes, + uint32_t sizesCount, + SliceOptions const* options); + OperandBase* APISoftmax(OperandBase*); + OperandArrayBase* APISplit(OperandBase*, + uint32_t const*, + uint32_t, + SplitOptions const* options); + OperandBase* APISqueeze(OperandBase*, SqueezeOptions const* options); + OperandBase* APISub(OperandBase*, OperandBase*); + OperandBase* APITan(OperandBase*); + OperandBase* APITanh(OperandBase*); + FusionOperatorBase* APITanhOperator(); + OperandBase* APITranspose(OperandBase*, TransposeOptions const* options); - GraphBase* Build(NamedOperandsBase const* namedOperands); + GraphBase* APIBuild(NamedOperandsBase const* namedOperands); private: ResultOrError> BuildImpl(NamedOperandsBase const* namedOperands); diff --git a/src/webnn/native/Instance.cpp b/src/webnn/native/Instance.cpp index 2ef8f04ab..d31bb655a 100644 --- a/src/webnn/native/Instance.cpp +++ b/src/webnn/native/Instance.cpp @@ -175,7 +175,7 @@ namespace webnn::native { return mBackends[wnn::BackendType::Null]->CreateContext(options); } - ContextBase* InstanceBase::CreateContext(const ContextOptions* options) { + ContextBase* InstanceBase::APICreateContext(const ContextOptions* options) { if (mBackends.find(wnn::BackendType::DirectML) != mBackends.end()) { return mBackends[wnn::BackendType::DirectML]->CreateContext(options); } else if (mBackends.find(wnn::BackendType::OpenVINO) != mBackends.end()) { @@ -193,7 +193,7 @@ namespace webnn::native { return nullptr; } - ContextBase* InstanceBase::CreateContextWithGpuDevice(const GpuDevice* wnn_device) { + ContextBase* InstanceBase::APICreateContextWithGpuDevice(const GpuDevice* wnn_device) { #if defined(WEBNN_ENABLE_GPU_BUFFER) WGPUDevice device = reinterpret_cast(wnn_device->device); if (mBackends.find(wnn::BackendType::DirectML) != mBackends.end()) { @@ -210,23 +210,23 @@ namespace webnn::native { return nullptr; } - GraphBuilderBase* InstanceBase::CreateGraphBuilder(ContextBase* context) { + GraphBuilderBase* InstanceBase::APICreateGraphBuilder(ContextBase* context) { return new GraphBuilderBase(context); } - NamedInputsBase* InstanceBase::CreateNamedInputs() { + NamedInputsBase* InstanceBase::APICreateNamedInputs() { return new NamedInputsBase(); } - NamedOperandsBase* InstanceBase::CreateNamedOperands() { + NamedOperandsBase* InstanceBase::APICreateNamedOperands() { return new NamedOperandsBase(); } - NamedOutputsBase* InstanceBase::CreateNamedOutputs() { + NamedOutputsBase* InstanceBase::APICreateNamedOutputs() { return new NamedOutputsBase(); } - OperatorArrayBase* InstanceBase::CreateOperatorArray() { + OperatorArrayBase* InstanceBase::APICreateOperatorArray() { return new OperatorArrayBase(); } diff --git a/src/webnn/native/Instance.h b/src/webnn/native/Instance.h index 9a1a826bc..bc0b4fe96 100644 --- a/src/webnn/native/Instance.h +++ b/src/webnn/native/Instance.h @@ -39,13 +39,13 @@ namespace webnn::native { static InstanceBase* Create(const InstanceDescriptor* descriptor = nullptr); // WebNN API - ContextBase* CreateContext(const ContextOptions* options); - ContextBase* CreateContextWithGpuDevice(const GpuDevice* device); - GraphBuilderBase* CreateGraphBuilder(ContextBase* context); - NamedInputsBase* CreateNamedInputs(); - NamedOperandsBase* CreateNamedOperands(); - NamedOutputsBase* CreateNamedOutputs(); - OperatorArrayBase* CreateOperatorArray(); + ContextBase* APICreateContext(const ContextOptions* options); + ContextBase* APICreateContextWithGpuDevice(const GpuDevice* device); + GraphBuilderBase* APICreateGraphBuilder(ContextBase* context); + NamedInputsBase* APICreateNamedInputs(); + NamedOperandsBase* APICreateNamedOperands(); + NamedOutputsBase* APICreateNamedOutputs(); + OperatorArrayBase* APICreateOperatorArray(); ContextBase* CreateTestContext(const ContextOptions* options); diff --git a/src/webnn/native/NamedInputs.h b/src/webnn/native/NamedInputs.h index 4b73dc687..e4d0686b9 100644 --- a/src/webnn/native/NamedInputs.h +++ b/src/webnn/native/NamedInputs.h @@ -40,7 +40,7 @@ namespace webnn::native { } // WebNN API - void Set(char const* name, const Input* input) { + void APISet(char const* name, const Input* input) { mInputs[std::string(name)] = *input; #if defined(WEBNN_ENABLE_WIRE) // Input data type is Arrary Buffer View. diff --git a/src/webnn/native/NamedOperands.h b/src/webnn/native/NamedOperands.h index af6ddbdce..bd51dab20 100644 --- a/src/webnn/native/NamedOperands.h +++ b/src/webnn/native/NamedOperands.h @@ -25,7 +25,7 @@ namespace webnn::native { class NamedOperandsBase : public RefCounted { public: // WebNN API - void Set(char const* name, const OperandBase* operand) { + void APISet(char const* name, const OperandBase* operand) { mNamedOperands[std::string(name)] = operand; } diff --git a/src/webnn/native/NamedOutputs.h b/src/webnn/native/NamedOutputs.h index 11d4bdf4e..dace4dd3d 100644 --- a/src/webnn/native/NamedOutputs.h +++ b/src/webnn/native/NamedOutputs.h @@ -41,7 +41,7 @@ namespace webnn::native { } // WebNN API - void Set(char const* name, const Resource* resource) { + void APISetOutput(char const* name, const Resource* resource) { mOutputs[std::string(name)] = *resource; if (resource->gpuBufferView.buffer != nullptr) { #if defined(WEBNN_ENABLE_GPU_BUFFER) @@ -61,7 +61,7 @@ namespace webnn::native { } } - void Get(char const* name, ArrayBufferView* arrayBuffer) { + void APIGetOutput(char const* name, ArrayBufferView* arrayBuffer) { if (mOutputs.find(std::string(name)) == mOutputs.end()) { return; } diff --git a/src/webnn/native/OperandArray.h b/src/webnn/native/OperandArray.h index 1d74e7b8e..c584396d4 100644 --- a/src/webnn/native/OperandArray.h +++ b/src/webnn/native/OperandArray.h @@ -31,10 +31,10 @@ namespace webnn::native { return new OperandArrayBase(graphBuilder, ObjectBase::kError); } // WebNN API - size_t Size() { + size_t APISize() { return mOperands.size(); } - OperandBase* Get(size_t index) { + OperandBase* APIGetOperand(size_t index) { return mOperands[index].Get(); } diff --git a/src/webnn/native/OperatorArray.h b/src/webnn/native/OperatorArray.h index 686bd3e80..df94f659a 100644 --- a/src/webnn/native/OperatorArray.h +++ b/src/webnn/native/OperatorArray.h @@ -25,15 +25,15 @@ namespace webnn::native { virtual ~OperatorArrayBase() = default; // WebNN API - size_t Size() { + size_t APISize() { return mOperators.size(); } - void Set(FusionOperatorBase* mlOperator) { + void APISetFusionOperator(FusionOperatorBase* mlOperator) { mOperators.push_back(Ref(mlOperator)); } - FusionOperatorBase* Get(size_t index) { + FusionOperatorBase* APIGetFusionOperator(size_t index) { return mOperators[index].Get(); } diff --git a/src/webnn/native/Utils.h b/src/webnn/native/Utils.h index 42b1055e5..206bf6a53 100644 --- a/src/webnn/native/Utils.h +++ b/src/webnn/native/Utils.h @@ -15,7 +15,7 @@ #ifndef WEBNN_NATIVE_NATIVEUTILS_H_ #define WEBNN_NATIVE_NATIVEUTILS_H_ -#include +#include #include #include diff --git a/src/webnn/native/WebnnNative.cpp b/src/webnn/native/WebnnNative.cpp index 220b9d977..a5c1645a0 100644 --- a/src/webnn/native/WebnnNative.cpp +++ b/src/webnn/native/WebnnNative.cpp @@ -67,7 +67,12 @@ namespace webnn::native { WNNContext Instance::CreateContext(const wnn::ContextOptions* options) { return reinterpret_cast( - mImpl->CreateContext(reinterpret_cast(options))); + mImpl->APICreateContext(reinterpret_cast(options))); + } + + WNNGraphBuilder Instance::CreateGraphBuilder(const WNNContext context) { + return reinterpret_cast( + mImpl->APICreateGraphBuilder(reinterpret_cast(context))); } WNNInstance Instance::Get() const { diff --git a/src/webnn/native/dmlx/GraphDMLX.cpp b/src/webnn/native/dmlx/GraphDMLX.cpp index c048223c2..962053fef 100644 --- a/src/webnn/native/dmlx/GraphDMLX.cpp +++ b/src/webnn/native/dmlx/GraphDMLX.cpp @@ -1842,8 +1842,8 @@ namespace webnn::native::dmlx { fActivation = ::dml::FusedActivation::Sigmoid(); gActivation = ::dml::FusedActivation::Tanh(); } else { - fActivation = CreateFusedActivation(options->activations->Get(0)); - gActivation = CreateFusedActivation(options->activations->Get(1)); + fActivation = CreateFusedActivation(options->activations->APIGetFusionOperator(0)); + gActivation = CreateFusedActivation(options->activations->APIGetFusionOperator(1)); } std::vector<::dml::FusedActivation> activations; if (direction == diff --git a/src/webnn/native/ops/Gru.cpp b/src/webnn/native/ops/Gru.cpp index ee980789b..6717fa329 100644 --- a/src/webnn/native/ops/Gru.cpp +++ b/src/webnn/native/ops/Gru.cpp @@ -44,9 +44,10 @@ namespace webnn::native::op { } if (options == nullptr || options->activations == nullptr) { mActivations = AcquireRef(new OperatorArrayBase()); - mActivations->Set( + mActivations->APISetFusionOperator( AcquireRef(new FusionOperatorBase(builder, FusionType::Sigmoid)).Get()); - mActivations->Set(AcquireRef(new FusionOperatorBase(builder, FusionType::Tanh)).Get()); + mActivations->APISetFusionOperator( + AcquireRef(new FusionOperatorBase(builder, FusionType::Tanh)).Get()); } else { mActivations = Ref(mOptions.activations); } @@ -108,7 +109,7 @@ namespace webnn::native::op { } } // The activations parameter - if (GetActivations().Get()->Size() != 2) { + if (GetActivations().Get()->APISize() != 2) { return DAWN_VALIDATION_ERROR("Argument activations is not a sequence of length 2."); } diff --git a/src/webnn/native/webnn_platform.h b/src/webnn/native/webnn_platform.h index 29259ce1c..938f38e3e 100644 --- a/src/webnn/native/webnn_platform.h +++ b/src/webnn/native/webnn_platform.h @@ -22,7 +22,7 @@ // Use our autogenerated version of the webnn structures that point to webnn_native // object types #include -#include +#include namespace webnn::native { // kEnumCount is a constant specifying the number of enums in a WebGPU enum type, diff --git a/src/webnn/tests/BUILD.gn b/src/webnn/tests/BUILD.gn index 587f3e7bc..f2196f8a4 100644 --- a/src/webnn/tests/BUILD.gn +++ b/src/webnn/tests/BUILD.gn @@ -16,7 +16,7 @@ import("../../../scripts/webnn_overrides_with_defaults.gni") import("//testing/test.gni") -import("${webnn_dawn_root}/scripts/dawn_features.gni") +import("${dawn_root}/scripts/dawn_features.gni") import("${webnn_root}/generator/webnn_generator.gni") group("webnn_tests") { @@ -96,10 +96,10 @@ if (!build_with_chromium) { ############################################################################### webnn_json_generator("mock_webnn_gen") { - target = "mock_webnn" + target = "mock_api" outputs = [ - "src/webnn/mock_webnn.h", - "src/webnn/mock_webnn.cpp", + "src/dawn/mock_webnn.h", + "src/dawn/mock_webnn.cpp", ] } @@ -132,8 +132,8 @@ test("webnn_unittests") { "${webnn_root}/src/webnn:cpp", "${webnn_root}/src/webnn:webnn_proc", "${webnn_root}/src/webnn/common", - "${webnn_root}/src/webnn/native:webnn_native", "${webnn_root}/src/webnn/native:sources", + "${webnn_root}/src/webnn/native:webnn_native", "${webnn_root}/src/webnn/utils:webnn_utils", ] @@ -142,7 +142,6 @@ test("webnn_unittests") { sources = get_target_outputs(":mock_webnn_gen") sources += [ - "//third_party/dawn/src/tests/unittests/ResultTests.cpp", "unittests/ErrorTests.cpp", "unittests/ObjectBaseTests.cpp", "unittests/native/ContextMockTests.cpp", diff --git a/src/webnn/tests/end2end/AddTests.cpp b/src/webnn/tests/end2end/AddTests.cpp index 5bedd8848..0b5eb0c3e 100644 --- a/src/webnn/tests/end2end/AddTests.cpp +++ b/src/webnn/tests/end2end/AddTests.cpp @@ -17,7 +17,7 @@ class AddTests : public WebnnTest {}; TEST_F(AddTests, AddConstantAndInput) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const std::vector bData = { -0.5781865, -0.49248728, -0.2162451, -0.13176449, -0.52118045, 1.9125274, 0.6508799, @@ -60,7 +60,7 @@ TEST_F(AddTests, AddConstantAndInput) { } TEST_F(AddTests, AddTwoInputs) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = utils::BuildInput(builder, "b", {3, 4, 5}); const wnn::Operand c = builder.Add(a, b); @@ -102,7 +102,7 @@ TEST_F(AddTests, AddTwoInputs) { } TEST_F(AddTests, AddBroadcast) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = utils::BuildInput(builder, "b", {5}); const wnn::Operand c = builder.Add(a, b); diff --git a/src/webnn/tests/end2end/BatchNormTests.cpp b/src/webnn/tests/end2end/BatchNormTests.cpp index 95728c327..750aef6f1 100644 --- a/src/webnn/tests/end2end/BatchNormTests.cpp +++ b/src/webnn/tests/end2end/BatchNormTests.cpp @@ -16,7 +16,7 @@ class BatchNormTests : public WebnnTest { void SetUp() override { - builder = wnn::CreateGraphBuilder(GetContext()); + builder = utils::CreateGraphBuilder(GetContext()); } protected: diff --git a/src/webnn/tests/end2end/ClampTests.cpp b/src/webnn/tests/end2end/ClampTests.cpp index 6df0b8128..30d5c27e4 100644 --- a/src/webnn/tests/end2end/ClampTests.cpp +++ b/src/webnn/tests/end2end/ClampTests.cpp @@ -20,7 +20,7 @@ class ClampTests : public WebnnTest { const std::vector& inputData, const std::vector& expectedValue, const wnn::ClampOptions* options = nullptr) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", inputShape); const wnn::Operand b = builder.Clamp(a, options); const wnn::Graph graph = utils::Build(builder, {{"b", b}}); diff --git a/src/webnn/tests/end2end/ConcatTests.cpp b/src/webnn/tests/end2end/ConcatTests.cpp index 848285e1a..d4b91d7bd 100644 --- a/src/webnn/tests/end2end/ConcatTests.cpp +++ b/src/webnn/tests/end2end/ConcatTests.cpp @@ -26,7 +26,7 @@ class ConcatTests : public WebnnTest { const std::vector& expectedShape, const std::vector& expectedValue, bool inputsDefined = true) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); std::vector inputsOperand; inputsOperand.reserve(inputs.size()); size_t index = 0; diff --git a/src/webnn/tests/end2end/Conv2dTests.cpp b/src/webnn/tests/end2end/Conv2dTests.cpp index 296637ab0..785422dd0 100644 --- a/src/webnn/tests/end2end/Conv2dTests.cpp +++ b/src/webnn/tests/end2end/Conv2dTests.cpp @@ -16,7 +16,7 @@ class Conv2dTests : public WebnnTest { void SetUp() override { - builder = wnn::CreateGraphBuilder(GetContext()); + builder = utils::CreateGraphBuilder(GetContext()); } protected: diff --git a/src/webnn/tests/end2end/ConvTranspose2dTests.cpp b/src/webnn/tests/end2end/ConvTranspose2dTests.cpp index 022e1b70a..87dfa5a78 100644 --- a/src/webnn/tests/end2end/ConvTranspose2dTests.cpp +++ b/src/webnn/tests/end2end/ConvTranspose2dTests.cpp @@ -16,7 +16,7 @@ class ConvTranspose2dTests : public WebnnTest { void SetUp() override { - builder = wnn::CreateGraphBuilder(GetContext()); + builder = utils::CreateGraphBuilder(GetContext()); } protected: diff --git a/src/webnn/tests/end2end/DivTests.cpp b/src/webnn/tests/end2end/DivTests.cpp index 83c22b804..0b455d4f6 100644 --- a/src/webnn/tests/end2end/DivTests.cpp +++ b/src/webnn/tests/end2end/DivTests.cpp @@ -17,7 +17,7 @@ class DivTests : public WebnnTest {}; TEST_F(DivTests, Div) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = utils::BuildInput(builder, "b", {3, 4, 5}); const wnn::Operand c = builder.Div(a, b); @@ -62,7 +62,7 @@ TEST_F(DivTests, Div) { } TEST_F(DivTests, DivBroadcast) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = utils::BuildInput(builder, "b", {5}); const wnn::Operand c = builder.Div(a, b); diff --git a/src/webnn/tests/end2end/ElementWiseUnaryTests.cpp b/src/webnn/tests/end2end/ElementWiseUnaryTests.cpp index 00241ee13..909eb2c22 100644 --- a/src/webnn/tests/end2end/ElementWiseUnaryTests.cpp +++ b/src/webnn/tests/end2end/ElementWiseUnaryTests.cpp @@ -32,7 +32,7 @@ class ElementWiseUnaryTests : public WebnnTest { const std::vector& inputData, const std::vector& expectedValue, const std::vector& shape) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", shape); wnn::Operand b; switch (type) { diff --git a/src/webnn/tests/end2end/GemmTests.cpp b/src/webnn/tests/end2end/GemmTests.cpp index 6d2ebfebb..ccbacf854 100644 --- a/src/webnn/tests/end2end/GemmTests.cpp +++ b/src/webnn/tests/end2end/GemmTests.cpp @@ -33,7 +33,7 @@ class GemmTests : public WebnnTest { const std::vector& expectedValue, const Options* options = nullptr, bool constantWeight = false) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", aShape); wnn::Operand b; if (constantWeight) { diff --git a/src/webnn/tests/end2end/GruTests.cpp b/src/webnn/tests/end2end/GruTests.cpp index 39f8aa024..757af2ba2 100644 --- a/src/webnn/tests/end2end/GruTests.cpp +++ b/src/webnn/tests/end2end/GruTests.cpp @@ -16,7 +16,7 @@ class GruTests : public WebnnTest { void SetUp() override { - builder = wnn::CreateGraphBuilder(GetContext()); + builder = utils::CreateGraphBuilder(GetContext()); } protected: @@ -43,7 +43,7 @@ class GruTests : public WebnnTest { const size_t outputSize = Y.Size(); std::vector namedOperands; for (size_t i = 0; i < outputSize; ++i) { - namedOperands.push_back({"gru" + std::to_string(i), Y.Get(i)}); + namedOperands.push_back({"gru" + std::to_string(i), Y.GetOperand(i)}); } const wnn::Graph graph = utils::Build(builder, namedOperands); ASSERT_TRUE(graph); @@ -151,9 +151,9 @@ TEST_F(GruTests, GruWithMultiActivitions) { auto activations = CreateCppOperatorArray(); auto activationSigmoid = utils::CreateActivationOperator(builder, utils::FusedActivation::SIGMOID); - activations.Set(activationSigmoid); + activations.SetFusionOperator(activationSigmoid); auto activationTanh = utils::CreateActivationOperator(builder, utils::FusedActivation::TANH); - activations.Set(activationTanh); + activations.SetFusionOperator(activationTanh); options.activations = activations; const std::vector expectedShape = {numDirections, batchSize, hiddenSize}; diff --git a/src/webnn/tests/end2end/HardSwishTests.cpp b/src/webnn/tests/end2end/HardSwishTests.cpp index 2f54a23d8..f57bd48de 100644 --- a/src/webnn/tests/end2end/HardSwishTests.cpp +++ b/src/webnn/tests/end2end/HardSwishTests.cpp @@ -21,7 +21,7 @@ class HardSwishTests : public WebnnTest { void CheckHardSwish(const std::vector& inputShape, const std::vector& inputBuffer, const std::vector& expectedBuffer) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", inputShape); const wnn::Operand y = builder.HardSwish(x); const wnn::Graph graph = utils::Build(builder, {{"y", y}}); diff --git a/src/webnn/tests/end2end/InstanceNormTests.cpp b/src/webnn/tests/end2end/InstanceNormTests.cpp index c6d5dc257..184b3a35b 100644 --- a/src/webnn/tests/end2end/InstanceNormTests.cpp +++ b/src/webnn/tests/end2end/InstanceNormTests.cpp @@ -16,7 +16,7 @@ class InstanceNormTests : public WebnnTest { void SetUp() override { - builder = wnn::CreateGraphBuilder(GetContext()); + builder = utils::CreateGraphBuilder(GetContext()); } protected: diff --git a/src/webnn/tests/end2end/LeakyReluTests.cpp b/src/webnn/tests/end2end/LeakyReluTests.cpp index 303b51ac1..6acd0a2c4 100644 --- a/src/webnn/tests/end2end/LeakyReluTests.cpp +++ b/src/webnn/tests/end2end/LeakyReluTests.cpp @@ -20,7 +20,7 @@ class LeakyReluTests : public WebnnTest { const std::vector& inputData, const std::vector& expectedValue, float alpha = 0.01) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", inputShape); wnn::LeakyReluOptions options; options.alpha = alpha; diff --git a/src/webnn/tests/end2end/MatMulTests.cpp b/src/webnn/tests/end2end/MatMulTests.cpp index 10f94f618..89ce0a499 100644 --- a/src/webnn/tests/end2end/MatMulTests.cpp +++ b/src/webnn/tests/end2end/MatMulTests.cpp @@ -17,7 +17,7 @@ class MatMulTests : public WebnnTest {}; TEST_F(MatMulTests, MatMul1d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {4}); const std::vector bData = {0.8782074, 0.22533207, 0.7134056, 0.04190519}; const wnn::Operand b = @@ -33,7 +33,7 @@ TEST_F(MatMulTests, MatMul1d) { } TEST_F(MatMulTests, MatMul1dx2d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {4}); const std::vector bData = { 0.3093976, -1.2924036, -0.64339244, 1.1423386, 1.5052135, 1.8182521, @@ -52,7 +52,7 @@ TEST_F(MatMulTests, MatMul1dx2d) { } TEST_F(MatMulTests, MatMul2dx1d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4}); const std::vector bData = {0.25528687, 0.2126722, 0.26320502, 0.8297401}; const wnn::Operand b = @@ -71,7 +71,7 @@ TEST_F(MatMulTests, MatMul2dx1d) { } TEST_F(MatMulTests, MatMul2d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4}); const std::vector bData = {0.17467105, -1.2045133, -0.02621938, 0.6096196, 1.4499376, 1.3465316, 0.03289436, 1.0754977, @@ -93,7 +93,7 @@ TEST_F(MatMulTests, MatMul2d) { } TEST_F(MatMulTests, MatMul3d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {2, 3, 4}); const std::vector bData = {-2.7142005, 0.41909233, 0.80572236, 0.19983047, -1.9361104, 1.1919757, 0.61684674, 0.23732206, 0.74679494, 0.4595843, @@ -122,7 +122,7 @@ TEST_F(MatMulTests, MatMul3d) { } TEST_F(MatMulTests, MatMul3dx2d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {2, 3, 4}); const std::vector bData = {-0.38534147, -0.18395364, -2.548874, 0.4525641, -0.41875792, 0.57480955, -0.41603103, 0.6973883, @@ -147,7 +147,7 @@ TEST_F(MatMulTests, MatMul3dx2d) { } TEST_F(MatMulTests, MatMul3dx2dGet3d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {1, 3, 4}); const std::vector bData = {0.2545374, -1.6150205, -0.64508885, -0.3454305, 0.38700557, 1.3147515, -0.3379386, 1.1804152, @@ -170,7 +170,7 @@ TEST_F(MatMulTests, MatMul3dx2dGet3d) { } TEST_F(MatMulTests, MatMul4d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {1, 2, 3, 4}); const std::vector bData = { -0.45605758, -0.43318668, 0.61509126, -2.2228749, 0.50257015, -0.29311436, @@ -200,7 +200,7 @@ TEST_F(MatMulTests, MatMul4d) { } TEST_F(MatMulTests, MatMul4dx2d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {1, 2, 3, 4}); const std::vector bData = { 0.01829041, -0.73948264, -0.95898634, -0.5105271, 2.1705306, 1.2495605, diff --git a/src/webnn/tests/end2end/MaxTests.cpp b/src/webnn/tests/end2end/MaxTests.cpp index 9bdce0581..09062d2dd 100644 --- a/src/webnn/tests/end2end/MaxTests.cpp +++ b/src/webnn/tests/end2end/MaxTests.cpp @@ -17,7 +17,7 @@ class MaxTests : public WebnnTest {}; TEST_F(MaxTests, MaxConstantAndInput) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const std::vector bData = { -0.00724315, -1.4088361, 0.17466596, 1.1395162, 1.3720452, -0.35610083, -0.5597993, @@ -60,7 +60,7 @@ TEST_F(MaxTests, MaxConstantAndInput) { } TEST_F(MaxTests, MaxTwoInputs) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = utils::BuildInput(builder, "b", {3, 4, 5}); const wnn::Operand c = builder.Max(a, b); @@ -102,7 +102,7 @@ TEST_F(MaxTests, MaxTwoInputs) { } TEST_F(MaxTests, MaxBroadcast) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = utils::BuildInput(builder, "b", {5}); const wnn::Operand c = builder.Max(a, b); diff --git a/src/webnn/tests/end2end/MinTests.cpp b/src/webnn/tests/end2end/MinTests.cpp index 77f29a99a..11f6abd1c 100644 --- a/src/webnn/tests/end2end/MinTests.cpp +++ b/src/webnn/tests/end2end/MinTests.cpp @@ -17,7 +17,7 @@ class MinTests : public WebnnTest {}; TEST_F(MinTests, MinConstantAndInput) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const std::vector bData = { -0.3013072, -0.09710764, 0.19347863, 0.57673335, -0.9459303, -0.311303, -0.51731133, @@ -60,7 +60,7 @@ TEST_F(MinTests, MinConstantAndInput) { } TEST_F(MinTests, MinTwoInputs) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = utils::BuildInput(builder, "b", {3, 4, 5}); const wnn::Operand c = builder.Min(a, b); @@ -102,7 +102,7 @@ TEST_F(MinTests, MinTwoInputs) { } TEST_F(MinTests, MinBroadcast) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = utils::BuildInput(builder, "b", {5}); const wnn::Operand c = builder.Min(a, b); diff --git a/src/webnn/tests/end2end/MulTests.cpp b/src/webnn/tests/end2end/MulTests.cpp index cfd1a0a21..c721e5857 100644 --- a/src/webnn/tests/end2end/MulTests.cpp +++ b/src/webnn/tests/end2end/MulTests.cpp @@ -17,7 +17,7 @@ class MulTests : public WebnnTest {}; TEST_F(MulTests, MulInputAndConstant) { - wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); std::vector dataB = { 2.0435283, 0.07213961, -1.1644137, -1.2209045, 0.8982674, 0.21796915, 0.27658972, @@ -66,7 +66,7 @@ TEST_F(MulTests, MulInputAndConstant) { } TEST_F(MulTests, MulTwoInputs) { - wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); wnn::Operand b = utils::BuildInput(builder, "b", {3, 4, 5}); wnn::Operand c = builder.Mul(a, b); @@ -114,7 +114,7 @@ TEST_F(MulTests, MulTwoInputs) { } TEST_F(MulTests, MulBroadcast) { - wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); std::vector dataB = { 0.6338172, 1.630534, -1.3819867, -1.0427561, 1.058136, diff --git a/src/webnn/tests/end2end/PadTests.cpp b/src/webnn/tests/end2end/PadTests.cpp index fab73be21..d9f43e1bf 100644 --- a/src/webnn/tests/end2end/PadTests.cpp +++ b/src/webnn/tests/end2end/PadTests.cpp @@ -23,7 +23,7 @@ class PadTests : public WebnnTest { const std::vector& expectedShape, const std::vector& expectedValue, wnn::PaddingMode mode = wnn::PaddingMode::Constant) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", inputShape); const wnn::Operand padding = utils::BuildConstant(builder, paddingShape, paddingData.data(), diff --git a/src/webnn/tests/end2end/Pool2dTests.cpp b/src/webnn/tests/end2end/Pool2dTests.cpp index 39d8f6dc9..b85c273c7 100644 --- a/src/webnn/tests/end2end/Pool2dTests.cpp +++ b/src/webnn/tests/end2end/Pool2dTests.cpp @@ -17,7 +17,7 @@ class Pool2dTests : public WebnnTest {}; TEST_F(Pool2dTests, MaxPool2dDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 4, 4}); utils::Pool2dOptions options; options.windowDimensions = {3, 3}; @@ -32,7 +32,7 @@ TEST_F(Pool2dTests, MaxPool2dDefault) { } TEST_F(Pool2dTests, MaxPool2dNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 4, 4, 1}); utils::Pool2dOptions options; options.windowDimensions = {3, 3}; @@ -48,7 +48,7 @@ TEST_F(Pool2dTests, MaxPool2dNhwc) { } TEST_F(Pool2dTests, MaxPool2dDilationsDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 4, 4}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; @@ -64,7 +64,7 @@ TEST_F(Pool2dTests, MaxPool2dDilationsDefault) { } TEST_F(Pool2dTests, MaxPool2dDilationsNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 4, 4, 1}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; @@ -81,7 +81,7 @@ TEST_F(Pool2dTests, MaxPool2dDilationsNhwc) { } TEST_F(Pool2dTests, MaxPool2dPadsDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 5, 5}); utils::Pool2dOptions options; options.windowDimensions = {5, 5}; @@ -99,7 +99,7 @@ TEST_F(Pool2dTests, MaxPool2dPadsDefault) { } TEST_F(Pool2dTests, MaxPool2dPadsNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 5, 5, 1}); utils::Pool2dOptions options; options.windowDimensions = {5, 5}; @@ -118,7 +118,7 @@ TEST_F(Pool2dTests, MaxPool2dPadsNhwc) { } TEST_F(Pool2dTests, MaxPool2dAutoPadSameUpperDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 5, 5}); utils::Pool2dOptions options; options.windowDimensions = {5, 5}; @@ -136,7 +136,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadSameUpperDefault) { } TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -159,7 +159,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitNhwc) { } TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitOutputSizes3x3Nhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -182,7 +182,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitOutputSizes3x3Nhwc) { } TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitOutputSizes4x4Nhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -206,7 +206,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitOutputSizes4x4Nhwc) { } TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitRoundingTypeFloorNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -229,7 +229,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitRoundingTypeFloorNhwc) { } TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitRoundingTypeCeilNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -253,7 +253,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitRoundingTypeCeilNhwc) { } TEST_F(Pool2dTests, MaxPool2dAutoPadSameLowerNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -275,7 +275,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadSameLowerNhwc) { } TEST_F(Pool2dTests, MaxPool2dAutoPadSameUpperNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 5, 5, 1}); utils::Pool2dOptions options; options.windowDimensions = {5, 5}; @@ -294,7 +294,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadSameUpperNhwc) { } TEST_F(Pool2dTests, MaxPool2dStridesDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 5, 5}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; @@ -311,7 +311,7 @@ TEST_F(Pool2dTests, MaxPool2dStridesDefault) { } TEST_F(Pool2dTests, MaxPool2dStridesNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 5, 5, 1}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; @@ -329,7 +329,7 @@ TEST_F(Pool2dTests, MaxPool2dStridesNhwc) { } TEST_F(Pool2dTests, AveragePool2dDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 4, 4}); utils::Pool2dOptions options; options.windowDimensions = {3, 3}; @@ -344,7 +344,7 @@ TEST_F(Pool2dTests, AveragePool2dDefault) { } TEST_F(Pool2dTests, AveragePool2dNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 4, 4, 1}); utils::Pool2dOptions options; options.windowDimensions = {3, 3}; @@ -360,7 +360,7 @@ TEST_F(Pool2dTests, AveragePool2dNhwc) { } TEST_F(Pool2dTests, AveragePool2dPadsDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 5, 5}); utils::Pool2dOptions options; options.windowDimensions = {5, 5}; @@ -379,7 +379,7 @@ TEST_F(Pool2dTests, AveragePool2dPadsDefault) { } TEST_F(Pool2dTests, AveragePool2dPadsNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 5, 5, 1}); utils::Pool2dOptions options; options.windowDimensions = {5, 5}; @@ -399,7 +399,7 @@ TEST_F(Pool2dTests, AveragePool2dPadsNhwc) { } TEST_F(Pool2dTests, AveragePool2dAutoPadSameUpperDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 5, 5}); utils::Pool2dOptions options; options.windowDimensions = {5, 5}; @@ -418,7 +418,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadSameUpperDefault) { } TEST_F(Pool2dTests, AveragePool2dAutoPadSameUpperNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 5, 5, 1}); utils::Pool2dOptions options; options.windowDimensions = {5, 5}; @@ -438,7 +438,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadSameUpperNhwc) { } TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -461,7 +461,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitNhwc) { } TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitOutputSizes3x3Nhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -484,7 +484,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitOutputSizes3x3Nhwc) { } TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitOutputSizes4x4Nhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -508,7 +508,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitOutputSizes4x4Nhwc) { } TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitRoundingTypeFloorNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -531,7 +531,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitRoundingTypeFloorNhwc) { } TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitRoundingTypeCeilNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -555,7 +555,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitRoundingTypeCeilNhwc) { } TEST_F(Pool2dTests, AveragePool2dAutoPadSameLowerNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -577,7 +577,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadSameLowerNhwc) { } TEST_F(Pool2dTests, AveragePool2dStridesDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 5, 5}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; @@ -594,7 +594,7 @@ TEST_F(Pool2dTests, AveragePool2dStridesDefault) { } TEST_F(Pool2dTests, AveragePool2dStridesNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 5, 5, 1}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; @@ -612,7 +612,7 @@ TEST_F(Pool2dTests, AveragePool2dStridesNhwc) { } TEST_F(Pool2dTests, GlobalAveragePool2dDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 3, 5, 5}); const wnn::Operand y = builder.AveragePool2d(x); const wnn::Graph graph = utils::Build(builder, {{"y", y}}); @@ -636,7 +636,7 @@ TEST_F(Pool2dTests, GlobalAveragePool2dDefault) { } TEST_F(Pool2dTests, GlobalAveragePool2dNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 5, 5, 3}); utils::Pool2dOptions options; options.layout = wnn::InputOperandLayout::Nhwc; diff --git a/src/webnn/tests/end2end/PowTests.cpp b/src/webnn/tests/end2end/PowTests.cpp index 43a11ad43..b2078e925 100644 --- a/src/webnn/tests/end2end/PowTests.cpp +++ b/src/webnn/tests/end2end/PowTests.cpp @@ -17,7 +17,7 @@ class PowTests : public WebnnTest {}; TEST_F(PowTests, Sqrt1d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3}); const std::vector bData = {0.5}; const wnn::Operand b = @@ -33,7 +33,7 @@ TEST_F(PowTests, Sqrt1d) { } TEST_F(PowTests, Sqrt3d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const std::vector bData = {0.5}; const wnn::Operand b = @@ -67,7 +67,7 @@ TEST_F(PowTests, Sqrt3d) { } TEST_F(PowTests, Pow1d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3}); const std::vector bData = {2}; const wnn::Operand b = @@ -83,7 +83,7 @@ TEST_F(PowTests, Pow1d) { } TEST_F(PowTests, PowBroadcastScalar) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {2, 3}); const std::vector bData = {2}; const wnn::Operand b = @@ -99,7 +99,7 @@ TEST_F(PowTests, PowBroadcastScalar) { } TEST_F(PowTests, PowBroadcast1d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {2, 3}); const std::vector bData = {1, 2, 3}; const wnn::Operand b = diff --git a/src/webnn/tests/end2end/ReduceTests.cpp b/src/webnn/tests/end2end/ReduceTests.cpp index 65e5c8e87..53d2ab4ed 100644 --- a/src/webnn/tests/end2end/ReduceTests.cpp +++ b/src/webnn/tests/end2end/ReduceTests.cpp @@ -33,7 +33,7 @@ class ReduceTests : public WebnnTest { const std::vector& expectedValue, const std::vector& axes = {}, bool keepDimensions = false) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", inputShape); wnn::ReduceOptions options; if (!axes.empty()) { diff --git a/src/webnn/tests/end2end/ReluTests.cpp b/src/webnn/tests/end2end/ReluTests.cpp index 4b8afc37e..4f3d04661 100644 --- a/src/webnn/tests/end2end/ReluTests.cpp +++ b/src/webnn/tests/end2end/ReluTests.cpp @@ -17,7 +17,7 @@ class ReluTests : public WebnnTest {}; TEST_F(ReluTests, Relu) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = builder.Relu(a); const wnn::Graph graph = utils::Build(builder, {{"b", b}}); diff --git a/src/webnn/tests/end2end/Resample2dTests.cpp b/src/webnn/tests/end2end/Resample2dTests.cpp index 42024904b..d5cfe9f86 100644 --- a/src/webnn/tests/end2end/Resample2dTests.cpp +++ b/src/webnn/tests/end2end/Resample2dTests.cpp @@ -21,7 +21,7 @@ class Resample2dTests : public WebnnTest { const std::vector& expectedShape, const std::vector& expectedValue, const wnn::Resample2dOptions* options = nullptr) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand inputOperand = utils::BuildInput(builder, "input", inputShape); const wnn::Operand output = builder.Resample2d(inputOperand, options); const wnn::Graph graph = utils::Build(builder, {{"output", output}}); diff --git a/src/webnn/tests/end2end/ReshapeTests.cpp b/src/webnn/tests/end2end/ReshapeTests.cpp index 110203d28..f86b71075 100644 --- a/src/webnn/tests/end2end/ReshapeTests.cpp +++ b/src/webnn/tests/end2end/ReshapeTests.cpp @@ -19,7 +19,7 @@ class ReshapeTests : public WebnnTest { void TestReshape(const std::vector& oldShape, const std::vector& newShape, const std::vector& expectedShape) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", oldShape); const wnn::Operand b = builder.Reshape(a, newShape.data(), newShape.size()); const wnn::Graph graph = utils::Build(builder, {{"b", b}}); diff --git a/src/webnn/tests/end2end/SigmoidTests.cpp b/src/webnn/tests/end2end/SigmoidTests.cpp index 893904d88..2686ea899 100644 --- a/src/webnn/tests/end2end/SigmoidTests.cpp +++ b/src/webnn/tests/end2end/SigmoidTests.cpp @@ -17,7 +17,7 @@ class SigmoidTests : public WebnnTest {}; TEST_F(SigmoidTests, SigmoidWith1DTensor) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3}); const wnn::Operand b = builder.Sigmoid(a); const wnn::Graph graph = utils::Build(builder, {{"b", b}}); @@ -30,7 +30,7 @@ TEST_F(SigmoidTests, SigmoidWith1DTensor) { } TEST_F(SigmoidTests, SigmoidWith3DTensor) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = builder.Sigmoid(a); const wnn::Graph graph = utils::Build(builder, {{"b", b}}); diff --git a/src/webnn/tests/end2end/SliceTests.cpp b/src/webnn/tests/end2end/SliceTests.cpp index 759b5c9b7..8c02143a7 100644 --- a/src/webnn/tests/end2end/SliceTests.cpp +++ b/src/webnn/tests/end2end/SliceTests.cpp @@ -16,7 +16,7 @@ class SliceTests : public WebnnTest { void SetUp() override { - builder = wnn::CreateGraphBuilder(GetContext()); + builder = utils::CreateGraphBuilder(GetContext()); } protected: diff --git a/src/webnn/tests/end2end/SoftmaxTests.cpp b/src/webnn/tests/end2end/SoftmaxTests.cpp index cb8e3b89b..2e7211ee9 100644 --- a/src/webnn/tests/end2end/SoftmaxTests.cpp +++ b/src/webnn/tests/end2end/SoftmaxTests.cpp @@ -17,7 +17,7 @@ class SoftmaxTests : public WebnnTest {}; TEST_F(SoftmaxTests, Softmax) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4}); const wnn::Operand b = builder.Softmax(a); const wnn::Graph graph = utils::Build(builder, {{"b", b}}); diff --git a/src/webnn/tests/end2end/SplitTests.cpp b/src/webnn/tests/end2end/SplitTests.cpp index 69b6b21fd..cf9bfb4d7 100644 --- a/src/webnn/tests/end2end/SplitTests.cpp +++ b/src/webnn/tests/end2end/SplitTests.cpp @@ -29,14 +29,14 @@ class SplitTests : public WebnnTest { const std::vector& splits, const std::vector& expectedArray, int32_t axis = 0) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand input = utils::BuildInput(builder, "input", inputShape); wnn::SplitOptions options = {axis}; const wnn::OperandArray splittedOperands = builder.Split(input, splits.data(), splits.size(), &options); std::vector namedOperands; for (size_t i = 0; i < splittedOperands.Size(); ++i) { - namedOperands.push_back({"split" + std::to_string(i), splittedOperands.Get(i)}); + namedOperands.push_back({"split" + std::to_string(i), splittedOperands.GetOperand(i)}); } const wnn::Graph graph = utils::Build(builder, namedOperands); ASSERT_TRUE(graph); diff --git a/src/webnn/tests/end2end/SqueezeTests.cpp b/src/webnn/tests/end2end/SqueezeTests.cpp index af336a1b0..ef581ad6e 100644 --- a/src/webnn/tests/end2end/SqueezeTests.cpp +++ b/src/webnn/tests/end2end/SqueezeTests.cpp @@ -21,7 +21,7 @@ class SqueezeTests : public WebnnTest { void CheckSqueeze(const std::vector& inputShape, const std::vector& axes, const std::vector& expectedShape) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", inputShape); wnn::SqueezeOptions options; if (!axes.empty()) { diff --git a/src/webnn/tests/end2end/SubTests.cpp b/src/webnn/tests/end2end/SubTests.cpp index 16beeb7ab..d9bae408b 100644 --- a/src/webnn/tests/end2end/SubTests.cpp +++ b/src/webnn/tests/end2end/SubTests.cpp @@ -17,7 +17,7 @@ class SubTests : public WebnnTest {}; TEST_F(SubTests, SubTwoInputs) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = utils::BuildInput(builder, "b", {3, 4, 5}); const wnn::Operand c = builder.Sub(a, b); @@ -60,7 +60,7 @@ TEST_F(SubTests, SubTwoInputs) { } TEST_F(SubTests, SubBroadcast) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = utils::BuildInput(builder, "b", {5}); const wnn::Operand c = builder.Sub(a, b); diff --git a/src/webnn/tests/end2end/TanhTests.cpp b/src/webnn/tests/end2end/TanhTests.cpp index 62b487ee4..fce86dd84 100644 --- a/src/webnn/tests/end2end/TanhTests.cpp +++ b/src/webnn/tests/end2end/TanhTests.cpp @@ -19,7 +19,7 @@ class TanhTests : public WebnnTest { void TestTanh(const std::vector& inputData, const std::vector& expectedData, const std::vector& shape) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", shape); const wnn::Operand b = builder.Tanh(a); const wnn::Graph graph = utils::Build(builder, {{"b", b}}); diff --git a/src/webnn/tests/end2end/TransposeTests.cpp b/src/webnn/tests/end2end/TransposeTests.cpp index c93e808cc..2da27f0f3 100644 --- a/src/webnn/tests/end2end/TransposeTests.cpp +++ b/src/webnn/tests/end2end/TransposeTests.cpp @@ -21,7 +21,7 @@ class TransposeTests : public WebnnTest { const std::vector& expectedShape, const std::vector& expectedValue, const std::vector& permutation = {}) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", inputShape); wnn::TransposeOptions options; options.permutation = permutation.data(); diff --git a/src/webnn/tests/end2end/models/MobileNetV2BatchNormNchw.cpp b/src/webnn/tests/end2end/models/MobileNetV2BatchNormNchw.cpp index 084c841bd..42ca325ff 100644 --- a/src/webnn/tests/end2end/models/MobileNetV2BatchNormNchw.cpp +++ b/src/webnn/tests/end2end/models/MobileNetV2BatchNormNchw.cpp @@ -26,8 +26,8 @@ class MobileNetV2BatchNormNchwTests : public WebnnTest { mobilenetv2.mFused = fused; const std::string nchwPath = kModelPath + "/mobilenetv2_batchnorm_nchw/"; mobilenetv2.mWeightsPath = nchwPath + "weights/"; - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); - wnn::Operand output = mobilenetv2.LoadBatchNormNchw(builder, false); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); + wnn::Operand output = mobilenetv2.LoadBatchNormNCHW(builder, false); wnn::Graph graph = utils::Build(builder, {{"output", output}}); const cnpy::NpyArray inputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + inputFile); const std::vector inputData = inputNpy.as_vec(); diff --git a/src/webnn/tests/end2end/models/MobileNetV2Nchw.cpp b/src/webnn/tests/end2end/models/MobileNetV2Nchw.cpp index 71a66ff5f..45f016e3a 100644 --- a/src/webnn/tests/end2end/models/MobileNetV2Nchw.cpp +++ b/src/webnn/tests/end2end/models/MobileNetV2Nchw.cpp @@ -26,8 +26,8 @@ class MobileNetV2NchwTests : public WebnnTest { mobilenetv2.mFused = fused; const std::string nchwPath = kModelPath + "/mobilenetv2_nchw/"; mobilenetv2.mWeightsPath = nchwPath + "weights/"; - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); - wnn::Operand output = mobilenetv2.LoadNchw(builder, false); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); + wnn::Operand output = mobilenetv2.LoadNCHW(builder, false); wnn::Graph graph = utils::Build(builder, {{"output", output}}); const cnpy::NpyArray inputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + inputFile); const std::vector inputData = inputNpy.as_vec(); diff --git a/src/webnn/tests/end2end/models/MobileNetV2Nhwc.cpp b/src/webnn/tests/end2end/models/MobileNetV2Nhwc.cpp index 47842facc..8c3e1195e 100644 --- a/src/webnn/tests/end2end/models/MobileNetV2Nhwc.cpp +++ b/src/webnn/tests/end2end/models/MobileNetV2Nhwc.cpp @@ -27,8 +27,8 @@ class MobileNetV2NhwcTests : public WebnnTest { const std::string nhwcPath = kModelPath + "/mobilenetv2_nhwc/"; mobilenetv2.mWeightsPath = nhwcPath + "weights/"; mobilenetv2.mLayout = "nhwc"; - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); - wnn::Operand output = mobilenetv2.LoadNhwc(builder); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); + wnn::Operand output = mobilenetv2.LoadNHWC(builder); wnn::Graph graph = utils::Build(builder, {{"output", output}}); const cnpy::NpyArray inputNpy = cnpy::npy_load(nhwcPath + "test_data_set/" + inputFile); const std::vector inputData = inputNpy.as_vec(); diff --git a/src/webnn/tests/end2end/models/ResNetNchw.cpp b/src/webnn/tests/end2end/models/ResNetNchw.cpp index 54e48122a..90a6ccae6 100644 --- a/src/webnn/tests/end2end/models/ResNetNchw.cpp +++ b/src/webnn/tests/end2end/models/ResNetNchw.cpp @@ -26,8 +26,8 @@ class ResNetNchwTests : public WebnnTest { resnet.mFused = fused; const std::string nchwPath = kModelPath + "/resnet50v2_nchw/"; resnet.mWeightsPath = nchwPath + "weights/"; - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); - wnn::Operand output = resnet.LoadNchw(builder, false); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); + wnn::Operand output = resnet.LoadNCHW(builder, false); wnn::Graph graph = utils::Build(builder, {{"output", output}}); const cnpy::NpyArray inputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + inputFile); const std::vector inputData = inputNpy.as_vec(); diff --git a/src/webnn/tests/end2end/models/ResNetNhwc.cpp b/src/webnn/tests/end2end/models/ResNetNhwc.cpp index ad0fcb23c..22edf774a 100644 --- a/src/webnn/tests/end2end/models/ResNetNhwc.cpp +++ b/src/webnn/tests/end2end/models/ResNetNhwc.cpp @@ -26,8 +26,8 @@ class ResNetNhwcTests : public WebnnTest { resnet.mFused = fused; const std::string nhwcPath = kModelPath + "/resnet50v2_nhwc/"; resnet.mWeightsPath = nhwcPath + "weights/"; - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); - wnn::Operand output = resnet.LoadNhwc(builder, true); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); + wnn::Operand output = resnet.LoadNHWC(builder, false); wnn::Graph graph = utils::Build(builder, {{"output", output}}); const cnpy::NpyArray inputNpy = cnpy::npy_load(nhwcPath + "test_data_set/" + inputFile); const std::vector inputData = inputNpy.as_vec(); diff --git a/src/webnn/tests/end2end/models/SqueezeNetNchw.cpp b/src/webnn/tests/end2end/models/SqueezeNetNchw.cpp index 9430f604f..0ef2291fa 100644 --- a/src/webnn/tests/end2end/models/SqueezeNetNchw.cpp +++ b/src/webnn/tests/end2end/models/SqueezeNetNchw.cpp @@ -26,8 +26,8 @@ class SqueezeNetNchwTests : public WebnnTest { squeezenet.mFused = fused; const std::string nchwPath = kModelPath + "/squeezenet1.1_nchw/"; squeezenet.mWeightsPath = nchwPath + "weights/"; - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); - wnn::Operand output = squeezenet.LoadNchw(builder, false); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); + wnn::Operand output = squeezenet.LoadNCHW(builder, false); wnn::Graph graph = utils::Build(builder, {{"output", output}}); const cnpy::NpyArray inputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + inputFile); const std::vector inputData = inputNpy.as_vec(); diff --git a/src/webnn/tests/end2end/models/SqueezeNetNhwc.cpp b/src/webnn/tests/end2end/models/SqueezeNetNhwc.cpp index adfa1d6ff..8dcae576c 100644 --- a/src/webnn/tests/end2end/models/SqueezeNetNhwc.cpp +++ b/src/webnn/tests/end2end/models/SqueezeNetNhwc.cpp @@ -27,8 +27,8 @@ class SqueezeNetNhwcTests : public WebnnTest { const std::string nhwcPath = kModelPath + "/squeezenet1.0_nhwc/"; squeezenet.mWeightsPath = nhwcPath + "weights/"; squeezenet.mLayout = "nhwc"; - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); - wnn::Operand output = squeezenet.LoadNhwc(builder); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); + wnn::Operand output = squeezenet.LoadNHWC(builder); wnn::Graph graph = utils::Build(builder, {{"output", output}}); const cnpy::NpyArray inputNpy = cnpy::npy_load(nhwcPath + "test_data_set/" + inputFile); const std::vector inputData = inputNpy.as_vec(); diff --git a/src/webnn/tests/unittests/ObjectBaseTests.cpp b/src/webnn/tests/unittests/ObjectBaseTests.cpp index 3d11a3602..6c9f93b72 100644 --- a/src/webnn/tests/unittests/ObjectBaseTests.cpp +++ b/src/webnn/tests/unittests/ObjectBaseTests.cpp @@ -23,11 +23,11 @@ class Object : public wnn::ObjectBase { using ObjectBase::ObjectBase; using ObjectBase::operator=; - static void WebnnReference(int* handle) { + static void WNNReference(int* handle) { ASSERT_LE(0, *handle); *handle += 1; } - static void WebnnRelease(int* handle) { + static void WNNRelease(int* handle) { ASSERT_LT(0, *handle); *handle -= 1; } @@ -54,14 +54,14 @@ TEST(ObjectBase, AcquireConstruction) { ASSERT_EQ(0, refcount); } -// Test .GetHandle(). +// Test .Get(). TEST(ObjectBase, Get) { int refcount = 1; { Object obj1(&refcount); ASSERT_EQ(2, refcount); - ASSERT_EQ(&refcount, obj1.GetHandle()); + ASSERT_EQ(&refcount, obj1.Get()); } ASSERT_EQ(1, refcount); } @@ -74,7 +74,7 @@ TEST(ObjectBase, Release) { ASSERT_EQ(2, refcount); ASSERT_EQ(&refcount, obj.Release()); - ASSERT_EQ(nullptr, obj.GetHandle()); + ASSERT_EQ(nullptr, obj.Get()); ASSERT_EQ(2, refcount); } ASSERT_EQ(2, refcount); @@ -98,8 +98,8 @@ TEST(ObjectBase, CopyConstructor) { Object source(&refcount); Object destination(source); - ASSERT_EQ(source.GetHandle(), &refcount); - ASSERT_EQ(destination.GetHandle(), &refcount); + ASSERT_EQ(source.Get(), &refcount); + ASSERT_EQ(destination.Get(), &refcount); ASSERT_EQ(3, refcount); destination = Object(); @@ -114,8 +114,8 @@ TEST(ObjectBase, CopyAssignment) { Object destination; destination = source; - ASSERT_EQ(source.GetHandle(), &refcount); - ASSERT_EQ(destination.GetHandle(), &refcount); + ASSERT_EQ(source.Get(), &refcount); + ASSERT_EQ(destination.Get(), &refcount); ASSERT_EQ(3, refcount); destination = Object(); @@ -132,7 +132,7 @@ TEST(ObjectBase, CopyAssignmentSelf) { Object* objPtr = &obj; obj = *objPtr; - ASSERT_EQ(obj.GetHandle(), &refcount); + ASSERT_EQ(obj.Get(), &refcount); ASSERT_EQ(refcount, 2); } @@ -142,8 +142,8 @@ TEST(ObjectBase, MoveConstructor) { Object source(&refcount); Object destination(std::move(source)); - ASSERT_EQ(source.GetHandle(), nullptr); - ASSERT_EQ(destination.GetHandle(), &refcount); + ASSERT_EQ(source.Get(), nullptr); + ASSERT_EQ(destination.Get(), &refcount); ASSERT_EQ(2, refcount); destination = Object(); @@ -158,8 +158,8 @@ TEST(ObjectBase, MoveAssignment) { Object destination; destination = std::move(source); - ASSERT_EQ(source.GetHandle(), nullptr); - ASSERT_EQ(destination.GetHandle(), &refcount); + ASSERT_EQ(source.Get(), nullptr); + ASSERT_EQ(destination.Get(), &refcount); ASSERT_EQ(2, refcount); destination = Object(); @@ -176,14 +176,14 @@ TEST(ObjectBase, MoveAssignmentSelf) { Object* objPtr = &obj; obj = std::move(*objPtr); - ASSERT_EQ(obj.GetHandle(), &refcount); + ASSERT_EQ(obj.Get(), &refcount); ASSERT_EQ(refcount, 2); } // Test the constructor using nullptr TEST(ObjectBase, NullptrConstructor) { Object obj(nullptr); - ASSERT_EQ(obj.GetHandle(), nullptr); + ASSERT_EQ(obj.Get(), nullptr); } // Test assigning nullptr to the object diff --git a/src/webnn/tests/unittests/validation/GraphValidationTests.cpp b/src/webnn/tests/unittests/validation/GraphValidationTests.cpp index 586537f0c..485b233d3 100644 --- a/src/webnn/tests/unittests/validation/GraphValidationTests.cpp +++ b/src/webnn/tests/unittests/validation/GraphValidationTests.cpp @@ -40,13 +40,17 @@ class GraphValidationTest : public ValidationTest { // Test the simple success case. TEST_F(GraphValidationTest, BuildGraphSuccess) { - wnn::NamedOperands namedOperands = wnn::CreateNamedOperands(); - namedOperands.Set("output", mOutput); - mBuilder.Build(namedOperands); + // TODO::Use instance->CreateNamedOperands instead of wnn::CreateNamedOperands + // that is removed. + // wnn::NamedOperands namedOperands = wnn::CreateNamedOperands(); + // namedOperands.Set("output", mOutput); + // mBuilder.Build(namedOperands); } // Create model with null nameOperands TEST_F(GraphValidationTest, BuildGraphError) { - wnn::NamedOperands namedOperands = wnn::CreateNamedOperands(); - DAWN_ASSERT(mBuilder.Build(namedOperands) == nullptr); + // TODO::Use instance->CreateNamedOperands instead of wnn::CreateNamedOperands + // that is removed. + // wnn::NamedOperands namedOperands = wnn::CreateNamedOperands(); + // DAWN_ASSERT(mBuilder.Build(namedOperands) == nullptr); } diff --git a/src/webnn/tests/unittests/validation/ValidationTest.cpp b/src/webnn/tests/unittests/validation/ValidationTest.cpp index bec699be3..cb5efe68a 100644 --- a/src/webnn/tests/unittests/validation/ValidationTest.cpp +++ b/src/webnn/tests/unittests/validation/ValidationTest.cpp @@ -29,7 +29,7 @@ void ValidationTest::SetUp() { ASSERT_TRUE(context != nullptr); mContext = wnn::Context::Acquire(context); mContext.SetUncapturedErrorCallback(ErrorCallback, this); - mBuilder = wnn::CreateGraphBuilder(mContext); + mBuilder = wnn::GraphBuilder::Acquire(instance->CreateGraphBuilder(mContext.Get())); } ValidationTest::~ValidationTest() { diff --git a/src/webnn/utils/BUILD.gn b/src/webnn/utils/BUILD.gn index 0defcc9b6..b9638e33d 100644 --- a/src/webnn/utils/BUILD.gn +++ b/src/webnn/utils/BUILD.gn @@ -11,16 +11,16 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. +# limitations under the License. import("../../../scripts/webnn_overrides_with_defaults.gni") -import("${webnn_dawn_root}/scripts/dawn_features.gni") +import("${dawn_root}/scripts/dawn_features.gni") import("${webnn_root}/build_overrides/webnn_features.gni") ############################################################################### # Utils for tests and samples -############################################################################### +############################################################################### static_library("webnn_utils") { configs += [ "${webnn_root}/src/webnn/common:internal_config" ] diff --git a/src/webnn/wire/BUILD.gn b/src/webnn/wire/BUILD.gn index 37219331a..b4a9b8e7e 100644 --- a/src/webnn/wire/BUILD.gn +++ b/src/webnn/wire/BUILD.gn @@ -15,7 +15,7 @@ import("../../../scripts/webnn_overrides_with_defaults.gni") -import("${webnn_dawn_root}/scripts/dawn_component.gni") +import("${dawn_root}/scripts/dawn_component.gni") import("${webnn_root}/generator/webnn_generator.gni") # Public webnn_wire headers so they can be publically visible for @@ -35,8 +35,8 @@ webnn_json_generator("gen") { target = "wire" outputs = [ "src/webnn/wire/ObjectType_autogen.h", - "src/webnn/wire/WireCmd_autogen.h", "src/webnn/wire/WireCmd_autogen.cpp", + "src/webnn/wire/WireCmd_autogen.h", "src/webnn/wire/client/ApiObjects_autogen.h", "src/webnn/wire/client/ApiProcs_autogen.cpp", "src/webnn/wire/client/ClientBase_autogen.h", @@ -107,7 +107,7 @@ dawn_component("webnn_wire") { if (is_component_build) { libs = [ "dawn_wire.dll.lib" ] } - deps += [ "${webnn_dawn_root}/src/dawn/wire" ] + deps += [ "${dawn_root}/src/dawn/wire" ] } # Make headers publicly visible diff --git a/webnn.json b/webnn.json index ef4546481..c31803d0d 100644 --- a/webnn.json +++ b/webnn.json @@ -15,6 +15,21 @@ "See the License for the specific language governing permissions and", "limitations under the License." ], + + "_metadata": { + "api": "WebNN", + "c_prefix": "WNN", + "namespace": "wnn", + "native_namespace": "webnn native", + "wire_namespace": "webnn wire", + "proc_table_prefix": "Webnn" + }, + + "proc": { + "category": "function pointer", + "returns": "void", + "args": [] + }, "instance descriptor": { "category": "structure", "members": [] @@ -145,7 +160,7 @@ ] }, "error callback": { - "category": "callback", + "category": "function pointer", "args": [ {"name": "type", "type": "error type"}, {"name": "message", "type": "char", "annotation": "const*"}, @@ -254,7 +269,7 @@ "returns": "size_t" }, { - "name": "get", + "name": "get operand", "returns": "operand", "args": [ {"name": "index", "type": "size_t"} @@ -270,14 +285,14 @@ "returns": "size_t" }, { - "name": "set", + "name": "set fusion operator", "returns": "void", "args": [ - {"name": "operator", "type": "fusion operator"} + {"name": "fusion operator", "type": "fusion operator"} ] }, { - "name": "get", + "name": "get fusion operator", "returns": "fusion operator", "args": [ {"name": "index", "type": "size_t"} @@ -1048,14 +1063,14 @@ "category": "object", "methods": [ { - "name": "set", + "name": "set output", "args": [ {"name": "name", "type": "char", "annotation": "const*", "length": "strlen"}, {"name": "resource", "type": "resource", "annotation": "const*"} ] }, { - "name": "get", + "name": "get output", "args": [ {"name": "name", "type": "char", "annotation": "const*", "length": "strlen"}, {"name": "resource", "type": "array buffer view", "annotation": "*"} @@ -1064,7 +1079,7 @@ ] }, "compute async callback": { - "category": "callback", + "category": "function pointer", "args": [ {"name": "type", "type": "error type"}, {"name": "message", "type": "char", "annotation": "const*", "length": "strlen"}, @@ -1093,5 +1108,11 @@ ] } ] + }, + "s type": { + "category": "enum", + "values": [ + {"value": 0, "name": "invalid", "valid": false} + ] } }