Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Commit

Permalink
fixed depth to space test failure
Browse files Browse the repository at this point in the history
  • Loading branch information
Tixxx committed May 10, 2021
1 parent 347dbdf commit d2dc883
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 100 deletions.
2 changes: 1 addition & 1 deletion lib/backends/webgl/glsl-coordinate-lib.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1275,7 +1275,7 @@ export class CoordsGlslLib extends GlslLib {
const result: {[name: string]: GlslLibRoutine} = {};
this.context.programInfo.samplers.forEach((name, i) => {
const layout = programInfo.inputLayouts[i];
const shape = layout.shape;
const shape = layout.unpackedShape.length > 0 ? layout.unpackedShape : layout.shape;
const rank = shape.length;
let funcName = `_${name}`;
result[funcName] = new GlslLibRoutine(
Expand Down
77 changes: 45 additions & 32 deletions lib/backends/webgl/ops/depth-to-space.ts
Original file line number Diff line number Diff line change
@@ -1,48 +1,61 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {Attribute} from '../../../attribute';
import {DepthToSpace} from '../../../ops/depth-to-space';
import {Tensor} from '../../../tensor';
import {getGlsl} from '../glsl-source';
import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo, RunData} from '../types';
import {Artifact, ProgramInfo, RunData} from '../types';

import {reshape} from './reshape';
import {WebGLTranspose} from './transpose';

export class WebGLDepthToSpace extends DepthToSpace {
protected transposeProgramInfo: ProgramInfo;

protected transposeArtifact: Artifact;

run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
return inferenceHandler.run(this, inputs);
}
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
const inputLayout = handler.getOrCreateTextureLayout(inputs[0]);
const outputShape = this.getOutShape(inputs[0]);
const glsl = getGlsl(handler.session.backend.glContext.version);
const shaderSource = `
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int d = coords[1];
int h = coords[2];
int w = coords[3];
const programManager = inferenceHandler.session.programManager;
const transposePerm = this.mode === 'DCR' ? [0, 3, 4, 1, 5, 2] : [0, 1, 4, 2, 5, 3];
const firstReshapeShape = this.mode === 'DCR' ?
[
inputs[0].dims[0], this.blocksize, this.blocksize, inputs[0].dims[1] / this.blocksizeSqr, inputs[0].dims[2],
inputs[0].dims[3]
] :
[
inputs[0].dims[0], inputs[0].dims[1] / this.blocksizeSqr, this.blocksize, this.blocksize, inputs[0].dims[2],
inputs[0].dims[3]
];

const transpose = new WebGLTranspose();
const attributes = new Attribute(undefined);
attributes.set('perm', 'ints', transposePerm);
transpose.initialize(attributes);

int in_h = h / ${this.blocksize};
int offset_h = imod(h, ${this.blocksize});
int in_w = w / ${this.blocksize};
int offset_w = imod(w, ${this.blocksize});
int offset_d = (offset_h * ${this.blocksize} + offset_w) *
${outputShape[1]};
int in_depth = d + offset_d;
// First reshape

float result = getX(b, in_depth, in_h, in_w);
${glsl.output} = vec4(result, 0, 0, 0);
const firstReshapedTensor = reshape(inferenceHandler, inputs[0], firstReshapeShape);

// transpose
if (!this.transposeProgramInfo) {
this.transposeProgramInfo = transpose.createProgramInfo(inferenceHandler, [firstReshapedTensor]);
this.transposeArtifact = programManager.build(this.transposeProgramInfo);
}
`;
return {
inputLayouts: [inputLayout],
outputLayout: handler.createTextureLayoutFromShape(outputShape),
samplers: ['X'],
shaderSource,
hasMain: true
};
const runDataTranspose =
transpose.createRunData(inferenceHandler, this.transposeProgramInfo, [firstReshapedTensor]);
inferenceHandler.checkAndUpdateTextureForm(this.transposeArtifact, runDataTranspose);
programManager.run(this.transposeArtifact, runDataTranspose);
const transposeOutput = runDataTranspose.outputTextureData.tensor;

// Second reshape
const result = reshape(inferenceHandler, transposeOutput, [
inputs[0].dims[0], inputs[0].dims[1] / this.blocksizeSqr, inputs[0].dims[2] * this.blocksize,
inputs[0].dims[3] * this.blocksize
]);
return [result];
}

protected getOutShape(input: Tensor): number[] {
const batchSize = input.dims[0];
const inputDepth = input.dims[1];
Expand Down
147 changes: 80 additions & 67 deletions test/unittests/backends/webgl/test_depth_to_space.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@ import {WebGLInferenceHandler} from '../../../../lib/backends/webgl/inference-ha
import {WebGLDepthToSpace} from '../../../../lib/backends/webgl/ops/depth-to-space';
import {Profiler} from '../../../../lib/instrument';
import {Tensor} from '../../../../lib/tensor';
import {ShapeUtil} from '../../../../lib/util';

import {createAscendingArray} from './test_utils';
import {createTextureFromArray, generateArrayForUnpackedTexture} from './test_utils';

let backend: Backend|undefined;
let sessionhandler: SessionHandler|undefined;
Expand All @@ -22,8 +20,6 @@ describe('#UnitTest# - unpacked WebGLDepthToSpace - Tensor WebGLDepthToSpace', (
before('Initialize Context', async () => {
const profiler = Profiler.create();
backend = await Backend('webgl');
// Explicitly set to true to trigger packed version
(backend as WebGLBackend).pack = false;
sessionhandler = backend.createSessionHandler({profiler});
inferenceHandler = sessionhandler.createInferenceHandler();
});
Expand All @@ -50,57 +46,30 @@ describe('#UnitTest# - unpacked WebGLDepthToSpace - Tensor WebGLDepthToSpace', (
const attributes = new Attribute(undefined);
const blocksize = testData.blocksize;
attributes.set('blocksize', 'int', blocksize);
attributes.set('mode', 'string', testData.mode as string);

op.initialize(attributes);
const elementCount = testData.elementCount;
const inputTensorShape = testData.inputShape;
const inputTextureShape = testData.inputTextureShape;

// create input data and tensor.
const inputData = testData.rawInput ? testData.rawInput : createAscendingArray(elementCount);
const inputTensorA = new Tensor(inputTensorShape, 'float32', undefined, undefined, inputData);

// manually creat packed texture from inputTensor, and insert in cache
const gl = webglInferenceHandler.session.textureManager.glContext.gl;
webglInferenceHandler.session.textureManager.glContext.checkError();
const webglTextureA = createTextureFromArray(
webglInferenceHandler.session.textureManager.glContext,
generateArrayForUnpackedTexture(testData.rawInput ? testData.rawInput : inputData), gl.RGBA,
inputTextureShape[0], inputTextureShape[1]);

webglInferenceHandler.session.textureManager.glContext.checkError();
const textureDataA = {
width: inputTextureShape[0],
height: inputTextureShape[1],
channels: 1 as const,
isPacked: false,
shape: inputTextureShape,
strides: ShapeUtil.computeStrides(inputTextureShape),
unpackedShape: inputTensorShape,
tensor: inputTensorA,
texture: webglTextureA!
};

webglInferenceHandler.setTextureData(inputTensorA.dataId, textureDataA);

// compile shader code
const programInfo = op.createProgramInfo(inferenceHandler! as WebGLInferenceHandler, [inputTensorA]);

const artifact = webglInferenceHandler.session.programManager.build(programInfo);
webglInferenceHandler.session.programManager.setArtifact(op, artifact);

// run kernal and get output
const runData = op.createRunData(webglInferenceHandler, artifact.programInfo, [inputTensorA]);
webglInferenceHandler.session.programManager.run(artifact, runData);
const result = runData.outputTextureData.tensor.data;

const result = op.run(webglInferenceHandler, [inputTensorA]);

webglInferenceHandler.session.textureManager.glContext.checkError();
// verify result.
const expectedOutput = testData.expectedOutput;
expect(result).to.not.equal(null);
expect(result[0].data).to.not.equal(null);

expect(result).to.have.lengthOf(elementCount);
expect(result).to.deep.equal(expectedOutput);
expect(result[0].data).to.have.lengthOf(elementCount);
expect(result[0].data).to.deep.equal(expectedOutput);
});
}
});
Expand All @@ -127,36 +96,80 @@ function getTestData(): TestData[] {
outputTextureShape: [4, 2],
rawInput: new Float32Array([0., 9., 18., 27., 36., 45., 54., 63.]),
expectedOutput: new Float32Array([0., 18., 36., 54., 9., 27., 45., 63.]),
mode: 'DCR',
},
{
elementCount: 16,
blocksize: 2,
inputShape: [1, 8, 1, 2],
outputShape: [1, 2, 4, 2],
inputTextureShape: [1, 16],
outputTextureShape: [8, 2],
rawInput: new Float32Array([0., 1., 9., 10, 18., 19, 27., 28., 36., 37., 45., 46., 54., 55., 63., 64.]),
expectedOutput: new Float32Array([0, 18, 1, 19, 36, 54, 37, 55, 9, 27, 10, 28, 45, 63, 46, 64]),
mode: 'DCR',
},

{
elementCount: 48,
blocksize: 2,
inputShape: [1, 8, 2, 3],
outputShape: [1, 2, 4, 6],
inputTextureShape: [16, 3],
outputTextureShape: [8, 6],
rawInput: new Float32Array([
0., 1., 2., 3., 4., 5., 9., 10., 11., 12., 13., 14., 18., 19., 20., 21.,
22., 23., 27., 28., 29., 30., 31., 32., 36., 37., 38., 39., 40., 41., 45., 46.,
47., 48., 49., 50., 54., 55., 56., 57., 58., 59., 63., 64., 65., 66., 67., 68.
]),
expectedOutput: new Float32Array([
0., 18., 1., 19., 2., 20., 36., 54., 37., 55., 38., 56., 3., 21., 4., 22.,
5., 23., 39., 57., 40., 58., 41., 59., 9., 27., 10., 28., 11., 29., 45., 63.,
46., 64., 47., 65., 12., 30., 13., 31., 14., 32., 48., 66., 49., 67., 50., 68.
]),
mode: 'DCR',
},
{
elementCount: 8,
blocksize: 2,
inputShape: [1, 8, 1, 1],
outputShape: [1, 2, 2, 2],
inputTextureShape: [8, 1],
outputTextureShape: [4, 2],
rawInput: new Float32Array([0., 9., 18., 27., 36., 45., 54., 63.]),
expectedOutput: new Float32Array([0, 9, 18, 27, 36, 45, 54, 63]),
mode: 'CRD',
},
{
elementCount: 16,
blocksize: 2,
inputShape: [1, 8, 1, 2],
outputShape: [1, 2, 4, 2],
inputTextureShape: [1, 16],
outputTextureShape: [8, 2],
rawInput: new Float32Array([0., 1., 9., 10, 18., 19, 27., 28., 36., 37., 45., 46., 54., 55., 63., 64.]),
expectedOutput: new Float32Array([0, 9, 1, 10, 18, 27, 19, 28, 36, 45, 37, 46, 54, 63, 55, 64]),
mode: 'CRD',
},

{
elementCount: 48,
blocksize: 2,
inputShape: [1, 8, 2, 3],
outputShape: [1, 2, 4, 6],
inputTextureShape: [16, 3],
outputTextureShape: [8, 6],
rawInput: new Float32Array([
0., 1., 2., 3., 4., 5., 9., 10., 11., 12., 13., 14., 18., 19., 20., 21.,
22., 23., 27., 28., 29., 30., 31., 32., 36., 37., 38., 39., 40., 41., 45., 46.,
47., 48., 49., 50., 54., 55., 56., 57., 58., 59., 63., 64., 65., 66., 67., 68.
]),
expectedOutput: new Float32Array([
0., 9., 1., 10., 2., 11., 18., 27., 19., 28., 20., 29., 3., 12., 4., 13.,
5., 14., 21., 30., 22., 31., 23., 32., 36., 45., 37., 46., 38., 47., 54., 63.,
55., 64., 56., 65., 39., 48., 40., 49., 41., 50., 57., 66., 58., 67., 59., 68.
]),
mode: 'CRD',
},
// {
// elementCount: 16,
// blocksize: 2,
// inputShape: [1, 8, 2, 1],
// outputShape: [1, 2, 4, 2],
// inputTextureShape: [16, 1],
// outputTextureShape: [8, 2],
// rawInput: new Float32Array([0., 1., 9., 10, 18., 19, 27., 28., 36., 37., 45., 46., 54., 55., 63., 64.]),
// expectedOutput: new
// Float32Array([0., 18., 1., 19., 36., 54., 37., 55., 9., 27., 10., 28., 45., 63., 46., 64.]),
// },

// {
// elementCount: 48,
// blocksize: 2,
// inputShape: [1, 8, 2, 3],
// outputShape: [1, 2, 4, 6],
// inputTextureShape: [16, 3],
// outputTextureShape: [8, 6],
// rawInput: new Float32Array([
// 0., 1., 2., 3., 4., 5., 9., 10., 11., 12., 13., 14., 18., 19., 20., 21.,
// 22., 23., 27., 28., 29., 30., 31., 32., 36., 37., 38., 39., 40., 41., 45., 46.,
// 47., 48., 49., 50., 54., 55., 56., 57., 58., 59., 63., 64., 65., 66., 67., 68.
// ]),
// expectedOutput: new Float32Array([
// 0., 18., 1., 19., 2., 20., 36., 54., 37., 55., 38., 56., 3., 21., 4., 22.,
// 5., 23., 39., 57., 40., 58., 41., 59., 9., 27., 10., 28., 11., 29., 45., 63.,
// 46., 64., 47., 65., 12., 30., 13., 31., 14., 32., 48., 66., 49., 67., 50., 68.
// ]),
// },
];
}

0 comments on commit d2dc883

Please sign in to comment.