Skip to content

Commit

Permalink
Gather, resize and softmax fixes for fp16
Browse files Browse the repository at this point in the history
  • Loading branch information
dakenf committed Sep 13, 2023
1 parent 9e1ce2b commit 8c58b3c
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 86 deletions.
9 changes: 4 additions & 5 deletions js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,10 @@ const createBiasAddProgramInfo = (metadata: ProgramMetadata, inputs: readonly Te
// since channel number can be only 320/640/1280, it's always divisable by 4
const outputSize = ShapeUtil.size(outputShape) / 4;

const dataType = inputs[0].dataType;
const input = inputVariable('input', dataType, outputShape, 4);
const bias = inputVariable('bias', dataType, [channels], 4);
const residual = inputVariable('residual', dataType, outputShape, 4);
const output = outputVariable('output', dataType, outputShape, 4);
const input = inputVariable('input', inputs[0].dataType, outputShape, 4);
const bias = inputVariable('bias', inputs[1].dataType, [channels], 4);
const residual = inputVariable('residual', inputs[1].dataType, outputShape, 4);
const output = outputVariable('output', inputs[0].dataType, outputShape, 4);

const getShaderSource = (shaderHelper: ShaderHelper) => `
const channels = ${channels}u / 4;
Expand Down
50 changes: 33 additions & 17 deletions js/web/lib/wasm/jsep/webgpu/ops/gather.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';
import { DataType } from '../../../wasm-common'
import { TensorView } from '../../tensor'
import { ShapeUtil } from '../../util'
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'
import { ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata } from '../types'

import {ShaderHelper} from './common';
import { getMaxComponents, inputVariable, outputVariable, ShaderHelper } from './common'

export interface GatherAttributes extends AttributeWithCacheKey {
axis: number;
Expand All @@ -32,17 +32,34 @@ const createGatherProgramInfo =

const inputDataType = inputs[0].dataType;
const block = ShapeUtil.sizeFromDimension(inputShape, axis + 1);
const elementSize = [DataType.int64, DataType.uint64, DataType.double].includes(inputDataType) ? 2 : 1;
let elementSize = [DataType.int64, DataType.uint64, DataType.double].includes(inputDataType) ? 2 : 1;
const indicesElementSize = inputs[1].dataType === DataType.int64 ? 2 : 1;

// for f16 when block size is odd, we'll use single f16
// when it's odd just one u32
let gatherType = DataType.uint32;
if (inputDataType === DataType.float16) {
if (block % 2 === 0) {
elementSize = 2;
} else {
gatherType = DataType.float16;
}
}
const blockSize = elementSize * block;
const components = getMaxComponents(blockSize);

const input = inputVariable('input', gatherType, inputShape, components);
const indices = inputVariable('inputIndices', DataType.int32, indicesShape);
const output = outputVariable('output', gatherType, outputShape, components);

const M = ShapeUtil.sizeToDimension(inputShape, axis);
const N = ShapeUtil.size(indicesShape);
const dataBatchElements = ShapeUtil.sizeFromDimension(inputShape, axis) * elementSize;
const gatheredBatchElements = N * block * elementSize;
const dataBatchElements = ShapeUtil.sizeFromDimension(inputShape, axis) * elementSize / components;
const gatheredBatchElements = N * block * elementSize / components;
const axisDimLimit = inputShape[axis];

const inputSize = ShapeUtil.size(inputShape) * elementSize;
const outputSize = ShapeUtil.size(outputShape) * elementSize;
const inputSize = ShapeUtil.size(inputShape) * elementSize / components;
const outputSize = ShapeUtil.size(outputShape) * elementSize / components;

const totalGathers = M * N;
// int64 indices would be treated as little endian i32 with assumption they fall in i32 limits
Expand All @@ -52,10 +69,9 @@ const createGatherProgramInfo =
const N: u32 = ${N};
const elementSize: u32 = ${elementSize};
const indicesElementSize: u32 = ${indicesElementSize};
const blockSize = ${blockSize / components};
@group(0) @binding(0) var<storage, read> input : array<u32>;
@group(0) @binding(1) var<storage, read> inputIndices : array<i32>;
@group(0) @binding(2) var<storage, read_write> output: array<u32>;
${shaderHelper.declareVariables(input, indices, output)}
${shaderHelper.mainStart()}
let batch: u32 = global_idx / N;
Expand All @@ -68,15 +84,15 @@ const createGatherProgramInfo =
idx = idx + ${axisDimLimit};
}
let srcOffset = srcOffsetBatch + u32(idx) * ${blockSize};
let dstOffset = dstOffsetBatch + i * ${blockSize};
let srcOffset = srcOffsetBatch + u32(idx) * blockSize;
let dstOffset = dstOffsetBatch + i * blockSize;
if (srcOffset >= ${inputSize}) {
return;
}
if (dstOffset >= ${outputSize}) {
return;
}
for (var j: u32 = 0; j < ${blockSize}; j++) {
for (var j: u32 = 0; j < blockSize; j++) {
output[dstOffset + j] = input[srcOffset + j];
}
}`;
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,12 @@ const createInstanceNormNHWCProgramInfo =
attributes: InstanceNormAttributes) => {
const xShape = inputs[0].dims;
const outputShape = xShape;
const outputSize = ShapeUtil.size(outputShape);
const N = xShape[0];
const C = xShape[xShape.length - 1];
const H = ShapeUtil.sizeFromDimension(xShape, 1) / C;

const components = getMaxComponents(C);
const outputSize = ShapeUtil.size(outputShape) / components;
const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components);
const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components);

Expand Down
103 changes: 52 additions & 51 deletions js/web/lib/wasm/jsep/webgpu/ops/resize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ const validateInputs =
}
};

const getOriginalCoordinateFromResizedCoordinate = (coordinateTransferMode: CoordinateTransformMode): string =>
'fn getOriginalCoordinateFromResizedCoordinate(xResized: f32, xScale: f32, lengthResized: f32,\
lengthOriginal: f32, roiStart: f32, roiEnd: f32) -> f32 { ' +
const getOriginalCoordinateFromResizedCoordinate = (coordinateTransferMode: CoordinateTransformMode, dType: string): string =>
`fn getOriginalCoordinateFromResizedCoordinate(xResized: ${dType}, xScale: ${dType}, lengthResized: ${dType},
lengthOriginal: ${dType}, roiStart: ${dType}, roiEnd: ${dType}) -> ${dType} { ` +
(() => {
switch (coordinateTransferMode) {
case 'asymmetric':
Expand All @@ -127,12 +127,12 @@ const getOriginalCoordinateFromResizedCoordinate = (coordinateTransferMode: Coor
return xResized * (lengthOriginal - 1) / (lengthResized - 1); \
}';
case 'tf_crop_and_resize':
return 'if (lengthResized > 1) { \
return `if (lengthResized > 1) { \
return roiStart * (lengthOriginal - 1) + \
(xResized * (roiEnd - roiStart) * (lengthOriginal - 1)) / (lengthResized - 1); \
} else { \
return 0.5 * (roiStart + roiEnd) * f32(lengthOriginal - 1); \
}';
return 0.5 * (roiStart + roiEnd) * ${dType}(lengthOriginal - 1); \
}`;
case 'half_pixel_symmetric':
return [
'const outputWidth = xScale * lengthResized;', 'const adjustment = lengthResized / outputWidth;',
Expand All @@ -147,8 +147,8 @@ const getOriginalCoordinateFromResizedCoordinate = (coordinateTransferMode: Coor
})() +
'}';

const getNearestPixelFromOriginal = (nearestMode: NearestMode, opsetVersion: number): string =>
'fn getNearestPixelFromOriginal(xOriginal: f32, isDownSample: bool) -> f32 {' + (() => {
const getNearestPixelFromOriginal = (nearestMode: NearestMode, opsetVersion: number, dType: string): string =>
`fn getNearestPixelFromOriginal(xOriginal: ${dType}, isDownSample: bool) -> ${dType} {` + (() => {
switch (nearestMode) {
case 'round_prefer_ceil':
return 'if (fract(xOriginal) == 0.5) { \
Expand Down Expand Up @@ -248,20 +248,19 @@ const adjustOutputShape =
const calculateOriginalIndicesFromOutputIndices =
(output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], scales: readonly number[],
roi: readonly number[]): string => `
fn calculateOriginalIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> array<f32, ${
outputShape.length}> {
fn calculateOriginalIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> array<${output.type.value}, ${outputShape.length}> {
const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});
const outputShape = array<u32, ${outputShape.length}>(${outputShape.map(i => `${i}u`).join(',')});
const scales = array<f32, ${scales.length}>(${scales.map(i => `${i}f`).join(',')});
const roi = array<f32, ${roi.length}>(${roi.map(i => `${i}f`).join(',')});
var originalIndices: array<f32, ${outputShape.length}>;
const scales = array<${output.type.value}, ${scales.length}>(${scales.map(i => `${i}f`).join(',')});
const roi = array<${output.type.value}, ${roi.length}>(${roi.map(i => `${i}f`).join(',')});
var originalIndices: array<${output.type.value}, ${outputShape.length}>;
for (var i:u32 = 0; i < ${outputShape.length}; i++) {
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'};
if (scales[i] == 1.0) {
originalIndices[i] = f32(outputIndex);
originalIndices[i] = ${output.type.value}(outputIndex);
} else {
originalIndices[i] = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), scales[i],
f32(outputShape[i]), f32(inputShape[i]), roi[i], roi[i + ${inputShape.length}]);
originalIndices[i] = getOriginalCoordinateFromResizedCoordinate(${output.type.value}(outputIndex), scales[i],
${output.type.value}(outputShape[i]), ${output.type.value}(inputShape[i]), roi[i], roi[i + ${inputShape.length}]);
}
}
return originalIndices;
Expand All @@ -273,21 +272,21 @@ const calculateInputIndicesFromOutputIndices =
fn calculateInputIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} {
const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});
const outputShape = array<u32, ${outputShape.length}>(${outputShape.map(i => `${i}u`).join(',')});
const scales = array<f32, ${scales.length}>(${scales.map(i => `${i}f`).join(',')});
const roi = array<f32, ${roi.length}>(${roi.map(i => `${i}f`).join(',')});
const scales = array<${input.type.value}, ${scales.length}>(${scales.map(i => `${i}`).join(',')});
const roi = array<${input.type.value}, ${roi.length}>(${roi.map(i => `${i}`).join(',')});
var inputIndices: ${input.type.indices};
for (var i:u32 = 0; i < ${outputShape.length}; i++) {
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'};
var inputIndex: u32;
if (scales[i] == 1.0) {
inputIndex = outputIndex;
} else {
var original_idx = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), scales[i],
f32(outputShape[i]), f32(inputShape[i]), roi[i], roi[i + ${inputShape.length}]);
if (!${useExtrapolation} || (original_idx >= 0 && original_idx < f32(inputShape[i]))) {
var original_idx = getOriginalCoordinateFromResizedCoordinate(${input.type.value}(outputIndex), scales[i],
${input.type.value}(outputShape[i]), ${input.type.value}(inputShape[i]), roi[i], roi[i + ${inputShape.length}]);
if (!${useExtrapolation} || (original_idx >= 0 && original_idx < ${input.type.value}(inputShape[i]))) {
if (original_idx < 0) {
inputIndex = 0;
} else if (original_idx > (f32(inputShape[i]) - 1)) {
} else if (original_idx > (${input.type.value}(inputShape[i]) - 1)) {
inputIndex = inputShape[i] - 1;
} else {
inputIndex = u32(getNearestPixelFromOriginal(original_idx, scales[i] < 1));
Expand Down Expand Up @@ -318,8 +317,9 @@ const bilinearInterpolation =
scales: readonly number[], useExtrapolation: boolean, extrapolationValue: number): string => {
const [batchIdx, heightIdx, widthIdx, channelIdx] =
inputShape.length === 2 ? [-1, 0, 1, -1] : (scales[1] === 1.0 ? [0, 2, 3, 1] : [0, 1, 2, 3]);
const dType = input.type.value;
return `
fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> f32 {
fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> ${dType} {
var inputIndices: ${input.type.indices};
inputIndices[${heightIdx}] = max(0, min(row, ${inputShape[heightIdx]} - 1));
inputIndices[${widthIdx}] = max(0, min(col, ${inputShape[widthIdx]} - 1));
Expand All @@ -330,10 +330,10 @@ const bilinearInterpolation =
return input[${input.indicesToOffset('inputIndices')}];
}
fn bilinearInterpolation(outputIndices: ${output.type.indices}) -> f32 {
fn bilinearInterpolation(outputIndices: ${output.type.indices}) -> ${dType} {
var originalIndices = calculateOriginalIndicesFromOutputIndices(outputIndices);
var row:f32 = originalIndices[${heightIdx}];
var col:f32 = originalIndices[${widthIdx}];
var row:${dType} = originalIndices[${heightIdx}];
var col:${dType} = originalIndices[${widthIdx}];
if (${useExtrapolation} && (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > ${
inputShape[widthIdx]} - 1)) {
return ${extrapolationValue};
Expand All @@ -350,14 +350,14 @@ const bilinearInterpolation =
channel = u32(originalIndices[${channelIdx}]);
batch = u32(originalIndices[${batchIdx}]);
}
var x11: f32 = getInputValue(batch, channel, row1, col1);
var x12: f32 = getInputValue(batch, channel, row1, col2);
var x21: f32 = getInputValue(batch, channel, row2, col1);
var x22: f32 = getInputValue(batch, channel, row2, col2);
var dx1: f32 = row - f32(row1);
var dx2: f32 = f32(row2 ) - row;
var dy1 = col - f32(col1);
var dy2 = f32(col2) - col;
var x11: ${dType} = getInputValue(batch, channel, row1, col1);
var x12: ${dType} = getInputValue(batch, channel, row1, col2);
var x21: ${dType} = getInputValue(batch, channel, row2, col1);
var x22: ${dType} = getInputValue(batch, channel, row2, col2);
var dx1: ${dType} = row - ${dType}(row1);
var dx2: ${dType} = ${dType}(row2) - row;
var dy1 = col - ${dType}(col1);
var dy2 = ${dType}(col2) - col;
return (x11 * dx2 * dy2 + x12 * dx2 * dy1 + x21 * dx1 * dy2 + x22 * dx1 * dy1);
}`;
};
Expand All @@ -367,24 +367,24 @@ const bicubicInterpolation =
scales: readonly number[], roi: readonly number[], cubicCoeffA: number, useExtrapolation: boolean,
extrapolationValue: number, excludeOutside: boolean): string => {
const [heightIdx, widthIdx] = inputShape.length === 2 ? [0, 1] : (scales[1] === 1.0) ? [2, 3] : [1, 2];

const dType = input.type.value;
const createCubicInterpolationFunction = (idx: number): string => {
const direction = idx === heightIdx ? 'row' : 'col';
return `
fn ${direction}CubicInterpolation(inputIndices: ${input.type.indices}, outputIndices: ${
output.type.indices}) -> f32 {
output.type.indices}) -> ${dType} {
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : `outputIndices[${idx}]`};
var originalIdx: f32 = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), ${scales[idx]},
f32(${outputShape[idx]}), f32(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length});
var fractOriginalIdx: f32 = originalIdx - floor(originalIdx);
var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(${dType}(outputIndex), ${scales[idx]},
${dType}(${outputShape[idx]}), ${dType}(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length});
var fractOriginalIdx: ${dType} = originalIdx - floor(originalIdx);
var coefs = getCubicInterpolationCoefs(fractOriginalIdx);
if (${useExtrapolation} && (originalIdx < 0 || originalIdx > (${inputShape[idx]} - 1))) {
return ${extrapolationValue};
}
var data: array<f32, 4> = array<f32, 4>(0.0, 0.0, 0.0, 0.0);
var data: array<${dType}, 4> = array<${dType}, 4>(0.0, 0.0, 0.0, 0.0);
for (var i: i32 = -1; i < 3; i++) {
var ${direction}: f32 = originalIdx + f32(i);
var ${direction}: ${dType} = originalIdx + ${dType}(i);
if (${direction} < 0 || ${direction} >= ${inputShape[idx]}) {
if (${excludeOutside}) {
coefs[i + 1] = 0.0;
Expand All @@ -407,12 +407,12 @@ const bicubicInterpolation =
return `
${createCubicInterpolationFunction(heightIdx)};
${createCubicInterpolationFunction(widthIdx)};
fn getCubicInterpolationCoefs(s: f32) -> array<f32, 4> {
fn getCubicInterpolationCoefs(s: ${dType}) -> array<${dType}, 4> {
var absS = abs(s);
var coeffs: array<f32, 4> = array<f32, 4>(0.0, 0.0, 0.0, 0.0);
var oneMinusAbsS: f32 = 1.0 - absS;
var twoMinusAbsS: f32 = 2.0 - absS;
var onePlusAbsS: f32 = 1.0 + absS;
var coeffs: array<${dType}, 4> = array<${dType}, 4>(0.0, 0.0, 0.0, 0.0);
var oneMinusAbsS: ${dType} = 1.0 - absS;
var twoMinusAbsS: ${dType} = 2.0 - absS;
var onePlusAbsS: ${dType} = 1.0 + absS;
coeffs[0] = ((${cubicCoeffA} * onePlusAbsS - 5 * ${cubicCoeffA}) * onePlusAbsS + 8 * ${
cubicCoeffA}) * onePlusAbsS - 4 * ${cubicCoeffA};
coeffs[1] = ((${cubicCoeffA} + 2) * absS - (${cubicCoeffA} + 3)) * absS * absS + 1;
Expand All @@ -422,12 +422,12 @@ const bicubicInterpolation =
return coeffs;
}
fn cubicInterpolation1D(x: array<f32, 4>, coefs: array<f32, 4>) -> f32 {
var coefsSum: f32 = coefs[0] + coefs[1] + coefs[2] + coefs[3];
fn cubicInterpolation1D(x: array<${dType}, 4>, coefs: array<${dType}, 4>) -> ${dType} {
var coefsSum: ${dType} = coefs[0] + coefs[1] + coefs[2] + coefs[3];
return (x[0] * coefs[0] + x[1] * coefs[1]+ x[2] * coefs[2]+ x[3] * coefs[3]) / coefsSum;
}
fn bicubicInterpolation(outputIndices: ${output.type.indices}) -> f32 {
fn bicubicInterpolation(outputIndices: ${output.type.indices}) -> ${dType} {
var inputIndices: ${input.type.indices} = outputIndices;
return colCubicInterpolation(inputIndices, outputIndices);
}
Expand All @@ -453,14 +453,15 @@ const createResizeProgramInfo =
const outputSize = ShapeUtil.size(outputShape);
const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]);
const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize';
const dataType = input.type.value;
const getShaderSource = (shaderHelper: ShaderHelper) => `
${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode)};
${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode, dataType)};
${(() => {
switch (attributes.mode) {
case 'nearest':
return `
${checkInputIndices(input, inputShape)};
${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion)};
${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion, dataType)};
${
calculateInputIndicesFromOutputIndices(
input, output, inputShape, outputShape, scales, roi, useExtrapolation)};
Expand Down
Loading

0 comments on commit 8c58b3c

Please sign in to comment.