Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Commit

Permalink
swap last 2 dims in packed concat for reversed WH (#293)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tixxx authored May 3, 2021
1 parent 0c998cc commit 94d3843
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 43 deletions.
13 changes: 7 additions & 6 deletions lib/backends/webgl/glsl-coordinate-lib.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ export class CoordsGlslLib extends GlslLib {
* Generates code for packed output sampler.
*/
protected getPackedOutputSamplingSnippet(outputLayout: TextureLayout): {[name: string]: GlslLibRoutine;} {
const outShape = outputLayout.shape;
const outShape = outputLayout.unpackedShape;
const outTexShape = [outputLayout.width, outputLayout.height];
const result: {[name: string]: GlslLibRoutine} = {};
const funcName = 'getOutputCoords';
Expand Down Expand Up @@ -231,7 +231,7 @@ export class CoordsGlslLib extends GlslLib {

const packedTexShape = texShape;
// texels needed to accommodate a logical row
const texelsInLogicalRow = shape[1];
const texelsInLogicalRow = Math.ceil(shape[1] / 2);

/**
* getOutputCoords
Expand Down Expand Up @@ -264,8 +264,9 @@ export class CoordsGlslLib extends GlslLib {
*/
protected getOutputPacked3DCoords(shape: [number, number, number], texShape: [number, number]): GlslLibRoutine {
const packedTexShape = [texShape[0], texShape[1]];
const texelsInLogicalRow = shape[2];
const texelsInBatch = texelsInLogicalRow * shape[1];
const texelsInLogicalRow = Math.ceil(shape[2] / 2);

const texelsInBatch = texelsInLogicalRow * Math.ceil(shape[1] / 2);
const source = `
ivec3 getOutputCoords() {
ivec2 resTexRC = ivec2(TexCoords.xy *
Expand All @@ -291,8 +292,8 @@ export class CoordsGlslLib extends GlslLib {
protected getOutputPackedNDCoords(shape: ReadonlyArray<number>, texShape: [number, number]): GlslLibRoutine {
const packedTexShape = [texShape[0], texShape[1]];

const texelsInLogicalRow = shape[shape.length - 1];
const texelsInBatch = texelsInLogicalRow * shape[shape.length - 2];
const texelsInLogicalRow = Math.ceil(shape[shape.length - 1] / 2);
const texelsInBatch = texelsInLogicalRow * Math.ceil(shape[shape.length - 2] / 2);
let texelsInBatchN = texelsInBatch;
let batches = ``;
let coords = 'b, r, c';
Expand Down
14 changes: 10 additions & 4 deletions lib/backends/webgl/ops/concat_packed.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ export class WebGLPackedConcat extends Concat implements WebGLOperator {
const unpackChannel = unpackFromChannel();

const shapes = inputs.map(i => i.dims);
const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank);
const allGlChannels = ['x', 'y', 'z', 'w', 'u', 'v'];
const channels = allGlChannels.slice(0, rank);
const offsets: number[] = new Array(shapes.length - 1);
const samplers = inputs.map((v, i) => `X${i}`);

Expand Down Expand Up @@ -88,6 +89,10 @@ export class WebGLPackedConcat extends Concat implements WebGLOperator {
void main() {
${dtype} coords = getOutputCoords();
int lastDim = coords.${allGlChannels[rank - 1]};
coords.${allGlChannels[rank - 1]} = coords.${allGlChannels[rank - 2]};
coords.${allGlChannels[rank - 2]} = lastDim;
vec4 result = vec4(getValue(${coords}), 0., 0., 0.);
${coords[rank - 1]} = ${coords[rank - 1]} + 1;
Expand All @@ -110,8 +115,9 @@ export class WebGLPackedConcat extends Concat implements WebGLOperator {
`;

return {
inputLayouts: inputs.map(t => handler.getOrCreateTextureLayout(t)),
outputLayout: handler.createTextureLayoutFromShape(outputShape),
inputLayouts: inputs.map(t => handler.getOrCreateTextureLayout(t, 4, true, t.dims, true)),
outputLayout:
handler.createTextureLayoutFromShape(outputShape, 4, outputShape, {isPacked: true, reverseWH: true}),
samplers,
shaderSource,
hasMain: true,
Expand All @@ -120,7 +126,7 @@ export class WebGLPackedConcat extends Concat implements WebGLOperator {
};
}
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
const inputTDs = inputs.map((t, i) => handler.getOrCreateTextureData(t, programInfo.inputLayouts[i]));
const inputTDs = inputs.map((t, i) => handler.getOrCreateTextureData(t, programInfo.inputLayouts[i], true));
return {
inputTextureDatas: inputTDs,
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
Expand Down
49 changes: 16 additions & 33 deletions test/unittests/backends/webgl/test_concat_packed.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ describe('#UnitTest# - packed concat - Tensor concat', () => {
describe(`Test concat ${JSON.stringify(testData)}`, () => {});
it(`Test packed concat kernel `, () => {
const webglInferenceHandler = inferenceHandler as WebGLInferenceHandler;
// webglInferenceHandler.session.pack = false;

// TODO support WebGl 1.0
if (webglInferenceHandler.session.textureManager.glContext.version === 1) {
Expand All @@ -55,7 +56,6 @@ describe('#UnitTest# - packed concat - Tensor concat', () => {
const elementCount = testData.elementCount;
const inputTensorShape = testData.inputShape;
const inputTextureShape = testData.inputTextureShape;
const outputTensorShape = testData.outputShape;

// create input data and tensor. The input data will be used to verify if the output tensor contains the
// same value but possibly different order depending on our packing algorithm.
Expand All @@ -65,6 +65,7 @@ describe('#UnitTest# - packed concat - Tensor concat', () => {

// manually creat packed texture from inputTensor, and insert in cache
const gl = webglInferenceHandler.session.textureManager.glContext.gl;

webglInferenceHandler.session.textureManager.glContext.checkError();
const webglTextureA = createTextureFromArray(
webglInferenceHandler.session.textureManager.glContext, testData.rawInput ? testData.rawInput : inputData,
Expand All @@ -82,7 +83,7 @@ describe('#UnitTest# - packed concat - Tensor concat', () => {
isPacked: true,
shape: packedShape,
strides: ShapeUtil.computeStrides(packedShape),
unpackedShape: outputTensorShape,
unpackedShape: inputTensorShape,
tensor: inputTensorA,
texture: webglTextureA!
};
Expand All @@ -93,13 +94,13 @@ describe('#UnitTest# - packed concat - Tensor concat', () => {
isPacked: true,
shape: packedShape,
strides: ShapeUtil.computeStrides(packedShape),
unpackedShape: outputTensorShape,
unpackedShape: inputTensorShape,
tensor: inputTensorB,
texture: webglTextureB!
};

webglInferenceHandler.setTextureData(inputTensorA.dataId, textureDataA);
webglInferenceHandler.setTextureData(inputTensorB.dataId, textureDataB);
webglInferenceHandler.setTextureData(inputTensorA.dataId, textureDataA, true);
webglInferenceHandler.setTextureData(inputTensorB.dataId, textureDataB, true);

// compile shader code
const programInfo =
Expand All @@ -117,7 +118,6 @@ describe('#UnitTest# - packed concat - Tensor concat', () => {
// verify result.
const expectedOutput = testData.expectedOutput;
expect(result).to.not.equal(null);

expect(result).to.have.lengthOf(elementCount * 2);

expect(result).to.deep.equal(expectedOutput);
Expand Down Expand Up @@ -167,32 +167,15 @@ function getTestData(): TestData[] {
outputShape: [4, 4],
inputTextureShape: [2, 1],
outputTextureShape: [2, 2],
expectedOutput: new Float32Array([
1,
2,
5,
6,
3,
4,
7,
8,
1,
2,
5,
6,
3,
4,
7,
8,
]),
expectedOutput: new Float32Array([1, 2, 5, 6, 3, 4, 7, 8, 1, 2, 5, 6, 3, 4, 7, 8]),
},
{
elementCount: 8,
axis: 1,
inputShape: [2, 4],
outputShape: [2, 8],
inputTextureShape: [2, 1],
outputTextureShape: [2, 4],
outputTextureShape: [4, 2],
expectedOutput: new Float32Array([
1,
2,
Expand Down Expand Up @@ -291,8 +274,8 @@ function getTestData(): TestData[] {
outputTextureShape: [8, 4],
expectedOutput: new Float32Array([
1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 1, 2, 5, 6, 3, 4,
7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 25, 26, 29, 30, 27, 28, 31, 32, 25, 26, 29, 30,
27, 28, 31, 32, 25, 26, 29, 30, 27, 28, 31, 32, 25, 26, 29, 30, 27, 28, 31, 32
7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30,
27, 28, 31, 32, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, 28, 31, 32
])
},

Expand All @@ -304,9 +287,9 @@ function getTestData(): TestData[] {
inputTextureShape: [2, 4],
outputTextureShape: [8, 4],
expectedOutput: new Float32Array([
1, 2, 5, 6, 3, 4, 7, 8, 1, 2, 5, 6, 3, 4, 7, 8, 17, 18, 21, 22, 19, 20,
23, 24, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, 28, 31, 32, 25, 26, 29, 30,
27, 28, 31, 32, 25, 26, 29, 30, 27, 28, 31, 32, 25, 26, 29, 30, 27, 28, 31, 32
1, 2, 5, 6, 3, 4, 7, 8, 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12,
15, 16, 9, 10, 13, 14, 11, 12, 15, 16, 17, 18, 21, 22, 19, 20, 23, 24, 17, 18, 21, 22,
19, 20, 23, 24, 25, 26, 29, 30, 27, 28, 31, 32, 25, 26, 29, 30, 27, 28, 31, 32
])
},
{
Expand All @@ -317,9 +300,9 @@ function getTestData(): TestData[] {
inputTextureShape: [2, 4],
outputTextureShape: [8, 4],
expectedOutput: new Float32Array([
1, 2, 5, 6, 1, 2, 5, 6, 3, 4, 7, 8, 3, 4, 7, 8, 17, 18, 21, 22, 17, 18,
21, 22, 19, 20, 23, 24, 19, 20, 23, 24, 25, 26, 29, 30, 25, 26, 29, 30, 27, 28, 31, 32,
27, 28, 31, 32, 25, 26, 29, 30, 25, 26, 29, 30, 27, 28, 31, 32, 27, 28, 31, 32
1, 2, 5, 6, 1, 2, 5, 6, 3, 4, 7, 8, 3, 4, 7, 8, 9, 10, 13, 14, 9, 10,
13, 14, 11, 12, 15, 16, 11, 12, 15, 16, 17, 18, 21, 22, 17, 18, 21, 22, 19, 20, 23, 24,
19, 20, 23, 24, 25, 26, 29, 30, 25, 26, 29, 30, 27, 28, 31, 32, 27, 28, 31, 32
])
},
];
Expand Down

0 comments on commit 94d3843

Please sign in to comment.