From 5b065050734e6bc397dc38ba0df246aeb57ac508 Mon Sep 17 00:00:00 2001 From: Jiajie Hu Date: Fri, 26 Jan 2024 00:25:35 +0800 Subject: [PATCH] [js/webgpu] Fix Tanh explosion (#19201) ### Description ```math \tanh(x)=\frac{e^x-e^{-x}}{e^x+e^{-x}}= \left\{ \begin{array}{cc} -\frac{1-e^{-2\cdot(-x)}}{1+e^{-2\cdot(-x)}}, & x<0 \\ 0, & x=0 \\ \frac{1-e^{-2x}}{1+e^{-2x}}, & x>0 \end{array} \right. ``` ### Motivation and Context On some platforms, $$\tanh(1000)=\frac{e^{1000}-e^{-1000}}{e^{1000}+e^{-1000}}$$ would produce NaN instead of 0.999... or 1 (imagine $e^{1000}=\infty$ and $\frac{\infty}{\infty}$ explodes). --- js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts | 4 +++- js/web/test/data/ops/tanh.jsonc | 26 +++++++++++++++++++++ js/web/test/suite-test-list.jsonc | 1 + 3 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 js/web/test/data/ops/tanh.jsonc diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 82311d72e58b9..76929efb32537 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -279,7 +279,9 @@ export const tan = (context: ComputeContext): void => { }; export const tanh = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tanh', 'tanh')); + // TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved + context.compute(createElementwiseProgramInfo( + context.inputs[0], 'Tanh', a => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`)); }; export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => { diff --git a/js/web/test/data/ops/tanh.jsonc b/js/web/test/data/ops/tanh.jsonc new file mode 100644 index 0000000000000..f7691535bd71c --- /dev/null +++ b/js/web/test/data/ops/tanh.jsonc @@ -0,0 +1,26 @@ +[ + { + "name": "tanh with no attributes", + "operator": "Tanh", + "attributes": [], + "cases": [ + { + "name": "T[2,4]", + "inputs": [ + { + "data": [-1000, -1, 0, 0.1, 0.2, 0.3, 0.4, 1000], + "dims": [2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-1, -0.761594, 0, 0.099668, 0.197375, 0.291313, 0.379949, 1], + "dims": [2, 4], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 373b3c645df57..56db28b0a379c 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1389,6 +1389,7 @@ "sub.jsonc", "sub_int32.jsonc", "tan.jsonc", + "tanh.jsonc", "tile.jsonc", "transpose.jsonc", "transpose_int32_uint32.jsonc",