diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 80a80b4c18619..48b9aeb6a1666 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -23,12 +23,12 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; +import {tensorTypeToWsglStorageType} from '../common' import {ConvAttributes} from '../conv'; import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; -import { tensorTypeToWsglStorageType } from '../common' const conv2dCommonSnippet = (isChannelsLast: boolean, fitAOuter: boolean, fitBOuter: boolean, fitInner: boolean, addBias = false, @@ -39,7 +39,7 @@ const conv2dCommonSnippet = case 1: return 'resData = x[xIndex];'; case 3: - return 'resData = vec3(x[xIndex], x[xIndex + 1], x[xIndex + 2]);'; + return `resData = vec3<${dataType}>(x[xIndex], x[xIndex + 1], x[xIndex + 2]);`; case 4: return 'resData = x[xIndex / 4];'; default: @@ -125,8 +125,10 @@ const conv2dCommonSnippet = const sampleW = `${getWSnippet(innerElementSizeW)}`; const resType = typeSnippet(innerElementSize, dataType); - const aType = isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); - const bType = isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); + const aType = + isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); + const bType = + isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); const userCode = ` ${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)} fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index ec6df438129fb..b3043b0e14d82 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -21,12 +21,13 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; -import {inputVariable, outputVariable, ShaderHelper} from '../common'; +import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common' import {ConvTransposeAttributes} from '../conv-transpose'; const createConvTranspose2DOpProgramShaderSource = (shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: ConvTransposeAttributes, - outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false): string => { + outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false, + dataType: string): string => { const isChannelsLast = attributes.format === 'NHWC'; const rowDim = isChannelsLast ? 1 : 2; const colDim = isChannelsLast ? 2 : 3; @@ -39,12 +40,12 @@ const createConvTranspose2DOpProgramShaderSource = const outputChannelsPerGroup = wShape[1]; let declareFunctions = ` - fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? 'vec4' : 'f32'}) { - result[flatIndex] = ${isVec4 ? 'vec4' : 'f32'}(value); + fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) { + result[flatIndex] = ${isVec4 ? `vec4<${dataType}>` : dataType}(value); }`; if (hasBias) { declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${dataType}>` : dataType} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } @@ -66,33 +67,33 @@ const createConvTranspose2DOpProgramShaderSource = // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). // ? = to be determined. : = across all values in that axis. - var dotProd: array, ${workPerThread}>; + var dotProd: array, ${workPerThread}>; for (var i = 0; i < ${workPerThread}; i++) { - dotProd[i] = vec4(0.0); + dotProd[i] = vec4<${dataType}>(0.0); } for (var wR: u32 = 0; wR < filterDims[0]; wR = wR + 1) { - var dyR = (f32(dyCorner.x) + f32(wR)) / f32(strides.x); + var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(strides.x); let wRPerm = filterDims[0] - 1 - wR; - if (dyR < 0.0 || dyR >= f32(outBackprop[1]) || + if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[1]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } let idyR: u32 = u32(dyR); for (var wC: u32 = 0; wC < filterDims[1]; wC = wC + 1) { - let dyC = (f32(dyCorner.y) + f32(wC)) / f32(strides.y); - let dyC2 = (f32(dyCorner.y) + 1.0 + f32(wC)) / f32(strides.y); + let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(strides.y); + let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(strides.y); let wCPerm = filterDims[1] - 1 - wC; if (wCPerm < 0) { continue; } var bDyCVal = true; var bDyCVal2 = true; - if (dyC < 0.0 || dyC >= f32(outBackprop[2]) || + if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[2]) || fract(dyC) > 0.0) { bDyCVal = false; } - if (dyC2 < 0.0 || dyC2 >= f32(outBackprop[2]) || + if (dyC2 < 0.0 || dyC2 >= ${dataType}(outBackprop[2]) || fract(dyC2) > 0.0) { bDyCVal2 = false; } @@ -108,7 +109,7 @@ const createConvTranspose2DOpProgramShaderSource = let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')}; - let tmpval = vec4(dot(xValue, wValue0), + let tmpval = vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -116,7 +117,7 @@ const createConvTranspose2DOpProgramShaderSource = xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')}; - dotProd[1] = dotProd[1] + vec4(dot(xValue, wValue0), + dotProd[1] = dotProd[1] + vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -130,7 +131,7 @@ const createConvTranspose2DOpProgramShaderSource = let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')}; - let tmpval = vec4(dot(xValue, wValue0), + let tmpval = vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -145,7 +146,7 @@ const createConvTranspose2DOpProgramShaderSource = let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; var xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')}; - let tmpval = vec4(dot(xValue, wValue0), + let tmpval = vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -178,9 +179,9 @@ const createConvTranspose2DOpProgramShaderSource = if (wR % dilations.x != 0) { continue; } - let dyR = (f32(dyRCorner) + f32(wR)) / f32(strides[0]); + let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(strides[0]); let wRPerm = filterDims.x - 1 - wR / dilations.x; - if (dyR < 0.0 || dyR >= f32(outBackprop[${rowDim}]) || fract(dyR) > 0.0 || + if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[${rowDim}]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } @@ -190,9 +191,9 @@ const createConvTranspose2DOpProgramShaderSource = if (wC % dilations.y != 0) { continue; } - let dyC = (f32(dyCCorner) + f32(wC)) / f32(strides.y); + let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(strides.y); let wCPerm = filterDims.y - 1 - wC / dilations.y; - if (dyC < 0.0 || dyC >= f32(outBackprop[${colDim}]) || + if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[${colDim}]) || fract(dyC) > 0.0 || wCPerm < 0) { continue; } @@ -256,6 +257,7 @@ export const createConvTranspose2DProgramInfo = ]; LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); return { ...metadata, outputs: [{ @@ -265,6 +267,7 @@ export const createConvTranspose2DProgramInfo = }], dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}), getShaderSource: (shaderHelper: ShaderHelper) => createConvTranspose2DOpProgramShaderSource( - shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1), + shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1, false, + dataType), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 2d6067fdbfa49..82f8c82291f4b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -22,14 +22,7 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; -import { - getBroadcastDims, - IndicesHelper, - inputVariable, - outputVariable, - ShaderHelper, - tensorTypeToWsglStorageType -} from '../common'; +import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; import {getActicationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -78,7 +71,7 @@ const calculateResultSnippet = (transposeA: boolean, innerElementSize: number) = export const makeMatMulPackedVec4Source = (workPerThread: number[], workgroupSize: [number, number, number], type = 'f32', batchDims?: IndicesHelper, - transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32): string => { + transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32): string => { const tileAOuter = workgroupSize[1] * workPerThread[1]; const tileBOuter = workgroupSize[0] * workPerThread[0]; const tileAWidth = transposeA ? tileAOuter : tileInner; @@ -187,8 +180,8 @@ const readDataFromSubASnippet = (transposeA: boolean) => // threads, instead of a single thread (default behavior). export const makeMatMulPackedSource = (workPerThread: number[], workgroupSize: [number, number, number], type = 'f32', batchDims?: IndicesHelper, - transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32, - sequentialAccessByThreads = false): string => { + transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32, + sequentialAccessByThreads = false): string => { const tileAOuter = workPerThread[1] * workgroupSize[1]; const tileBOuter = workPerThread[0] * workgroupSize[0]; const tileAWidth = transposeA ? tileAOuter : tileInner; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index 3274969970a91..2b418e65e7c3b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index a9642d85ede8f..7afc3ce1b9d77 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {PoolConvUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index bceaf244987c4..7dadf9a6205ea 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {BroadcastUtil} from '../../util'; import {ComputeContext, GpuDataType, ProgramInfoLoader} from '../types'; @@ -9,7 +8,6 @@ import {ComputeContext, GpuDataType, ProgramInfoLoader} from '../types'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; import {InternalActivationAttributes} from './fuse-utils'; - const createMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({ name: 'MatMul', inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] :