diff --git a/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts b/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts index 9421d7433a14f..637c2419fbd93 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts @@ -32,11 +32,10 @@ const createBiasAddProgramInfo = (metadata: ProgramMetadata, inputs: readonly Te // since channel number can be only 320/640/1280, it's always divisable by 4 const outputSize = ShapeUtil.size(outputShape) / 4; - const dataType = inputs[0].dataType; - const input = inputVariable('input', dataType, outputShape, 4); - const bias = inputVariable('bias', dataType, [channels], 4); - const residual = inputVariable('residual', dataType, outputShape, 4); - const output = outputVariable('output', dataType, outputShape, 4); + const input = inputVariable('input', inputs[0].dataType, outputShape, 4); + const bias = inputVariable('bias', inputs[1].dataType, [channels], 4); + const residual = inputVariable('residual', inputs[1].dataType, outputShape, 4); + const output = outputVariable('output', inputs[0].dataType, outputShape, 4); const getShaderSource = (shaderHelper: ShaderHelper) => ` const channels = ${channels}u / 4; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index a915a4bbd969c..34fbd681637fe 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; -import {TensorView} from '../../tensor'; -import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; +import { DataType } from '../../../wasm-common' +import { TensorView } from '../../tensor' +import { ShapeUtil } from '../../util' +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key' +import { ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata } from '../types' -import {ShaderHelper} from './common'; +import { getMaxComponents, inputVariable, outputVariable, ShaderHelper } from './common' export interface GatherAttributes extends AttributeWithCacheKey { axis: number; @@ -32,17 +32,34 @@ const createGatherProgramInfo = const inputDataType = inputs[0].dataType; const block = ShapeUtil.sizeFromDimension(inputShape, axis + 1); - const elementSize = [DataType.int64, DataType.uint64, DataType.double].includes(inputDataType) ? 2 : 1; + let elementSize = [DataType.int64, DataType.uint64, DataType.double].includes(inputDataType) ? 2 : 1; const indicesElementSize = inputs[1].dataType === DataType.int64 ? 2 : 1; + + // for f16 when block size is odd, we'll use single f16 + // when it's odd just one u32 + let gatherType = DataType.uint32; + if (inputDataType === DataType.float16) { + if (block % 2 === 0) { + elementSize = 2; + } else { + gatherType = DataType.float16; + } + } const blockSize = elementSize * block; + const components = getMaxComponents(blockSize); + + const input = inputVariable('input', gatherType, inputShape, components); + const indices = inputVariable('inputIndices', DataType.int32, indicesShape); + const output = outputVariable('output', gatherType, outputShape, components); + const M = ShapeUtil.sizeToDimension(inputShape, axis); const N = ShapeUtil.size(indicesShape); - const dataBatchElements = ShapeUtil.sizeFromDimension(inputShape, axis) * elementSize; - const gatheredBatchElements = N * block * elementSize; + const dataBatchElements = ShapeUtil.sizeFromDimension(inputShape, axis) * elementSize / components; + const gatheredBatchElements = N * block * elementSize / components; const axisDimLimit = inputShape[axis]; - const inputSize = ShapeUtil.size(inputShape) * elementSize; - const outputSize = ShapeUtil.size(outputShape) * elementSize; + const inputSize = ShapeUtil.size(inputShape) * elementSize / components; + const outputSize = ShapeUtil.size(outputShape) * elementSize / components; const totalGathers = M * N; // int64 indices would be treated as little endian i32 with assumption they fall in i32 limits @@ -52,10 +69,9 @@ const createGatherProgramInfo = const N: u32 = ${N}; const elementSize: u32 = ${elementSize}; const indicesElementSize: u32 = ${indicesElementSize}; + const blockSize = ${blockSize / components}; - @group(0) @binding(0) var input : array; - @group(0) @binding(1) var inputIndices : array; - @group(0) @binding(2) var output: array; + ${shaderHelper.declareVariables(input, indices, output)} ${shaderHelper.mainStart()} let batch: u32 = global_idx / N; @@ -68,15 +84,15 @@ const createGatherProgramInfo = idx = idx + ${axisDimLimit}; } - let srcOffset = srcOffsetBatch + u32(idx) * ${blockSize}; - let dstOffset = dstOffsetBatch + i * ${blockSize}; + let srcOffset = srcOffsetBatch + u32(idx) * blockSize; + let dstOffset = dstOffsetBatch + i * blockSize; if (srcOffset >= ${inputSize}) { return; } if (dstOffset >= ${outputSize}) { return; } - for (var j: u32 = 0; j < ${blockSize}; j++) { + for (var j: u32 = 0; j < blockSize; j++) { output[dstOffset + j] = input[srcOffset + j]; } }`; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 97f7e84d79ded..6c681a95b45cd 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -195,12 +195,12 @@ const createInstanceNormNHWCProgramInfo = attributes: InstanceNormAttributes) => { const xShape = inputs[0].dims; const outputShape = xShape; - const outputSize = ShapeUtil.size(outputShape); const N = xShape[0]; const C = xShape[xShape.length - 1]; const H = ShapeUtil.sizeFromDimension(xShape, 1) / C; const components = getMaxComponents(C); + const outputSize = ShapeUtil.size(outputShape) / components; const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components); const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index 1d0b8229a76f7..fd0f374357358 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -105,9 +105,9 @@ const validateInputs = } }; -const getOriginalCoordinateFromResizedCoordinate = (coordinateTransferMode: CoordinateTransformMode): string => - 'fn getOriginalCoordinateFromResizedCoordinate(xResized: f32, xScale: f32, lengthResized: f32,\ - lengthOriginal: f32, roiStart: f32, roiEnd: f32) -> f32 { ' + +const getOriginalCoordinateFromResizedCoordinate = (coordinateTransferMode: CoordinateTransformMode, dType: string): string => + `fn getOriginalCoordinateFromResizedCoordinate(xResized: ${dType}, xScale: ${dType}, lengthResized: ${dType}, + lengthOriginal: ${dType}, roiStart: ${dType}, roiEnd: ${dType}) -> ${dType} { ` + (() => { switch (coordinateTransferMode) { case 'asymmetric': @@ -127,12 +127,12 @@ const getOriginalCoordinateFromResizedCoordinate = (coordinateTransferMode: Coor return xResized * (lengthOriginal - 1) / (lengthResized - 1); \ }'; case 'tf_crop_and_resize': - return 'if (lengthResized > 1) { \ + return `if (lengthResized > 1) { \ return roiStart * (lengthOriginal - 1) + \ (xResized * (roiEnd - roiStart) * (lengthOriginal - 1)) / (lengthResized - 1); \ } else { \ - return 0.5 * (roiStart + roiEnd) * f32(lengthOriginal - 1); \ - }'; + return 0.5 * (roiStart + roiEnd) * ${dType}(lengthOriginal - 1); \ + }`; case 'half_pixel_symmetric': return [ 'const outputWidth = xScale * lengthResized;', 'const adjustment = lengthResized / outputWidth;', @@ -147,8 +147,8 @@ const getOriginalCoordinateFromResizedCoordinate = (coordinateTransferMode: Coor })() + '}'; -const getNearestPixelFromOriginal = (nearestMode: NearestMode, opsetVersion: number): string => - 'fn getNearestPixelFromOriginal(xOriginal: f32, isDownSample: bool) -> f32 {' + (() => { +const getNearestPixelFromOriginal = (nearestMode: NearestMode, opsetVersion: number, dType: string): string => + `fn getNearestPixelFromOriginal(xOriginal: ${dType}, isDownSample: bool) -> ${dType} {` + (() => { switch (nearestMode) { case 'round_prefer_ceil': return 'if (fract(xOriginal) == 0.5) { \ @@ -248,20 +248,19 @@ const adjustOutputShape = const calculateOriginalIndicesFromOutputIndices = (output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], scales: readonly number[], roi: readonly number[]): string => ` - fn calculateOriginalIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> array { + fn calculateOriginalIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> array<${output.type.value}, ${outputShape.length}> { const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); const outputShape = array(${outputShape.map(i => `${i}u`).join(',')}); - const scales = array(${scales.map(i => `${i}f`).join(',')}); - const roi = array(${roi.map(i => `${i}f`).join(',')}); - var originalIndices: array; + const scales = array<${output.type.value}, ${scales.length}>(${scales.map(i => `${i}f`).join(',')}); + const roi = array<${output.type.value}, ${roi.length}>(${roi.map(i => `${i}f`).join(',')}); + var originalIndices: array<${output.type.value}, ${outputShape.length}>; for (var i:u32 = 0; i < ${outputShape.length}; i++) { var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'}; if (scales[i] == 1.0) { - originalIndices[i] = f32(outputIndex); + originalIndices[i] = ${output.type.value}(outputIndex); } else { - originalIndices[i] = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), scales[i], - f32(outputShape[i]), f32(inputShape[i]), roi[i], roi[i + ${inputShape.length}]); + originalIndices[i] = getOriginalCoordinateFromResizedCoordinate(${output.type.value}(outputIndex), scales[i], + ${output.type.value}(outputShape[i]), ${output.type.value}(inputShape[i]), roi[i], roi[i + ${inputShape.length}]); } } return originalIndices; @@ -273,8 +272,8 @@ const calculateInputIndicesFromOutputIndices = fn calculateInputIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} { const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); const outputShape = array(${outputShape.map(i => `${i}u`).join(',')}); - const scales = array(${scales.map(i => `${i}f`).join(',')}); - const roi = array(${roi.map(i => `${i}f`).join(',')}); + const scales = array<${input.type.value}, ${scales.length}>(${scales.map(i => `${i}`).join(',')}); + const roi = array<${input.type.value}, ${roi.length}>(${roi.map(i => `${i}`).join(',')}); var inputIndices: ${input.type.indices}; for (var i:u32 = 0; i < ${outputShape.length}; i++) { var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'}; @@ -282,12 +281,12 @@ const calculateInputIndicesFromOutputIndices = if (scales[i] == 1.0) { inputIndex = outputIndex; } else { - var original_idx = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), scales[i], - f32(outputShape[i]), f32(inputShape[i]), roi[i], roi[i + ${inputShape.length}]); - if (!${useExtrapolation} || (original_idx >= 0 && original_idx < f32(inputShape[i]))) { + var original_idx = getOriginalCoordinateFromResizedCoordinate(${input.type.value}(outputIndex), scales[i], + ${input.type.value}(outputShape[i]), ${input.type.value}(inputShape[i]), roi[i], roi[i + ${inputShape.length}]); + if (!${useExtrapolation} || (original_idx >= 0 && original_idx < ${input.type.value}(inputShape[i]))) { if (original_idx < 0) { inputIndex = 0; - } else if (original_idx > (f32(inputShape[i]) - 1)) { + } else if (original_idx > (${input.type.value}(inputShape[i]) - 1)) { inputIndex = inputShape[i] - 1; } else { inputIndex = u32(getNearestPixelFromOriginal(original_idx, scales[i] < 1)); @@ -318,8 +317,9 @@ const bilinearInterpolation = scales: readonly number[], useExtrapolation: boolean, extrapolationValue: number): string => { const [batchIdx, heightIdx, widthIdx, channelIdx] = inputShape.length === 2 ? [-1, 0, 1, -1] : (scales[1] === 1.0 ? [0, 2, 3, 1] : [0, 1, 2, 3]); + const dType = input.type.value; return ` - fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> f32 { + fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> ${dType} { var inputIndices: ${input.type.indices}; inputIndices[${heightIdx}] = max(0, min(row, ${inputShape[heightIdx]} - 1)); inputIndices[${widthIdx}] = max(0, min(col, ${inputShape[widthIdx]} - 1)); @@ -330,10 +330,10 @@ const bilinearInterpolation = return input[${input.indicesToOffset('inputIndices')}]; } - fn bilinearInterpolation(outputIndices: ${output.type.indices}) -> f32 { + fn bilinearInterpolation(outputIndices: ${output.type.indices}) -> ${dType} { var originalIndices = calculateOriginalIndicesFromOutputIndices(outputIndices); - var row:f32 = originalIndices[${heightIdx}]; - var col:f32 = originalIndices[${widthIdx}]; + var row:${dType} = originalIndices[${heightIdx}]; + var col:${dType} = originalIndices[${widthIdx}]; if (${useExtrapolation} && (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > ${ inputShape[widthIdx]} - 1)) { return ${extrapolationValue}; @@ -350,14 +350,14 @@ const bilinearInterpolation = channel = u32(originalIndices[${channelIdx}]); batch = u32(originalIndices[${batchIdx}]); } - var x11: f32 = getInputValue(batch, channel, row1, col1); - var x12: f32 = getInputValue(batch, channel, row1, col2); - var x21: f32 = getInputValue(batch, channel, row2, col1); - var x22: f32 = getInputValue(batch, channel, row2, col2); - var dx1: f32 = row - f32(row1); - var dx2: f32 = f32(row2 ) - row; - var dy1 = col - f32(col1); - var dy2 = f32(col2) - col; + var x11: ${dType} = getInputValue(batch, channel, row1, col1); + var x12: ${dType} = getInputValue(batch, channel, row1, col2); + var x21: ${dType} = getInputValue(batch, channel, row2, col1); + var x22: ${dType} = getInputValue(batch, channel, row2, col2); + var dx1: ${dType} = row - ${dType}(row1); + var dx2: ${dType} = ${dType}(row2) - row; + var dy1 = col - ${dType}(col1); + var dy2 = ${dType}(col2) - col; return (x11 * dx2 * dy2 + x12 * dx2 * dy1 + x21 * dx1 * dy2 + x22 * dx1 * dy1); }`; }; @@ -367,24 +367,24 @@ const bicubicInterpolation = scales: readonly number[], roi: readonly number[], cubicCoeffA: number, useExtrapolation: boolean, extrapolationValue: number, excludeOutside: boolean): string => { const [heightIdx, widthIdx] = inputShape.length === 2 ? [0, 1] : (scales[1] === 1.0) ? [2, 3] : [1, 2]; - + const dType = input.type.value; const createCubicInterpolationFunction = (idx: number): string => { const direction = idx === heightIdx ? 'row' : 'col'; return ` fn ${direction}CubicInterpolation(inputIndices: ${input.type.indices}, outputIndices: ${ - output.type.indices}) -> f32 { + output.type.indices}) -> ${dType} { var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : `outputIndices[${idx}]`}; - var originalIdx: f32 = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), ${scales[idx]}, - f32(${outputShape[idx]}), f32(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length}); - var fractOriginalIdx: f32 = originalIdx - floor(originalIdx); + var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(${dType}(outputIndex), ${scales[idx]}, + ${dType}(${outputShape[idx]}), ${dType}(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length}); + var fractOriginalIdx: ${dType} = originalIdx - floor(originalIdx); var coefs = getCubicInterpolationCoefs(fractOriginalIdx); if (${useExtrapolation} && (originalIdx < 0 || originalIdx > (${inputShape[idx]} - 1))) { return ${extrapolationValue}; } - var data: array = array(0.0, 0.0, 0.0, 0.0); + var data: array<${dType}, 4> = array<${dType}, 4>(0.0, 0.0, 0.0, 0.0); for (var i: i32 = -1; i < 3; i++) { - var ${direction}: f32 = originalIdx + f32(i); + var ${direction}: ${dType} = originalIdx + ${dType}(i); if (${direction} < 0 || ${direction} >= ${inputShape[idx]}) { if (${excludeOutside}) { coefs[i + 1] = 0.0; @@ -407,12 +407,12 @@ const bicubicInterpolation = return ` ${createCubicInterpolationFunction(heightIdx)}; ${createCubicInterpolationFunction(widthIdx)}; - fn getCubicInterpolationCoefs(s: f32) -> array { + fn getCubicInterpolationCoefs(s: ${dType}) -> array<${dType}, 4> { var absS = abs(s); - var coeffs: array = array(0.0, 0.0, 0.0, 0.0); - var oneMinusAbsS: f32 = 1.0 - absS; - var twoMinusAbsS: f32 = 2.0 - absS; - var onePlusAbsS: f32 = 1.0 + absS; + var coeffs: array<${dType}, 4> = array<${dType}, 4>(0.0, 0.0, 0.0, 0.0); + var oneMinusAbsS: ${dType} = 1.0 - absS; + var twoMinusAbsS: ${dType} = 2.0 - absS; + var onePlusAbsS: ${dType} = 1.0 + absS; coeffs[0] = ((${cubicCoeffA} * onePlusAbsS - 5 * ${cubicCoeffA}) * onePlusAbsS + 8 * ${ cubicCoeffA}) * onePlusAbsS - 4 * ${cubicCoeffA}; coeffs[1] = ((${cubicCoeffA} + 2) * absS - (${cubicCoeffA} + 3)) * absS * absS + 1; @@ -422,12 +422,12 @@ const bicubicInterpolation = return coeffs; } - fn cubicInterpolation1D(x: array, coefs: array) -> f32 { - var coefsSum: f32 = coefs[0] + coefs[1] + coefs[2] + coefs[3]; + fn cubicInterpolation1D(x: array<${dType}, 4>, coefs: array<${dType}, 4>) -> ${dType} { + var coefsSum: ${dType} = coefs[0] + coefs[1] + coefs[2] + coefs[3]; return (x[0] * coefs[0] + x[1] * coefs[1]+ x[2] * coefs[2]+ x[3] * coefs[3]) / coefsSum; } - fn bicubicInterpolation(outputIndices: ${output.type.indices}) -> f32 { + fn bicubicInterpolation(outputIndices: ${output.type.indices}) -> ${dType} { var inputIndices: ${input.type.indices} = outputIndices; return colCubicInterpolation(inputIndices, outputIndices); } @@ -453,14 +453,15 @@ const createResizeProgramInfo = const outputSize = ShapeUtil.size(outputShape); const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]); const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize'; + const dataType = input.type.value; const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode)}; + ${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode, dataType)}; ${(() => { switch (attributes.mode) { case 'nearest': return ` ${checkInputIndices(input, inputShape)}; - ${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion)}; + ${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion, dataType)}; ${ calculateInputIndicesFromOutputIndices( input, output, inputShape, outputShape, scales, roi, useExtrapolation)}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index b2a8285172d5a..b1f44ee8bd119 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -44,14 +44,10 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut const cols = shape[axis]; const rows = outputSize / cols; - // 6.2.4 in wgsl spec - const threadMaxDecl = dataType === 'f32' - ? 'var threadMax: f32 = -3.402823e+38f;' - : 'var threadMax: f16 = -65504.0h;'; const getShaderSource = (_shaderHelper: ShaderHelper) => ` - var rowMaxShared : ${dataType}; - var rowSumShared : ${dataType}; - var threadShared : array<${dataType}, ${WG}>; + var rowMaxShared : f32; + var rowSumShared : f32; + var threadShared : array; @group(0) @binding(0) var x : array<${dataType}>; @group(0) @binding(1) var result : array<${dataType}>; @@ -76,10 +72,10 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut let row_stride : i32 = ${cols}; // find the rows max - ${threadMaxDecl} + var threadMax = -3.402823e+38f; // 6.2.4 in wgsl spec for (var col = lindex; col < cols; col += wg) { let value = getValue(row, col, row_stride); - threadMax = max(threadMax, value); + threadMax = max(threadMax, f32(value)); } if (lindex < cols) { threadShared[lindex] = threadMax; @@ -100,9 +96,9 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut workgroupBarrier(); // find the rows sum - var threadSum: ${dataType} = 0.0; + var threadSum: f32 = 0.0; for (var col = lindex; col < cols; col += wg) { - let subExp = exp(getValue(row, col, row_stride) - rowMaxShared); + let subExp = exp(f32(getValue(row, col, row_stride)) - rowMaxShared); threadSum += subExp; } threadShared[lindex] = threadSum; @@ -121,7 +117,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut // calculate final value for each element in the row for (var col = lindex; col < cols; col += wg) { - let value = exp(getValue(row, col, row_stride) - rowMaxShared) / rowSumShared; + let value = exp(getValue(row, col, row_stride) - ${dataType}(rowMaxShared)) / ${dataType}(rowSumShared); setValue(row, col, row_stride, value); } }`;