Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the MobileNet NCHW .js to fetch in parallel #200

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 39 additions & 41 deletions image_classification/mobilenet_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -25,54 +25,52 @@ export class MobileNetV2Nchw {
async buildConv_(input, name, relu6 = true, options = {}) {
const prefix = this.weightsUrl_ + 'conv_' + name;
const weightsName = prefix + '_weight.npy';
const weights =
await buildConstantByNpy(this.builder_, weightsName);
const weights = buildConstantByNpy(this.builder_, weightsName);
const biasName = prefix + '_bias.npy';
const bias =
await buildConstantByNpy(this.builder_, biasName);
options.bias = bias;
const bias = buildConstantByNpy(this.builder_, biasName);
options.bias = await bias;
if (relu6) {
// TODO: Set clamp activation to options once it's supported in
// WebNN DML backend.
// Implement `clip` by `clamp` of WebNN API
if (this.deviceType_ == 'gpu') {
return this.builder_.clamp(
this.builder_.conv2d(input, weights, options),
this.builder_.conv2d(await input, await weights, options),
{minValue: 0, maxValue: 6});
} else {
options.activation = this.builder_.clamp({minValue: 0, maxValue: 6});
}
}
return this.builder_.conv2d(input, weights, options);
return this.builder_.conv2d(await input, await weights, options);
}

async buildGemm_(input, name) {
const prefix = this.weightsUrl_ + 'gemm_' + name;
const weightsName = prefix + '_weight.npy';
const weights = await buildConstantByNpy(this.builder_, weightsName);
const weights = buildConstantByNpy(this.builder_, weightsName);
const biasName = prefix + '_bias.npy';
const bias = await buildConstantByNpy(this.builder_, biasName);
const options = {c: bias, bTranspose: true};
return this.builder_.gemm(input, weights, options);
const bias = buildConstantByNpy(this.builder_, biasName);
const options = {c: await bias, bTranspose: true};
return this.builder_.gemm(await input, await weights, options);
}

async buildLinearBottleneck_(
input, convNameArray, group, stride, shortcut = true) {
const conv1x1Relu6 = await this.buildConv_(input, convNameArray[0]);
const conv1x1Relu6 = this.buildConv_(await input, convNameArray[0]);
const options = {
padding: [1, 1, 1, 1],
groups: group,
strides: [stride, stride],
};
const dwise3x3Relu6 = await this.buildConv_(
conv1x1Relu6, convNameArray[1], true, options);
const conv1x1Linear = await this.buildConv_(
dwise3x3Relu6, convNameArray[2], false);
const dwise3x3Relu6 = this.buildConv_(
await conv1x1Relu6, convNameArray[1], true, options);
const conv1x1Linear = this.buildConv_(
await dwise3x3Relu6, convNameArray[2], false);

if (shortcut) {
return this.builder_.add(input, conv1x1Linear);
return this.builder_.add(await input, await conv1x1Linear);
}
return conv1x1Linear;
return await conv1x1Linear;
}

async load(contextOptions) {
Expand All @@ -84,49 +82,49 @@ export class MobileNetV2Nchw {
dataType: 'float32',
dimensions: this.inputOptions.inputDimensions,
});
const conv0 = await this.buildConv_(
const conv0 = this.buildConv_(
data, '0', true, {padding: [1, 1, 1, 1], strides: [2, 2]});
const conv1 = await this.buildConv_(
const conv1 = this.buildConv_(
conv0, '2', true, {padding: [1, 1, 1, 1], groups: 32});
const conv2 = await this.buildConv_(conv1, '4', false);
const bottleneck0 = await this.buildLinearBottleneck_(
const conv2 = this.buildConv_(conv1, '4', false);
const bottleneck0 = this.buildLinearBottleneck_(
conv2, ['5', '7', '9'], 96, 2, false);
const bottleneck1 = await this.buildLinearBottleneck_(
const bottleneck1 = this.buildLinearBottleneck_(
bottleneck0, ['10', '12', '14'], 144, 1);
const bottleneck2 = await this.buildLinearBottleneck_(
const bottleneck2 = this.buildLinearBottleneck_(
bottleneck1, ['16', '18', '20'], 144, 2, false);
const bottleneck3 = await this.buildLinearBottleneck_(
const bottleneck3 = this.buildLinearBottleneck_(
bottleneck2, ['21', '23', '25'], 192, 1);
const bottleneck4 = await this.buildLinearBottleneck_(
const bottleneck4 = this.buildLinearBottleneck_(
bottleneck3, ['27', '29', '31'], 192, 1);
const bottleneck5 = await this.buildLinearBottleneck_(
const bottleneck5 = this.buildLinearBottleneck_(
bottleneck4, ['33', '35', '37'], 192, 2, false);
const bottleneck6 = await this.buildLinearBottleneck_(
const bottleneck6 = this.buildLinearBottleneck_(
bottleneck5, ['38', '40', '42'], 384, 1);
const bottleneck7 = await this.buildLinearBottleneck_(
const bottleneck7 = this.buildLinearBottleneck_(
bottleneck6, ['44', '46', '48'], 384, 1);
const bottleneck8 = await this.buildLinearBottleneck_(
const bottleneck8 = this.buildLinearBottleneck_(
bottleneck7, ['50', '52', '54'], 384, 1);
const bottleneck9 = await this.buildLinearBottleneck_(
const bottleneck9 = this.buildLinearBottleneck_(
bottleneck8, ['56', '58', '60'], 384, 1, false);
const bottleneck10 = await this.buildLinearBottleneck_(
const bottleneck10 = this.buildLinearBottleneck_(
bottleneck9, ['61', '63', '65'], 576, 1);
const bottleneck11 = await this.buildLinearBottleneck_(
const bottleneck11 = this.buildLinearBottleneck_(
bottleneck10, ['67', '69', '71'], 576, 1);
const bottleneck12 = await this.buildLinearBottleneck_(
const bottleneck12 = this.buildLinearBottleneck_(
bottleneck11, ['73', '75', '77'], 576, 2, false);
const bottleneck13 = await this.buildLinearBottleneck_(
const bottleneck13 = this.buildLinearBottleneck_(
bottleneck12, ['78', '80', '82'], 960, 1);
const bottleneck14 = await this.buildLinearBottleneck_(
const bottleneck14 = this.buildLinearBottleneck_(
bottleneck13, ['84', '86', '88'], 960, 1);
const bottleneck15 = await this.buildLinearBottleneck_(
const bottleneck15 = this.buildLinearBottleneck_(
bottleneck14, ['90', '92', '94'], 960, 1, false);

const conv3 = await this.buildConv_(bottleneck15, '95', true);
const pool = this.builder_.averagePool2d(conv3);
const conv3 = this.buildConv_(bottleneck15, '95', true);
const pool = this.builder_.averagePool2d(await conv3);
const reshape = this.builder_.reshape(pool, [1, 1280]);
const gemm = await this.buildGemm_(reshape, '104');
return this.builder_.softmax(gemm);
const gemm = this.buildGemm_(reshape, '104');
return await this.builder_.softmax(await gemm);
}

async build(outputOperand) {
Expand Down