Skip to content

Commit

Permalink
Add CHLO log1p operation.
Browse files Browse the repository at this point in the history
  • Loading branch information
pearu committed Nov 22, 2024
1 parent 7efac85 commit 27f35fc
Show file tree
Hide file tree
Showing 12 changed files with 458 additions and 2 deletions.
2 changes: 1 addition & 1 deletion build_tools/math/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ following requirements:

- Python 3.11 or newer
- mpmath 1.3 or newer
- functional_algorithms 0.11.1 or newer
- functional_algorithms 0.12 or newer

that can be installed via pypi:

Expand Down
1 change: 1 addition & 0 deletions build_tools/math/generate_ChloDecompositionPatternsMath.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def main():
("CHLO_AtanhOp", "complex_atanh", ("z:complex",)),
("CHLO_SquareOp", "complex_square", ("z:complex",)),
("CHLO_SquareOp", "real_square", ("x:float",)),
("CHLO_Log1pOp", "complex_log1p", ("z:complex",)),
]:
print(f'Generating {chloname} from {fname}{args}')
func = getattr(fa.algorithms, fname, None)
Expand Down
1 change: 1 addition & 0 deletions build_tools/math/generate_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
dict(name="acosh", mpmath_name="arccosh"),
dict(name="atanh", mpmath_name="arctanh"),
dict(name="square", mpmath_name="square"),
dict(name="log1p", mpmath_name="log1p"),
]


Expand Down
1 change: 1 addition & 0 deletions stablehlo/dialect/ChloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfcOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfInvOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(LgammaOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Log1pOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NextAfterOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(PolygammaOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SinhOp)
Expand Down
14 changes: 14 additions & 0 deletions stablehlo/dialect/ChloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,20 @@ def CHLO_LgammaOp : CHLO_UnaryElementwiseOp<"lgamma",
}];
}

def CHLO_Log1pOp : CHLO_UnaryElementwiseOp<"log1p",
[HLO_CompatibleOperandsAndResultType], HLO_AnyFpOrComplexTensor> {
let summary = "Log1p function";

let description = [{
Returns `Log1p(operand)` element-wise.

$$
\log1p(x) = complex(log(hypot(x.real + 1, x.imag)), arctan2(x.imag, x.real + 1)) if x is a complex number
= log(x + 1) otherwise
$$
}];
}

def CHLO_SquareOp : CHLO_UnaryElementwiseOp<"square",
[HLO_CompatibleOperandsAndResultType], HLO_AnyFpOrComplexTensor> {
let summary = "Square operation";
Expand Down
119 changes: 119 additions & 0 deletions stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3925,3 +3925,122 @@ func.func @square_f32(%arg : tensor<f32>) -> tensor<f32> {
%result = "chlo.square"(%arg) : (tensor<f32>) -> tensor<f32>
func.return %result : tensor<f32>
}

// CHECK-LABEL: @log1p_complex_f32(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<complex<f32>>) -> tensor<complex<f32>> {
// CHECK: %[[VAL_1:.*]] = stablehlo.real %[[VAL_0]] : (tensor<complex<f32>>) -> tensor<f32>
// CHECK: %[[VAL_2:.*]] = stablehlo.abs %[[VAL_1]] : tensor<f32>
// CHECK: %[[VAL_3:.*]] = stablehlo.imag %[[VAL_0]] : (tensor<complex<f32>>) -> tensor<f32>
// CHECK: %[[VAL_4:.*]] = stablehlo.abs %[[VAL_3]] : tensor<f32>
// CHECK: %[[VAL_5:.*]] = stablehlo.maximum %[[VAL_2]], %[[VAL_4]] : tensor<f32>
// CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<3.40282347E+38> : tensor<f32>
// CHECK: %[[VAL_7:.*]] = stablehlo.sqrt %[[VAL_6]] : tensor<f32>
// CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.00999999977> : tensor<f32>
// CHECK: %[[VAL_9:.*]] = stablehlo.multiply %[[VAL_7]], %[[VAL_8]] : tensor<f32>
// CHECK: %[[VAL_10:.*]] = stablehlo.compare GT, %[[VAL_5]], %[[VAL_9]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_11:.*]] = stablehlo.log %[[VAL_5]] : tensor<f32>
// CHECK: %[[VAL_12:.*]] = stablehlo.constant dense<5.000000e-01> : tensor<f32>
// CHECK: %[[VAL_13:.*]] = stablehlo.minimum %[[VAL_2]], %[[VAL_4]] : tensor<f32>
// CHECK: %[[VAL_14:.*]] = stablehlo.compare EQ, %[[VAL_13]], %[[VAL_5]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_15:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: %[[VAL_16:.*]] = stablehlo.divide %[[VAL_13]], %[[VAL_5]] : tensor<f32>
// CHECK: %[[VAL_17:.*]] = stablehlo.multiply %[[VAL_16]], %[[VAL_16]] : tensor<f32>
// CHECK: %[[VAL_18:.*]] = stablehlo.select %[[VAL_14]], %[[VAL_15]], %[[VAL_17]] : tensor<i1>, tensor<f32>
// CHECK: %[[VAL_19:.*]] = stablehlo.log_plus_one %[[VAL_18]] : tensor<f32>
// CHECK: %[[VAL_20:.*]] = stablehlo.multiply %[[VAL_12]], %[[VAL_19]] : tensor<f32>
// CHECK: %[[VAL_21:.*]] = stablehlo.add %[[VAL_11]], %[[VAL_20]] : tensor<f32>
// CHECK: %[[VAL_22:.*]] = stablehlo.add %[[VAL_1]], %[[VAL_15]] : tensor<f32>
// CHECK: %[[VAL_23:.*]] = stablehlo.abs %[[VAL_22]] : tensor<f32>
// CHECK: %[[VAL_24:.*]] = stablehlo.add %[[VAL_23]], %[[VAL_4]] : tensor<f32>
// CHECK: %[[VAL_25:.*]] = stablehlo.constant dense<2.000000e-01> : tensor<f32>
// CHECK: %[[VAL_26:.*]] = stablehlo.compare LT, %[[VAL_24]], %[[VAL_25]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_27:.*]] = stablehlo.multiply %[[VAL_22]], %[[VAL_22]] : tensor<f32>
// CHECK: %[[VAL_28:.*]] = stablehlo.multiply %[[VAL_3]], %[[VAL_3]] : tensor<f32>
// CHECK: %[[VAL_29:.*]] = stablehlo.add %[[VAL_27]], %[[VAL_28]] : tensor<f32>
// CHECK: %[[VAL_30:.*]] = stablehlo.log %[[VAL_29]] : tensor<f32>
// CHECK: %[[VAL_31:.*]] = stablehlo.multiply %[[VAL_12]], %[[VAL_30]] : tensor<f32>
// CHECK: %[[VAL_32:.*]] = stablehlo.add %[[VAL_1]], %[[VAL_1]] : tensor<f32>
// CHECK: %[[VAL_33:.*]] = stablehlo.add %[[VAL_32]], %[[VAL_28]] : tensor<f32>
// CHECK: %[[VAL_34:.*]] = stablehlo.multiply %[[VAL_1]], %[[VAL_1]] : tensor<f32>
// CHECK: %[[VAL_35:.*]] = stablehlo.add %[[VAL_33]], %[[VAL_34]] : tensor<f32>
// CHECK: %[[VAL_36:.*]] = stablehlo.negate %[[VAL_28]] : tensor<f32>
// CHECK: %[[VAL_37:.*]] = stablehlo.constant dense<0x7F800000> : tensor<f32>
// CHECK: %[[VAL_38:.*]] = stablehlo.compare GT, %[[VAL_6]], %[[VAL_37]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_39:.*]] = stablehlo.constant dense<0x4D000000> : tensor<f32>
// CHECK: %[[VAL_40:.*]] = stablehlo.constant dense<9.99999968E+37> : tensor<f32>
// CHECK: %[[VAL_41:.*]] = stablehlo.compare GT, %[[VAL_6]], %[[VAL_40]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_42:.*]] = stablehlo.constant dense<4.097000e+03> : tensor<f32>
// CHECK: %[[VAL_43:.*]] = stablehlo.constant dense<6.500000e+01> : tensor<f32>
// CHECK: %[[VAL_44:.*]] = stablehlo.select %[[VAL_41]], %[[VAL_42]], %[[VAL_43]] : tensor<i1>, tensor<f32>
// CHECK: %[[VAL_45:.*]] = stablehlo.select %[[VAL_38]], %[[VAL_39]], %[[VAL_44]] : tensor<i1>, tensor<f32>
// CHECK: %[[VAL_46:.*]] = stablehlo.multiply %[[VAL_45]], %[[VAL_3]] : tensor<f32>
// CHECK: %[[VAL_47:.*]] = stablehlo.subtract %[[VAL_3]], %[[VAL_46]] : tensor<f32>
// CHECK: %[[VAL_48:.*]] = stablehlo.add %[[VAL_46]], %[[VAL_47]] : tensor<f32>
// CHECK: %[[VAL_49:.*]] = stablehlo.multiply %[[VAL_48]], %[[VAL_48]] : tensor<f32>
// CHECK: %[[VAL_50:.*]] = stablehlo.add %[[VAL_36]], %[[VAL_49]] : tensor<f32>
// CHECK: %[[VAL_51:.*]] = stablehlo.subtract %[[VAL_3]], %[[VAL_48]] : tensor<f32>
// CHECK: %[[VAL_52:.*]] = stablehlo.multiply %[[VAL_48]], %[[VAL_51]] : tensor<f32>
// CHECK: %[[VAL_53:.*]] = stablehlo.add %[[VAL_50]], %[[VAL_52]] : tensor<f32>
// CHECK: %[[VAL_54:.*]] = stablehlo.add %[[VAL_53]], %[[VAL_52]] : tensor<f32>
// CHECK: %[[VAL_55:.*]] = stablehlo.multiply %[[VAL_51]], %[[VAL_51]] : tensor<f32>
// CHECK: %[[VAL_56:.*]] = stablehlo.add %[[VAL_54]], %[[VAL_55]] : tensor<f32>
// CHECK: %[[VAL_57:.*]] = stablehlo.add %[[VAL_35]], %[[VAL_56]] : tensor<f32>
// CHECK: %[[VAL_58:.*]] = stablehlo.negate %[[VAL_34]] : tensor<f32>
// CHECK: %[[VAL_59:.*]] = stablehlo.multiply %[[VAL_45]], %[[VAL_1]] : tensor<f32>
// CHECK: %[[VAL_60:.*]] = stablehlo.subtract %[[VAL_1]], %[[VAL_59]] : tensor<f32>
// CHECK: %[[VAL_61:.*]] = stablehlo.add %[[VAL_59]], %[[VAL_60]] : tensor<f32>
// CHECK: %[[VAL_62:.*]] = stablehlo.multiply %[[VAL_61]], %[[VAL_61]] : tensor<f32>
// CHECK: %[[VAL_63:.*]] = stablehlo.add %[[VAL_58]], %[[VAL_62]] : tensor<f32>
// CHECK: %[[VAL_64:.*]] = stablehlo.subtract %[[VAL_1]], %[[VAL_61]] : tensor<f32>
// CHECK: %[[VAL_65:.*]] = stablehlo.multiply %[[VAL_61]], %[[VAL_64]] : tensor<f32>
// CHECK: %[[VAL_66:.*]] = stablehlo.add %[[VAL_63]], %[[VAL_65]] : tensor<f32>
// CHECK: %[[VAL_67:.*]] = stablehlo.add %[[VAL_66]], %[[VAL_65]] : tensor<f32>
// CHECK: %[[VAL_68:.*]] = stablehlo.multiply %[[VAL_64]], %[[VAL_64]] : tensor<f32>
// CHECK: %[[VAL_69:.*]] = stablehlo.add %[[VAL_67]], %[[VAL_68]] : tensor<f32>
// CHECK: %[[VAL_70:.*]] = stablehlo.add %[[VAL_57]], %[[VAL_69]] : tensor<f32>
// CHECK: %[[VAL_71:.*]] = stablehlo.subtract %[[VAL_33]], %[[VAL_32]] : tensor<f32>
// CHECK: %[[VAL_72:.*]] = stablehlo.subtract %[[VAL_33]], %[[VAL_71]] : tensor<f32>
// CHECK: %[[VAL_73:.*]] = stablehlo.subtract %[[VAL_32]], %[[VAL_72]] : tensor<f32>
// CHECK: %[[VAL_74:.*]] = stablehlo.subtract %[[VAL_28]], %[[VAL_71]] : tensor<f32>
// CHECK: %[[VAL_75:.*]] = stablehlo.add %[[VAL_73]], %[[VAL_74]] : tensor<f32>
// CHECK: %[[VAL_76:.*]] = stablehlo.subtract %[[VAL_35]], %[[VAL_33]] : tensor<f32>
// CHECK: %[[VAL_77:.*]] = stablehlo.subtract %[[VAL_35]], %[[VAL_76]] : tensor<f32>
// CHECK: %[[VAL_78:.*]] = stablehlo.subtract %[[VAL_33]], %[[VAL_77]] : tensor<f32>
// CHECK: %[[VAL_79:.*]] = stablehlo.subtract %[[VAL_34]], %[[VAL_76]] : tensor<f32>
// CHECK: %[[VAL_80:.*]] = stablehlo.add %[[VAL_78]], %[[VAL_79]] : tensor<f32>
// CHECK: %[[VAL_81:.*]] = stablehlo.add %[[VAL_75]], %[[VAL_80]] : tensor<f32>
// CHECK: %[[VAL_82:.*]] = stablehlo.subtract %[[VAL_57]], %[[VAL_35]] : tensor<f32>
// CHECK: %[[VAL_83:.*]] = stablehlo.subtract %[[VAL_57]], %[[VAL_82]] : tensor<f32>
// CHECK: %[[VAL_84:.*]] = stablehlo.subtract %[[VAL_35]], %[[VAL_83]] : tensor<f32>
// CHECK: %[[VAL_85:.*]] = stablehlo.subtract %[[VAL_56]], %[[VAL_82]] : tensor<f32>
// CHECK: %[[VAL_86:.*]] = stablehlo.add %[[VAL_84]], %[[VAL_85]] : tensor<f32>
// CHECK: %[[VAL_87:.*]] = stablehlo.add %[[VAL_81]], %[[VAL_86]] : tensor<f32>
// CHECK: %[[VAL_88:.*]] = stablehlo.subtract %[[VAL_70]], %[[VAL_57]] : tensor<f32>
// CHECK: %[[VAL_89:.*]] = stablehlo.subtract %[[VAL_70]], %[[VAL_88]] : tensor<f32>
// CHECK: %[[VAL_90:.*]] = stablehlo.subtract %[[VAL_57]], %[[VAL_89]] : tensor<f32>
// CHECK: %[[VAL_91:.*]] = stablehlo.subtract %[[VAL_69]], %[[VAL_88]] : tensor<f32>
// CHECK: %[[VAL_92:.*]] = stablehlo.add %[[VAL_90]], %[[VAL_91]] : tensor<f32>
// CHECK: %[[VAL_93:.*]] = stablehlo.add %[[VAL_87]], %[[VAL_92]] : tensor<f32>
// CHECK: %[[VAL_94:.*]] = stablehlo.add %[[VAL_70]], %[[VAL_93]] : tensor<f32>
// CHECK: %[[VAL_95:.*]] = stablehlo.log_plus_one %[[VAL_94]] : tensor<f32>
// CHECK: %[[VAL_96:.*]] = stablehlo.multiply %[[VAL_12]], %[[VAL_95]] : tensor<f32>
// CHECK: %[[VAL_97:.*]] = stablehlo.select %[[VAL_26]], %[[VAL_31]], %[[VAL_96]] : tensor<i1>, tensor<f32>
// CHECK: %[[VAL_98:.*]] = stablehlo.select %[[VAL_10]], %[[VAL_21]], %[[VAL_97]] : tensor<i1>, tensor<f32>
// CHECK: %[[VAL_99:.*]] = stablehlo.atan2 %[[VAL_3]], %[[VAL_22]] : tensor<f32>
// CHECK: %[[VAL_100:.*]] = stablehlo.complex %[[VAL_98]], %[[VAL_99]] : tensor<complex<f32>>
// CHECK: return %[[VAL_100]] : tensor<complex<f32>>
// CHECK: }
func.func @log1p_complex_f32(%arg : tensor<complex<f32>>) -> tensor<complex<f32>> {
%result = "chlo.log1p"(%arg) : (tensor<complex<f32>>) -> tensor<complex<f32>>
func.return %result : tensor<complex<f32>>
}

// CHECK-LABEL: @log1p_f32(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<f32>) -> tensor<f32> {
// CHECK: %[[VAL_1:.*]] = stablehlo.log_plus_one %[[VAL_0]] : tensor<f32>
// CHECK: return %[[VAL_1]] : tensor<f32>
// CHECK: }
func.func @log1p_f32(%arg : tensor<f32>) -> tensor<f32> {
%result = "chlo.log1p"(%arg) : (tensor<f32>) -> tensor<f32>
func.return %result : tensor<f32>
}
Loading

0 comments on commit 27f35fc

Please sign in to comment.