Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Aug 15, 2024
1 parent c67c7d8 commit 5ec24f5
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 42 deletions.
16 changes: 8 additions & 8 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ const conv2dCommonSnippet = (
innerElementSizeX = 4,
innerElementSizeW = 4,
innerElementSize = 4,
dataType = 'f32'
dataType = 'f32',
): string => {
const getXSnippet = (innerElementSize: number) => {
switch (innerElementSize) {
Expand Down Expand Up @@ -133,10 +133,10 @@ const conv2dCommonSnippet = (
}
return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`
: fitInner && fitBOuter
? `
? `
let col = colIn * ${innerElementSizeX};
${readXSnippet}`
: `
: `
let col = colIn * ${innerElementSizeX};
if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) {
${readXSnippet}
Expand Down Expand Up @@ -182,7 +182,7 @@ export const createConv2DMatMulProgramInfo = (
dimInner: number,
hasBias: boolean,
sequentialAccessByThreads: boolean,
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[]
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
): ProgramInfo => {
const isChannelsLast = attributes.format === 'NHWC';
const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1];
Expand Down Expand Up @@ -258,7 +258,7 @@ export const createConv2DMatMulProgramInfo = (
'x',
inputs[0].dataType,
inputs[0].dims.length,
innerElementSize === 3 ? 1 : innerElementSize
innerElementSize === 3 ? 1 : innerElementSize,
);
const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components);
const inputVariables = [x, w];
Expand Down Expand Up @@ -289,7 +289,7 @@ export const createConv2DMatMulProgramInfo = (
elementsSize[0],
elementsSize[1],
elementsSize[2],
t
t,
)}
${conv2dCommonSnippet(
isChannelsLast,
Expand All @@ -301,7 +301,7 @@ export const createConv2DMatMulProgramInfo = (
elementsSize[0],
elementsSize[1],
elementsSize[2],
t
t,
)}
${
isVec4
Expand All @@ -315,7 +315,7 @@ export const createConv2DMatMulProgramInfo = (
tileInner,
false,
undefined,
sequentialAccessByThreads
sequentialAccessByThreads,
)
}`;
};
Expand Down
24 changes: 12 additions & 12 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ export const makeMatMulPackedVec4Source = (
transposeA = false,
tileInner = 32,
splitK = false,
splitedDimInner = 32
splitedDimInner = 32,
): string => {
const tileAOuter = workgroupSize[1] * workPerThread[1];
const tileBOuter = workgroupSize[0] * workPerThread[0];
Expand Down Expand Up @@ -212,7 +212,7 @@ export const makeMatMulPackedSource = (
tileInner = 32,
splitK = false,
splitedDimInner = 32,
sequentialAccessByThreads = false
sequentialAccessByThreads = false,
): string => {
const tileAOuter = workPerThread[1] * workgroupSize[1];
const tileBOuter = workPerThread[0] * workgroupSize[0];
Expand All @@ -223,7 +223,7 @@ export const makeMatMulPackedSource = (
!(tileAHight % workgroupSize[1] === 0 && tileAWidth % workgroupSize[0] === 0 && tileInner % workgroupSize[1] === 0)
) {
throw new Error(
`tileAHight ${tileAHight} must be divisible by workgroupSize[1]${workgroupSize[1]}, tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${workgroupSize[0]}, tileInner ${tileInner} must be divisible by workgroupSize[1]${workgroupSize[1]}`
`tileAHight ${tileAHight} must be divisible by workgroupSize[1]${workgroupSize[1]}, tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${workgroupSize[0]}, tileInner ${tileInner} must be divisible by workgroupSize[1]${workgroupSize[1]}`,
);
}
const rowPerThreadA = tileAHight / workgroupSize[1];
Expand Down Expand Up @@ -374,7 +374,7 @@ const matMulReadWriteFnSource = (
applyActivation: string,
variables: IndicesHelper[],
batchShapes: Array<readonly number[]>,
isChannelsLast = false
isChannelsLast = false,
): string => {
const [batchAShape, batchBShape, batchShape] = batchShapes;
const [batchVariable, aVariable, bVariable, outputVariable] = variables;
Expand Down Expand Up @@ -411,9 +411,9 @@ const matMulReadWriteFnSource = (
};
const source = `
fn mm_readA(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${typeSnippet(
component,
dataType
)} {
component,
dataType,
)} {
var value = ${typeSnippet(component, dataType)}(0.0);
let col = colIn * ${component};
if(row < uniforms.dim_a_outer && col < uniforms.dim_inner)
Expand All @@ -425,9 +425,9 @@ const matMulReadWriteFnSource = (
}
fn mm_readB(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${typeSnippet(
component,
dataType
)} {
component,
dataType,
)} {
var value = ${typeSnippet(component, dataType)}(0.0);
let col = colIn * ${component};
if(row < uniforms.dim_inner && col < uniforms.dim_b_outer)
Expand Down Expand Up @@ -462,7 +462,7 @@ export const createMatmulProgramInfo = (
outputShape: readonly number[],
reshapedOutputShape?: readonly number[],
isChannelsLast = false /* only used for conv2dByMatMul*/,
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[]
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
): ProgramInfo => {
const aShape = inputs[0].dims;
const bShape = inputs[1].dims;
Expand Down Expand Up @@ -533,7 +533,7 @@ export const createMatmulProgramInfo = (
applyActivation,
[batchDims, A, B, output],
[outerDimsA, outerDimsB, outerDims],
isChannelsLast
isChannelsLast,
);
return `
${shaderHelper
Expand Down
6 changes: 3 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import { appendActivationUniforms, appendActivationUniformsData, getActivationSn
export const createGroupedConvProgramInfo = (
inputs: readonly TensorView[],
attributes: ConvAttributes,
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[]
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
): ProgramInfo => {
const hasBias = inputs.length > 2;
const processBias = hasBias ? 'value += b[output_channel];' : '';
Expand All @@ -40,7 +40,7 @@ export const createGroupedConvProgramInfo = (
attributes.dilations,
attributes.pads,
attributes.strides,
isChannelLast
isChannelLast,
);
const outputSize = ShapeUtil.size(outputShape);

Expand Down Expand Up @@ -145,7 +145,7 @@ export const createGroupedConvVectorizeProgramInfo = (
inputs: readonly TensorView[],
attributes: ConvAttributes,
outputShape: readonly number[],
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[]
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
): ProgramInfo => {
const hasBias = inputs.length > 2;
const components = getMaxComponents(outputShape[3]);
Expand Down
38 changes: 19 additions & 19 deletions js/web/lib/wasm/jsep/webgpu/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export const calculateOutputShape = (
dilations: readonly number[],
adjustPads: readonly number[],
strides: readonly number[],
isChannelLast: boolean
isChannelLast: boolean,
): number[] => {
const batchSize = inputShape[0];
const inputSpatialShape = inputShape.slice(isChannelLast ? 1 : 2, isChannelLast ? 3 : 4);
Expand All @@ -30,7 +30,7 @@ export const calculateOutputShape = (
const dilatedKernelShape = kernelSpatialShape.map((v, i) => v + (v - 1) * (dilations[i] - 1));
const inputSpatialShapeWithPad = inputSpatialShape.map((v, i) => v + adjustPads[i] + adjustPads[i + spatialRank]);
const outputShape = inputSpatialShapeWithPad.map((v, i) =>
Math.floor((v - dilatedKernelShape[i] + strides[i]) / strides[i])
Math.floor((v - dilatedKernelShape[i] + strides[i]) / strides[i]),
);
outputShape.splice(0, 0, batchSize);
outputShape.splice(isChannelLast ? 3 : 1, 0, outChannels);
Expand Down Expand Up @@ -117,7 +117,7 @@ const getAdjustedConvAttributes = <T extends ConvAttributes>(attributes: T, inpu
kernelShape,
pads,
attributes.format === 'NHWC',
attributes.autoPad
attributes.autoPad,
);

// always return a new object so does not modify the original attributes
Expand Down Expand Up @@ -156,7 +156,7 @@ const conv2d = (
context: ComputeContext,
inputs: readonly TensorView[],
attributes: ConvAttributes,
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[]
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
): void => {
// check attributes

Expand All @@ -181,7 +181,7 @@ const conv2d = (
attributes.dilations,
attributes.pads,
attributes.strides,
isChannelsLast
isChannelsLast,
);
const transposedWeight =
(context.kernelCustomData.wT as TensorView | undefined) ??
Expand All @@ -198,7 +198,7 @@ const conv2d = (
}
context.compute(
createGroupedConvVectorizeProgramInfo(convInputs, attributes, outputShape, squeezeOutputShapeFunction),
{ inputs: convInputs }
{ inputs: convInputs },
);
} else {
context.compute(createGroupedConvProgramInfo(inputs, attributes, squeezeOutputShapeFunction));
Expand All @@ -219,7 +219,7 @@ const conv2d = (
attributes.dilations,
attributes.pads,
attributes.strides,
isChannelsLast
isChannelsLast,
);
const outHeight = outputShape[isChannelsLast ? 1 : 2];
const outWidth = outputShape[isChannelsLast ? 2 : 3];
Expand Down Expand Up @@ -289,9 +289,9 @@ const conv2d = (
outputShape,
matmulOutputShape,
isChannelsLast,
squeezeOutputShapeFunction
squeezeOutputShapeFunction,
),
{ inputs: matmulInputs }
{ inputs: matmulInputs },
);
} else {
context.compute(
Expand All @@ -301,9 +301,9 @@ const conv2d = (
outputShape,
matmulOutputShape,
isChannelsLast,
squeezeOutputShapeFunction
squeezeOutputShapeFunction,
),
{ inputs: matmulInputs }
{ inputs: matmulInputs },
);
}
return;
Expand Down Expand Up @@ -344,9 +344,9 @@ const conv2d = (
dimInner,
hasBias,
sequentialAccessByThreads,
squeezeOutputShapeFunction
squeezeOutputShapeFunction,
),
{ inputs: convInputs }
{ inputs: convInputs },
);
};

Expand All @@ -359,7 +359,7 @@ const conv1d = (context: ComputeContext, attributes: ConvAttributes): void => {
? // [N, W, C] -> [N, H=1, W, C]
[context.inputs[0].dims[0], 1, context.inputs[0].dims[1], context.inputs[0].dims[2]]
: // [N, C, W] -> [N, C, H=1, W]
[context.inputs[0].dims[0], context.inputs[0].dims[1], 1, context.inputs[0].dims[2]]
[context.inputs[0].dims[0], context.inputs[0].dims[1], 1, context.inputs[0].dims[2]],
),
//[FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, kW] -> [FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, kH=1, kW]
context.inputs[1].reshape([context.inputs[1].dims[0], context.inputs[1].dims[1], 1, context.inputs[1].dims[2]]),
Expand All @@ -373,10 +373,10 @@ const conv1d = (context: ComputeContext, attributes: ConvAttributes): void => {
const kernelShape = [1].concat(attributes.kernelShape);
const adjustedAttributes = getAdjustedConvAttributes(
{ ...attributes, pads, strides, dilations, kernelShape },
inputs
inputs,
);
conv2d(context, inputs, adjustedAttributes, (outputShape) =>
isChannelLast ? [outputShape[0], outputShape[2], outputShape[3]] : [outputShape[0], outputShape[1], outputShape[3]]
isChannelLast ? [outputShape[0], outputShape[2], outputShape[3]] : [outputShape[0], outputShape[1], outputShape[3]],
);
};

Expand All @@ -391,7 +391,7 @@ const conv3d = (context: ComputeContext, inputs: readonly TensorView[], attribut
attributes.dilations as number | [number, number, number],
pads as string | number[],
false,
format
format,
);
context.compute(
createConv3DNaiveProgramInfo(
Expand All @@ -400,8 +400,8 @@ const conv3d = (context: ComputeContext, inputs: readonly TensorView[], attribut
convInfo.outShape,
[convInfo.filterDepth, convInfo.filterHeight, convInfo.filterWidth],
[convInfo.padInfo.front, convInfo.padInfo.top, convInfo.padInfo.left],
format
)
format,
),
);
};

Expand Down

0 comments on commit 5ec24f5

Please sign in to comment.