Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support corr in GroupBy.apply through the jit engine #13767

Merged
merged 13 commits into from
Aug 2, 2023
41 changes: 41 additions & 0 deletions python/cudf/cudf/core/udf/groupby_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,46 @@ def group_reduction_impl_basic(context, builder, sig, args, function):
)


def group_corr(context, builder, sig, args):
"""
Instruction boilerplate used for calling a groupby correlation
"""
lhs_grp = cgutils.create_struct_proxy(sig.args[0])(
context, builder, value=args[0]
)
rhs_grp = cgutils.create_struct_proxy(sig.args[1])(
context, builder, value=args[1]
)

device_func = call_cuda_functions["corr"][
(
sig.return_type,
sig.args[0].group_scalar_type,
sig.args[1].group_scalar_type,
)
]
result = context.compile_internal(
builder,
device_func,
nb_signature(
types.float64,
types.CPointer(
sig.args[0].group_scalar_type
), # this group calls corr
types.CPointer(
sig.args[1].group_scalar_type
), # this group is passed
group_size_type,
),
(
lhs_grp.group_data,
rhs_grp.group_data,
lhs_grp.size,
),
)
return result


@lower_builtin(Group, types.Array, group_size_type, types.Array)
def group_constructor(context, builder, sig, args):
"""
Expand Down Expand Up @@ -155,3 +195,4 @@ def cuda_Group_size(context, builder, sig, args):
cuda_lower("GroupType.idxmin", GroupType(ty, types.int64))(
cuda_Group_idxmin
)
cuda_lower("GroupType.corr", GroupType(ty), GroupType(ty))(group_corr)
67 changes: 48 additions & 19 deletions python/cudf/cudf/core/udf/groupby_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,22 @@ def __init__(self, dmm, fe_type):
call_cuda_functions: Dict[Any, Any] = {}


def _register_cuda_reduction_caller(funcname, inputty, retty):
def _register_cuda_binary_reduction_caller(funcname, lty, rty, retty):
cuda_func = cuda.declare_device(
f"Block{funcname}_{lty}_{rty}",
retty(types.CPointer(lty), types.CPointer(rty), group_size_type),
)

def caller(lhs, rhs, size):
return cuda_func(lhs, rhs, size)

call_cuda_functions.setdefault(funcname.lower(), {})

type_key = (retty, lty, rty)
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
call_cuda_functions[funcname.lower()][type_key] = caller


def _register_cuda_unary_reduction_caller(funcname, inputty, retty):
cuda_func = cuda.declare_device(
f"Block{funcname}_{inputty}",
retty(types.CPointer(inputty), group_size_type),
Expand Down Expand Up @@ -191,6 +206,13 @@ def generic(self, args, kws):
return nb_signature(self.this.index_type, recvr=self.this)


class GroupCorr(AbstractTemplate):
key = "GroupType.corr"

def generic(self, args, kws):
return nb_signature(types.float64, args[0], recvr=self.this)


@cuda_registry.register_attr
class GroupAttr(AttributeTemplate):
key = GroupType
Expand Down Expand Up @@ -220,33 +242,40 @@ def resolve_idxmin(self, mod):
GroupIdxMin, GroupType(mod.group_scalar_type, mod.index_type)
)

def resolve_corr(self, mod):
return types.BoundFunction(
GroupCorr, GroupType(mod.group_scalar_type, mod.index_type)
)


for ty in SUPPORTED_GROUPBY_NUMBA_TYPES:
_register_cuda_reduction_caller("Max", ty, ty)
_register_cuda_reduction_caller("Min", ty, ty)
_register_cuda_unary_reduction_caller("Max", ty, ty)
_register_cuda_unary_reduction_caller("Min", ty, ty)
_register_cuda_idx_reduction_caller("IdxMax", ty)
_register_cuda_idx_reduction_caller("IdxMin", ty)
_register_cuda_binary_reduction_caller("Corr", ty, ty, types.float64)


_register_cuda_reduction_caller("Sum", types.int32, types.int64)
_register_cuda_reduction_caller("Sum", types.int64, types.int64)
_register_cuda_reduction_caller("Sum", types.float32, types.float32)
_register_cuda_reduction_caller("Sum", types.float64, types.float64)
_register_cuda_unary_reduction_caller("Sum", types.int32, types.int64)
_register_cuda_unary_reduction_caller("Sum", types.int64, types.int64)
_register_cuda_unary_reduction_caller("Sum", types.float32, types.float32)
_register_cuda_unary_reduction_caller("Sum", types.float64, types.float64)


_register_cuda_reduction_caller("Mean", types.int32, types.float64)
_register_cuda_reduction_caller("Mean", types.int64, types.float64)
_register_cuda_reduction_caller("Mean", types.float32, types.float32)
_register_cuda_reduction_caller("Mean", types.float64, types.float64)
_register_cuda_unary_reduction_caller("Mean", types.int32, types.float64)
_register_cuda_unary_reduction_caller("Mean", types.int64, types.float64)
_register_cuda_unary_reduction_caller("Mean", types.float32, types.float32)
_register_cuda_unary_reduction_caller("Mean", types.float64, types.float64)

_register_cuda_reduction_caller("Std", types.int32, types.float64)
_register_cuda_reduction_caller("Std", types.int64, types.float64)
_register_cuda_reduction_caller("Std", types.float32, types.float32)
_register_cuda_reduction_caller("Std", types.float64, types.float64)
_register_cuda_unary_reduction_caller("Std", types.int32, types.float64)
_register_cuda_unary_reduction_caller("Std", types.int64, types.float64)
_register_cuda_unary_reduction_caller("Std", types.float32, types.float32)
_register_cuda_unary_reduction_caller("Std", types.float64, types.float64)

_register_cuda_reduction_caller("Var", types.int32, types.float64)
_register_cuda_reduction_caller("Var", types.int64, types.float64)
_register_cuda_reduction_caller("Var", types.float32, types.float32)
_register_cuda_reduction_caller("Var", types.float64, types.float64)
_register_cuda_unary_reduction_caller("Var", types.int32, types.float64)
_register_cuda_unary_reduction_caller("Var", types.int64, types.float64)
_register_cuda_unary_reduction_caller("Var", types.float32, types.float32)
_register_cuda_unary_reduction_caller("Var", types.float64, types.float64)


for attr in ("group_data", "index", "size"):
Expand Down
1 change: 0 additions & 1 deletion python/cudf/cudf/core/udf/groupby_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def _get_groupby_apply_kernel(frame, func, args):
"types": types,
}
kernel_string = _groupby_apply_kernel_string_from_template(frame, args)

kernel = _get_kernel(kernel_string, global_exec_context, None, func)

return kernel, return_type
Expand Down
16 changes: 16 additions & 0 deletions python/cudf/cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,8 @@ def groupby_jit_data():
df["key2"] = np.random.randint(0, 2, nelem)
df["val1"] = np.random.random(nelem)
df["val2"] = np.random.random(nelem)
df["val3"] = np.random.randint(0, 10, nelem)
df["val4"] = np.random.randint(0, 10, nelem)
return df


Expand Down Expand Up @@ -433,6 +435,20 @@ def func(df):
run_groupby_apply_jit_test(groupby_jit_data, func, ["key1"])


@pytest.mark.parametrize("dtype", SUPPORTED_GROUPBY_NUMPY_TYPES)
def test_groupby_apply_jit_correlation(groupby_jit_data, dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to test data with NaNs? Infinity? Empty groups? Negative numbers? etc.

I'd like to see stronger test coverage for much more of our JIT code paths, not just corr...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Closing the loop on this conversation, after some discussion offline it was found that significant changes are needed to robustly support special values for this reduction which we'll tackle in a separate pull request.

Copy link
Contributor

@bdice bdice Aug 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please file an issue for this -- and we also need to test the behavior of existing functions like variance and standard deviation for NaN support (do other functions ignore the NaN values like corr?).


groupby_jit_data["val3"] = groupby_jit_data["val3"].astype(dtype)
groupby_jit_data["val4"] = groupby_jit_data["val4"].astype(dtype)

keys = ["key1", "key2"]

def func(group):
return group["val3"].corr(group["val4"])

run_groupby_apply_jit_test(groupby_jit_data, func, keys)


@pytest.mark.parametrize("dtype", ["float64"])
@pytest.mark.parametrize("func", ["min", "max", "sum", "mean", "var", "std"])
@pytest.mark.parametrize("special_val", [np.nan, np.inf, -np.inf])
Expand Down
76 changes: 60 additions & 16 deletions python/cudf/udf_cpp/shim.cu
Original file line number Diff line number Diff line change
Expand Up @@ -437,37 +437,49 @@ __device__ double BlockMean(T const* data, int64_t size)
}

template <typename T>
__device__ double BlockVar(T const* data, int64_t size)
__device__ double BlockCoVar(T const* lhs, T const* rhs, int64_t size)
{
auto block = cooperative_groups::this_thread_block();

__shared__ double block_var;
__shared__ T block_sum;
__shared__ double block_covar;

__shared__ T block_sum_lhs;
__shared__ T block_sum_rhs;

if (block.thread_rank() == 0) {
block_var = 0;
block_sum = 0;
block_covar = 0;
block_sum_lhs = 0;
block_sum_rhs = 0;
}
block.sync();

T local_sum = 0;
double local_var = 0;
device_sum<T>(block, lhs, size, &block_sum_lhs);
device_sum<T>(block, rhs, size, &block_sum_rhs);
auto const mu_l = static_cast<double>(block_sum_lhs) / static_cast<double>(size);
auto const mu_r = static_cast<double>(block_sum_rhs) / static_cast<double>(size);
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved

device_sum<T>(block, data, size, &block_sum);

auto const mean = static_cast<double>(block_sum) / static_cast<double>(size);
double local_covar = 0;

for (int64_t idx = block.thread_rank(); idx < size; idx += block.size()) {
auto const delta = static_cast<double>(data[idx]) - mean;
local_var += delta * delta;
auto const delta =
(static_cast<double>(lhs[idx]) - mu_l) * (static_cast<double>(rhs[idx]) - mu_r);
local_covar += delta;
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
}

cuda::atomic_ref<double, cuda::thread_scope_block> ref{block_var};
ref.fetch_add(local_var, cuda::std::memory_order_relaxed);
cuda::atomic_ref<double, cuda::thread_scope_block> ref{block_covar};
ref.fetch_add(local_covar, cuda::std::memory_order_relaxed);
block.sync();

if (block.thread_rank() == 0) { block_var = block_var / static_cast<double>(size - 1); }
if (block.thread_rank() == 0) { block_covar = block_covar / static_cast<double>(size - 1); }
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
block.sync();
return block_var;

return block_covar;
}

template <typename T>
__device__ double BlockVar(T const* data, int64_t size)
{
return BlockCoVar<T>(data, data, size);
}

template <typename T>
Expand Down Expand Up @@ -620,6 +632,19 @@ __device__ int64_t BlockIdxMin(T const* data, int64_t* index, int64_t size)
return block_idx_min;
}

template <typename T>
__device__ double BlockCorr(T* const lhs_ptr, T* const rhs_ptr, int64_t size)
{
auto numerator = BlockCoVar(lhs_ptr, rhs_ptr, size);
auto denominator = BlockStd(lhs_ptr, size) * BlockStd<T>(rhs_ptr, size);

if (denominator == 0.0) {
return 0.0;
} else {
return numerator / denominator;
}
}

extern "C" {
#define make_definition(name, cname, type, return_type) \
__device__ int name##_##cname(return_type* numba_return_value, type* const data, int64_t size) \
Expand Down Expand Up @@ -684,3 +709,22 @@ make_definition_idx(BlockIdxMax, float32, float);
make_definition_idx(BlockIdxMax, float64, double);
#undef make_definition_idx
}

extern "C" {
#define make_definition_corr(name, cname, type) \
__device__ int name##_##cname##_##cname( \
double* numba_return_value, type* const lhs, type* const rhs, int64_t size) \
{ \
double const res = name<type>(lhs, rhs, size); \
*numba_return_value = res; \
__syncthreads(); \
return 0; \
}

make_definition_corr(BlockCorr, int32, int32_t);
make_definition_corr(BlockCorr, int64, int64_t);
make_definition_corr(BlockCorr, float32, float);
make_definition_corr(BlockCorr, float64, double);

#undef make_definition_corr
}