Skip to content

Commit

Permalink
Porting ALICE-specific changes (alisw#6)
Browse files Browse the repository at this point in the history
* Additional native functions for float32 in Gandiva
* Extend bitwise operations for more int types
* Adapt to upstream trigonometric function definitions
* Prevent clash with c++ variant
  • Loading branch information
aalkin committed Jul 31, 2024
1 parent 6a2e19a commit a798984
Show file tree
Hide file tree
Showing 7 changed files with 351 additions and 45 deletions.
22 changes: 21 additions & 1 deletion cpp/src/gandiva/function_registry_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,12 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
UNARY_SAFE_NULL_IF_NULL(not, {}, boolean, boolean),
UNARY_SAFE_NULL_IF_NULL(castBIGINT, {}, int32, int64),
UNARY_SAFE_NULL_IF_NULL(castINT, {}, int64, int32),
UNARY_SAFE_NULL_IF_NULL(castINT, {}, int8, int32),
UNARY_SAFE_NULL_IF_NULL(castBIGINT, {}, decimal128, int64),

// cast to float32
UNARY_CAST_TO_FLOAT32(int32), UNARY_CAST_TO_FLOAT32(int64),
UNARY_CAST_TO_FLOAT32(float64),
UNARY_CAST_TO_FLOAT32(float64), UNARY_CAST_TO_FLOAT32(int8),

// cast to int32
UNARY_CAST_TO_INT32(float32), UNARY_CAST_TO_INT32(float64),
Expand Down Expand Up @@ -125,8 +126,27 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_or, {}, int64),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_xor, {"xor"}, int32),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_xor, {"xor"}, int64),

BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_and, {}, uint32),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_and, {}, uint64),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_or, {}, uint32),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_or, {}, uint64),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_xor, {}, uint32),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_xor, {}, uint64),

BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_and, {}, uint8),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_and, {}, uint16),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_or, {}, uint8),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_or, {}, uint16),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_xor, {}, uint8),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_xor, {}, uint16),

UNARY_SAFE_NULL_IF_NULL(bitwise_not, {}, int32, int32),
UNARY_SAFE_NULL_IF_NULL(bitwise_not, {}, int64, int64),
UNARY_SAFE_NULL_IF_NULL(bitwise_not, {}, uint32, uint32),
UNARY_SAFE_NULL_IF_NULL(bitwise_not, {}, uint64, uint64),
UNARY_SAFE_NULL_IF_NULL(bitwise_not, {}, uint16, uint16),
UNARY_SAFE_NULL_IF_NULL(bitwise_not, {}, uint8, uint8),

UNARY_SAFE_NULL_NEVER_BOOL(isnotfalse, ({"is not false"}), boolean),
UNARY_SAFE_NULL_NEVER_BOOL(isnottrue, ({"is not true"}), boolean),
Expand Down
25 changes: 25 additions & 0 deletions cpp/src/gandiva/function_registry_math_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

#include "gandiva/function_registry_math_ops.h"

#include "gandiva/function_registry_common.h"

namespace gandiva {
Expand All @@ -28,6 +29,11 @@ namespace gandiva {
UNARY_SAFE_NULL_IF_NULL(name, ALIASES, float32, float64), \
UNARY_SAFE_NULL_IF_NULL(name, ALIASES, float64, float64)

#define MATH_UNARY_OPS_FLOAT(name, ALIASES) \
UNARY_SAFE_NULL_IF_NULL(name, ALIASES, int32, float32), \
UNARY_SAFE_NULL_IF_NULL(name, ALIASES, uint32, float32), \
UNARY_SAFE_NULL_IF_NULL(name, ALIASES, float32, float32)

#define MATH_BINARY_UNSAFE(name, ALIASES) \
BINARY_UNSAFE_NULL_IF_NULL(name, ALIASES, int32, float64), \
BINARY_UNSAFE_NULL_IF_NULL(name, ALIASES, int64, float64), \
Expand All @@ -44,6 +50,11 @@ namespace gandiva {
BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, float32, float32, float64), \
BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, float64, float64, float64)

#define MATH_BINARY_SAFE_FLOAT(name, ALIASES) \
BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, int32, int32, float32), \
BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, uint32, uint32, float32), \
BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, float32, float32, float32)

#define UNARY_SAFE_NULL_NEVER_BOOL_FN(name, ALIASES) \
NUMERIC_BOOL_DATE_TYPES(UNARY_SAFE_NULL_NEVER_BOOL, name, ALIASES)

Expand All @@ -64,9 +75,14 @@ std::vector<NativeFunction> GetMathOpsFunctionRegistry() {
MATH_UNARY_OPS(cbrt, {}), MATH_UNARY_OPS(exp, {}), MATH_UNARY_OPS(log, {}),
MATH_UNARY_OPS(log10, {}),

MATH_UNARY_OPS_FLOAT(sqrtf, {}), MATH_UNARY_OPS_FLOAT(cbrtf, {}),
MATH_UNARY_OPS_FLOAT(expf, {}), MATH_UNARY_OPS_FLOAT(logf, {}),
MATH_UNARY_OPS_FLOAT(log10f, {}),

MATH_BINARY_UNSAFE(log, {}),

BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(power, {"pow"}, float64),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(powerf, {"powf"}, float32),

UNARY_SAFE_NULL_NEVER_BOOL_FN(isnull, {}),
UNARY_SAFE_NULL_NEVER_BOOL_FN(isnotnull, {}),
Expand All @@ -85,9 +101,16 @@ std::vector<NativeFunction> GetMathOpsFunctionRegistry() {
MATH_UNARY_OPS(sinh, {}), MATH_UNARY_OPS(cosh, {}), MATH_UNARY_OPS(tanh, {}),
MATH_UNARY_OPS(cot, {}), MATH_UNARY_OPS(radians, {}),
MATH_UNARY_OPS(degrees, {"udfdegrees"}), MATH_BINARY_SAFE(atan2, {}),
MATH_UNARY_OPS_FLOAT(sinf, {}), MATH_UNARY_OPS_FLOAT(cosf, {}),
MATH_UNARY_OPS_FLOAT(asinf, {}), MATH_UNARY_OPS_FLOAT(acosf, {}),
MATH_UNARY_OPS_FLOAT(tanf, {}), MATH_UNARY_OPS_FLOAT(atanf, {}),
MATH_UNARY_OPS_FLOAT(sinhf, {}), MATH_UNARY_OPS_FLOAT(coshf, {}),
MATH_UNARY_OPS_FLOAT(tanhf, {}), MATH_UNARY_OPS_FLOAT(cotf, {}),
MATH_BINARY_SAFE_FLOAT(atan2f, {}),

// decimal functions
UNARY_SAFE_NULL_IF_NULL(abs, {}, decimal128, decimal128),
UNARY_SAFE_NULL_IF_NULL(absf, {}, float32, float32),
UNARY_SAFE_NULL_IF_NULL(ceil, {}, decimal128, decimal128),
UNARY_SAFE_NULL_IF_NULL(floor, {}, decimal128, decimal128),
UNARY_SAFE_NULL_IF_NULL(round, {}, decimal128, decimal128),
Expand All @@ -110,6 +133,8 @@ std::vector<NativeFunction> GetMathOpsFunctionRegistry() {

#undef MATH_UNARY_OPS

#undef MATH_UNARY_OPS_FLOAT

#undef MATH_BINARY_UNSAFE

#undef UNARY_SAFE_NULL_NEVER_BOOL_FN
Expand Down
20 changes: 20 additions & 0 deletions cpp/src/gandiva/precompiled/arithmetic_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,27 @@ extern "C" {
NUMERIC_TYPES(BINARY_SYMMETRIC, add, +)
NUMERIC_TYPES(BINARY_SYMMETRIC, subtract, -)
NUMERIC_TYPES(BINARY_SYMMETRIC, multiply, *)

BINARY_SYMMETRIC(bitwise_and, int32, &)
BINARY_SYMMETRIC(bitwise_and, int64, &)
BINARY_SYMMETRIC(bitwise_or, int32, |)
BINARY_SYMMETRIC(bitwise_or, int64, |)
BINARY_SYMMETRIC(bitwise_xor, int32, ^)
BINARY_SYMMETRIC(bitwise_xor, int64, ^)

BINARY_SYMMETRIC(bitwise_and, uint32, &)
BINARY_SYMMETRIC(bitwise_and, uint64, &)
BINARY_SYMMETRIC(bitwise_or, uint32, |)
BINARY_SYMMETRIC(bitwise_or, uint64, |)
BINARY_SYMMETRIC(bitwise_xor, uint32, ^)
BINARY_SYMMETRIC(bitwise_xor, uint64, ^)

BINARY_SYMMETRIC(bitwise_and, uint8, &)
BINARY_SYMMETRIC(bitwise_and, uint16, &)
BINARY_SYMMETRIC(bitwise_or, uint8, |)
BINARY_SYMMETRIC(bitwise_or, uint16, |)
BINARY_SYMMETRIC(bitwise_xor, uint8, ^)
BINARY_SYMMETRIC(bitwise_xor, uint16, ^)
#undef BINARY_SYMMETRIC

MOD_OP(mod, int64, int32, int32)
Expand Down Expand Up @@ -202,6 +216,8 @@ NUMERIC_DATE_TYPES(COMPARE_SIX_VALUES, least, <)

CAST_UNARY(castBIGINT, int32, int64)
CAST_UNARY(castINT, int64, int32)
CAST_UNARY(castINT, int8, int32)
CAST_UNARY(castFLOAT4, int8, float32)
CAST_UNARY(castFLOAT4, int32, float32)
CAST_UNARY(castFLOAT4, int64, float32)
CAST_UNARY(castFLOAT8, int32, float64)
Expand Down Expand Up @@ -466,6 +482,10 @@ DIV_FLOAT(float64)

BITWISE_NOT(int32)
BITWISE_NOT(int64)
BITWISE_NOT(uint64)
BITWISE_NOT(uint32)
BITWISE_NOT(uint16)
BITWISE_NOT(uint8)

#undef BITWISE_NOT

Expand Down
Loading

0 comments on commit a798984

Please sign in to comment.