From 9ba181e250938349a64036110fce4a8faefb530a Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 20 Nov 2023 23:18:06 -0800 Subject: [PATCH] [js/web] unify resolve rules for "Clip" (#18527) ### Description It was a mistake to use 2 different names for Clip operator in op-resolve-rules.ts for different opset. An optimized implementation can handle both cases (opset < 11 and opset >=11). Remove "ClipV10" as an entry from the table. --- web/lib/wasm/jsep/webgpu/op-resolve-rules.ts | 1 - web/lib/wasm/jsep/webgpu/ops/unary-op.ts | 19 ++++++++----------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 9f5dceb8f4726..bac44328d8f44 100644 --- a/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -55,7 +55,6 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['BiasSplitGelu', [biasSplitGelu]], ['Cast', [unaryOps.cast, unaryOps.parseCastAttributes]], ['Ceil', [unaryOps.ceil]], - ['ClipV10', [unaryOps.clipV10]], ['Clip', [unaryOps.clip]], ['Concat', [concat, parseConcatAttributes]], ['Conv', [conv, parseConvAttributes]], diff --git a/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 4238449f9246f..119609e06f5a3 100644 --- a/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -124,7 +124,14 @@ export interface ClipAttributes extends AttributeWithCacheKey { readonly max: number; } -export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => { +const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => { + const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP; + const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP; + return createAttributeWithCacheKey({min, max}); +}; + +export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => { + const attributes = context.inputs.length === 1 ? clipAttributes : generateClipAttributesFromInputs(context.inputs); const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); context.compute( createElementwiseProgramInfo( @@ -135,16 +142,6 @@ export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): vo attributes.cacheKey), {inputs: [0]}); }; -const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => { - const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP; - const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP; - return createAttributeWithCacheKey({min, max}); -}; - -export const clip = (context: ComputeContext): void => { - const attributes = generateClipAttributesFromInputs(context.inputs); - clipV10(context, attributes); -}; export const ceil = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Ceil', 'ceil'));