Skip to content

Commit

Permalink
Use type instead of rank
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Nov 9, 2023
1 parent d282082 commit 165ce99
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions js/web/lib/wasm/jsep/webgpu/ops/softmax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,23 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
if (enableInputShapesUniforms) {
programUniforms.push(...createTensorShapeVariables(input.dims), ...createTensorShapeVariables(input.dims));
}
const inputDependencies: ProgramInputTensorInfoDependency[] = [enableInputShapesUniforms ? 'rank' : 'dims'];
const dataType = x.type.value;
const inputDependencies: ProgramInputTensorInfoDependency[] = [enableInputShapesUniforms ? 'type' : 'dims'];
const valueType = x.type.value;
// 6.2.4 in wgsl spec
const threadMaxDecl = tensorTypeToWsglStorageType(input.dataType) === 'f32' ?
`var threadMax = ${dataType}(-3.402823e+38f);` :
`var threadMax = ${dataType}(-65504.0h);`;
`var threadMax = ${valueType}(-3.402823e+38f);` :
`var threadMax = ${valueType}(-65504.0h);`;
const getShaderSource = (shaderHelper: ShaderHelper) => `
var<workgroup> rowMaxShared : ${dataType};
var<workgroup> rowSumShared : ${dataType};
var<workgroup> threadShared : array<${dataType}, ${WG}>;
var<workgroup> rowMaxShared : ${valueType};
var<workgroup> rowSumShared : ${valueType};
var<workgroup> threadShared : array<${valueType}, ${WG}>;
fn getValue(row: i32, col: i32, row_stride: i32) -> ${dataType} {
fn getValue(row: i32, col: i32, row_stride: i32) -> ${valueType} {
let index = row * row_stride + col;
return x[index];
}
fn setValue(row: i32, col: i32, row_stride: i32, value: ${dataType}) {
fn setValue(row: i32, col: i32, row_stride: i32, value: ${valueType}) {
let index = row * row_stride + col;
result[index] = value;
}
Expand Down Expand Up @@ -106,12 +106,12 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
workgroupBarrier();
}
if (lindex == 0) {
rowMaxShared = ${dataType}(${maxVector('threadShared[0]', components)});
rowMaxShared = ${valueType}(${maxVector('threadShared[0]', components)});
}
workgroupBarrier();
// find the rows sum
var threadSum = ${dataType}(0.0);
var threadSum = ${valueType}(0.0);
for (var col = lindex; col < cols; col += wg) {
let subExp = exp(getValue(row, col, row_stride) - rowMaxShared);
threadSum += subExp;
Expand All @@ -126,7 +126,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
workgroupBarrier();
}
if (lindex == 0) {
rowSumShared = ${dataType}(${sumVector('threadShared[0]', components)});
rowSumShared = ${valueType}(${sumVector('threadShared[0]', components)});
}
workgroupBarrier();
Expand Down

0 comments on commit 165ce99

Please sign in to comment.