diff --git a/common/component/component.js b/common/component/component.js index 977b4f43..7bea3523 100644 --- a/common/component/component.js +++ b/common/component/component.js @@ -606,6 +606,18 @@ $(document).ready(async () => { "title", "WebNN is supported, disable WebNN Polyfill." ); + // Disable WebNN NPU backend if failed to find a capable NPU adapter. + try { + await navigator.ml.createContext({deviceType: 'npu'}); + } catch (error) { + $('#webnn_npu').parent().addClass('disabled'); + $('#webnn_npu').parent().addClass('btn-outline-secondary'); + $('#webnn_npu').parent().removeClass('btn-outline-info'); + $('#webnn_npu').parent().attr( + "title", + "Unable to find a capable NPU adapter." + ); + } } } $("#webnnstatus").html("supported").addClass("webnn-status-true"); diff --git a/common/ui.js b/common/ui.js index ec6a6cc7..1cc69fc1 100644 --- a/common/ui.js +++ b/common/ui.js @@ -97,6 +97,23 @@ export function handleClick(cssSelectors, disabled = true) { } } +/** + * Handle button UI, disable or enable the button. + * @param {String} selector, css selector. + * @param {Boolean} disabled, disable or enable the button. + */ +export function handleBtnUI(selector, disabled = true) { + if (disabled) { + $(selector).addClass('disabled'); + $(selector).addClass('btn-outline-secondary'); + $(selector).removeClass('btn-outline-info'); + } else { + $(selector).removeClass('disabled'); + $(selector).removeClass('btn-outline-secondary'); + $(selector).addClass('btn-outline-info'); + } +} + /** * Show flexible alert messages * @param {String} msg, alert message. diff --git a/image_classification/.eslintrc.js b/image_classification/.eslintrc.js index 41955769..c02d313a 100644 --- a/image_classification/.eslintrc.js +++ b/image_classification/.eslintrc.js @@ -1,5 +1,6 @@ module.exports = { globals: { 'MLGraphBuilder': 'readonly', + 'tf': 'readonly', }, }; diff --git a/image_classification/efficientnet_fp16_nchw.js b/image_classification/efficientnet_fp16_nchw.js new file mode 100644 index 00000000..9f0d87f4 --- /dev/null +++ b/image_classification/efficientnet_fp16_nchw.js @@ -0,0 +1,176 @@ +'use strict'; + +import {buildConstantByNpy, weightsOrigin} from '../common/utils.js'; + +// EfficientNet fp16 model with 'nchw' input layout +export class EfficientNetFP16Nchw { + constructor() { + this.context_ = null; + this.builder_ = null; + this.graph_ = null; + this.weightsUrl_ = weightsOrigin() + + '/test-data/models/efficientnet_fp16_nchw_optimized/weights/'; + this.inputOptions = { + mean: [0.485, 0.456, 0.406], + std: [0.229, 0.224, 0.225], + norm: true, + inputLayout: 'nchw', + labelUrl: './labels/labels1000.txt', + inputDimensions: [1, 3, 224, 224], + }; + this.outputDimensions = [1, 1000]; + } + + async buildConv_(input, name, blockName, clip = false, options = {}) { + let prefix = ''; + if (blockName !== '') { + prefix = this.weightsUrl_ + 'block' + blockName + '_conv' + + name; + } else { + prefix = this.weightsUrl_ + 'conv' + name; + } + const weight = buildConstantByNpy(this.builder_, prefix + '_w.npy', + 'float16'); + options.bias = await buildConstantByNpy(this.builder_, prefix + '_b.npy', + 'float16'); + if (clip) { + return this.builder_.clamp( + this.builder_.conv2d(await input, await weight, options), + {minValue: 0, maxValue: 6}); + } + return this.builder_.conv2d(await input, await weight, options); + } + + async buildGemm_(input, name) { + const prefix = this.weightsUrl_ + 'dense' + name; + const weightName = prefix + '_w.npy'; + const weight = buildConstantByNpy(this.builder_, weightName, + 'float16'); + const biasName = prefix + '_b.npy'; + const bias = buildConstantByNpy(this.builder_, biasName, + 'float16'); + const options = + {c: this.builder_.reshape(await bias, [1, 1000])}; + return await this.builder_.gemm(await input, await weight, options); + } + + async buildBottleneck_(input, blockName, group, pad = 1) { + const conv1 = this.buildConv_(input, '0', blockName, true); + const conv2 = this.buildConv_(conv1, '1', blockName, true, + {groups: group, padding: [pad, pad, pad, pad]}); + const conv3 = this.buildConv_(conv2, '2', blockName); + return this.builder_.add(await conv3, await input); + } + + async buildBottlenecks_(input, blockNames, group, pad = 1) { + let result = input; + for (let i = 0; i < blockNames.length; i++) { + const bottleneck = await this.buildBottleneck_(result, blockNames[i], + group, pad); + result = bottleneck; + } + return result; + } + + async load(contextOptions) { + this.context_ = await navigator.ml.createContext(contextOptions); + this.builder_ = new MLGraphBuilder(this.context_); + let data = this.builder_.input('input', { + dataType: 'float32', + dimensions: this.inputOptions.inputDimensions, + }); + data = this.builder_.cast(data, 'float16'); + // Block 0 + const conv1 = this.buildConv_( + data, '0', '0', true, {padding: [0, 1, 0, 1], strides: [2, 2]}); + const conv2 = this.buildConv_(conv1, '1', '0', true, + {groups: 32, padding: [1, 1, 1, 1]}); + const conv3 = this.buildConv_(conv2, '2', '0'); + + // Block 1 + const conv4 = this.buildConv_(conv3, '0', '1', true); + const conv5 = this.buildConv_(conv4, '1', '1', true, + {groups: 144, padding: [0, 1, 0, 1], strides: [2, 2]}); + const conv6 = this.buildConv_(conv5, '2', '1'); + + // Block 2~4 + const bottleneck4 = this.buildBottlenecks_(conv6, + ['2', '3', '4'], 192); + + // Block 5 + const conv7 = this.buildConv_(bottleneck4, '0', '5', true); + const conv8 = this.buildConv_(conv7, '1', '5', true, + {groups: 192, padding: [1, 2, 1, 2], strides: [2, 2]}); + const conv9 = this.buildConv_(conv8, '2', '5'); + + // Block 6~8 + const bottleneck8 = this.buildBottlenecks_(conv9, + ['6', '7', '8'], 336, 2); + + // Block 9 + const conv10 = this.buildConv_(bottleneck8, '0', '9', true); + const conv11 = this.buildConv_(conv10, '1', '9', true, + {groups: 336, padding: [0, 1, 0, 1], strides: [2, 2]}); + const conv12 = this.buildConv_(conv11, '2', '9'); + + // Block 10~14 + const bottleneck14 = this.buildBottlenecks_(conv12, + ['10', '11', '12', '13', '14'], 672); + + // Block 15 + const conv13 = this.buildConv_(bottleneck14, '0', '15', true); + const conv14 = this.buildConv_(conv13, '1', '15', true, + {groups: 672, padding: [2, 2, 2, 2]}); + const conv15 = this.buildConv_(conv14, '2', '15'); + + // Block 16~20 + const bottleneck20 = await this.buildBottlenecks_(conv15, + ['16', '17', '18', '19', '20'], 960, 2); + + // Block 21 + const conv16 = this.buildConv_(bottleneck20, '0', '21', true); + const conv17 = this.buildConv_(conv16, '1', '21', true, + {groups: 960, padding: [1, 2, 1, 2], strides: [2, 2]}); + const conv18 = this.buildConv_(conv17, '2', '21'); + + // Block 22~28 + const bottleneck28 = this.buildBottlenecks_(conv18, + ['22', '23', '24', '25', '26', '27', '28'], 1632, 2); + + // Block 29 + const conv19 = this.buildConv_(bottleneck28, '0', '29', true); + const conv20 = this.buildConv_(conv19, '1', '29', true, + {groups: 1632, padding: [1, 1, 1, 1]}); + const conv21 = this.buildConv_(conv20, '2', '29'); + + const conv22 = this.buildConv_(conv21, '0', '', true); + const pool1 = this.builder_.averagePool2d(await conv22); + const reshape = this.builder_.reshape(pool1, [1, 1280]); + const gemm = this.buildGemm_(reshape, '0'); + if (contextOptions.deviceType === 'npu') { + return this.builder_.cast(await gemm, 'float32'); + } else { + const softmax = this.builder_.softmax(await gemm); + return this.builder_.cast(softmax, 'float32'); + } + } + + async build(outputOperand) { + this.graph_ = await this.builder_.build({'output': outputOperand}); + } + + // Release the constant tensors of a model + dispose() { + // dispose() is only available in webnn-polyfill + if (this.graph_ !== null && 'dispose' in this.graph_) { + this.graph_.dispose(); + } + } + + async compute(inputBuffer, outputBuffer) { + const inputs = {'input': inputBuffer}; + const outputs = {'output': outputBuffer}; + const results = await this.context_.compute(this.graph_, inputs, outputs); + return results; + } +} diff --git a/image_classification/index.html b/image_classification/index.html index 94394826..fc9a0d43 100644 --- a/image_classification/index.html +++ b/image_classification/index.html @@ -43,6 +43,9 @@ + @@ -61,21 +64,43 @@ --> +
+
+ Data Type +
+
+
+ + +
+
+
Model
- - - + + + + + +
@@ -213,6 +238,9 @@

No model selected

+ diff --git a/image_classification/main.js b/image_classification/main.js index 5e7d74b5..ce8839b0 100644 --- a/image_classification/main.js +++ b/image_classification/main.js @@ -1,5 +1,7 @@ 'use strict'; +import {ResNet50V1FP16Nchw} from './resnet50v1_fp16_nchw.js'; +import {EfficientNetFP16Nchw} from './efficientnet_fp16_nchw.js'; import {MobileNetV2Nchw} from './mobilenet_nchw.js'; import {MobileNetV2Nhwc} from './mobilenet_nhwc.js'; import {SqueezeNetNchw} from './squeezenet_nchw.js'; @@ -15,7 +17,9 @@ const imgElement = document.getElementById('feedElement'); imgElement.src = './images/test.jpg'; const camElement = document.getElementById('feedMediaElement'); let modelName = ''; +let modelId = ''; let layout = 'nhwc'; +let dataType = 'float32'; let instanceType = modelName + layout; let rafReq; let isFirstTimeLoad = true; @@ -35,6 +39,41 @@ let lastBackend = ''; let stopRender = true; let isRendering = false; const disabledSelectors = ['#tabs > li', '.btn']; +const modelIds = [ + 'mobilenet', + 'squeezenet', + 'resnet50v2', + 'resnet50v1', + 'efficientnet', +]; +const modelList = { + 'cpu': { + 'float32': [ + 'mobilenet', + 'squeezenet', + 'resnet50v2', + ], + }, + 'gpu': { + 'float32': [ + 'mobilenet', + 'squeezenet', + 'resnet50v2', + ], + 'float16': [ + 'efficientnet', + 'mobilenet', + 'resnet50v1', + ], + }, + 'npu': { + 'float16': [ + 'efficientnet', + 'mobilenet', + 'resnet50v1', + ], + }, +}; async function fetchLabels(url) { const response = await fetch(url); @@ -42,6 +81,26 @@ async function fetchLabels(url) { return data.split('\n'); } +function displayAvailableModels(modelList, deviceType, dataType) { + let models = []; + if (dataType == '') { + models = models.concat(modelList[deviceType]['float32']); + models = models.concat(modelList[deviceType]['float16']); + } else { + models = models.concat(modelList[deviceType][dataType]); + } + // Remove duplicate ids. + models = [...new Set(models)]; + // Display available models. + for (const model of modelIds) { + if (models.includes(model)) { + $(`#${model}`).parent().show(); + } else { + $(`#${model}`).parent().hide(); + } + } +} + $(document).ready(async () => { $('.icdisplay').hide(); if (await utils.isWebNN()) { @@ -56,14 +115,39 @@ $('#backendBtns .btn').on('change', async (e) => { await stopCamRender(); } layout = utils.getDefaultLayout($(e.target).attr('id')); - await main(); + [backend, deviceType] = $(e.target).attr('id').split('_'); + // Only show the supported models for each deviceType. Now fp16 nchw models + // are only supported on gpu/npu. + if (deviceType == 'gpu') { + ui.handleBtnUI('#float16Label', false); + ui.handleBtnUI('#float32Label', false); + displayAvailableModels(modelList, deviceType, dataType); + } else if (deviceType == 'npu') { + ui.handleBtnUI('#float16Label', false); + ui.handleBtnUI('#float32Label', true); + displayAvailableModels(modelList, deviceType, 'float16'); + } else { + ui.handleBtnUI('#float16Label', true); + ui.handleBtnUI('#float32Label', false); + displayAvailableModels(modelList, deviceType, 'float32'); + } + + // Uncheck selected model + if (modelId != '') { + $(`#${modelId}`).parent().removeClass('active'); + } }); $('#modelBtns .btn').on('change', async (e) => { if (inputType === 'camera') { await stopCamRender(); } - modelName = $(e.target).attr('id'); + modelId = $(e.target).attr('id'); + modelName = modelId; + if (dataType == 'float16') { + modelName += 'fp16'; + } + await main(); }); @@ -75,6 +159,16 @@ $('#modelBtns .btn').on('change', async (e) => { // await main(); // }); +$('#dataTypeBtns .btn').on('change', async (e) => { + dataType = $(e.target).attr('id'); + displayAvailableModels(modelList, deviceType, dataType); + // Uncheck selected model + if (modelId != '') { + $(`#${modelId}`).parent().removeClass('active'); + } +}); + + // Click trigger to do inference with element $('#img').click(async () => { if (inputType === 'camera') { @@ -154,6 +248,18 @@ async function renderCamStream() { // Get top 3 classes of labels from output buffer function getTopClasses(buffer, labels) { + // Currently we need to fallback softmax to tf.softmax because + // NPU dosen't support softmax. + // TODO: Remove this workaround once NPU supports softmax. + if (deviceType === 'npu') { + // Softmax + buffer = tf.tidy(() => { + const a = + tf.tensor(buffer, netInstance.outputDimensions, 'float32'); + const b = tf.softmax(a); + return b.dataSync(); + }); + } const probs = Array.from(buffer); const indexes = probs.map((prob, index) => [prob, index]); const sorted = indexes.sort((a, b) => { @@ -216,12 +322,15 @@ function showPerfResult(medianComputeTime = undefined) { function constructNetObject(type) { const netObject = { + 'mobilenetfp16nchw': new MobileNetV2Nchw('float16'), + 'resnet50v1fp16nchw': new ResNet50V1FP16Nchw(), + 'efficientnetfp16nchw': new EfficientNetFP16Nchw(), 'mobilenetnchw': new MobileNetV2Nchw(), 'mobilenetnhwc': new MobileNetV2Nhwc(), 'squeezenetnchw': new SqueezeNetNchw(), 'squeezenetnhwc': new SqueezeNetNhwc(), - 'resnet50nchw': new ResNet50V2Nchw(), - 'resnet50nhwc': new ResNet50V2Nhwc(), + 'resnet50v2nchw': new ResNet50V2Nchw(), + 'resnet50v2nhwc': new ResNet50V2Nhwc(), }; return netObject[type]; @@ -230,8 +339,6 @@ function constructNetObject(type) { async function main() { try { if (modelName === '') return; - [backend, deviceType] = - $('input[name="backend"]:checked').attr('id').split('_'); ui.handleClick(disabledSelectors, true); if (isFirstTimeLoad) $('#hint').hide(); let start; @@ -245,7 +352,7 @@ async function main() { // Set backend and device await utils.setBackend(backend, deviceType); lastdeviceType = lastdeviceType != deviceType ? - deviceType : lastdeviceType; + deviceType : lastdeviceType; lastBackend = lastBackend != backend ? backend : lastBackend; } if (netInstance !== null) { diff --git a/image_classification/mobilenet_nchw.js b/image_classification/mobilenet_nchw.js index c62cb51e..6b52350b 100644 --- a/image_classification/mobilenet_nchw.js +++ b/image_classification/mobilenet_nchw.js @@ -4,13 +4,21 @@ import {buildConstantByNpy, weightsOrigin} from '../common/utils.js'; // MobileNet V2 model with 'nchw' input layout export class MobileNetV2Nchw { - constructor() { + constructor(dataType = 'float32') { this.context_ = null; this.deviceType_ = null; this.builder_ = null; this.graph_ = null; - this.weightsUrl_ = weightsOrigin() + - '/test-data/models/mobilenetv2_nchw/weights/'; + this.dataType_ = dataType; + this.weightsUrl_ = weightsOrigin(); + if (this.dataType_ === 'float32') { + this.weightsUrl_ += '/test-data/models/mobilenetv2_nchw/weights/'; + } else if (this.dataType_ === 'float16') { + this.weightsUrl_ += + '/test-data/models/mobilenetv2_fp16_nchw_optimized/weights/'; + } else { + throw new Error(`Unsupported dataType: ${this.dataType_}`); + } this.inputOptions = { mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], @@ -23,17 +31,27 @@ export class MobileNetV2Nchw { } async buildConv_(input, name, relu6 = true, options = {}) { - const prefix = this.weightsUrl_ + 'conv_' + name; - const weightsName = prefix + '_weight.npy'; - const weights = buildConstantByNpy(this.builder_, weightsName); - const biasName = prefix + '_bias.npy'; - const bias = buildConstantByNpy(this.builder_, biasName); - options.bias = await bias; + let weights; + if (this.dataType_ === 'float32') { + weights = buildConstantByNpy(this.builder_, + `${this.weightsUrl_}conv_${name}_weight.npy`); + options.bias = await buildConstantByNpy(this.builder_, + `${this.weightsUrl_}conv_${name}_bias.npy`); + } else { + weights = buildConstantByNpy(this.builder_, + `${this.weightsUrl_}w${name}.npy`, this.dataType_); + // Only node 97 has no bias input + if (name !== '97') { + options.bias = await buildConstantByNpy(this.builder_, + `${this.weightsUrl_}b${name}.npy`, this.dataType_); + } + } + if (relu6) { // TODO: Set clamp activation to options once it's supported in // WebNN DML backend. // Implement `clip` by `clamp` of WebNN API - if (this.deviceType_ == 'gpu') { + if (this.deviceType_ == 'gpu' || this.deviceType_ == 'npu') { return this.builder_.clamp( this.builder_.conv2d(await input, await weights, options), {minValue: 0, maxValue: 6}); @@ -47,9 +65,11 @@ export class MobileNetV2Nchw { async buildGemm_(input, name) { const prefix = this.weightsUrl_ + 'gemm_' + name; const weightsName = prefix + '_weight.npy'; - const weights = buildConstantByNpy(this.builder_, weightsName); + const weights = buildConstantByNpy(this.builder_, weightsName, + this.dataType_); const biasName = prefix + '_bias.npy'; - const bias = buildConstantByNpy(this.builder_, biasName); + const bias = buildConstantByNpy(this.builder_, biasName, + this.dataType_); const options = {c: await bias, bTranspose: true}; return this.builder_.gemm(await input, await weights, options); } @@ -63,25 +83,27 @@ export class MobileNetV2Nchw { strides: [stride, stride], }; const dwise3x3Relu6 = this.buildConv_( - await conv1x1Relu6, convNameArray[1], true, options); + conv1x1Relu6, convNameArray[1], true, options); const conv1x1Linear = this.buildConv_( - await dwise3x3Relu6, convNameArray[2], false); + dwise3x3Relu6, convNameArray[2], false); if (shortcut) { return this.builder_.add(await input, await conv1x1Linear); } - return await conv1x1Linear; + return conv1x1Linear; } async load(contextOptions) { this.context_ = await navigator.ml.createContext(contextOptions); this.deviceType_ = contextOptions.deviceType; this.builder_ = new MLGraphBuilder(this.context_); - const data = this.builder_.input('input', { - type: 'float32', + let data = this.builder_.input('input', { dataType: 'float32', dimensions: this.inputOptions.inputDimensions, }); + if (this.dataType_ === 'float16') { + data = this.builder_.cast(data, 'float16'); + } const conv0 = this.buildConv_( data, '0', true, {padding: [1, 1, 1, 1], strides: [2, 2]}); const conv1 = this.buildConv_( @@ -121,10 +143,23 @@ export class MobileNetV2Nchw { bottleneck14, ['90', '92', '94'], 960, 1, false); const conv3 = this.buildConv_(bottleneck15, '95', true); - const pool = this.builder_.averagePool2d(await conv3); - const reshape = this.builder_.reshape(pool, [1, 1280]); - const gemm = this.buildGemm_(reshape, '104'); - return await this.builder_.softmax(await gemm); + if (this.dataType_ == 'float32') { + const pool = this.builder_.averagePool2d(await conv3); + const reshape = this.builder_.reshape(pool, [1, 1280]); + const gemm = this.buildGemm_(reshape, '104'); + return this.builder_.softmax(await gemm); + } else { + const conv4 = this.buildConv_(await conv3, '97', false, + {groups: 1280, strides: [7, 7]}); + const conv5 = this.buildConv_(await conv4, '104', false); + const reshape = this.builder_.reshape(await conv5, [1, 1000]); + if (contextOptions.deviceType === 'npu') { + return this.builder_.cast(reshape, 'float32'); + } else { + const softmax = this.builder_.softmax(reshape); + return this.builder_.cast(softmax, 'float32'); + } + } } async build(outputOperand) { diff --git a/image_classification/resnet50v1_fp16_nchw.js b/image_classification/resnet50v1_fp16_nchw.js new file mode 100644 index 00000000..58ff60be --- /dev/null +++ b/image_classification/resnet50v1_fp16_nchw.js @@ -0,0 +1,152 @@ +'use strict'; + +import {buildConstantByNpy, weightsOrigin} from '../common/utils.js'; + +// ResNet50 V1 fp16 model with 'nchw' input layout +export class ResNet50V1FP16Nchw { + constructor() { + this.context_ = null; + this.builder_ = null; + this.graph_ = null; + this.weightsUrl_ = weightsOrigin() + + '/test-data/models/resnet50v1_fp16_nchw_optimized/weights/'; + this.inputOptions = { + mean: [0.485, 0.456, 0.406], + std: [0.229, 0.224, 0.225], + norm: true, + inputLayout: 'nchw', + labelUrl: './labels/labels1000.txt', + inputDimensions: [1, 3, 224, 224], + }; + this.outputDimensions = [1, 1000]; + } + + async buildConv_(input, name, stageName, relu, options = undefined) { + let prefix = ''; + if (stageName !== '') { + prefix = this.weightsUrl_ + 'stage' + stageName + '_conv' + + name; + } else { + prefix = this.weightsUrl_ + 'conv' + name; + } + const weight = buildConstantByNpy(this.builder_, prefix + '_w.npy', + 'float16'); + options.bias = await buildConstantByNpy(this.builder_, prefix + '_b.npy', + 'float16'); + if (relu) { + options.activation = this.builder_.relu(); + } + + return this.builder_.conv2d(await input, await weight, options); + } + + async buildGemm_(input, name) { + const prefix = this.weightsUrl_ + 'dense' + name; + const weightName = prefix + '_w.npy'; + const weight = buildConstantByNpy(this.builder_, weightName, + 'float16'); + const biasName = prefix + '_b.npy'; + const bias = buildConstantByNpy(this.builder_, biasName, + 'float16'); + const options = + {c: this.builder_.reshape(await bias, [1, 1000]), bTranspose: true}; + return this.builder_.gemm(await input, await weight, options); + } + + async buildBottleneck_( + input, stageName, nameIndex, downsample = false, stride = 1) { + let residual = input; + let strides = [1, 1]; + + if (downsample) { + strides = [stride, stride]; + } + const conv1 = this.buildConv_(input, nameIndex, + stageName, true, {strides}); + const conv2 = this.buildConv_(conv1, parseInt(nameIndex) + 1, + stageName, true, {padding: [1, 1, 1, 1]}); + const conv3 = this.buildConv_(conv2, + parseInt(nameIndex) + 2, stageName, false, {}); + if (downsample) { + residual = this.buildConv_( + input, parseInt(nameIndex) + 3, stageName, false, {strides}); + } + const add = this.builder_.add(await conv3, await residual); + return this.builder_.relu(add); + } + + async load(contextOptions) { + this.context_ = await navigator.ml.createContext(contextOptions); + this.builder_ = new MLGraphBuilder(this.context_); + let data = this.builder_.input('input', { + dataType: 'float32', + dimensions: this.inputOptions.inputDimensions, + }); + data = this.builder_.cast(data, 'float16'); + const conv1 = await this.buildConv_( + data, '0', '', true, {padding: [3, 3, 3, 3], strides: [2, 2]}); + const pool1 = this.builder_.maxPool2d(conv1, + {windowDimensions: [3, 3], padding: [1, 1, 1, 1], strides: [2, 2]}); + + // Stage 1 + const bottleneck1 = this.buildBottleneck_(pool1, '1', '0', true); + const bottleneck2 = this.buildBottleneck_(bottleneck1, '1', '4'); + const bottleneck3 = this.buildBottleneck_(bottleneck2, '1', '7'); + + // Stage 2 + const bottleneck4 = this.buildBottleneck_(bottleneck3, '2', '0', + true, 2); + const bottleneck5 = this.buildBottleneck_(bottleneck4, '2', '4'); + const bottleneck6 = this.buildBottleneck_(bottleneck5, '2', '7'); + const bottleneck7 = this.buildBottleneck_(bottleneck6, '2', '10'); + + // Stage 3 + const bottleneck8 = this.buildBottleneck_(bottleneck7, '3', '0', + true, 2); + const bottleneck9 = this.buildBottleneck_(bottleneck8, '3', '4'); + const bottleneck10 = this.buildBottleneck_(bottleneck9, '3', '7'); + const bottleneck11 = this.buildBottleneck_(bottleneck10, '3', '10'); + const bottleneck12 = this.buildBottleneck_(bottleneck11, '3', '13'); + const bottleneck13 = this.buildBottleneck_(bottleneck12, '3', '16'); + + // Stage 4 + const bottleneck14 = this.buildBottleneck_(bottleneck13, '4', '0', + true, 2); + const bottleneck15 = this.buildBottleneck_(bottleneck14, '4', '4'); + const bottleneck16 = this.buildBottleneck_(bottleneck15, '4', '7'); + + const pool2 = this.builder_.averagePool2d(await bottleneck16); + const reshape = this.builder_.reshape(pool2, [1, 2048]); + const gemm = this.buildGemm_(reshape, '0'); + if (contextOptions.deviceType === 'npu') { + return this.builder_.cast(await gemm, 'float32'); + } else { + const softmax = this.builder_.softmax(await gemm); + return this.builder_.cast(softmax, 'float32'); + } + } + + async build(outputOperand) { + this.graph_ = this.builder_.build({'output': outputOperand}); + } + + // Release the constant tensors of a model + dispose() { + // dispose() is only available in webnn-polyfill + if (this.graph_ !== null && 'dispose' in this.graph_) { + this.graph_.dispose(); + } + } + + // Release the constant tensors of a model + async compute(inputBuffer, outputBuffer) { + const inputs = {'input': inputBuffer}; + const outputs = {'output': outputBuffer}; + const results = await this.context_.compute( + await this.graph_, + inputs, + outputs, + ); + return results; + } +} diff --git a/test-data b/test-data index 045017d3..75634f2e 160000 --- a/test-data +++ b/test-data @@ -1 +1 @@ -Subproject commit 045017d38ea0133807fa26af9e5b030147cb2314 +Subproject commit 75634f2e6ac2be244eab703c9dd7497a0fa23ab6