From 5ec9ec8156e8320847239986decdbc1ac025be96 Mon Sep 17 00:00:00 2001 From: Ningxin Hu Date: Mon, 25 Mar 2024 16:37:14 +0800 Subject: [PATCH] Support CPU backend for LeNet sample (#205) * Revert "Disable WebNN CPU for LeNet (#204)" This reverts commit cc9559c30656b684f37e45cbfa071b8cf0cf6064. * Support CPU backend for LeNet sample The changes of this PR include: 1. Support nhwc conv2d and pool2d for LeNet. 2. Derive permuteData function from transformers.js that is used to permute the filter data. 3. Use gemm to replace matmul, because XNNPACK matmul doesn't support 2D inputs. 4. Fix some isses of UI. --- common/utils.js | 44 ++++++++++++++ lenet/index.html | 6 +- lenet/lenet.js | 85 +++++++++++++++++++-------- lenet/main.js | 146 ++++++++++++++++++++++++++--------------------- 4 files changed, 187 insertions(+), 94 deletions(-) diff --git a/common/utils.js b/common/utils.js index 135012d1..627ddf73 100644 --- a/common/utils.js +++ b/common/utils.js @@ -440,3 +440,47 @@ export function computePadding2DForAutoPad( return [beginningPaddingHeight, endingPaddingHeight, beginningPaddingWidth, endingPaddingWidth]; } + +// This function derives from Transformer.js `permute_data()` function: +// https://github.com/xenova/transformers.js/blob/main/src/utils/maths.js#L98 +// which is in Apache License 2.0 +// https://github.com/xenova/transformers.js/blob/main/LICENSE +/** + * Helper method to permute a `AnyTypedArray` directly + * @template {AnyTypedArray} T + * @param {T} array + * @param {number[]} dims + * @param {number[]} axes + * @return {[T, number[]]} The permuted array and the new shape. + */ +export function permuteData(array, dims, axes) { + // Calculate the new shape of the permuted array + // and the stride of the original array + const shape = new Array(axes.length); + const stride = new Array(axes.length); + + for (let i = axes.length - 1, s = 1; i >= 0; --i) { + stride[i] = s; + shape[i] = dims[axes[i]]; + s *= shape[i]; + } + + // Precompute inverse mapping of stride + const invStride = axes.map((_, i) => stride[axes.indexOf(i)]); + + // Create the permuted array with the new shape + // @ts-ignore + const permutedData = new array.constructor(array.length); + + // Permute the original array to the new array + for (let i = 0; i < array.length; ++i) { + let newIndex = 0; + for (let j = dims.length - 1, k = i; j >= 0; --j) { + newIndex += (k % dims[j]) * invStride[j]; + k = Math.floor(k / dims[j]); + } + permutedData[newIndex] = array[i]; + } + + return [permutedData, shape]; +} diff --git a/lenet/index.html b/lenet/index.html index a2dd7695..cf6b5e8d 100644 --- a/lenet/index.html +++ b/lenet/index.html @@ -34,9 +34,9 @@ - + diff --git a/lenet/lenet.js b/lenet/lenet.js index 2cc2c54f..6dcd66c9 100644 --- a/lenet/lenet.js +++ b/lenet/lenet.js @@ -1,14 +1,18 @@ 'use strict'; -import {getBufferFromUrl, sizeOfShape} from '../common/utils.js'; +import {getBufferFromUrl, sizeOfShape, permuteData} from '../common/utils.js'; export class LeNet { - constructor(url) { + constructor(url, layout) { this.context_ = null; this.url_ = url; this.graph_ = null; this.builder_ = null; + this.layout_ = layout; + this.nchwToNhwcPermutation_ = [0, 2, 3, 1]; + this.nhwcToNchwPermutation_ = [0, 3, 1, 2]; + this.oihwToOhwiPermutation_ = [0, 2, 3, 1]; } async load(contextOptions) { @@ -20,25 +24,38 @@ export class LeNet { this.context_ = await navigator.ml.createContext(contextOptions); this.builder_ = new MLGraphBuilder(this.context_); - const inputShape = [1, 1, 28, 28]; - const input = this.builder_.input('input', { - type: 'float32', + const inputShape = /* nchw */ [1, 1, 28, 28]; + let input = this.builder_.input('input', { dataType: 'float32', dimensions: inputShape, }); - const conv1FitlerShape = [20, 1, 5, 5]; + if (this.layout_ === 'nhwc') { + input = this.builder_.transpose( + input, {permutation: this.nchwToNhwcPermutation_}); + } + + const conv1Options = {}; + if (this.layout_ === 'nhwc') { + conv1Options.inputLayout = 'nhwc'; + conv1Options.filterLayout = 'ohwi'; + } + let conv1FitlerShape = /* oihw */ [20, 1, 5, 5]; let byteOffset = 0; - const conv1FilterData = new Float32Array( + let conv1FilterData = new Float32Array( arrayBuffer, byteOffset, sizeOfShape(conv1FitlerShape)); + if (this.layout_ === 'nhwc') { + [conv1FilterData, conv1FitlerShape] = + permuteData( + conv1FilterData, conv1FitlerShape, this.oihwToOhwiPermutation_); + } const conv1Filter = this.builder_.constant( - {type: 'float32', dataType: 'float32', dimensions: conv1FitlerShape}, + {dataType: 'float32', dimensions: conv1FitlerShape}, conv1FilterData); byteOffset += sizeOfShape(conv1FitlerShape) * Float32Array.BYTES_PER_ELEMENT; - const conv1 = this.builder_.conv2d(input, conv1Filter); - const add1BiasShape = [1, 20, 1, 1]; + const add1BiasShape = [20]; const add1BiasData = new Float32Array(arrayBuffer, byteOffset, sizeOfShape(add1BiasShape)); const add1Bias = this.builder_.constant( @@ -46,36 +63,54 @@ export class LeNet { add1BiasData, ); byteOffset += sizeOfShape(add1BiasShape) * Float32Array.BYTES_PER_ELEMENT; - const add1 = this.builder_.add(conv1, add1Bias); + conv1Options.bias = add1Bias; + + const conv1 = this.builder_.conv2d(input, conv1Filter, conv1Options); const pool1WindowShape = [2, 2]; const pool1Strides = [2, 2]; const pool1 = - this.builder_.maxPool2d(add1, {windowDimensions: pool1WindowShape, - strides: pool1Strides}); + this.builder_.maxPool2d(conv1, {windowDimensions: pool1WindowShape, + strides: pool1Strides, layout: this.layout_}); - const conv2FilterShape = [50, 20, 5, 5]; + const conv2Options = {}; + if (this.layout_ === 'nhwc') { + conv2Options.inputLayout = 'nhwc'; + conv2Options.filterLayout = 'ohwi'; + } + let conv2FilterShape = /* oihw */ [50, 20, 5, 5]; + let conv2FilterData = new Float32Array( + arrayBuffer, byteOffset, sizeOfShape(conv2FilterShape)); + if (this.layout_ === 'nhwc') { + [conv2FilterData, conv2FilterShape] = + permuteData( + conv2FilterData, conv2FilterShape, this.oihwToOhwiPermutation_); + } const conv2Filter = this.builder_.constant( {type: 'float32', dataType: 'float32', dimensions: conv2FilterShape}, - new Float32Array( - arrayBuffer, byteOffset, sizeOfShape(conv2FilterShape)), - ); + conv2FilterData); byteOffset += sizeOfShape(conv2FilterShape) * Float32Array.BYTES_PER_ELEMENT; - const conv2 = this.builder_.conv2d(pool1, conv2Filter); - const add2BiasShape = [1, 50, 1, 1]; + const add2BiasShape = [50]; const add2Bias = this.builder_.constant( {type: 'float32', dataType: 'float32', dimensions: add2BiasShape}, new Float32Array(arrayBuffer, byteOffset, sizeOfShape(add2BiasShape))); byteOffset += sizeOfShape(add2BiasShape) * Float32Array.BYTES_PER_ELEMENT; - const add2 = this.builder_.add(conv2, add2Bias); + conv2Options.bias = add2Bias; + + const conv2 = this.builder_.conv2d(pool1, conv2Filter, conv2Options); const pool2WindowShape = [2, 2]; const pool2Strides = [2, 2]; - const pool2 = - this.builder_.maxPool2d(add2, {windowDimensions: pool2WindowShape, - strides: pool2Strides}); + let pool2 = + this.builder_.maxPool2d(conv2, {windowDimensions: pool2WindowShape, + strides: pool2Strides, layout: this.layout_}); + + if (this.layout_ === 'nhwc') { + pool2 = this.builder_.transpose( + pool2, {permutation: this.nhwcToNchwPermutation_}); + } const reshape1Shape = [1, 800]; const reshape1 = this.builder_.reshape(pool2, reshape1Shape); @@ -89,7 +124,7 @@ export class LeNet { new Float32Array(arrayBuffer, byteOffset, sizeOfShape(matmul1Shape))); byteOffset += sizeOfShape(matmul1Shape) * Float32Array.BYTES_PER_ELEMENT; const matmul1WeightsTransposed = this.builder_.transpose(matmul1Weights); - const matmul1 = this.builder_.matmul(reshape1, matmul1WeightsTransposed); + const matmul1 = this.builder_.gemm(reshape1, matmul1WeightsTransposed); const add3BiasShape = [1, 500]; const add3Bias = this.builder_.constant( @@ -109,7 +144,7 @@ export class LeNet { new Float32Array(arrayBuffer, byteOffset, sizeOfShape(matmul2Shape))); byteOffset += sizeOfShape(matmul2Shape) * Float32Array.BYTES_PER_ELEMENT; const matmul2WeightsTransposed = this.builder_.transpose(matmul2Weights); - const matmul2 = this.builder_.matmul(reshape2, matmul2WeightsTransposed); + const matmul2 = this.builder_.gemm(reshape2, matmul2WeightsTransposed); const add4BiasShape = [1, 10]; const add4Bias = this.builder_.constant( diff --git a/lenet/main.js b/lenet/main.js index c1fe388b..cc064527 100644 --- a/lenet/main.js +++ b/lenet/main.js @@ -17,6 +17,19 @@ digitCanvas.setAttribute('height', 28); digitCanvas.setAttribute('width', 28); digitCanvas.style.backgroundColor = 'black'; const digitContext = digitCanvas.getContext('2d'); +const pen = new Pen(visualCanvas); +let lenet; +let numRuns; + +function clearInferenceResult() { + inferenceTimeElement.innerHTML = ''; + for (let i = 0; i < 3; ++i) { + const labelElement = document.getElementById(`label${i}`); + const probElement = document.getElementById(`prob${i}`); + labelElement.innerHTML = ''; + probElement.innerHTML = ''; + } +} $('#backendBtns .btn').on('change', async () => { await main(); @@ -49,25 +62,20 @@ function getMedianValue(array) { (array[array.length / 2 - 1] + array[array.length / 2]) / 2; } -function clearResult() { - for (let i = 0; i < 3; ++i) { - const labelElement = document.getElementById(`label${i}`); - const probElement = document.getElementById(`prob${i}`); - labelElement.innerHTML = ''; - probElement.innerHTML = ''; - } -} - async function main() { + buildTimeElement.innerHTML = ''; + predictButton.setAttribute('disabled', true); + clearInferenceResult(); const [backend, deviceType] = $('input[name="backend"]:checked').attr('id').split('_'); await utils.setBackend(backend, deviceType); drawNextDigitFromMnist(); - const pen = new Pen(visualCanvas); const weightUrl = utils.weightsOrigin() + '/test-data/models/lenet_nchw/weights/lenet.bin'; - const lenet = new LeNet(weightUrl); - const [numRuns, powerPreference, numThreads] = utils.getUrlParams(); + const layout = deviceType === 'cpu' ? 'nhwc' : 'nchw'; + lenet = new LeNet(weightUrl, layout); + const [localNumRuns, powerPreference, numThreads] = utils.getUrlParams(); + numRuns = localNumRuns; try { const contextOptions = {deviceType}; if (powerPreference) { @@ -93,62 +101,68 @@ async function main() { console.log(error); addAlert(error.message); } - predictButton.addEventListener('click', async function(e) { - try { - let start; - let inferenceTime; - const inferenceTimeArray = []; - const input = getInputFromCanvas(); - let outputBuffer = new Float32Array(utils.sizeOfShape([1, 10])); - - // Do warm up - let results = await lenet.compute(input, outputBuffer); - - for (let i = 0; i < numRuns; i++) { - start = performance.now(); - results = await lenet.compute( - results.inputs.input, results.outputs.output); - inferenceTime = performance.now() - start; - console.log(`execution elapsed time: ${inferenceTime.toFixed(2)} ms`); - inferenceTimeArray.push(inferenceTime); - } - - if (numRuns === 1) { - inferenceTimeElement.innerHTML = 'Execution Time: ' + - `${inferenceTime.toFixed(2)} ms`; - } else { - const medianInferenceTime = getMedianValue(inferenceTimeArray); - console.log(`median execution elapsed time: ` + - `${medianInferenceTime.toFixed(2)} ms`); - inferenceTimeElement.innerHTML = `Median Execution Time(${numRuns}` + - ` runs): ` + - `${medianInferenceTime.toFixed(2)} ms`; - } - - outputBuffer = results.outputs.output; - const classes = topK(Array.from(outputBuffer)); - classes.forEach((c, i) => { - console.log(`\tlabel: ${c.label}, probability: ${c.prob}%`); - const labelElement = document.getElementById(`label${i}`); - const probElement = document.getElementById(`prob${i}`); - labelElement.innerHTML = `${c.label}`; - probElement.innerHTML = `${c.prob}%`; - }); - } catch (error) { - console.log(error); - addAlert(error.message); +} + +predictButton.addEventListener('click', async function(e) { + clearInferenceResult(); + predictButton.setAttribute('disabled', true); + try { + let start; + let inferenceTime; + const inferenceTimeArray = []; + const input = getInputFromCanvas(); + let outputBuffer = new Float32Array(utils.sizeOfShape([1, 10])); + + // Do warm up + let results = await lenet.compute(input, outputBuffer); + + for (let i = 0; i < numRuns; i++) { + start = performance.now(); + results = await lenet.compute( + results.inputs.input, results.outputs.output); + inferenceTime = performance.now() - start; + console.log(`execution elapsed time: ${inferenceTime.toFixed(2)} ms`); + inferenceTimeArray.push(inferenceTime); } - }); - nextButton.addEventListener('click', () => { - drawNextDigitFromMnist(); - clearResult(); - }); - clearButton.addEventListener('click', () => { - pen.clear(); - clearResult(); - }); -} + if (numRuns === 1) { + inferenceTimeElement.innerHTML = 'Execution Time: ' + + `${inferenceTime.toFixed(2)} ms`; + } else { + const medianInferenceTime = getMedianValue(inferenceTimeArray); + console.log(`median execution elapsed time: ` + + `${medianInferenceTime.toFixed(2)} ms`); + inferenceTimeElement.innerHTML = `Median Execution Time(${numRuns}` + + ` runs): ` + + `${medianInferenceTime.toFixed(2)} ms`; + } + + outputBuffer = results.outputs.output; + const classes = topK(Array.from(outputBuffer)); + classes.forEach((c, i) => { + console.log(`\tlabel: ${c.label}, probability: ${c.prob}%`); + const labelElement = document.getElementById(`label${i}`); + const probElement = document.getElementById(`prob${i}`); + labelElement.innerHTML = `${c.label}`; + probElement.innerHTML = `${c.prob}%`; + }); + + predictButton.removeAttribute('disabled'); + } catch (error) { + console.log(error); + addAlert(error.message); + } +}); + +nextButton.addEventListener('click', () => { + drawNextDigitFromMnist(); + clearInferenceResult(); +}); + +clearButton.addEventListener('click', () => { + pen.clear(); + clearInferenceResult(); +}); function topK(probs, k = 3) { const sorted = probs.map((prob, index) => [prob, index]).sort((a, b) => {