Skip to content

Commit

Permalink
[js/webgpu] Fix Conv2DTransposeMatMul f16 compilation failure (#19596)
Browse files Browse the repository at this point in the history
This is used in sam-h-decoder-f16.

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
axinging authored and fs-eire committed Mar 15, 2024
1 parent c8e7e8b commit 12101de
Showing 1 changed file with 13 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ import {DataType} from '../../../../wasm-common';
import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
import {ConvTransposeAttributes} from '../conv-transpose';
import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils';

import {biasSnippet, typeSnippet} from './activation_util';
import {biasSnippet} from './activation_util';
import {utilFunctions} from './conv_util';
import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu';

const conv2dTransposeCommonSnippet =
(isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, innerElementSize = 4): string => {
const type = typeSnippet(innerElementSize, 'f32');
(isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, type: string,
innerElementSize = 4): string => {
const getWSnippet = (innerElementSize: number) => {
switch (innerElementSize) {
case 1:
Expand All @@ -47,7 +47,7 @@ const conv2dTransposeCommonSnippet =
let v1 = w[getIndexFromCoords4D(coord1, vec4<i32>(uniforms.w_shape))];
let v2 = w[getIndexFromCoords4D(coord2, vec4<i32>(uniforms.w_shape))];
let v3 = w[getIndexFromCoords4D(coord3, vec4<i32>(uniforms.w_shape))];
return vec4<f32>(v0, v1, v2, v3);
return ${type}(v0, v1, v2, v3);
`;
default:
throw new Error(`innerElementSize ${innerElementSize} is not supported.`);
Expand Down Expand Up @@ -224,7 +224,7 @@ export const createConv2DTransposeMatMulProgramInfo =
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
inputVariables.push(bias);
declareFunctions += `
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? 'vec4<f32>' : 'f32'} {
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${bias.type.value} {
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
}`;
}
Expand All @@ -236,16 +236,20 @@ export const createConv2DTransposeMatMulProgramInfo =
{name: 'pads', type: 'i32', length: pads.length}
];
appendActivationUniforms(attributes, uniforms);
const elemType = tensorTypeToWsglStorageType(inputs[0].dataType, 1);
if (elemType !== 'f16' && elemType !== 'f32') {
throw new Error(`elemType ${elemType} is not supported.`);
}
return `
${utilFunctions('uniforms.result_strides')}
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)};
${declareFunctions}
${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)}
${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, x.type.value, innerElementSize)}
${
isVec4 ? makeMatMulPackedVec4Source(
elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) :
elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner) :
makeMatMulPackedSource(
elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false,
elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner, false,
undefined, sequentialAccessByThreads)}`;
};

Expand Down

0 comments on commit 12101de

Please sign in to comment.