Skip to content

Commit

Permalink
Support CPU backend for LeNet sample (webmachinelearning#205)
Browse files Browse the repository at this point in the history
* Revert "Disable WebNN CPU for LeNet (webmachinelearning#204)"

This reverts commit cc9559c.

* Support CPU backend for LeNet sample

The changes of this PR include:
1. Support nhwc conv2d and pool2d for LeNet.
2. Derive permuteData function from transformers.js that is used to
   permute the filter data.
3. Use gemm to replace matmul, because XNNPACK matmul doesn't support 2D
   inputs.
4. Fix some isses of UI.
  • Loading branch information
huningxin authored and Honry committed May 15, 2024
1 parent f5d60ca commit eb7a1be
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 94 deletions.
44 changes: 44 additions & 0 deletions common/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,47 @@ export function computePadding2DForAutoPad(
return [beginningPaddingHeight, endingPaddingHeight,
beginningPaddingWidth, endingPaddingWidth];
}

// This function derives from Transformer.js `permute_data()` function:
// https://github.com/xenova/transformers.js/blob/main/src/utils/maths.js#L98
// which is in Apache License 2.0
// https://github.com/xenova/transformers.js/blob/main/LICENSE
/**
* Helper method to permute a `AnyTypedArray` directly
* @template {AnyTypedArray} T
* @param {T} array
* @param {number[]} dims
* @param {number[]} axes
* @return {[T, number[]]} The permuted array and the new shape.
*/
export function permuteData(array, dims, axes) {
// Calculate the new shape of the permuted array
// and the stride of the original array
const shape = new Array(axes.length);
const stride = new Array(axes.length);

for (let i = axes.length - 1, s = 1; i >= 0; --i) {
stride[i] = s;
shape[i] = dims[axes[i]];
s *= shape[i];
}

// Precompute inverse mapping of stride
const invStride = axes.map((_, i) => stride[axes.indexOf(i)]);

// Create the permuted array with the new shape
// @ts-ignore
const permutedData = new array.constructor(array.length);

// Permute the original array to the new array
for (let i = 0; i < array.length; ++i) {
let newIndex = 0;
for (let j = dims.length - 1, k = i; j >= 0; --j) {
newIndex += (k % dims[j]) * invStride[j];
k = Math.floor(k / dims[j]);
}
permutedData[newIndex] = array[i];
}

return [permutedData, shape];
}
6 changes: 3 additions & 3 deletions lenet/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
<label class="btn btn-outline-info" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
</label>
<!-- <label class="btn btn-outline-info" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off" disabled>WebNN (CPU)
</label> -->
<label class="btn btn-outline-info" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
</label>
<label class="btn btn-outline-info" name="webnn">
<input type="radio" name="backend" id="webnn_gpu" autocomplete="off">WebNN (GPU)
</label>
Expand Down
85 changes: 60 additions & 25 deletions lenet/lenet.js
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@

'use strict';

import {getBufferFromUrl, sizeOfShape} from '../common/utils.js';
import {getBufferFromUrl, sizeOfShape, permuteData} from '../common/utils.js';

export class LeNet {
constructor(url) {
constructor(url, layout) {
this.context_ = null;
this.url_ = url;
this.graph_ = null;
this.builder_ = null;
this.layout_ = layout;
this.nchwToNhwcPermutation_ = [0, 2, 3, 1];
this.nhwcToNchwPermutation_ = [0, 3, 1, 2];
this.oihwToOhwiPermutation_ = [0, 2, 3, 1];
}

async load(contextOptions) {
Expand All @@ -20,62 +24,93 @@ export class LeNet {

this.context_ = await navigator.ml.createContext(contextOptions);
this.builder_ = new MLGraphBuilder(this.context_);
const inputShape = [1, 1, 28, 28];
const input = this.builder_.input('input', {
type: 'float32',
const inputShape = /* nchw */ [1, 1, 28, 28];
let input = this.builder_.input('input', {
dataType: 'float32',
dimensions: inputShape,
});

const conv1FitlerShape = [20, 1, 5, 5];
if (this.layout_ === 'nhwc') {
input = this.builder_.transpose(
input, {permutation: this.nchwToNhwcPermutation_});
}

const conv1Options = {};
if (this.layout_ === 'nhwc') {
conv1Options.inputLayout = 'nhwc';
conv1Options.filterLayout = 'ohwi';
}
let conv1FitlerShape = /* oihw */ [20, 1, 5, 5];
let byteOffset = 0;
const conv1FilterData = new Float32Array(
let conv1FilterData = new Float32Array(
arrayBuffer, byteOffset, sizeOfShape(conv1FitlerShape));
if (this.layout_ === 'nhwc') {
[conv1FilterData, conv1FitlerShape] =
permuteData(
conv1FilterData, conv1FitlerShape, this.oihwToOhwiPermutation_);
}
const conv1Filter = this.builder_.constant(
{type: 'float32', dataType: 'float32', dimensions: conv1FitlerShape},
{dataType: 'float32', dimensions: conv1FitlerShape},
conv1FilterData);
byteOffset +=
sizeOfShape(conv1FitlerShape) * Float32Array.BYTES_PER_ELEMENT;
const conv1 = this.builder_.conv2d(input, conv1Filter);

const add1BiasShape = [1, 20, 1, 1];
const add1BiasShape = [20];
const add1BiasData =
new Float32Array(arrayBuffer, byteOffset, sizeOfShape(add1BiasShape));
const add1Bias = this.builder_.constant(
{type: 'float32', dataType: 'float32', dimensions: add1BiasShape},
add1BiasData,
);
byteOffset += sizeOfShape(add1BiasShape) * Float32Array.BYTES_PER_ELEMENT;
const add1 = this.builder_.add(conv1, add1Bias);
conv1Options.bias = add1Bias;

const conv1 = this.builder_.conv2d(input, conv1Filter, conv1Options);

const pool1WindowShape = [2, 2];
const pool1Strides = [2, 2];
const pool1 =
this.builder_.maxPool2d(add1, {windowDimensions: pool1WindowShape,
strides: pool1Strides});
this.builder_.maxPool2d(conv1, {windowDimensions: pool1WindowShape,
strides: pool1Strides, layout: this.layout_});

const conv2FilterShape = [50, 20, 5, 5];
const conv2Options = {};
if (this.layout_ === 'nhwc') {
conv2Options.inputLayout = 'nhwc';
conv2Options.filterLayout = 'ohwi';
}
let conv2FilterShape = /* oihw */ [50, 20, 5, 5];
let conv2FilterData = new Float32Array(
arrayBuffer, byteOffset, sizeOfShape(conv2FilterShape));
if (this.layout_ === 'nhwc') {
[conv2FilterData, conv2FilterShape] =
permuteData(
conv2FilterData, conv2FilterShape, this.oihwToOhwiPermutation_);
}
const conv2Filter = this.builder_.constant(
{type: 'float32', dataType: 'float32', dimensions: conv2FilterShape},
new Float32Array(
arrayBuffer, byteOffset, sizeOfShape(conv2FilterShape)),
);
conv2FilterData);
byteOffset +=
sizeOfShape(conv2FilterShape) * Float32Array.BYTES_PER_ELEMENT;
const conv2 = this.builder_.conv2d(pool1, conv2Filter);

const add2BiasShape = [1, 50, 1, 1];
const add2BiasShape = [50];
const add2Bias = this.builder_.constant(
{type: 'float32', dataType: 'float32', dimensions: add2BiasShape},
new Float32Array(arrayBuffer, byteOffset, sizeOfShape(add2BiasShape)));
byteOffset += sizeOfShape(add2BiasShape) * Float32Array.BYTES_PER_ELEMENT;
const add2 = this.builder_.add(conv2, add2Bias);
conv2Options.bias = add2Bias;

const conv2 = this.builder_.conv2d(pool1, conv2Filter, conv2Options);

const pool2WindowShape = [2, 2];
const pool2Strides = [2, 2];
const pool2 =
this.builder_.maxPool2d(add2, {windowDimensions: pool2WindowShape,
strides: pool2Strides});
let pool2 =
this.builder_.maxPool2d(conv2, {windowDimensions: pool2WindowShape,
strides: pool2Strides, layout: this.layout_});

if (this.layout_ === 'nhwc') {
pool2 = this.builder_.transpose(
pool2, {permutation: this.nhwcToNchwPermutation_});
}

const reshape1Shape = [1, 800];
const reshape1 = this.builder_.reshape(pool2, reshape1Shape);
Expand All @@ -89,7 +124,7 @@ export class LeNet {
new Float32Array(arrayBuffer, byteOffset, sizeOfShape(matmul1Shape)));
byteOffset += sizeOfShape(matmul1Shape) * Float32Array.BYTES_PER_ELEMENT;
const matmul1WeightsTransposed = this.builder_.transpose(matmul1Weights);
const matmul1 = this.builder_.matmul(reshape1, matmul1WeightsTransposed);
const matmul1 = this.builder_.gemm(reshape1, matmul1WeightsTransposed);

const add3BiasShape = [1, 500];
const add3Bias = this.builder_.constant(
Expand All @@ -109,7 +144,7 @@ export class LeNet {
new Float32Array(arrayBuffer, byteOffset, sizeOfShape(matmul2Shape)));
byteOffset += sizeOfShape(matmul2Shape) * Float32Array.BYTES_PER_ELEMENT;
const matmul2WeightsTransposed = this.builder_.transpose(matmul2Weights);
const matmul2 = this.builder_.matmul(reshape2, matmul2WeightsTransposed);
const matmul2 = this.builder_.gemm(reshape2, matmul2WeightsTransposed);

const add4BiasShape = [1, 10];
const add4Bias = this.builder_.constant(
Expand Down
146 changes: 80 additions & 66 deletions lenet/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@ digitCanvas.setAttribute('height', 28);
digitCanvas.setAttribute('width', 28);
digitCanvas.style.backgroundColor = 'black';
const digitContext = digitCanvas.getContext('2d');
const pen = new Pen(visualCanvas);
let lenet;
let numRuns;

function clearInferenceResult() {
inferenceTimeElement.innerHTML = '';
for (let i = 0; i < 3; ++i) {
const labelElement = document.getElementById(`label${i}`);
const probElement = document.getElementById(`prob${i}`);
labelElement.innerHTML = '';
probElement.innerHTML = '';
}
}

$('#backendBtns .btn').on('change', async () => {
await main();
Expand Down Expand Up @@ -49,25 +62,20 @@ function getMedianValue(array) {
(array[array.length / 2 - 1] + array[array.length / 2]) / 2;
}

function clearResult() {
for (let i = 0; i < 3; ++i) {
const labelElement = document.getElementById(`label${i}`);
const probElement = document.getElementById(`prob${i}`);
labelElement.innerHTML = '';
probElement.innerHTML = '';
}
}

async function main() {
buildTimeElement.innerHTML = '';
predictButton.setAttribute('disabled', true);
clearInferenceResult();
const [backend, deviceType] =
$('input[name="backend"]:checked').attr('id').split('_');
await utils.setBackend(backend, deviceType);
drawNextDigitFromMnist();
const pen = new Pen(visualCanvas);
const weightUrl = utils.weightsOrigin() +
'/test-data/models/lenet_nchw/weights/lenet.bin';
const lenet = new LeNet(weightUrl);
const [numRuns, powerPreference, numThreads] = utils.getUrlParams();
const layout = deviceType === 'cpu' ? 'nhwc' : 'nchw';
lenet = new LeNet(weightUrl, layout);
const [localNumRuns, powerPreference, numThreads] = utils.getUrlParams();
numRuns = localNumRuns;
try {
const contextOptions = {deviceType};
if (powerPreference) {
Expand All @@ -93,62 +101,68 @@ async function main() {
console.log(error);
addAlert(error.message);
}
predictButton.addEventListener('click', async function(e) {
try {
let start;
let inferenceTime;
const inferenceTimeArray = [];
const input = getInputFromCanvas();
let outputBuffer = new Float32Array(utils.sizeOfShape([1, 10]));

// Do warm up
let results = await lenet.compute(input, outputBuffer);

for (let i = 0; i < numRuns; i++) {
start = performance.now();
results = await lenet.compute(
results.inputs.input, results.outputs.output);
inferenceTime = performance.now() - start;
console.log(`execution elapsed time: ${inferenceTime.toFixed(2)} ms`);
inferenceTimeArray.push(inferenceTime);
}

if (numRuns === 1) {
inferenceTimeElement.innerHTML = 'Execution Time: ' +
`<span class='text-primary'>${inferenceTime.toFixed(2)}</span> ms`;
} else {
const medianInferenceTime = getMedianValue(inferenceTimeArray);
console.log(`median execution elapsed time: ` +
`${medianInferenceTime.toFixed(2)} ms`);
inferenceTimeElement.innerHTML = `Median Execution Time(${numRuns}` +
` runs): <span class='text-primary'>` +
`${medianInferenceTime.toFixed(2)}</span> ms`;
}

outputBuffer = results.outputs.output;
const classes = topK(Array.from(outputBuffer));
classes.forEach((c, i) => {
console.log(`\tlabel: ${c.label}, probability: ${c.prob}%`);
const labelElement = document.getElementById(`label${i}`);
const probElement = document.getElementById(`prob${i}`);
labelElement.innerHTML = `${c.label}`;
probElement.innerHTML = `${c.prob}%`;
});
} catch (error) {
console.log(error);
addAlert(error.message);
}

predictButton.addEventListener('click', async function(e) {
clearInferenceResult();
predictButton.setAttribute('disabled', true);
try {
let start;
let inferenceTime;
const inferenceTimeArray = [];
const input = getInputFromCanvas();
let outputBuffer = new Float32Array(utils.sizeOfShape([1, 10]));

// Do warm up
let results = await lenet.compute(input, outputBuffer);

for (let i = 0; i < numRuns; i++) {
start = performance.now();
results = await lenet.compute(
results.inputs.input, results.outputs.output);
inferenceTime = performance.now() - start;
console.log(`execution elapsed time: ${inferenceTime.toFixed(2)} ms`);
inferenceTimeArray.push(inferenceTime);
}
});
nextButton.addEventListener('click', () => {
drawNextDigitFromMnist();
clearResult();
});

clearButton.addEventListener('click', () => {
pen.clear();
clearResult();
});
}
if (numRuns === 1) {
inferenceTimeElement.innerHTML = 'Execution Time: ' +
`<span class='text-primary'>${inferenceTime.toFixed(2)}</span> ms`;
} else {
const medianInferenceTime = getMedianValue(inferenceTimeArray);
console.log(`median execution elapsed time: ` +
`${medianInferenceTime.toFixed(2)} ms`);
inferenceTimeElement.innerHTML = `Median Execution Time(${numRuns}` +
` runs): <span class='text-primary'>` +
`${medianInferenceTime.toFixed(2)}</span> ms`;
}

outputBuffer = results.outputs.output;
const classes = topK(Array.from(outputBuffer));
classes.forEach((c, i) => {
console.log(`\tlabel: ${c.label}, probability: ${c.prob}%`);
const labelElement = document.getElementById(`label${i}`);
const probElement = document.getElementById(`prob${i}`);
labelElement.innerHTML = `${c.label}`;
probElement.innerHTML = `${c.prob}%`;
});

predictButton.removeAttribute('disabled');
} catch (error) {
console.log(error);
addAlert(error.message);
}
});

nextButton.addEventListener('click', () => {
drawNextDigitFromMnist();
clearInferenceResult();
});

clearButton.addEventListener('click', () => {
pen.clear();
clearInferenceResult();
});

function topK(probs, k = 3) {
const sorted = probs.map((prob, index) => [prob, index]).sort((a, b) => {
Expand Down

0 comments on commit eb7a1be

Please sign in to comment.