Skip to content

Commit

Permalink
[js/webgpu] make RunFunction return void (microsoft#15669)
Browse files Browse the repository at this point in the history
### Description
make `RunFunction` return `void`.

the return value is meaningless in the OpResolveRule context. Allows any
JavaScript error to be caught and returns non-zero return value from
`computeKernel()`
  • Loading branch information
fs-eire authored Apr 25, 2023
1 parent 5c4f5bb commit 069950d
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 83 deletions.
6 changes: 5 additions & 1 deletion web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,11 @@ export class WebGpuBackend {

this.temporaryData = [];
try {
return kernelEntry(context, attributes[1]);
kernelEntry(context, attributes[1]);
return 0; // ORT_OK
} catch (e) {
LOG_DEBUG('warning', `[WebGPU] Kernel "${name}" failed. Error: ${e}`);
return 1; // ORT_FAIL
} finally {
for (const data of this.temporaryData) {
this.gpuDataManager.release(data.id);
Expand Down
2 changes: 1 addition & 1 deletion web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {parseTransposeAttributes, transpose} from './ops/transpose';
import * as unaryOps from './ops/unary-op';
import {ComputeContext} from './types';

export type RunFunction = (context: ComputeContext, attribute?: unknown) => number;
export type RunFunction = (context: ComputeContext, attribute?: unknown) => void;
export type ParseAttributeFunction = (attributeRaw: unknown) => unknown;
export type OperatorImplementation = [RunFunction]|[RunFunction, ParseAttributeFunction];

Expand Down
15 changes: 5 additions & 10 deletions web/lib/wasm/jsep/webgpu/ops/binary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -173,22 +173,19 @@ const createBinaryOpProgramInfoLoader =
};
};

export const add = (context: ComputeContext): number => {
export const add = (context: ComputeContext): void => {
context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Add', (a, b) => `${a}+${b}`));
return 0;
};

export const div = (context: ComputeContext): number => {
export const div = (context: ComputeContext): void => {
context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Div', (a, b) => `${a}/${b}`));
return 0;
};

export const mul = (context: ComputeContext): number => {
export const mul = (context: ComputeContext): void => {
context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Mul', (a, b) => `${a}*${b}`));
return 0;
};

export const pow = (context: ComputeContext): number => {
export const pow = (context: ComputeContext): void => {
context.compute(createBinaryOpProgramInfoLoader(
context.inputs, 'Pow', ({scalar: (a, b) => `pow_f32(${a},${b})`, vector: (a, b) => `pow_vf32(${a},${b})`}), `
fn pow_f32(a : f32, b : f32) -> f32 {
Expand All @@ -204,10 +201,8 @@ export const pow = (context: ComputeContext): number => {
return vec4<f32>(pow_f32(a.x, b.x), pow_f32(a.y, b.y), pow_f32(a.z, b.z), pow_f32(a.w, b.w));
}
`));
return 0;
};

export const sub = (context: ComputeContext): number => {
export const sub = (context: ComputeContext): void => {
context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Sub', (a, b) => `${a}-${b}`));
return 0;
};
3 changes: 1 addition & 2 deletions web/lib/wasm/jsep/webgpu/ops/concat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ const createConcatProgramInfoLoader =
return {...metadata, get: () => createConcatProgramInfo(metadata, inputs, attributes.axis)};
};

export const concat = (context: ComputeContext, attributes: ConcatAttributes): number => {
export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => {
validateInputs(context.inputs);
context.compute(createConcatProgramInfoLoader(context.inputs, attributes));
return 0;
};
19 changes: 10 additions & 9 deletions web/lib/wasm/jsep/webgpu/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ export const parseConvAttributes = (attributes: Record<string, unknown>): ConvAt
{autoPad, format, dilations, group, kernelShape, pads, strides, wIsConst, ...activationAttributes});
};

const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvAttributes): number => {
const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvAttributes): void => {
const adjustedAttributes = getAdjustedConvAttributes(attributes, inputs);

// check attributes
Expand Down Expand Up @@ -170,12 +170,12 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
attributes.autoPad === 'VALID'))) {
// TODO: implement conv2dByMatMul()
context.compute(createGroupedConvProgramInfoLoader(inputs, adjustedAttributes));
return 0;
return;
}

if (!isChannelsLast || attributes.group !== 1) {
context.compute(createGroupedConvProgramInfoLoader(inputs, adjustedAttributes));
return 0;
return;
}

// TODO: implement conv2dWithIm2Col()
Expand Down Expand Up @@ -215,10 +215,9 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
convInputs, adjustedAttributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias,
sequentialAccessByThreads),
{inputs: convInputs});
return 0;
};

const conv1d = (context: ComputeContext, attributes: ConvAttributes): number => {
const conv1d = (context: ComputeContext, attributes: ConvAttributes): void => {
// extend the input to 2D by adding H dimension
const isChannelLast = attributes.format === 'NHWC';
const inputs = [
Expand All @@ -242,11 +241,13 @@ const conv1d = (context: ComputeContext, attributes: ConvAttributes): number =>
context.compute(createGroupedConvProgramInfoLoader(
inputs, adjustedAttributes,
outputShape => isChannelLast ? [outputShape[0], outputShape[2], outputShape[3]] : []));
return 0;
};

export const conv = (context: ComputeContext, attributes: ConvAttributes): number => {
export const conv = (context: ComputeContext, attributes: ConvAttributes): void => {
validateInputs(context.inputs, attributes); // currently will fail if not conv1D/2D
return context.inputs[0].dims.length === 3 ? conv1d(context, attributes) :
conv2d(context, context.inputs, attributes);
if (context.inputs[0].dims.length === 3) {
conv1d(context, attributes);
} else {
conv2d(context, context.inputs, attributes);
}
};
3 changes: 1 addition & 2 deletions web/lib/wasm/jsep/webgpu/ops/gemm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,9 @@ const createGemmProgramInfoLoader = (inputs: readonly TensorView[], attributes:
return {...metadata, get: () => createGemmProgramInfo(metadata, inputs, attributes)};
};

export const gemm = (context: ComputeContext, attributes: GemmAttributes): number => {
export const gemm = (context: ComputeContext, attributes: GemmAttributes): void => {
validateInputs(context.inputs);
context.compute(createGemmProgramInfoLoader(context.inputs, attributes));
return 0;
};

export const parseGemmAttributes = (attributes: Record<string, unknown>): GemmAttributes =>
Expand Down
3 changes: 1 addition & 2 deletions web/lib/wasm/jsep/webgpu/ops/matmul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
}
};

export const matMul = (context: ComputeContext): number => {
export const matMul = (context: ComputeContext): void => {
validateInputs(context.inputs);

context.compute(createMatmulProgramInfoLoader(context.inputs, {activation: '', activationCacheKey: ''}));
return 0;
};
12 changes: 4 additions & 8 deletions web/lib/wasm/jsep/webgpu/ops/pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,10 @@ export const parseAveragePoolAttributes = (attributes: Record<string, unknown>):
return createAttributeWithCacheKey({countIncludePad, ...attr});
};

export const averagePool = (context: ComputeContext, attributes: AveragePoolAttributes): number => {
export const averagePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => {
validateInputs(context.inputs);
const metadata = {name: 'AveragePool', inputTypes: [GpuDataType.default], cacheHint: attributes.cacheKey};
context.compute({...metadata, get: () => createAveragePoolProgramInfo(context.inputs, metadata, false, attributes)});
return 0;
};

const globalPoolAttributes = {
Expand All @@ -313,11 +312,10 @@ export const parseGlobalAveragePoolAttributes = (attributes: Record<string, unkn
return {format, ...globalPoolAttributes, cacheKey: format};
};

export const globalAveragePool = (context: ComputeContext, attributes: AveragePoolAttributes): number => {
export const globalAveragePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => {
validateInputs(context.inputs);
const metadata = {name: 'GlobalAveragePool', inputTypes: [GpuDataType.default], cacheHint: attributes.cacheKey};
context.compute({...metadata, get: () => createAveragePoolProgramInfo(context.inputs, metadata, true, attributes)});
return 0;
};

export interface MaxPoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey {
Expand All @@ -343,11 +341,10 @@ const createMaxPoolProgramInfo =
};
};

export const maxPool = (context: ComputeContext, attributes: MaxPoolAttributes): number => {
export const maxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => {
validateInputs(context.inputs);
const metadata = {name: 'MaxPool', inputTypes: [GpuDataType.default], cacheHint: attributes.cacheKey};
context.compute({...metadata, get: () => createMaxPoolProgramInfo(context.inputs, metadata, false, attributes)});
return 0;
};

export const parseMaxPoolAttributes = (attributes: Record<string, unknown>): MaxPoolAttributes => {
Expand All @@ -371,9 +368,8 @@ export const parseGlobalMaxPoolAttributes = (attributes: Record<string, unknown>
return {format, ...globalPoolAttributes, cacheKey: format};
};

export const globalMaxPool = (context: ComputeContext, attributes: MaxPoolAttributes): number => {
export const globalMaxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => {
validateInputs(context.inputs);
const metadata = {name: 'GlobalMaxPool', inputTypes: [GpuDataType.default], cacheHint: attributes.cacheKey};
context.compute({...metadata, get: () => createMaxPoolProgramInfo(context.inputs, metadata, true, attributes)});
return 0;
};
3 changes: 1 addition & 2 deletions web/lib/wasm/jsep/webgpu/ops/transpose.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,13 @@ export const createTransposeProgramInfo = (input: TensorView, permAttr: number[]
};
};

export const transpose = (context: ComputeContext, attributes: TransposeAttributes): number => {
export const transpose = (context: ComputeContext, attributes: TransposeAttributes): void => {
validateInputs(context.inputs);
context.compute({
...transposeProgramMetadata,
cacheHint: attributes.cacheKey,
get: () => createTransposeProgramInfo(context.inputs[0], attributes.perm)
});
return 0;
};

export const parseTransposeAttributes = (attributes: Record<string, unknown>): TransposeAttributes =>
Expand Down
70 changes: 24 additions & 46 deletions web/lib/wasm/jsep/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,46 +58,39 @@ const createElementwiseProgramInfoLoader =
};
};

export const abs = (context: ComputeContext): number => {
export const abs = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Abs', 'abs'));
return 0;
};

export const acos = (context: ComputeContext): number => {
export const acos = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Acos', 'acos'));
return 0;
};

export const acosh = (context: ComputeContext): number => {
export const acosh = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Acosh', 'acosh'));
return 0;
};

export const asin = (context: ComputeContext): number => {
export const asin = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Asin', 'asin'));
return 0;
};

export const asinh = (context: ComputeContext): number => {
export const asinh = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Asinh', 'asinh'));
return 0;
};

export const atan = (context: ComputeContext): number => {
export const atan = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Atan', 'atan'));
return 0;
};
export const atanh = (context: ComputeContext): number => {
export const atanh = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Atanh', 'atanh'));
return 0;
};

export interface ClipAttributes extends AttributeWithCacheKey {
readonly min: number;
readonly max: number;
}

export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): number => {
export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => {
context.compute(
createElementwiseProgramInfoLoader(
context.inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, `
Expand All @@ -106,39 +99,35 @@ export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): nu
`,
attributes.cacheKey),
{inputs: [0]});
return 0;
};
const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => {
const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP;
const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP;
return createAttributeWithCacheKey({min, max});
};

export const clip = (context: ComputeContext): number => {
export const clip = (context: ComputeContext): void => {
const attributes = generateClipAttributesFromInputs(context.inputs);
return clipV10(context, attributes);
clipV10(context, attributes);
};

export const ceil = (context: ComputeContext): number => {
export const ceil = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Ceil', 'ceil'));
return 0;
};

export const cos = (context: ComputeContext): number => {
export const cos = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Cos', 'cos'));
return 0;
};

export const cosh = (context: ComputeContext): number => {
export const cosh = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Cosh', 'cosh'));
return 0;
};

export interface EluAttributes extends AttributeWithCacheKey {
readonly alpha: number;
}

export const elu = (context: ComputeContext, attributes: EluAttributes): number => {
export const elu = (context: ComputeContext, attributes: EluAttributes): void => {
context.compute(createElementwiseProgramInfoLoader(
context.inputs[0], 'Elu', a => `elu_vf32(${a})`, `
const elu_alpha_: f32 = f32(${attributes.alpha});
Expand All @@ -151,13 +140,12 @@ export const elu = (context: ComputeContext, attributes: EluAttributes): number
return vec4(elu_f32(v.x), elu_f32(v.y), elu_f32(v.z), elu_f32(v.w));
}`,
attributes.cacheKey));
return 0;
};

export const parseEluAttributes = (attributes: Record<string, unknown>): EluAttributes =>
createAttributeWithCacheKey(attributes as {alpha: number});

export const erf = (context: ComputeContext): number => {
export const erf = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, `
const r0: f32 = 0.3275911;
const r1: f32 = 0.254829592;
Expand All @@ -171,50 +159,40 @@ export const erf = (context: ComputeContext): number => {
let x = 1.0 / (1.0 + r0 * absv);
return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv));
}`));
return 0;
};

export const floor = (context: ComputeContext): number => {
export const floor = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Floor', 'floor'));
return 0;
};

export const neg = (context: ComputeContext): number => {
export const neg = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Neg', a => `-${a}`));
return 0;
};

export const reciprocal = (context: ComputeContext): number => {
export const reciprocal = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Reciprocal', a => `1.0/${a}`));
return 0;
};

export const sigmoid = (context: ComputeContext): number => {
export const sigmoid = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`));
return 0;
};

export const sin = (context: ComputeContext): number => {
export const sin = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Sin', 'sin'));
return 0;
};

export const sinh = (context: ComputeContext): number => {
export const sinh = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Sinh', 'sinh'));
return 0;
};

export const sqrt = (context: ComputeContext): number => {
export const sqrt = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Sqrt', 'sqrt'));
return 0;
};

export const tan = (context: ComputeContext): number => {
export const tan = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Tan', 'tan'));
return 0;
};

export const tanh = (context: ComputeContext): number => {
export const tanh = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Tanh', 'tanh'));
return 0;
};

0 comments on commit 069950d

Please sign in to comment.