Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type stubs for pylibcudf #17258

Open
wants to merge 17 commits into
base: branch-24.12
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion docs/cudf/source/developer_guide/pylibcudf.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ To satisfy the goals of pylibcudf, we impose the following set of design princip
- All typing in code should be written using Cython syntax, not PEP 484 Python typing syntax. Not only does this ensure compatibility with Cython < 3, but even with Cython 3 PEP 484 support remains incomplete as of this writing.
- All cudf code should interact only with pylibcudf, never with libcudf directly. This is not currently the case, but is the direction that the library is moving towards.
- Ideally, pylibcudf should depend on no RAPIDS component other than rmm, and should in general have minimal runtime dependencies.

- Type stubs are provided and generated manually. When adding new
functionality, ensure that the matching type stub is appropriately updated.

## Relationship to libcudf

Expand Down Expand Up @@ -249,3 +250,73 @@ In the event that libcudf provides multiple overloads for the same function with
and set arguments not shared between overloads to `None`. If a user tries to pass in an unsupported argument for a specific overload type, you should raise `ValueError`.

Finally, consider making an libcudf issue if you think this inconsistency can be addressed on the libcudf side.

### Type stubs

Since static type checkers like `mypy` and `pyright` cannot parse
Cython code, we provide type stubs for the pylibcudf package. These
are currently maintained manually, alongside the matching pylibcudf
files.

Every `pyx` file should have a matching `pyi` file that provides the
type stubs. Most functions can be exposed straightforwardly. Some
guiding principles:

- For typed integer arguments in libcudf, use `int` as a type
annotation.
- For functions which are annotated as a `list` in Cython, but the
function body does more detailed checking, try and encode the
detailed information in the type.
- For Cython fused types there are two options:
1. If the fused type appears only once in the function signature,
use a `Union` type;
2. If the fused type appears more than once (or as both an input
and output type), use a `TypeVar` with
the variants in the fused type provided as constraints.


As an example, `pylibcudf.copying.split` is typed in Cython as:

```cython
ctypedef fused ColumnOrTable:
Table
Column

cpdef list split(ColumnOrTable input, list splits): ...
```

Here we only have a single use of the fused type, and the `list`
arguments do not specify their values. Here, if we provide a `Column`
as input, we receive a `list[Column]` as output, and if we provide a
`Table` we receive `list[Table]` as output.

In the type stub, we can encode this with a `TypeVar`, we can also
provide typing for the `splits` argument that indicates that the split
values must be integers:

```python
ColumnOrTable = TypeVar("ColumnOrTable", Column, Table)

def split(input: ColumnOrTable, splits: list[int]) -> list[ColumnOrTable]: ...
```

Conversely, `pylibcudf.copying.scatter` uses a fused type only once in
its input:

```cython
ctypedef fused TableOrListOfScalars:
Table
list

cpdef Table scatter(
TableOrListOfScalars source, Column scatter_map, Table target
)
```

In the type stub, we can use a normal union in this case

```python
def scatter(
source: Table | list[Scalar], scatter_map: Column, target: Table
) -> Table: ...
```
4 changes: 2 additions & 2 deletions python/cudf/cudf/_lib/labeling.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def label_bins(Column input, Column left_edges, cbool left_inclusive,
plc_column = plc.labeling.label_bins(
input.to_pylibcudf(mode="read"),
left_edges.to_pylibcudf(mode="read"),
left_inclusive,
plc.labeling.Inclusive.YES if left_inclusive else plc.labeling.Inclusive.NO,
right_edges.to_pylibcudf(mode="read"),
right_inclusive
plc.labeling.Inclusive.YES if right_inclusive else plc.labeling.Inclusive.NO,
)
return Column.from_pylibcudf(plc_column)
18 changes: 11 additions & 7 deletions python/cudf/cudf/_lib/lists.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ from cudf.core.buffer import acquire_spill_lock

from libcpp cimport bool

from pylibcudf.libcudf.types cimport null_order, size_type
from pylibcudf.libcudf.types cimport (
nan_equality, null_equality, null_order, order, size_type
)

from cudf._lib.column cimport Column
from cudf._lib.utils cimport columns_from_pylibcudf_table
Expand Down Expand Up @@ -37,8 +39,8 @@ def distinct(Column col, bool nulls_equal, bool nans_all_equal):
return Column.from_pylibcudf(
plc.lists.distinct(
col.to_pylibcudf(mode="read"),
nulls_equal,
nans_all_equal,
null_equality.EQUAL if nulls_equal else null_equality.UNEQUAL,
nan_equality.ALL_EQUAL if nans_all_equal else nan_equality.UNEQUAL,
)
)

Expand All @@ -48,7 +50,7 @@ def sort_lists(Column col, bool ascending, str na_position):
return Column.from_pylibcudf(
plc.lists.sort_lists(
col.to_pylibcudf(mode="read"),
ascending,
order.ASCENDING if ascending else order.DESCENDING,
null_order.BEFORE if na_position == "first" else null_order.AFTER,
False,
)
Expand Down Expand Up @@ -91,7 +93,7 @@ def index_of_scalar(Column col, object py_search_key):
plc.lists.index_of(
col.to_pylibcudf(mode="read"),
<Scalar> py_search_key.device_value.c_value,
True,
plc.lists.DuplicateFindOption.FIND_FIRST,
)
)

Expand All @@ -102,7 +104,7 @@ def index_of_column(Column col, Column search_keys):
plc.lists.index_of(
col.to_pylibcudf(mode="read"),
search_keys.to_pylibcudf(mode="read"),
True,
plc.lists.DuplicateFindOption.FIND_FIRST,
)
)

Expand All @@ -123,7 +125,9 @@ def concatenate_list_elements(Column input_column, dropna=False):
return Column.from_pylibcudf(
plc.lists.concatenate_list_elements(
input_column.to_pylibcudf(mode="read"),
dropna,
plc.lists.ConcatenateNullPolicy.IGNORE
if dropna
else plc.lists.ConcatenateNullPolicy.NULLIFY_OUTPUT_ROW,
)
)

Expand Down
2 changes: 1 addition & 1 deletion python/cudf_polars/cudf_polars/containers/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def to_polars(self) -> pl.DataFrame:
# To guarantee we produce correct names, we therefore
# serialise with names we control and rename with that map.
name_map = {f"column_{i}": name for i, name in enumerate(self.column_map)}
table: pa.Table = plc.interop.to_arrow(
table = plc.interop.to_arrow(
self.table,
[plc.interop.ColumnMetadata(name=name) for name in name_map],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@

class TemporalFunction(Expr):
__slots__ = ("name", "options")
_COMPONENT_MAP: ClassVar[dict[pl_expr.TemporalFunction, str]] = {
_COMPONENT_MAP: ClassVar[
dict[pl_expr.TemporalFunction, plc.datetime.DatetimeComponent]
] = {
vyasr marked this conversation as resolved.
Show resolved Hide resolved
pl_expr.TemporalFunction.Year: plc.datetime.DatetimeComponent.YEAR,
pl_expr.TemporalFunction.Month: plc.datetime.DatetimeComponent.MONTH,
pl_expr.TemporalFunction.Day: plc.datetime.DatetimeComponent.DAY,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def collect_agg(self, *, depth: int) -> AggInfo:
class LiteralColumn(Expr):
__slots__ = ("value",)
_non_child = ("dtype", "value")
value: pa.Array[Any, Any]
value: pa.Array[Any]
vyasr marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, dtype: plc.DataType, value: pl.Series) -> None:
self.dtype = dtype
Expand Down
2 changes: 1 addition & 1 deletion python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def do_evaluate(
# Mask must have been applied.
return df
elif typ == "ndjson":
json_schema: list[tuple[str, str, list]] = [
json_schema: list[plc.io.json.NameAndType] = [
vyasr marked this conversation as resolved.
Show resolved Hide resolved
(name, typ, []) for name, typ in schema.items()
]
plc_tbl_w_meta = plc.io.json.read_json(
Expand Down
110 changes: 110 additions & 0 deletions python/pylibcudf/pylibcudf/aggregation.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from enum import IntEnum

from pylibcudf.types import (
DataType,
Interpolation,
NanEquality,
NullEquality,
NullOrder,
NullPolicy,
Order,
)

class Kind(IntEnum):
SUM = ...
PRODUCT = ...
MIN = ...
MAX = ...
COUNT_VALID = ...
COUNT_ALL = ...
ANY = ...
ALL = ...
SUM_OF_SQUARES = ...
MEAN = ...
VARIANCE = ...
STD = ...
MEDIAN = ...
QUANTILE = ...
ARGMAX = ...
ARGMIN = ...
NUNIQUE = ...
NTH_ELEMENT = ...
RANK = ...
COLLECT_LIST = ...
COLLECT_SET = ...
PTX = ...
CUDA = ...
CORRELATION = ...
COVARIANCE = ...

class CorrelationType(IntEnum):
PEARSON = ...
KENDALL = ...
SPEARMAN = ...

class EWMHistory(IntEnum):
INFINITE = ...
FINITE = ...

class RankMethod(IntEnum):
FIRST = ...
AVERAGE = ...
MIN = ...
MAX = ...
DENSE = ...

class RankPercentage(IntEnum):
NONE = ...
ZERO_NORMALIZED = ...
ONE_NORMALIZED = ...

class UdfType(IntEnum):
CUDA = ...
PTX = ...

class Aggregation:
def __init__(self): ...
def kind(self) -> Kind: ...

def sum() -> Aggregation: ...
def product() -> Aggregation: ...
def min() -> Aggregation: ...
def max() -> Aggregation: ...
def count(null_handling: NullPolicy = NullPolicy.INCLUDE) -> Aggregation: ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we elide default parameters? They don't add anything for typing, and that is mypy's recommendation:

Stub files are written in normal Python syntax, but generally leaving out runtime logic like variable initializers, function bodies, and default arguments.

It might also simplify the requirements for the automation scripts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can elide default values by doing def count(null_handling: NullPolicy = ...).

However, type stubs are not just for type checking, but also LSP. In the latter case, having the default value encoded is useful since I can immediately see from the signature what behaviour I get if I don't provide the argument.

I will see if I can automatically put the right value in here.

def any() -> Aggregation: ...
def all() -> Aggregation: ...
def sum_of_squares() -> Aggregation: ...
def mean() -> Aggregation: ...
def variance(ddof: int = 1) -> Aggregation: ...
def std(ddof: int = 1) -> Aggregation: ...
def median() -> Aggregation: ...
def quantile(
quantiles: list[float], interp: Interpolation = Interpolation.LINEAR
) -> Aggregation: ...
def argmax() -> Aggregation: ...
def argmin() -> Aggregation: ...
def ewma(center_of_mass: float, history: EWMHistory) -> Aggregation: ...
def nunique(null_handling: NullPolicy = NullPolicy.EXCLUDE) -> Aggregation: ...
def nth_element(
n: int, null_handling: NullPolicy = NullPolicy.INCLUDE
) -> Aggregation: ...
def collect_list(
null_handling: NullPolicy = NullPolicy.INCLUDE,
) -> Aggregation: ...
def collect_set(
null_handling: NullPolicy = NullPolicy.INCLUDE,
nulls_equal: NullEquality = NullEquality.EQUAL,
nans_equal: NanEquality = NanEquality.ALL_EQUAL,
) -> Aggregation: ...
def udf(operation: str, output_type: DataType) -> Aggregation: ...
def correlation(type: CorrelationType, min_periods: int) -> Aggregation: ...
def covariance(min_periods: int, ddof: int) -> Aggregation: ...
def rank(
method: RankMethod,
column_order: Order = Order.ASCENDING,
null_handling: NullPolicy = NullPolicy.EXCLUDE,
null_precedence: NullOrder = NullOrder.AFTER,
percentage: RankPercentage = RankPercentage.NONE,
) -> Aggregation: ...
34 changes: 34 additions & 0 deletions python/pylibcudf/pylibcudf/aggregation.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,40 @@ from pylibcudf.libcudf.aggregation import udf_type as UdfType # no-cython-lint
from .types cimport DataType


__all__ = [
"Aggregation",
"CorrelationType",
"EWMHistory",
"Kind",
"RankMethod",
"RankPercentage",
"UdfType",
"all",
"any",
"argmax",
"argmin",
"collect_list",
"collect_set",
"correlation",
"count",
"covariance",
"ewma",
"max",
"mean",
"median",
"min",
"nth_element",
"nunique",
"product",
"quantile",
"rank",
"std",
"sum",
"sum_of_squares",
"udf",
"variance",
]

cdef class Aggregation:
"""A type of aggregation to perform.
Expand Down
54 changes: 54 additions & 0 deletions python/pylibcudf/pylibcudf/binaryop.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from enum import IntEnum

from pylibcudf.column import Column
from pylibcudf.scalar import Scalar
from pylibcudf.types import DataType

class BinaryOperator(IntEnum):
ADD = ...
SUB = ...
MUL = ...
DIV = ...
TRUE_DIV = ...
FLOOR_DIV = ...
MOD = ...
PMOD = ...
PYMOD = ...
POW = ...
INT_POW = ...
LOG_BASE = ...
ATAN2 = ...
SHIFT_LEFT = ...
SHIFT_RIGHT = ...
SHIFT_RIGHT_UNSIGNED = ...
BITWISE_AND = ...
BITWISE_OR = ...
BITWISE_XOR = ...
LOGICAL_AND = ...
LOGICAL_OR = ...
EQUAL = ...
NOT_EQUAL = ...
LESS = ...
GREATER = ...
LESS_EQUAL = ...
GREATER_EQUAL = ...
NULL_EQUALS = ...
NULL_MAX = ...
NULL_MIN = ...
NULL_NOT_EQUALS = ...
GENERIC_BINARY = ...
NULL_LOGICAL_AND = ...
NULL_LOGICAL_OR = ...
INVALID_BINARY = ...

def binary_operation(
lhs: Column | Scalar,
rhs: Column | Scalar,
op: BinaryOperator,
output_type: DataType,
) -> Column: ...
def is_supported_operation(
out: DataType, lhs: DataType, rhs: DataType, op: BinaryOperator
) -> bool: ...
Loading
Loading