From 1742685826738f955245194e3e3cce0c1a034568 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Thu, 13 Jul 2023 17:01:55 -0400 Subject: [PATCH 01/11] Initial commit --- python/cudf/cudf/core/udf/groupby_lowering.py | 58 ++++++++++++++- python/cudf/cudf/core/udf/groupby_typing.py | 12 +++ python/cudf/cudf/core/udf/groupby_utils.py | 1 - python/cudf/cudf/tests/test_groupby.py | 24 ++++++ python/cudf/udf_cpp/shim.cu | 74 +++++++++++++++++++ 5 files changed, 167 insertions(+), 2 deletions(-) diff --git a/python/cudf/cudf/core/udf/groupby_lowering.py b/python/cudf/cudf/core/udf/groupby_lowering.py index 376eccb9308..2c81fc8ea76 100644 --- a/python/cudf/cudf/core/udf/groupby_lowering.py +++ b/python/cudf/cudf/core/udf/groupby_lowering.py @@ -2,7 +2,7 @@ from functools import partial -from numba import types +from numba import cuda, types from numba.core import cgutils from numba.core.extending import lower_builtin from numba.core.typing import signature as nb_signature @@ -55,6 +55,62 @@ def group_reduction_impl_basic(context, builder, sig, args, function): ) +block_corr = cuda.declare_device( + "BlockCorr", + types.float64( + types.CPointer(types.int64), + types.CPointer(types.int64), + group_size_type, + ), +) + + +def _block_corr(lhs_ptr, rhs_ptr, size): + return block_corr(lhs_ptr, rhs_ptr, size) + + +@cuda_lower( + "GroupType.corr", + GroupType(types.int64, types.int64), + GroupType(types.int64, types.int64), +) +def group_corr(context, builder, sig, args): + """ + Instruction boilerplate used for calling a groupby correlation + """ + lhs_grp = cgutils.create_struct_proxy(sig.args[0])( + context, builder, value=args[0] + ) + rhs_grp = cgutils.create_struct_proxy(sig.args[1])( + context, builder, value=args[1] + ) + + # logically take the address of the group's data pointer + lhs_group_data_ptr = builder.alloca(lhs_grp.group_data.type) + builder.store(lhs_grp.group_data, lhs_group_data_ptr) + + # logically take the address of the group's data pointer + rhs_group_data_ptr = builder.alloca(rhs_grp.group_data.type) + builder.store(rhs_grp.group_data, rhs_group_data_ptr) + + result = context.compile_internal( + builder, + _block_corr, + nb_signature( + types.float64, + types.CPointer(types.int64), + types.CPointer(types.int64), + group_size_type, + ), + ( + builder.load(lhs_group_data_ptr), + builder.load(rhs_group_data_ptr), + lhs_grp.size, + ), + ) + return result + + @lower_builtin(Group, types.Array, group_size_type, types.Array) def group_constructor(context, builder, sig, args): """ diff --git a/python/cudf/cudf/core/udf/groupby_typing.py b/python/cudf/cudf/core/udf/groupby_typing.py index 37381a95fdf..435a64b5621 100644 --- a/python/cudf/cudf/core/udf/groupby_typing.py +++ b/python/cudf/cudf/core/udf/groupby_typing.py @@ -167,6 +167,13 @@ def generic(self, args, kws): return nb_signature(self.this.index_type, recvr=self.this) +class GroupCorr(AbstractTemplate): + key = "GroupType.corr" + + def generic(self, args, kws): + return nb_signature(types.float64, args[0], recvr=self.this) + + @cuda_registry.register_attr class GroupAttr(AttributeTemplate): key = GroupType @@ -197,6 +204,11 @@ def resolve_idxmin(self, mod): GroupIdxMin, GroupType(mod.group_scalar_type, mod.index_type) ) + def resolve_corr(self, mod): + return types.BoundFunction( + GroupCorr, GroupType(mod.group_scalar_type, mod.index_type) + ) + for ty in SUPPORTED_GROUPBY_NUMBA_TYPES: _register_cuda_reduction_caller("Max", ty, ty) diff --git a/python/cudf/cudf/core/udf/groupby_utils.py b/python/cudf/cudf/core/udf/groupby_utils.py index ca72c28cd5f..b18720f5db5 100644 --- a/python/cudf/cudf/core/udf/groupby_utils.py +++ b/python/cudf/cudf/core/udf/groupby_utils.py @@ -124,7 +124,6 @@ def _get_groupby_apply_kernel(frame, func, args): "types": types, } kernel_string = _groupby_apply_kernel_string_from_template(frame, args) - kernel = _get_kernel(kernel_string, global_exec_context, None, func) return kernel, return_type diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index dde80639fc7..06e986644eb 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -3302,3 +3302,27 @@ def test_group_by_pandas_sort_order(groups, sort): pdf.groupby(groups, sort=sort).sum(), df.groupby(groups, sort=sort).sum(), ) + + +def test_corr_jit(): + def func(group): + return group["b"].corr(group["c"]) + + size = int(1000000) + gdf = cudf.DataFrame( + { + "a": np.random.randint(0, 10000, size), + "b": np.random.randint(0, 1000, size), + "c": np.random.randint(0, 1000, size), + } + ) + gdf = gdf.sort_values("a") + pdf = gdf.to_pandas() + + gdf_grouped = gdf.groupby("a") + pdf_grouped = pdf.groupby("a", as_index=False) + + expect = pdf_grouped.apply(func) + got = gdf_grouped.apply(func, engine="jit") + + assert_eq(expect, got) diff --git a/python/cudf/udf_cpp/shim.cu b/python/cudf/udf_cpp/shim.cu index 63ad1039da6..b570e9edc41 100644 --- a/python/cudf/udf_cpp/shim.cu +++ b/python/cudf/udf_cpp/shim.cu @@ -662,3 +662,77 @@ make_definition_idx(BlockIdxMax, int64, int64_t); make_definition_idx(BlockIdxMax, float64, double); #undef make_definition_idx } + +extern "C" __device__ int BlockCorr(double* numba_return_value, + int64_t* const lhs_ptr, + int64_t* rhs_ptr, + int64_t size) +{ + double lhs_mean = BlockMean(lhs_ptr, size); + double rhs_mean = BlockMean(rhs_ptr, size); + + // cuda::atomic numerator = 0; + // cuda::atomic sum_sq_l = 0; + // cuda::atomic sum_sq_r = 0; + + __shared__ double numerators[1024]; + __shared__ double sum_sq_ls[1024]; + __shared__ double sum_sq_rs[1024]; + + numerators[threadIdx.x] = 0.0; + sum_sq_ls[threadIdx.x] = 0.0; + sum_sq_rs[threadIdx.x] = 0.0; + + auto block = cooperative_groups::this_thread_block(); + + for (int64_t idx = block.thread_rank(); idx < size; idx += block.size()) { + // numerator += data[idx]; + double delta_l = lhs_ptr[idx] - lhs_mean; + double delta_r = rhs_ptr[idx] - rhs_mean; + + numerators[idx] = delta_l * delta_r; + sum_sq_ls[idx] = (delta_l * delta_l); + sum_sq_rs[idx] = (delta_r * delta_r); + // printf("cuda THREAD INDEX=%d\n", threadIdx.x); + // printf(" GPU d_l=%.6f, d_r =%.6f, num=%.6f, sum_sq_l=%.6f, sum_sq_r=%.6f \n", delta_l, + // delta_r, numerator, sum_sq_l, sum_sq_r); + } + __syncthreads(); + + /* + if (threadIdx.x == 0 ){ + printf("nums:\n"); + + for (int i = 0; i < block.size(); i++) { + printf("%.6f ", numerators[i]); + } + printf("\n"); + } + */ + double numerator = BlockSum(numerators, block.size()); + double denominator = + sqrt(BlockSum(sum_sq_ls, block.size())) * sqrt(BlockSum(sum_sq_rs, block.size())); + // printf("GPU Numerator: %.6f, Denominator: %.6f", numerator, denominator); + // double denominator = sqrt(sum_sq_l) * sqrt(sum_sq_r); + + // double numsum = BlockSum(nums, block.size()); + // printf("NUMSUM: %.6f", numsum); + + block.sync(); + + if (denominator == 0.0) { return 0.0; } + *numba_return_value = numerator / denominator; + __syncthreads(); + return 0; + + // numerator = sum( + // (xi - mean_x) * (yi - mean_y) for xi, yi in zip(x, y) + // ) + // denominator = sqrt( + // sum( + // (xi - mean_x) ** 2 + // ) + // ) + // + // +} From 0bfc9a87ecc350eb741be720e00e148368f0f11e Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Wed, 26 Jul 2023 08:32:38 -0700 Subject: [PATCH 02/11] cleanup --- python/cudf/udf_cpp/shim.cu | 32 +------------------------------- 1 file changed, 1 insertion(+), 31 deletions(-) diff --git a/python/cudf/udf_cpp/shim.cu b/python/cudf/udf_cpp/shim.cu index b570e9edc41..d7307b97f8a 100644 --- a/python/cudf/udf_cpp/shim.cu +++ b/python/cudf/udf_cpp/shim.cu @@ -686,53 +686,23 @@ extern "C" __device__ int BlockCorr(double* numba_return_value, auto block = cooperative_groups::this_thread_block(); for (int64_t idx = block.thread_rank(); idx < size; idx += block.size()) { - // numerator += data[idx]; double delta_l = lhs_ptr[idx] - lhs_mean; double delta_r = rhs_ptr[idx] - rhs_mean; numerators[idx] = delta_l * delta_r; sum_sq_ls[idx] = (delta_l * delta_l); sum_sq_rs[idx] = (delta_r * delta_r); - // printf("cuda THREAD INDEX=%d\n", threadIdx.x); - // printf(" GPU d_l=%.6f, d_r =%.6f, num=%.6f, sum_sq_l=%.6f, sum_sq_r=%.6f \n", delta_l, - // delta_r, numerator, sum_sq_l, sum_sq_r); } __syncthreads(); - /* - if (threadIdx.x == 0 ){ - printf("nums:\n"); - - for (int i = 0; i < block.size(); i++) { - printf("%.6f ", numerators[i]); - } - printf("\n"); - } - */ double numerator = BlockSum(numerators, block.size()); double denominator = sqrt(BlockSum(sum_sq_ls, block.size())) * sqrt(BlockSum(sum_sq_rs, block.size())); - // printf("GPU Numerator: %.6f, Denominator: %.6f", numerator, denominator); - // double denominator = sqrt(sum_sq_l) * sqrt(sum_sq_r); - - // double numsum = BlockSum(nums, block.size()); - // printf("NUMSUM: %.6f", numsum); block.sync(); - if (denominator == 0.0) { return 0.0; } + if (denominator == 0.0) { *numba_return_value = 0.0; } *numba_return_value = numerator / denominator; __syncthreads(); return 0; - - // numerator = sum( - // (xi - mean_x) * (yi - mean_y) for xi, yi in zip(x, y) - // ) - // denominator = sqrt( - // sum( - // (xi - mean_x) ** 2 - // ) - // ) - // - // } From 3273c4f702d85e41527ac998105d6d73455beca7 Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Wed, 26 Jul 2023 11:59:06 -0700 Subject: [PATCH 03/11] reimplement var in terms of covar, corr in terms of covar and std --- python/cudf/udf_cpp/shim.cu | 85 +++++++++++++++---------------------- 1 file changed, 35 insertions(+), 50 deletions(-) diff --git a/python/cudf/udf_cpp/shim.cu b/python/cudf/udf_cpp/shim.cu index d7307b97f8a..9da4f904982 100644 --- a/python/cudf/udf_cpp/shim.cu +++ b/python/cudf/udf_cpp/shim.cu @@ -437,37 +437,49 @@ __device__ double BlockMean(T const* data, int64_t size) } template -__device__ double BlockVar(T const* data, int64_t size) +__device__ double BlockCoVar(T const* lhs, T const* rhs, int64_t size) { auto block = cooperative_groups::this_thread_block(); - __shared__ double block_var; - __shared__ T block_sum; + __shared__ double block_covar; + + __shared__ T block_sum_lhs; + __shared__ T block_sum_rhs; + if (block.thread_rank() == 0) { - block_var = 0; - block_sum = 0; + block_covar = 0; + block_sum_lhs = 0; + block_sum_rhs = 0; } block.sync(); - T local_sum = 0; - double local_var = 0; + device_sum(block, lhs, size, &block_sum_lhs); + device_sum(block, rhs, size, &block_sum_rhs); + auto const mu_l = static_cast(block_sum_lhs) / static_cast(size); + auto const mu_r = static_cast(block_sum_rhs) / static_cast(size); - device_sum(block, data, size, &block_sum); - - auto const mean = static_cast(block_sum) / static_cast(size); + double local_covar = 0; for (int64_t idx = block.thread_rank(); idx < size; idx += block.size()) { - auto const delta = static_cast(data[idx]) - mean; - local_var += delta * delta; + auto const delta = + (static_cast(lhs[idx]) - mu_l) * (static_cast(rhs[idx]) - mu_r); + local_covar += delta; } - cuda::atomic_ref ref{block_var}; - ref.fetch_add(local_var, cuda::std::memory_order_relaxed); + cuda::atomic_ref ref{block_covar}; + ref.fetch_add(local_covar, cuda::std::memory_order_relaxed); block.sync(); - if (block.thread_rank() == 0) { block_var = block_var / static_cast(size - 1); } + if (block.thread_rank() == 0) { block_covar = block_covar / static_cast(size - 1); } block.sync(); - return block_var; + + return block_covar; +} + +template +__device__ double BlockVar(T const* data, int64_t size) +{ + return BlockCoVar(data, data, size); } template @@ -665,44 +677,17 @@ make_definition_idx(BlockIdxMax, float64, double); extern "C" __device__ int BlockCorr(double* numba_return_value, int64_t* const lhs_ptr, - int64_t* rhs_ptr, + int64_t* const rhs_ptr, int64_t size) { - double lhs_mean = BlockMean(lhs_ptr, size); - double rhs_mean = BlockMean(rhs_ptr, size); - - // cuda::atomic numerator = 0; - // cuda::atomic sum_sq_l = 0; - // cuda::atomic sum_sq_r = 0; + auto numerator = BlockCoVar(lhs_ptr, rhs_ptr, size); + auto denominator = BlockStd(lhs_ptr, size) * BlockStd(rhs_ptr, size); - __shared__ double numerators[1024]; - __shared__ double sum_sq_ls[1024]; - __shared__ double sum_sq_rs[1024]; - - numerators[threadIdx.x] = 0.0; - sum_sq_ls[threadIdx.x] = 0.0; - sum_sq_rs[threadIdx.x] = 0.0; - - auto block = cooperative_groups::this_thread_block(); - - for (int64_t idx = block.thread_rank(); idx < size; idx += block.size()) { - double delta_l = lhs_ptr[idx] - lhs_mean; - double delta_r = rhs_ptr[idx] - rhs_mean; - - numerators[idx] = delta_l * delta_r; - sum_sq_ls[idx] = (delta_l * delta_l); - sum_sq_rs[idx] = (delta_r * delta_r); + if (denominator == 0.0) { + *numba_return_value = 0.0; + } else { + *numba_return_value = numerator / denominator; } __syncthreads(); - - double numerator = BlockSum(numerators, block.size()); - double denominator = - sqrt(BlockSum(sum_sq_ls, block.size())) * sqrt(BlockSum(sum_sq_rs, block.size())); - - block.sync(); - - if (denominator == 0.0) { *numba_return_value = 0.0; } - *numba_return_value = numerator / denominator; - __syncthreads(); return 0; } From 011fce703f02f565623339dad1542fba20061ae3 Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Thu, 27 Jul 2023 12:46:00 -0700 Subject: [PATCH 04/11] generalize dtypes, pass tests --- python/cudf/cudf/core/udf/groupby_lowering.py | 41 ++++++-------- python/cudf/cudf/core/udf/groupby_typing.py | 55 ++++++++++++------- python/cudf/cudf/tests/test_groupby.py | 40 ++++++-------- python/cudf/udf_cpp/shim.cu | 43 ++++++++++----- 4 files changed, 97 insertions(+), 82 deletions(-) diff --git a/python/cudf/cudf/core/udf/groupby_lowering.py b/python/cudf/cudf/core/udf/groupby_lowering.py index 2c81fc8ea76..eac01c476b6 100644 --- a/python/cudf/cudf/core/udf/groupby_lowering.py +++ b/python/cudf/cudf/core/udf/groupby_lowering.py @@ -2,7 +2,7 @@ from functools import partial -from numba import cuda, types +from numba import types from numba.core import cgutils from numba.core.extending import lower_builtin from numba.core.typing import signature as nb_signature @@ -55,25 +55,6 @@ def group_reduction_impl_basic(context, builder, sig, args, function): ) -block_corr = cuda.declare_device( - "BlockCorr", - types.float64( - types.CPointer(types.int64), - types.CPointer(types.int64), - group_size_type, - ), -) - - -def _block_corr(lhs_ptr, rhs_ptr, size): - return block_corr(lhs_ptr, rhs_ptr, size) - - -@cuda_lower( - "GroupType.corr", - GroupType(types.int64, types.int64), - GroupType(types.int64, types.int64), -) def group_corr(context, builder, sig, args): """ Instruction boilerplate used for calling a groupby correlation @@ -84,7 +65,6 @@ def group_corr(context, builder, sig, args): rhs_grp = cgutils.create_struct_proxy(sig.args[1])( context, builder, value=args[1] ) - # logically take the address of the group's data pointer lhs_group_data_ptr = builder.alloca(lhs_grp.group_data.type) builder.store(lhs_grp.group_data, lhs_group_data_ptr) @@ -92,14 +72,24 @@ def group_corr(context, builder, sig, args): # logically take the address of the group's data pointer rhs_group_data_ptr = builder.alloca(rhs_grp.group_data.type) builder.store(rhs_grp.group_data, rhs_group_data_ptr) - + device_func = call_cuda_functions["corr"][ + ( + sig.return_type, + sig.args[0].group_scalar_type, + sig.args[1].group_scalar_type, + ) + ] result = context.compile_internal( builder, - _block_corr, + device_func, nb_signature( types.float64, - types.CPointer(types.int64), - types.CPointer(types.int64), + types.CPointer( + sig.args[0].group_scalar_type + ), # this group calls corr + types.CPointer( + sig.args[1].group_scalar_type + ), # this group is passed group_size_type, ), ( @@ -211,3 +201,4 @@ def cuda_Group_size(context, builder, sig, args): cuda_lower("GroupType.idxmin", GroupType(ty, types.int64))( cuda_Group_idxmin ) + cuda_lower("GroupType.corr", GroupType(ty), GroupType(ty))(group_corr) diff --git a/python/cudf/cudf/core/udf/groupby_typing.py b/python/cudf/cudf/core/udf/groupby_typing.py index e77f93b0e8f..7c897b47330 100644 --- a/python/cudf/cudf/core/udf/groupby_typing.py +++ b/python/cudf/cudf/core/udf/groupby_typing.py @@ -104,7 +104,22 @@ def __init__(self, dmm, fe_type): call_cuda_functions: Dict[Any, Any] = {} -def _register_cuda_reduction_caller(funcname, inputty, retty): +def _register_cuda_binary_reduction_caller(funcname, lty, rty, retty): + cuda_func = cuda.declare_device( + f"Block{funcname}_{lty}_{rty}", + retty(types.CPointer(lty), types.CPointer(rty), group_size_type), + ) + + def caller(lhs, rhs, size): + return cuda_func(lhs, rhs, size) + + call_cuda_functions.setdefault(funcname.lower(), {}) + + type_key = (retty, lty, rty) + call_cuda_functions[funcname.lower()][type_key] = caller + + +def _register_cuda_unary_reduction_caller(funcname, inputty, retty): cuda_func = cuda.declare_device( f"Block{funcname}_{inputty}", retty(types.CPointer(inputty), group_size_type), @@ -234,31 +249,33 @@ def resolve_corr(self, mod): for ty in SUPPORTED_GROUPBY_NUMBA_TYPES: - _register_cuda_reduction_caller("Max", ty, ty) - _register_cuda_reduction_caller("Min", ty, ty) + _register_cuda_unary_reduction_caller("Max", ty, ty) + _register_cuda_unary_reduction_caller("Min", ty, ty) _register_cuda_idx_reduction_caller("IdxMax", ty) _register_cuda_idx_reduction_caller("IdxMin", ty) + _register_cuda_binary_reduction_caller("Corr", ty, ty, types.float64) + -_register_cuda_reduction_caller("Sum", types.int32, types.int64) -_register_cuda_reduction_caller("Sum", types.int64, types.int64) -_register_cuda_reduction_caller("Sum", types.float32, types.float32) -_register_cuda_reduction_caller("Sum", types.float64, types.float64) +_register_cuda_unary_reduction_caller("Sum", types.int32, types.int64) +_register_cuda_unary_reduction_caller("Sum", types.int64, types.int64) +_register_cuda_unary_reduction_caller("Sum", types.float32, types.float32) +_register_cuda_unary_reduction_caller("Sum", types.float64, types.float64) -_register_cuda_reduction_caller("Mean", types.int32, types.float64) -_register_cuda_reduction_caller("Mean", types.int64, types.float64) -_register_cuda_reduction_caller("Mean", types.float32, types.float32) -_register_cuda_reduction_caller("Mean", types.float64, types.float64) +_register_cuda_unary_reduction_caller("Mean", types.int32, types.float64) +_register_cuda_unary_reduction_caller("Mean", types.int64, types.float64) +_register_cuda_unary_reduction_caller("Mean", types.float32, types.float32) +_register_cuda_unary_reduction_caller("Mean", types.float64, types.float64) -_register_cuda_reduction_caller("Std", types.int32, types.float64) -_register_cuda_reduction_caller("Std", types.int64, types.float64) -_register_cuda_reduction_caller("Std", types.float32, types.float32) -_register_cuda_reduction_caller("Std", types.float64, types.float64) +_register_cuda_unary_reduction_caller("Std", types.int32, types.float64) +_register_cuda_unary_reduction_caller("Std", types.int64, types.float64) +_register_cuda_unary_reduction_caller("Std", types.float32, types.float32) +_register_cuda_unary_reduction_caller("Std", types.float64, types.float64) -_register_cuda_reduction_caller("Var", types.int32, types.float64) -_register_cuda_reduction_caller("Var", types.int64, types.float64) -_register_cuda_reduction_caller("Var", types.float32, types.float32) -_register_cuda_reduction_caller("Var", types.float64, types.float64) +_register_cuda_unary_reduction_caller("Var", types.int32, types.float64) +_register_cuda_unary_reduction_caller("Var", types.int64, types.float64) +_register_cuda_unary_reduction_caller("Var", types.float32, types.float32) +_register_cuda_unary_reduction_caller("Var", types.float64, types.float64) for attr in ("group_data", "index", "size"): diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index f4aa2426a70..3950a83e831 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -389,6 +389,8 @@ def groupby_jit_data(): df["key2"] = np.random.randint(0, 2, nelem) df["val1"] = np.random.random(nelem) df["val2"] = np.random.random(nelem) + df["val3"] = np.random.randint(0, 10, nelem) + df["val4"] = np.random.randint(0, 10, nelem) return df @@ -433,6 +435,20 @@ def func(df): run_groupby_apply_jit_test(groupby_jit_data, func, ["key1"]) +@pytest.mark.parametrize("dtype", SUPPORTED_GROUPBY_NUMPY_TYPES) +def test_groupby_apply_jit_correlation(groupby_jit_data, dtype): + + groupby_jit_data["val3"] = groupby_jit_data["val3"].astype(dtype) + groupby_jit_data["val4"] = groupby_jit_data["val4"].astype(dtype) + + keys = ["key1", "key2"] + + def func(group): + return group["val3"].corr(group["val4"]) + + run_groupby_apply_jit_test(groupby_jit_data, func, keys) + + @pytest.mark.parametrize("dtype", ["float64"]) @pytest.mark.parametrize("func", ["min", "max", "sum", "mean", "var", "std"]) @pytest.mark.parametrize("special_val", [np.nan, np.inf, -np.inf]) @@ -3306,27 +3322,3 @@ def test_group_by_pandas_sort_order(groups, sort): pdf.groupby(groups, sort=sort).sum(), df.groupby(groups, sort=sort).sum(), ) - - -def test_corr_jit(): - def func(group): - return group["b"].corr(group["c"]) - - size = int(1000000) - gdf = cudf.DataFrame( - { - "a": np.random.randint(0, 10000, size), - "b": np.random.randint(0, 1000, size), - "c": np.random.randint(0, 1000, size), - } - ) - gdf = gdf.sort_values("a") - pdf = gdf.to_pandas() - - gdf_grouped = gdf.groupby("a") - pdf_grouped = pdf.groupby("a", as_index=False) - - expect = pdf_grouped.apply(func) - got = gdf_grouped.apply(func, engine="jit") - - assert_eq(expect, got) diff --git a/python/cudf/udf_cpp/shim.cu b/python/cudf/udf_cpp/shim.cu index a2b4a6a1141..5357f60498d 100644 --- a/python/cudf/udf_cpp/shim.cu +++ b/python/cudf/udf_cpp/shim.cu @@ -632,6 +632,19 @@ __device__ int64_t BlockIdxMin(T const* data, int64_t* index, int64_t size) return block_idx_min; } +template +__device__ double BlockCorr(T* const lhs_ptr, T* const rhs_ptr, int64_t size) +{ + auto numerator = BlockCoVar(lhs_ptr, rhs_ptr, size); + auto denominator = BlockStd(lhs_ptr, size) * BlockStd(rhs_ptr, size); + + if (denominator == 0.0) { + return 0.0; + } else { + return numerator / denominator; + } +} + extern "C" { #define make_definition(name, cname, type, return_type) \ __device__ int name##_##cname(return_type* numba_return_value, type* const data, int64_t size) \ @@ -697,19 +710,21 @@ make_definition_idx(BlockIdxMax, float64, double); #undef make_definition_idx } -extern "C" __device__ int BlockCorr(double* numba_return_value, - int64_t* const lhs_ptr, - int64_t* const rhs_ptr, - int64_t size) -{ - auto numerator = BlockCoVar(lhs_ptr, rhs_ptr, size); - auto denominator = BlockStd(lhs_ptr, size) * BlockStd(rhs_ptr, size); - - if (denominator == 0.0) { - *numba_return_value = 0.0; - } else { - *numba_return_value = numerator / denominator; +extern "C" { +#define make_definition_corr(name, cname, type) \ + __device__ int name##_##cname##_##cname( \ + double* numba_return_value, type* const lhs, type* const rhs, int64_t size) \ + { \ + double const res = name(lhs, rhs, size); \ + *numba_return_value = res; \ + __syncthreads(); \ + return 0; \ } - __syncthreads(); - return 0; + +make_definition_corr(BlockCorr, int32, int32_t); +make_definition_corr(BlockCorr, int64, int64_t); +make_definition_corr(BlockCorr, float32, float); +make_definition_corr(BlockCorr, float64, double); + +#undef make_definition_corr } From 5730faa7c52793a0b78e856de8db48c9f79b4d1e Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Mon, 31 Jul 2023 09:32:08 -0700 Subject: [PATCH 05/11] remove unnecessary pointer indirection --- python/cudf/cudf/core/udf/groupby_lowering.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/python/cudf/cudf/core/udf/groupby_lowering.py b/python/cudf/cudf/core/udf/groupby_lowering.py index eac01c476b6..1c6672a7729 100644 --- a/python/cudf/cudf/core/udf/groupby_lowering.py +++ b/python/cudf/cudf/core/udf/groupby_lowering.py @@ -65,13 +65,7 @@ def group_corr(context, builder, sig, args): rhs_grp = cgutils.create_struct_proxy(sig.args[1])( context, builder, value=args[1] ) - # logically take the address of the group's data pointer - lhs_group_data_ptr = builder.alloca(lhs_grp.group_data.type) - builder.store(lhs_grp.group_data, lhs_group_data_ptr) - # logically take the address of the group's data pointer - rhs_group_data_ptr = builder.alloca(rhs_grp.group_data.type) - builder.store(rhs_grp.group_data, rhs_group_data_ptr) device_func = call_cuda_functions["corr"][ ( sig.return_type, @@ -93,8 +87,8 @@ def group_corr(context, builder, sig, args): group_size_type, ), ( - builder.load(lhs_group_data_ptr), - builder.load(rhs_group_data_ptr), + lhs_grp.group_data, + rhs_grp.group_data, lhs_grp.size, ), ) From e02cc028253d05c7fe7689ebeb2cd1409b57f023 Mon Sep 17 00:00:00 2001 From: brandon-b-miller <53796099+brandon-b-miller@users.noreply.github.com> Date: Mon, 31 Jul 2023 13:53:23 -0500 Subject: [PATCH 06/11] Apply suggestions from code review Co-authored-by: Bradley Dice --- python/cudf/cudf/core/udf/groupby_typing.py | 2 +- python/cudf/udf_cpp/shim.cu | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/python/cudf/cudf/core/udf/groupby_typing.py b/python/cudf/cudf/core/udf/groupby_typing.py index 7c897b47330..de08ee800c0 100644 --- a/python/cudf/cudf/core/udf/groupby_typing.py +++ b/python/cudf/cudf/core/udf/groupby_typing.py @@ -115,7 +115,7 @@ def caller(lhs, rhs, size): call_cuda_functions.setdefault(funcname.lower(), {}) - type_key = (retty, lty, rty) + type_key = retty, lty, rty call_cuda_functions[funcname.lower()][type_key] = caller diff --git a/python/cudf/udf_cpp/shim.cu b/python/cudf/udf_cpp/shim.cu index 5357f60498d..1dba2216059 100644 --- a/python/cudf/udf_cpp/shim.cu +++ b/python/cudf/udf_cpp/shim.cu @@ -454,9 +454,15 @@ __device__ double BlockCoVar(T const* lhs, T const* rhs, int64_t size) block.sync(); device_sum(block, lhs, size, &block_sum_lhs); - device_sum(block, rhs, size, &block_sum_rhs); auto const mu_l = static_cast(block_sum_lhs) / static_cast(size); - auto const mu_r = static_cast(block_sum_rhs) / static_cast(size); + auto const mu_r = [=](){ + if (lhs == rhs) { + return mu_l; + } else { + device_sum(block, rhs, size, &block_sum_rhs); + return static_cast(block_sum_rhs) / static_cast(size); + } + }(); double local_covar = 0; @@ -470,7 +476,7 @@ __device__ double BlockCoVar(T const* lhs, T const* rhs, int64_t size) ref.fetch_add(local_covar, cuda::std::memory_order_relaxed); block.sync(); - if (block.thread_rank() == 0) { block_covar = block_covar / static_cast(size - 1); } + if (block.thread_rank() == 0) { block_covar /= static_cast(size - 1); } block.sync(); return block_covar; From dc35515b7d4562e7b61df2a6733204bcf60219a0 Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Mon, 31 Jul 2023 13:52:12 -0700 Subject: [PATCH 07/11] style --- python/cudf/udf_cpp/shim.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cudf/udf_cpp/shim.cu b/python/cudf/udf_cpp/shim.cu index 1dba2216059..3cdabe257b0 100644 --- a/python/cudf/udf_cpp/shim.cu +++ b/python/cudf/udf_cpp/shim.cu @@ -455,7 +455,7 @@ __device__ double BlockCoVar(T const* lhs, T const* rhs, int64_t size) device_sum(block, lhs, size, &block_sum_lhs); auto const mu_l = static_cast(block_sum_lhs) / static_cast(size); - auto const mu_r = [=](){ + auto const mu_r = [=]() { if (lhs == rhs) { return mu_l; } else { From e1cdad150892ab2752cca2e7d7d2d36313b5570a Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Tue, 1 Aug 2023 08:31:23 -0700 Subject: [PATCH 08/11] drop float for now --- python/cudf/cudf/core/udf/groupby_typing.py | 4 +++- python/cudf/cudf/tests/test_groupby.py | 2 +- python/cudf/udf_cpp/shim.cu | 2 -- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/cudf/cudf/core/udf/groupby_typing.py b/python/cudf/cudf/core/udf/groupby_typing.py index de08ee800c0..97afdd1c6ba 100644 --- a/python/cudf/cudf/core/udf/groupby_typing.py +++ b/python/cudf/cudf/core/udf/groupby_typing.py @@ -253,7 +253,9 @@ def resolve_corr(self, mod): _register_cuda_unary_reduction_caller("Min", ty, ty) _register_cuda_idx_reduction_caller("IdxMax", ty) _register_cuda_idx_reduction_caller("IdxMin", ty) - _register_cuda_binary_reduction_caller("Corr", ty, ty, types.float64) + + if ty in types.integer_domain: + _register_cuda_binary_reduction_caller("Corr", ty, ty, types.float64) _register_cuda_unary_reduction_caller("Sum", types.int32, types.int64) diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index 3950a83e831..7d22cb70803 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -435,7 +435,7 @@ def func(df): run_groupby_apply_jit_test(groupby_jit_data, func, ["key1"]) -@pytest.mark.parametrize("dtype", SUPPORTED_GROUPBY_NUMPY_TYPES) +@pytest.mark.parametrize("dtype", ["int32", "int64"]) def test_groupby_apply_jit_correlation(groupby_jit_data, dtype): groupby_jit_data["val3"] = groupby_jit_data["val3"].astype(dtype) diff --git a/python/cudf/udf_cpp/shim.cu b/python/cudf/udf_cpp/shim.cu index 3cdabe257b0..4925f2752e7 100644 --- a/python/cudf/udf_cpp/shim.cu +++ b/python/cudf/udf_cpp/shim.cu @@ -729,8 +729,6 @@ extern "C" { make_definition_corr(BlockCorr, int32, int32_t); make_definition_corr(BlockCorr, int64, int64_t); -make_definition_corr(BlockCorr, float32, float); -make_definition_corr(BlockCorr, float64, double); #undef make_definition_corr } From 14bd1be8a1fa95a8b29e6a8f6e13d1920260e3cd Mon Sep 17 00:00:00 2001 From: brandon-b-miller <53796099+brandon-b-miller@users.noreply.github.com> Date: Tue, 1 Aug 2023 10:49:24 -0500 Subject: [PATCH 09/11] Apply suggestions from code review Co-authored-by: Bradley Dice --- python/cudf/udf_cpp/shim.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/cudf/udf_cpp/shim.cu b/python/cudf/udf_cpp/shim.cu index 4925f2752e7..891f9034e83 100644 --- a/python/cudf/udf_cpp/shim.cu +++ b/python/cudf/udf_cpp/shim.cu @@ -457,6 +457,8 @@ __device__ double BlockCoVar(T const* lhs, T const* rhs, int64_t size) auto const mu_l = static_cast(block_sum_lhs) / static_cast(size); auto const mu_r = [=]() { if (lhs == rhs) { + // If the lhs and rhs are the same, this is calculating variance. + // Thus we can assume mu_r = mu_l. return mu_l; } else { device_sum(block, rhs, size, &block_sum_rhs); @@ -467,9 +469,8 @@ __device__ double BlockCoVar(T const* lhs, T const* rhs, int64_t size) double local_covar = 0; for (int64_t idx = block.thread_rank(); idx < size; idx += block.size()) { - auto const delta = + local_covar += (static_cast(lhs[idx]) - mu_l) * (static_cast(rhs[idx]) - mu_r); - local_covar += delta; } cuda::atomic_ref ref{block_covar}; From ff618ba70572b7d6953841d241dc4adf1d0feed7 Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Tue, 1 Aug 2023 11:13:53 -0700 Subject: [PATCH 10/11] style fixes --- python/cudf/udf_cpp/shim.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/cudf/udf_cpp/shim.cu b/python/cudf/udf_cpp/shim.cu index 891f9034e83..0959b6ba53f 100644 --- a/python/cudf/udf_cpp/shim.cu +++ b/python/cudf/udf_cpp/shim.cu @@ -469,8 +469,7 @@ __device__ double BlockCoVar(T const* lhs, T const* rhs, int64_t size) double local_covar = 0; for (int64_t idx = block.thread_rank(); idx < size; idx += block.size()) { - local_covar += - (static_cast(lhs[idx]) - mu_l) * (static_cast(rhs[idx]) - mu_r); + local_covar += (static_cast(lhs[idx]) - mu_l) * (static_cast(rhs[idx]) - mu_r); } cuda::atomic_ref ref{block_covar}; From 3f37b61e26b612d985135d9d4b86a8d342c58e7a Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Tue, 1 Aug 2023 14:46:14 -0400 Subject: [PATCH 11/11] Empty commit