Skip to content

Commit

Permalink
[js/web] unify resolve rules for "Clip" (#18527)
Browse files Browse the repository at this point in the history
### Description
It was a mistake to use 2 different names for Clip operator in
op-resolve-rules.ts for different opset. An optimized implementation can
handle both cases (opset < 11 and opset >=11). Remove "ClipV10" as an
entry from the table.
  • Loading branch information
fs-eire authored Nov 21, 2023
1 parent abdf8b7 commit c7fd930
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 13 deletions.
1 change: 0 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['BiasSplitGelu', [biasSplitGelu]],
['Cast', [unaryOps.cast, unaryOps.parseCastAttributes]],
['Ceil', [unaryOps.ceil]],
['ClipV10', [unaryOps.clipV10]],
['Clip', [unaryOps.clip]],
['Concat', [concat, parseConcatAttributes]],
['Conv', [conv, parseConvAttributes]],
Expand Down
19 changes: 8 additions & 11 deletions js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,14 @@ export interface ClipAttributes extends AttributeWithCacheKey {
readonly max: number;
}

export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => {
const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => {
const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP;
const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP;
return createAttributeWithCacheKey({min, max});
};

export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => {
const attributes = context.inputs.length === 1 ? clipAttributes : generateClipAttributesFromInputs(context.inputs);
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
context.compute(
createElementwiseProgramInfo(
Expand All @@ -135,16 +142,6 @@ export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): vo
attributes.cacheKey),
{inputs: [0]});
};
const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => {
const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP;
const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP;
return createAttributeWithCacheKey({min, max});
};

export const clip = (context: ComputeContext): void => {
const attributes = generateClipAttributesFromInputs(context.inputs);
clipV10(context, attributes);
};

export const ceil = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Ceil', 'ceil'));
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/js/operators/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ JSEP_ELEMENTWISE_TYPED_KERNEL(Not, 1, bool, Not)

// activation

JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, ClipV10, min, 3.402823e+38f, max, -3.402823e+38f)
JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, Clip, min, 3.402823e+38f, max, -3.402823e+38f)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, ClipV10)
JSEP_KERNEL_IMPL(Clip, Clip)
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, kJsExecutionProvider,
Expand Down

0 comments on commit c7fd930

Please sign in to comment.