Skip to content

Commit

Permalink
Revert unintended changes to where.js
Browse files Browse the repository at this point in the history
  • Loading branch information
satyajandhyala committed Feb 23, 2024
1 parent 4d06b60 commit f020517
Showing 1 changed file with 31 additions and 29 deletions.
60 changes: 31 additions & 29 deletions js/web/lib/wasm/jsep/webgpu/ops/where.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,15 @@ import {TensorView} from '../../tensor-view';
import {BroadcastUtil, ShapeUtil} from '../../util';
import {ComputeContext, ProgramInfo} from '../types';

import {inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';

const createWhereOpProgramShader =
(shaderHelper: ShaderHelper, inputs: readonly TensorView[], dimsOutput: readonly number[], isBroadcast: boolean,
typeOutput: number) => {
const outputSize = ShapeUtil.size(dimsOutput);
const vecSize = Math.ceil(outputSize / 4);

const output = outputVariable('outputData', typeOutput, dimsOutput, 4);
const a = inputVariable('aData', inputs[1].dataType, inputs[1].dims, 4);
const b = inputVariable('bData', inputs[2].dataType, inputs[2].dims, 4);
const c = inputVariable('cData', inputs[0].dataType, inputs[0].dims, 4);
const output = outputVariable('output_data', typeOutput, dimsOutput.length, 4);
const a = inputVariable('a_data', inputs[1].dataType, inputs[1].dims.length, 4);
const b = inputVariable('b_data', inputs[2].dataType, inputs[2].dims.length, 4);
const c = inputVariable('c_data', inputs[0].dataType, inputs[0].dims.length, 4);

let assignment: string;
const expression = (a: string, b: string, c: string) => `select(${b}, ${a}, ${c})`;
Expand All @@ -27,21 +24,21 @@ const createWhereOpProgramShader =
expression(a.getByOffset('global_idx'), b.getByOffset('global_idx'), c.getByOffset('global_idx')));
} else {
const singleAssignment = (resStr: string, x: number, typeCast = '') => {
const expressionA = `aData[indexA${x}][componentA${x}]`;
const expressionB = `bData[indexB${x}][componentB${x}]`;
const expressionA = `a_data[index_a${x}][component_a${x}]`;
const expressionB = `b_data[index_b${x}][component_b${x}]`;
// eslint-disable-next-line no-bitwise
const expressionC = `bool(cData[indexC${x}] & (0xffu << (componentC${x} * 8)))`;
const expressionC = `bool(c_data[index_c${x}] & (0xffu << (component_c${x} * 8)))`;
return `
let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
let offsetA${x} = ${a.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
let offsetB${x} = ${b.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
let offsetC${x} = ${c.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
let indexA${x} = offsetA${x} / 4u;
let indexB${x} = offsetB${x} / 4u;
let indexC${x} = offsetC${x} / 4u;
let componentA${x} = offsetA${x} % 4u;
let componentB${x} = offsetB${x} % 4u;
let componentC${x} = offsetC${x} % 4u;
let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)};
let offset_b${x} = ${b.broadcastedIndicesToOffset(`output_indices${x}`, output)};
let offset_c${x} = ${c.broadcastedIndicesToOffset(`output_indices${x}`, output)};
let index_a${x} = offset_a${x} / 4u;
let index_b${x} = offset_b${x} / 4u;
let index_c${x} = offset_c${x} / 4u;
let component_a${x} = offset_a${x} % 4u;
let component_b${x} = offset_b${x} % 4u;
let component_c${x} = offset_c${x} % 4u;
${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)});
`;
};
Expand All @@ -52,21 +49,21 @@ const createWhereOpProgramShader =
${singleAssignment('data', 1, 'u32')}
${singleAssignment('data', 2, 'u32')}
${singleAssignment('data', 3, 'u32')}
outputData[global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));`;
output_data[global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));`;
} else {
assignment = `
${singleAssignment('outputData[global_idx]', 0)}
${singleAssignment('outputData[global_idx]', 1)}
${singleAssignment('outputData[global_idx]', 2)}
${singleAssignment('outputData[global_idx]', 3)}
${singleAssignment('output_data[global_idx]', 0)}
${singleAssignment('output_data[global_idx]', 1)}
${singleAssignment('output_data[global_idx]', 2)}
${singleAssignment('output_data[global_idx]', 3)}
`;
}
}

return `
${shaderHelper.declareVariables(c, a, b, output)}
${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(c, a, b, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')}
${assignment}
}`;
};
Expand All @@ -91,13 +88,18 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
outputSize = ShapeUtil.size(outputShape);
}

const vecSize = Math.ceil(outputSize / 4);

return {
name: 'Where',
shaderCache: {inputDependencies: ['rank', 'rank', 'rank']},
getShaderSource: (shaderHelper) =>
createWhereOpProgramShader(shaderHelper, inputs, outputShape, isBroadcast, outputDataType),
getRunData: () => ({
outputs: [{dims: outputShape, dataType: outputDataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)}
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)},
programUniforms:
[{type: DataType.uint32, data: vecSize}, ...createTensorShapeVariables(dimsC, dimsA, dimsB, outputShape)],
}),
};
};
Expand Down

0 comments on commit f020517

Please sign in to comment.