From d2dc883fa8b2c35ba8df94e49a920f1d4ca5cf65 Mon Sep 17 00:00:00 2001 From: Tixxx Date: Mon, 10 May 2021 10:00:37 -0700 Subject: [PATCH] fixed depth to space test failure --- lib/backends/webgl/glsl-coordinate-lib.ts | 2 +- lib/backends/webgl/ops/depth-to-space.ts | 77 +++++---- .../backends/webgl/test_depth_to_space.ts | 147 ++++++++++-------- 3 files changed, 126 insertions(+), 100 deletions(-) diff --git a/lib/backends/webgl/glsl-coordinate-lib.ts b/lib/backends/webgl/glsl-coordinate-lib.ts index dcdf6acc..6ad3d018 100644 --- a/lib/backends/webgl/glsl-coordinate-lib.ts +++ b/lib/backends/webgl/glsl-coordinate-lib.ts @@ -1275,7 +1275,7 @@ export class CoordsGlslLib extends GlslLib { const result: {[name: string]: GlslLibRoutine} = {}; this.context.programInfo.samplers.forEach((name, i) => { const layout = programInfo.inputLayouts[i]; - const shape = layout.shape; + const shape = layout.unpackedShape.length > 0 ? layout.unpackedShape : layout.shape; const rank = shape.length; let funcName = `_${name}`; result[funcName] = new GlslLibRoutine( diff --git a/lib/backends/webgl/ops/depth-to-space.ts b/lib/backends/webgl/ops/depth-to-space.ts index 9391a9ae..93b53690 100644 --- a/lib/backends/webgl/ops/depth-to-space.ts +++ b/lib/backends/webgl/ops/depth-to-space.ts @@ -1,48 +1,61 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +import {Attribute} from '../../../attribute'; import {DepthToSpace} from '../../../ops/depth-to-space'; import {Tensor} from '../../../tensor'; -import {getGlsl} from '../glsl-source'; import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, RunData} from '../types'; +import {Artifact, ProgramInfo, RunData} from '../types'; + +import {reshape} from './reshape'; +import {WebGLTranspose} from './transpose'; export class WebGLDepthToSpace extends DepthToSpace { + protected transposeProgramInfo: ProgramInfo; + + protected transposeArtifact: Artifact; + run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] { - return inferenceHandler.run(this, inputs); - } - createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo { - const inputLayout = handler.getOrCreateTextureLayout(inputs[0]); - const outputShape = this.getOutShape(inputs[0]); - const glsl = getGlsl(handler.session.backend.glContext.version); - const shaderSource = ` - void main() { - ivec4 coords = getOutputCoords(); - int b = coords[0]; - int d = coords[1]; - int h = coords[2]; - int w = coords[3]; + const programManager = inferenceHandler.session.programManager; + const transposePerm = this.mode === 'DCR' ? [0, 3, 4, 1, 5, 2] : [0, 1, 4, 2, 5, 3]; + const firstReshapeShape = this.mode === 'DCR' ? + [ + inputs[0].dims[0], this.blocksize, this.blocksize, inputs[0].dims[1] / this.blocksizeSqr, inputs[0].dims[2], + inputs[0].dims[3] + ] : + [ + inputs[0].dims[0], inputs[0].dims[1] / this.blocksizeSqr, this.blocksize, this.blocksize, inputs[0].dims[2], + inputs[0].dims[3] + ]; + + const transpose = new WebGLTranspose(); + const attributes = new Attribute(undefined); + attributes.set('perm', 'ints', transposePerm); + transpose.initialize(attributes); - int in_h = h / ${this.blocksize}; - int offset_h = imod(h, ${this.blocksize}); - int in_w = w / ${this.blocksize}; - int offset_w = imod(w, ${this.blocksize}); - int offset_d = (offset_h * ${this.blocksize} + offset_w) * - ${outputShape[1]}; - int in_depth = d + offset_d; + // First reshape - float result = getX(b, in_depth, in_h, in_w); - ${glsl.output} = vec4(result, 0, 0, 0); + const firstReshapedTensor = reshape(inferenceHandler, inputs[0], firstReshapeShape); + + // transpose + if (!this.transposeProgramInfo) { + this.transposeProgramInfo = transpose.createProgramInfo(inferenceHandler, [firstReshapedTensor]); + this.transposeArtifact = programManager.build(this.transposeProgramInfo); } - `; - return { - inputLayouts: [inputLayout], - outputLayout: handler.createTextureLayoutFromShape(outputShape), - samplers: ['X'], - shaderSource, - hasMain: true - }; + const runDataTranspose = + transpose.createRunData(inferenceHandler, this.transposeProgramInfo, [firstReshapedTensor]); + inferenceHandler.checkAndUpdateTextureForm(this.transposeArtifact, runDataTranspose); + programManager.run(this.transposeArtifact, runDataTranspose); + const transposeOutput = runDataTranspose.outputTextureData.tensor; + + // Second reshape + const result = reshape(inferenceHandler, transposeOutput, [ + inputs[0].dims[0], inputs[0].dims[1] / this.blocksizeSqr, inputs[0].dims[2] * this.blocksize, + inputs[0].dims[3] * this.blocksize + ]); + return [result]; } + protected getOutShape(input: Tensor): number[] { const batchSize = input.dims[0]; const inputDepth = input.dims[1]; diff --git a/test/unittests/backends/webgl/test_depth_to_space.ts b/test/unittests/backends/webgl/test_depth_to_space.ts index 5795f6aa..0569d709 100644 --- a/test/unittests/backends/webgl/test_depth_to_space.ts +++ b/test/unittests/backends/webgl/test_depth_to_space.ts @@ -9,10 +9,8 @@ import {WebGLInferenceHandler} from '../../../../lib/backends/webgl/inference-ha import {WebGLDepthToSpace} from '../../../../lib/backends/webgl/ops/depth-to-space'; import {Profiler} from '../../../../lib/instrument'; import {Tensor} from '../../../../lib/tensor'; -import {ShapeUtil} from '../../../../lib/util'; import {createAscendingArray} from './test_utils'; -import {createTextureFromArray, generateArrayForUnpackedTexture} from './test_utils'; let backend: Backend|undefined; let sessionhandler: SessionHandler|undefined; @@ -22,8 +20,6 @@ describe('#UnitTest# - unpacked WebGLDepthToSpace - Tensor WebGLDepthToSpace', ( before('Initialize Context', async () => { const profiler = Profiler.create(); backend = await Backend('webgl'); - // Explicitly set to true to trigger packed version - (backend as WebGLBackend).pack = false; sessionhandler = backend.createSessionHandler({profiler}); inferenceHandler = sessionhandler.createInferenceHandler(); }); @@ -50,57 +46,30 @@ describe('#UnitTest# - unpacked WebGLDepthToSpace - Tensor WebGLDepthToSpace', ( const attributes = new Attribute(undefined); const blocksize = testData.blocksize; attributes.set('blocksize', 'int', blocksize); + attributes.set('mode', 'string', testData.mode as string); op.initialize(attributes); const elementCount = testData.elementCount; const inputTensorShape = testData.inputShape; - const inputTextureShape = testData.inputTextureShape; // create input data and tensor. const inputData = testData.rawInput ? testData.rawInput : createAscendingArray(elementCount); const inputTensorA = new Tensor(inputTensorShape, 'float32', undefined, undefined, inputData); // manually creat packed texture from inputTensor, and insert in cache - const gl = webglInferenceHandler.session.textureManager.glContext.gl; webglInferenceHandler.session.textureManager.glContext.checkError(); - const webglTextureA = createTextureFromArray( - webglInferenceHandler.session.textureManager.glContext, - generateArrayForUnpackedTexture(testData.rawInput ? testData.rawInput : inputData), gl.RGBA, - inputTextureShape[0], inputTextureShape[1]); webglInferenceHandler.session.textureManager.glContext.checkError(); - const textureDataA = { - width: inputTextureShape[0], - height: inputTextureShape[1], - channels: 1 as const, - isPacked: false, - shape: inputTextureShape, - strides: ShapeUtil.computeStrides(inputTextureShape), - unpackedShape: inputTensorShape, - tensor: inputTensorA, - texture: webglTextureA! - }; - - webglInferenceHandler.setTextureData(inputTensorA.dataId, textureDataA); - - // compile shader code - const programInfo = op.createProgramInfo(inferenceHandler! as WebGLInferenceHandler, [inputTensorA]); - - const artifact = webglInferenceHandler.session.programManager.build(programInfo); - webglInferenceHandler.session.programManager.setArtifact(op, artifact); - - // run kernal and get output - const runData = op.createRunData(webglInferenceHandler, artifact.programInfo, [inputTensorA]); - webglInferenceHandler.session.programManager.run(artifact, runData); - const result = runData.outputTextureData.tensor.data; + + const result = op.run(webglInferenceHandler, [inputTensorA]); webglInferenceHandler.session.textureManager.glContext.checkError(); // verify result. const expectedOutput = testData.expectedOutput; - expect(result).to.not.equal(null); + expect(result[0].data).to.not.equal(null); - expect(result).to.have.lengthOf(elementCount); - expect(result).to.deep.equal(expectedOutput); + expect(result[0].data).to.have.lengthOf(elementCount); + expect(result[0].data).to.deep.equal(expectedOutput); }); } }); @@ -127,36 +96,80 @@ function getTestData(): TestData[] { outputTextureShape: [4, 2], rawInput: new Float32Array([0., 9., 18., 27., 36., 45., 54., 63.]), expectedOutput: new Float32Array([0., 18., 36., 54., 9., 27., 45., 63.]), + mode: 'DCR', + }, + { + elementCount: 16, + blocksize: 2, + inputShape: [1, 8, 1, 2], + outputShape: [1, 2, 4, 2], + inputTextureShape: [1, 16], + outputTextureShape: [8, 2], + rawInput: new Float32Array([0., 1., 9., 10, 18., 19, 27., 28., 36., 37., 45., 46., 54., 55., 63., 64.]), + expectedOutput: new Float32Array([0, 18, 1, 19, 36, 54, 37, 55, 9, 27, 10, 28, 45, 63, 46, 64]), + mode: 'DCR', + }, + + { + elementCount: 48, + blocksize: 2, + inputShape: [1, 8, 2, 3], + outputShape: [1, 2, 4, 6], + inputTextureShape: [16, 3], + outputTextureShape: [8, 6], + rawInput: new Float32Array([ + 0., 1., 2., 3., 4., 5., 9., 10., 11., 12., 13., 14., 18., 19., 20., 21., + 22., 23., 27., 28., 29., 30., 31., 32., 36., 37., 38., 39., 40., 41., 45., 46., + 47., 48., 49., 50., 54., 55., 56., 57., 58., 59., 63., 64., 65., 66., 67., 68. + ]), + expectedOutput: new Float32Array([ + 0., 18., 1., 19., 2., 20., 36., 54., 37., 55., 38., 56., 3., 21., 4., 22., + 5., 23., 39., 57., 40., 58., 41., 59., 9., 27., 10., 28., 11., 29., 45., 63., + 46., 64., 47., 65., 12., 30., 13., 31., 14., 32., 48., 66., 49., 67., 50., 68. + ]), + mode: 'DCR', + }, + { + elementCount: 8, + blocksize: 2, + inputShape: [1, 8, 1, 1], + outputShape: [1, 2, 2, 2], + inputTextureShape: [8, 1], + outputTextureShape: [4, 2], + rawInput: new Float32Array([0., 9., 18., 27., 36., 45., 54., 63.]), + expectedOutput: new Float32Array([0, 9, 18, 27, 36, 45, 54, 63]), + mode: 'CRD', + }, + { + elementCount: 16, + blocksize: 2, + inputShape: [1, 8, 1, 2], + outputShape: [1, 2, 4, 2], + inputTextureShape: [1, 16], + outputTextureShape: [8, 2], + rawInput: new Float32Array([0., 1., 9., 10, 18., 19, 27., 28., 36., 37., 45., 46., 54., 55., 63., 64.]), + expectedOutput: new Float32Array([0, 9, 1, 10, 18, 27, 19, 28, 36, 45, 37, 46, 54, 63, 55, 64]), + mode: 'CRD', + }, + + { + elementCount: 48, + blocksize: 2, + inputShape: [1, 8, 2, 3], + outputShape: [1, 2, 4, 6], + inputTextureShape: [16, 3], + outputTextureShape: [8, 6], + rawInput: new Float32Array([ + 0., 1., 2., 3., 4., 5., 9., 10., 11., 12., 13., 14., 18., 19., 20., 21., + 22., 23., 27., 28., 29., 30., 31., 32., 36., 37., 38., 39., 40., 41., 45., 46., + 47., 48., 49., 50., 54., 55., 56., 57., 58., 59., 63., 64., 65., 66., 67., 68. + ]), + expectedOutput: new Float32Array([ + 0., 9., 1., 10., 2., 11., 18., 27., 19., 28., 20., 29., 3., 12., 4., 13., + 5., 14., 21., 30., 22., 31., 23., 32., 36., 45., 37., 46., 38., 47., 54., 63., + 55., 64., 56., 65., 39., 48., 40., 49., 41., 50., 57., 66., 58., 67., 59., 68. + ]), + mode: 'CRD', }, - // { - // elementCount: 16, - // blocksize: 2, - // inputShape: [1, 8, 2, 1], - // outputShape: [1, 2, 4, 2], - // inputTextureShape: [16, 1], - // outputTextureShape: [8, 2], - // rawInput: new Float32Array([0., 1., 9., 10, 18., 19, 27., 28., 36., 37., 45., 46., 54., 55., 63., 64.]), - // expectedOutput: new - // Float32Array([0., 18., 1., 19., 36., 54., 37., 55., 9., 27., 10., 28., 45., 63., 46., 64.]), - // }, - - // { - // elementCount: 48, - // blocksize: 2, - // inputShape: [1, 8, 2, 3], - // outputShape: [1, 2, 4, 6], - // inputTextureShape: [16, 3], - // outputTextureShape: [8, 6], - // rawInput: new Float32Array([ - // 0., 1., 2., 3., 4., 5., 9., 10., 11., 12., 13., 14., 18., 19., 20., 21., - // 22., 23., 27., 28., 29., 30., 31., 32., 36., 37., 38., 39., 40., 41., 45., 46., - // 47., 48., 49., 50., 54., 55., 56., 57., 58., 59., 63., 64., 65., 66., 67., 68. - // ]), - // expectedOutput: new Float32Array([ - // 0., 18., 1., 19., 2., 20., 36., 54., 37., 55., 38., 56., 3., 21., 4., 22., - // 5., 23., 39., 57., 40., 58., 41., 59., 9., 27., 10., 28., 11., 29., 45., 63., - // 46., 64., 47., 65., 12., 30., 13., 31., 14., 32., 48., 66., 49., 67., 50., 68. - // ]), - // }, ]; }