-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #226 from mingmingtasd/npu_fp16
Add NPU device type and three fp16 models for image classification
- Loading branch information
Showing
9 changed files
with
566 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
module.exports = { | ||
globals: { | ||
'MLGraphBuilder': 'readonly', | ||
'tf': 'readonly', | ||
}, | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
'use strict'; | ||
|
||
import {buildConstantByNpy, weightsOrigin} from '../common/utils.js'; | ||
|
||
// EfficientNet fp16 model with 'nchw' input layout | ||
export class EfficientNetFP16Nchw { | ||
constructor() { | ||
this.context_ = null; | ||
this.builder_ = null; | ||
this.graph_ = null; | ||
this.weightsUrl_ = weightsOrigin() + | ||
'/test-data/models/efficientnet_fp16_nchw_optimized/weights/'; | ||
this.inputOptions = { | ||
mean: [0.485, 0.456, 0.406], | ||
std: [0.229, 0.224, 0.225], | ||
norm: true, | ||
inputLayout: 'nchw', | ||
labelUrl: './labels/labels1000.txt', | ||
inputDimensions: [1, 3, 224, 224], | ||
}; | ||
this.outputDimensions = [1, 1000]; | ||
} | ||
|
||
async buildConv_(input, name, blockName, clip = false, options = {}) { | ||
let prefix = ''; | ||
if (blockName !== '') { | ||
prefix = this.weightsUrl_ + 'block' + blockName + '_conv' + | ||
name; | ||
} else { | ||
prefix = this.weightsUrl_ + 'conv' + name; | ||
} | ||
const weight = buildConstantByNpy(this.builder_, prefix + '_w.npy', | ||
'float16'); | ||
options.bias = await buildConstantByNpy(this.builder_, prefix + '_b.npy', | ||
'float16'); | ||
if (clip) { | ||
return this.builder_.clamp( | ||
this.builder_.conv2d(await input, await weight, options), | ||
{minValue: 0, maxValue: 6}); | ||
} | ||
return this.builder_.conv2d(await input, await weight, options); | ||
} | ||
|
||
async buildGemm_(input, name) { | ||
const prefix = this.weightsUrl_ + 'dense' + name; | ||
const weightName = prefix + '_w.npy'; | ||
const weight = buildConstantByNpy(this.builder_, weightName, | ||
'float16'); | ||
const biasName = prefix + '_b.npy'; | ||
const bias = buildConstantByNpy(this.builder_, biasName, | ||
'float16'); | ||
const options = | ||
{c: this.builder_.reshape(await bias, [1, 1000])}; | ||
return await this.builder_.gemm(await input, await weight, options); | ||
} | ||
|
||
async buildBottleneck_(input, blockName, group, pad = 1) { | ||
const conv1 = this.buildConv_(input, '0', blockName, true); | ||
const conv2 = this.buildConv_(conv1, '1', blockName, true, | ||
{groups: group, padding: [pad, pad, pad, pad]}); | ||
const conv3 = this.buildConv_(conv2, '2', blockName); | ||
return this.builder_.add(await conv3, await input); | ||
} | ||
|
||
async buildBottlenecks_(input, blockNames, group, pad = 1) { | ||
let result = input; | ||
for (let i = 0; i < blockNames.length; i++) { | ||
const bottleneck = await this.buildBottleneck_(result, blockNames[i], | ||
group, pad); | ||
result = bottleneck; | ||
} | ||
return result; | ||
} | ||
|
||
async load(contextOptions) { | ||
this.context_ = await navigator.ml.createContext(contextOptions); | ||
this.builder_ = new MLGraphBuilder(this.context_); | ||
let data = this.builder_.input('input', { | ||
dataType: 'float32', | ||
dimensions: this.inputOptions.inputDimensions, | ||
}); | ||
data = this.builder_.cast(data, 'float16'); | ||
// Block 0 | ||
const conv1 = this.buildConv_( | ||
data, '0', '0', true, {padding: [0, 1, 0, 1], strides: [2, 2]}); | ||
const conv2 = this.buildConv_(conv1, '1', '0', true, | ||
{groups: 32, padding: [1, 1, 1, 1]}); | ||
const conv3 = this.buildConv_(conv2, '2', '0'); | ||
|
||
// Block 1 | ||
const conv4 = this.buildConv_(conv3, '0', '1', true); | ||
const conv5 = this.buildConv_(conv4, '1', '1', true, | ||
{groups: 144, padding: [0, 1, 0, 1], strides: [2, 2]}); | ||
const conv6 = this.buildConv_(conv5, '2', '1'); | ||
|
||
// Block 2~4 | ||
const bottleneck4 = this.buildBottlenecks_(conv6, | ||
['2', '3', '4'], 192); | ||
|
||
// Block 5 | ||
const conv7 = this.buildConv_(bottleneck4, '0', '5', true); | ||
const conv8 = this.buildConv_(conv7, '1', '5', true, | ||
{groups: 192, padding: [1, 2, 1, 2], strides: [2, 2]}); | ||
const conv9 = this.buildConv_(conv8, '2', '5'); | ||
|
||
// Block 6~8 | ||
const bottleneck8 = this.buildBottlenecks_(conv9, | ||
['6', '7', '8'], 336, 2); | ||
|
||
// Block 9 | ||
const conv10 = this.buildConv_(bottleneck8, '0', '9', true); | ||
const conv11 = this.buildConv_(conv10, '1', '9', true, | ||
{groups: 336, padding: [0, 1, 0, 1], strides: [2, 2]}); | ||
const conv12 = this.buildConv_(conv11, '2', '9'); | ||
|
||
// Block 10~14 | ||
const bottleneck14 = this.buildBottlenecks_(conv12, | ||
['10', '11', '12', '13', '14'], 672); | ||
|
||
// Block 15 | ||
const conv13 = this.buildConv_(bottleneck14, '0', '15', true); | ||
const conv14 = this.buildConv_(conv13, '1', '15', true, | ||
{groups: 672, padding: [2, 2, 2, 2]}); | ||
const conv15 = this.buildConv_(conv14, '2', '15'); | ||
|
||
// Block 16~20 | ||
const bottleneck20 = await this.buildBottlenecks_(conv15, | ||
['16', '17', '18', '19', '20'], 960, 2); | ||
|
||
// Block 21 | ||
const conv16 = this.buildConv_(bottleneck20, '0', '21', true); | ||
const conv17 = this.buildConv_(conv16, '1', '21', true, | ||
{groups: 960, padding: [1, 2, 1, 2], strides: [2, 2]}); | ||
const conv18 = this.buildConv_(conv17, '2', '21'); | ||
|
||
// Block 22~28 | ||
const bottleneck28 = this.buildBottlenecks_(conv18, | ||
['22', '23', '24', '25', '26', '27', '28'], 1632, 2); | ||
|
||
// Block 29 | ||
const conv19 = this.buildConv_(bottleneck28, '0', '29', true); | ||
const conv20 = this.buildConv_(conv19, '1', '29', true, | ||
{groups: 1632, padding: [1, 1, 1, 1]}); | ||
const conv21 = this.buildConv_(conv20, '2', '29'); | ||
|
||
const conv22 = this.buildConv_(conv21, '0', '', true); | ||
const pool1 = this.builder_.averagePool2d(await conv22); | ||
const reshape = this.builder_.reshape(pool1, [1, 1280]); | ||
const gemm = this.buildGemm_(reshape, '0'); | ||
if (contextOptions.deviceType === 'npu') { | ||
return this.builder_.cast(await gemm, 'float32'); | ||
} else { | ||
const softmax = this.builder_.softmax(await gemm); | ||
return this.builder_.cast(softmax, 'float32'); | ||
} | ||
} | ||
|
||
async build(outputOperand) { | ||
this.graph_ = await this.builder_.build({'output': outputOperand}); | ||
} | ||
|
||
// Release the constant tensors of a model | ||
dispose() { | ||
// dispose() is only available in webnn-polyfill | ||
if (this.graph_ !== null && 'dispose' in this.graph_) { | ||
this.graph_.dispose(); | ||
} | ||
} | ||
|
||
async compute(inputBuffer, outputBuffer) { | ||
const inputs = {'input': inputBuffer}; | ||
const outputs = {'output': outputBuffer}; | ||
const results = await this.context_.compute(this.graph_, inputs, outputs); | ||
return results; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,6 +43,9 @@ | |
<label class="btn btn-outline-info custom" name="webnn"> | ||
<input type="radio" name="backend" id="webnn_gpu" autocomplete="off">WebNN (GPU) | ||
</label> | ||
<label class="btn btn-outline-info custom" name="webnn"> | ||
<input type="radio" name="backend" id="webnn_npu" autocomplete="off">WebNN (NPU) | ||
</label> | ||
</div> | ||
</div> | ||
</div> | ||
|
@@ -61,21 +64,43 @@ | |
</div> | ||
</div> | ||
</div> --> | ||
<div class="row mb-2 align-items-center"> | ||
<div class="col-1 col-md-1"> | ||
<span>Data Type</span> | ||
</div> | ||
<div class="col-md-auto"> | ||
<div class="btn-group-toggle" data-toggle="buttons" id="dataTypeBtns"> | ||
<label class="btn btn-outline-info" id="float32Label" active> | ||
<input type="radio" name="layout" id="float32" autocomplete="off" checked>Float32 | ||
</label> | ||
<label class="btn btn-outline-info" id="float16Label"> | ||
<input type="radio" name="layout" id="float16" autocomplete="off">Float16 | ||
</label> | ||
</div> | ||
</div> | ||
</div> | ||
<div class="row align-items-center"> | ||
<div class="col col-md-1"> | ||
<span>Model</span> | ||
</div> | ||
<div class="col-md-auto"> | ||
<div class="btn-group-toggle" data-toggle="buttons" id="modelBtns"> | ||
<label class="btn btn-outline-info"> | ||
<input type="radio" name="model" id="mobilenet" autocomplete="off">MobileNet V2 | ||
</label> | ||
<label class="btn btn-outline-info"> | ||
<input type="radio" name="model" id="squeezenet" autocomplete="off">SqueezeNet | ||
</label> | ||
<label class="btn btn-outline-info"> | ||
<input type="radio" name="model" id="resnet50" autocomplete="off">ResNet V2 50 | ||
</label> | ||
<label class="btn btn-outline-info"> | ||
<input type="radio" name="model" id="mobilenet" autocomplete="off">MobileNet V2 | ||
</label> | ||
<label class="btn btn-outline-info"> | ||
<input type="radio" name="model" id="squeezenet" autocomplete="off">SqueezeNet | ||
</label> | ||
<label class="btn btn-outline-info"> | ||
<input type="radio" name="model" id="resnet50v2" autocomplete="off">ResNet 50 V2 | ||
</label> | ||
<label class="btn btn-outline-info"> | ||
<input type="radio" name="model" id="resnet50v1" autocomplete="off">ResNet 50 V1 | ||
</label> | ||
<label class="btn btn-outline-info"> | ||
<input type="radio" name="model" id="efficientnet" autocomplete="off">EfficientNet | ||
</label> | ||
|
||
</div> | ||
</div> | ||
</div> | ||
|
@@ -213,6 +238,9 @@ <h2 class="text-uppercase text-info">No model selected</h2> | |
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/umd/popper.min.js" | ||
integrity="sha384-9/reFTGAW83EW2RDu2S0VKaIzap3H66lZH81PoYlFhbGU+6BZp6G7niu735Sk7lN" | ||
crossorigin="anonymous"></script> | ||
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js" | ||
integrity="sha256-28ZvjeNGrGNEIj9/2D8YAPE6Vm5JSvvDs+LI4ED31x8=" | ||
crossorigin="anonymous"></script> | ||
<script src="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/js/bootstrap.min.js" | ||
integrity="sha384-B4gt1jrGC7Jh4AgTPSdUtOBvfO8shuf57BaghqFfPlYxofvL8/KUEfYiJOMMV+rV" | ||
crossorigin="anonymous"></script> | ||
|
Oops, something went wrong.