Skip to content

Commit

Permalink
Revert acos changes for non complex numbers (#2449)
Browse files Browse the repository at this point in the history
It looks like something may be broken in the `chlo.acos` lowering from
functional algorithms. The complex lowering works just fine (has
`SelectOp`).

Also I'm wondering how we didn't catch this / could have caught this.
Perhaps we should have special cases for ops with limits..somehow? Open
to iterating on this, either by dev policy on these CLs (ensure total
code coverage) or if there's a way to auto generate cases would be good
too.

Current behavior prior to revert: `chlo.acos(-1) --> 0`, expected
behavior is `pi`.
  • Loading branch information
GleasonK authored Jul 22, 2024
1 parent c28d55e commit 70c210d
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 20 deletions.
14 changes: 14 additions & 0 deletions stablehlo/tests/math/acos_limits.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s | stablehlo-translate --interpret

func.func @main() -> (tensor<f64>, tensor<complex<f64>>) {
%cst = stablehlo.constant dense<-1.000000e+00> : tensor<f64>
%cst_0 = stablehlo.constant dense<(-1.000000e+00,0.000000e+00)> : tensor<complex<f64>>
%zero = stablehlo.constant dense<0.0> : tensor<f64>
%pi = stablehlo.constant dense<3.1415926535897931> : tensor<f64>
%complex_pi = stablehlo.complex %pi, %zero : tensor<complex<f64>>
%0 = chlo.acos %cst : tensor<f64> -> tensor<f64>
%1 = chlo.acos %cst_0 : tensor<complex<f64>> -> tensor<complex<f64>>
check.expect_close %0, %pi, max_ulp_difference = 1 : tensor<f64>, tensor<f64>
check.expect_close %1, %complex_pi, max_ulp_difference = 1 : tensor<complex<f64>>, tensor<complex<f64>>
return %0, %1 : tensor<f64>, tensor<complex<f64>>
}
31 changes: 31 additions & 0 deletions stablehlo/transforms/ChloDecompositionPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,37 @@ def StableHLO_ConstantLikeSmallestNormalizedValue : NativeCodeCall<
// Unary op patterns.
//===----------------------------------------------------------------------===//

// Expand acos for non-complex arguments to MHLO dialect as follows:
// acos(x) = 2 * atan2(sqrt(1 - x^2), (1 + x)) if x != -1
// = pi if x == -1
//
// Note: Complex decomposition is in ChloDecompositionPatternsMath.td
def : Pat<(CHLO_AcosOp NonComplexElementType:$input),
(StableHLO_SelectOp
(StableHLO_CompareOp
$input,
(StableHLO_ConstantLike<"-1"> $input),
StableHLO_ComparisonDirectionValue<"NE">,
(STABLEHLO_DEFAULT_COMPARISON_TYPE)
),
(StableHLO_MulOp
(StableHLO_ConstantLike<"2"> $input),
(StableHLO_Atan2Op
(StableHLO_SqrtOp
(StableHLO_SubtractOp
(StableHLO_ConstantLike<"1"> $input),
(StableHLO_MulOp $input, $input)
)
),
(StableHLO_AddOp
(StableHLO_ConstantLike<"1"> $input),
$input
)
)
),
(StableHLO_ConstantLike<"M_PI"> $input)
)>;

// Express `atan` as
// atan(x) = atan2(x, 1)
def : Pat<(CHLO_AtanOp $input),
Expand Down
20 changes: 0 additions & 20 deletions stablehlo/transforms/ChloDecompositionPatternsMath.td
Original file line number Diff line number Diff line change
Expand Up @@ -635,26 +635,6 @@ def : Pat<(CHLO_AcosOp ComplexElementType:$z),
(StableHLO_AddOp $am1, $sq)))),
(StableHLO_NegOp $imag)))>;

// Arcus cosine on real input:
//
// arccos(x) = 2 * arctan2(sqrt(1 - x * x), 1 + x)
//
// To avoid cancellation errors at abs(x) close to 1, we'll use
//
// 1 - x * x == (1 - x) * (1 + x)
//
def : Pat<(CHLO_AcosOp NonComplexElementType:$x),
(StableHLO_MulOp
(StableHLO_ConstantLike<"2"> $x),
(StableHLO_Atan2Op
(StableHLO_SqrtOp
(StableHLO_MulOp
(StableHLO_SubtractOp
(StableHLO_ConstantLike<"1">:$one $x),
$x),
(StableHLO_AddOp:$add_one_x $one, $x))),
$add_one_x))>;

// Inverse hyperbolic cosine on complex input:
//
// acosh(z) = sqrt(z - 1) / sqrt(1 - z) * acos(z)
Expand Down

0 comments on commit 70c210d

Please sign in to comment.