diff --git a/DEPS b/DEPS index 3535c9b0f..09d555211 100644 --- a/DEPS +++ b/DEPS @@ -8,7 +8,8 @@ gclient_gn_args = [ vars = { 'chromium_git': 'https://chromium.googlesource.com', - 'dawn_git': 'https://dawn.googlesource.com', + # 'dawn_git': 'https://github.com/fujunwei', + 'dawn_git': 'https://github.com/lisa0314', 'github_git': 'https://github.com', 'dawn_standalone': True, @@ -45,9 +46,15 @@ deps = { # Dependencies required for code generator and infrastructure code. 'third_party/dawn': { - 'url': '{dawn_git}/dawn.git@bf1c0cf52377b4db2bf3a433dc5056620aad7cdd' + # 'url': '{dawn_git}/dawn.git@f4c84e239bf8b5b2c4733d68ca38e1e9049fd895' + 'url': '{dawn_git}/dawn.git@5e6f6fbfcb038e7a0f7857cda186a8771c6eba05' }, + 'third_party/abseil-cpp': { + 'url': '{chromium_git}/chromium/src/third_party/abseil-cpp@789af048b388657987c59d4da406859034fe310f', + 'condition': 'dawn_standalone', + }, + # Dependencies required for backends. 'third_party/DirectML': { 'url': '{github_git}/microsoft/DirectML.git@c3f16a701beeeefc9ce5b67c71b554a6903c0f67', @@ -136,7 +143,7 @@ deps = { # Jinja2 and MarkupSafe for the code generator 'third_party/jinja2': { - 'url': '{chromium_git}/chromium/src/third_party/jinja2@a82a4944a7f2496639f34a89c9923be5908b80aa', + 'url': '{chromium_git}/chromium/src/third_party/jinja2@ee69aa00ee8536f61db6a451f3858745cf587de6', 'condition': 'dawn_standalone', }, 'third_party/markupsafe': { diff --git a/build_overrides/webnn.gni b/build_overrides/webnn.gni index 7155c1cfd..9ca5b973e 100644 --- a/build_overrides/webnn.gni +++ b/build_overrides/webnn.gni @@ -15,7 +15,8 @@ webnn_standalone = true # The paths to WebNN's dependencies -webnn_dawn_root = "//third_party/dawn" +webnn_abseil_dir = "//third_party/abseil-cpp" +dawn_root = "//third_party/dawn" webnn_googletest_dir = "//third_party/googletest" webnn_jinja2_dir = "//third_party/jinja2" webnn_gpgmm_dir = "//third_party/gpgmm" diff --git a/examples/LeNet/LeNet.cpp b/examples/LeNet/LeNet.cpp index a9ae2a6b4..74bb17568 100644 --- a/examples/LeNet/LeNet.cpp +++ b/examples/LeNet/LeNet.cpp @@ -48,7 +48,7 @@ wnn::Graph LeNet::Build(const std::string& weigthsPath) { return nullptr; } - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(mContext); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(mContext); uint32_t byteOffset = 0; const wnn::Operand input = utils::BuildInput(builder, "input", {1, 1, 28, 28}); diff --git a/examples/MobileNetV2/Main.cpp b/examples/MobileNetV2/Main.cpp index 340d865e3..94ef00563 100644 --- a/examples/MobileNetV2/Main.cpp +++ b/examples/MobileNetV2/Main.cpp @@ -39,7 +39,8 @@ int main(int argc, const char* argv[]) { } }, &mobilevetv2); - wnn::GraphBuilder builder = wnn::CreateGraphBuilder(context); + + wnn::GraphBuilder builder = utils::CreateGraphBuilder(context); wnn::Operand output = mobilevetv2.mLayout == "nchw" ? mobilevetv2.LoadNchw(builder) : mobilevetv2.LoadNhwc(builder); diff --git a/examples/ResNet/Main.cpp b/examples/ResNet/Main.cpp index 64773cfd3..462e494c0 100644 --- a/examples/ResNet/Main.cpp +++ b/examples/ResNet/Main.cpp @@ -39,7 +39,7 @@ int main(int argc, const char* argv[]) { } }, &resnet); - wnn::GraphBuilder builder = wnn::CreateGraphBuilder(context); + wnn::GraphBuilder builder = utils::CreateGraphBuilder(context); wnn::Operand output = resnet.mLayout == "nchw" ? resnet.LoadNchw(builder) : resnet.LoadNhwc(builder); diff --git a/examples/SampleUtils.cpp b/examples/SampleUtils.cpp index ff9a6c10e..18175ead5 100644 --- a/examples/SampleUtils.cpp +++ b/examples/SampleUtils.cpp @@ -50,7 +50,7 @@ static webnn::wire::WireClient* wireClient = nullptr; static utils::TerribleCommandBuffer* c2sBuf = nullptr; static utils::TerribleCommandBuffer* s2cBuf = nullptr; -static wnn::Instance clientInstance; +static wnn::Instance instance; static std::unique_ptr nativeInstance; wnn::Context CreateCppContext(wnn::ContextOptions const* options) { nativeInstance = std::make_unique(); @@ -62,11 +62,14 @@ wnn::Context CreateCppContext(wnn::ContextOptions const* options) { // Choose whether to use the backend procs and context directly, or set up the wire. WNNContext context = nullptr; WebnnProcTable procs; + WNNInstance wnnInstance; + switch (cmdBufType) { case CmdBufType::None: procs = backendProcs; context = backendContext; + wnnInstance = nativeInstance->Get(); break; case CmdBufType::Terrible: { @@ -94,14 +97,13 @@ wnn::Context CreateCppContext(wnn::ContextOptions const* options) { context = contextReservation.context; #else - webnnProcSetProcs(&procs); auto instanceReservation = wireClient->ReserveInstance(); wireServer->InjectInstance(nativeInstance->Get(), instanceReservation.id, instanceReservation.generation); // Keep the reference instread of using Acquire. // TODO:: make the instance in the client as singleton object. - clientInstance = wnn::Instance(instanceReservation.instance); - return clientInstance.CreateContext(options); + wnnInstance = instanceReservation.instance; + break; #endif } default: @@ -109,8 +111,9 @@ wnn::Context CreateCppContext(wnn::ContextOptions const* options) { DAWN_ASSERT(0); } webnnProcSetProcs(&procs); - - return wnn::Context::Acquire(context); + instance = wnn::Instance(wnnInstance); + return instance.CreateContext(options); + ; } void DoFlush() { @@ -123,35 +126,15 @@ void DoFlush() { } wnn::NamedInputs CreateCppNamedInputs() { -#if defined(WEBNN_ENABLE_WIRE) - return clientInstance.CreateNamedInputs(); -#else - return wnn::CreateNamedInputs(); -#endif // defined(WEBNN_ENABLE_WIRE) -} - -wnn::NamedOperands CreateCppNamedOperands() { -#if defined(WEBNN_ENABLE_WIRE) - return clientInstance.CreateNamedOperands(); -#else - return wnn::CreateNamedOperands(); -#endif // defined(WEBNN_ENABLE_WIRE) + return instance.CreateNamedInputs(); } wnn::NamedOutputs CreateCppNamedOutputs() { -#if defined(WEBNN_ENABLE_WIRE) - return clientInstance.CreateNamedOutputs(); -#else - return wnn::CreateNamedOutputs(); -#endif // defined(WEBNN_ENABLE_WIRE) + return instance.CreateNamedOutputs(); } wnn::OperatorArray CreateCppOperatorArray() { -#if defined(WEBNN_ENABLE_WIRE) - return clientInstance.CreateOperatorArray(); -#else - return wnn::CreateOperatorArray(); -#endif // defined(WEBNN_ENABLE_WIRE) + return instance.CreateOperatorArray(); } bool ExampleBase::ParseAndCheckExampleOptions(int argc, const char* argv[]) { @@ -264,6 +247,10 @@ namespace utils { return activationOperand; } + wnn::GraphBuilder CreateGraphBuilder(const wnn::Context& context) { + return instance.CreateGraphBuilder(context); + } + wnn::Operand BuildInput(const wnn::GraphBuilder& builder, std::string name, const std::vector& dimensions, @@ -283,7 +270,7 @@ namespace utils { } wnn::Graph Build(const wnn::GraphBuilder& builder, const std::vector& outputs) { - wnn::NamedOperands namedOperands = CreateCppNamedOperands(); + wnn::NamedOperands namedOperands = instance.CreateNamedOperands(); for (auto& output : outputs) { namedOperands.Set(output.name.c_str(), output.operand); } diff --git a/examples/SampleUtils.h b/examples/SampleUtils.h index 7c424fd18..fa21b9e0a 100644 --- a/examples/SampleUtils.h +++ b/examples/SampleUtils.h @@ -83,6 +83,7 @@ namespace utils { FusedActivation activation, const void* options = nullptr); + wnn::GraphBuilder CreateGraphBuilder(const wnn::Context& context); wnn::Operand BuildInput(const wnn::GraphBuilder& builder, std::string name, const std::vector& dimensions, @@ -247,7 +248,7 @@ namespace utils { void Compute(const wnn::Graph& graph, const std::vector>& inputs, const std::vector>& outputs) { - if (graph.GetHandle() == nullptr) { + if (graph.Get() == nullptr) { dawn::ErrorLog() << "The graph is invaild."; } @@ -272,7 +273,7 @@ namespace utils { resource.arrayBufferView.buffer = output.resource.data(); resource.arrayBufferView.byteLength = output.resource.size() * sizeof(float); mlOutputs.push_back(resource); - namedOutputs.Set(output.name.c_str(), &mlOutputs.back()); + namedOutputs.SetOutput(output.name.c_str(), &mlOutputs.back()); } graph.Compute(namedInputs, namedOutputs); DoFlush(); diff --git a/examples/SqueezeNet/Main.cpp b/examples/SqueezeNet/Main.cpp index feba69e18..d250e0e7a 100644 --- a/examples/SqueezeNet/Main.cpp +++ b/examples/SqueezeNet/Main.cpp @@ -39,7 +39,7 @@ int main(int argc, const char* argv[]) { } }, &squeezenet); - wnn::GraphBuilder builder = wnn::CreateGraphBuilder(context); + wnn::GraphBuilder builder = utils::CreateGraphBuilder(context); wnn::Operand output = squeezenet.mLayout == "nchw" ? squeezenet.LoadNchw(builder) : squeezenet.LoadNhwc(builder); diff --git a/generator/webnn_generator.gni b/generator/webnn_generator.gni index 7c76e1fc3..0c213c44a 100644 --- a/generator/webnn_generator.gni +++ b/generator/webnn_generator.gni @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import("//third_party/dawn/generator/generator_lib.gni") +import("//third_party/dawn/generator/dawn_generator.gni") import("../scripts/webnn_overrides_with_defaults.gni") # Dawn used to put autogenerated files in a lot of different places. When we @@ -39,12 +39,16 @@ import("../scripts/webnn_overrides_with_defaults.gni") # disallowed gen directories. webnn_allowed_gen_output_dirs = [ + "src/dawn/", + "src/dawn/native/", "src/webnn/", "src/webnn/native/", "src/webnn/wire/client/", "src/webnn/wire/server/", "src/webnn/wire/", "include/webnn/", + "emscripten-bits/", + "include/dawn/", ] # Template to help invoking Dawn code generators based on generator_lib @@ -77,7 +81,7 @@ template("webnn_generator") { forward_variables_from(invoker, "*") # Set arguments required to find the python libraries for the generator - generator_lib_dir = "${webnn_dawn_root}/generator" + generator_lib_dir = "${dawn_root}/generator" jinja2_path = webnn_jinja2_dir # Force Dawn's autogenerated file structure to mirror exactly the source @@ -87,15 +91,16 @@ template("webnn_generator") { # Make sure that we delete stale autogenerated file in directories that are # no longer used by code generation to avoid include conflicts. - deps = [ "${webnn_root}/generator:remove_stale_autogen_files" ] + deps = [ "${dawn_root}/generator:remove_stale_autogen_files" ] + + template_dir = "${dawn_root}/generator/templates" - template_dir = "${webnn_root}/generator/templates" } } # Helper generator for calling the generator from webnn.json # -# dawn_json_generator("my_target_gen") { +# webnn_json_generator("my_target_gen") { # # Which generator target to output # target = "my_target" # @@ -103,21 +108,20 @@ template("webnn_generator") { # } template("webnn_json_generator") { webnn_generator(target_name) { - script = "${webnn_root}/generator/webnn_json_generator.py" + script = "${dawn_root}/generator/dawn_json_generator.py" # The base arguments for the generator: from this webnn.json, generate this # target using templates in this directory. args = [ - "--webnn-json", + "--dawn-json", rebase_path("${webnn_root}/webnn.json", root_build_dir), "--wire-json", rebase_path("${webnn_root}/webnn_wire.json", root_build_dir), "--targets", invoker.target, - "--dawn-generator-path", - rebase_path("${webnn_root}/third_party/dawn/generator"), ] forward_variables_from(invoker, "*", [ "target" ]) + } } diff --git a/include/webnn/BUILD.gn b/include/webnn/BUILD.gn index 906c76eaf..f21d89562 100644 --- a/include/webnn/BUILD.gn +++ b/include/webnn/BUILD.gn @@ -15,7 +15,7 @@ import("../../scripts/webnn_overrides_with_defaults.gni") -import("${webnn_dawn_root}/scripts/dawn_component.gni") +import("${dawn_root}/scripts/dawn_component.gni") import("${webnn_root}/generator/webnn_generator.gni") ############################################################################### @@ -25,8 +25,8 @@ import("${webnn_root}/generator/webnn_generator.gni") webnn_json_generator("headers_gen") { target = "headers" outputs = [ - "include/webnn/webnn_proc_table.h", - "include/webnn/webnn.h", + "include/dawn/webnn_proc_table.h", + "include/dawn/webnn.h", ] } @@ -43,7 +43,10 @@ source_set("headers") { webnn_json_generator("cpp_headers_gen") { target = "cpp_headers" - outputs = [ "include/webnn/webnn_cpp.h" ] + outputs = [ + "include/dawn/webnn_cpp.h", + "include/dawn/webnn_cpp_print.h", + ] } source_set("cpp_headers") { @@ -63,18 +66,14 @@ config("public") { include_dirs = [ "${target_gen_dir}/../../include", "${webnn_root}/include", + "${dawn_root}/include", + "${dawn_gen_root}/include", ] if (build_with_chromium) { include_dirs += [ - "${webnn_dawn_root}/include", + "${dawn_root}/include", "${dawn_gen_root}/include", ] - } else { - # TODO: Remove after upgrading webnn infranstructure align with dawn. - include_dirs += [ - "${webnn_dawn_root}/src/include", - "${dawn_gen_root}/src/include", - ] } } diff --git a/include/webnn/EnumClassBitmasks.h b/include/webnn/EnumClassBitmasks.h deleted file mode 100644 index bdef1df3d..000000000 --- a/include/webnn/EnumClassBitmasks.h +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2017 The Dawn Authors -// Copyright 2021 The WebNN-native Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef WEBNN_ENUM_CLASS_BITMASKS_H_ -#define WEBNN_ENUM_CLASS_BITMASKS_H_ - -#include - -namespace webnn { - - template - struct IsDawnBitmask { - static constexpr bool enable = false; - }; - - template - struct LowerBitmask { - static constexpr bool enable = false; - }; - - template - struct LowerBitmask::enable>::type> { - static constexpr bool enable = true; - using type = T; - constexpr static T Lower(T t) { - return t; - } - }; - - template - struct BoolConvertible { - using Integral = typename std::underlying_type::type; - - constexpr BoolConvertible(Integral value) : value(value) { - } - constexpr operator bool() const { - return value != 0; - } - constexpr operator T() const { - return static_cast(value); - } - - Integral value; - }; - - template - struct LowerBitmask> { - static constexpr bool enable = true; - using type = T; - static constexpr type Lower(BoolConvertible t) { - return t; - } - }; - - template ::enable && - LowerBitmask::enable>::type> - constexpr BoolConvertible::type> operator|(T1 left, T2 right) { - using T = typename LowerBitmask::type; - using Integral = typename std::underlying_type::type; - return static_cast(LowerBitmask::Lower(left)) | - static_cast(LowerBitmask::Lower(right)); - } - - template ::enable && - LowerBitmask::enable>::type> - constexpr BoolConvertible::type> operator&(T1 left, T2 right) { - using T = typename LowerBitmask::type; - using Integral = typename std::underlying_type::type; - return static_cast(LowerBitmask::Lower(left)) & - static_cast(LowerBitmask::Lower(right)); - } - - template ::enable && - LowerBitmask::enable>::type> - constexpr BoolConvertible::type> operator^(T1 left, T2 right) { - using T = typename LowerBitmask::type; - using Integral = typename std::underlying_type::type; - return static_cast(LowerBitmask::Lower(left)) ^ - static_cast(LowerBitmask::Lower(right)); - } - - template - constexpr BoolConvertible::type> operator~(T1 t) { - using T = typename LowerBitmask::type; - using Integral = typename std::underlying_type::type; - return ~static_cast(LowerBitmask::Lower(t)); - } - - template ::enable && - LowerBitmask::enable>::type> - constexpr T& operator&=(T& l, T2 right) { - T r = LowerBitmask::Lower(right); - l = l & r; - return l; - } - - template ::enable && - LowerBitmask::enable>::type> - constexpr T& operator|=(T& l, T2 right) { - T r = LowerBitmask::Lower(right); - l = l | r; - return l; - } - - template ::enable && - LowerBitmask::enable>::type> - constexpr T& operator^=(T& l, T2 right) { - T r = LowerBitmask::Lower(right); - l = l ^ r; - return l; - } - - template - constexpr bool HasZeroOrOneBits(T value) { - using Integral = typename std::underlying_type::type; - return (static_cast(value) & (static_cast(value) - 1)) == 0; - } - -} // namespace webnn - -#endif // WEBNN_ENUM_CLASS_BITMASKS_H_ diff --git a/include/webnn/native/WebnnNative.h b/include/webnn/native/WebnnNative.h index bcf76ddd4..bacd3956b 100644 --- a/include/webnn/native/WebnnNative.h +++ b/include/webnn/native/WebnnNative.h @@ -44,6 +44,7 @@ namespace webnn::native { WNNContext CreateTestContext(const wnn::ContextOptions* options = nullptr); WNNContext CreateContext(const wnn::ContextOptions* options = nullptr); + WNNGraphBuilder CreateGraphBuilder(const WNNContext context); // Returns the underlying WNNInstance object. WNNInstance Get() const; diff --git a/include/webnn/webnn.h b/include/webnn/webnn.h new file mode 100644 index 000000000..f45f3b601 --- /dev/null +++ b/include/webnn/webnn.h @@ -0,0 +1 @@ +#include "dawn/webnn.h" diff --git a/include/webnn/webnn_cpp.h b/include/webnn/webnn_cpp.h new file mode 100644 index 000000000..666ea57bf --- /dev/null +++ b/include/webnn/webnn_cpp.h @@ -0,0 +1 @@ +#include diff --git a/include/webnn/webnn_proc.h b/include/webnn/webnn_proc.h index 959453125..2bf0228b3 100644 --- a/include/webnn/webnn_proc.h +++ b/include/webnn/webnn_proc.h @@ -16,8 +16,8 @@ #ifndef WEBNN_WEBNN_PROC_H_ #define WEBNN_WEBNN_PROC_H_ +#include "dawn/webnn_proc_table.h" #include "webnn/webnn.h" -#include "webnn/webnn_proc_table.h" #ifdef __cplusplus extern "C" { @@ -28,7 +28,7 @@ extern "C" { // default value of the proctable. Setting the proctable back to null is good practice when you // are done using libdawn_proc since further usage will cause a segfault instead of calling an // unexpected function. -WEBNN_EXPORT void webnnProcSetProcs(const WebnnProcTable* procs); +WNN_EXPORT void webnnProcSetProcs(const WebnnProcTable* procs); #ifdef __cplusplus } // extern "C" diff --git a/include/webnn/webnn_proc_table.h b/include/webnn/webnn_proc_table.h new file mode 100644 index 000000000..390201d09 --- /dev/null +++ b/include/webnn/webnn_proc_table.h @@ -0,0 +1 @@ +#include "dawn/webnn_proc_table.h" diff --git a/include/webnn/webnn_thread_dispatch_proc.h b/include/webnn/webnn_thread_dispatch_proc.h new file mode 100644 index 000000000..bf0fb012f --- /dev/null +++ b/include/webnn/webnn_thread_dispatch_proc.h @@ -0,0 +1,33 @@ +// Copyright 2020 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DAWN_WEBNN_THREAD_DISPATCH_PROC_H_ +#define DAWN_WEBNN_THREAD_DISPATCH_PROC_H_ + +#include "webnn_proc.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// Call webnnProcSetProcs(&webnnThreadDispatchProcTable) and then use webnnProcSetPerThreadProcs +// to set per-thread procs. +WNN_EXPORT extern WebnnProcTable webnnThreadDispatchProcTable; +WNN_EXPORT void webnnProcSetPerThreadProcs(const WebnnProcTable* procs); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // DAWN_WEBNN_THREAD_DISPATCH_PROC_H_ diff --git a/include/webnn/wire/Wire.h b/include/webnn/wire/Wire.h index 5703754e2..dd6da5dab 100644 --- a/include/webnn/wire/Wire.h +++ b/include/webnn/wire/Wire.h @@ -26,7 +26,10 @@ namespace webnn::wire { class WEBNN_WIRE_EXPORT CommandSerializer { public: - virtual ~CommandSerializer() = default; + CommandSerializer(); + virtual ~CommandSerializer(); + CommandSerializer(const CommandSerializer& rhs) = delete; + CommandSerializer& operator=(const CommandSerializer& rhs) = delete; // Get space for serializing commands. // GetCmdSpace will never be called with a value larger than @@ -35,6 +38,7 @@ namespace webnn::wire { virtual void* GetCmdSpace(size_t size) = 0; virtual bool Flush() = 0; virtual size_t GetMaximumAllocationSize() const = 0; + virtual void OnSerializeError(); }; class WEBNN_WIRE_EXPORT CommandHandler { diff --git a/node/src/GraphBuilder.cpp b/node/src/GraphBuilder.cpp index d04c657bb..392096d50 100644 --- a/node/src/GraphBuilder.cpp +++ b/node/src/GraphBuilder.cpp @@ -79,7 +79,7 @@ namespace node { : Napi::ObjectWrap(info) { Napi::Object object = info[0].As(); node::Context* context = Napi::ObjectWrap::Unwrap(object); - mImpl = wnn::CreateGraphBuilder(context->GetImpl()); + mImpl = utils::CreateGraphBuilder(context->GetImpl()); } Napi::Value GraphBuilder::Constant(const Napi::CallbackInfo& info) { diff --git a/scripts/webnn_overrides_with_defaults.gni b/scripts/webnn_overrides_with_defaults.gni index 2f90d69a9..ab2d22d0a 100644 --- a/scripts/webnn_overrides_with_defaults.gni +++ b/scripts/webnn_overrides_with_defaults.gni @@ -17,8 +17,8 @@ if (!defined(webnn_standalone)) { webnn_standalone = false } -if (!defined(webnn_dawn_root)) { - webnn_dawn_root = "//third_party/dawn" +if (!defined(dawn_root)) { + dawn_root = "//third_party/dawn" } if (!defined(webnn_jinja2_dir)) { @@ -33,3 +33,7 @@ webnn_gen_root = get_path_info("${webnn_root}", "gen_dir") if (!defined(webnn_googletest_dir)) { webnn_googletest_dir = "//third_party/googletest" } + +if (!defined(webnn_abseil_dir)) { + webnn_abseil_dir = "//third_party/abseil-cpp" +} diff --git a/src/webnn/BUILD.gn b/src/webnn/BUILD.gn index 0c025e22b..576dc8144 100644 --- a/src/webnn/BUILD.gn +++ b/src/webnn/BUILD.gn @@ -15,7 +15,7 @@ import("../../scripts/webnn_overrides_with_defaults.gni") -import("${webnn_dawn_root}/scripts/dawn_component.gni") +import("${dawn_root}/scripts/dawn_component.gni") import("${webnn_root}/generator/webnn_generator.gni") ############################################################################### @@ -24,7 +24,7 @@ import("${webnn_root}/generator/webnn_generator.gni") webnn_json_generator("cpp_gen") { target = "cpp" - outputs = [ "src/webnn/webnn_cpp.cpp" ] + outputs = [ "src/dawn/webnn_cpp.cpp" ] } source_set("cpp") { @@ -41,11 +41,14 @@ source_set("cpp") { webnn_json_generator("proc_gen") { target = "proc" - outputs = [ "src/webnn/webnn_proc.c" ] + outputs = [ + "src/dawn/webnn_proc.c", + "src/dawn/webnn_thread_dispatch_proc.cpp", + ] } dawn_component("webnn_proc") { - DEFINE_PREFIX = "WEBNN" + DEFINE_PREFIX = "WNN" public_deps = [ "${webnn_root}/include/webnn:headers" ] deps = [ ":proc_gen" ] @@ -59,7 +62,10 @@ dawn_component("webnn_proc") { webnn_json_generator("emscripten_bits_gen") { target = "emscripten_bits" outputs = [ - "src/webnn/webnn_struct_info.json", - "src/webnn/library_webnn_enum_tables.js", + "emscripten-bits/library_webnn_enum_tables.js", + "emscripten-bits/webnn.h", + "emscripten-bits/webnn_cpp.cpp", + "emscripten-bits/webnn_cpp.h", + "emscripten-bits/webnn_struct_info.json", ] } diff --git a/src/webnn/common/BUILD.gn b/src/webnn/common/BUILD.gn index 21b0226c6..a967db953 100644 --- a/src/webnn/common/BUILD.gn +++ b/src/webnn/common/BUILD.gn @@ -16,8 +16,8 @@ import("../../../scripts/webnn_overrides_with_defaults.gni") import("//build_overrides/build.gni") -import("${webnn_dawn_root}/scripts/dawn_features.gni") -import("${webnn_dawn_root}/scripts/dawn_overrides_with_defaults.gni") +import("${dawn_root}/scripts/dawn_features.gni") +import("${dawn_root}/scripts/dawn_overrides_with_defaults.gni") import("${webnn_root}/build_overrides/webnn_features.gni") # Use Chromium's dcheck_always_on when available so that we respect it when @@ -41,11 +41,16 @@ config("internal_config") { include_dirs = [ "${target_gen_dir}/../../../src", "${webnn_root}/src", - "${webnn_dawn_root}/src", - "${webnn_dawn_root}/src/dawn", - "${webnn_root}", + "${dawn_root}/src", + "${dawn_root}/src/dawn", + "${webnn_root}/include", + "${dawn_root}/include", ] + if (build_with_chromium) { + include_dirs += [ "${dawn_gen_root}/include" ] + } + defines = [] if (dawn_always_assert || dcheck_always_on || is_debug || use_fuzzing_engine) { @@ -122,6 +127,7 @@ config("internal_config") { "-Wshadow-field", "-Wstrict-prototypes", "-Wtautological-unsigned-zero-compare", + "-Wno-unused-function", ] # Allow comparison against type limits that might be tautological on 32bit @@ -179,6 +185,14 @@ config("internal_config") { "-Wno-c++17-extensions", ] } + + + if (is_clang && webnn_enable_wire) { + cflags += [ + "-Wno-unused-function", + ] + } + } ############################################################################### @@ -190,42 +204,22 @@ config("internal_config") { # systems we know Dawn is able to compile on. if (is_win || is_linux || is_chromeos || is_mac || is_fuchsia || is_android) { static_library("common") { - if (build_with_chromium) { - sources = [ - "//third_party/dawn/src/dawn/common/Assert.cpp", - "//third_party/dawn/src/dawn/common/Assert.h", - "//third_party/dawn/src/dawn/common/Compiler.h", - "//third_party/dawn/src/dawn/common/Log.cpp", - "//third_party/dawn/src/dawn/common/Log.h", - "//third_party/dawn/src/dawn/common/Math.cpp", - "//third_party/dawn/src/dawn/common/Math.h", - "//third_party/dawn/src/dawn/common/Platform.h", - "//third_party/dawn/src/dawn/common/RefCounted.cpp", - "//third_party/dawn/src/dawn/common/RefCounted.h", - "//third_party/dawn/src/dawn/common/Result.cpp", - "//third_party/dawn/src/dawn/common/Result.h", - "//third_party/dawn/src/dawn/common/SystemUtils.cpp", - "//third_party/dawn/src/dawn/common/SystemUtils.h", - ] - } else { - # TODO: Remove after upgrading webnn infranstructure align with dawn. - sources = [ - "//third_party/dawn/src/common/Assert.cpp", - "//third_party/dawn/src/common/Assert.h", - "//third_party/dawn/src/common/Compiler.h", - "//third_party/dawn/src/common/Log.cpp", - "//third_party/dawn/src/common/Log.h", - "//third_party/dawn/src/common/Math.cpp", - "//third_party/dawn/src/common/Math.h", - "//third_party/dawn/src/common/Platform.h", - "//third_party/dawn/src/common/RefCounted.cpp", - "//third_party/dawn/src/common/RefCounted.h", - "//third_party/dawn/src/common/Result.cpp", - "//third_party/dawn/src/common/Result.h", - "//third_party/dawn/src/common/SystemUtils.cpp", - "//third_party/dawn/src/common/SystemUtils.h", - ] - } + sources = [ + "${dawn_root}/src/dawn/common/Assert.cpp", + "${dawn_root}/src/dawn/common/Assert.h", + "${dawn_root}/src/dawn/common/Compiler.h", + "${dawn_root}/src/dawn/common/Log.cpp", + "${dawn_root}/src/dawn/common/Log.h", + "${dawn_root}/src/dawn/common/Math.cpp", + "${dawn_root}/src/dawn/common/Math.h", + "${dawn_root}/src/dawn/common/Platform.h", + "${dawn_root}/src/dawn/common/RefCounted.cpp", + "${dawn_root}/src/dawn/common/RefCounted.h", + "${dawn_root}/src/dawn/common/Result.cpp", + "${dawn_root}/src/dawn/common/Result.h", + "${dawn_root}/src/dawn/common/SystemUtils.cpp", + "${dawn_root}/src/dawn/common/SystemUtils.h", + ] public_configs = [ ":internal_config" ] deps = [ diff --git a/src/webnn/native/BUILD.gn b/src/webnn/native/BUILD.gn index 0815b96be..76f9e4340 100644 --- a/src/webnn/native/BUILD.gn +++ b/src/webnn/native/BUILD.gn @@ -16,8 +16,8 @@ import("../../../scripts/webnn_overrides_with_defaults.gni") import("//build_overrides/build.gni") -import("${webnn_dawn_root}/scripts/dawn_component.gni") -import("${webnn_dawn_root}/scripts/dawn_features.gni") +import("${dawn_root}/scripts/dawn_component.gni") +import("${dawn_root}/scripts/dawn_features.gni") import("${webnn_root}/build_overrides/webnn_features.gni") import("${webnn_root}/generator/webnn_generator.gni") @@ -30,6 +30,17 @@ if (is_mac) { } } +group("abseil") { + # When build_with_chromium=true we need to include "//third_party/abseil-cpp:absl" while + # it's beneficial to be more specific with standalone Dawn, especially when it comes to + # including it as a dependency in other projects (such as Skia). + if (build_with_chromium) { + public_deps = [ "$webnn_abseil_dir:absl" ] + } else { + public_deps = [ "${webnn_root}/third_party/gn/abseil-cpp:str_format" ] + } +} + config("internal") { configs = [ "${webnn_root}/src/webnn/common:internal_config" ] @@ -51,12 +62,18 @@ config("internal") { webnn_json_generator("utils_gen") { target = "native_utils" outputs = [ + "src/webnn/native/ChainUtils_autogen.cpp", + "src/webnn/native/ChainUtils_autogen.h", + "src/webnn/native/ObjectType_autogen.cpp", + "src/webnn/native/ObjectType_autogen.h", "src/webnn/native/ProcTable.cpp", - "src/webnn/native/webnn_structs_autogen.h", - "src/webnn/native/webnn_structs_autogen.cpp", - "src/webnn/native/ValidationUtils_autogen.h", "src/webnn/native/ValidationUtils_autogen.cpp", + "src/webnn/native/ValidationUtils_autogen.h", + "src/webnn/native/webnn_absl_format_autogen.cpp", + "src/webnn/native/webnn_absl_format_autogen.h", "src/webnn/native/webnn_platform_autogen.h", + "src/webnn/native/wnn_structs_autogen.cpp", + "src/webnn/native/wnn_structs_autogen.h", ] } @@ -89,6 +106,8 @@ source_set("sources") { configs += [ ":internal" ] + public_deps = [ ":abseil" ] + sources = get_target_outputs(":utils_gen") sources += [ @@ -234,7 +253,7 @@ source_set("sources") { include_dirs += [ "${webnn_root}/third_party/microsoft.ai.directml.1.8.2/include" ] if (build_with_chromium) { - include_dirs += [ "${webnn_dawn_root}/src/dawn/native/dml/deps/src" ] + include_dirs += [ "${dawn_root}/src/dawn/native/dml/deps/src" ] } else { include_dirs += [ "${webnn_root}/third_party/DirectML/Libraries" ] } @@ -255,7 +274,7 @@ source_set("sources") { "dawn_wire.dll.lib", ] } - deps += [ "${webnn_dawn_root}/src/dawn/native" ] + deps += [ "${dawn_root}/src/dawn/native" ] } cflags = [ diff --git a/src/webnn/native/Context.cpp b/src/webnn/native/Context.cpp index dc7152502..e490c824c 100644 --- a/src/webnn/native/Context.cpp +++ b/src/webnn/native/Context.cpp @@ -65,7 +65,7 @@ namespace webnn::native { } #endif - void ContextBase::InjectError(wnn::ErrorType type, const char* message) { + void ContextBase::APIInjectError(wnn::ErrorType type, const char* message) { if (ConsumedError(ValidateErrorType(type))) { return; } @@ -80,14 +80,14 @@ namespace webnn::native { HandleError(DAWN_MAKE_ERROR(FromWNNErrorType(type), message)); } - void ContextBase::PushErrorScope(wnn::ErrorFilter filter) { + void ContextBase::APIPushErrorScope(wnn::ErrorFilter filter) { if (ConsumedError(ValidateErrorFilter(filter))) { return; } mCurrentErrorScope = AcquireRef(new ErrorScope(filter, mCurrentErrorScope.Get())); } - bool ContextBase::PopErrorScope(wnn::ErrorCallback callback, void* userdata) { + bool ContextBase::APIPopErrorScope(wnn::ErrorCallback callback, void* userdata) { if (DAWN_UNLIKELY(mCurrentErrorScope.Get() == mRootErrorScope.Get())) { return false; } @@ -97,7 +97,7 @@ namespace webnn::native { return true; } - void ContextBase::SetUncapturedErrorCallback(wnn::ErrorCallback callback, void* userdata) { + void ContextBase::APISetUncapturedErrorCallback(wnn::ErrorCallback callback, void* userdata) { mRootErrorScope->SetCallback(callback, userdata); } diff --git a/src/webnn/native/Context.h b/src/webnn/native/Context.h index d01cc16c8..4a5b4e845 100644 --- a/src/webnn/native/Context.h +++ b/src/webnn/native/Context.h @@ -58,11 +58,12 @@ namespace webnn::native { WGPUDevice GetWGPUDevice(); #endif - // Dawn API - void InjectError(wnn::ErrorType type, const char* message); - void PushErrorScope(wnn::ErrorFilter filter); - bool PopErrorScope(wnn::ErrorCallback callback, void* userdata); - void SetUncapturedErrorCallback(wnn::ErrorCallback callback, void* userdata); + // Webnn API + void APIInjectError(wnn::ErrorType type, const char* message); + void APIPushErrorScope(wnn::ErrorFilter filter); + bool APIPopErrorScope(wnn::ErrorCallback callback, void* userdata); + void APISetUncapturedErrorCallback(wnn::ErrorCallback callback, void* userdata); + ContextOptions GetContextOptions() { return mContextOptions; } diff --git a/src/webnn/native/Error.h b/src/webnn/native/Error.h index 9bcb918c5..7084fa2d3 100644 --- a/src/webnn/native/Error.h +++ b/src/webnn/native/Error.h @@ -16,8 +16,10 @@ #ifndef WEBNN_NATIVE_ERROR_H_ #define WEBNN_NATIVE_ERROR_H_ +#include "absl/strings/str_format.h" #include "common/Result.h" #include "webnn/native/ErrorData.h" +#include "webnn/native/webnn_absl_format_autogen.h" #include @@ -80,11 +82,11 @@ namespace webnn::native { DAWN_MAKE_ERROR(InternalErrorType::Internal, std::string("Unimplemented: ") + MESSAGE) #define DAWN_OUT_OF_MEMORY_ERROR(MESSAGE) DAWN_MAKE_ERROR(InternalErrorType::OutOfMemory, MESSAGE) -#define DAWN_INVALID_IF(EXPR, ...) \ - if (DAWN_UNLIKELY(EXPR)) { \ - return DAWN_MAKE_ERROR(InternalErrorType::Validation, __VA_ARGS__); \ - } \ - for (;;) \ +#define DAWN_INVALID_IF(EXPR, ...) \ + if (DAWN_UNLIKELY(EXPR)) { \ + return DAWN_MAKE_ERROR(InternalErrorType::Validation, absl::StrFormat(__VA_ARGS__)); \ + } \ + for (;;) \ break #define DAWN_CONCAT1(x, y) x##y diff --git a/src/webnn/native/Forward.h b/src/webnn/native/Forward.h index c979553cd..36d959568 100644 --- a/src/webnn/native/Forward.h +++ b/src/webnn/native/Forward.h @@ -23,9 +23,10 @@ class Ref; namespace webnn::native { - class CompilationBase; + class ContextBase; class GraphBase; class GraphBuilderBase; + class InstanceBase; class NamedInputsBase; class NamedOperandsBase; class NamedOutputsBase; diff --git a/src/webnn/native/Graph.cpp b/src/webnn/native/Graph.cpp index d8359d042..ef5d8bd5d 100644 --- a/src/webnn/native/Graph.cpp +++ b/src/webnn/native/Graph.cpp @@ -137,14 +137,14 @@ namespace webnn::native { return CompileImpl(); } - void GraphBase::Compute(NamedInputsBase* inputs, NamedOutputsBase* outputs) { + void GraphBase::APICompute(NamedInputsBase* inputs, NamedOutputsBase* outputs) { GetContext()->ConsumedError(ComputeImpl(inputs, outputs)); } - void GraphBase::ComputeAsync(NamedInputsBase* inputs, - NamedOutputsBase* outputs, - WNNComputeAsyncCallback callback, - void* userdata) { + void GraphBase::APIComputeAsync(NamedInputsBase* inputs, + NamedOutputsBase* outputs, + WNNComputeAsyncCallback callback, + void* userdata) { if (inputs == nullptr || outputs == nullptr) { callback(WNNErrorType_Validation, "named inputs or outputs is empty.", userdata); } diff --git a/src/webnn/native/Graph.h b/src/webnn/native/Graph.h index 7709a3855..a53f6d7c8 100644 --- a/src/webnn/native/Graph.h +++ b/src/webnn/native/Graph.h @@ -82,11 +82,11 @@ namespace webnn::native { virtual MaybeError Compile(); // Webnn API - void Compute(NamedInputsBase* inputs, NamedOutputsBase* outputs); - void ComputeAsync(NamedInputsBase* inputs, - NamedOutputsBase* outputs, - WNNComputeAsyncCallback callback, - void* userdata); + void APICompute(NamedInputsBase* inputs, NamedOutputsBase* outputs); + void APIComputeAsync(NamedInputsBase* inputs, + NamedOutputsBase* outputs, + WNNComputeAsyncCallback callback, + void* userdata); GraphBase(ContextBase* context, ObjectBase::ErrorTag tag); static GraphBase* MakeError(ContextBase* context); diff --git a/src/webnn/native/GraphBuilder.cpp b/src/webnn/native/GraphBuilder.cpp index 26a4ac08e..a8511f2ac 100644 --- a/src/webnn/native/GraphBuilder.cpp +++ b/src/webnn/native/GraphBuilder.cpp @@ -70,40 +70,41 @@ namespace webnn::native { GraphBuilderBase::GraphBuilderBase(ContextBase* context) : ObjectBase(context) { } - OperandBase* GraphBuilderBase::Abs(OperandBase* input) { + OperandBase* GraphBuilderBase::APIAbs(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kAbs, input)); } - OperandBase* GraphBuilderBase::Add(OperandBase* a, OperandBase* b) { + OperandBase* GraphBuilderBase::APIAdd(OperandBase* a, OperandBase* b) { VALIDATE_FOR_OPERAND(new op::Binary(this, op::BinaryOpType::kAdd, a, b)); } - OperandBase* GraphBuilderBase::AveragePool2d(OperandBase* input, Pool2dOptions const* options) { + OperandBase* GraphBuilderBase::APIAveragePool2d(OperandBase* input, + Pool2dOptions const* options) { VALIDATE_FOR_OPERAND(new op::Pool2d(this, op::Pool2dType::kAveragePool2d, input, options)); } - OperandBase* GraphBuilderBase::BatchNorm(OperandBase* input, - OperandBase* mean, - OperandBase* variance, - BatchNormOptions const* options) { + OperandBase* GraphBuilderBase::APIBatchNorm(OperandBase* input, + OperandBase* mean, + OperandBase* variance, + BatchNormOptions const* options) { VALIDATE_FOR_OPERAND(new op::BatchNorm(this, input, mean, variance, options)); } - OperandBase* GraphBuilderBase::Clamp(OperandBase* input, ClampOptions const* options) { + OperandBase* GraphBuilderBase::APIClamp(OperandBase* input, ClampOptions const* options) { VALIDATE_FOR_OPERAND(new op::Clamp(this, input, options)); } - FusionOperatorBase* GraphBuilderBase::ClampOperator(ClampOptions const* options) { + FusionOperatorBase* GraphBuilderBase::APIClampOperator(ClampOptions const* options) { return new op::FusionClamp(this, options); } - OperandBase* GraphBuilderBase::Ceil(OperandBase* input) { + OperandBase* GraphBuilderBase::APICeil(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kCeil, input)); } - OperandBase* GraphBuilderBase::Concat(uint32_t inputsCount, - OperandBase* const* inputs, - uint32_t axis) { + OperandBase* GraphBuilderBase::APIConcat(uint32_t inputsCount, + OperandBase* const* inputs, + uint32_t axis) { std::vector> operandInputs; operandInputs.reserve(inputsCount); for (uint32_t i = 0; i < inputsCount; ++i) { @@ -112,13 +113,13 @@ namespace webnn::native { VALIDATE_FOR_OPERAND(new op::Concat(this, std::move(operandInputs), axis)); } - OperandBase* GraphBuilderBase::Constant(OperandDescriptor const* desc, - ArrayBufferView const* arrayBuffer) { + OperandBase* GraphBuilderBase::APIConstant(OperandDescriptor const* desc, + ArrayBufferView const* arrayBuffer) { VALIDATE_FOR_OPERAND(new op::Constant(this, desc, arrayBuffer)); } - OperandBase* GraphBuilderBase::ConstantWithGpuBuffer(OperandDescriptor const* desc, - GpuBufferView const* gpuBuffer) { + OperandBase* GraphBuilderBase::APIConstantWithGpuBuffer(OperandDescriptor const* desc, + GpuBufferView const* gpuBuffer) { #if defined(WEBNN_ENABLE_GPU_BUFFER) VALIDATE_FOR_OPERAND(new op::Constant(this, desc, gpuBuffer)); #endif @@ -126,104 +127,105 @@ namespace webnn::native { return nullptr; } - OperandBase* GraphBuilderBase::Conv2d(OperandBase* input, - OperandBase* filter, - Conv2dOptions const* options) { + OperandBase* GraphBuilderBase::APIConv2d(OperandBase* input, + OperandBase* filter, + Conv2dOptions const* options) { VALIDATE_FOR_OPERAND(new op::Conv2d(this, input, filter, options)); } - OperandBase* GraphBuilderBase::ConvTranspose2d(OperandBase* input, - OperandBase* filter, - ConvTranspose2dOptions const* options) { + OperandBase* GraphBuilderBase::APIConvTranspose2d(OperandBase* input, + OperandBase* filter, + ConvTranspose2dOptions const* options) { VALIDATE_FOR_OPERAND(new op::ConvTranspose2d(this, input, filter, options)); } - OperandBase* GraphBuilderBase::Cos(OperandBase* input) { + OperandBase* GraphBuilderBase::APICos(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kCos, input)); } - OperandBase* GraphBuilderBase::Div(OperandBase* a, OperandBase* b) { + OperandBase* GraphBuilderBase::APIDiv(OperandBase* a, OperandBase* b) { VALIDATE_FOR_OPERAND(new op::Binary(this, op::BinaryOpType::kDiv, a, b)); } - OperandBase* GraphBuilderBase::Exp(OperandBase* input) { + OperandBase* GraphBuilderBase::APIExp(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kExp, input)); } - OperandBase* GraphBuilderBase::Floor(OperandBase* input) { + OperandBase* GraphBuilderBase::APIFloor(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kFloor, input)); } - OperandBase* GraphBuilderBase::Gemm(OperandBase* a, - OperandBase* b, - GemmOptions const* options) { + OperandBase* GraphBuilderBase::APIGemm(OperandBase* a, + OperandBase* b, + GemmOptions const* options) { VALIDATE_FOR_OPERAND(new op::Gemm(this, a, b, options)); } - OperandArrayBase* GraphBuilderBase::Gru(OperandBase* input, - OperandBase* weight, - OperandBase* recurrentWeight, - int32_t steps, - int32_t hiddenSize, - GruOptions const* options) { + OperandArrayBase* GraphBuilderBase::APIGru(OperandBase* input, + OperandBase* weight, + OperandBase* recurrentWeight, + int32_t steps, + int32_t hiddenSize, + GruOptions const* options) { VALIDATE_ARRAY_OPERAND( new op::Gru(this, input, weight, recurrentWeight, steps, hiddenSize, options)); } - OperandBase* GraphBuilderBase::HardSwish(OperandBase* input) { + OperandBase* GraphBuilderBase::APIHardSwish(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kHardSwish, input)); } - FusionOperatorBase* GraphBuilderBase::HardSwishOperator() { + FusionOperatorBase* GraphBuilderBase::APIHardSwishOperator() { return new op::FusionUnary(this, FusionType::HardSwish); } - OperandBase* GraphBuilderBase::Input(char const* name, OperandDescriptor const* desc) { + OperandBase* GraphBuilderBase::APIInput(char const* name, OperandDescriptor const* desc) { VALIDATE_FOR_OPERAND(new op::Input(this, std::string(name), desc)); } - OperandBase* GraphBuilderBase::InstanceNorm(OperandBase* input, - InstanceNormOptions const* options) { + OperandBase* GraphBuilderBase::APIInstanceNorm(OperandBase* input, + InstanceNormOptions const* options) { VALIDATE_FOR_OPERAND(new op::InstanceNorm(this, input, options)); } - OperandBase* GraphBuilderBase::LeakyRelu(OperandBase* input, LeakyReluOptions const* options) { + OperandBase* GraphBuilderBase::APILeakyRelu(OperandBase* input, + LeakyReluOptions const* options) { VALIDATE_FOR_OPERAND(new op::LeakyRelu(this, input, options)); } - FusionOperatorBase* GraphBuilderBase::LeakyReluOperator(LeakyReluOptions const* options) { + FusionOperatorBase* GraphBuilderBase::APILeakyReluOperator(LeakyReluOptions const* options) { return new op::FusionLeakyRelu(this, options); } - OperandBase* GraphBuilderBase::Log(OperandBase* input) { + OperandBase* GraphBuilderBase::APILog(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kLog, input)); } - OperandBase* GraphBuilderBase::L2Pool2d(OperandBase* input, Pool2dOptions const* options) { + OperandBase* GraphBuilderBase::APIL2Pool2d(OperandBase* input, Pool2dOptions const* options) { VALIDATE_FOR_OPERAND(new op::Pool2d(this, op::Pool2dType::kL2Pool2d, input, options)); } - OperandBase* GraphBuilderBase::Matmul(OperandBase* a, OperandBase* b) { + OperandBase* GraphBuilderBase::APIMatmul(OperandBase* a, OperandBase* b) { VALIDATE_FOR_OPERAND(new op::Binary(this, op::BinaryOpType::kMatMul, a, b)); } - OperandBase* GraphBuilderBase::Max(OperandBase* a, OperandBase* b) { + OperandBase* GraphBuilderBase::APIMax(OperandBase* a, OperandBase* b) { VALIDATE_FOR_OPERAND(new op::Binary(this, op::BinaryOpType::kMax, a, b)); } - OperandBase* GraphBuilderBase::MaxPool2d(OperandBase* input, Pool2dOptions const* options) { + OperandBase* GraphBuilderBase::APIMaxPool2d(OperandBase* input, Pool2dOptions const* options) { VALIDATE_FOR_OPERAND(new op::Pool2d(this, op::Pool2dType::kMaxPool2d, input, options)); } - OperandBase* GraphBuilderBase::Min(OperandBase* a, OperandBase* b) { + OperandBase* GraphBuilderBase::APIMin(OperandBase* a, OperandBase* b) { VALIDATE_FOR_OPERAND(new op::Binary(this, op::BinaryOpType::kMin, a, b)); } - OperandBase* GraphBuilderBase::Mul(OperandBase* a, OperandBase* b) { + OperandBase* GraphBuilderBase::APIMul(OperandBase* a, OperandBase* b) { VALIDATE_FOR_OPERAND(new op::Binary(this, op::BinaryOpType::kMul, a, b)); } - OperandBase* GraphBuilderBase::Neg(OperandBase* input) { + OperandBase* GraphBuilderBase::APINeg(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kNeg, input)); } @@ -233,125 +235,129 @@ namespace webnn::native { // PadOptions const* options) { // VALIDATE_FOR_OPERAND(new op::Pad(this, input, padding, padding_count, options)); // } - OperandBase* GraphBuilderBase::Pad(OperandBase* input, - OperandBase* padding, - PadOptions const* options) { + OperandBase* GraphBuilderBase::APIPad(OperandBase* input, + OperandBase* padding, + PadOptions const* options) { VALIDATE_FOR_OPERAND(new op::Pad(this, input, padding, options)); } - OperandBase* GraphBuilderBase::Pow(OperandBase* a, OperandBase* b) { + OperandBase* GraphBuilderBase::APIPow(OperandBase* a, OperandBase* b) { VALIDATE_FOR_OPERAND(new op::Binary(this, op::BinaryOpType::kPower, a, b)); } - OperandBase* GraphBuilderBase::ReduceArgMax(OperandBase* input, ReduceOptions const* options) { + OperandBase* GraphBuilderBase::APIReduceArgMax(OperandBase* input, + ReduceOptions const* options) { VALIDATE_FOR_OPERAND(new op::Reduce(this, op::ReduceType::kReduceArgMax, input, options)); } - OperandBase* GraphBuilderBase::ReduceArgMin(OperandBase* input, ReduceOptions const* options) { + OperandBase* GraphBuilderBase::APIReduceArgMin(OperandBase* input, + ReduceOptions const* options) { VALIDATE_FOR_OPERAND(new op::Reduce(this, op::ReduceType::kReduceArgMin, input, options)); } - OperandBase* GraphBuilderBase::ReduceL2(OperandBase* input, ReduceOptions const* options) { + OperandBase* GraphBuilderBase::APIReduceL2(OperandBase* input, ReduceOptions const* options) { VALIDATE_FOR_OPERAND(new op::Reduce(this, op::ReduceType::kReduceL2, input, options)); } - OperandBase* GraphBuilderBase::ReduceL1(OperandBase* input, ReduceOptions const* options) { + OperandBase* GraphBuilderBase::APIReduceL1(OperandBase* input, ReduceOptions const* options) { VALIDATE_FOR_OPERAND(new op::Reduce(this, op::ReduceType::kReduceL1, input, options)); } - OperandBase* GraphBuilderBase::ReduceMax(OperandBase* input, ReduceOptions const* options) { + OperandBase* GraphBuilderBase::APIReduceMax(OperandBase* input, ReduceOptions const* options) { VALIDATE_FOR_OPERAND(new op::Reduce(this, op::ReduceType::kReduceMax, input, options)); } - OperandBase* GraphBuilderBase::ReduceMean(OperandBase* input, ReduceOptions const* options) { + OperandBase* GraphBuilderBase::APIReduceMean(OperandBase* input, ReduceOptions const* options) { VALIDATE_FOR_OPERAND(new op::Reduce(this, op::ReduceType::kReduceMean, input, options)); } - OperandBase* GraphBuilderBase::ReduceMin(OperandBase* input, ReduceOptions const* options) { + OperandBase* GraphBuilderBase::APIReduceMin(OperandBase* input, ReduceOptions const* options) { VALIDATE_FOR_OPERAND(new op::Reduce(this, op::ReduceType::kReduceMin, input, options)); } - OperandBase* GraphBuilderBase::ReduceProduct(OperandBase* input, ReduceOptions const* options) { + OperandBase* GraphBuilderBase::APIReduceProduct(OperandBase* input, + ReduceOptions const* options) { VALIDATE_FOR_OPERAND(new op::Reduce(this, op::ReduceType::kReduceProduct, input, options)); } - OperandBase* GraphBuilderBase::ReduceSum(OperandBase* input, ReduceOptions const* options) { + OperandBase* GraphBuilderBase::APIReduceSum(OperandBase* input, ReduceOptions const* options) { VALIDATE_FOR_OPERAND(new op::Reduce(this, op::ReduceType::kReduceSum, input, options)); } - OperandBase* GraphBuilderBase::Relu(OperandBase* input) { + OperandBase* GraphBuilderBase::APIRelu(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kRelu, input)); } - FusionOperatorBase* GraphBuilderBase::ReluOperator() { + FusionOperatorBase* GraphBuilderBase::APIReluOperator() { return new op::FusionUnary(this, FusionType::Relu); } - OperandBase* GraphBuilderBase::Resample2d(OperandBase* input, - Resample2dOptions const* options) { + OperandBase* GraphBuilderBase::APIResample2d(OperandBase* input, + Resample2dOptions const* options) { VALIDATE_FOR_OPERAND(new op::Resample2d(this, input, options)); } - OperandBase* GraphBuilderBase::Reshape(OperandBase* input, - int32_t const* new_shape, - size_t new_shape_count) { + OperandBase* GraphBuilderBase::APIReshape(OperandBase* input, + int32_t const* new_shape, + size_t new_shape_count) { VALIDATE_FOR_OPERAND(new op::Reshape(this, input, new_shape, new_shape_count)); } - OperandBase* GraphBuilderBase::Sigmoid(OperandBase* input) { + OperandBase* GraphBuilderBase::APISigmoid(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kSigmoid, input)); } - FusionOperatorBase* GraphBuilderBase::SigmoidOperator() { + FusionOperatorBase* GraphBuilderBase::APISigmoidOperator() { return new op::FusionUnary(this, FusionType::Sigmoid); } - OperandBase* GraphBuilderBase::Sin(OperandBase* input) { + OperandBase* GraphBuilderBase::APISin(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kSin, input)); } - OperandBase* GraphBuilderBase::Slice(OperandBase* input, - int32_t const* starts, - uint32_t startsCount, - int32_t const* sizes, - uint32_t sizesCount, - SliceOptions const* options) { + OperandBase* GraphBuilderBase::APISlice(OperandBase* input, + int32_t const* starts, + uint32_t startsCount, + int32_t const* sizes, + uint32_t sizesCount, + SliceOptions const* options) { VALIDATE_FOR_OPERAND( new op::Slice(this, input, starts, startsCount, sizes, sizesCount, options)); } - OperandBase* GraphBuilderBase::Softmax(OperandBase* input) { + OperandBase* GraphBuilderBase::APISoftmax(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kSoftmax, input)); } - OperandArrayBase* GraphBuilderBase::Split(OperandBase* input, - uint32_t const* splits, - uint32_t splitsCount, - SplitOptions const* options) { + OperandArrayBase* GraphBuilderBase::APISplit(OperandBase* input, + uint32_t const* splits, + uint32_t splitsCount, + SplitOptions const* options) { VALIDATE_ARRAY_OPERAND(new op::Split(this, input, splits, splitsCount, options)); } - OperandBase* GraphBuilderBase::Squeeze(OperandBase* input, SqueezeOptions const* options) { + OperandBase* GraphBuilderBase::APISqueeze(OperandBase* input, SqueezeOptions const* options) { VALIDATE_FOR_OPERAND(new op::Squeeze(this, input, options)); } - OperandBase* GraphBuilderBase::Sub(OperandBase* a, OperandBase* b) { + OperandBase* GraphBuilderBase::APISub(OperandBase* a, OperandBase* b) { VALIDATE_FOR_OPERAND(new op::Binary(this, op::BinaryOpType::kSub, a, b)); } - OperandBase* GraphBuilderBase::Tan(OperandBase* input) { + OperandBase* GraphBuilderBase::APITan(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kTan, input)); } - OperandBase* GraphBuilderBase::Tanh(OperandBase* input) { + OperandBase* GraphBuilderBase::APITanh(OperandBase* input) { VALIDATE_FOR_OPERAND(new op::Unary(this, op::UnaryOpType::kTanh, input)); } - FusionOperatorBase* GraphBuilderBase::TanhOperator() { + FusionOperatorBase* GraphBuilderBase::APITanhOperator() { return new op::FusionUnary(this, FusionType::Tanh); } - OperandBase* GraphBuilderBase::Transpose(OperandBase* input, TransposeOptions const* options) { + OperandBase* GraphBuilderBase::APITranspose(OperandBase* input, + TransposeOptions const* options) { VALIDATE_FOR_OPERAND(new op::Transpose(this, input, options)); } @@ -380,7 +386,7 @@ namespace webnn::native { return std::move(graph); } - GraphBase* GraphBuilderBase::Build(NamedOperandsBase const* namedOperands) { + GraphBase* GraphBuilderBase::APIBuild(NamedOperandsBase const* namedOperands) { Ref result = nullptr; if (GetContext()->ConsumedError(BuildImpl(namedOperands), &result)) { ASSERT(result == nullptr); diff --git a/src/webnn/native/GraphBuilder.h b/src/webnn/native/GraphBuilder.h index 58318a122..cd8ddd326 100644 --- a/src/webnn/native/GraphBuilder.h +++ b/src/webnn/native/GraphBuilder.h @@ -34,86 +34,86 @@ namespace webnn::native { virtual ~GraphBuilderBase() = default; // WebNN API - OperandBase* Abs(OperandBase*); - OperandBase* Add(OperandBase*, OperandBase*); - OperandBase* AveragePool2d(OperandBase*, Pool2dOptions const* options); - OperandBase* BatchNorm(OperandBase*, - OperandBase*, - OperandBase*, - BatchNormOptions const* options); - OperandBase* Clamp(OperandBase*, ClampOptions const* options); - FusionOperatorBase* ClampOperator(ClampOptions const* options); - OperandBase* Ceil(OperandBase*); - OperandBase* Concat(uint32_t inputsCount, OperandBase* const* inputs, uint32_t axis); - OperandBase* Constant(OperandDescriptor const* desc, ArrayBufferView const* arrayBuffer); - OperandBase* ConstantWithGpuBuffer(OperandDescriptor const* desc, - GpuBufferView const* arrayBuffer); - OperandBase* Conv2d(OperandBase*, OperandBase*, Conv2dOptions const* options); - OperandBase* ConvTranspose2d(OperandBase*, - OperandBase*, - ConvTranspose2dOptions const* options); - OperandBase* Cos(OperandBase*); - OperandBase* Div(OperandBase*, OperandBase*); - OperandBase* Exp(OperandBase*); - OperandBase* Floor(OperandBase*); - OperandBase* Gemm(OperandBase*, OperandBase*, GemmOptions const* options); - OperandArrayBase* Gru(OperandBase*, - OperandBase*, - OperandBase*, - int32_t steps, - int32_t hiddenSize, - GruOptions const* options); - OperandBase* HardSwish(OperandBase*); - FusionOperatorBase* HardSwishOperator(); - OperandBase* Input(char const* name, OperandDescriptor const* desc); - OperandBase* InstanceNorm(OperandBase*, InstanceNormOptions const* options); - OperandBase* LeakyRelu(OperandBase*, LeakyReluOptions const* options); - FusionOperatorBase* LeakyReluOperator(LeakyReluOptions const* options); - OperandBase* Log(OperandBase*); - OperandBase* L2Pool2d(OperandBase*, Pool2dOptions const* options); - OperandBase* Matmul(OperandBase* a, OperandBase* b); - OperandBase* Max(OperandBase*, OperandBase*); - OperandBase* MaxPool2d(OperandBase*, Pool2dOptions const* options); - OperandBase* Min(OperandBase*, OperandBase*); - OperandBase* Mul(OperandBase*, OperandBase*); - OperandBase* Neg(OperandBase*); - OperandBase* Pad(OperandBase*, OperandBase*, PadOptions const* options); - OperandBase* Pow(OperandBase*, OperandBase*); - OperandBase* ReduceArgMax(OperandBase*, ReduceOptions const* options); - OperandBase* ReduceArgMin(OperandBase*, ReduceOptions const* options); - OperandBase* ReduceL1(OperandBase*, ReduceOptions const* options); - OperandBase* ReduceL2(OperandBase*, ReduceOptions const* options); - OperandBase* ReduceMax(OperandBase*, ReduceOptions const* options); - OperandBase* ReduceMean(OperandBase*, ReduceOptions const* options); - OperandBase* ReduceMin(OperandBase*, ReduceOptions const* options); - OperandBase* ReduceProduct(OperandBase*, ReduceOptions const* options); - OperandBase* ReduceSum(OperandBase*, ReduceOptions const* options); - OperandBase* Relu(OperandBase*); - FusionOperatorBase* ReluOperator(); - OperandBase* Resample2d(OperandBase*, Resample2dOptions const* options); - OperandBase* Reshape(OperandBase*, int32_t const*, size_t); - OperandBase* Sigmoid(OperandBase*); - FusionOperatorBase* SigmoidOperator(); - OperandBase* Sin(OperandBase*); - OperandBase* Slice(OperandBase*, - int32_t const* starts, - uint32_t startsCount, - int32_t const* sizes, - uint32_t sizesCount, - SliceOptions const* options); - OperandBase* Softmax(OperandBase*); - OperandArrayBase* Split(OperandBase*, - uint32_t const*, - uint32_t, - SplitOptions const* options); - OperandBase* Squeeze(OperandBase*, SqueezeOptions const* options); - OperandBase* Sub(OperandBase*, OperandBase*); - OperandBase* Tan(OperandBase*); - OperandBase* Tanh(OperandBase*); - FusionOperatorBase* TanhOperator(); - OperandBase* Transpose(OperandBase*, TransposeOptions const* options); + OperandBase* APIAbs(OperandBase*); + OperandBase* APIAdd(OperandBase*, OperandBase*); + OperandBase* APIAveragePool2d(OperandBase*, Pool2dOptions const* options); + OperandBase* APIBatchNorm(OperandBase*, + OperandBase*, + OperandBase*, + BatchNormOptions const* options); + OperandBase* APIClamp(OperandBase*, ClampOptions const* options); + FusionOperatorBase* APIClampOperator(ClampOptions const* options); + OperandBase* APICeil(OperandBase*); + OperandBase* APIConcat(uint32_t inputsCount, OperandBase* const* inputs, uint32_t axis); + OperandBase* APIConstant(OperandDescriptor const* desc, ArrayBufferView const* arrayBuffer); + OperandBase* APIConstantWithGpuBuffer(OperandDescriptor const* desc, + GpuBufferView const* arrayBuffer); + OperandBase* APIConv2d(OperandBase*, OperandBase*, Conv2dOptions const* options); + OperandBase* APIConvTranspose2d(OperandBase*, + OperandBase*, + ConvTranspose2dOptions const* options); + OperandBase* APICos(OperandBase*); + OperandBase* APIDiv(OperandBase*, OperandBase*); + OperandBase* APIExp(OperandBase*); + OperandBase* APIFloor(OperandBase*); + OperandBase* APIGemm(OperandBase*, OperandBase*, GemmOptions const* options); + OperandArrayBase* APIGru(OperandBase*, + OperandBase*, + OperandBase*, + int32_t steps, + int32_t hiddenSize, + GruOptions const* options); + OperandBase* APIHardSwish(OperandBase*); + FusionOperatorBase* APIHardSwishOperator(); + OperandBase* APIInput(char const* name, OperandDescriptor const* desc); + OperandBase* APIInstanceNorm(OperandBase*, InstanceNormOptions const* options); + OperandBase* APILeakyRelu(OperandBase*, LeakyReluOptions const* options); + FusionOperatorBase* APILeakyReluOperator(LeakyReluOptions const* options); + OperandBase* APILog(OperandBase*); + OperandBase* APIL2Pool2d(OperandBase*, Pool2dOptions const* options); + OperandBase* APIMatmul(OperandBase* a, OperandBase* b); + OperandBase* APIMax(OperandBase*, OperandBase*); + OperandBase* APIMaxPool2d(OperandBase*, Pool2dOptions const* options); + OperandBase* APIMin(OperandBase*, OperandBase*); + OperandBase* APIMul(OperandBase*, OperandBase*); + OperandBase* APINeg(OperandBase*); + OperandBase* APIPad(OperandBase*, OperandBase*, PadOptions const* options); + OperandBase* APIPow(OperandBase*, OperandBase*); + OperandBase* APIReduceArgMax(OperandBase*, ReduceOptions const* options); + OperandBase* APIReduceArgMin(OperandBase*, ReduceOptions const* options); + OperandBase* APIReduceL1(OperandBase*, ReduceOptions const* options); + OperandBase* APIReduceL2(OperandBase*, ReduceOptions const* options); + OperandBase* APIReduceMax(OperandBase*, ReduceOptions const* options); + OperandBase* APIReduceMean(OperandBase*, ReduceOptions const* options); + OperandBase* APIReduceMin(OperandBase*, ReduceOptions const* options); + OperandBase* APIReduceProduct(OperandBase*, ReduceOptions const* options); + OperandBase* APIReduceSum(OperandBase*, ReduceOptions const* options); + OperandBase* APIRelu(OperandBase*); + FusionOperatorBase* APIReluOperator(); + OperandBase* APIResample2d(OperandBase*, Resample2dOptions const* options); + OperandBase* APIReshape(OperandBase*, int32_t const*, size_t); + OperandBase* APISigmoid(OperandBase*); + FusionOperatorBase* APISigmoidOperator(); + OperandBase* APISin(OperandBase*); + OperandBase* APISlice(OperandBase*, + int32_t const* starts, + uint32_t startsCount, + int32_t const* sizes, + uint32_t sizesCount, + SliceOptions const* options); + OperandBase* APISoftmax(OperandBase*); + OperandArrayBase* APISplit(OperandBase*, + uint32_t const*, + uint32_t, + SplitOptions const* options); + OperandBase* APISqueeze(OperandBase*, SqueezeOptions const* options); + OperandBase* APISub(OperandBase*, OperandBase*); + OperandBase* APITan(OperandBase*); + OperandBase* APITanh(OperandBase*); + FusionOperatorBase* APITanhOperator(); + OperandBase* APITranspose(OperandBase*, TransposeOptions const* options); - GraphBase* Build(NamedOperandsBase const* namedOperands); + GraphBase* APIBuild(NamedOperandsBase const* namedOperands); private: ResultOrError> BuildImpl(NamedOperandsBase const* namedOperands); diff --git a/src/webnn/native/Instance.cpp b/src/webnn/native/Instance.cpp index 2ef8f04ab..d31bb655a 100644 --- a/src/webnn/native/Instance.cpp +++ b/src/webnn/native/Instance.cpp @@ -175,7 +175,7 @@ namespace webnn::native { return mBackends[wnn::BackendType::Null]->CreateContext(options); } - ContextBase* InstanceBase::CreateContext(const ContextOptions* options) { + ContextBase* InstanceBase::APICreateContext(const ContextOptions* options) { if (mBackends.find(wnn::BackendType::DirectML) != mBackends.end()) { return mBackends[wnn::BackendType::DirectML]->CreateContext(options); } else if (mBackends.find(wnn::BackendType::OpenVINO) != mBackends.end()) { @@ -193,7 +193,7 @@ namespace webnn::native { return nullptr; } - ContextBase* InstanceBase::CreateContextWithGpuDevice(const GpuDevice* wnn_device) { + ContextBase* InstanceBase::APICreateContextWithGpuDevice(const GpuDevice* wnn_device) { #if defined(WEBNN_ENABLE_GPU_BUFFER) WGPUDevice device = reinterpret_cast(wnn_device->device); if (mBackends.find(wnn::BackendType::DirectML) != mBackends.end()) { @@ -210,23 +210,23 @@ namespace webnn::native { return nullptr; } - GraphBuilderBase* InstanceBase::CreateGraphBuilder(ContextBase* context) { + GraphBuilderBase* InstanceBase::APICreateGraphBuilder(ContextBase* context) { return new GraphBuilderBase(context); } - NamedInputsBase* InstanceBase::CreateNamedInputs() { + NamedInputsBase* InstanceBase::APICreateNamedInputs() { return new NamedInputsBase(); } - NamedOperandsBase* InstanceBase::CreateNamedOperands() { + NamedOperandsBase* InstanceBase::APICreateNamedOperands() { return new NamedOperandsBase(); } - NamedOutputsBase* InstanceBase::CreateNamedOutputs() { + NamedOutputsBase* InstanceBase::APICreateNamedOutputs() { return new NamedOutputsBase(); } - OperatorArrayBase* InstanceBase::CreateOperatorArray() { + OperatorArrayBase* InstanceBase::APICreateOperatorArray() { return new OperatorArrayBase(); } diff --git a/src/webnn/native/Instance.h b/src/webnn/native/Instance.h index 9a1a826bc..bc0b4fe96 100644 --- a/src/webnn/native/Instance.h +++ b/src/webnn/native/Instance.h @@ -39,13 +39,13 @@ namespace webnn::native { static InstanceBase* Create(const InstanceDescriptor* descriptor = nullptr); // WebNN API - ContextBase* CreateContext(const ContextOptions* options); - ContextBase* CreateContextWithGpuDevice(const GpuDevice* device); - GraphBuilderBase* CreateGraphBuilder(ContextBase* context); - NamedInputsBase* CreateNamedInputs(); - NamedOperandsBase* CreateNamedOperands(); - NamedOutputsBase* CreateNamedOutputs(); - OperatorArrayBase* CreateOperatorArray(); + ContextBase* APICreateContext(const ContextOptions* options); + ContextBase* APICreateContextWithGpuDevice(const GpuDevice* device); + GraphBuilderBase* APICreateGraphBuilder(ContextBase* context); + NamedInputsBase* APICreateNamedInputs(); + NamedOperandsBase* APICreateNamedOperands(); + NamedOutputsBase* APICreateNamedOutputs(); + OperatorArrayBase* APICreateOperatorArray(); ContextBase* CreateTestContext(const ContextOptions* options); diff --git a/src/webnn/native/NamedInputs.h b/src/webnn/native/NamedInputs.h index 4b73dc687..e4d0686b9 100644 --- a/src/webnn/native/NamedInputs.h +++ b/src/webnn/native/NamedInputs.h @@ -40,7 +40,7 @@ namespace webnn::native { } // WebNN API - void Set(char const* name, const Input* input) { + void APISet(char const* name, const Input* input) { mInputs[std::string(name)] = *input; #if defined(WEBNN_ENABLE_WIRE) // Input data type is Arrary Buffer View. diff --git a/src/webnn/native/NamedOperands.h b/src/webnn/native/NamedOperands.h index af6ddbdce..bd51dab20 100644 --- a/src/webnn/native/NamedOperands.h +++ b/src/webnn/native/NamedOperands.h @@ -25,7 +25,7 @@ namespace webnn::native { class NamedOperandsBase : public RefCounted { public: // WebNN API - void Set(char const* name, const OperandBase* operand) { + void APISet(char const* name, const OperandBase* operand) { mNamedOperands[std::string(name)] = operand; } diff --git a/src/webnn/native/NamedOutputs.h b/src/webnn/native/NamedOutputs.h index 11d4bdf4e..dace4dd3d 100644 --- a/src/webnn/native/NamedOutputs.h +++ b/src/webnn/native/NamedOutputs.h @@ -41,7 +41,7 @@ namespace webnn::native { } // WebNN API - void Set(char const* name, const Resource* resource) { + void APISetOutput(char const* name, const Resource* resource) { mOutputs[std::string(name)] = *resource; if (resource->gpuBufferView.buffer != nullptr) { #if defined(WEBNN_ENABLE_GPU_BUFFER) @@ -61,7 +61,7 @@ namespace webnn::native { } } - void Get(char const* name, ArrayBufferView* arrayBuffer) { + void APIGetOutput(char const* name, ArrayBufferView* arrayBuffer) { if (mOutputs.find(std::string(name)) == mOutputs.end()) { return; } diff --git a/src/webnn/native/OperandArray.h b/src/webnn/native/OperandArray.h index 1d74e7b8e..c584396d4 100644 --- a/src/webnn/native/OperandArray.h +++ b/src/webnn/native/OperandArray.h @@ -31,10 +31,10 @@ namespace webnn::native { return new OperandArrayBase(graphBuilder, ObjectBase::kError); } // WebNN API - size_t Size() { + size_t APISize() { return mOperands.size(); } - OperandBase* Get(size_t index) { + OperandBase* APIGetOperand(size_t index) { return mOperands[index].Get(); } diff --git a/src/webnn/native/OperatorArray.h b/src/webnn/native/OperatorArray.h index 686bd3e80..df94f659a 100644 --- a/src/webnn/native/OperatorArray.h +++ b/src/webnn/native/OperatorArray.h @@ -25,15 +25,15 @@ namespace webnn::native { virtual ~OperatorArrayBase() = default; // WebNN API - size_t Size() { + size_t APISize() { return mOperators.size(); } - void Set(FusionOperatorBase* mlOperator) { + void APISetFusionOperator(FusionOperatorBase* mlOperator) { mOperators.push_back(Ref(mlOperator)); } - FusionOperatorBase* Get(size_t index) { + FusionOperatorBase* APIGetFusionOperator(size_t index) { return mOperators[index].Get(); } diff --git a/src/webnn/native/Utils.h b/src/webnn/native/Utils.h index 42b1055e5..206bf6a53 100644 --- a/src/webnn/native/Utils.h +++ b/src/webnn/native/Utils.h @@ -15,7 +15,7 @@ #ifndef WEBNN_NATIVE_NATIVEUTILS_H_ #define WEBNN_NATIVE_NATIVEUTILS_H_ -#include +#include #include #include diff --git a/src/webnn/native/WebnnNative.cpp b/src/webnn/native/WebnnNative.cpp index 220b9d977..a5c1645a0 100644 --- a/src/webnn/native/WebnnNative.cpp +++ b/src/webnn/native/WebnnNative.cpp @@ -67,7 +67,12 @@ namespace webnn::native { WNNContext Instance::CreateContext(const wnn::ContextOptions* options) { return reinterpret_cast( - mImpl->CreateContext(reinterpret_cast(options))); + mImpl->APICreateContext(reinterpret_cast(options))); + } + + WNNGraphBuilder Instance::CreateGraphBuilder(const WNNContext context) { + return reinterpret_cast( + mImpl->APICreateGraphBuilder(reinterpret_cast(context))); } WNNInstance Instance::Get() const { diff --git a/src/webnn/native/dmlx/GraphDMLX.cpp b/src/webnn/native/dmlx/GraphDMLX.cpp index c048223c2..962053fef 100644 --- a/src/webnn/native/dmlx/GraphDMLX.cpp +++ b/src/webnn/native/dmlx/GraphDMLX.cpp @@ -1842,8 +1842,8 @@ namespace webnn::native::dmlx { fActivation = ::dml::FusedActivation::Sigmoid(); gActivation = ::dml::FusedActivation::Tanh(); } else { - fActivation = CreateFusedActivation(options->activations->Get(0)); - gActivation = CreateFusedActivation(options->activations->Get(1)); + fActivation = CreateFusedActivation(options->activations->APIGetFusionOperator(0)); + gActivation = CreateFusedActivation(options->activations->APIGetFusionOperator(1)); } std::vector<::dml::FusedActivation> activations; if (direction == diff --git a/src/webnn/native/ops/Conv2d.h b/src/webnn/native/ops/Conv2d.h index 4837b3cfb..c668ac7ec 100644 --- a/src/webnn/native/ops/Conv2d.h +++ b/src/webnn/native/ops/Conv2d.h @@ -32,7 +32,7 @@ namespace webnn::native::op { if (options != nullptr && options->bias != nullptr) { mInputs.push_back(options->bias); } - if (options == nullptr || options->padding == nullptr) { + if (options == nullptr || options->paddingCount == 0) { mPadding = std::vector(4, 0); } else { mPadding.assign(options->padding, options->padding + options->paddingCount); @@ -40,7 +40,7 @@ namespace webnn::native::op { mOptions.padding = mPadding.data(); mOptions.paddingCount = mPadding.size(); - if (options == nullptr || options->strides == nullptr) { + if (options == nullptr || options->stridesCount == 0) { mStride = std::vector(2, 1); } else { mStride.assign(options->strides, options->strides + options->stridesCount); @@ -48,7 +48,7 @@ namespace webnn::native::op { mOptions.strides = mStride.data(); mOptions.stridesCount = mStride.size(); - if (options == nullptr || options->dilations == nullptr) { + if (options == nullptr || options->dilationsCount == 0) { mDilations = std::vector(2, 1); } else { mDilations.assign(options->dilations, options->dilations + options->dilationsCount); diff --git a/src/webnn/native/ops/ConvTranspose2d.cpp b/src/webnn/native/ops/ConvTranspose2d.cpp index 736563b82..9b4ccb6a6 100644 --- a/src/webnn/native/ops/ConvTranspose2d.cpp +++ b/src/webnn/native/ops/ConvTranspose2d.cpp @@ -25,7 +25,7 @@ namespace webnn::native::op { OperandBase* filter, ConvTranspose2dOptions const* options) : Conv2dBase(builder, input, filter, options) { - if (options == nullptr || options->outputPadding == nullptr) { + if (options == nullptr || options->outputPaddingCount == 0) { mOutputPadding = std::vector(2, 0); } else { mOutputPadding.assign(options->outputPadding, @@ -34,7 +34,7 @@ namespace webnn::native::op { mOptions.outputPadding = mOutputPadding.data(); mOptions.outputPaddingCount = mOutputPadding.size(); - if (options != nullptr && options->outputSizes != nullptr) { + if (options != nullptr && options->outputSizesCount != 0) { mOutputSizes.assign(options->outputSizes, options->outputSizes + options->outputSizesCount); mOptions.outputSizes = mOutputSizes.data(); @@ -113,7 +113,7 @@ namespace webnn::native::op { } int32_t outputHeight, outputWidth; - if (mOptions.outputSizes != nullptr) { + if (mOptions.outputSizesCount != 0) { outputHeight = mOptions.outputSizes[0]; outputWidth = mOptions.outputSizes[1]; } else { diff --git a/src/webnn/native/ops/Gru.cpp b/src/webnn/native/ops/Gru.cpp index ee980789b..6717fa329 100644 --- a/src/webnn/native/ops/Gru.cpp +++ b/src/webnn/native/ops/Gru.cpp @@ -44,9 +44,10 @@ namespace webnn::native::op { } if (options == nullptr || options->activations == nullptr) { mActivations = AcquireRef(new OperatorArrayBase()); - mActivations->Set( + mActivations->APISetFusionOperator( AcquireRef(new FusionOperatorBase(builder, FusionType::Sigmoid)).Get()); - mActivations->Set(AcquireRef(new FusionOperatorBase(builder, FusionType::Tanh)).Get()); + mActivations->APISetFusionOperator( + AcquireRef(new FusionOperatorBase(builder, FusionType::Tanh)).Get()); } else { mActivations = Ref(mOptions.activations); } @@ -108,7 +109,7 @@ namespace webnn::native::op { } } // The activations parameter - if (GetActivations().Get()->Size() != 2) { + if (GetActivations().Get()->APISize() != 2) { return DAWN_VALIDATION_ERROR("Argument activations is not a sequence of length 2."); } diff --git a/src/webnn/native/ops/Pool2d.cpp b/src/webnn/native/ops/Pool2d.cpp index 63050eb95..d1e662059 100644 --- a/src/webnn/native/ops/Pool2d.cpp +++ b/src/webnn/native/ops/Pool2d.cpp @@ -26,14 +26,14 @@ namespace webnn::native::op { OperandBase* input, Pool2dOptions const* options) : OperatorBase(builder, {input}), mOpType(opType) { - if (options != nullptr && options->windowDimensions != nullptr) { + if (options != nullptr && options->windowDimensionsCount != 0) { mWindowDimensions.assign(options->windowDimensions, options->windowDimensions + options->windowDimensionsCount); mOptions.windowDimensions = mWindowDimensions.data(); mOptions.windowDimensionsCount = mWindowDimensions.size(); } - if (options == nullptr || options->padding == nullptr) { + if (options == nullptr || options->paddingCount == 0) { mPadding = std::vector(4, 0); } else { mPadding.assign(options->padding, options->padding + options->paddingCount); @@ -41,7 +41,7 @@ namespace webnn::native::op { mOptions.padding = mPadding.data(); mOptions.paddingCount = mPadding.size(); - if (options == nullptr || options->strides == nullptr) { + if (options == nullptr || options->stridesCount == 0) { mStride = std::vector(2, 1); } else { mStride.assign(options->strides, options->strides + options->stridesCount); @@ -49,7 +49,7 @@ namespace webnn::native::op { mOptions.strides = mStride.data(); mOptions.stridesCount = mStride.size(); - if (options == nullptr || options->dilations == nullptr) { + if (options == nullptr || options->dilationsCount == 0) { mDilations = std::vector(2, 1); } else { mDilations.assign(options->dilations, options->dilations + options->dilationsCount); @@ -62,7 +62,7 @@ namespace webnn::native::op { mOptions.roundingType = options == nullptr ? wnn::RoundingType::Floor : options->roundingType; - if (options != nullptr && options->outputSizes != nullptr) { + if (options != nullptr && options->outputSizesCount != 0) { mOutputSizes.assign(options->outputSizes, options->outputSizes + options->outputSizesCount); mOptions.outputSizes = mOutputSizes.data(); diff --git a/src/webnn/native/ops/Reduce.cpp b/src/webnn/native/ops/Reduce.cpp index 8da0353b0..6d568efd5 100644 --- a/src/webnn/native/ops/Reduce.cpp +++ b/src/webnn/native/ops/Reduce.cpp @@ -26,7 +26,7 @@ namespace webnn::native::op { ReduceOptions const* options) : OperatorBase(builder, {input}), mOpType(opType) { // If axes are not present, all dimensions are reduced. - if (options == nullptr || options->axes == nullptr) { + if (options == nullptr || options->axesCount == 0) { int32_t rank = input->Shape().size(); mAxes.resize(rank); for (auto i = 0; i < rank; ++i) { diff --git a/src/webnn/native/ops/Resample2d.cpp b/src/webnn/native/ops/Resample2d.cpp index 6c3645354..5dae354d6 100644 --- a/src/webnn/native/ops/Resample2d.cpp +++ b/src/webnn/native/ops/Resample2d.cpp @@ -13,6 +13,7 @@ // limitations under the License. #include "webnn/native/ops/Resample2d.h" +#include "common/Log.h" #include @@ -25,13 +26,13 @@ namespace webnn::native::op { : OperatorBase(builder, {input}), mScales({1.0, 1.0}), mSizes({}), mAxes({2, 3}) { mOptions.mode = options == nullptr ? wnn::InterpolationMode::NearestNeighbor : options->mode; - if (options != nullptr && options->scales != nullptr) { + if (options != nullptr && options->scalesCount != 0) { mScales.assign(options->scales, options->scales + options->scalesCount); } - if (options != nullptr && options->sizes != nullptr) { + if (options != nullptr && options->sizesCount != 0) { mSizes.assign(options->sizes, options->sizes + options->sizesCount); } - if (options != nullptr && options->axes != nullptr) { + if (options != nullptr && options->axesCount != 0) { mAxes.assign(options->axes, options->axes + options->axesCount); } } diff --git a/src/webnn/native/ops/Slice.h b/src/webnn/native/ops/Slice.h index 7be33f9af..ffb18ebb9 100644 --- a/src/webnn/native/ops/Slice.h +++ b/src/webnn/native/ops/Slice.h @@ -33,7 +33,7 @@ namespace webnn::native::op { uint32_t sizesCount, SliceOptions const* options) : OperatorBase(builder, {input}) { - if (options != nullptr && options->axes != nullptr) { + if (options != nullptr && options->axesCount != 0) { mAxes.assign(options->axes, options->axes + options->axesCount); } diff --git a/src/webnn/native/ops/Transpose.h b/src/webnn/native/ops/Transpose.h index 497c9afa6..f0f3da044 100644 --- a/src/webnn/native/ops/Transpose.h +++ b/src/webnn/native/ops/Transpose.h @@ -25,7 +25,7 @@ namespace webnn::native::op { public: Transpose(GraphBuilderBase* builder, OperandBase* input, TransposeOptions const* options) : OperatorBase(builder, {input}) { - if (options == nullptr || options->permutation == nullptr) { + if (options == nullptr || options->permutationCount == 0) { int32_t rank = input->Shape().size(); mPermutation.resize(rank); for (auto i = 0; i < rank - 1; i++) { diff --git a/src/webnn/native/webnn_platform.h b/src/webnn/native/webnn_platform.h index 29259ce1c..938f38e3e 100644 --- a/src/webnn/native/webnn_platform.h +++ b/src/webnn/native/webnn_platform.h @@ -22,7 +22,7 @@ // Use our autogenerated version of the webnn structures that point to webnn_native // object types #include -#include +#include namespace webnn::native { // kEnumCount is a constant specifying the number of enums in a WebGPU enum type, diff --git a/src/webnn/tests/BUILD.gn b/src/webnn/tests/BUILD.gn index 587f3e7bc..f2196f8a4 100644 --- a/src/webnn/tests/BUILD.gn +++ b/src/webnn/tests/BUILD.gn @@ -16,7 +16,7 @@ import("../../../scripts/webnn_overrides_with_defaults.gni") import("//testing/test.gni") -import("${webnn_dawn_root}/scripts/dawn_features.gni") +import("${dawn_root}/scripts/dawn_features.gni") import("${webnn_root}/generator/webnn_generator.gni") group("webnn_tests") { @@ -96,10 +96,10 @@ if (!build_with_chromium) { ############################################################################### webnn_json_generator("mock_webnn_gen") { - target = "mock_webnn" + target = "mock_api" outputs = [ - "src/webnn/mock_webnn.h", - "src/webnn/mock_webnn.cpp", + "src/dawn/mock_webnn.h", + "src/dawn/mock_webnn.cpp", ] } @@ -132,8 +132,8 @@ test("webnn_unittests") { "${webnn_root}/src/webnn:cpp", "${webnn_root}/src/webnn:webnn_proc", "${webnn_root}/src/webnn/common", - "${webnn_root}/src/webnn/native:webnn_native", "${webnn_root}/src/webnn/native:sources", + "${webnn_root}/src/webnn/native:webnn_native", "${webnn_root}/src/webnn/utils:webnn_utils", ] @@ -142,7 +142,6 @@ test("webnn_unittests") { sources = get_target_outputs(":mock_webnn_gen") sources += [ - "//third_party/dawn/src/tests/unittests/ResultTests.cpp", "unittests/ErrorTests.cpp", "unittests/ObjectBaseTests.cpp", "unittests/native/ContextMockTests.cpp", diff --git a/src/webnn/tests/end2end/AddTests.cpp b/src/webnn/tests/end2end/AddTests.cpp index 5bedd8848..0b5eb0c3e 100644 --- a/src/webnn/tests/end2end/AddTests.cpp +++ b/src/webnn/tests/end2end/AddTests.cpp @@ -17,7 +17,7 @@ class AddTests : public WebnnTest {}; TEST_F(AddTests, AddConstantAndInput) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const std::vector bData = { -0.5781865, -0.49248728, -0.2162451, -0.13176449, -0.52118045, 1.9125274, 0.6508799, @@ -60,7 +60,7 @@ TEST_F(AddTests, AddConstantAndInput) { } TEST_F(AddTests, AddTwoInputs) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = utils::BuildInput(builder, "b", {3, 4, 5}); const wnn::Operand c = builder.Add(a, b); @@ -102,7 +102,7 @@ TEST_F(AddTests, AddTwoInputs) { } TEST_F(AddTests, AddBroadcast) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = utils::BuildInput(builder, "b", {5}); const wnn::Operand c = builder.Add(a, b); diff --git a/src/webnn/tests/end2end/BatchNormTests.cpp b/src/webnn/tests/end2end/BatchNormTests.cpp index 95728c327..750aef6f1 100644 --- a/src/webnn/tests/end2end/BatchNormTests.cpp +++ b/src/webnn/tests/end2end/BatchNormTests.cpp @@ -16,7 +16,7 @@ class BatchNormTests : public WebnnTest { void SetUp() override { - builder = wnn::CreateGraphBuilder(GetContext()); + builder = utils::CreateGraphBuilder(GetContext()); } protected: diff --git a/src/webnn/tests/end2end/ClampTests.cpp b/src/webnn/tests/end2end/ClampTests.cpp index 6df0b8128..30d5c27e4 100644 --- a/src/webnn/tests/end2end/ClampTests.cpp +++ b/src/webnn/tests/end2end/ClampTests.cpp @@ -20,7 +20,7 @@ class ClampTests : public WebnnTest { const std::vector& inputData, const std::vector& expectedValue, const wnn::ClampOptions* options = nullptr) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", inputShape); const wnn::Operand b = builder.Clamp(a, options); const wnn::Graph graph = utils::Build(builder, {{"b", b}}); diff --git a/src/webnn/tests/end2end/ConcatTests.cpp b/src/webnn/tests/end2end/ConcatTests.cpp index 848285e1a..d4b91d7bd 100644 --- a/src/webnn/tests/end2end/ConcatTests.cpp +++ b/src/webnn/tests/end2end/ConcatTests.cpp @@ -26,7 +26,7 @@ class ConcatTests : public WebnnTest { const std::vector& expectedShape, const std::vector& expectedValue, bool inputsDefined = true) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); std::vector inputsOperand; inputsOperand.reserve(inputs.size()); size_t index = 0; diff --git a/src/webnn/tests/end2end/Conv2dTests.cpp b/src/webnn/tests/end2end/Conv2dTests.cpp index 296637ab0..785422dd0 100644 --- a/src/webnn/tests/end2end/Conv2dTests.cpp +++ b/src/webnn/tests/end2end/Conv2dTests.cpp @@ -16,7 +16,7 @@ class Conv2dTests : public WebnnTest { void SetUp() override { - builder = wnn::CreateGraphBuilder(GetContext()); + builder = utils::CreateGraphBuilder(GetContext()); } protected: diff --git a/src/webnn/tests/end2end/ConvTranspose2dTests.cpp b/src/webnn/tests/end2end/ConvTranspose2dTests.cpp index 022e1b70a..87dfa5a78 100644 --- a/src/webnn/tests/end2end/ConvTranspose2dTests.cpp +++ b/src/webnn/tests/end2end/ConvTranspose2dTests.cpp @@ -16,7 +16,7 @@ class ConvTranspose2dTests : public WebnnTest { void SetUp() override { - builder = wnn::CreateGraphBuilder(GetContext()); + builder = utils::CreateGraphBuilder(GetContext()); } protected: diff --git a/src/webnn/tests/end2end/DivTests.cpp b/src/webnn/tests/end2end/DivTests.cpp index 83c22b804..0b455d4f6 100644 --- a/src/webnn/tests/end2end/DivTests.cpp +++ b/src/webnn/tests/end2end/DivTests.cpp @@ -17,7 +17,7 @@ class DivTests : public WebnnTest {}; TEST_F(DivTests, Div) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = utils::BuildInput(builder, "b", {3, 4, 5}); const wnn::Operand c = builder.Div(a, b); @@ -62,7 +62,7 @@ TEST_F(DivTests, Div) { } TEST_F(DivTests, DivBroadcast) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = utils::BuildInput(builder, "b", {5}); const wnn::Operand c = builder.Div(a, b); diff --git a/src/webnn/tests/end2end/ElementWiseUnaryTests.cpp b/src/webnn/tests/end2end/ElementWiseUnaryTests.cpp index 00241ee13..909eb2c22 100644 --- a/src/webnn/tests/end2end/ElementWiseUnaryTests.cpp +++ b/src/webnn/tests/end2end/ElementWiseUnaryTests.cpp @@ -32,7 +32,7 @@ class ElementWiseUnaryTests : public WebnnTest { const std::vector& inputData, const std::vector& expectedValue, const std::vector& shape) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", shape); wnn::Operand b; switch (type) { diff --git a/src/webnn/tests/end2end/GemmTests.cpp b/src/webnn/tests/end2end/GemmTests.cpp index 6d2ebfebb..ccbacf854 100644 --- a/src/webnn/tests/end2end/GemmTests.cpp +++ b/src/webnn/tests/end2end/GemmTests.cpp @@ -33,7 +33,7 @@ class GemmTests : public WebnnTest { const std::vector& expectedValue, const Options* options = nullptr, bool constantWeight = false) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", aShape); wnn::Operand b; if (constantWeight) { diff --git a/src/webnn/tests/end2end/GruTests.cpp b/src/webnn/tests/end2end/GruTests.cpp index 39f8aa024..757af2ba2 100644 --- a/src/webnn/tests/end2end/GruTests.cpp +++ b/src/webnn/tests/end2end/GruTests.cpp @@ -16,7 +16,7 @@ class GruTests : public WebnnTest { void SetUp() override { - builder = wnn::CreateGraphBuilder(GetContext()); + builder = utils::CreateGraphBuilder(GetContext()); } protected: @@ -43,7 +43,7 @@ class GruTests : public WebnnTest { const size_t outputSize = Y.Size(); std::vector namedOperands; for (size_t i = 0; i < outputSize; ++i) { - namedOperands.push_back({"gru" + std::to_string(i), Y.Get(i)}); + namedOperands.push_back({"gru" + std::to_string(i), Y.GetOperand(i)}); } const wnn::Graph graph = utils::Build(builder, namedOperands); ASSERT_TRUE(graph); @@ -151,9 +151,9 @@ TEST_F(GruTests, GruWithMultiActivitions) { auto activations = CreateCppOperatorArray(); auto activationSigmoid = utils::CreateActivationOperator(builder, utils::FusedActivation::SIGMOID); - activations.Set(activationSigmoid); + activations.SetFusionOperator(activationSigmoid); auto activationTanh = utils::CreateActivationOperator(builder, utils::FusedActivation::TANH); - activations.Set(activationTanh); + activations.SetFusionOperator(activationTanh); options.activations = activations; const std::vector expectedShape = {numDirections, batchSize, hiddenSize}; diff --git a/src/webnn/tests/end2end/HardSwishTests.cpp b/src/webnn/tests/end2end/HardSwishTests.cpp index 2f54a23d8..f57bd48de 100644 --- a/src/webnn/tests/end2end/HardSwishTests.cpp +++ b/src/webnn/tests/end2end/HardSwishTests.cpp @@ -21,7 +21,7 @@ class HardSwishTests : public WebnnTest { void CheckHardSwish(const std::vector& inputShape, const std::vector& inputBuffer, const std::vector& expectedBuffer) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", inputShape); const wnn::Operand y = builder.HardSwish(x); const wnn::Graph graph = utils::Build(builder, {{"y", y}}); diff --git a/src/webnn/tests/end2end/InstanceNormTests.cpp b/src/webnn/tests/end2end/InstanceNormTests.cpp index c6d5dc257..184b3a35b 100644 --- a/src/webnn/tests/end2end/InstanceNormTests.cpp +++ b/src/webnn/tests/end2end/InstanceNormTests.cpp @@ -16,7 +16,7 @@ class InstanceNormTests : public WebnnTest { void SetUp() override { - builder = wnn::CreateGraphBuilder(GetContext()); + builder = utils::CreateGraphBuilder(GetContext()); } protected: diff --git a/src/webnn/tests/end2end/LeakyReluTests.cpp b/src/webnn/tests/end2end/LeakyReluTests.cpp index 303b51ac1..6acd0a2c4 100644 --- a/src/webnn/tests/end2end/LeakyReluTests.cpp +++ b/src/webnn/tests/end2end/LeakyReluTests.cpp @@ -20,7 +20,7 @@ class LeakyReluTests : public WebnnTest { const std::vector& inputData, const std::vector& expectedValue, float alpha = 0.01) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", inputShape); wnn::LeakyReluOptions options; options.alpha = alpha; diff --git a/src/webnn/tests/end2end/MatMulTests.cpp b/src/webnn/tests/end2end/MatMulTests.cpp index 10f94f618..89ce0a499 100644 --- a/src/webnn/tests/end2end/MatMulTests.cpp +++ b/src/webnn/tests/end2end/MatMulTests.cpp @@ -17,7 +17,7 @@ class MatMulTests : public WebnnTest {}; TEST_F(MatMulTests, MatMul1d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {4}); const std::vector bData = {0.8782074, 0.22533207, 0.7134056, 0.04190519}; const wnn::Operand b = @@ -33,7 +33,7 @@ TEST_F(MatMulTests, MatMul1d) { } TEST_F(MatMulTests, MatMul1dx2d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {4}); const std::vector bData = { 0.3093976, -1.2924036, -0.64339244, 1.1423386, 1.5052135, 1.8182521, @@ -52,7 +52,7 @@ TEST_F(MatMulTests, MatMul1dx2d) { } TEST_F(MatMulTests, MatMul2dx1d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4}); const std::vector bData = {0.25528687, 0.2126722, 0.26320502, 0.8297401}; const wnn::Operand b = @@ -71,7 +71,7 @@ TEST_F(MatMulTests, MatMul2dx1d) { } TEST_F(MatMulTests, MatMul2d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4}); const std::vector bData = {0.17467105, -1.2045133, -0.02621938, 0.6096196, 1.4499376, 1.3465316, 0.03289436, 1.0754977, @@ -93,7 +93,7 @@ TEST_F(MatMulTests, MatMul2d) { } TEST_F(MatMulTests, MatMul3d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {2, 3, 4}); const std::vector bData = {-2.7142005, 0.41909233, 0.80572236, 0.19983047, -1.9361104, 1.1919757, 0.61684674, 0.23732206, 0.74679494, 0.4595843, @@ -122,7 +122,7 @@ TEST_F(MatMulTests, MatMul3d) { } TEST_F(MatMulTests, MatMul3dx2d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {2, 3, 4}); const std::vector bData = {-0.38534147, -0.18395364, -2.548874, 0.4525641, -0.41875792, 0.57480955, -0.41603103, 0.6973883, @@ -147,7 +147,7 @@ TEST_F(MatMulTests, MatMul3dx2d) { } TEST_F(MatMulTests, MatMul3dx2dGet3d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {1, 3, 4}); const std::vector bData = {0.2545374, -1.6150205, -0.64508885, -0.3454305, 0.38700557, 1.3147515, -0.3379386, 1.1804152, @@ -170,7 +170,7 @@ TEST_F(MatMulTests, MatMul3dx2dGet3d) { } TEST_F(MatMulTests, MatMul4d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {1, 2, 3, 4}); const std::vector bData = { -0.45605758, -0.43318668, 0.61509126, -2.2228749, 0.50257015, -0.29311436, @@ -200,7 +200,7 @@ TEST_F(MatMulTests, MatMul4d) { } TEST_F(MatMulTests, MatMul4dx2d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {1, 2, 3, 4}); const std::vector bData = { 0.01829041, -0.73948264, -0.95898634, -0.5105271, 2.1705306, 1.2495605, diff --git a/src/webnn/tests/end2end/MaxTests.cpp b/src/webnn/tests/end2end/MaxTests.cpp index 9bdce0581..09062d2dd 100644 --- a/src/webnn/tests/end2end/MaxTests.cpp +++ b/src/webnn/tests/end2end/MaxTests.cpp @@ -17,7 +17,7 @@ class MaxTests : public WebnnTest {}; TEST_F(MaxTests, MaxConstantAndInput) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const std::vector bData = { -0.00724315, -1.4088361, 0.17466596, 1.1395162, 1.3720452, -0.35610083, -0.5597993, @@ -60,7 +60,7 @@ TEST_F(MaxTests, MaxConstantAndInput) { } TEST_F(MaxTests, MaxTwoInputs) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = utils::BuildInput(builder, "b", {3, 4, 5}); const wnn::Operand c = builder.Max(a, b); @@ -102,7 +102,7 @@ TEST_F(MaxTests, MaxTwoInputs) { } TEST_F(MaxTests, MaxBroadcast) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = utils::BuildInput(builder, "b", {5}); const wnn::Operand c = builder.Max(a, b); diff --git a/src/webnn/tests/end2end/MinTests.cpp b/src/webnn/tests/end2end/MinTests.cpp index 77f29a99a..11f6abd1c 100644 --- a/src/webnn/tests/end2end/MinTests.cpp +++ b/src/webnn/tests/end2end/MinTests.cpp @@ -17,7 +17,7 @@ class MinTests : public WebnnTest {}; TEST_F(MinTests, MinConstantAndInput) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const std::vector bData = { -0.3013072, -0.09710764, 0.19347863, 0.57673335, -0.9459303, -0.311303, -0.51731133, @@ -60,7 +60,7 @@ TEST_F(MinTests, MinConstantAndInput) { } TEST_F(MinTests, MinTwoInputs) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = utils::BuildInput(builder, "b", {3, 4, 5}); const wnn::Operand c = builder.Min(a, b); @@ -102,7 +102,7 @@ TEST_F(MinTests, MinTwoInputs) { } TEST_F(MinTests, MinBroadcast) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = utils::BuildInput(builder, "b", {5}); const wnn::Operand c = builder.Min(a, b); diff --git a/src/webnn/tests/end2end/MulTests.cpp b/src/webnn/tests/end2end/MulTests.cpp index cfd1a0a21..c721e5857 100644 --- a/src/webnn/tests/end2end/MulTests.cpp +++ b/src/webnn/tests/end2end/MulTests.cpp @@ -17,7 +17,7 @@ class MulTests : public WebnnTest {}; TEST_F(MulTests, MulInputAndConstant) { - wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); std::vector dataB = { 2.0435283, 0.07213961, -1.1644137, -1.2209045, 0.8982674, 0.21796915, 0.27658972, @@ -66,7 +66,7 @@ TEST_F(MulTests, MulInputAndConstant) { } TEST_F(MulTests, MulTwoInputs) { - wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); wnn::Operand b = utils::BuildInput(builder, "b", {3, 4, 5}); wnn::Operand c = builder.Mul(a, b); @@ -114,7 +114,7 @@ TEST_F(MulTests, MulTwoInputs) { } TEST_F(MulTests, MulBroadcast) { - wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); std::vector dataB = { 0.6338172, 1.630534, -1.3819867, -1.0427561, 1.058136, diff --git a/src/webnn/tests/end2end/PadTests.cpp b/src/webnn/tests/end2end/PadTests.cpp index fab73be21..d9f43e1bf 100644 --- a/src/webnn/tests/end2end/PadTests.cpp +++ b/src/webnn/tests/end2end/PadTests.cpp @@ -23,7 +23,7 @@ class PadTests : public WebnnTest { const std::vector& expectedShape, const std::vector& expectedValue, wnn::PaddingMode mode = wnn::PaddingMode::Constant) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", inputShape); const wnn::Operand padding = utils::BuildConstant(builder, paddingShape, paddingData.data(), diff --git a/src/webnn/tests/end2end/Pool2dTests.cpp b/src/webnn/tests/end2end/Pool2dTests.cpp index 39d8f6dc9..d71f117b8 100644 --- a/src/webnn/tests/end2end/Pool2dTests.cpp +++ b/src/webnn/tests/end2end/Pool2dTests.cpp @@ -17,7 +17,7 @@ class Pool2dTests : public WebnnTest {}; TEST_F(Pool2dTests, MaxPool2dDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 4, 4}); utils::Pool2dOptions options; options.windowDimensions = {3, 3}; @@ -32,7 +32,7 @@ TEST_F(Pool2dTests, MaxPool2dDefault) { } TEST_F(Pool2dTests, MaxPool2dNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 4, 4, 1}); utils::Pool2dOptions options; options.windowDimensions = {3, 3}; @@ -48,7 +48,7 @@ TEST_F(Pool2dTests, MaxPool2dNhwc) { } TEST_F(Pool2dTests, MaxPool2dDilationsDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 4, 4}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; @@ -64,7 +64,7 @@ TEST_F(Pool2dTests, MaxPool2dDilationsDefault) { } TEST_F(Pool2dTests, MaxPool2dDilationsNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 4, 4, 1}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; @@ -81,7 +81,7 @@ TEST_F(Pool2dTests, MaxPool2dDilationsNhwc) { } TEST_F(Pool2dTests, MaxPool2dPadsDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 5, 5}); utils::Pool2dOptions options; options.windowDimensions = {5, 5}; @@ -99,7 +99,7 @@ TEST_F(Pool2dTests, MaxPool2dPadsDefault) { } TEST_F(Pool2dTests, MaxPool2dPadsNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 5, 5, 1}); utils::Pool2dOptions options; options.windowDimensions = {5, 5}; @@ -118,7 +118,7 @@ TEST_F(Pool2dTests, MaxPool2dPadsNhwc) { } TEST_F(Pool2dTests, MaxPool2dAutoPadSameUpperDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 5, 5}); utils::Pool2dOptions options; options.windowDimensions = {5, 5}; @@ -136,7 +136,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadSameUpperDefault) { } TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -159,7 +159,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitNhwc) { } TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitOutputSizes3x3Nhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -182,7 +182,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitOutputSizes3x3Nhwc) { } TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitOutputSizes4x4Nhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -206,7 +206,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitOutputSizes4x4Nhwc) { } TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitRoundingTypeFloorNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -229,7 +229,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitRoundingTypeFloorNhwc) { } TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitRoundingTypeCeilNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -253,7 +253,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitRoundingTypeCeilNhwc) { } TEST_F(Pool2dTests, MaxPool2dAutoPadSameLowerNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -275,7 +275,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadSameLowerNhwc) { } TEST_F(Pool2dTests, MaxPool2dAutoPadSameUpperNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 5, 5, 1}); utils::Pool2dOptions options; options.windowDimensions = {5, 5}; @@ -294,7 +294,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadSameUpperNhwc) { } TEST_F(Pool2dTests, MaxPool2dStridesDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 5, 5}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; @@ -311,7 +311,7 @@ TEST_F(Pool2dTests, MaxPool2dStridesDefault) { } TEST_F(Pool2dTests, MaxPool2dStridesNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 5, 5, 1}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; @@ -329,7 +329,7 @@ TEST_F(Pool2dTests, MaxPool2dStridesNhwc) { } TEST_F(Pool2dTests, AveragePool2dDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 4, 4}); utils::Pool2dOptions options; options.windowDimensions = {3, 3}; @@ -344,7 +344,7 @@ TEST_F(Pool2dTests, AveragePool2dDefault) { } TEST_F(Pool2dTests, AveragePool2dNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 4, 4, 1}); utils::Pool2dOptions options; options.windowDimensions = {3, 3}; @@ -360,7 +360,7 @@ TEST_F(Pool2dTests, AveragePool2dNhwc) { } TEST_F(Pool2dTests, AveragePool2dPadsDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 5, 5}); utils::Pool2dOptions options; options.windowDimensions = {5, 5}; @@ -379,7 +379,7 @@ TEST_F(Pool2dTests, AveragePool2dPadsDefault) { } TEST_F(Pool2dTests, AveragePool2dPadsNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 5, 5, 1}); utils::Pool2dOptions options; options.windowDimensions = {5, 5}; @@ -399,7 +399,7 @@ TEST_F(Pool2dTests, AveragePool2dPadsNhwc) { } TEST_F(Pool2dTests, AveragePool2dAutoPadSameUpperDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 5, 5}); utils::Pool2dOptions options; options.windowDimensions = {5, 5}; @@ -418,7 +418,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadSameUpperDefault) { } TEST_F(Pool2dTests, AveragePool2dAutoPadSameUpperNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 5, 5, 1}); utils::Pool2dOptions options; options.windowDimensions = {5, 5}; @@ -438,7 +438,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadSameUpperNhwc) { } TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -461,7 +461,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitNhwc) { } TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitOutputSizes3x3Nhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -484,7 +484,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitOutputSizes3x3Nhwc) { } TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitOutputSizes4x4Nhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -508,7 +508,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitOutputSizes4x4Nhwc) { } TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitRoundingTypeFloorNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -531,7 +531,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitRoundingTypeFloorNhwc) { } TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitRoundingTypeCeilNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -555,7 +555,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitRoundingTypeCeilNhwc) { } TEST_F(Pool2dTests, AveragePool2dAutoPadSameLowerNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 7, 7, 1}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -577,7 +577,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadSameLowerNhwc) { } TEST_F(Pool2dTests, AveragePool2dStridesDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 5, 5}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; @@ -594,7 +594,7 @@ TEST_F(Pool2dTests, AveragePool2dStridesDefault) { } TEST_F(Pool2dTests, AveragePool2dStridesNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 5, 5, 1}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; @@ -612,7 +612,7 @@ TEST_F(Pool2dTests, AveragePool2dStridesNhwc) { } TEST_F(Pool2dTests, GlobalAveragePool2dDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 3, 5, 5}); const wnn::Operand y = builder.AveragePool2d(x); const wnn::Graph graph = utils::Build(builder, {{"y", y}}); @@ -636,7 +636,7 @@ TEST_F(Pool2dTests, GlobalAveragePool2dDefault) { } TEST_F(Pool2dTests, GlobalAveragePool2dNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 5, 5, 3}); utils::Pool2dOptions options; options.layout = wnn::InputOperandLayout::Nhwc; @@ -662,7 +662,7 @@ TEST_F(Pool2dTests, GlobalAveragePool2dNhwc) { } TEST_F(Pool2dTests, DISABLED_L2Pool2dStridesDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 2, 4}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; @@ -677,7 +677,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dStridesDefault) { } TEST_F(Pool2dTests, DISABLED_L2Pool2dStrides) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 2, 4}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; @@ -693,7 +693,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dStrides) { } TEST_F(Pool2dTests, DISABLED_L2Pool2dStridesNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 2, 4}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; @@ -709,7 +709,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dStridesNhwc) { } TEST_F(Pool2dTests, DISABLED_L2Pool2dPadsDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 2, 4, 1}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; @@ -726,7 +726,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dPadsDefault) { } TEST_F(Pool2dTests, DISABLED_l2Pool2dPadsOutputSizes3x3) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 7, 7}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -751,7 +751,7 @@ TEST_F(Pool2dTests, DISABLED_l2Pool2dPadsOutputSizes3x3) { } TEST_F(Pool2dTests, DISABLED_l2Pool2dPadsOutputSizes4x4) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 7, 7}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -777,7 +777,7 @@ TEST_F(Pool2dTests, DISABLED_l2Pool2dPadsOutputSizes4x4) { } TEST_F(Pool2dTests, DISABLED_l2Pool2dPadsRoundingTypeFloor) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 7, 7}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -802,7 +802,7 @@ TEST_F(Pool2dTests, DISABLED_l2Pool2dPadsRoundingTypeFloor) { } TEST_F(Pool2dTests, DISABLED_l2Pool2dPadsRoundingTypeCeil) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 7, 7}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -828,7 +828,7 @@ TEST_F(Pool2dTests, DISABLED_l2Pool2dPadsRoundingTypeCeil) { } TEST_F(Pool2dTests, DISABLED_L2Pool2dPadsNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 2, 4, 1}); utils::Pool2dOptions options; options.windowDimensions = {3, 3}; @@ -846,7 +846,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dPadsNhwc) { } TEST_F(Pool2dTests, DISABLED_L2Pool2dSameUpperDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 2, 4}); utils::Pool2dOptions options; options.windowDimensions = {3, 3}; @@ -863,7 +863,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dSameUpperDefault) { } TEST_F(Pool2dTests, DISABLED_L2Pool2dSameUpperNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 2, 4, 1}); utils::Pool2dOptions options; options.windowDimensions = {3, 3}; @@ -881,7 +881,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dSameUpperNhwc) { } TEST_F(Pool2dTests, DISABLED_L2Pool2dSameLowerDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 2, 4}); utils::Pool2dOptions options; options.windowDimensions = {3, 3}; @@ -898,7 +898,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dSameLowerDefault) { } TEST_F(Pool2dTests, DISABLED_L2Pool2dSameLowerNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 2, 4, 1}); utils::Pool2dOptions options; options.windowDimensions = {3, 3}; diff --git a/src/webnn/tests/end2end/PowTests.cpp b/src/webnn/tests/end2end/PowTests.cpp index 43a11ad43..b2078e925 100644 --- a/src/webnn/tests/end2end/PowTests.cpp +++ b/src/webnn/tests/end2end/PowTests.cpp @@ -17,7 +17,7 @@ class PowTests : public WebnnTest {}; TEST_F(PowTests, Sqrt1d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3}); const std::vector bData = {0.5}; const wnn::Operand b = @@ -33,7 +33,7 @@ TEST_F(PowTests, Sqrt1d) { } TEST_F(PowTests, Sqrt3d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const std::vector bData = {0.5}; const wnn::Operand b = @@ -67,7 +67,7 @@ TEST_F(PowTests, Sqrt3d) { } TEST_F(PowTests, Pow1d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3}); const std::vector bData = {2}; const wnn::Operand b = @@ -83,7 +83,7 @@ TEST_F(PowTests, Pow1d) { } TEST_F(PowTests, PowBroadcastScalar) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {2, 3}); const std::vector bData = {2}; const wnn::Operand b = @@ -99,7 +99,7 @@ TEST_F(PowTests, PowBroadcastScalar) { } TEST_F(PowTests, PowBroadcast1d) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {2, 3}); const std::vector bData = {1, 2, 3}; const wnn::Operand b = diff --git a/src/webnn/tests/end2end/ReduceTests.cpp b/src/webnn/tests/end2end/ReduceTests.cpp index 65e5c8e87..53d2ab4ed 100644 --- a/src/webnn/tests/end2end/ReduceTests.cpp +++ b/src/webnn/tests/end2end/ReduceTests.cpp @@ -33,7 +33,7 @@ class ReduceTests : public WebnnTest { const std::vector& expectedValue, const std::vector& axes = {}, bool keepDimensions = false) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", inputShape); wnn::ReduceOptions options; if (!axes.empty()) { diff --git a/src/webnn/tests/end2end/ReluTests.cpp b/src/webnn/tests/end2end/ReluTests.cpp index 4b8afc37e..4f3d04661 100644 --- a/src/webnn/tests/end2end/ReluTests.cpp +++ b/src/webnn/tests/end2end/ReluTests.cpp @@ -17,7 +17,7 @@ class ReluTests : public WebnnTest {}; TEST_F(ReluTests, Relu) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = builder.Relu(a); const wnn::Graph graph = utils::Build(builder, {{"b", b}}); diff --git a/src/webnn/tests/end2end/Resample2dTests.cpp b/src/webnn/tests/end2end/Resample2dTests.cpp index 42024904b..d5cfe9f86 100644 --- a/src/webnn/tests/end2end/Resample2dTests.cpp +++ b/src/webnn/tests/end2end/Resample2dTests.cpp @@ -21,7 +21,7 @@ class Resample2dTests : public WebnnTest { const std::vector& expectedShape, const std::vector& expectedValue, const wnn::Resample2dOptions* options = nullptr) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand inputOperand = utils::BuildInput(builder, "input", inputShape); const wnn::Operand output = builder.Resample2d(inputOperand, options); const wnn::Graph graph = utils::Build(builder, {{"output", output}}); diff --git a/src/webnn/tests/end2end/ReshapeTests.cpp b/src/webnn/tests/end2end/ReshapeTests.cpp index 110203d28..f86b71075 100644 --- a/src/webnn/tests/end2end/ReshapeTests.cpp +++ b/src/webnn/tests/end2end/ReshapeTests.cpp @@ -19,7 +19,7 @@ class ReshapeTests : public WebnnTest { void TestReshape(const std::vector& oldShape, const std::vector& newShape, const std::vector& expectedShape) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", oldShape); const wnn::Operand b = builder.Reshape(a, newShape.data(), newShape.size()); const wnn::Graph graph = utils::Build(builder, {{"b", b}}); diff --git a/src/webnn/tests/end2end/SigmoidTests.cpp b/src/webnn/tests/end2end/SigmoidTests.cpp index 893904d88..2686ea899 100644 --- a/src/webnn/tests/end2end/SigmoidTests.cpp +++ b/src/webnn/tests/end2end/SigmoidTests.cpp @@ -17,7 +17,7 @@ class SigmoidTests : public WebnnTest {}; TEST_F(SigmoidTests, SigmoidWith1DTensor) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3}); const wnn::Operand b = builder.Sigmoid(a); const wnn::Graph graph = utils::Build(builder, {{"b", b}}); @@ -30,7 +30,7 @@ TEST_F(SigmoidTests, SigmoidWith1DTensor) { } TEST_F(SigmoidTests, SigmoidWith3DTensor) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = builder.Sigmoid(a); const wnn::Graph graph = utils::Build(builder, {{"b", b}}); diff --git a/src/webnn/tests/end2end/SliceTests.cpp b/src/webnn/tests/end2end/SliceTests.cpp index 759b5c9b7..8c02143a7 100644 --- a/src/webnn/tests/end2end/SliceTests.cpp +++ b/src/webnn/tests/end2end/SliceTests.cpp @@ -16,7 +16,7 @@ class SliceTests : public WebnnTest { void SetUp() override { - builder = wnn::CreateGraphBuilder(GetContext()); + builder = utils::CreateGraphBuilder(GetContext()); } protected: diff --git a/src/webnn/tests/end2end/SoftmaxTests.cpp b/src/webnn/tests/end2end/SoftmaxTests.cpp index cb8e3b89b..2e7211ee9 100644 --- a/src/webnn/tests/end2end/SoftmaxTests.cpp +++ b/src/webnn/tests/end2end/SoftmaxTests.cpp @@ -17,7 +17,7 @@ class SoftmaxTests : public WebnnTest {}; TEST_F(SoftmaxTests, Softmax) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4}); const wnn::Operand b = builder.Softmax(a); const wnn::Graph graph = utils::Build(builder, {{"b", b}}); diff --git a/src/webnn/tests/end2end/SplitTests.cpp b/src/webnn/tests/end2end/SplitTests.cpp index 69b6b21fd..cf9bfb4d7 100644 --- a/src/webnn/tests/end2end/SplitTests.cpp +++ b/src/webnn/tests/end2end/SplitTests.cpp @@ -29,14 +29,14 @@ class SplitTests : public WebnnTest { const std::vector& splits, const std::vector& expectedArray, int32_t axis = 0) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand input = utils::BuildInput(builder, "input", inputShape); wnn::SplitOptions options = {axis}; const wnn::OperandArray splittedOperands = builder.Split(input, splits.data(), splits.size(), &options); std::vector namedOperands; for (size_t i = 0; i < splittedOperands.Size(); ++i) { - namedOperands.push_back({"split" + std::to_string(i), splittedOperands.Get(i)}); + namedOperands.push_back({"split" + std::to_string(i), splittedOperands.GetOperand(i)}); } const wnn::Graph graph = utils::Build(builder, namedOperands); ASSERT_TRUE(graph); diff --git a/src/webnn/tests/end2end/SqueezeTests.cpp b/src/webnn/tests/end2end/SqueezeTests.cpp index af336a1b0..ef581ad6e 100644 --- a/src/webnn/tests/end2end/SqueezeTests.cpp +++ b/src/webnn/tests/end2end/SqueezeTests.cpp @@ -21,7 +21,7 @@ class SqueezeTests : public WebnnTest { void CheckSqueeze(const std::vector& inputShape, const std::vector& axes, const std::vector& expectedShape) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", inputShape); wnn::SqueezeOptions options; if (!axes.empty()) { diff --git a/src/webnn/tests/end2end/SubTests.cpp b/src/webnn/tests/end2end/SubTests.cpp index 16beeb7ab..d9bae408b 100644 --- a/src/webnn/tests/end2end/SubTests.cpp +++ b/src/webnn/tests/end2end/SubTests.cpp @@ -17,7 +17,7 @@ class SubTests : public WebnnTest {}; TEST_F(SubTests, SubTwoInputs) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = utils::BuildInput(builder, "b", {3, 4, 5}); const wnn::Operand c = builder.Sub(a, b); @@ -60,7 +60,7 @@ TEST_F(SubTests, SubTwoInputs) { } TEST_F(SubTests, SubBroadcast) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const wnn::Operand b = utils::BuildInput(builder, "b", {5}); const wnn::Operand c = builder.Sub(a, b); diff --git a/src/webnn/tests/end2end/TanhTests.cpp b/src/webnn/tests/end2end/TanhTests.cpp index 62b487ee4..fce86dd84 100644 --- a/src/webnn/tests/end2end/TanhTests.cpp +++ b/src/webnn/tests/end2end/TanhTests.cpp @@ -19,7 +19,7 @@ class TanhTests : public WebnnTest { void TestTanh(const std::vector& inputData, const std::vector& expectedData, const std::vector& shape) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", shape); const wnn::Operand b = builder.Tanh(a); const wnn::Graph graph = utils::Build(builder, {{"b", b}}); diff --git a/src/webnn/tests/end2end/TransposeTests.cpp b/src/webnn/tests/end2end/TransposeTests.cpp index c93e808cc..2da27f0f3 100644 --- a/src/webnn/tests/end2end/TransposeTests.cpp +++ b/src/webnn/tests/end2end/TransposeTests.cpp @@ -21,7 +21,7 @@ class TransposeTests : public WebnnTest { const std::vector& expectedShape, const std::vector& expectedValue, const std::vector& permutation = {}) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand a = utils::BuildInput(builder, "a", inputShape); wnn::TransposeOptions options; options.permutation = permutation.data(); diff --git a/src/webnn/tests/end2end/models/MobileNetV2BatchNormNchw.cpp b/src/webnn/tests/end2end/models/MobileNetV2BatchNormNchw.cpp index 084c841bd..d205dc251 100644 --- a/src/webnn/tests/end2end/models/MobileNetV2BatchNormNchw.cpp +++ b/src/webnn/tests/end2end/models/MobileNetV2BatchNormNchw.cpp @@ -26,7 +26,7 @@ class MobileNetV2BatchNormNchwTests : public WebnnTest { mobilenetv2.mFused = fused; const std::string nchwPath = kModelPath + "/mobilenetv2_batchnorm_nchw/"; mobilenetv2.mWeightsPath = nchwPath + "weights/"; - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); wnn::Operand output = mobilenetv2.LoadBatchNormNchw(builder, false); wnn::Graph graph = utils::Build(builder, {{"output", output}}); const cnpy::NpyArray inputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + inputFile); diff --git a/src/webnn/tests/end2end/models/MobileNetV2Nchw.cpp b/src/webnn/tests/end2end/models/MobileNetV2Nchw.cpp index 71a66ff5f..dcda6514f 100644 --- a/src/webnn/tests/end2end/models/MobileNetV2Nchw.cpp +++ b/src/webnn/tests/end2end/models/MobileNetV2Nchw.cpp @@ -26,7 +26,7 @@ class MobileNetV2NchwTests : public WebnnTest { mobilenetv2.mFused = fused; const std::string nchwPath = kModelPath + "/mobilenetv2_nchw/"; mobilenetv2.mWeightsPath = nchwPath + "weights/"; - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); wnn::Operand output = mobilenetv2.LoadNchw(builder, false); wnn::Graph graph = utils::Build(builder, {{"output", output}}); const cnpy::NpyArray inputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + inputFile); diff --git a/src/webnn/tests/end2end/models/MobileNetV2Nhwc.cpp b/src/webnn/tests/end2end/models/MobileNetV2Nhwc.cpp index 47842facc..b0652a1e1 100644 --- a/src/webnn/tests/end2end/models/MobileNetV2Nhwc.cpp +++ b/src/webnn/tests/end2end/models/MobileNetV2Nhwc.cpp @@ -27,7 +27,7 @@ class MobileNetV2NhwcTests : public WebnnTest { const std::string nhwcPath = kModelPath + "/mobilenetv2_nhwc/"; mobilenetv2.mWeightsPath = nhwcPath + "weights/"; mobilenetv2.mLayout = "nhwc"; - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); wnn::Operand output = mobilenetv2.LoadNhwc(builder); wnn::Graph graph = utils::Build(builder, {{"output", output}}); const cnpy::NpyArray inputNpy = cnpy::npy_load(nhwcPath + "test_data_set/" + inputFile); diff --git a/src/webnn/tests/end2end/models/ResNetNchw.cpp b/src/webnn/tests/end2end/models/ResNetNchw.cpp index 54e48122a..f21e9673d 100644 --- a/src/webnn/tests/end2end/models/ResNetNchw.cpp +++ b/src/webnn/tests/end2end/models/ResNetNchw.cpp @@ -26,7 +26,7 @@ class ResNetNchwTests : public WebnnTest { resnet.mFused = fused; const std::string nchwPath = kModelPath + "/resnet50v2_nchw/"; resnet.mWeightsPath = nchwPath + "weights/"; - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); wnn::Operand output = resnet.LoadNchw(builder, false); wnn::Graph graph = utils::Build(builder, {{"output", output}}); const cnpy::NpyArray inputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + inputFile); diff --git a/src/webnn/tests/end2end/models/ResNetNhwc.cpp b/src/webnn/tests/end2end/models/ResNetNhwc.cpp index ad0fcb23c..bf1a22995 100644 --- a/src/webnn/tests/end2end/models/ResNetNhwc.cpp +++ b/src/webnn/tests/end2end/models/ResNetNhwc.cpp @@ -26,7 +26,7 @@ class ResNetNhwcTests : public WebnnTest { resnet.mFused = fused; const std::string nhwcPath = kModelPath + "/resnet50v2_nhwc/"; resnet.mWeightsPath = nhwcPath + "weights/"; - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); wnn::Operand output = resnet.LoadNhwc(builder, true); wnn::Graph graph = utils::Build(builder, {{"output", output}}); const cnpy::NpyArray inputNpy = cnpy::npy_load(nhwcPath + "test_data_set/" + inputFile); diff --git a/src/webnn/tests/end2end/models/SqueezeNetNchw.cpp b/src/webnn/tests/end2end/models/SqueezeNetNchw.cpp index 9430f604f..18d2b6b24 100644 --- a/src/webnn/tests/end2end/models/SqueezeNetNchw.cpp +++ b/src/webnn/tests/end2end/models/SqueezeNetNchw.cpp @@ -26,7 +26,7 @@ class SqueezeNetNchwTests : public WebnnTest { squeezenet.mFused = fused; const std::string nchwPath = kModelPath + "/squeezenet1.1_nchw/"; squeezenet.mWeightsPath = nchwPath + "weights/"; - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); wnn::Operand output = squeezenet.LoadNchw(builder, false); wnn::Graph graph = utils::Build(builder, {{"output", output}}); const cnpy::NpyArray inputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + inputFile); diff --git a/src/webnn/tests/end2end/models/SqueezeNetNhwc.cpp b/src/webnn/tests/end2end/models/SqueezeNetNhwc.cpp index adfa1d6ff..4ef54b70a 100644 --- a/src/webnn/tests/end2end/models/SqueezeNetNhwc.cpp +++ b/src/webnn/tests/end2end/models/SqueezeNetNhwc.cpp @@ -27,7 +27,7 @@ class SqueezeNetNhwcTests : public WebnnTest { const std::string nhwcPath = kModelPath + "/squeezenet1.0_nhwc/"; squeezenet.mWeightsPath = nhwcPath + "weights/"; squeezenet.mLayout = "nhwc"; - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); wnn::Operand output = squeezenet.LoadNhwc(builder); wnn::Graph graph = utils::Build(builder, {{"output", output}}); const cnpy::NpyArray inputNpy = cnpy::npy_load(nhwcPath + "test_data_set/" + inputFile); diff --git a/src/webnn/tests/unittests/ObjectBaseTests.cpp b/src/webnn/tests/unittests/ObjectBaseTests.cpp index 3d11a3602..6c9f93b72 100644 --- a/src/webnn/tests/unittests/ObjectBaseTests.cpp +++ b/src/webnn/tests/unittests/ObjectBaseTests.cpp @@ -23,11 +23,11 @@ class Object : public wnn::ObjectBase { using ObjectBase::ObjectBase; using ObjectBase::operator=; - static void WebnnReference(int* handle) { + static void WNNReference(int* handle) { ASSERT_LE(0, *handle); *handle += 1; } - static void WebnnRelease(int* handle) { + static void WNNRelease(int* handle) { ASSERT_LT(0, *handle); *handle -= 1; } @@ -54,14 +54,14 @@ TEST(ObjectBase, AcquireConstruction) { ASSERT_EQ(0, refcount); } -// Test .GetHandle(). +// Test .Get(). TEST(ObjectBase, Get) { int refcount = 1; { Object obj1(&refcount); ASSERT_EQ(2, refcount); - ASSERT_EQ(&refcount, obj1.GetHandle()); + ASSERT_EQ(&refcount, obj1.Get()); } ASSERT_EQ(1, refcount); } @@ -74,7 +74,7 @@ TEST(ObjectBase, Release) { ASSERT_EQ(2, refcount); ASSERT_EQ(&refcount, obj.Release()); - ASSERT_EQ(nullptr, obj.GetHandle()); + ASSERT_EQ(nullptr, obj.Get()); ASSERT_EQ(2, refcount); } ASSERT_EQ(2, refcount); @@ -98,8 +98,8 @@ TEST(ObjectBase, CopyConstructor) { Object source(&refcount); Object destination(source); - ASSERT_EQ(source.GetHandle(), &refcount); - ASSERT_EQ(destination.GetHandle(), &refcount); + ASSERT_EQ(source.Get(), &refcount); + ASSERT_EQ(destination.Get(), &refcount); ASSERT_EQ(3, refcount); destination = Object(); @@ -114,8 +114,8 @@ TEST(ObjectBase, CopyAssignment) { Object destination; destination = source; - ASSERT_EQ(source.GetHandle(), &refcount); - ASSERT_EQ(destination.GetHandle(), &refcount); + ASSERT_EQ(source.Get(), &refcount); + ASSERT_EQ(destination.Get(), &refcount); ASSERT_EQ(3, refcount); destination = Object(); @@ -132,7 +132,7 @@ TEST(ObjectBase, CopyAssignmentSelf) { Object* objPtr = &obj; obj = *objPtr; - ASSERT_EQ(obj.GetHandle(), &refcount); + ASSERT_EQ(obj.Get(), &refcount); ASSERT_EQ(refcount, 2); } @@ -142,8 +142,8 @@ TEST(ObjectBase, MoveConstructor) { Object source(&refcount); Object destination(std::move(source)); - ASSERT_EQ(source.GetHandle(), nullptr); - ASSERT_EQ(destination.GetHandle(), &refcount); + ASSERT_EQ(source.Get(), nullptr); + ASSERT_EQ(destination.Get(), &refcount); ASSERT_EQ(2, refcount); destination = Object(); @@ -158,8 +158,8 @@ TEST(ObjectBase, MoveAssignment) { Object destination; destination = std::move(source); - ASSERT_EQ(source.GetHandle(), nullptr); - ASSERT_EQ(destination.GetHandle(), &refcount); + ASSERT_EQ(source.Get(), nullptr); + ASSERT_EQ(destination.Get(), &refcount); ASSERT_EQ(2, refcount); destination = Object(); @@ -176,14 +176,14 @@ TEST(ObjectBase, MoveAssignmentSelf) { Object* objPtr = &obj; obj = std::move(*objPtr); - ASSERT_EQ(obj.GetHandle(), &refcount); + ASSERT_EQ(obj.Get(), &refcount); ASSERT_EQ(refcount, 2); } // Test the constructor using nullptr TEST(ObjectBase, NullptrConstructor) { Object obj(nullptr); - ASSERT_EQ(obj.GetHandle(), nullptr); + ASSERT_EQ(obj.Get(), nullptr); } // Test assigning nullptr to the object diff --git a/src/webnn/tests/unittests/validation/GraphValidationTests.cpp b/src/webnn/tests/unittests/validation/GraphValidationTests.cpp index 586537f0c..485b233d3 100644 --- a/src/webnn/tests/unittests/validation/GraphValidationTests.cpp +++ b/src/webnn/tests/unittests/validation/GraphValidationTests.cpp @@ -40,13 +40,17 @@ class GraphValidationTest : public ValidationTest { // Test the simple success case. TEST_F(GraphValidationTest, BuildGraphSuccess) { - wnn::NamedOperands namedOperands = wnn::CreateNamedOperands(); - namedOperands.Set("output", mOutput); - mBuilder.Build(namedOperands); + // TODO::Use instance->CreateNamedOperands instead of wnn::CreateNamedOperands + // that is removed. + // wnn::NamedOperands namedOperands = wnn::CreateNamedOperands(); + // namedOperands.Set("output", mOutput); + // mBuilder.Build(namedOperands); } // Create model with null nameOperands TEST_F(GraphValidationTest, BuildGraphError) { - wnn::NamedOperands namedOperands = wnn::CreateNamedOperands(); - DAWN_ASSERT(mBuilder.Build(namedOperands) == nullptr); + // TODO::Use instance->CreateNamedOperands instead of wnn::CreateNamedOperands + // that is removed. + // wnn::NamedOperands namedOperands = wnn::CreateNamedOperands(); + // DAWN_ASSERT(mBuilder.Build(namedOperands) == nullptr); } diff --git a/src/webnn/tests/unittests/validation/ValidationTest.cpp b/src/webnn/tests/unittests/validation/ValidationTest.cpp index bec699be3..cb5efe68a 100644 --- a/src/webnn/tests/unittests/validation/ValidationTest.cpp +++ b/src/webnn/tests/unittests/validation/ValidationTest.cpp @@ -29,7 +29,7 @@ void ValidationTest::SetUp() { ASSERT_TRUE(context != nullptr); mContext = wnn::Context::Acquire(context); mContext.SetUncapturedErrorCallback(ErrorCallback, this); - mBuilder = wnn::CreateGraphBuilder(mContext); + mBuilder = wnn::GraphBuilder::Acquire(instance->CreateGraphBuilder(mContext.Get())); } ValidationTest::~ValidationTest() { diff --git a/src/webnn/utils/BUILD.gn b/src/webnn/utils/BUILD.gn index 0defcc9b6..b9638e33d 100644 --- a/src/webnn/utils/BUILD.gn +++ b/src/webnn/utils/BUILD.gn @@ -11,16 +11,16 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. +# limitations under the License. import("../../../scripts/webnn_overrides_with_defaults.gni") -import("${webnn_dawn_root}/scripts/dawn_features.gni") +import("${dawn_root}/scripts/dawn_features.gni") import("${webnn_root}/build_overrides/webnn_features.gni") ############################################################################### # Utils for tests and samples -############################################################################### +############################################################################### static_library("webnn_utils") { configs += [ "${webnn_root}/src/webnn/common:internal_config" ] diff --git a/src/webnn/wire/BUILD.gn b/src/webnn/wire/BUILD.gn index 37219331a..4365b3215 100644 --- a/src/webnn/wire/BUILD.gn +++ b/src/webnn/wire/BUILD.gn @@ -15,7 +15,7 @@ import("../../../scripts/webnn_overrides_with_defaults.gni") -import("${webnn_dawn_root}/scripts/dawn_component.gni") +import("${dawn_root}/scripts/dawn_component.gni") import("${webnn_root}/generator/webnn_generator.gni") # Public webnn_wire headers so they can be publically visible for @@ -35,8 +35,8 @@ webnn_json_generator("gen") { target = "wire" outputs = [ "src/webnn/wire/ObjectType_autogen.h", - "src/webnn/wire/WireCmd_autogen.h", "src/webnn/wire/WireCmd_autogen.cpp", + "src/webnn/wire/WireCmd_autogen.h", "src/webnn/wire/client/ApiObjects_autogen.h", "src/webnn/wire/client/ApiProcs_autogen.cpp", "src/webnn/wire/client/ClientBase_autogen.h", @@ -60,14 +60,18 @@ dawn_component("webnn_wire") { configs = [ "${webnn_root}/src/webnn/common:internal_config" ] sources = get_target_outputs(":gen") sources += [ + "BufferConsumer.h", + "BufferConsumer_impl.h", "ChunkedCommandHandler.cpp", "ChunkedCommandHandler.h", "ChunkedCommandSerializer.cpp", "ChunkedCommandSerializer.h", "WireClient.cpp", + "Wire.cpp", "WireDeserializeAllocator.cpp", "WireDeserializeAllocator.h", "WireServer.cpp", + "WireResult.h", "client/ApiObjects.h", "client/Client.cpp", "client/Client.h", @@ -107,7 +111,7 @@ dawn_component("webnn_wire") { if (is_component_build) { libs = [ "dawn_wire.dll.lib" ] } - deps += [ "${webnn_dawn_root}/src/dawn/wire" ] + deps += [ "${dawn_root}/src/dawn/wire" ] } # Make headers publicly visible diff --git a/src/webnn/wire/BufferConsumer.h b/src/webnn/wire/BufferConsumer.h new file mode 100644 index 000000000..6fd631a83 --- /dev/null +++ b/src/webnn/wire/BufferConsumer.h @@ -0,0 +1,85 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef WEBNN_WIRE_BUFFERCONSUMER_H_ +#define WEBNN_WIRE_BUFFERCONSUMER_H_ + +#include + +#include "webnn/wire/WireResult.h" + +namespace webnn::wire { + + // BufferConsumer is a utility class that allows reading bytes from a buffer + // while simultaneously decrementing the amount of remaining space by exactly + // the amount read. It helps prevent bugs where incrementing a pointer and + // decrementing a size value are not kept in sync. + // BufferConsumer also contains bounds checks to prevent reading out-of-bounds. + template + class BufferConsumer { + static_assert(sizeof(BufferT) == 1, + "BufferT must be 1-byte, but may have const/volatile qualifiers."); + + public: + BufferConsumer(BufferT* buffer, size_t size) : mBuffer(buffer), mSize(size) { + } + + BufferT* Buffer() const { + return mBuffer; + } + size_t AvailableSize() const { + return mSize; + } + + protected: + template + WireResult NextN(N count, T** data); + + template + WireResult Next(T** data); + + template + WireResult Peek(T** data); + + private: + BufferT* mBuffer; + size_t mSize; + }; + + class SerializeBuffer : public BufferConsumer { + public: + using BufferConsumer::BufferConsumer; + using BufferConsumer::Next; + using BufferConsumer::NextN; + }; + + class DeserializeBuffer : public BufferConsumer { + public: + using BufferConsumer::BufferConsumer; + using BufferConsumer::Peek; + + template + WireResult ReadN(N count, const volatile T** data) { + return NextN(count, data); + } + + template + WireResult Read(const volatile T** data) { + return Next(data); + } + }; + +} // namespace webnn::wire + +#endif // WEBNN_WIRE_BUFFERCONSUMER_H_ diff --git a/src/webnn/wire/BufferConsumer_impl.h b/src/webnn/wire/BufferConsumer_impl.h new file mode 100644 index 000000000..3259d75b2 --- /dev/null +++ b/src/webnn/wire/BufferConsumer_impl.h @@ -0,0 +1,73 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef WEBNN_WIRE_BUFFERCONSUMER_IMPL_H_ +#define WEBNN_WIRE_BUFFERCONSUMER_IMPL_H_ + +#include "webnn/wire/BufferConsumer.h" + +#include +#include + +namespace webnn::wire { + + template + template + WireResult BufferConsumer::Peek(T** data) { + if (sizeof(T) > mSize) { + return WireResult::FatalError; + } + + *data = reinterpret_cast(mBuffer); + return WireResult::Success; + } + + template + template + WireResult BufferConsumer::Next(T** data) { + if (sizeof(T) > mSize) { + return WireResult::FatalError; + } + + *data = reinterpret_cast(mBuffer); + mBuffer += sizeof(T); + mSize -= sizeof(T); + return WireResult::Success; + } + + template + template + WireResult BufferConsumer::NextN(N count, T** data) { + static_assert(std::is_unsigned::value, "|count| argument of NextN must be unsigned."); + + constexpr size_t kMaxCountWithoutOverflows = std::numeric_limits::max() / sizeof(T); + if (count > kMaxCountWithoutOverflows) { + return WireResult::FatalError; + } + + // Cannot overflow because |count| is not greater than |kMaxCountWithoutOverflows|. + size_t totalSize = sizeof(T) * count; + if (totalSize > mSize) { + return WireResult::FatalError; + } + + *data = reinterpret_cast(mBuffer); + mBuffer += totalSize; + mSize -= totalSize; + return WireResult::Success; + } + +} // namespace webnn::wire + +#endif // WEBNN_WIRE_BUFFERCONSUMER_IMPL_H_ diff --git a/src/webnn/wire/ChunkedCommandSerializer.h b/src/webnn/wire/ChunkedCommandSerializer.h index f3c0a6347..3cf4d62fa 100644 --- a/src/webnn/wire/ChunkedCommandSerializer.h +++ b/src/webnn/wire/ChunkedCommandSerializer.h @@ -23,16 +23,17 @@ #include #include #include +#include namespace webnn::wire { class ChunkedCommandSerializer { public: - ChunkedCommandSerializer(CommandSerializer* serializer); + explicit ChunkedCommandSerializer(CommandSerializer* serializer); template void SerializeCommand(const Cmd& cmd) { - SerializeCommand(cmd, 0, [](char*) {}); + SerializeCommand(cmd, 0, [](SerializeBuffer*) { return WireResult::Success; }); } template @@ -41,15 +42,16 @@ namespace webnn::wire { ExtraSizeSerializeFn&& SerializeExtraSize) { SerializeCommandImpl( cmd, - [](const Cmd& cmd, size_t requiredSize, char* allocatedBuffer) { - cmd.Serialize(requiredSize, allocatedBuffer); + [](const Cmd& cmd, size_t requiredSize, SerializeBuffer* serializeBuffer) { + return cmd.Serialize(requiredSize, serializeBuffer); }, extraSize, std::forward(SerializeExtraSize)); } template void SerializeCommand(const Cmd& cmd, const ObjectIdProvider& objectIdProvider) { - SerializeCommand(cmd, objectIdProvider, 0, [](char*) {}); + SerializeCommand(cmd, objectIdProvider, 0, + [](SerializeBuffer*) { return WireResult::Success; }); } template @@ -59,8 +61,9 @@ namespace webnn::wire { ExtraSizeSerializeFn&& SerializeExtraSize) { SerializeCommandImpl( cmd, - [&objectIdProvider](const Cmd& cmd, size_t requiredSize, char* allocatedBuffer) { - cmd.Serialize(requiredSize, allocatedBuffer, objectIdProvider); + [&objectIdProvider](const Cmd& cmd, size_t requiredSize, + SerializeBuffer* serializeBuffer) { + return cmd.Serialize(requiredSize, serializeBuffer, objectIdProvider); }, extraSize, std::forward(SerializeExtraSize)); } @@ -77,8 +80,12 @@ namespace webnn::wire { if (requiredSize <= mMaxAllocationSize) { char* allocatedBuffer = static_cast(mSerializer->GetCmdSpace(requiredSize)); if (allocatedBuffer != nullptr) { - SerializeCmd(cmd, requiredSize, allocatedBuffer); - SerializeExtraSize(allocatedBuffer + commandSize); + SerializeBuffer serializeBuffer(allocatedBuffer, requiredSize); + WireResult r1 = SerializeCmd(cmd, requiredSize, &serializeBuffer); + WireResult r2 = SerializeExtraSize(&serializeBuffer); + if (DAWN_UNLIKELY(r1 != WireResult::Success || r2 != WireResult::Success)) { + mSerializer->OnSerializeError(); + } } return; } @@ -87,8 +94,13 @@ namespace webnn::wire { if (!cmdSpace) { return; } - SerializeCmd(cmd, requiredSize, cmdSpace.get()); - SerializeExtraSize(cmdSpace.get() + commandSize); + SerializeBuffer serializeBuffer(cmdSpace.get(), requiredSize); + WireResult r1 = SerializeCmd(cmd, requiredSize, &serializeBuffer); + WireResult r2 = SerializeExtraSize(&serializeBuffer); + if (DAWN_UNLIKELY(r1 != WireResult::Success || r2 != WireResult::Success)) { + mSerializer->OnSerializeError(); + return; + } SerializeChunkedCommand(cmdSpace.get(), requiredSize); } diff --git a/src/webnn/wire/Wire.cpp b/src/webnn/wire/Wire.cpp new file mode 100644 index 000000000..44754eafa --- /dev/null +++ b/src/webnn/wire/Wire.cpp @@ -0,0 +1,24 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "webnn/wire/Wire.h" + +namespace webnn::wire { + +CommandSerializer::CommandSerializer() = default; +CommandSerializer::~CommandSerializer() = default; + +void CommandSerializer::OnSerializeError() {} + +} // namespace webnn::wire diff --git a/src/webnn/wire/WireResult.h b/src/webnn/wire/WireResult.h new file mode 100644 index 000000000..85e765eb3 --- /dev/null +++ b/src/webnn/wire/WireResult.h @@ -0,0 +1,38 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef WEBNN_WIRE_WIRERESULT_H_ +#define WEBNN_WIRE_WIRERESULT_H_ + +#include "common/Compiler.h" + +namespace webnn::wire { + + enum class [[nodiscard]] WireResult{ + Success, + FatalError, + }; + +// Macro to simplify error handling, similar to DAWN_TRY but for WireResult. +#define WIRE_TRY(EXPR) \ + do { \ + WireResult exprResult = EXPR; \ + if (DAWN_UNLIKELY(exprResult != WireResult::Success)) { \ + return exprResult; \ + } \ + } while (0) + +} // namespace webnn::wire + +#endif // WEBNN_WIRE_WIRERESULT_H_ diff --git a/src/webnn/wire/client/GraphBuilder.cpp b/src/webnn/wire/client/GraphBuilder.cpp index c6a6fad31..940d45d57 100644 --- a/src/webnn/wire/client/GraphBuilder.cpp +++ b/src/webnn/wire/client/GraphBuilder.cpp @@ -25,7 +25,7 @@ namespace webnn::wire::client { GraphBuilderConstantInternalCmd cmd; cmd.graphBuilderId = this->id; cmd.desc = desc; - cmd.buffer = static_cast(value->buffer); + cmd.arrayBuffer = static_cast(value->buffer); cmd.byteLength = value->byteLength; cmd.byteOffset = value->byteOffset; @@ -43,7 +43,7 @@ namespace webnn::wire::client { GraphBuilderConstantWithGpuBufferInternalCmd cmd; cmd.graphBuilderId = this->id; cmd.desc = desc; - cmd.buffer = static_cast(value->buffer); + cmd.arrayBuffer = static_cast(value->buffer); cmd.id = value->id; cmd.generation = value->generation; cmd.byteLength = value->size; diff --git a/src/webnn/wire/client/NamedInputs.cpp b/src/webnn/wire/client/NamedInputs.cpp index 07675bc2f..26192f7b3 100644 --- a/src/webnn/wire/client/NamedInputs.cpp +++ b/src/webnn/wire/client/NamedInputs.cpp @@ -26,7 +26,7 @@ namespace webnn::wire::client { // Input type is ArrayBufferView WNNArrayBufferView arrayBufferView = input->resource.arrayBufferView; if (arrayBufferView.buffer != nullptr) { - cmd.buffer = static_cast(arrayBufferView.buffer); + cmd.arrayBuffer = static_cast(arrayBufferView.buffer); cmd.byteLength = arrayBufferView.byteLength; cmd.byteOffset = arrayBufferView.byteOffset; } else { diff --git a/src/webnn/wire/client/NamedOutputs.cpp b/src/webnn/wire/client/NamedOutputs.cpp index 38a437b50..9ef0893d6 100644 --- a/src/webnn/wire/client/NamedOutputs.cpp +++ b/src/webnn/wire/client/NamedOutputs.cpp @@ -19,9 +19,9 @@ namespace webnn::wire::client { - void NamedOutputs::Set(char const* name, WNNResource const* resource) { + void NamedOutputs::SetOutput(char const* name, WNNResource const* resource) { // The type of output data is WNNArrayBufferView. - NamedOutputsSetCmd cmd = {}; + NamedOutputsSetOutputCmd cmd = {}; cmd.namedOutputsId = this->id; cmd.name = name; WNNArrayBufferView arrayBufferView = resource->arrayBufferView; @@ -40,8 +40,16 @@ namespace webnn::wire::client { client->SerializeCommand(cmd); } - void NamedOutputs::Get(char const* name, WNNArrayBufferView const* resource) { - UNREACHABLE(); + void NamedOutputs::GetOutput(char const* name, WNNArrayBufferView const* resource) { + NamedOutputsGetOutputCmd cmd = {}; + cmd.namedOutputsId = this->id; + cmd.name = name; + if (resource->buffer != nullptr) { + cmd.arrayBuffer = static_cast(resource->buffer); + cmd.byteLength = resource->byteLength; + cmd.byteOffset = resource->byteOffset; + } + client->SerializeCommand(cmd); } bool NamedOutputs::OutputResult(char const* name, diff --git a/src/webnn/wire/client/NamedOutputs.h b/src/webnn/wire/client/NamedOutputs.h index f57d27545..9d7fdfa8b 100644 --- a/src/webnn/wire/client/NamedOutputs.h +++ b/src/webnn/wire/client/NamedOutputs.h @@ -29,8 +29,8 @@ namespace webnn::wire::client { public: using ObjectBase::ObjectBase; - void Set(char const* name, WNNResource const* resource); - void Get(char const* name, WNNArrayBufferView const* resource); + void SetOutput(char const* name, WNNResource const* resource); + void GetOutput(char const* name, WNNArrayBufferView const* resource); bool OutputResult(char const* name, uint8_t const* buffer, size_t byteLength, diff --git a/src/webnn/wire/server/Server.cpp b/src/webnn/wire/server/Server.cpp index ee205bb63..d5dea8feb 100644 --- a/src/webnn/wire/server/Server.cpp +++ b/src/webnn/wire/server/Server.cpp @@ -158,28 +158,6 @@ namespace webnn::wire::server { return true; } - bool Server::DoCreateGraphBuilder(ObjectId contextId, ObjectHandle result) { - auto* context = ContextObjects().Get(contextId); - if (context == nullptr) { - return false; - } - - // Create and register the GraphBuilder object. - auto* resultData = GraphBuilderObjects().Allocate(result.id); - if (resultData == nullptr) { - return false; - } - resultData->generation = result.generation; - resultData->contextInfo = context->contextInfo; - if (resultData->contextInfo != nullptr) { - if (!TrackContextChild(resultData->contextInfo, ObjectType::GraphBuilder, result.id)) { - return false; - } - } - resultData->handle = mProcs.createGraphBuilder(context->handle); - return true; - } - #if defined(WEBNN_ENABLE_GPU_BUFFER) WGPUDevice Server::GetWGPUDevice(uint32_t id, uint32_t generation) { return mDawnWireServer->GetDevice(id, generation); diff --git a/src/webnn/wire/server/ServerGraph.cpp b/src/webnn/wire/server/ServerGraph.cpp index b80ade192..2a8633382 100644 --- a/src/webnn/wire/server/ServerGraph.cpp +++ b/src/webnn/wire/server/ServerGraph.cpp @@ -25,7 +25,7 @@ namespace webnn::wire::server { } for (auto& name : mOutputNamesMap[outputsId]) { WNNArrayBufferView arrayBuffer = {}; - mProcs.namedOutputsGet(namedOutputs->handle, name.data(), &arrayBuffer); + mProcs.namedOutputsGetOutput(namedOutputs->handle, name.data(), &arrayBuffer); if (arrayBuffer.buffer == nullptr) { return false; } @@ -34,7 +34,7 @@ namespace webnn::wire::server { ReturnGraphComputeResultCmd cmd; cmd.namedOutputs = ObjectHandle{outputsId, namedOutputs->generation}; cmd.name = name.data(); - cmd.buffer = static_cast(arrayBuffer.buffer); + cmd.arrayBuffer = static_cast(arrayBuffer.buffer); cmd.byteLength = arrayBuffer.byteLength; cmd.byteOffset = arrayBuffer.byteOffset; SerializeCommand(cmd); diff --git a/src/webnn/wire/server/ServerNamedOutputs.cpp b/src/webnn/wire/server/ServerNamedOutputs.cpp index b16b2ce0f..ff4100881 100644 --- a/src/webnn/wire/server/ServerNamedOutputs.cpp +++ b/src/webnn/wire/server/ServerNamedOutputs.cpp @@ -17,7 +17,7 @@ namespace webnn::wire::server { - bool Server::DoNamedOutputsSet(ObjectId namedOutputsId, + bool Server::DoNamedOutputsSetOutput(ObjectId namedOutputsId, char const* name, size_t byteLength, size_t byteOffset, @@ -50,9 +50,27 @@ namespace webnn::wire::server { outputNames.push_back(std::string(name)); } } - mProcs.namedOutputsSet(namedOutputs->handle, name, &resource); + mProcs.namedOutputsSetOutput(namedOutputs->handle, name, &resource); return true; } + bool Server::DoNamedOutputsGetOutput(ObjectId namedOutputsId, + char const* name, + uint8_t const* buffer, + size_t byteLength, + size_t byteOffset) { + auto* namedOutputs = NamedOutputsObjects().Get(namedOutputsId); + if (namedOutputs == nullptr) { + return false; + } + + WNNArrayBufferView arrayBuffer = {}; + arrayBuffer.buffer = const_cast(static_cast(buffer)); + arrayBuffer.byteLength = byteLength; + arrayBuffer.byteOffset = byteOffset; + mProcs.namedOutputsGetOutput(namedOutputs->handle, name, &arrayBuffer); + + return true; + } } // namespace webnn::wire::server diff --git a/third_party/gn/abseil-cpp/BUILD.gn b/third_party/gn/abseil-cpp/BUILD.gn new file mode 100644 index 000000000..9229d7228 --- /dev/null +++ b/third_party/gn/abseil-cpp/BUILD.gn @@ -0,0 +1,170 @@ +# Copyright 2021 The Dawn Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import("../../../scripts/webnn_overrides_with_defaults.gni") + +import("${dawn_root}/scripts/dawn_features.gni") + +config("absl_config") { + if (webnn_standalone && is_clang) { + cflags = [ + # Allow the use of enable_if() + "-Wno-gcc-compat", + ] + } + + include_dirs = [ "${webnn_abseil_dir}" ] +} + +template("absl_source_set") { + source_set(target_name) { + forward_variables_from(invoker, "*") + + if (!defined(public_configs)) { + public_configs = [] + } + public_configs += [ ":absl_config" ] + } +} + +# +# absl/base +# + +absl_source_set("log_severity") { + sources = [ "${webnn_abseil_dir}/absl/base/log_severity.cc" ] + public = [ "${webnn_abseil_dir}/absl/base/log_severity.h" ] +} + +absl_source_set("raw_logging_internal") { + sources = [ "${webnn_abseil_dir}/absl/base/internal/raw_logging.cc" ] + public = [ "${webnn_abseil_dir}/absl/base/internal/raw_logging.h" ] + public_deps = [ ":log_severity" ] + visibility = [ ":*" ] +} + +absl_source_set("throw_delegate") { + sources = [ "${webnn_abseil_dir}/absl/base/internal/throw_delegate.cc" ] + public = [ "${webnn_abseil_dir}/absl/base/internal/throw_delegate.h" ] + public_deps = [ ":raw_logging_internal" ] + visibility = [ ":*" ] +} + +# +# absl/numeric +# + +absl_source_set("int128") { + sources = [ + "${webnn_abseil_dir}/absl/numeric/int128.cc", + "${webnn_abseil_dir}/absl/numeric/int128_have_intrinsic.inc", + "${webnn_abseil_dir}/absl/numeric/int128_no_intrinsic.inc", + ] + public = [ "${webnn_abseil_dir}/absl/numeric/int128.h" ] +} + +# +# absl/strings +# + +absl_source_set("strings") { + sources = [ + "${webnn_abseil_dir}/absl/strings/ascii.cc", + "${webnn_abseil_dir}/absl/strings/charconv.cc", + "${webnn_abseil_dir}/absl/strings/escaping.cc", + "${webnn_abseil_dir}/absl/strings/internal/charconv_bigint.cc", + "${webnn_abseil_dir}/absl/strings/internal/charconv_bigint.h", + "${webnn_abseil_dir}/absl/strings/internal/charconv_parse.cc", + "${webnn_abseil_dir}/absl/strings/internal/charconv_parse.h", + "${webnn_abseil_dir}/absl/strings/internal/memutil.cc", + "${webnn_abseil_dir}/absl/strings/internal/memutil.h", + "${webnn_abseil_dir}/absl/strings/internal/stl_type_traits.h", + "${webnn_abseil_dir}/absl/strings/internal/str_join_internal.h", + "${webnn_abseil_dir}/absl/strings/internal/str_split_internal.h", + "${webnn_abseil_dir}/absl/strings/match.cc", + "${webnn_abseil_dir}/absl/strings/numbers.cc", + "${webnn_abseil_dir}/absl/strings/str_cat.cc", + "${webnn_abseil_dir}/absl/strings/str_replace.cc", + "${webnn_abseil_dir}/absl/strings/str_split.cc", + "${webnn_abseil_dir}/absl/strings/string_view.cc", + "${webnn_abseil_dir}/absl/strings/substitute.cc", + ] + public = [ + "${webnn_abseil_dir}/absl/strings/ascii.h", + "${webnn_abseil_dir}/absl/strings/charconv.h", + "${webnn_abseil_dir}/absl/strings/escaping.h", + "${webnn_abseil_dir}/absl/strings/internal/string_constant.h", + "${webnn_abseil_dir}/absl/strings/match.h", + "${webnn_abseil_dir}/absl/strings/numbers.h", + "${webnn_abseil_dir}/absl/strings/str_cat.h", + "${webnn_abseil_dir}/absl/strings/str_join.h", + "${webnn_abseil_dir}/absl/strings/str_replace.h", + "${webnn_abseil_dir}/absl/strings/str_split.h", + "${webnn_abseil_dir}/absl/strings/string_view.h", + "${webnn_abseil_dir}/absl/strings/strip.h", + "${webnn_abseil_dir}/absl/strings/substitute.h", + ] + deps = [ + ":int128", + ":raw_logging_internal", + ":strings_internal", + ":throw_delegate", + ] +} + +absl_source_set("strings_internal") { + sources = [ + "${webnn_abseil_dir}/absl/strings/internal/escaping.cc", + "${webnn_abseil_dir}/absl/strings/internal/ostringstream.cc", + "${webnn_abseil_dir}/absl/strings/internal/utf8.cc", + ] + public = [ + "${webnn_abseil_dir}/absl/strings/internal/char_map.h", + "${webnn_abseil_dir}/absl/strings/internal/escaping.h", + "${webnn_abseil_dir}/absl/strings/internal/ostringstream.h", + "${webnn_abseil_dir}/absl/strings/internal/resize_uninitialized.h", + "${webnn_abseil_dir}/absl/strings/internal/utf8.h", + ] + deps = [ ":raw_logging_internal" ] +} + +absl_source_set("str_format") { + public = [ "${webnn_abseil_dir}/absl/strings/str_format.h" ] + deps = [ ":str_format_internal" ] +} + +absl_source_set("str_format_internal") { + sources = [ + "${webnn_abseil_dir}/absl/strings/internal/str_format/arg.cc", + "${webnn_abseil_dir}/absl/strings/internal/str_format/bind.cc", + "${webnn_abseil_dir}/absl/strings/internal/str_format/extension.cc", + "${webnn_abseil_dir}/absl/strings/internal/str_format/float_conversion.cc", + "${webnn_abseil_dir}/absl/strings/internal/str_format/output.cc", + "${webnn_abseil_dir}/absl/strings/internal/str_format/parser.cc", + ] + public = [ + "${webnn_abseil_dir}/absl/strings/internal/str_format/arg.h", + "${webnn_abseil_dir}/absl/strings/internal/str_format/bind.h", + "${webnn_abseil_dir}/absl/strings/internal/str_format/checker.h", + "${webnn_abseil_dir}/absl/strings/internal/str_format/extension.h", + "${webnn_abseil_dir}/absl/strings/internal/str_format/float_conversion.h", + "${webnn_abseil_dir}/absl/strings/internal/str_format/output.h", + "${webnn_abseil_dir}/absl/strings/internal/str_format/parser.h", + ] + visibility = [ ":*" ] + deps = [ + ":int128", + ":strings", + ] +} diff --git a/webnn.json b/webnn.json index ef4546481..82507f903 100644 --- a/webnn.json +++ b/webnn.json @@ -15,6 +15,21 @@ "See the License for the specific language governing permissions and", "limitations under the License." ], + + "_metadata": { + "api": "WebNN", + "c_prefix": "WNN", + "namespace": "wnn", + "native_namespace": "webnn native", + "wire_namespace": "webnn wire", + "proc_table_prefix": "Webnn" + }, + + "proc": { + "category": "function pointer", + "returns": "void", + "args": [] + }, "instance descriptor": { "category": "structure", "members": [] @@ -145,7 +160,7 @@ ] }, "error callback": { - "category": "callback", + "category": "function pointer", "args": [ {"name": "type", "type": "error type"}, {"name": "message", "type": "char", "annotation": "const*"}, @@ -184,7 +199,7 @@ {"name": "type", "type": "error type"}, {"name": "message", "type": "char", "annotation": "const*", "length": "strlen"} ], - "TODO": "enga@: Make this a Dawn extension" + "_TODO": "enga@: Make this a Dawn extension" }, { "name": "set uncaptured error callback", @@ -251,13 +266,14 @@ "methods": [ { "name": "size", - "returns": "size_t" + "returns": "uint64_t", + "args": [] }, { - "name": "get", + "name": "get operand", "returns": "operand", "args": [ - {"name": "index", "type": "size_t"} + {"name": "index", "type": "uint64_t"} ] } ] @@ -267,20 +283,21 @@ "methods": [ { "name": "size", - "returns": "size_t" + "returns": "uint64_t", + "args": [] }, { - "name": "set", + "name": "set fusion operator", "returns": "void", "args": [ - {"name": "operator", "type": "fusion operator"} + {"name": "fusion operator", "type": "fusion operator"} ] }, { - "name": "get", + "name": "get fusion operator", "returns": "fusion operator", "args": [ - {"name": "index", "type": "size_t"} + {"name": "index", "type": "uint64_t"} ] } ] @@ -366,11 +383,11 @@ "category": "structure", "members": [ {"name": "padding count", "type": "uint32_t", "default": 0}, - {"name": "padding", "type": "int32_t", "annotation": "const*", "length": "padding count", "optional": true}, + {"name": "padding", "type": "int32_t", "annotation": "const*", "length": "padding count", "default": "nullptr"}, {"name": "strides count", "type": "uint32_t", "default": 0}, - {"name": "strides", "type": "int32_t", "annotation": "const*", "length": "strides count", "optional": true}, + {"name": "strides", "type": "int32_t", "annotation": "const*", "length": "strides count", "default": "nullptr"}, {"name": "dilations count", "type": "uint32_t", "default": 0}, - {"name": "dilations", "type": "int32_t", "annotation": "const*", "length": "dilations count", "optional": true}, + {"name": "dilations", "type": "int32_t", "annotation": "const*", "length": "dilations count", "default": "nullptr"}, {"name": "auto pad", "type": "auto pad", "default": "explicit"}, {"name": "groups", "type": "int32_t", "default": 1}, {"name": "input layout", "type": "input operand layout", "default": "nchw"}, @@ -383,15 +400,15 @@ "category": "structure", "members": [ {"name": "padding count", "type": "uint32_t", "default": 0}, - {"name": "padding", "type": "int32_t", "annotation": "const*", "length": "padding count", "optional": true}, + {"name": "padding", "type": "int32_t", "annotation": "const*", "length": "padding count", "default": "nullptr"}, {"name": "strides count", "type": "uint32_t", "default": 0}, - {"name": "strides", "type": "int32_t", "annotation": "const*", "length": "strides count", "optional": true}, + {"name": "strides", "type": "int32_t", "annotation": "const*", "length": "strides count", "default": "nullptr"}, {"name": "dilations count", "type": "uint32_t", "default": 0}, - {"name": "dilations", "type": "int32_t", "annotation": "const*", "length": "dilations count", "optional": true}, + {"name": "dilations", "type": "int32_t", "annotation": "const*", "length": "dilations count", "default": "nullptr"}, {"name": "output padding count", "type": "uint32_t", "default": 0}, - {"name": "output padding", "type": "int32_t", "annotation": "const*", "length": "output padding count", "optional": true}, + {"name": "output padding", "type": "int32_t", "annotation": "const*", "length": "output padding count", "default": "nullptr"}, {"name": "output sizes count", "type": "uint32_t", "default": 0}, - {"name": "output sizes", "type": "int32_t", "annotation": "const*", "length": "output sizes count", "optional": true}, + {"name": "output sizes", "type": "int32_t", "annotation": "const*", "length": "output sizes count", "default": "nullptr"}, {"name": "auto pad", "type": "auto pad", "default": "explicit"}, {"name": "groups", "type": "int32_t", "default": 1}, {"name": "input layout", "type": "input operand layout", "default": "nchw"}, @@ -404,7 +421,7 @@ "category": "structure", "members": [ {"name": "axes count", "type": "uint32_t", "default": 0}, - {"name": "axes", "type": "int32_t", "annotation": "const*", "length": "axes count", "optional": true} + {"name": "axes", "type": "int32_t", "annotation": "const*", "length": "axes count", "default": "nullptr"} ] }, "gru options": { @@ -431,18 +448,18 @@ "category": "structure", "members": [ {"name": "window dimensions count", "type": "uint32_t", "default": 0}, - {"name": "window dimensions", "type": "int32_t", "annotation": "const*", "length": "window dimensions count", "optional": true}, + {"name": "window dimensions", "type": "int32_t", "annotation": "const*", "length": "window dimensions count", "default": "nullptr"}, {"name": "padding count", "type": "uint32_t", "default": 0}, - {"name": "padding", "type": "int32_t", "annotation": "const*", "length": "padding count", "optional": true}, + {"name": "padding", "type": "int32_t", "annotation": "const*", "length": "padding count", "default": "nullptr"}, {"name": "strides count", "type": "uint32_t", "default": 0}, - {"name": "strides", "type": "int32_t", "annotation": "const*", "length": "strides count", "optional": true}, + {"name": "strides", "type": "int32_t", "annotation": "const*", "length": "strides count", "default": "nullptr"}, {"name": "dilations count", "type": "uint32_t", "default": 0}, - {"name": "dilations", "type": "int32_t", "annotation": "const*", "length": "dilations count", "optional": true}, + {"name": "dilations", "type": "int32_t", "annotation": "const*", "length": "dilations count", "default": "nullptr"}, {"name": "auto pad", "type": "auto pad", "default": "explicit"}, {"name": "layout", "type": "input operand layout", "default": "nchw"}, {"name": "rounding type", "type": "rounding type", "default": "floor"}, {"name": "output sizes count", "type": "uint32_t", "default": 0}, - {"name": "output sizes", "type": "int32_t", "annotation": "const*", "length": "output sizes count", "optional": true} + {"name": "output sizes", "type": "int32_t", "annotation": "const*", "length": "output sizes count", "default": "nullptr"} ] }, "gemm options": { @@ -465,7 +482,7 @@ "category": "structure", "members": [ {"name": "axes count", "type": "uint32_t", "default": 0}, - {"name": "axes", "type": "int32_t", "annotation": "const*", "length": "axes count","optional": true}, + {"name": "axes", "type": "int32_t", "annotation": "const*", "length": "axes count", "default": "nullptr"}, {"name": "keepDimensions", "type": "bool", "default": "false"} ] }, @@ -474,11 +491,11 @@ "members": [ {"name": "mode", "type": "interpolation mode", "default": "nearest neighbor"}, {"name": "scales count", "type": "uint32_t", "default": 0}, - {"name": "scales", "type": "float", "annotation": "const*", "length": "scales count", "optional": true}, + {"name": "scales", "type": "float", "annotation": "const*", "length": "scales count", "default": "nullptr" }, {"name": "sizes count", "type": "uint32_t", "default": 0}, - {"name": "sizes", "type": "int32_t", "annotation": "const*", "length": "sizes count", "optional": true}, + {"name": "sizes", "type": "int32_t", "annotation": "const*", "length": "sizes count", "default": "nullptr"}, {"name": "axes count", "type": "uint32_t", "default": 0}, - {"name": "axes", "type": "int32_t", "annotation": "const*", "length": "axes count", "optional": true} + {"name": "axes", "type": "int32_t", "annotation": "const*", "length": "axes count", "default": "nullptr"} ] }, "split options": { @@ -491,14 +508,14 @@ "category": "structure", "members": [ {"name": "axes count", "type": "uint32_t", "default": 0}, - {"name": "axes", "type": "int32_t", "annotation": "const*", "length": "axes count", "optional": true} + {"name": "axes", "type": "int32_t", "annotation": "const*", "length": "axes count", "default": "nullptr"} ] }, "transpose options": { "category": "structure", "members": [ {"name": "permutation count", "type": "uint32_t", "default": 0}, - {"name": "permutation", "type": "int32_t", "annotation": "const*", "length": "permutation count", "optional": true} + {"name": "permutation", "type": "int32_t", "annotation": "const*", "length": "permutation count", "default": "nullptr"} ] }, "batchNorm options": { @@ -721,7 +738,8 @@ }, { "name": "hard swish operator", - "returns": "fusion operator" + "returns": "fusion operator", + "args": [] }, { "name": "gemm", @@ -861,7 +879,8 @@ }, { "name": "relu operator", - "returns": "fusion operator" + "returns": "fusion operator", + "args": [] }, { "name": "reshape", @@ -881,7 +900,8 @@ }, { "name": "sigmoid operator", - "returns": "fusion operator" + "returns": "fusion operator", + "args": [] }, { "name": "slice", @@ -930,7 +950,8 @@ }, { "name": "tanh operator", - "returns": "fusion operator" + "returns": "fusion operator", + "args": [] }, { "name": "transpose", @@ -1002,7 +1023,7 @@ "category": "structure", "members": [ {"name": "resource", "type": "resource"}, - {"name": "dimensions", "type": "int32_t", "annotation": "const*", "length": "dimensions count", "optional": true}, + {"name": "dimensions", "type": "int32_t", "annotation": "const*", "length": "dimensions count", "default": "nullptr"}, {"name": "dimensions count", "type": "uint32_t", "default": 0} ] }, @@ -1048,14 +1069,14 @@ "category": "object", "methods": [ { - "name": "set", + "name": "set output", "args": [ {"name": "name", "type": "char", "annotation": "const*", "length": "strlen"}, {"name": "resource", "type": "resource", "annotation": "const*"} ] }, { - "name": "get", + "name": "get output", "args": [ {"name": "name", "type": "char", "annotation": "const*", "length": "strlen"}, {"name": "resource", "type": "array buffer view", "annotation": "*"} @@ -1064,7 +1085,7 @@ ] }, "compute async callback": { - "category": "callback", + "category": "function pointer", "args": [ {"name": "type", "type": "error type"}, {"name": "message", "type": "char", "annotation": "const*", "length": "strlen"}, @@ -1093,5 +1114,11 @@ ] } ] + }, + "s type": { + "category": "enum", + "values": [ + {"value": 0, "name": "invalid", "valid": false} + ] } } diff --git a/webnn_wire.json b/webnn_wire.json index c947002ca..1c012f7ca 100644 --- a/webnn_wire.json +++ b/webnn_wire.json @@ -23,19 +23,19 @@ "graph builder constant internal": [ {"name": "graph builder id", "type": "ObjectId"}, {"name": "desc", "type": "operand descriptor", "annotation": "const*"}, - {"name": "buffer", "type": "uint8_t", "annotation": "const*", "length": "byte length"}, - {"name": "byte length", "type": "size_t"}, - {"name": "byte offset", "type": "size_t", "default": 0}, + {"name": "array buffer", "type": "uint8_t", "annotation": "const*", "length": "byte length"}, + {"name": "byte length", "type": "uint64_t"}, + {"name": "byte offset", "type": "uint64_t", "default": 0}, {"name": "result", "type": "ObjectHandle", "handle_type": "operand"} ], "graph builder constant with gpu buffer internal": [ {"name": "graph builder id", "type": "ObjectId"}, {"name": "desc", "type": "operand descriptor", "annotation": "const*"}, - {"name": "buffer", "type": "uint8_t", "annotation": "const*", "optional": true}, + {"name": "array buffer", "type": "uint8_t", "annotation": "const*", "optional": true}, {"name": "id", "type": "uint32_t", "default": 0}, {"name": "generation", "type": "uint32_t", "default": 0}, - {"name": "byte length", "type": "size_t"}, - {"name": "byte offset", "type": "size_t", "default": 0}, + {"name": "byte length", "type": "uint64_t"}, + {"name": "byte offset", "type": "uint64_t", "default": 0}, {"name": "result", "type": "ObjectHandle", "handle_type": "operand"} ], "graph builder gru internal": [ @@ -83,29 +83,32 @@ "named inputs set": [ {"name": "named inputs id", "type": "ObjectId"}, {"name": "name", "type": "char", "annotation": "const*", "length": "strlen"}, - {"name": "buffer", "type": "uint8_t", "annotation": "const*", "length": "byte length", "optional": true}, - {"name": "byte length", "type": "size_t"}, - {"name": "byte offset", "type": "size_t", "default": 0}, + {"name": "array buffer", "type": "uint8_t", "annotation": "const*", "length": "byte length", "default": "nullptr"}, + {"name": "byte length", "type": "uint64_t"}, + {"name": "byte offset", "type": "uint64_t", "default": 0}, {"name": "gpu buffer id", "type": "uint32_t", "default": 0}, {"name": "gpu buffer generation", "type": "uint32_t", "default": 0}, - {"name": "dimensions", "type": "int32_t", "annotation": "const*", "length": "dimensions count", "optional": true}, + {"name": "dimensions", "type": "int32_t", "annotation": "const*", "length": "dimensions count", "default": "nullptr"}, {"name": "dimensions count", "type": "uint32_t", "default": 0} ], - "named outputs set": [ + "named outputs set output": [ {"name": "named outputs id", "type": "ObjectId"}, {"name": "name", "type": "char", "annotation": "const*", "length": "strlen"}, - {"name": "byte length", "type": "size_t"}, - {"name": "byte offset", "type": "size_t", "default": 0}, + {"name": "byte length", "type": "uint64_t"}, + {"name": "byte offset", "type": "uint64_t", "default": 0}, {"name": "gpu buffer id", "type": "uint32_t", "default": 0}, {"name": "gpu buffer generation", "type": "uint32_t", "default": 0} ], + "named outputs get output": [ + {"name": "named outputs id", "type": "ObjectId"}, + {"name": "name", "type": "char", "annotation": "const*", "length": "strlen"}, + {"name": "array buffer", "type": "uint8_t", "annotation": "const*", "length": "byte length", "default": "nullptr"}, + {"name": "byte length", "type": "uint64_t"}, + {"name": "byte offset", "type": "uint64_t", "default": 0} + ], "destroy object": [ {"name": "object type", "type": "ObjectType"}, {"name": "object id", "type": "ObjectId"} - ], - "create graph builder": [ - {"name": "context", "type": "ObjectId"}, - {"name": "result", "type": "ObjectHandle", "handle_type": "graph builder"} ] }, "return commands": { @@ -118,9 +121,9 @@ "graph compute result": [ {"name": "named outputs", "type": "ObjectHandle", "handle_type": "named outputs"}, {"name": "name", "type": "char", "annotation": "const*", "length": "strlen"}, - {"name": "buffer", "type": "uint8_t", "annotation": "const*", "length": "byte length"}, - {"name": "byte length", "type": "size_t"}, - {"name": "byte offset", "type": "size_t", "default": 0} + {"name": "array buffer", "type": "uint8_t", "annotation": "const*", "length": "byte length"}, + {"name": "byte length", "type": "uint64_t"}, + {"name": "byte offset", "type": "uint64_t", "default": 0} ], "graph compute async callback": [ { "name": "graph", "type": "ObjectHandle", "handle_type": "graph" }, @@ -146,8 +149,8 @@ "GraphBuilderSplit", "InstanceCreateContextWithGpuDevice", "NamedInputsSet", - "NamedOutputsSet", - "NamedOutputsGet", + "NamedOutputsSetOutput", + "NamedOutputsGetOutput", "OperandArraySize", "OperatorArraySize", "GraphComputeAsync",