From f060b5a670f7da5bd84a23d9e266ea5f37ae3588 Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Tue, 12 Sep 2023 17:03:35 +0400 Subject: [PATCH 1/4] FP16 Conv, ConvTranspose and MatMul --- .../webgpu/ops/3rd-party/activation_util.ts | 10 ++-- .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 38 +++++++------- .../ops/3rd-party/matmul_packed_webgpu.ts | 50 +++++++++++-------- .../wasm/jsep/webgpu/ops/conv-transpose.ts | 10 ---- js/web/lib/wasm/jsep/webgpu/ops/conv.ts | 10 ---- js/web/lib/wasm/jsep/webgpu/ops/matmul.ts | 5 -- .../providers/js/js_execution_provider.cc | 24 ++++----- .../core/providers/js/operators/conv.cc | 48 ++++++++---------- .../core/providers/js/operators/conv.h | 2 +- .../providers/js/operators/conv_transpose.cc | 47 ++++++++--------- .../providers/js/operators/conv_transpose.h | 2 +- .../core/providers/js/operators/matmul.cc | 4 +- 12 files changed, 113 insertions(+), 137 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts index dd4f13e76ee04..22b91d680a9b4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts @@ -21,16 +21,16 @@ export declare type Activation = 'linear' | 'relu' | 'prelu' | 'elu' | 'relu6' | 'leakyrelu' | 'sigmoid' | 'gelu'; -export const typeSnippet = (component: number) => { +export const typeSnippet = (component: number, dataType: string) => { switch (component) { case 1: - return 'f32'; + return dataType; case 2: - return 'vec2'; + return `vec2<${dataType}>`; case 3: - return 'vec3'; + return `vec3<${dataType}>`; case 4: - return 'vec4'; + return `vec4<${dataType}>`; default: throw new Error(`${component}-component is not supported.`); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 02507ad802b36..f07f0bbb84ee6 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -28,11 +28,12 @@ import {ConvAttributes} from '../conv'; import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; +import { tensorTypeToWsglStorageType } from '../common' const conv2dCommonSnippet = (isChannelsLast: boolean, fitAOuter: boolean, fitBOuter: boolean, fitInner: boolean, addBias = false, activation?: Activation, hasPreluActivationWeights = false, innerElementSizeX = 4, innerElementSizeW = 4, - innerElementSize = 4): string => { + innerElementSize = 4, dataType = 'f32'): string => { const getXSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: @@ -92,7 +93,7 @@ const conv2dCommonSnippet = let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0]; let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1]; let xCh = ${col} % inChannels; - var resData = ${typeSnippet(innerElementSizeX)}(0.0); + var resData = ${typeSnippet(innerElementSizeX, dataType)}(0.0); // The bounds checking is always needed since we use it to pad zero for // the 'same' padding type. if (xRow >= 0 && xRow < ${xHeight} && xCol >= 0 && xCol < ${xWidth}) { @@ -110,7 +111,7 @@ const conv2dCommonSnippet = if (row < dimAOuter && col < dimInner) { ${readXSnippet} } - return ${typeSnippet(innerElementSizeX)}(0.0);`) : + return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) : (fitInner && fitBOuter ? ` let col = colIn * ${innerElementSizeX}; ${readXSnippet}` : @@ -119,13 +120,13 @@ const conv2dCommonSnippet = if (row < dimInner && col < dimBOuter) { ${readXSnippet} } - return ${typeSnippet(innerElementSizeX)}(0.0);`); + return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`); const sampleW = `${getWSnippet(innerElementSizeW)}`; - const resType = typeSnippet(innerElementSize); - const aType = isChannelsLast ? typeSnippet(innerElementSizeX) : typeSnippet(innerElementSizeW); - const bType = isChannelsLast ? typeSnippet(innerElementSizeW) : typeSnippet(innerElementSizeX); + const resType = typeSnippet(innerElementSize, dataType); + const aType = isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); + const bType = isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); const userCode = ` ${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)} fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} { @@ -190,23 +191,24 @@ export const createConv2DMatMulProgramInfo = const fitInner = dimInner % tileInner === 0; const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1]; + const t = tensorTypeToWsglStorageType(inputs[0].dataType); const declareInputs = [ - `@group(0) @binding(0) var x: array<${isVec4 && innerElementSize === 4 ? 'vec4' : 'f32'}>;`, - `@group(0) @binding(1) var w: array<${isVec4 ? 'vec4' : 'f32'}>;` + `@group(0) @binding(0) var x: array<${isVec4 && innerElementSize === 4 ? `vec4<${t}>` : t}>;`, + `@group(0) @binding(1) var w: array<${isVec4 ? `vec4<${t}>` : t}>;` ]; let declareFunctions = ` - fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? 'vec4' : 'f32'}) { - result[flatIndex] = ${isVec4 ? 'vec4' : 'f32'}(value); + fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) { + result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value); } - fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, d3 : i32, value : ${isVec4 ? 'vec4' : 'f32'}) { + fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, d3 : i32, value : ${isVec4 ? `vec4<${t}>` : t}) { let flatIndex = getOutputIndexFromCoords(vec4(d0, d1, d2, d3)); setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value); }`; if (hasBias) { - declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? 'vec4' : 'f32'}>;`); + declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? `vec4<${t}>` : t}>;`); declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${t}>` : t} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } @@ -222,7 +224,7 @@ export const createConv2DMatMulProgramInfo = // dilation : vec2, dimAOuter : i32, dimBOuter : i32, dimInner : i32 }; ${declareInputs.join('')} @group(0) @binding(${declareInputs.length}) var result: array<${ - isVec4 ? 'vec4' : 'f32'}>; + isVec4 ? `vec4<${t}>` : t}>; //@group(0) @binding(${declareInputs.length + 1}) var uniforms: Uniforms; const xShape : vec4 = vec4(${inputs[0].dims.join(',')}); @@ -240,12 +242,12 @@ export const createConv2DMatMulProgramInfo = ${ conv2dCommonSnippet( isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, undefined, false, elementsSize[0], - elementsSize[1], elementsSize[2])} + elementsSize[1], elementsSize[2], t)} ${ isVec4 ? - makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner) : + makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) : makeMatMulPackedSource( - elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner, false, undefined, + elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner, false, undefined, sequentialAccessByThreads)}` }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index ab4f608451101..e70226e55ee79 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -22,7 +22,14 @@ import {TensorView} from '../../../tensor'; import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; -import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from '../common'; +import { + getBroadcastDims, + IndicesHelper, + inputVariable, + outputVariable, + ShaderHelper, + tensorTypeToWsglStorageType +} from '../common'; import {getActicationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -70,8 +77,8 @@ const calculateResultSnippet = (transposeA: boolean, innerElementSize: number) = }; export const makeMatMulPackedVec4Source = - (workPerThread: number[], workgroupSize: [number, number, number], batchDims?: IndicesHelper, transposeA = false, - tileInner = 32, splitK = false, splitedDimInner = 32): string => { + (workPerThread: number[], workgroupSize: [number, number, number], type = 'f32', batchDims?: IndicesHelper, + transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32): string => { const tileAOuter = workgroupSize[1] * workPerThread[1]; const tileBOuter = workgroupSize[0] * workPerThread[0]; const tileAWidth = transposeA ? tileAOuter : tileInner; @@ -90,8 +97,8 @@ export const makeMatMulPackedVec4Source = workPerThread[0]} must be 4.`); } return ` -var mm_Asub : array, ${tileAWidth / innerElementSize}>, ${tileAHight}>; -var mm_Bsub : array, ${tileBOuter / workPerThread[0]}>, ${tileInner}>; +var mm_Asub : array, ${tileAWidth / innerElementSize}>, ${tileAHight}>; +var mm_Bsub : array, ${tileBOuter / workPerThread[0]}>, ${tileInner}>; const rowPerThread = ${workPerThread[1]}; const colPerThread = ${workPerThread[0]}; @@ -115,7 +122,7 @@ fn main(@builtin(local_invocation_id) localId : vec3, let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; - var acc: array, rowPerThread>; + var acc: array, rowPerThread>; // Loop over shared dimension. let tileRowB = localRow * ${rowPerThreadB}; @@ -179,8 +186,9 @@ const readDataFromSubASnippet = (transposeA: boolean) => // sequentialAccessByThreads means sequential data in memory is accessed by // threads, instead of a single thread (default behavior). export const makeMatMulPackedSource = - (workPerThread: number[], workgroupSize: [number, number, number], batchDims?: IndicesHelper, transposeA = false, - tileInner = 32, splitK = false, splitedDimInner = 32, sequentialAccessByThreads = false): string => { + (workPerThread: number[], workgroupSize: [number, number, number], type = 'f32', batchDims?: IndicesHelper, + transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32, + sequentialAccessByThreads = false): string => { const tileAOuter = workPerThread[1] * workgroupSize[1]; const tileBOuter = workPerThread[0] * workgroupSize[0]; const tileAWidth = transposeA ? tileAOuter : tileInner; @@ -222,7 +230,7 @@ export const makeMatMulPackedSource = workgroupBarrier(); // Compute acc values for a single thread. - var BCached : array; + var BCached : array<${type}, colPerThread>; for (var k = 0; k < tileInner; k = k + 1) { for (var inner = 0; inner < colPerThread; inner = inner + 1) { BCached[inner] = mm_Bsub[k][localCol + inner * ${workgroupSize[0]}]; @@ -283,7 +291,7 @@ for (var t = 0; t < numTiles; t = t + 1) { workgroupBarrier(); // Compute acc values for a single thread. - var BCached : array; + var BCached : array<${type}, colPerThread>; for (var k = 0; k < tileInner; k = k + 1) { for (var inner = 0; inner < colPerThread; inner = inner + 1) { BCached[inner] = mm_Bsub[k][tileCol + inner]; @@ -309,8 +317,8 @@ for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { `; return ` - var mm_Asub : array, ${tileAHight}>; - var mm_Bsub : array, ${tileInner}>; + var mm_Asub : array, ${tileAHight}>; + var mm_Bsub : array, ${tileInner}>; const rowPerThread = ${workPerThread[1]}; const colPerThread = ${workPerThread[0]}; const tileInner = ${tileInner}; @@ -324,7 +332,7 @@ fn main(@builtin(local_invocation_id) localId : vec3, let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; - var acc : array, rowPerThread>; + var acc : array, rowPerThread>; // Without this initialization strange values show up in acc. for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { @@ -347,6 +355,7 @@ const matMulReadWriteFnSource = const outputVariable = variables[5]; const broadCastADims = getBroadcastDims(batchAVariable.shape, batchVariable.shape); const broadCastBDims = getBroadcastDims(batchBVariable.shape, batchVariable.shape); + const dataType = tensorTypeToWsglStorageType(variables[0].type.tensor); const getAIndices = () => { const aRank = aVariable.shape.length; const batchRank = batchVariable.shape.length; @@ -377,8 +386,8 @@ const matMulReadWriteFnSource = }; const source = ` fn mm_readA(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${ - typeSnippet(component)} { - var value = ${typeSnippet(component)}(0.0); + typeSnippet(component, dataType)} { + var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; if(row < dimAOuter && col < dimInner) { @@ -389,8 +398,8 @@ const matMulReadWriteFnSource = } fn mm_readB(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${ - typeSnippet(component)} { - var value = ${typeSnippet(component)}(0.0); + typeSnippet(component, dataType)} { + var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; if(row < dimInner && col < dimBOuter) { @@ -400,7 +409,7 @@ const matMulReadWriteFnSource = return value; } - fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component)}) { + fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component, dataType)}) { let col = colIn * ${component}; if (row < dimAOuter && col < dimBOuter) { var value = valueIn; @@ -444,6 +453,7 @@ export const createMatmulProgramInfo = Math.ceil(batchSize / workgroupSize[2] / elementsPerThread[2]) ]; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const components = isVec4 ? 4 : 1; const A = inputVariable('a', inputs[0].dataType, [...outerDimsA, dimAOuter, dimInner / components], components); const B = inputVariable('b', inputs[1].dataType, [...outerDimsB, dimInner, dimBOuter / components], components); @@ -466,8 +476,8 @@ export const createMatmulProgramInfo = ${declareFunctions} ${activationFunction} ${ - isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, batchDims) : - makeMatMulPackedSource(elementsPerThread, workgroupSize, batchDims)} + isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : + makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} ${batchDims.impl()}`; return { ...metadata, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index acdfd7e40f258..7503f664dfc13 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor'; import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types'; @@ -197,15 +196,6 @@ const validateInputs = (inputs: readonly TensorView[], attributes: ConvTranspose if (attributes.outputShape.length !== 0 && attributes.outputShape.length !== inputs[0].dims.length - 2) { throw new Error('invalid output shape'); } - - // TODO : Need to add support for float64 - if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { - throw new Error('ConvTranspose input(X,W) should be float tensor'); - } - - if (inputs.length === 3 && inputs[2].dataType !== DataType.float) { - throw new Error('ConvTranspose input(bias) should be float tensor'); - } }; const createConvTranspose2DProgramMetadata = (hasBias: boolean, cacheHint: string): ProgramMetadata => ({ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index 3a83b1c2de6c1..f205d4a06b176 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor'; import {PoolConvUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -93,15 +92,6 @@ const validateInputs = (inputs: readonly TensorView[], attributes: ConvAttribute if (attributes.kernelShape.length !== 0 && attributes.kernelShape.length !== inputs[1].dims.length - 2) { throw new Error('invalid kernel shape'); } - - // TODO : Need to add support for float64 - if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { - throw new Error('Conv input(X,W) should be float tensor'); - } - - if (inputs.length === 3 && inputs[2].dataType !== DataType.float) { - throw new Error('Conv input(bias) should be float tensor'); - } }; const getAdjustedConvAttributes = (attributes: T, inputs: readonly TensorView[]): T => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index e4dae00db6305..8ed41bc09480d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor'; import {BroadcastUtil} from '../../util'; import {ComputeContext, GpuDataType, ProgramInfoLoader} from '../types'; @@ -35,10 +34,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (inputs[0].dims[inputs[0].dims.length - 1] !== inputs[1].dims[inputs[1].dims.length - 2]) { throw new Error('shared dimension does not match.'); } - - if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { - throw new Error('inputs should be float type'); - } }; export const matMul = (context: ComputeContext): void => { diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 6b548921cdc8c..00998b8559e64 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -232,18 +232,18 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Uns class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, float, Conv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, float, ConvTranspose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 11, float, MaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 12, float, MaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, float, AveragePool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, float, GlobalAveragePool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, float, GlobalMaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, Conv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Conv); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ConvTranspose); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, ConvTranspose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, Conv); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, ConvTranspose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gemm); @@ -486,18 +486,18 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/conv.cc b/onnxruntime/core/providers/js/operators/conv.cc index c7c9f7f7c3f0e..2e07124dcd901 100644 --- a/onnxruntime/core/providers/js/operators/conv.cc +++ b/onnxruntime/core/providers/js/operators/conv.cc @@ -9,33 +9,27 @@ namespace onnxruntime { namespace js { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Conv, \ - kMSInternalNHWCDomain, \ - 11, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Conv, \ - kOnnxDomain, \ - 11, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - Conv, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); - -REGISTER_KERNEL_TYPED(float) +ONNX_OPERATOR_KERNEL_EX( + Conv, + kMSInternalNHWCDomain, + 11, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + Conv); +ONNX_OPERATOR_KERNEL_EX( + Conv, + kOnnxDomain, + 11, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + Conv); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Conv, + kOnnxDomain, + 1, 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + Conv); } // namespace js } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index 22f7721276677..fdf3e5b6c6b66 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -9,7 +9,7 @@ namespace onnxruntime { namespace js { -template +template class Conv : public JsKernel { public: Conv(const OpKernelInfo& info) : JsKernel(info), conv_attrs_(info), w_is_const_(false) { diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.cc b/onnxruntime/core/providers/js/operators/conv_transpose.cc index 1a2fc99eada6a..2228343e1e6e3 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.cc +++ b/onnxruntime/core/providers/js/operators/conv_transpose.cc @@ -7,33 +7,28 @@ #include "conv_transpose.h" namespace onnxruntime { namespace js { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - ConvTranspose, \ - kMSInternalNHWCDomain, \ - 11, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ConvTranspose); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - ConvTranspose, \ - kOnnxDomain, \ - 11, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ConvTranspose); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - ConvTranspose, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ConvTranspose); -REGISTER_KERNEL_TYPED(float) +ONNX_OPERATOR_KERNEL_EX( + ConvTranspose, + kMSInternalNHWCDomain, + 11, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + ConvTranspose); +ONNX_OPERATOR_KERNEL_EX( + ConvTranspose, + kOnnxDomain, + 11, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + ConvTranspose); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ConvTranspose, + kOnnxDomain, + 1, 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + ConvTranspose); } // namespace js } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index a5aeae8646373..dbe4c188380eb 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -9,7 +9,7 @@ #include "core/providers/js/js_kernel.h" namespace onnxruntime { namespace js { -template +template class ConvTranspose : public JsKernel { public: ConvTranspose(const OpKernelInfo& info) : JsKernel(info), conv_transpose_attrs_(info), w_is_const_(false) { diff --git a/onnxruntime/core/providers/js/operators/matmul.cc b/onnxruntime/core/providers/js/operators/matmul.cc index ddfbb454def07..6e6f906f7b42c 100644 --- a/onnxruntime/core/providers/js/operators/matmul.cc +++ b/onnxruntime/core/providers/js/operators/matmul.cc @@ -9,11 +9,11 @@ namespace js { JSEP_KERNEL_IMPL(MatMul, MatMul) ONNX_OPERATOR_VERSIONED_KERNEL_EX(MatMul, kOnnxDomain, 1, 12, kJsExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), MatMul); ONNX_OPERATOR_KERNEL_EX(MatMul, kOnnxDomain, 13, kJsExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), MatMul); } // namespace js From 673b9ec8908a1c26194c335a16778d3ed9910da1 Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Sat, 16 Sep 2023 00:48:24 +0400 Subject: [PATCH 2/4] Conflict fixes --- .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 10 ++-- .../ops/3rd-party/conv_backprop_webgpu.ts | 47 ++++++++++--------- .../ops/3rd-party/matmul_packed_webgpu.ts | 15 ++---- .../wasm/jsep/webgpu/ops/conv-transpose.ts | 1 - js/web/lib/wasm/jsep/webgpu/ops/conv.ts | 1 - js/web/lib/wasm/jsep/webgpu/ops/matmul.ts | 2 - 6 files changed, 35 insertions(+), 41 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 80a80b4c18619..48b9aeb6a1666 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -23,12 +23,12 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; +import {tensorTypeToWsglStorageType} from '../common' import {ConvAttributes} from '../conv'; import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; -import { tensorTypeToWsglStorageType } from '../common' const conv2dCommonSnippet = (isChannelsLast: boolean, fitAOuter: boolean, fitBOuter: boolean, fitInner: boolean, addBias = false, @@ -39,7 +39,7 @@ const conv2dCommonSnippet = case 1: return 'resData = x[xIndex];'; case 3: - return 'resData = vec3(x[xIndex], x[xIndex + 1], x[xIndex + 2]);'; + return `resData = vec3<${dataType}>(x[xIndex], x[xIndex + 1], x[xIndex + 2]);`; case 4: return 'resData = x[xIndex / 4];'; default: @@ -125,8 +125,10 @@ const conv2dCommonSnippet = const sampleW = `${getWSnippet(innerElementSizeW)}`; const resType = typeSnippet(innerElementSize, dataType); - const aType = isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); - const bType = isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); + const aType = + isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); + const bType = + isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); const userCode = ` ${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)} fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index ec6df438129fb..b3043b0e14d82 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -21,12 +21,13 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; -import {inputVariable, outputVariable, ShaderHelper} from '../common'; +import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common' import {ConvTransposeAttributes} from '../conv-transpose'; const createConvTranspose2DOpProgramShaderSource = (shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: ConvTransposeAttributes, - outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false): string => { + outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false, + dataType: string): string => { const isChannelsLast = attributes.format === 'NHWC'; const rowDim = isChannelsLast ? 1 : 2; const colDim = isChannelsLast ? 2 : 3; @@ -39,12 +40,12 @@ const createConvTranspose2DOpProgramShaderSource = const outputChannelsPerGroup = wShape[1]; let declareFunctions = ` - fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? 'vec4' : 'f32'}) { - result[flatIndex] = ${isVec4 ? 'vec4' : 'f32'}(value); + fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) { + result[flatIndex] = ${isVec4 ? `vec4<${dataType}>` : dataType}(value); }`; if (hasBias) { declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${dataType}>` : dataType} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } @@ -66,33 +67,33 @@ const createConvTranspose2DOpProgramShaderSource = // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). // ? = to be determined. : = across all values in that axis. - var dotProd: array, ${workPerThread}>; + var dotProd: array, ${workPerThread}>; for (var i = 0; i < ${workPerThread}; i++) { - dotProd[i] = vec4(0.0); + dotProd[i] = vec4<${dataType}>(0.0); } for (var wR: u32 = 0; wR < filterDims[0]; wR = wR + 1) { - var dyR = (f32(dyCorner.x) + f32(wR)) / f32(strides.x); + var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(strides.x); let wRPerm = filterDims[0] - 1 - wR; - if (dyR < 0.0 || dyR >= f32(outBackprop[1]) || + if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[1]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } let idyR: u32 = u32(dyR); for (var wC: u32 = 0; wC < filterDims[1]; wC = wC + 1) { - let dyC = (f32(dyCorner.y) + f32(wC)) / f32(strides.y); - let dyC2 = (f32(dyCorner.y) + 1.0 + f32(wC)) / f32(strides.y); + let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(strides.y); + let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(strides.y); let wCPerm = filterDims[1] - 1 - wC; if (wCPerm < 0) { continue; } var bDyCVal = true; var bDyCVal2 = true; - if (dyC < 0.0 || dyC >= f32(outBackprop[2]) || + if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[2]) || fract(dyC) > 0.0) { bDyCVal = false; } - if (dyC2 < 0.0 || dyC2 >= f32(outBackprop[2]) || + if (dyC2 < 0.0 || dyC2 >= ${dataType}(outBackprop[2]) || fract(dyC2) > 0.0) { bDyCVal2 = false; } @@ -108,7 +109,7 @@ const createConvTranspose2DOpProgramShaderSource = let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')}; - let tmpval = vec4(dot(xValue, wValue0), + let tmpval = vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -116,7 +117,7 @@ const createConvTranspose2DOpProgramShaderSource = xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')}; - dotProd[1] = dotProd[1] + vec4(dot(xValue, wValue0), + dotProd[1] = dotProd[1] + vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -130,7 +131,7 @@ const createConvTranspose2DOpProgramShaderSource = let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')}; - let tmpval = vec4(dot(xValue, wValue0), + let tmpval = vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -145,7 +146,7 @@ const createConvTranspose2DOpProgramShaderSource = let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; var xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')}; - let tmpval = vec4(dot(xValue, wValue0), + let tmpval = vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -178,9 +179,9 @@ const createConvTranspose2DOpProgramShaderSource = if (wR % dilations.x != 0) { continue; } - let dyR = (f32(dyRCorner) + f32(wR)) / f32(strides[0]); + let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(strides[0]); let wRPerm = filterDims.x - 1 - wR / dilations.x; - if (dyR < 0.0 || dyR >= f32(outBackprop[${rowDim}]) || fract(dyR) > 0.0 || + if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[${rowDim}]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } @@ -190,9 +191,9 @@ const createConvTranspose2DOpProgramShaderSource = if (wC % dilations.y != 0) { continue; } - let dyC = (f32(dyCCorner) + f32(wC)) / f32(strides.y); + let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(strides.y); let wCPerm = filterDims.y - 1 - wC / dilations.y; - if (dyC < 0.0 || dyC >= f32(outBackprop[${colDim}]) || + if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[${colDim}]) || fract(dyC) > 0.0 || wCPerm < 0) { continue; } @@ -256,6 +257,7 @@ export const createConvTranspose2DProgramInfo = ]; LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); return { ...metadata, outputs: [{ @@ -265,6 +267,7 @@ export const createConvTranspose2DProgramInfo = }], dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}), getShaderSource: (shaderHelper: ShaderHelper) => createConvTranspose2DOpProgramShaderSource( - shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1), + shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1, false, + dataType), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 2d6067fdbfa49..82f8c82291f4b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -22,14 +22,7 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; -import { - getBroadcastDims, - IndicesHelper, - inputVariable, - outputVariable, - ShaderHelper, - tensorTypeToWsglStorageType -} from '../common'; +import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; import {getActicationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -78,7 +71,7 @@ const calculateResultSnippet = (transposeA: boolean, innerElementSize: number) = export const makeMatMulPackedVec4Source = (workPerThread: number[], workgroupSize: [number, number, number], type = 'f32', batchDims?: IndicesHelper, - transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32): string => { + transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32): string => { const tileAOuter = workgroupSize[1] * workPerThread[1]; const tileBOuter = workgroupSize[0] * workPerThread[0]; const tileAWidth = transposeA ? tileAOuter : tileInner; @@ -187,8 +180,8 @@ const readDataFromSubASnippet = (transposeA: boolean) => // threads, instead of a single thread (default behavior). export const makeMatMulPackedSource = (workPerThread: number[], workgroupSize: [number, number, number], type = 'f32', batchDims?: IndicesHelper, - transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32, - sequentialAccessByThreads = false): string => { + transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32, + sequentialAccessByThreads = false): string => { const tileAOuter = workPerThread[1] * workgroupSize[1]; const tileBOuter = workPerThread[0] * workgroupSize[0]; const tileAWidth = transposeA ? tileAOuter : tileInner; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index 3274969970a91..2b418e65e7c3b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index a9642d85ede8f..7afc3ce1b9d77 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {PoolConvUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index bceaf244987c4..7dadf9a6205ea 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {BroadcastUtil} from '../../util'; import {ComputeContext, GpuDataType, ProgramInfoLoader} from '../types'; @@ -9,7 +8,6 @@ import {ComputeContext, GpuDataType, ProgramInfoLoader} from '../types'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; import {InternalActivationAttributes} from './fuse-utils'; - const createMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({ name: 'MatMul', inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : From 2c19a4b2f13a8c04c452cef0e4d729889e607f8d Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Wed, 20 Sep 2023 16:36:25 +0400 Subject: [PATCH 3/4] Linter fixes --- js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 2 +- .../lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 48b9aeb6a1666..e6d4039d8131b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -23,7 +23,7 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; -import {tensorTypeToWsglStorageType} from '../common' +import {tensorTypeToWsglStorageType} from '../common'; import {ConvAttributes} from '../conv'; import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index b3043b0e14d82..aa912ff4ad48f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -21,7 +21,7 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; -import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common' +import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; const createConvTranspose2DOpProgramShaderSource = From 22e545825f272f484ff5ea2ee571a9a45a50ac13 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 29 Sep 2023 15:02:20 -0700 Subject: [PATCH 4/4] fix error caused by conflict resolve --- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 3925e1cb4f564..f41d0d058a624 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -32,6 +32,7 @@ import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packe const conv2dTransposeCommonSnippet = (isChannelsLast: boolean, addBias = false, activation?: Activation, hasPreluActivationWeights = false, innerElementSize = 4): string => { + const type = typeSnippet(innerElementSize, 'f32'); const getWSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: @@ -89,10 +90,10 @@ const conv2dTransposeCommonSnippet = let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]); let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]); if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) { - return ${typeSnippet(innerElementSize)}(0.0); + return ${type}(0.0); } if (xC < 0.0 || xC >= f32(${xWidth}) || fract(xC) > 0.0) { - return ${typeSnippet(innerElementSize)}(0.0); + return ${type}(0.0); } let iXR = i32(xR); let iXC = i32(xC); @@ -105,13 +106,13 @@ const conv2dTransposeCommonSnippet = if (row < dimAOuter && col < dimInner) { ${readASnippet} } - return ${typeSnippet(innerElementSize)}(0.0);` : + return ${type}(0.0);` : ` let col = colIn * ${innerElementSize}; if (row < dimInner && col < dimBOuter) { ${readASnippet} } - return ${typeSnippet(innerElementSize)}(0.0);`; + return ${type}(0.0);`; const sampleW = ` let col = colIn * ${innerElementSize}; @@ -125,21 +126,21 @@ const conv2dTransposeCommonSnippet = let coord = vec4(coordX, coordY, col, rowInner); ${getWSnippet(innerElementSize)} } - return ${typeSnippet(innerElementSize)}(0.0); + return ${type}(0.0); `; const userCode = ` ${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)} - fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${typeSnippet(innerElementSize)} { + fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${type} { ${isChannelsLast ? sampleA : sampleW} } - fn mm_readB(batch: i32, row : i32, colIn : i32) -> ${typeSnippet(innerElementSize)} { + fn mm_readB(batch: i32, row : i32, colIn : i32) -> ${type} { ${isChannelsLast ? sampleW : sampleA} } - fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${typeSnippet(innerElementSize)}) { + fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${type}) { let col = colIn * ${innerElementSize}; if (row < dimAOuter && col < dimBOuter) { var value = valueInput; @@ -234,10 +235,10 @@ export const createConv2DTransposeMatMulProgramInfo = ${declareFunctions} ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, undefined, false, innerElementSize)} ${ - isVec4 ? - makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner) : - makeMatMulPackedSource( - elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner, false, undefined, - sequentialAccessByThreads)}` + isVec4 ? makeMatMulPackedVec4Source( + elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : + makeMatMulPackedSource( + elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false, + undefined, sequentialAccessByThreads)}` }; };