Skip to content

Commit

Permalink
Used typed enums for null/nan equality in list methods
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed Nov 6, 2024
1 parent 0fb6980 commit 26cddb3
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 135 deletions.
18 changes: 11 additions & 7 deletions python/cudf/cudf/_lib/lists.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ from cudf.core.buffer import acquire_spill_lock

from libcpp cimport bool

from pylibcudf.libcudf.types cimport null_order, size_type
from pylibcudf.libcudf.types cimport (
nan_equality, null_equality, null_order, order, size_type
)

from cudf._lib.column cimport Column
from cudf._lib.utils cimport columns_from_pylibcudf_table
Expand Down Expand Up @@ -37,8 +39,8 @@ def distinct(Column col, bool nulls_equal, bool nans_all_equal):
return Column.from_pylibcudf(
plc.lists.distinct(
col.to_pylibcudf(mode="read"),
nulls_equal,
nans_all_equal,
null_equality.EQUAL if nulls_equal else null_equality.UNEQUAL,
nan_equality.ALL_EQUAL if nans_all_equal else nan_equality.UNEQUAL,
)
)

Expand All @@ -48,7 +50,7 @@ def sort_lists(Column col, bool ascending, str na_position):
return Column.from_pylibcudf(
plc.lists.sort_lists(
col.to_pylibcudf(mode="read"),
ascending,
order.ASCENDING if ascending else order.DESCENDING,
null_order.BEFORE if na_position == "first" else null_order.AFTER,
False,
)
Expand Down Expand Up @@ -91,7 +93,7 @@ def index_of_scalar(Column col, object py_search_key):
plc.lists.index_of(
col.to_pylibcudf(mode="read"),
<Scalar> py_search_key.device_value.c_value,
True,
plc.lists.DuplicateFindOption.FIND_FIRST,
)
)

Expand All @@ -102,7 +104,7 @@ def index_of_column(Column col, Column search_keys):
plc.lists.index_of(
col.to_pylibcudf(mode="read"),
search_keys.to_pylibcudf(mode="read"),
True,
plc.lists.DuplicateFindOption.FIND_FIRST,
)
)

Expand All @@ -123,7 +125,9 @@ def concatenate_list_elements(Column input_column, dropna=False):
return Column.from_pylibcudf(
plc.lists.concatenate_list_elements(
input_column.to_pylibcudf(mode="read"),
dropna,
plc.lists.ConcatenateNullPolicy.IGNORE
if dropna
else plc.lists.ConcatenateNullPolicy.NULLIFTY_OUTPUT_ROW,
)
)

Expand Down
8 changes: 4 additions & 4 deletions python/pylibcudf/pylibcudf/libcudf/lists/combine.pxd
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2021-2024, NVIDIA CORPORATION.

from libc.stdint cimport int32_t
from libcpp.memory cimport unique_ptr
from pylibcudf.libcudf.column.column cimport column
from pylibcudf.libcudf.column.column_view cimport column_view
Expand All @@ -9,10 +10,9 @@ from pylibcudf.libcudf.table.table_view cimport table_view
cdef extern from "cudf/lists/combine.hpp" namespace \
"cudf::lists" nogil:

ctypedef enum concatenate_null_policy:
IGNORE "cudf::lists::concatenate_null_policy::IGNORE"
NULLIFY_OUTPUT_ROW \
"cudf::lists::concatenate_null_policy::NULLIFY_OUTPUT_ROW"
cpdef enum class concatenate_null_policy(int32_t):
IGNORE
NULLIFY_OUTPUT_ROW

cdef unique_ptr[column] concatenate_rows(
const table_view input_table
Expand Down
Empty file.
Empty file.
30 changes: 21 additions & 9 deletions python/pylibcudf/pylibcudf/lists.pxd
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from libcpp cimport bool
from pylibcudf.libcudf.types cimport null_order, size_type
from pylibcudf.libcudf.types cimport (
nan_equality, null_equality, null_order, order, size_type
)
from pylibcudf.libcudf.lists.combine cimport concatenate_null_policy
from pylibcudf.libcudf.lists.contains cimport duplicate_find_option

from .column cimport Column
from .scalar cimport Scalar
Expand All @@ -19,13 +23,13 @@ cpdef Table explode_outer(Table, size_type explode_column_idx)

cpdef Column concatenate_rows(Table)

cpdef Column concatenate_list_elements(Column, bool dropna)
cpdef Column concatenate_list_elements(Column, concatenate_null_policy null_policy)

cpdef Column contains(Column, ColumnOrScalar)

cpdef Column contains_nulls(Column)

cpdef Column index_of(Column, ColumnOrScalar, bool)
cpdef Column index_of(Column, ColumnOrScalar, duplicate_find_option)

cpdef Column reverse(Column)

Expand All @@ -37,16 +41,24 @@ cpdef Column count_elements(Column)

cpdef Column sequences(Column, Column, Column steps = *)

cpdef Column sort_lists(Column, bool, null_order, bool stable = *)
cpdef Column sort_lists(Column, order, null_order, bool stable = *)

cpdef Column difference_distinct(Column, Column, bool nulls_equal=*, bool nans_equal=*)
cpdef Column difference_distinct(
Column, Column, null_equality nulls_equal=*, nan_equality nans_equal=*
)

cpdef Column have_overlap(Column, Column, bool nulls_equal=*, bool nans_equal=*)
cpdef Column have_overlap(
Column, Column, null_equality nulls_equal=*, nan_equality nans_equal=*
)

cpdef Column intersect_distinct(Column, Column, bool nulls_equal=*, bool nans_equal=*)
cpdef Column intersect_distinct(
Column, Column, null_equality nulls_equal=*, nan_equality nans_equal=*
)

cpdef Column union_distinct(Column, Column, bool nulls_equal=*, bool nans_equal=*)
cpdef Column union_distinct(
Column, Column, null_equality nulls_equal=*, nan_equality nans_equal=*
)

cpdef Column apply_boolean_mask(Column, Column)

cpdef Column distinct(Column, bool, bool)
cpdef Column distinct(Column, null_equality, nan_equality)
46 changes: 37 additions & 9 deletions python/pylibcudf/pylibcudf/lists.pyi
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from enum import IntEnum, auto

from pylibcudf.column import Column
from pylibcudf.scalar import Scalar
from pylibcudf.table import Table
from pylibcudf.types import NullOrder
from pylibcudf.types import NanEquality, NullEquality, NullOrder, Order

class ConcatenateNullPolicy(IntEnum):
IGNORE = auto()
NULLIFY_OUTPUT_ROW = auto()

class DuplicateFindOption(IntEnum):
FIND_FIRST = auto()
FIND_LAST = auto()

def explode_outer(input: Table, explode_column_idx: int) -> Table: ...
def concatenate_rows(input: Table) -> Column: ...
def concatenate_list_elements(input: Column, dropna: bool) -> Column: ...
def concatenate_list_elements(
input: Column, null_policy: ConcatenateNullPolicy
) -> Column: ...
def contains(input: Column, search_key: Column | Scalar) -> Column: ...
def contains_nulls(input: Column) -> Column: ...
def index_of(
input: Column, search_key: Column | Scalar, find_first_option: bool
input: Column,
search_key: Column | Scalar,
find_option: DuplicateFindOption,
) -> Column: ...
def reverse(input: Column) -> Column: ...
def segmented_gather(input: Column, gather_map_list: Column) -> Column: ...
Expand All @@ -22,21 +36,35 @@ def sequences(
) -> Column: ...
def sort_lists(
input: Column,
ascending: bool,
sort_order: Order,
na_position: NullOrder,
stable: bool = False,
) -> Column: ...
def difference_distinct(
lhs: Column, rhs: Column, nulls_equal: bool = True, nans_equal: bool = True
lhs: Column,
rhs: Column,
nulls_equal: NullEquality = NullEquality.EQUAL,
nans_equal: NanEquality = NanEquality.ALL_EQUAL,
) -> Column: ...
def have_overlap(
lhs: Column, rhs: Column, nulls_equal: bool = True, nans_equal: bool = True
lhs: Column,
rhs: Column,
nulls_equal: NullEquality = NullEquality.EQUAL,
nans_equal: NanEquality = NanEquality.ALL_EQUAL,
) -> Column: ...
def intersect_distinct(
lhs: Column, rhs: Column, nulls_equal: bool = True, nans_equal: bool = True
lhs: Column,
rhs: Column,
nulls_equal: NullEquality = NullEquality.EQUAL,
nans_equal: NanEquality = NanEquality.ALL_EQUAL,
) -> Column: ...
def union_distinct(
lhs: Column, rhs: Column, nulls_equal: bool = True, nans_equal: bool = True
lhs: Column,
rhs: Column,
nulls_equal: NullEquality = NullEquality.EQUAL,
nans_equal: NanEquality = NanEquality.ALL_EQUAL,
) -> Column: ...
def apply_boolean_mask(input: Column, mask: Column) -> Column: ...
def distinct(input: Column, nulls_equal: bool, nans_equal: bool) -> Column: ...
def distinct(
input: Column, nulls_equal: NullEquality, nans_equal: NanEquality
) -> Column: ...
Loading

0 comments on commit 26cddb3

Please sign in to comment.