Skip to content

Commit

Permalink
[js/webgpu] Fix Tanh explosion (#19201)
Browse files Browse the repository at this point in the history
### Description
```math
\tanh(x)=\frac{e^x-e^{-x}}{e^x+e^{-x}}=
\left\{
\begin{array}{cc}
-\frac{1-e^{-2\cdot(-x)}}{1+e^{-2\cdot(-x)}}, & x<0 \\
0, & x=0 \\
\frac{1-e^{-2x}}{1+e^{-2x}}, & x>0
\end{array}
\right.
```

### Motivation and Context
On some platforms,
$$\tanh(1000)=\frac{e^{1000}-e^{-1000}}{e^{1000}+e^{-1000}}$$ would
produce NaN instead of 0.999... or 1 (imagine $e^{1000}=\infty$ and
$\frac{\infty}{\infty}$ explodes).
  • Loading branch information
hujiajie authored Jan 25, 2024
1 parent 1c92e56 commit 5b06505
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
4 changes: 3 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,9 @@ export const tan = (context: ComputeContext): void => {
};

export const tanh = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tanh', 'tanh'));
// TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'Tanh', a => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`));
};

export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => {
Expand Down
26 changes: 26 additions & 0 deletions js/web/test/data/ops/tanh.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[
{
"name": "tanh with no attributes",
"operator": "Tanh",
"attributes": [],
"cases": [
{
"name": "T[2,4]",
"inputs": [
{
"data": [-1000, -1, 0, 0.1, 0.2, 0.3, 0.4, 1000],
"dims": [2, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [-1, -0.761594, 0, 0.099668, 0.197375, 0.291313, 0.379949, 1],
"dims": [2, 4],
"type": "float32"
}
]
}
]
}
]
1 change: 1 addition & 0 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -1389,6 +1389,7 @@
"sub.jsonc",
"sub_int32.jsonc",
"tan.jsonc",
"tanh.jsonc",
"tile.jsonc",
"transpose.jsonc",
"transpose_int32_uint32.jsonc",
Expand Down

0 comments on commit 5b06505

Please sign in to comment.