diff --git a/DEPS b/DEPS index 4470b983a..09d555211 100644 --- a/DEPS +++ b/DEPS @@ -8,7 +8,8 @@ gclient_gn_args = [ vars = { 'chromium_git': 'https://chromium.googlesource.com', - 'dawn_git': 'https://github.com/fujunwei', + # 'dawn_git': 'https://github.com/fujunwei', + 'dawn_git': 'https://github.com/lisa0314', 'github_git': 'https://github.com', 'dawn_standalone': True, @@ -45,9 +46,15 @@ deps = { # Dependencies required for code generator and infrastructure code. 'third_party/dawn': { - 'url': '{dawn_git}/dawn.git@f4c84e239bf8b5b2c4733d68ca38e1e9049fd895' + # 'url': '{dawn_git}/dawn.git@f4c84e239bf8b5b2c4733d68ca38e1e9049fd895' + 'url': '{dawn_git}/dawn.git@5e6f6fbfcb038e7a0f7857cda186a8771c6eba05' }, + 'third_party/abseil-cpp': { + 'url': '{chromium_git}/chromium/src/third_party/abseil-cpp@789af048b388657987c59d4da406859034fe310f', + 'condition': 'dawn_standalone', + }, + # Dependencies required for backends. 'third_party/DirectML': { 'url': '{github_git}/microsoft/DirectML.git@c3f16a701beeeefc9ce5b67c71b554a6903c0f67', @@ -136,7 +143,7 @@ deps = { # Jinja2 and MarkupSafe for the code generator 'third_party/jinja2': { - 'url': '{chromium_git}/chromium/src/third_party/jinja2@a82a4944a7f2496639f34a89c9923be5908b80aa', + 'url': '{chromium_git}/chromium/src/third_party/jinja2@ee69aa00ee8536f61db6a451f3858745cf587de6', 'condition': 'dawn_standalone', }, 'third_party/markupsafe': { diff --git a/examples/MobileNetV2/Main.cpp b/examples/MobileNetV2/Main.cpp index bec9759e1..94ef00563 100644 --- a/examples/MobileNetV2/Main.cpp +++ b/examples/MobileNetV2/Main.cpp @@ -39,9 +39,10 @@ int main(int argc, const char* argv[]) { } }, &mobilevetv2); + wnn::GraphBuilder builder = utils::CreateGraphBuilder(context); - wnn::Operand output = mobilevetv2.mLayout == "nchw" ? mobilevetv2.LoadNCHW(builder) - : mobilevetv2.LoadNHWC(builder); + wnn::Operand output = mobilevetv2.mLayout == "nchw" ? mobilevetv2.LoadNchw(builder) + : mobilevetv2.LoadNhwc(builder); // Build the graph. const std::chrono::time_point compilationStartTime = diff --git a/examples/SampleUtils.cpp b/examples/SampleUtils.cpp index 2da6009b6..18175ead5 100644 --- a/examples/SampleUtils.cpp +++ b/examples/SampleUtils.cpp @@ -62,12 +62,14 @@ wnn::Context CreateCppContext(wnn::ContextOptions const* options) { // Choose whether to use the backend procs and context directly, or set up the wire. WNNContext context = nullptr; WebnnProcTable procs; + WNNInstance wnnInstance; + switch (cmdBufType) { case CmdBufType::None: procs = backendProcs; context = backendContext; - instance = wnn::Instance(nativeInstance->Get()); + wnnInstance = nativeInstance->Get(); break; case CmdBufType::Terrible: { @@ -100,7 +102,8 @@ wnn::Context CreateCppContext(wnn::ContextOptions const* options) { instanceReservation.generation); // Keep the reference instread of using Acquire. // TODO:: make the instance in the client as singleton object. - instance = wnn::Instance(instanceReservation.instance); + wnnInstance = instanceReservation.instance; + break; #endif } default: @@ -108,6 +111,7 @@ wnn::Context CreateCppContext(wnn::ContextOptions const* options) { DAWN_ASSERT(0); } webnnProcSetProcs(&procs); + instance = wnn::Instance(wnnInstance); return instance.CreateContext(options); ; } diff --git a/generator/webnn_generator.gni b/generator/webnn_generator.gni index 764387136..0c213c44a 100644 --- a/generator/webnn_generator.gni +++ b/generator/webnn_generator.gni @@ -94,6 +94,7 @@ template("webnn_generator") { deps = [ "${dawn_root}/generator:remove_stale_autogen_files" ] template_dir = "${dawn_root}/generator/templates" + } } @@ -121,5 +122,6 @@ template("webnn_json_generator") { ] forward_variables_from(invoker, "*", [ "target" ]) + } } diff --git a/include/webnn/BUILD.gn b/include/webnn/BUILD.gn index 38cd18e30..f21d89562 100644 --- a/include/webnn/BUILD.gn +++ b/include/webnn/BUILD.gn @@ -66,6 +66,8 @@ config("public") { include_dirs = [ "${target_gen_dir}/../../include", "${webnn_root}/include", + "${dawn_root}/include", + "${dawn_gen_root}/include", ] if (build_with_chromium) { diff --git a/include/webnn/wire/Wire.h b/include/webnn/wire/Wire.h index 5703754e2..dd6da5dab 100644 --- a/include/webnn/wire/Wire.h +++ b/include/webnn/wire/Wire.h @@ -26,7 +26,10 @@ namespace webnn::wire { class WEBNN_WIRE_EXPORT CommandSerializer { public: - virtual ~CommandSerializer() = default; + CommandSerializer(); + virtual ~CommandSerializer(); + CommandSerializer(const CommandSerializer& rhs) = delete; + CommandSerializer& operator=(const CommandSerializer& rhs) = delete; // Get space for serializing commands. // GetCmdSpace will never be called with a value larger than @@ -35,6 +38,7 @@ namespace webnn::wire { virtual void* GetCmdSpace(size_t size) = 0; virtual bool Flush() = 0; virtual size_t GetMaximumAllocationSize() const = 0; + virtual void OnSerializeError(); }; class WEBNN_WIRE_EXPORT CommandHandler { diff --git a/src/webnn/common/BUILD.gn b/src/webnn/common/BUILD.gn index 872983b6d..a967db953 100644 --- a/src/webnn/common/BUILD.gn +++ b/src/webnn/common/BUILD.gn @@ -127,6 +127,7 @@ config("internal_config") { "-Wshadow-field", "-Wstrict-prototypes", "-Wtautological-unsigned-zero-compare", + "-Wno-unused-function", ] # Allow comparison against type limits that might be tautological on 32bit @@ -184,6 +185,14 @@ config("internal_config") { "-Wno-c++17-extensions", ] } + + + if (is_clang && webnn_enable_wire) { + cflags += [ + "-Wno-unused-function", + ] + } + } ############################################################################### diff --git a/src/webnn/native/ops/Conv2d.h b/src/webnn/native/ops/Conv2d.h index 4837b3cfb..c668ac7ec 100644 --- a/src/webnn/native/ops/Conv2d.h +++ b/src/webnn/native/ops/Conv2d.h @@ -32,7 +32,7 @@ namespace webnn::native::op { if (options != nullptr && options->bias != nullptr) { mInputs.push_back(options->bias); } - if (options == nullptr || options->padding == nullptr) { + if (options == nullptr || options->paddingCount == 0) { mPadding = std::vector(4, 0); } else { mPadding.assign(options->padding, options->padding + options->paddingCount); @@ -40,7 +40,7 @@ namespace webnn::native::op { mOptions.padding = mPadding.data(); mOptions.paddingCount = mPadding.size(); - if (options == nullptr || options->strides == nullptr) { + if (options == nullptr || options->stridesCount == 0) { mStride = std::vector(2, 1); } else { mStride.assign(options->strides, options->strides + options->stridesCount); @@ -48,7 +48,7 @@ namespace webnn::native::op { mOptions.strides = mStride.data(); mOptions.stridesCount = mStride.size(); - if (options == nullptr || options->dilations == nullptr) { + if (options == nullptr || options->dilationsCount == 0) { mDilations = std::vector(2, 1); } else { mDilations.assign(options->dilations, options->dilations + options->dilationsCount); diff --git a/src/webnn/native/ops/ConvTranspose2d.cpp b/src/webnn/native/ops/ConvTranspose2d.cpp index 736563b82..9b4ccb6a6 100644 --- a/src/webnn/native/ops/ConvTranspose2d.cpp +++ b/src/webnn/native/ops/ConvTranspose2d.cpp @@ -25,7 +25,7 @@ namespace webnn::native::op { OperandBase* filter, ConvTranspose2dOptions const* options) : Conv2dBase(builder, input, filter, options) { - if (options == nullptr || options->outputPadding == nullptr) { + if (options == nullptr || options->outputPaddingCount == 0) { mOutputPadding = std::vector(2, 0); } else { mOutputPadding.assign(options->outputPadding, @@ -34,7 +34,7 @@ namespace webnn::native::op { mOptions.outputPadding = mOutputPadding.data(); mOptions.outputPaddingCount = mOutputPadding.size(); - if (options != nullptr && options->outputSizes != nullptr) { + if (options != nullptr && options->outputSizesCount != 0) { mOutputSizes.assign(options->outputSizes, options->outputSizes + options->outputSizesCount); mOptions.outputSizes = mOutputSizes.data(); @@ -113,7 +113,7 @@ namespace webnn::native::op { } int32_t outputHeight, outputWidth; - if (mOptions.outputSizes != nullptr) { + if (mOptions.outputSizesCount != 0) { outputHeight = mOptions.outputSizes[0]; outputWidth = mOptions.outputSizes[1]; } else { diff --git a/src/webnn/native/ops/Pool2d.cpp b/src/webnn/native/ops/Pool2d.cpp index 63050eb95..d1e662059 100644 --- a/src/webnn/native/ops/Pool2d.cpp +++ b/src/webnn/native/ops/Pool2d.cpp @@ -26,14 +26,14 @@ namespace webnn::native::op { OperandBase* input, Pool2dOptions const* options) : OperatorBase(builder, {input}), mOpType(opType) { - if (options != nullptr && options->windowDimensions != nullptr) { + if (options != nullptr && options->windowDimensionsCount != 0) { mWindowDimensions.assign(options->windowDimensions, options->windowDimensions + options->windowDimensionsCount); mOptions.windowDimensions = mWindowDimensions.data(); mOptions.windowDimensionsCount = mWindowDimensions.size(); } - if (options == nullptr || options->padding == nullptr) { + if (options == nullptr || options->paddingCount == 0) { mPadding = std::vector(4, 0); } else { mPadding.assign(options->padding, options->padding + options->paddingCount); @@ -41,7 +41,7 @@ namespace webnn::native::op { mOptions.padding = mPadding.data(); mOptions.paddingCount = mPadding.size(); - if (options == nullptr || options->strides == nullptr) { + if (options == nullptr || options->stridesCount == 0) { mStride = std::vector(2, 1); } else { mStride.assign(options->strides, options->strides + options->stridesCount); @@ -49,7 +49,7 @@ namespace webnn::native::op { mOptions.strides = mStride.data(); mOptions.stridesCount = mStride.size(); - if (options == nullptr || options->dilations == nullptr) { + if (options == nullptr || options->dilationsCount == 0) { mDilations = std::vector(2, 1); } else { mDilations.assign(options->dilations, options->dilations + options->dilationsCount); @@ -62,7 +62,7 @@ namespace webnn::native::op { mOptions.roundingType = options == nullptr ? wnn::RoundingType::Floor : options->roundingType; - if (options != nullptr && options->outputSizes != nullptr) { + if (options != nullptr && options->outputSizesCount != 0) { mOutputSizes.assign(options->outputSizes, options->outputSizes + options->outputSizesCount); mOptions.outputSizes = mOutputSizes.data(); diff --git a/src/webnn/native/ops/Reduce.cpp b/src/webnn/native/ops/Reduce.cpp index 8da0353b0..6d568efd5 100644 --- a/src/webnn/native/ops/Reduce.cpp +++ b/src/webnn/native/ops/Reduce.cpp @@ -26,7 +26,7 @@ namespace webnn::native::op { ReduceOptions const* options) : OperatorBase(builder, {input}), mOpType(opType) { // If axes are not present, all dimensions are reduced. - if (options == nullptr || options->axes == nullptr) { + if (options == nullptr || options->axesCount == 0) { int32_t rank = input->Shape().size(); mAxes.resize(rank); for (auto i = 0; i < rank; ++i) { diff --git a/src/webnn/native/ops/Resample2d.cpp b/src/webnn/native/ops/Resample2d.cpp index 6c3645354..5dae354d6 100644 --- a/src/webnn/native/ops/Resample2d.cpp +++ b/src/webnn/native/ops/Resample2d.cpp @@ -13,6 +13,7 @@ // limitations under the License. #include "webnn/native/ops/Resample2d.h" +#include "common/Log.h" #include @@ -25,13 +26,13 @@ namespace webnn::native::op { : OperatorBase(builder, {input}), mScales({1.0, 1.0}), mSizes({}), mAxes({2, 3}) { mOptions.mode = options == nullptr ? wnn::InterpolationMode::NearestNeighbor : options->mode; - if (options != nullptr && options->scales != nullptr) { + if (options != nullptr && options->scalesCount != 0) { mScales.assign(options->scales, options->scales + options->scalesCount); } - if (options != nullptr && options->sizes != nullptr) { + if (options != nullptr && options->sizesCount != 0) { mSizes.assign(options->sizes, options->sizes + options->sizesCount); } - if (options != nullptr && options->axes != nullptr) { + if (options != nullptr && options->axesCount != 0) { mAxes.assign(options->axes, options->axes + options->axesCount); } } diff --git a/src/webnn/native/ops/Slice.h b/src/webnn/native/ops/Slice.h index 7be33f9af..ffb18ebb9 100644 --- a/src/webnn/native/ops/Slice.h +++ b/src/webnn/native/ops/Slice.h @@ -33,7 +33,7 @@ namespace webnn::native::op { uint32_t sizesCount, SliceOptions const* options) : OperatorBase(builder, {input}) { - if (options != nullptr && options->axes != nullptr) { + if (options != nullptr && options->axesCount != 0) { mAxes.assign(options->axes, options->axes + options->axesCount); } diff --git a/src/webnn/native/ops/Transpose.h b/src/webnn/native/ops/Transpose.h index 497c9afa6..f0f3da044 100644 --- a/src/webnn/native/ops/Transpose.h +++ b/src/webnn/native/ops/Transpose.h @@ -25,7 +25,7 @@ namespace webnn::native::op { public: Transpose(GraphBuilderBase* builder, OperandBase* input, TransposeOptions const* options) : OperatorBase(builder, {input}) { - if (options == nullptr || options->permutation == nullptr) { + if (options == nullptr || options->permutationCount == 0) { int32_t rank = input->Shape().size(); mPermutation.resize(rank); for (auto i = 0; i < rank - 1; i++) { diff --git a/src/webnn/tests/end2end/Pool2dTests.cpp b/src/webnn/tests/end2end/Pool2dTests.cpp index b85c273c7..d71f117b8 100644 --- a/src/webnn/tests/end2end/Pool2dTests.cpp +++ b/src/webnn/tests/end2end/Pool2dTests.cpp @@ -662,7 +662,7 @@ TEST_F(Pool2dTests, GlobalAveragePool2dNhwc) { } TEST_F(Pool2dTests, DISABLED_L2Pool2dStridesDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 2, 4}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; @@ -677,7 +677,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dStridesDefault) { } TEST_F(Pool2dTests, DISABLED_L2Pool2dStrides) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 2, 4}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; @@ -693,7 +693,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dStrides) { } TEST_F(Pool2dTests, DISABLED_L2Pool2dStridesNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 2, 4}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; @@ -709,7 +709,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dStridesNhwc) { } TEST_F(Pool2dTests, DISABLED_L2Pool2dPadsDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 2, 4, 1}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; @@ -726,7 +726,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dPadsDefault) { } TEST_F(Pool2dTests, DISABLED_l2Pool2dPadsOutputSizes3x3) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 7, 7}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -751,7 +751,7 @@ TEST_F(Pool2dTests, DISABLED_l2Pool2dPadsOutputSizes3x3) { } TEST_F(Pool2dTests, DISABLED_l2Pool2dPadsOutputSizes4x4) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 7, 7}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -777,7 +777,7 @@ TEST_F(Pool2dTests, DISABLED_l2Pool2dPadsOutputSizes4x4) { } TEST_F(Pool2dTests, DISABLED_l2Pool2dPadsRoundingTypeFloor) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 7, 7}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -802,7 +802,7 @@ TEST_F(Pool2dTests, DISABLED_l2Pool2dPadsRoundingTypeFloor) { } TEST_F(Pool2dTests, DISABLED_l2Pool2dPadsRoundingTypeCeil) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 7, 7}); utils::Pool2dOptions options; options.windowDimensions = {4, 4}; @@ -828,7 +828,7 @@ TEST_F(Pool2dTests, DISABLED_l2Pool2dPadsRoundingTypeCeil) { } TEST_F(Pool2dTests, DISABLED_L2Pool2dPadsNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 2, 4, 1}); utils::Pool2dOptions options; options.windowDimensions = {3, 3}; @@ -846,7 +846,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dPadsNhwc) { } TEST_F(Pool2dTests, DISABLED_L2Pool2dSameUpperDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 2, 4}); utils::Pool2dOptions options; options.windowDimensions = {3, 3}; @@ -863,7 +863,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dSameUpperDefault) { } TEST_F(Pool2dTests, DISABLED_L2Pool2dSameUpperNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 2, 4, 1}); utils::Pool2dOptions options; options.windowDimensions = {3, 3}; @@ -881,7 +881,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dSameUpperNhwc) { } TEST_F(Pool2dTests, DISABLED_L2Pool2dSameLowerDefault) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 2, 4}); utils::Pool2dOptions options; options.windowDimensions = {3, 3}; @@ -898,7 +898,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dSameLowerDefault) { } TEST_F(Pool2dTests, DISABLED_L2Pool2dSameLowerNhwc) { - const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext()); + const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); const wnn::Operand x = utils::BuildInput(builder, "x", {1, 2, 4, 1}); utils::Pool2dOptions options; options.windowDimensions = {3, 3}; diff --git a/src/webnn/tests/end2end/models/MobileNetV2BatchNormNchw.cpp b/src/webnn/tests/end2end/models/MobileNetV2BatchNormNchw.cpp index 42ca325ff..d205dc251 100644 --- a/src/webnn/tests/end2end/models/MobileNetV2BatchNormNchw.cpp +++ b/src/webnn/tests/end2end/models/MobileNetV2BatchNormNchw.cpp @@ -27,7 +27,7 @@ class MobileNetV2BatchNormNchwTests : public WebnnTest { const std::string nchwPath = kModelPath + "/mobilenetv2_batchnorm_nchw/"; mobilenetv2.mWeightsPath = nchwPath + "weights/"; const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); - wnn::Operand output = mobilenetv2.LoadBatchNormNCHW(builder, false); + wnn::Operand output = mobilenetv2.LoadBatchNormNchw(builder, false); wnn::Graph graph = utils::Build(builder, {{"output", output}}); const cnpy::NpyArray inputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + inputFile); const std::vector inputData = inputNpy.as_vec(); diff --git a/src/webnn/tests/end2end/models/MobileNetV2Nchw.cpp b/src/webnn/tests/end2end/models/MobileNetV2Nchw.cpp index 45f016e3a..dcda6514f 100644 --- a/src/webnn/tests/end2end/models/MobileNetV2Nchw.cpp +++ b/src/webnn/tests/end2end/models/MobileNetV2Nchw.cpp @@ -27,7 +27,7 @@ class MobileNetV2NchwTests : public WebnnTest { const std::string nchwPath = kModelPath + "/mobilenetv2_nchw/"; mobilenetv2.mWeightsPath = nchwPath + "weights/"; const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); - wnn::Operand output = mobilenetv2.LoadNCHW(builder, false); + wnn::Operand output = mobilenetv2.LoadNchw(builder, false); wnn::Graph graph = utils::Build(builder, {{"output", output}}); const cnpy::NpyArray inputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + inputFile); const std::vector inputData = inputNpy.as_vec(); diff --git a/src/webnn/tests/end2end/models/MobileNetV2Nhwc.cpp b/src/webnn/tests/end2end/models/MobileNetV2Nhwc.cpp index 8c3e1195e..b0652a1e1 100644 --- a/src/webnn/tests/end2end/models/MobileNetV2Nhwc.cpp +++ b/src/webnn/tests/end2end/models/MobileNetV2Nhwc.cpp @@ -28,7 +28,7 @@ class MobileNetV2NhwcTests : public WebnnTest { mobilenetv2.mWeightsPath = nhwcPath + "weights/"; mobilenetv2.mLayout = "nhwc"; const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); - wnn::Operand output = mobilenetv2.LoadNHWC(builder); + wnn::Operand output = mobilenetv2.LoadNhwc(builder); wnn::Graph graph = utils::Build(builder, {{"output", output}}); const cnpy::NpyArray inputNpy = cnpy::npy_load(nhwcPath + "test_data_set/" + inputFile); const std::vector inputData = inputNpy.as_vec(); diff --git a/src/webnn/tests/end2end/models/ResNetNchw.cpp b/src/webnn/tests/end2end/models/ResNetNchw.cpp index 90a6ccae6..f21e9673d 100644 --- a/src/webnn/tests/end2end/models/ResNetNchw.cpp +++ b/src/webnn/tests/end2end/models/ResNetNchw.cpp @@ -27,7 +27,7 @@ class ResNetNchwTests : public WebnnTest { const std::string nchwPath = kModelPath + "/resnet50v2_nchw/"; resnet.mWeightsPath = nchwPath + "weights/"; const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); - wnn::Operand output = resnet.LoadNCHW(builder, false); + wnn::Operand output = resnet.LoadNchw(builder, false); wnn::Graph graph = utils::Build(builder, {{"output", output}}); const cnpy::NpyArray inputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + inputFile); const std::vector inputData = inputNpy.as_vec(); diff --git a/src/webnn/tests/end2end/models/ResNetNhwc.cpp b/src/webnn/tests/end2end/models/ResNetNhwc.cpp index 22edf774a..bf1a22995 100644 --- a/src/webnn/tests/end2end/models/ResNetNhwc.cpp +++ b/src/webnn/tests/end2end/models/ResNetNhwc.cpp @@ -27,7 +27,7 @@ class ResNetNhwcTests : public WebnnTest { const std::string nhwcPath = kModelPath + "/resnet50v2_nhwc/"; resnet.mWeightsPath = nhwcPath + "weights/"; const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); - wnn::Operand output = resnet.LoadNHWC(builder, false); + wnn::Operand output = resnet.LoadNhwc(builder, true); wnn::Graph graph = utils::Build(builder, {{"output", output}}); const cnpy::NpyArray inputNpy = cnpy::npy_load(nhwcPath + "test_data_set/" + inputFile); const std::vector inputData = inputNpy.as_vec(); diff --git a/src/webnn/tests/end2end/models/SqueezeNetNchw.cpp b/src/webnn/tests/end2end/models/SqueezeNetNchw.cpp index 0ef2291fa..18d2b6b24 100644 --- a/src/webnn/tests/end2end/models/SqueezeNetNchw.cpp +++ b/src/webnn/tests/end2end/models/SqueezeNetNchw.cpp @@ -27,7 +27,7 @@ class SqueezeNetNchwTests : public WebnnTest { const std::string nchwPath = kModelPath + "/squeezenet1.1_nchw/"; squeezenet.mWeightsPath = nchwPath + "weights/"; const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); - wnn::Operand output = squeezenet.LoadNCHW(builder, false); + wnn::Operand output = squeezenet.LoadNchw(builder, false); wnn::Graph graph = utils::Build(builder, {{"output", output}}); const cnpy::NpyArray inputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + inputFile); const std::vector inputData = inputNpy.as_vec(); diff --git a/src/webnn/tests/end2end/models/SqueezeNetNhwc.cpp b/src/webnn/tests/end2end/models/SqueezeNetNhwc.cpp index 8dcae576c..4ef54b70a 100644 --- a/src/webnn/tests/end2end/models/SqueezeNetNhwc.cpp +++ b/src/webnn/tests/end2end/models/SqueezeNetNhwc.cpp @@ -28,7 +28,7 @@ class SqueezeNetNhwcTests : public WebnnTest { squeezenet.mWeightsPath = nhwcPath + "weights/"; squeezenet.mLayout = "nhwc"; const wnn::GraphBuilder builder = utils::CreateGraphBuilder(GetContext()); - wnn::Operand output = squeezenet.LoadNHWC(builder); + wnn::Operand output = squeezenet.LoadNhwc(builder); wnn::Graph graph = utils::Build(builder, {{"output", output}}); const cnpy::NpyArray inputNpy = cnpy::npy_load(nhwcPath + "test_data_set/" + inputFile); const std::vector inputData = inputNpy.as_vec(); diff --git a/src/webnn/wire/BUILD.gn b/src/webnn/wire/BUILD.gn index b4a9b8e7e..4365b3215 100644 --- a/src/webnn/wire/BUILD.gn +++ b/src/webnn/wire/BUILD.gn @@ -60,14 +60,18 @@ dawn_component("webnn_wire") { configs = [ "${webnn_root}/src/webnn/common:internal_config" ] sources = get_target_outputs(":gen") sources += [ + "BufferConsumer.h", + "BufferConsumer_impl.h", "ChunkedCommandHandler.cpp", "ChunkedCommandHandler.h", "ChunkedCommandSerializer.cpp", "ChunkedCommandSerializer.h", "WireClient.cpp", + "Wire.cpp", "WireDeserializeAllocator.cpp", "WireDeserializeAllocator.h", "WireServer.cpp", + "WireResult.h", "client/ApiObjects.h", "client/Client.cpp", "client/Client.h", diff --git a/src/webnn/wire/BufferConsumer.h b/src/webnn/wire/BufferConsumer.h new file mode 100644 index 000000000..6fd631a83 --- /dev/null +++ b/src/webnn/wire/BufferConsumer.h @@ -0,0 +1,85 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef WEBNN_WIRE_BUFFERCONSUMER_H_ +#define WEBNN_WIRE_BUFFERCONSUMER_H_ + +#include + +#include "webnn/wire/WireResult.h" + +namespace webnn::wire { + + // BufferConsumer is a utility class that allows reading bytes from a buffer + // while simultaneously decrementing the amount of remaining space by exactly + // the amount read. It helps prevent bugs where incrementing a pointer and + // decrementing a size value are not kept in sync. + // BufferConsumer also contains bounds checks to prevent reading out-of-bounds. + template + class BufferConsumer { + static_assert(sizeof(BufferT) == 1, + "BufferT must be 1-byte, but may have const/volatile qualifiers."); + + public: + BufferConsumer(BufferT* buffer, size_t size) : mBuffer(buffer), mSize(size) { + } + + BufferT* Buffer() const { + return mBuffer; + } + size_t AvailableSize() const { + return mSize; + } + + protected: + template + WireResult NextN(N count, T** data); + + template + WireResult Next(T** data); + + template + WireResult Peek(T** data); + + private: + BufferT* mBuffer; + size_t mSize; + }; + + class SerializeBuffer : public BufferConsumer { + public: + using BufferConsumer::BufferConsumer; + using BufferConsumer::Next; + using BufferConsumer::NextN; + }; + + class DeserializeBuffer : public BufferConsumer { + public: + using BufferConsumer::BufferConsumer; + using BufferConsumer::Peek; + + template + WireResult ReadN(N count, const volatile T** data) { + return NextN(count, data); + } + + template + WireResult Read(const volatile T** data) { + return Next(data); + } + }; + +} // namespace webnn::wire + +#endif // WEBNN_WIRE_BUFFERCONSUMER_H_ diff --git a/src/webnn/wire/BufferConsumer_impl.h b/src/webnn/wire/BufferConsumer_impl.h new file mode 100644 index 000000000..3259d75b2 --- /dev/null +++ b/src/webnn/wire/BufferConsumer_impl.h @@ -0,0 +1,73 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef WEBNN_WIRE_BUFFERCONSUMER_IMPL_H_ +#define WEBNN_WIRE_BUFFERCONSUMER_IMPL_H_ + +#include "webnn/wire/BufferConsumer.h" + +#include +#include + +namespace webnn::wire { + + template + template + WireResult BufferConsumer::Peek(T** data) { + if (sizeof(T) > mSize) { + return WireResult::FatalError; + } + + *data = reinterpret_cast(mBuffer); + return WireResult::Success; + } + + template + template + WireResult BufferConsumer::Next(T** data) { + if (sizeof(T) > mSize) { + return WireResult::FatalError; + } + + *data = reinterpret_cast(mBuffer); + mBuffer += sizeof(T); + mSize -= sizeof(T); + return WireResult::Success; + } + + template + template + WireResult BufferConsumer::NextN(N count, T** data) { + static_assert(std::is_unsigned::value, "|count| argument of NextN must be unsigned."); + + constexpr size_t kMaxCountWithoutOverflows = std::numeric_limits::max() / sizeof(T); + if (count > kMaxCountWithoutOverflows) { + return WireResult::FatalError; + } + + // Cannot overflow because |count| is not greater than |kMaxCountWithoutOverflows|. + size_t totalSize = sizeof(T) * count; + if (totalSize > mSize) { + return WireResult::FatalError; + } + + *data = reinterpret_cast(mBuffer); + mBuffer += totalSize; + mSize -= totalSize; + return WireResult::Success; + } + +} // namespace webnn::wire + +#endif // WEBNN_WIRE_BUFFERCONSUMER_IMPL_H_ diff --git a/src/webnn/wire/ChunkedCommandSerializer.h b/src/webnn/wire/ChunkedCommandSerializer.h index f3c0a6347..3cf4d62fa 100644 --- a/src/webnn/wire/ChunkedCommandSerializer.h +++ b/src/webnn/wire/ChunkedCommandSerializer.h @@ -23,16 +23,17 @@ #include #include #include +#include namespace webnn::wire { class ChunkedCommandSerializer { public: - ChunkedCommandSerializer(CommandSerializer* serializer); + explicit ChunkedCommandSerializer(CommandSerializer* serializer); template void SerializeCommand(const Cmd& cmd) { - SerializeCommand(cmd, 0, [](char*) {}); + SerializeCommand(cmd, 0, [](SerializeBuffer*) { return WireResult::Success; }); } template @@ -41,15 +42,16 @@ namespace webnn::wire { ExtraSizeSerializeFn&& SerializeExtraSize) { SerializeCommandImpl( cmd, - [](const Cmd& cmd, size_t requiredSize, char* allocatedBuffer) { - cmd.Serialize(requiredSize, allocatedBuffer); + [](const Cmd& cmd, size_t requiredSize, SerializeBuffer* serializeBuffer) { + return cmd.Serialize(requiredSize, serializeBuffer); }, extraSize, std::forward(SerializeExtraSize)); } template void SerializeCommand(const Cmd& cmd, const ObjectIdProvider& objectIdProvider) { - SerializeCommand(cmd, objectIdProvider, 0, [](char*) {}); + SerializeCommand(cmd, objectIdProvider, 0, + [](SerializeBuffer*) { return WireResult::Success; }); } template @@ -59,8 +61,9 @@ namespace webnn::wire { ExtraSizeSerializeFn&& SerializeExtraSize) { SerializeCommandImpl( cmd, - [&objectIdProvider](const Cmd& cmd, size_t requiredSize, char* allocatedBuffer) { - cmd.Serialize(requiredSize, allocatedBuffer, objectIdProvider); + [&objectIdProvider](const Cmd& cmd, size_t requiredSize, + SerializeBuffer* serializeBuffer) { + return cmd.Serialize(requiredSize, serializeBuffer, objectIdProvider); }, extraSize, std::forward(SerializeExtraSize)); } @@ -77,8 +80,12 @@ namespace webnn::wire { if (requiredSize <= mMaxAllocationSize) { char* allocatedBuffer = static_cast(mSerializer->GetCmdSpace(requiredSize)); if (allocatedBuffer != nullptr) { - SerializeCmd(cmd, requiredSize, allocatedBuffer); - SerializeExtraSize(allocatedBuffer + commandSize); + SerializeBuffer serializeBuffer(allocatedBuffer, requiredSize); + WireResult r1 = SerializeCmd(cmd, requiredSize, &serializeBuffer); + WireResult r2 = SerializeExtraSize(&serializeBuffer); + if (DAWN_UNLIKELY(r1 != WireResult::Success || r2 != WireResult::Success)) { + mSerializer->OnSerializeError(); + } } return; } @@ -87,8 +94,13 @@ namespace webnn::wire { if (!cmdSpace) { return; } - SerializeCmd(cmd, requiredSize, cmdSpace.get()); - SerializeExtraSize(cmdSpace.get() + commandSize); + SerializeBuffer serializeBuffer(cmdSpace.get(), requiredSize); + WireResult r1 = SerializeCmd(cmd, requiredSize, &serializeBuffer); + WireResult r2 = SerializeExtraSize(&serializeBuffer); + if (DAWN_UNLIKELY(r1 != WireResult::Success || r2 != WireResult::Success)) { + mSerializer->OnSerializeError(); + return; + } SerializeChunkedCommand(cmdSpace.get(), requiredSize); } diff --git a/src/webnn/wire/Wire.cpp b/src/webnn/wire/Wire.cpp new file mode 100644 index 000000000..44754eafa --- /dev/null +++ b/src/webnn/wire/Wire.cpp @@ -0,0 +1,24 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "webnn/wire/Wire.h" + +namespace webnn::wire { + +CommandSerializer::CommandSerializer() = default; +CommandSerializer::~CommandSerializer() = default; + +void CommandSerializer::OnSerializeError() {} + +} // namespace webnn::wire diff --git a/src/webnn/wire/WireResult.h b/src/webnn/wire/WireResult.h new file mode 100644 index 000000000..85e765eb3 --- /dev/null +++ b/src/webnn/wire/WireResult.h @@ -0,0 +1,38 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef WEBNN_WIRE_WIRERESULT_H_ +#define WEBNN_WIRE_WIRERESULT_H_ + +#include "common/Compiler.h" + +namespace webnn::wire { + + enum class [[nodiscard]] WireResult{ + Success, + FatalError, + }; + +// Macro to simplify error handling, similar to DAWN_TRY but for WireResult. +#define WIRE_TRY(EXPR) \ + do { \ + WireResult exprResult = EXPR; \ + if (DAWN_UNLIKELY(exprResult != WireResult::Success)) { \ + return exprResult; \ + } \ + } while (0) + +} // namespace webnn::wire + +#endif // WEBNN_WIRE_WIRERESULT_H_ diff --git a/src/webnn/wire/client/GraphBuilder.cpp b/src/webnn/wire/client/GraphBuilder.cpp index c6a6fad31..940d45d57 100644 --- a/src/webnn/wire/client/GraphBuilder.cpp +++ b/src/webnn/wire/client/GraphBuilder.cpp @@ -25,7 +25,7 @@ namespace webnn::wire::client { GraphBuilderConstantInternalCmd cmd; cmd.graphBuilderId = this->id; cmd.desc = desc; - cmd.buffer = static_cast(value->buffer); + cmd.arrayBuffer = static_cast(value->buffer); cmd.byteLength = value->byteLength; cmd.byteOffset = value->byteOffset; @@ -43,7 +43,7 @@ namespace webnn::wire::client { GraphBuilderConstantWithGpuBufferInternalCmd cmd; cmd.graphBuilderId = this->id; cmd.desc = desc; - cmd.buffer = static_cast(value->buffer); + cmd.arrayBuffer = static_cast(value->buffer); cmd.id = value->id; cmd.generation = value->generation; cmd.byteLength = value->size; diff --git a/src/webnn/wire/client/NamedInputs.cpp b/src/webnn/wire/client/NamedInputs.cpp index 07675bc2f..26192f7b3 100644 --- a/src/webnn/wire/client/NamedInputs.cpp +++ b/src/webnn/wire/client/NamedInputs.cpp @@ -26,7 +26,7 @@ namespace webnn::wire::client { // Input type is ArrayBufferView WNNArrayBufferView arrayBufferView = input->resource.arrayBufferView; if (arrayBufferView.buffer != nullptr) { - cmd.buffer = static_cast(arrayBufferView.buffer); + cmd.arrayBuffer = static_cast(arrayBufferView.buffer); cmd.byteLength = arrayBufferView.byteLength; cmd.byteOffset = arrayBufferView.byteOffset; } else { diff --git a/src/webnn/wire/client/NamedOutputs.cpp b/src/webnn/wire/client/NamedOutputs.cpp index 38a437b50..9ef0893d6 100644 --- a/src/webnn/wire/client/NamedOutputs.cpp +++ b/src/webnn/wire/client/NamedOutputs.cpp @@ -19,9 +19,9 @@ namespace webnn::wire::client { - void NamedOutputs::Set(char const* name, WNNResource const* resource) { + void NamedOutputs::SetOutput(char const* name, WNNResource const* resource) { // The type of output data is WNNArrayBufferView. - NamedOutputsSetCmd cmd = {}; + NamedOutputsSetOutputCmd cmd = {}; cmd.namedOutputsId = this->id; cmd.name = name; WNNArrayBufferView arrayBufferView = resource->arrayBufferView; @@ -40,8 +40,16 @@ namespace webnn::wire::client { client->SerializeCommand(cmd); } - void NamedOutputs::Get(char const* name, WNNArrayBufferView const* resource) { - UNREACHABLE(); + void NamedOutputs::GetOutput(char const* name, WNNArrayBufferView const* resource) { + NamedOutputsGetOutputCmd cmd = {}; + cmd.namedOutputsId = this->id; + cmd.name = name; + if (resource->buffer != nullptr) { + cmd.arrayBuffer = static_cast(resource->buffer); + cmd.byteLength = resource->byteLength; + cmd.byteOffset = resource->byteOffset; + } + client->SerializeCommand(cmd); } bool NamedOutputs::OutputResult(char const* name, diff --git a/src/webnn/wire/client/NamedOutputs.h b/src/webnn/wire/client/NamedOutputs.h index f57d27545..9d7fdfa8b 100644 --- a/src/webnn/wire/client/NamedOutputs.h +++ b/src/webnn/wire/client/NamedOutputs.h @@ -29,8 +29,8 @@ namespace webnn::wire::client { public: using ObjectBase::ObjectBase; - void Set(char const* name, WNNResource const* resource); - void Get(char const* name, WNNArrayBufferView const* resource); + void SetOutput(char const* name, WNNResource const* resource); + void GetOutput(char const* name, WNNArrayBufferView const* resource); bool OutputResult(char const* name, uint8_t const* buffer, size_t byteLength, diff --git a/src/webnn/wire/server/Server.cpp b/src/webnn/wire/server/Server.cpp index ee205bb63..d5dea8feb 100644 --- a/src/webnn/wire/server/Server.cpp +++ b/src/webnn/wire/server/Server.cpp @@ -158,28 +158,6 @@ namespace webnn::wire::server { return true; } - bool Server::DoCreateGraphBuilder(ObjectId contextId, ObjectHandle result) { - auto* context = ContextObjects().Get(contextId); - if (context == nullptr) { - return false; - } - - // Create and register the GraphBuilder object. - auto* resultData = GraphBuilderObjects().Allocate(result.id); - if (resultData == nullptr) { - return false; - } - resultData->generation = result.generation; - resultData->contextInfo = context->contextInfo; - if (resultData->contextInfo != nullptr) { - if (!TrackContextChild(resultData->contextInfo, ObjectType::GraphBuilder, result.id)) { - return false; - } - } - resultData->handle = mProcs.createGraphBuilder(context->handle); - return true; - } - #if defined(WEBNN_ENABLE_GPU_BUFFER) WGPUDevice Server::GetWGPUDevice(uint32_t id, uint32_t generation) { return mDawnWireServer->GetDevice(id, generation); diff --git a/src/webnn/wire/server/ServerGraph.cpp b/src/webnn/wire/server/ServerGraph.cpp index b80ade192..2a8633382 100644 --- a/src/webnn/wire/server/ServerGraph.cpp +++ b/src/webnn/wire/server/ServerGraph.cpp @@ -25,7 +25,7 @@ namespace webnn::wire::server { } for (auto& name : mOutputNamesMap[outputsId]) { WNNArrayBufferView arrayBuffer = {}; - mProcs.namedOutputsGet(namedOutputs->handle, name.data(), &arrayBuffer); + mProcs.namedOutputsGetOutput(namedOutputs->handle, name.data(), &arrayBuffer); if (arrayBuffer.buffer == nullptr) { return false; } @@ -34,7 +34,7 @@ namespace webnn::wire::server { ReturnGraphComputeResultCmd cmd; cmd.namedOutputs = ObjectHandle{outputsId, namedOutputs->generation}; cmd.name = name.data(); - cmd.buffer = static_cast(arrayBuffer.buffer); + cmd.arrayBuffer = static_cast(arrayBuffer.buffer); cmd.byteLength = arrayBuffer.byteLength; cmd.byteOffset = arrayBuffer.byteOffset; SerializeCommand(cmd); diff --git a/src/webnn/wire/server/ServerNamedOutputs.cpp b/src/webnn/wire/server/ServerNamedOutputs.cpp index b16b2ce0f..ff4100881 100644 --- a/src/webnn/wire/server/ServerNamedOutputs.cpp +++ b/src/webnn/wire/server/ServerNamedOutputs.cpp @@ -17,7 +17,7 @@ namespace webnn::wire::server { - bool Server::DoNamedOutputsSet(ObjectId namedOutputsId, + bool Server::DoNamedOutputsSetOutput(ObjectId namedOutputsId, char const* name, size_t byteLength, size_t byteOffset, @@ -50,9 +50,27 @@ namespace webnn::wire::server { outputNames.push_back(std::string(name)); } } - mProcs.namedOutputsSet(namedOutputs->handle, name, &resource); + mProcs.namedOutputsSetOutput(namedOutputs->handle, name, &resource); return true; } + bool Server::DoNamedOutputsGetOutput(ObjectId namedOutputsId, + char const* name, + uint8_t const* buffer, + size_t byteLength, + size_t byteOffset) { + auto* namedOutputs = NamedOutputsObjects().Get(namedOutputsId); + if (namedOutputs == nullptr) { + return false; + } + + WNNArrayBufferView arrayBuffer = {}; + arrayBuffer.buffer = const_cast(static_cast(buffer)); + arrayBuffer.byteLength = byteLength; + arrayBuffer.byteOffset = byteOffset; + mProcs.namedOutputsGetOutput(namedOutputs->handle, name, &arrayBuffer); + + return true; + } } // namespace webnn::wire::server diff --git a/webnn.json b/webnn.json index c31803d0d..82507f903 100644 --- a/webnn.json +++ b/webnn.json @@ -199,7 +199,7 @@ {"name": "type", "type": "error type"}, {"name": "message", "type": "char", "annotation": "const*", "length": "strlen"} ], - "TODO": "enga@: Make this a Dawn extension" + "_TODO": "enga@: Make this a Dawn extension" }, { "name": "set uncaptured error callback", @@ -266,13 +266,14 @@ "methods": [ { "name": "size", - "returns": "size_t" + "returns": "uint64_t", + "args": [] }, { "name": "get operand", "returns": "operand", "args": [ - {"name": "index", "type": "size_t"} + {"name": "index", "type": "uint64_t"} ] } ] @@ -282,7 +283,8 @@ "methods": [ { "name": "size", - "returns": "size_t" + "returns": "uint64_t", + "args": [] }, { "name": "set fusion operator", @@ -295,7 +297,7 @@ "name": "get fusion operator", "returns": "fusion operator", "args": [ - {"name": "index", "type": "size_t"} + {"name": "index", "type": "uint64_t"} ] } ] @@ -381,11 +383,11 @@ "category": "structure", "members": [ {"name": "padding count", "type": "uint32_t", "default": 0}, - {"name": "padding", "type": "int32_t", "annotation": "const*", "length": "padding count", "optional": true}, + {"name": "padding", "type": "int32_t", "annotation": "const*", "length": "padding count", "default": "nullptr"}, {"name": "strides count", "type": "uint32_t", "default": 0}, - {"name": "strides", "type": "int32_t", "annotation": "const*", "length": "strides count", "optional": true}, + {"name": "strides", "type": "int32_t", "annotation": "const*", "length": "strides count", "default": "nullptr"}, {"name": "dilations count", "type": "uint32_t", "default": 0}, - {"name": "dilations", "type": "int32_t", "annotation": "const*", "length": "dilations count", "optional": true}, + {"name": "dilations", "type": "int32_t", "annotation": "const*", "length": "dilations count", "default": "nullptr"}, {"name": "auto pad", "type": "auto pad", "default": "explicit"}, {"name": "groups", "type": "int32_t", "default": 1}, {"name": "input layout", "type": "input operand layout", "default": "nchw"}, @@ -398,15 +400,15 @@ "category": "structure", "members": [ {"name": "padding count", "type": "uint32_t", "default": 0}, - {"name": "padding", "type": "int32_t", "annotation": "const*", "length": "padding count", "optional": true}, + {"name": "padding", "type": "int32_t", "annotation": "const*", "length": "padding count", "default": "nullptr"}, {"name": "strides count", "type": "uint32_t", "default": 0}, - {"name": "strides", "type": "int32_t", "annotation": "const*", "length": "strides count", "optional": true}, + {"name": "strides", "type": "int32_t", "annotation": "const*", "length": "strides count", "default": "nullptr"}, {"name": "dilations count", "type": "uint32_t", "default": 0}, - {"name": "dilations", "type": "int32_t", "annotation": "const*", "length": "dilations count", "optional": true}, + {"name": "dilations", "type": "int32_t", "annotation": "const*", "length": "dilations count", "default": "nullptr"}, {"name": "output padding count", "type": "uint32_t", "default": 0}, - {"name": "output padding", "type": "int32_t", "annotation": "const*", "length": "output padding count", "optional": true}, + {"name": "output padding", "type": "int32_t", "annotation": "const*", "length": "output padding count", "default": "nullptr"}, {"name": "output sizes count", "type": "uint32_t", "default": 0}, - {"name": "output sizes", "type": "int32_t", "annotation": "const*", "length": "output sizes count", "optional": true}, + {"name": "output sizes", "type": "int32_t", "annotation": "const*", "length": "output sizes count", "default": "nullptr"}, {"name": "auto pad", "type": "auto pad", "default": "explicit"}, {"name": "groups", "type": "int32_t", "default": 1}, {"name": "input layout", "type": "input operand layout", "default": "nchw"}, @@ -419,7 +421,7 @@ "category": "structure", "members": [ {"name": "axes count", "type": "uint32_t", "default": 0}, - {"name": "axes", "type": "int32_t", "annotation": "const*", "length": "axes count", "optional": true} + {"name": "axes", "type": "int32_t", "annotation": "const*", "length": "axes count", "default": "nullptr"} ] }, "gru options": { @@ -446,18 +448,18 @@ "category": "structure", "members": [ {"name": "window dimensions count", "type": "uint32_t", "default": 0}, - {"name": "window dimensions", "type": "int32_t", "annotation": "const*", "length": "window dimensions count", "optional": true}, + {"name": "window dimensions", "type": "int32_t", "annotation": "const*", "length": "window dimensions count", "default": "nullptr"}, {"name": "padding count", "type": "uint32_t", "default": 0}, - {"name": "padding", "type": "int32_t", "annotation": "const*", "length": "padding count", "optional": true}, + {"name": "padding", "type": "int32_t", "annotation": "const*", "length": "padding count", "default": "nullptr"}, {"name": "strides count", "type": "uint32_t", "default": 0}, - {"name": "strides", "type": "int32_t", "annotation": "const*", "length": "strides count", "optional": true}, + {"name": "strides", "type": "int32_t", "annotation": "const*", "length": "strides count", "default": "nullptr"}, {"name": "dilations count", "type": "uint32_t", "default": 0}, - {"name": "dilations", "type": "int32_t", "annotation": "const*", "length": "dilations count", "optional": true}, + {"name": "dilations", "type": "int32_t", "annotation": "const*", "length": "dilations count", "default": "nullptr"}, {"name": "auto pad", "type": "auto pad", "default": "explicit"}, {"name": "layout", "type": "input operand layout", "default": "nchw"}, {"name": "rounding type", "type": "rounding type", "default": "floor"}, {"name": "output sizes count", "type": "uint32_t", "default": 0}, - {"name": "output sizes", "type": "int32_t", "annotation": "const*", "length": "output sizes count", "optional": true} + {"name": "output sizes", "type": "int32_t", "annotation": "const*", "length": "output sizes count", "default": "nullptr"} ] }, "gemm options": { @@ -480,7 +482,7 @@ "category": "structure", "members": [ {"name": "axes count", "type": "uint32_t", "default": 0}, - {"name": "axes", "type": "int32_t", "annotation": "const*", "length": "axes count","optional": true}, + {"name": "axes", "type": "int32_t", "annotation": "const*", "length": "axes count", "default": "nullptr"}, {"name": "keepDimensions", "type": "bool", "default": "false"} ] }, @@ -489,11 +491,11 @@ "members": [ {"name": "mode", "type": "interpolation mode", "default": "nearest neighbor"}, {"name": "scales count", "type": "uint32_t", "default": 0}, - {"name": "scales", "type": "float", "annotation": "const*", "length": "scales count", "optional": true}, + {"name": "scales", "type": "float", "annotation": "const*", "length": "scales count", "default": "nullptr" }, {"name": "sizes count", "type": "uint32_t", "default": 0}, - {"name": "sizes", "type": "int32_t", "annotation": "const*", "length": "sizes count", "optional": true}, + {"name": "sizes", "type": "int32_t", "annotation": "const*", "length": "sizes count", "default": "nullptr"}, {"name": "axes count", "type": "uint32_t", "default": 0}, - {"name": "axes", "type": "int32_t", "annotation": "const*", "length": "axes count", "optional": true} + {"name": "axes", "type": "int32_t", "annotation": "const*", "length": "axes count", "default": "nullptr"} ] }, "split options": { @@ -506,14 +508,14 @@ "category": "structure", "members": [ {"name": "axes count", "type": "uint32_t", "default": 0}, - {"name": "axes", "type": "int32_t", "annotation": "const*", "length": "axes count", "optional": true} + {"name": "axes", "type": "int32_t", "annotation": "const*", "length": "axes count", "default": "nullptr"} ] }, "transpose options": { "category": "structure", "members": [ {"name": "permutation count", "type": "uint32_t", "default": 0}, - {"name": "permutation", "type": "int32_t", "annotation": "const*", "length": "permutation count", "optional": true} + {"name": "permutation", "type": "int32_t", "annotation": "const*", "length": "permutation count", "default": "nullptr"} ] }, "batchNorm options": { @@ -736,7 +738,8 @@ }, { "name": "hard swish operator", - "returns": "fusion operator" + "returns": "fusion operator", + "args": [] }, { "name": "gemm", @@ -876,7 +879,8 @@ }, { "name": "relu operator", - "returns": "fusion operator" + "returns": "fusion operator", + "args": [] }, { "name": "reshape", @@ -896,7 +900,8 @@ }, { "name": "sigmoid operator", - "returns": "fusion operator" + "returns": "fusion operator", + "args": [] }, { "name": "slice", @@ -945,7 +950,8 @@ }, { "name": "tanh operator", - "returns": "fusion operator" + "returns": "fusion operator", + "args": [] }, { "name": "transpose", @@ -1017,7 +1023,7 @@ "category": "structure", "members": [ {"name": "resource", "type": "resource"}, - {"name": "dimensions", "type": "int32_t", "annotation": "const*", "length": "dimensions count", "optional": true}, + {"name": "dimensions", "type": "int32_t", "annotation": "const*", "length": "dimensions count", "default": "nullptr"}, {"name": "dimensions count", "type": "uint32_t", "default": 0} ] }, diff --git a/webnn_wire.json b/webnn_wire.json index c947002ca..1c012f7ca 100644 --- a/webnn_wire.json +++ b/webnn_wire.json @@ -23,19 +23,19 @@ "graph builder constant internal": [ {"name": "graph builder id", "type": "ObjectId"}, {"name": "desc", "type": "operand descriptor", "annotation": "const*"}, - {"name": "buffer", "type": "uint8_t", "annotation": "const*", "length": "byte length"}, - {"name": "byte length", "type": "size_t"}, - {"name": "byte offset", "type": "size_t", "default": 0}, + {"name": "array buffer", "type": "uint8_t", "annotation": "const*", "length": "byte length"}, + {"name": "byte length", "type": "uint64_t"}, + {"name": "byte offset", "type": "uint64_t", "default": 0}, {"name": "result", "type": "ObjectHandle", "handle_type": "operand"} ], "graph builder constant with gpu buffer internal": [ {"name": "graph builder id", "type": "ObjectId"}, {"name": "desc", "type": "operand descriptor", "annotation": "const*"}, - {"name": "buffer", "type": "uint8_t", "annotation": "const*", "optional": true}, + {"name": "array buffer", "type": "uint8_t", "annotation": "const*", "optional": true}, {"name": "id", "type": "uint32_t", "default": 0}, {"name": "generation", "type": "uint32_t", "default": 0}, - {"name": "byte length", "type": "size_t"}, - {"name": "byte offset", "type": "size_t", "default": 0}, + {"name": "byte length", "type": "uint64_t"}, + {"name": "byte offset", "type": "uint64_t", "default": 0}, {"name": "result", "type": "ObjectHandle", "handle_type": "operand"} ], "graph builder gru internal": [ @@ -83,29 +83,32 @@ "named inputs set": [ {"name": "named inputs id", "type": "ObjectId"}, {"name": "name", "type": "char", "annotation": "const*", "length": "strlen"}, - {"name": "buffer", "type": "uint8_t", "annotation": "const*", "length": "byte length", "optional": true}, - {"name": "byte length", "type": "size_t"}, - {"name": "byte offset", "type": "size_t", "default": 0}, + {"name": "array buffer", "type": "uint8_t", "annotation": "const*", "length": "byte length", "default": "nullptr"}, + {"name": "byte length", "type": "uint64_t"}, + {"name": "byte offset", "type": "uint64_t", "default": 0}, {"name": "gpu buffer id", "type": "uint32_t", "default": 0}, {"name": "gpu buffer generation", "type": "uint32_t", "default": 0}, - {"name": "dimensions", "type": "int32_t", "annotation": "const*", "length": "dimensions count", "optional": true}, + {"name": "dimensions", "type": "int32_t", "annotation": "const*", "length": "dimensions count", "default": "nullptr"}, {"name": "dimensions count", "type": "uint32_t", "default": 0} ], - "named outputs set": [ + "named outputs set output": [ {"name": "named outputs id", "type": "ObjectId"}, {"name": "name", "type": "char", "annotation": "const*", "length": "strlen"}, - {"name": "byte length", "type": "size_t"}, - {"name": "byte offset", "type": "size_t", "default": 0}, + {"name": "byte length", "type": "uint64_t"}, + {"name": "byte offset", "type": "uint64_t", "default": 0}, {"name": "gpu buffer id", "type": "uint32_t", "default": 0}, {"name": "gpu buffer generation", "type": "uint32_t", "default": 0} ], + "named outputs get output": [ + {"name": "named outputs id", "type": "ObjectId"}, + {"name": "name", "type": "char", "annotation": "const*", "length": "strlen"}, + {"name": "array buffer", "type": "uint8_t", "annotation": "const*", "length": "byte length", "default": "nullptr"}, + {"name": "byte length", "type": "uint64_t"}, + {"name": "byte offset", "type": "uint64_t", "default": 0} + ], "destroy object": [ {"name": "object type", "type": "ObjectType"}, {"name": "object id", "type": "ObjectId"} - ], - "create graph builder": [ - {"name": "context", "type": "ObjectId"}, - {"name": "result", "type": "ObjectHandle", "handle_type": "graph builder"} ] }, "return commands": { @@ -118,9 +121,9 @@ "graph compute result": [ {"name": "named outputs", "type": "ObjectHandle", "handle_type": "named outputs"}, {"name": "name", "type": "char", "annotation": "const*", "length": "strlen"}, - {"name": "buffer", "type": "uint8_t", "annotation": "const*", "length": "byte length"}, - {"name": "byte length", "type": "size_t"}, - {"name": "byte offset", "type": "size_t", "default": 0} + {"name": "array buffer", "type": "uint8_t", "annotation": "const*", "length": "byte length"}, + {"name": "byte length", "type": "uint64_t"}, + {"name": "byte offset", "type": "uint64_t", "default": 0} ], "graph compute async callback": [ { "name": "graph", "type": "ObjectHandle", "handle_type": "graph" }, @@ -146,8 +149,8 @@ "GraphBuilderSplit", "InstanceCreateContextWithGpuDevice", "NamedInputsSet", - "NamedOutputsSet", - "NamedOutputsGet", + "NamedOutputsSetOutput", + "NamedOutputsGetOutput", "OperandArraySize", "OperatorArraySize", "GraphComputeAsync",