diff --git a/lib/backends/webgl/ops/conv-pack.ts b/lib/backends/webgl/ops/conv-pack.ts index 414bf64d..6d00df15 100644 --- a/lib/backends/webgl/ops/conv-pack.ts +++ b/lib/backends/webgl/ops/conv-pack.ts @@ -15,6 +15,10 @@ import {WebGLReshapePacked} from './reshape-packed'; export class WebGLConvPacked extends Conv { protected artifacts: Artifact[]; protected programInfo: ProgramInfo[]; + private kernelReshape = new WebGLReshapePacked(); + private im2col: WebGLIm2ColPacked; + private matmul = new WebGLMatMulPacked(); + private outputReshape = new WebGLReshapePacked(); run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] { const programManager = inferenceHandler.session.programManager; @@ -34,30 +38,31 @@ export class WebGLConvPacked extends Conv { this.kernelShape}, pads:${this.pads}, strides:${this.strides}`); const outputShape = WebGLConv.calcOutputShape(xshape, kshape, this.dilations, this.pads, this.strides); - const im2col = new WebGLIm2ColPacked(outputShape, kshape, this.dilations, this.pads, this.strides); - const matmul = new WebGLMatMulPacked(); - const reshape = new WebGLReshapePacked(); + if (this.im2col === undefined) { + this.im2col = new WebGLIm2ColPacked(outputShape, kshape, this.dilations, this.pads, this.strides); + } // shape for kernel reshape const shape = new Tensor([2], 'int32', undefined, undefined, new Int32Array([kshape[0], kshape[1] * kshape[2] * kshape[3]])); if (!this.artifacts) { this.artifacts = []; this.programInfo = []; - this.programInfo[0] = im2col.createProgramInfo(inferenceHandler, [inputs[0], inputs[1]]); + this.programInfo[0] = this.im2col.createProgramInfo(inferenceHandler, [inputs[0], inputs[1]]); this.artifacts[0] = programManager.build(this.programInfo[0]); - this.programInfo[1] = reshape.createProgramInfo(inferenceHandler, [inputs[1], shape]); + this.programInfo[1] = this.kernelReshape.createProgramInfo(inferenceHandler, [inputs[1], shape]); this.artifacts[1] = programManager.build(this.programInfo[1]); } // run im2col - const runDataIm2col = im2col.createRunData(inferenceHandler, this.programInfo[0], [inputs[0], inputs[1]]); + const runDataIm2col = this.im2col.createRunData(inferenceHandler, this.programInfo[0], [inputs[0], inputs[1]]); inferenceHandler.checkAndUpdateTextureForm(this.artifacts[0], runDataIm2col); programManager.run(this.artifacts[0], runDataIm2col); const im2colOutput = runDataIm2col.outputTextureData.tensor; // reshape kernel - const runDataKernelReshape = reshape.createRunData(inferenceHandler, this.programInfo[1], [inputs[1], shape]); + const runDataKernelReshape = + this.kernelReshape.createRunData(inferenceHandler, this.programInfo[1], [inputs[1], shape]); inferenceHandler.checkAndUpdateTextureForm(this.artifacts[1], runDataKernelReshape); programManager.run(this.artifacts[1], runDataKernelReshape); const kernelReshaped = runDataKernelReshape.outputTextureData.tensor; @@ -66,11 +71,11 @@ export class WebGLConvPacked extends Conv { const hasBias = (inputs.length === 3); assert(this.artifacts.length > 1, () => 'expect at least 2 artifacts created'); if (this.artifacts.length === 2) { - this.programInfo[2] = matmul.createProgramInfo( + this.programInfo[2] = this.matmul.createProgramInfo( inferenceHandler, hasBias ? [kernelReshaped, im2colOutput, inputs[2]] : [kernelReshaped, im2colOutput]); this.artifacts[2] = programManager.build(this.programInfo[2]); } - const runDataMatmul = matmul.createRunData( + const runDataMatmul = this.matmul.createRunData( inferenceHandler, this.programInfo[2], hasBias ? [kernelReshaped, im2colOutput, inputs[2]] : [kernelReshaped, im2colOutput]); inferenceHandler.checkAndUpdateTextureForm(this.artifacts[2], runDataMatmul); @@ -84,11 +89,11 @@ export class WebGLConvPacked extends Conv { assert(this.artifacts.length > 2, () => 'expect at least 3 artifacts created'); if (this.artifacts.length === 3) { - this.programInfo[3] = reshape.createProgramInfo(inferenceHandler, [matmulOutput, outputShapeTensor]); + this.programInfo[3] = this.outputReshape.createProgramInfo(inferenceHandler, [matmulOutput, outputShapeTensor]); this.artifacts[3] = programManager.build(this.programInfo[3]); } const runDataOutputReshape = - reshape.createRunData(inferenceHandler, this.programInfo[3], [matmulOutput, outputShapeTensor]); + this.outputReshape.createRunData(inferenceHandler, this.programInfo[3], [matmulOutput, outputShapeTensor]); inferenceHandler.checkAndUpdateTextureForm(this.artifacts[3], runDataOutputReshape); programManager.run(this.artifacts[3], runDataOutputReshape); return [runDataOutputReshape.outputTextureData.tensor]; diff --git a/lib/backends/webgl/ops/im2col-pack.ts b/lib/backends/webgl/ops/im2col-pack.ts index 0681a1fd..e6b944d5 100644 --- a/lib/backends/webgl/ops/im2col-pack.ts +++ b/lib/backends/webgl/ops/im2col-pack.ts @@ -48,12 +48,11 @@ export class WebGLIm2ColPacked implements WebGLOperator { if(blockIndex < ${im2colShape[1]} && pos < ${im2colShape[0]}) { offsetY = int(blockIndex / (${this.convOutputShape[rank - 1]})) * ${this.strides[0]} - ${this.pads[1]}; - d0 = offsetY + ${this.dilations[0]} * (int(mod(float(pos), ${kernelSize}.)) / ${wshape[2]} ); + d0 = offsetY + ${this.dilations[0]} * (imod(pos, ${kernelSize}) / ${wshape[2]}); if(d0 < ${xshape[rowDim]} && d0 >= 0) { - offsetX = int(mod(float(blockIndex), ${this.convOutputShape[rank - 1]}.) * ${this.strides[1]}. - ${ - this.pads[0]}.); - d1 = offsetX + ${this.dilations[1]} * (int(mod(mod(float(pos), ${kernelSize}.), ${wshape[2]}.))); + offsetX = imod(blockIndex, ${this.convOutputShape[rank - 1]}) * ${this.strides[1]} - ${this.pads[0]}; + d1 = offsetX + ${this.dilations[1]} * imod(imod(pos, ${kernelSize}), ${wshape[2]}); if(d1 < ${xshape[colDim]} && d1 >= 0) { diff --git a/lib/backends/webgl/ops/reshape-packed.ts b/lib/backends/webgl/ops/reshape-packed.ts index 9cdc6346..a2cc99da 100644 --- a/lib/backends/webgl/ops/reshape-packed.ts +++ b/lib/backends/webgl/ops/reshape-packed.ts @@ -6,7 +6,7 @@ import {Tensor} from '../../../tensor'; import {ShapeUtil} from '../../../util'; import {getGlsl} from '../glsl-source'; import {WebGLInferenceHandler} from '../inference-handler'; -import {ProgramInfo, RunData, WebGLOperator} from '../types'; +import {ProgramInfo, RunData, TextureData, WebGLOperator} from '../types'; import {TextureLayout} from '../types'; import {unpackFromChannel} from './packing_utils'; @@ -32,14 +32,18 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator { // the same between input shape and output shape, the packed reshape can be // treated as no-op. const originInputShape = inputs[0].dims; - const inputShape3D = processDims3D(inputs[0].dims); + this.inputShape3D = processDims3D(inputs[0].dims); let inputLayout: TextureLayout; - if (originInputShape.length === 3) { - inputLayout = handler.getOrCreateTextureLayout(inputs[0], 4, true, originInputShape, true); - } else { + inputLayout = handler.getOrCreateTextureLayout(inputs[0], 4, true, originInputShape, true); + if (originInputShape.length !== 3) { + const originalInputLayout = inputLayout; // if originShape is not a 3D shape, create texture layout from the processed shape. - inputLayout = - handler.createTextureLayoutFromShape(inputShape3D, 4, inputShape3D, {isPacked: true, reverseWH: true}); + inputLayout = handler.createTextureLayoutFromShape( + this.inputShape3D, 4, this.inputShape3D, {isPacked: true, reverseWH: true}); + // if the processed input shape produces texture layout differnt from the original + // one, the run data has to use the processed (3D) input shape later. + this.needSqueezeInputData = + (inputLayout.height !== originalInputLayout.height) || (inputLayout.width !== originalInputLayout.width); } this.outputShape = ShapeUtil.calculateReshapedDims(originInputShape, inputs[1].integerData); @@ -86,9 +90,10 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator { const glsl = getGlsl(handler.session.backend.glContext.version); const shaderSource = ` - ${getReshapedInputCoords(inputShape3D)} + ${getReshapedInputCoords(this.inputShape3D)} ${getFlattenedIndexFrom3D(squeezedOutputShape)} ${unpackFromChannel()} + void main() { ivec3 rc = getOutputCoords(); @@ -99,7 +104,6 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator { int cols = ${squeezedOutputShape[1]}; ${mainLoop} - ${glsl.output} = result; } `; @@ -115,8 +119,26 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator { }; } createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData { - const inputTDs = - [handler.getOrCreateTextureData(inputs[0], handler.getOrCreateTextureLayout(inputs[0], 1, false, [], false))]; + let inputTDs: [TextureData]; + const originalInputLayout = handler.getOrCreateTextureLayout(inputs[0], 1, false, [], false); + const originalInputTD = handler.getOrCreateTextureData(inputs[0], originalInputLayout, false); + + if (this.needSqueezeInputData) { + const squeezedInputLayout: TextureLayout = { + channels: 1, + height: originalInputLayout.height, + width: originalInputLayout.width, + shape: this.inputShape3D, + strides: ShapeUtil.computeStrides(this.inputShape3D), + unpackedShape: this.inputShape3D, + }; + const squeezedInputTD = + handler.createSharedTextureData(squeezedInputLayout, inputs[0].type, originalInputTD.texture); + inputTDs = [squeezedInputTD]; + + } else { + inputTDs = [originalInputTD]; + } let outputLayout = this.originalOutputLayout; if (outputLayout === undefined) { const originInputShape = inputs[0].dims; @@ -133,6 +155,8 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator { } protected outputShape: ReadonlyArray; private originalOutputLayout: TextureLayout; + private inputShape3D: [number, number, number]; + private needSqueezeInputData = false; } function processDims3D(shape: readonly number[]|ReadonlyArray|Tensor.IntegerType): [number, number, number] { diff --git a/test/unittests/backends/webgl/test_reshape_packed.ts b/test/unittests/backends/webgl/test_reshape_packed.ts index 1f7bf3d8..5f30cf60 100644 --- a/test/unittests/backends/webgl/test_reshape_packed.ts +++ b/test/unittests/backends/webgl/test_reshape_packed.ts @@ -156,5 +156,10 @@ function getTestData(): TestData[] { inputShape: [2, 2, 2, 4], outputShape: [2, 1, 4, 4], }, + { + elementCount: 18432, + inputShape: [512, 36, 1, 1], + outputShape: [512, 36], + }, ]; }