Skip to content

Commit

Permalink
[js/web] unify resolve rules for "Clip" (microsoft#18527)
Browse files Browse the repository at this point in the history
### Description
It was a mistake to use 2 different names for Clip operator in
op-resolve-rules.ts for different opset. An optimized implementation can
handle both cases (opset < 11 and opset >=11). Remove "ClipV10" as an
entry from the table.
  • Loading branch information
fs-eire authored Nov 21, 2023
1 parent de8704e commit 9ba181e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
1 change: 0 additions & 1 deletion web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['BiasSplitGelu', [biasSplitGelu]],
['Cast', [unaryOps.cast, unaryOps.parseCastAttributes]],
['Ceil', [unaryOps.ceil]],
['ClipV10', [unaryOps.clipV10]],
['Clip', [unaryOps.clip]],
['Concat', [concat, parseConcatAttributes]],
['Conv', [conv, parseConvAttributes]],
Expand Down
19 changes: 8 additions & 11 deletions web/lib/wasm/jsep/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,14 @@ export interface ClipAttributes extends AttributeWithCacheKey {
readonly max: number;
}

export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => {
const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => {
const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP;
const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP;
return createAttributeWithCacheKey({min, max});
};

export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => {
const attributes = context.inputs.length === 1 ? clipAttributes : generateClipAttributesFromInputs(context.inputs);
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
context.compute(
createElementwiseProgramInfo(
Expand All @@ -135,16 +142,6 @@ export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): vo
attributes.cacheKey),
{inputs: [0]});
};
const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => {
const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP;
const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP;
return createAttributeWithCacheKey({min, max});
};

export const clip = (context: ComputeContext): void => {
const attributes = generateClipAttributesFromInputs(context.inputs);
clipV10(context, attributes);
};

export const ceil = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Ceil', 'ceil'));
Expand Down

0 comments on commit 9ba181e

Please sign in to comment.