Skip to content

Commit

Permalink
Merge pull request #284 from bbernhar/super_resolution
Browse files Browse the repository at this point in the history
Add SuperResolution example and test
  • Loading branch information
fujunwei authored Aug 1, 2022
2 parents bce6861 + 7c4030c commit 6b27f08
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 3 deletions.
113 changes: 113 additions & 0 deletions examples/SuperResolution/SuperResolution.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright 2022 The WebNN-native Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "examples/SuperResolution/SuperResolution.h"

SuperResolution::SuperResolution() : ExampleBase() {
}

const wnn::Operand SuperResolution::BuildConstantFromNpy(const wnn::GraphBuilder& builder,
const std::string& path) {
const cnpy::NpyArray data = cnpy::npy_load(path);
mConstants.push_back(data.data_holder);
return utils::BuildConstant(builder, data.shape, data.data<float>(), data.num_bytes());
}

const wnn::Operand SuperResolution::BuildConv(const wnn::GraphBuilder& builder,
const wnn::Operand& input,
int32_t convIndex,
bool relu,
utils::Conv2dOptions* options,
const std::string& biasName) {
std::string prefix = mLayout == "nchw" ? mWeightsPath + "conv" : mWeightsPath + "Const_";
std::string suffix = mLayout == "nchw" ? "_weight.npy" : ".npy";
const std::string weightsPath = prefix + std::to_string(convIndex) + suffix;
const wnn::Operand convWeights = BuildConstantFromNpy(builder, weightsPath);

// TODO: Figure out correct "channels last" path suffix.
prefix = mLayout == "nchw" ? mWeightsPath + "conv" : mWeightsPath + "super_resolution_";
if (mLayout == "nchw") {
prefix.append(std::to_string(convIndex));
}

const std::string biasPath = prefix + biasName + "_bias.npy";
const wnn::Operand convBias = BuildConstantFromNpy(builder, biasPath);

const wnn::Conv2dOptions* conv2dOptions = options != nullptr ? options->AsPtr() : nullptr;
const wnn::Operand conv2d = builder.Conv2d(input, convWeights, conv2dOptions);

if (!mFused) {
if (relu) {
return builder.Relu(conv2d);
}
return conv2d;
}

// Fused
utils::Conv2dOptions fusedOptions;
if (options != nullptr) {
fusedOptions = *options;
}
fusedOptions.bias = convBias;

if (relu) {
fusedOptions.activation = builder.ReluOperator();
}

return builder.Conv2d(input, convWeights, fusedOptions.AsPtr());
}

const wnn::Operand SuperResolution::LoadNchw(const wnn::GraphBuilder& builder, bool softmax) {
const wnn::Operand input = utils::BuildInput(builder, "input", {1, 1, 224, 224});

utils::Conv2dOptions conv1Options;
conv1Options.strides = {1, 1};
conv1Options.padding = {2, 2, 2, 2};
conv1Options.dilations = {1, 1};
const wnn::Operand conv1 =
BuildConv(builder, input, /*convIndex*/ 1, /*relu*/ true, &conv1Options);

utils::Conv2dOptions conv2Options;
conv2Options.strides = {1, 1};
conv2Options.padding = {1, 1, 1, 1};
conv2Options.dilations = {1, 1};
const wnn::Operand conv2 =
BuildConv(builder, conv1, /*convIndex*/ 2, /*relu*/ true, &conv2Options);

utils::Conv2dOptions conv3Options;
conv3Options.strides = {1, 1};
conv3Options.padding = {1, 1, 1, 1};
conv3Options.dilations = {1, 1};
const wnn::Operand conv3 =
BuildConv(builder, conv2, /*convIndex*/ 3, /*relu*/ true, &conv3Options);

utils::Conv2dOptions conv4Options;
conv4Options.strides = {1, 1};
conv4Options.padding = {1, 1, 1, 1};
conv4Options.dilations = {1, 1};
const wnn::Operand conv4 =
BuildConv(builder, conv3, /*convIndex*/ 4, /*relu*/ false, &conv4Options);

const std::vector<int32_t> newShape1 = {-1, 1, 3, 3, 224, 224};
const wnn::Operand reshape1 = builder.Reshape(conv4, newShape1.data(), newShape1.size());

wnn::TransposeOptions transpose1Options;
std::vector<int32_t> permutation = {0, 1, 4, 2, 5, 3};
transpose1Options.permutation = permutation.data();
transpose1Options.permutationCount = permutation.size();
const wnn::Operand transpose1 = builder.Transpose(reshape1, &transpose1Options);

const std::vector<int32_t> newShape2 = {-1, 1, 672, 672};
return builder.Reshape(transpose1, newShape2.data(), newShape2.size());
}
39 changes: 39 additions & 0 deletions examples/SuperResolution/SuperResolution.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright 2022 The WebNN-native Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <webnn/webnn.h>
#include <webnn/webnn_cpp.h>

#include "examples/SampleUtils.h"

class SuperResolution : public ExampleBase {
public:
SuperResolution();
~SuperResolution() override = default;

const wnn::Operand LoadNchw(const wnn::GraphBuilder& builder, bool softmax);

private:
const wnn::Operand BuildConstantFromNpy(const wnn::GraphBuilder& builder,
const std::string& path);

const wnn::Operand BuildConv(const wnn::GraphBuilder& builder,
const wnn::Operand& input,
int32_t convIndex,
bool relu6,
utils::Conv2dOptions* options,
const std::string& biasName = "");

std::vector<SHARED_DATA_TYPE> mConstants;
};
4 changes: 2 additions & 2 deletions src/webnn/native/dmlx/GraphDMLX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1270,7 +1270,7 @@ namespace webnn::native::dmlx {
DAWN_ASSERT(mExpression.find(inputOperand) != mExpression.end());
::dml::Expression input = mExpression.at(inputOperand);
auto newShape = reshape->GetNewShape();
if (newShape.size() > DML_TENSOR_DIMENSION_COUNT_MAX) {
if (newShape.size() > DML_TENSOR_DIMENSION_COUNT_MAX1) {
return DAWN_INTERNAL_ERROR("The size of new shape is not supported by DML.");
}
::dml::TensorDimensions newSizes(newShape.size());
Expand Down Expand Up @@ -1424,7 +1424,7 @@ namespace webnn::native::dmlx {
DAWN_ASSERT(mExpression.find(inputOperand) != mExpression.end());
::dml::Expression input = mExpression.at(inputOperand);
std::vector<int32_t> permutation = transpose->GetPermutation();
if (permutation.size() > DML_TENSOR_DIMENSION_COUNT_MAX) {
if (permutation.size() > DML_TENSOR_DIMENSION_COUNT_MAX1) {
return DAWN_INTERNAL_ERROR("The size of permutation is not supported by DML.");
}

Expand Down
5 changes: 4 additions & 1 deletion src/webnn/tests/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ test("webnn_unittests") {
"${webnn_root}/src/webnn:cpp",
"${webnn_root}/src/webnn:webnn_proc",
"${webnn_root}/src/webnn/common",
"${webnn_root}/src/webnn/native:webnn_native",
"${webnn_root}/src/webnn/native:sources",
"${webnn_root}/src/webnn/native:webnn_native",
"${webnn_root}/src/webnn/utils:webnn_utils",
]

Expand Down Expand Up @@ -205,6 +205,8 @@ source_set("webnn_end2end_tests_sources") {
"${webnn_root}/examples/ResNet/ResNet.h",
"${webnn_root}/examples/SqueezeNet/SqueezeNet.cpp",
"${webnn_root}/examples/SqueezeNet/SqueezeNet.h",
"${webnn_root}/examples/SuperResolution/SuperResolution.cpp",
"${webnn_root}/examples/SuperResolution/SuperResolution.h",
"WebnnTest.cpp",
"WebnnTest.h",
"end2end/AddTests.cpp",
Expand Down Expand Up @@ -249,6 +251,7 @@ source_set("webnn_end2end_tests_sources") {
"end2end/models/ResNetNhwc.cpp",
"end2end/models/SqueezeNetNchw.cpp",
"end2end/models/SqueezeNetNhwc.cpp",
"end2end/models/SuperResolutionNchw.cpp",
]

# Validation tests that need OS windows live in end2end tests.
Expand Down
43 changes: 43 additions & 0 deletions src/webnn/tests/end2end/models/SuperResolutionNchw.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright 2022 The WebNN-native Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "examples/SuperResolution/SuperResolution.h"
#include "webnn/tests/WebnnTest.h"

static const std::string kModelPath = WEBNN_END2END_TEST_MODEL_PATH;

class SuperResolutionNchwTests : public WebnnTest {
public:
void TestSuperResolutionNchw(const std::string& inputFile,
const std::string& expectedFile,
bool fused = true) {
SuperResolution superresolution;
superresolution.mFused = true;
const std::string nchwPath = kModelPath + "/super_resolution_nchw/";
superresolution.mWeightsPath = nchwPath + "weights/";
const wnn::GraphBuilder builder = wnn::CreateGraphBuilder(GetContext());
wnn::Operand output = superresolution.LoadNchw(builder, false);
wnn::Graph graph = utils::Build(builder, {{"output", output}});
const cnpy::NpyArray inputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + inputFile);
const std::vector<float> inputData = inputNpy.as_vec<float>();
std::vector<float> result(utils::SizeOfShape({/*TODO: batchSize?*/ 1, 1, 672, 672}));
utils::Compute(graph, {{"input", inputData}}, {{"output", result}});
const cnpy::NpyArray outputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + expectedFile);
EXPECT_TRUE(utils::CheckValue(result, outputNpy.as_vec<float>()));
}
};

TEST_F(SuperResolutionNchwTests, NchwTest0) {
TestSuperResolutionNchw("0/input_0.npy", "0/output_0.npy", false);
}

0 comments on commit 6b27f08

Please sign in to comment.