diff --git a/common/component/component.js b/common/component/component.js
index 977b4f43..7bea3523 100644
--- a/common/component/component.js
+++ b/common/component/component.js
@@ -606,6 +606,18 @@ $(document).ready(async () => {
"title",
"WebNN is supported, disable WebNN Polyfill."
);
+ // Disable WebNN NPU backend if failed to find a capable NPU adapter.
+ try {
+ await navigator.ml.createContext({deviceType: 'npu'});
+ } catch (error) {
+ $('#webnn_npu').parent().addClass('disabled');
+ $('#webnn_npu').parent().addClass('btn-outline-secondary');
+ $('#webnn_npu').parent().removeClass('btn-outline-info');
+ $('#webnn_npu').parent().attr(
+ "title",
+ "Unable to find a capable NPU adapter."
+ );
+ }
}
}
$("#webnnstatus").html("supported").addClass("webnn-status-true");
diff --git a/common/ui.js b/common/ui.js
index ec6a6cc7..1cc69fc1 100644
--- a/common/ui.js
+++ b/common/ui.js
@@ -97,6 +97,23 @@ export function handleClick(cssSelectors, disabled = true) {
}
}
+/**
+ * Handle button UI, disable or enable the button.
+ * @param {String} selector, css selector.
+ * @param {Boolean} disabled, disable or enable the button.
+ */
+export function handleBtnUI(selector, disabled = true) {
+ if (disabled) {
+ $(selector).addClass('disabled');
+ $(selector).addClass('btn-outline-secondary');
+ $(selector).removeClass('btn-outline-info');
+ } else {
+ $(selector).removeClass('disabled');
+ $(selector).removeClass('btn-outline-secondary');
+ $(selector).addClass('btn-outline-info');
+ }
+}
+
/**
* Show flexible alert messages
* @param {String} msg, alert message.
diff --git a/image_classification/.eslintrc.js b/image_classification/.eslintrc.js
index 41955769..c02d313a 100644
--- a/image_classification/.eslintrc.js
+++ b/image_classification/.eslintrc.js
@@ -1,5 +1,6 @@
module.exports = {
globals: {
'MLGraphBuilder': 'readonly',
+ 'tf': 'readonly',
},
};
diff --git a/image_classification/efficientnet_fp16_nchw.js b/image_classification/efficientnet_fp16_nchw.js
new file mode 100644
index 00000000..9f0d87f4
--- /dev/null
+++ b/image_classification/efficientnet_fp16_nchw.js
@@ -0,0 +1,176 @@
+'use strict';
+
+import {buildConstantByNpy, weightsOrigin} from '../common/utils.js';
+
+// EfficientNet fp16 model with 'nchw' input layout
+export class EfficientNetFP16Nchw {
+ constructor() {
+ this.context_ = null;
+ this.builder_ = null;
+ this.graph_ = null;
+ this.weightsUrl_ = weightsOrigin() +
+ '/test-data/models/efficientnet_fp16_nchw_optimized/weights/';
+ this.inputOptions = {
+ mean: [0.485, 0.456, 0.406],
+ std: [0.229, 0.224, 0.225],
+ norm: true,
+ inputLayout: 'nchw',
+ labelUrl: './labels/labels1000.txt',
+ inputDimensions: [1, 3, 224, 224],
+ };
+ this.outputDimensions = [1, 1000];
+ }
+
+ async buildConv_(input, name, blockName, clip = false, options = {}) {
+ let prefix = '';
+ if (blockName !== '') {
+ prefix = this.weightsUrl_ + 'block' + blockName + '_conv' +
+ name;
+ } else {
+ prefix = this.weightsUrl_ + 'conv' + name;
+ }
+ const weight = buildConstantByNpy(this.builder_, prefix + '_w.npy',
+ 'float16');
+ options.bias = await buildConstantByNpy(this.builder_, prefix + '_b.npy',
+ 'float16');
+ if (clip) {
+ return this.builder_.clamp(
+ this.builder_.conv2d(await input, await weight, options),
+ {minValue: 0, maxValue: 6});
+ }
+ return this.builder_.conv2d(await input, await weight, options);
+ }
+
+ async buildGemm_(input, name) {
+ const prefix = this.weightsUrl_ + 'dense' + name;
+ const weightName = prefix + '_w.npy';
+ const weight = buildConstantByNpy(this.builder_, weightName,
+ 'float16');
+ const biasName = prefix + '_b.npy';
+ const bias = buildConstantByNpy(this.builder_, biasName,
+ 'float16');
+ const options =
+ {c: this.builder_.reshape(await bias, [1, 1000])};
+ return await this.builder_.gemm(await input, await weight, options);
+ }
+
+ async buildBottleneck_(input, blockName, group, pad = 1) {
+ const conv1 = this.buildConv_(input, '0', blockName, true);
+ const conv2 = this.buildConv_(conv1, '1', blockName, true,
+ {groups: group, padding: [pad, pad, pad, pad]});
+ const conv3 = this.buildConv_(conv2, '2', blockName);
+ return this.builder_.add(await conv3, await input);
+ }
+
+ async buildBottlenecks_(input, blockNames, group, pad = 1) {
+ let result = input;
+ for (let i = 0; i < blockNames.length; i++) {
+ const bottleneck = await this.buildBottleneck_(result, blockNames[i],
+ group, pad);
+ result = bottleneck;
+ }
+ return result;
+ }
+
+ async load(contextOptions) {
+ this.context_ = await navigator.ml.createContext(contextOptions);
+ this.builder_ = new MLGraphBuilder(this.context_);
+ let data = this.builder_.input('input', {
+ dataType: 'float32',
+ dimensions: this.inputOptions.inputDimensions,
+ });
+ data = this.builder_.cast(data, 'float16');
+ // Block 0
+ const conv1 = this.buildConv_(
+ data, '0', '0', true, {padding: [0, 1, 0, 1], strides: [2, 2]});
+ const conv2 = this.buildConv_(conv1, '1', '0', true,
+ {groups: 32, padding: [1, 1, 1, 1]});
+ const conv3 = this.buildConv_(conv2, '2', '0');
+
+ // Block 1
+ const conv4 = this.buildConv_(conv3, '0', '1', true);
+ const conv5 = this.buildConv_(conv4, '1', '1', true,
+ {groups: 144, padding: [0, 1, 0, 1], strides: [2, 2]});
+ const conv6 = this.buildConv_(conv5, '2', '1');
+
+ // Block 2~4
+ const bottleneck4 = this.buildBottlenecks_(conv6,
+ ['2', '3', '4'], 192);
+
+ // Block 5
+ const conv7 = this.buildConv_(bottleneck4, '0', '5', true);
+ const conv8 = this.buildConv_(conv7, '1', '5', true,
+ {groups: 192, padding: [1, 2, 1, 2], strides: [2, 2]});
+ const conv9 = this.buildConv_(conv8, '2', '5');
+
+ // Block 6~8
+ const bottleneck8 = this.buildBottlenecks_(conv9,
+ ['6', '7', '8'], 336, 2);
+
+ // Block 9
+ const conv10 = this.buildConv_(bottleneck8, '0', '9', true);
+ const conv11 = this.buildConv_(conv10, '1', '9', true,
+ {groups: 336, padding: [0, 1, 0, 1], strides: [2, 2]});
+ const conv12 = this.buildConv_(conv11, '2', '9');
+
+ // Block 10~14
+ const bottleneck14 = this.buildBottlenecks_(conv12,
+ ['10', '11', '12', '13', '14'], 672);
+
+ // Block 15
+ const conv13 = this.buildConv_(bottleneck14, '0', '15', true);
+ const conv14 = this.buildConv_(conv13, '1', '15', true,
+ {groups: 672, padding: [2, 2, 2, 2]});
+ const conv15 = this.buildConv_(conv14, '2', '15');
+
+ // Block 16~20
+ const bottleneck20 = await this.buildBottlenecks_(conv15,
+ ['16', '17', '18', '19', '20'], 960, 2);
+
+ // Block 21
+ const conv16 = this.buildConv_(bottleneck20, '0', '21', true);
+ const conv17 = this.buildConv_(conv16, '1', '21', true,
+ {groups: 960, padding: [1, 2, 1, 2], strides: [2, 2]});
+ const conv18 = this.buildConv_(conv17, '2', '21');
+
+ // Block 22~28
+ const bottleneck28 = this.buildBottlenecks_(conv18,
+ ['22', '23', '24', '25', '26', '27', '28'], 1632, 2);
+
+ // Block 29
+ const conv19 = this.buildConv_(bottleneck28, '0', '29', true);
+ const conv20 = this.buildConv_(conv19, '1', '29', true,
+ {groups: 1632, padding: [1, 1, 1, 1]});
+ const conv21 = this.buildConv_(conv20, '2', '29');
+
+ const conv22 = this.buildConv_(conv21, '0', '', true);
+ const pool1 = this.builder_.averagePool2d(await conv22);
+ const reshape = this.builder_.reshape(pool1, [1, 1280]);
+ const gemm = this.buildGemm_(reshape, '0');
+ if (contextOptions.deviceType === 'npu') {
+ return this.builder_.cast(await gemm, 'float32');
+ } else {
+ const softmax = this.builder_.softmax(await gemm);
+ return this.builder_.cast(softmax, 'float32');
+ }
+ }
+
+ async build(outputOperand) {
+ this.graph_ = await this.builder_.build({'output': outputOperand});
+ }
+
+ // Release the constant tensors of a model
+ dispose() {
+ // dispose() is only available in webnn-polyfill
+ if (this.graph_ !== null && 'dispose' in this.graph_) {
+ this.graph_.dispose();
+ }
+ }
+
+ async compute(inputBuffer, outputBuffer) {
+ const inputs = {'input': inputBuffer};
+ const outputs = {'output': outputBuffer};
+ const results = await this.context_.compute(this.graph_, inputs, outputs);
+ return results;
+ }
+}
diff --git a/image_classification/index.html b/image_classification/index.html
index 94394826..fc9a0d43 100644
--- a/image_classification/index.html
+++ b/image_classification/index.html
@@ -43,6 +43,9 @@
+
@@ -61,21 +64,43 @@
-->
+
@@ -213,6 +238,9 @@ No model selected
+
diff --git a/image_classification/main.js b/image_classification/main.js
index 5e7d74b5..ce8839b0 100644
--- a/image_classification/main.js
+++ b/image_classification/main.js
@@ -1,5 +1,7 @@
'use strict';
+import {ResNet50V1FP16Nchw} from './resnet50v1_fp16_nchw.js';
+import {EfficientNetFP16Nchw} from './efficientnet_fp16_nchw.js';
import {MobileNetV2Nchw} from './mobilenet_nchw.js';
import {MobileNetV2Nhwc} from './mobilenet_nhwc.js';
import {SqueezeNetNchw} from './squeezenet_nchw.js';
@@ -15,7 +17,9 @@ const imgElement = document.getElementById('feedElement');
imgElement.src = './images/test.jpg';
const camElement = document.getElementById('feedMediaElement');
let modelName = '';
+let modelId = '';
let layout = 'nhwc';
+let dataType = 'float32';
let instanceType = modelName + layout;
let rafReq;
let isFirstTimeLoad = true;
@@ -35,6 +39,41 @@ let lastBackend = '';
let stopRender = true;
let isRendering = false;
const disabledSelectors = ['#tabs > li', '.btn'];
+const modelIds = [
+ 'mobilenet',
+ 'squeezenet',
+ 'resnet50v2',
+ 'resnet50v1',
+ 'efficientnet',
+];
+const modelList = {
+ 'cpu': {
+ 'float32': [
+ 'mobilenet',
+ 'squeezenet',
+ 'resnet50v2',
+ ],
+ },
+ 'gpu': {
+ 'float32': [
+ 'mobilenet',
+ 'squeezenet',
+ 'resnet50v2',
+ ],
+ 'float16': [
+ 'efficientnet',
+ 'mobilenet',
+ 'resnet50v1',
+ ],
+ },
+ 'npu': {
+ 'float16': [
+ 'efficientnet',
+ 'mobilenet',
+ 'resnet50v1',
+ ],
+ },
+};
async function fetchLabels(url) {
const response = await fetch(url);
@@ -42,6 +81,26 @@ async function fetchLabels(url) {
return data.split('\n');
}
+function displayAvailableModels(modelList, deviceType, dataType) {
+ let models = [];
+ if (dataType == '') {
+ models = models.concat(modelList[deviceType]['float32']);
+ models = models.concat(modelList[deviceType]['float16']);
+ } else {
+ models = models.concat(modelList[deviceType][dataType]);
+ }
+ // Remove duplicate ids.
+ models = [...new Set(models)];
+ // Display available models.
+ for (const model of modelIds) {
+ if (models.includes(model)) {
+ $(`#${model}`).parent().show();
+ } else {
+ $(`#${model}`).parent().hide();
+ }
+ }
+}
+
$(document).ready(async () => {
$('.icdisplay').hide();
if (await utils.isWebNN()) {
@@ -56,14 +115,39 @@ $('#backendBtns .btn').on('change', async (e) => {
await stopCamRender();
}
layout = utils.getDefaultLayout($(e.target).attr('id'));
- await main();
+ [backend, deviceType] = $(e.target).attr('id').split('_');
+ // Only show the supported models for each deviceType. Now fp16 nchw models
+ // are only supported on gpu/npu.
+ if (deviceType == 'gpu') {
+ ui.handleBtnUI('#float16Label', false);
+ ui.handleBtnUI('#float32Label', false);
+ displayAvailableModels(modelList, deviceType, dataType);
+ } else if (deviceType == 'npu') {
+ ui.handleBtnUI('#float16Label', false);
+ ui.handleBtnUI('#float32Label', true);
+ displayAvailableModels(modelList, deviceType, 'float16');
+ } else {
+ ui.handleBtnUI('#float16Label', true);
+ ui.handleBtnUI('#float32Label', false);
+ displayAvailableModels(modelList, deviceType, 'float32');
+ }
+
+ // Uncheck selected model
+ if (modelId != '') {
+ $(`#${modelId}`).parent().removeClass('active');
+ }
});
$('#modelBtns .btn').on('change', async (e) => {
if (inputType === 'camera') {
await stopCamRender();
}
- modelName = $(e.target).attr('id');
+ modelId = $(e.target).attr('id');
+ modelName = modelId;
+ if (dataType == 'float16') {
+ modelName += 'fp16';
+ }
+
await main();
});
@@ -75,6 +159,16 @@ $('#modelBtns .btn').on('change', async (e) => {
// await main();
// });
+$('#dataTypeBtns .btn').on('change', async (e) => {
+ dataType = $(e.target).attr('id');
+ displayAvailableModels(modelList, deviceType, dataType);
+ // Uncheck selected model
+ if (modelId != '') {
+ $(`#${modelId}`).parent().removeClass('active');
+ }
+});
+
+
// Click trigger to do inference with element
$('#img').click(async () => {
if (inputType === 'camera') {
@@ -154,6 +248,18 @@ async function renderCamStream() {
// Get top 3 classes of labels from output buffer
function getTopClasses(buffer, labels) {
+ // Currently we need to fallback softmax to tf.softmax because
+ // NPU dosen't support softmax.
+ // TODO: Remove this workaround once NPU supports softmax.
+ if (deviceType === 'npu') {
+ // Softmax
+ buffer = tf.tidy(() => {
+ const a =
+ tf.tensor(buffer, netInstance.outputDimensions, 'float32');
+ const b = tf.softmax(a);
+ return b.dataSync();
+ });
+ }
const probs = Array.from(buffer);
const indexes = probs.map((prob, index) => [prob, index]);
const sorted = indexes.sort((a, b) => {
@@ -216,12 +322,15 @@ function showPerfResult(medianComputeTime = undefined) {
function constructNetObject(type) {
const netObject = {
+ 'mobilenetfp16nchw': new MobileNetV2Nchw('float16'),
+ 'resnet50v1fp16nchw': new ResNet50V1FP16Nchw(),
+ 'efficientnetfp16nchw': new EfficientNetFP16Nchw(),
'mobilenetnchw': new MobileNetV2Nchw(),
'mobilenetnhwc': new MobileNetV2Nhwc(),
'squeezenetnchw': new SqueezeNetNchw(),
'squeezenetnhwc': new SqueezeNetNhwc(),
- 'resnet50nchw': new ResNet50V2Nchw(),
- 'resnet50nhwc': new ResNet50V2Nhwc(),
+ 'resnet50v2nchw': new ResNet50V2Nchw(),
+ 'resnet50v2nhwc': new ResNet50V2Nhwc(),
};
return netObject[type];
@@ -230,8 +339,6 @@ function constructNetObject(type) {
async function main() {
try {
if (modelName === '') return;
- [backend, deviceType] =
- $('input[name="backend"]:checked').attr('id').split('_');
ui.handleClick(disabledSelectors, true);
if (isFirstTimeLoad) $('#hint').hide();
let start;
@@ -245,7 +352,7 @@ async function main() {
// Set backend and device
await utils.setBackend(backend, deviceType);
lastdeviceType = lastdeviceType != deviceType ?
- deviceType : lastdeviceType;
+ deviceType : lastdeviceType;
lastBackend = lastBackend != backend ? backend : lastBackend;
}
if (netInstance !== null) {
diff --git a/image_classification/mobilenet_nchw.js b/image_classification/mobilenet_nchw.js
index c62cb51e..6b52350b 100644
--- a/image_classification/mobilenet_nchw.js
+++ b/image_classification/mobilenet_nchw.js
@@ -4,13 +4,21 @@ import {buildConstantByNpy, weightsOrigin} from '../common/utils.js';
// MobileNet V2 model with 'nchw' input layout
export class MobileNetV2Nchw {
- constructor() {
+ constructor(dataType = 'float32') {
this.context_ = null;
this.deviceType_ = null;
this.builder_ = null;
this.graph_ = null;
- this.weightsUrl_ = weightsOrigin() +
- '/test-data/models/mobilenetv2_nchw/weights/';
+ this.dataType_ = dataType;
+ this.weightsUrl_ = weightsOrigin();
+ if (this.dataType_ === 'float32') {
+ this.weightsUrl_ += '/test-data/models/mobilenetv2_nchw/weights/';
+ } else if (this.dataType_ === 'float16') {
+ this.weightsUrl_ +=
+ '/test-data/models/mobilenetv2_fp16_nchw_optimized/weights/';
+ } else {
+ throw new Error(`Unsupported dataType: ${this.dataType_}`);
+ }
this.inputOptions = {
mean: [0.485, 0.456, 0.406],
std: [0.229, 0.224, 0.225],
@@ -23,17 +31,27 @@ export class MobileNetV2Nchw {
}
async buildConv_(input, name, relu6 = true, options = {}) {
- const prefix = this.weightsUrl_ + 'conv_' + name;
- const weightsName = prefix + '_weight.npy';
- const weights = buildConstantByNpy(this.builder_, weightsName);
- const biasName = prefix + '_bias.npy';
- const bias = buildConstantByNpy(this.builder_, biasName);
- options.bias = await bias;
+ let weights;
+ if (this.dataType_ === 'float32') {
+ weights = buildConstantByNpy(this.builder_,
+ `${this.weightsUrl_}conv_${name}_weight.npy`);
+ options.bias = await buildConstantByNpy(this.builder_,
+ `${this.weightsUrl_}conv_${name}_bias.npy`);
+ } else {
+ weights = buildConstantByNpy(this.builder_,
+ `${this.weightsUrl_}w${name}.npy`, this.dataType_);
+ // Only node 97 has no bias input
+ if (name !== '97') {
+ options.bias = await buildConstantByNpy(this.builder_,
+ `${this.weightsUrl_}b${name}.npy`, this.dataType_);
+ }
+ }
+
if (relu6) {
// TODO: Set clamp activation to options once it's supported in
// WebNN DML backend.
// Implement `clip` by `clamp` of WebNN API
- if (this.deviceType_ == 'gpu') {
+ if (this.deviceType_ == 'gpu' || this.deviceType_ == 'npu') {
return this.builder_.clamp(
this.builder_.conv2d(await input, await weights, options),
{minValue: 0, maxValue: 6});
@@ -47,9 +65,11 @@ export class MobileNetV2Nchw {
async buildGemm_(input, name) {
const prefix = this.weightsUrl_ + 'gemm_' + name;
const weightsName = prefix + '_weight.npy';
- const weights = buildConstantByNpy(this.builder_, weightsName);
+ const weights = buildConstantByNpy(this.builder_, weightsName,
+ this.dataType_);
const biasName = prefix + '_bias.npy';
- const bias = buildConstantByNpy(this.builder_, biasName);
+ const bias = buildConstantByNpy(this.builder_, biasName,
+ this.dataType_);
const options = {c: await bias, bTranspose: true};
return this.builder_.gemm(await input, await weights, options);
}
@@ -63,25 +83,27 @@ export class MobileNetV2Nchw {
strides: [stride, stride],
};
const dwise3x3Relu6 = this.buildConv_(
- await conv1x1Relu6, convNameArray[1], true, options);
+ conv1x1Relu6, convNameArray[1], true, options);
const conv1x1Linear = this.buildConv_(
- await dwise3x3Relu6, convNameArray[2], false);
+ dwise3x3Relu6, convNameArray[2], false);
if (shortcut) {
return this.builder_.add(await input, await conv1x1Linear);
}
- return await conv1x1Linear;
+ return conv1x1Linear;
}
async load(contextOptions) {
this.context_ = await navigator.ml.createContext(contextOptions);
this.deviceType_ = contextOptions.deviceType;
this.builder_ = new MLGraphBuilder(this.context_);
- const data = this.builder_.input('input', {
- type: 'float32',
+ let data = this.builder_.input('input', {
dataType: 'float32',
dimensions: this.inputOptions.inputDimensions,
});
+ if (this.dataType_ === 'float16') {
+ data = this.builder_.cast(data, 'float16');
+ }
const conv0 = this.buildConv_(
data, '0', true, {padding: [1, 1, 1, 1], strides: [2, 2]});
const conv1 = this.buildConv_(
@@ -121,10 +143,23 @@ export class MobileNetV2Nchw {
bottleneck14, ['90', '92', '94'], 960, 1, false);
const conv3 = this.buildConv_(bottleneck15, '95', true);
- const pool = this.builder_.averagePool2d(await conv3);
- const reshape = this.builder_.reshape(pool, [1, 1280]);
- const gemm = this.buildGemm_(reshape, '104');
- return await this.builder_.softmax(await gemm);
+ if (this.dataType_ == 'float32') {
+ const pool = this.builder_.averagePool2d(await conv3);
+ const reshape = this.builder_.reshape(pool, [1, 1280]);
+ const gemm = this.buildGemm_(reshape, '104');
+ return this.builder_.softmax(await gemm);
+ } else {
+ const conv4 = this.buildConv_(await conv3, '97', false,
+ {groups: 1280, strides: [7, 7]});
+ const conv5 = this.buildConv_(await conv4, '104', false);
+ const reshape = this.builder_.reshape(await conv5, [1, 1000]);
+ if (contextOptions.deviceType === 'npu') {
+ return this.builder_.cast(reshape, 'float32');
+ } else {
+ const softmax = this.builder_.softmax(reshape);
+ return this.builder_.cast(softmax, 'float32');
+ }
+ }
}
async build(outputOperand) {
diff --git a/image_classification/resnet50v1_fp16_nchw.js b/image_classification/resnet50v1_fp16_nchw.js
new file mode 100644
index 00000000..58ff60be
--- /dev/null
+++ b/image_classification/resnet50v1_fp16_nchw.js
@@ -0,0 +1,152 @@
+'use strict';
+
+import {buildConstantByNpy, weightsOrigin} from '../common/utils.js';
+
+// ResNet50 V1 fp16 model with 'nchw' input layout
+export class ResNet50V1FP16Nchw {
+ constructor() {
+ this.context_ = null;
+ this.builder_ = null;
+ this.graph_ = null;
+ this.weightsUrl_ = weightsOrigin() +
+ '/test-data/models/resnet50v1_fp16_nchw_optimized/weights/';
+ this.inputOptions = {
+ mean: [0.485, 0.456, 0.406],
+ std: [0.229, 0.224, 0.225],
+ norm: true,
+ inputLayout: 'nchw',
+ labelUrl: './labels/labels1000.txt',
+ inputDimensions: [1, 3, 224, 224],
+ };
+ this.outputDimensions = [1, 1000];
+ }
+
+ async buildConv_(input, name, stageName, relu, options = undefined) {
+ let prefix = '';
+ if (stageName !== '') {
+ prefix = this.weightsUrl_ + 'stage' + stageName + '_conv' +
+ name;
+ } else {
+ prefix = this.weightsUrl_ + 'conv' + name;
+ }
+ const weight = buildConstantByNpy(this.builder_, prefix + '_w.npy',
+ 'float16');
+ options.bias = await buildConstantByNpy(this.builder_, prefix + '_b.npy',
+ 'float16');
+ if (relu) {
+ options.activation = this.builder_.relu();
+ }
+
+ return this.builder_.conv2d(await input, await weight, options);
+ }
+
+ async buildGemm_(input, name) {
+ const prefix = this.weightsUrl_ + 'dense' + name;
+ const weightName = prefix + '_w.npy';
+ const weight = buildConstantByNpy(this.builder_, weightName,
+ 'float16');
+ const biasName = prefix + '_b.npy';
+ const bias = buildConstantByNpy(this.builder_, biasName,
+ 'float16');
+ const options =
+ {c: this.builder_.reshape(await bias, [1, 1000]), bTranspose: true};
+ return this.builder_.gemm(await input, await weight, options);
+ }
+
+ async buildBottleneck_(
+ input, stageName, nameIndex, downsample = false, stride = 1) {
+ let residual = input;
+ let strides = [1, 1];
+
+ if (downsample) {
+ strides = [stride, stride];
+ }
+ const conv1 = this.buildConv_(input, nameIndex,
+ stageName, true, {strides});
+ const conv2 = this.buildConv_(conv1, parseInt(nameIndex) + 1,
+ stageName, true, {padding: [1, 1, 1, 1]});
+ const conv3 = this.buildConv_(conv2,
+ parseInt(nameIndex) + 2, stageName, false, {});
+ if (downsample) {
+ residual = this.buildConv_(
+ input, parseInt(nameIndex) + 3, stageName, false, {strides});
+ }
+ const add = this.builder_.add(await conv3, await residual);
+ return this.builder_.relu(add);
+ }
+
+ async load(contextOptions) {
+ this.context_ = await navigator.ml.createContext(contextOptions);
+ this.builder_ = new MLGraphBuilder(this.context_);
+ let data = this.builder_.input('input', {
+ dataType: 'float32',
+ dimensions: this.inputOptions.inputDimensions,
+ });
+ data = this.builder_.cast(data, 'float16');
+ const conv1 = await this.buildConv_(
+ data, '0', '', true, {padding: [3, 3, 3, 3], strides: [2, 2]});
+ const pool1 = this.builder_.maxPool2d(conv1,
+ {windowDimensions: [3, 3], padding: [1, 1, 1, 1], strides: [2, 2]});
+
+ // Stage 1
+ const bottleneck1 = this.buildBottleneck_(pool1, '1', '0', true);
+ const bottleneck2 = this.buildBottleneck_(bottleneck1, '1', '4');
+ const bottleneck3 = this.buildBottleneck_(bottleneck2, '1', '7');
+
+ // Stage 2
+ const bottleneck4 = this.buildBottleneck_(bottleneck3, '2', '0',
+ true, 2);
+ const bottleneck5 = this.buildBottleneck_(bottleneck4, '2', '4');
+ const bottleneck6 = this.buildBottleneck_(bottleneck5, '2', '7');
+ const bottleneck7 = this.buildBottleneck_(bottleneck6, '2', '10');
+
+ // Stage 3
+ const bottleneck8 = this.buildBottleneck_(bottleneck7, '3', '0',
+ true, 2);
+ const bottleneck9 = this.buildBottleneck_(bottleneck8, '3', '4');
+ const bottleneck10 = this.buildBottleneck_(bottleneck9, '3', '7');
+ const bottleneck11 = this.buildBottleneck_(bottleneck10, '3', '10');
+ const bottleneck12 = this.buildBottleneck_(bottleneck11, '3', '13');
+ const bottleneck13 = this.buildBottleneck_(bottleneck12, '3', '16');
+
+ // Stage 4
+ const bottleneck14 = this.buildBottleneck_(bottleneck13, '4', '0',
+ true, 2);
+ const bottleneck15 = this.buildBottleneck_(bottleneck14, '4', '4');
+ const bottleneck16 = this.buildBottleneck_(bottleneck15, '4', '7');
+
+ const pool2 = this.builder_.averagePool2d(await bottleneck16);
+ const reshape = this.builder_.reshape(pool2, [1, 2048]);
+ const gemm = this.buildGemm_(reshape, '0');
+ if (contextOptions.deviceType === 'npu') {
+ return this.builder_.cast(await gemm, 'float32');
+ } else {
+ const softmax = this.builder_.softmax(await gemm);
+ return this.builder_.cast(softmax, 'float32');
+ }
+ }
+
+ async build(outputOperand) {
+ this.graph_ = this.builder_.build({'output': outputOperand});
+ }
+
+ // Release the constant tensors of a model
+ dispose() {
+ // dispose() is only available in webnn-polyfill
+ if (this.graph_ !== null && 'dispose' in this.graph_) {
+ this.graph_.dispose();
+ }
+ }
+
+ // Release the constant tensors of a model
+ async compute(inputBuffer, outputBuffer) {
+ const inputs = {'input': inputBuffer};
+ const outputs = {'output': outputBuffer};
+ const results = await this.context_.compute(
+ await this.graph_,
+ inputs,
+ outputs,
+ );
+ return results;
+ }
+}
diff --git a/test-data b/test-data
index 045017d3..75634f2e 160000
--- a/test-data
+++ b/test-data
@@ -1 +1 @@
-Subproject commit 045017d38ea0133807fa26af9e5b030147cb2314
+Subproject commit 75634f2e6ac2be244eab703c9dd7497a0fa23ab6