From 3571291c533412f8efa4c5d41caa865564b5391b Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Mon, 4 Mar 2024 15:04:54 -1000 Subject: [PATCH] Use as_column instead of full (#14698) Similar to https://github.com/rapidsai/cudf/pull/14689, ensures there's 1 entrypoint to create a column from a scalar. This builds on https://github.com/rapidsai/cudf/pull/14620 Authors: - Matthew Roeschke (https://github.com/mroeschke) - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Richard (Rick) Zamora (https://github.com/rjzamora) - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/cudf/pull/14698 --- python/cudf/cudf/core/column/__init__.py | 1 - python/cudf/cudf/core/column/categorical.py | 12 +-- python/cudf/cudf/core/column/column.py | 100 ++++++-------------- python/cudf/cudf/core/column/decimal.py | 4 +- python/cudf/cudf/core/column/numerical.py | 3 +- python/cudf/cudf/core/column/string.py | 12 ++- python/cudf/cudf/core/column/timedelta.py | 4 +- python/cudf/cudf/core/dataframe.py | 26 +++-- python/cudf/cudf/core/index.py | 6 +- python/cudf/cudf/core/indexed_frame.py | 14 ++- python/cudf/cudf/core/multiindex.py | 8 +- python/cudf/cudf/core/series.py | 5 +- python/cudf/cudf/core/tools/datetimes.py | 4 +- python/cudf/cudf/core/window/rolling.py | 5 +- python/cudf/cudf/io/parquet.py | 14 +-- python/cudf/cudf/tests/test_testing.py | 6 +- python/cudf/cudf/utils/utils.py | 6 +- python/dask_cudf/dask_cudf/backends.py | 6 +- 18 files changed, 101 insertions(+), 135 deletions(-) diff --git a/python/cudf/cudf/core/column/__init__.py b/python/cudf/cudf/core/column/__init__.py index a1c86b617b0..2a46654ccc2 100644 --- a/python/cudf/cudf/core/column/__init__.py +++ b/python/cudf/cudf/core/column/__init__.py @@ -16,7 +16,6 @@ column_empty_like_same_mask, concat_columns, deserialize_columns, - full, serialize_columns, ) from cudf.core.column.datetime import DatetimeColumn # noqa: F401 diff --git a/python/cudf/cudf/core/column/categorical.py b/python/cudf/cudf/core/column/categorical.py index 4c64e7085c9..88bb4521a5b 100644 --- a/python/cudf/cudf/core/column/categorical.py +++ b/python/cudf/cudf/core/column/categorical.py @@ -734,8 +734,8 @@ def normalize_binop_value(self, other: ScalarLike) -> CategoricalColumn: ) return other - ary = column.full( - len(self), self._encode(other), dtype=self.codes.dtype + ary = column.as_column( + self._encode(other), length=len(self), dtype=self.codes.dtype ) return column.build_categorical_column( categories=self.dtype.categories._values, @@ -1444,11 +1444,9 @@ def _create_empty_categorical_column( return column.build_categorical_column( categories=column.as_column(dtype.categories), codes=column.as_column( - column.full( - categorical_column.size, - _DEFAULT_CATEGORICAL_VALUE, - categorical_column.codes.dtype, - ) + _DEFAULT_CATEGORICAL_VALUE, + length=categorical_column.size, + dtype=categorical_column.codes.dtype, ), offset=categorical_column.offset, size=categorical_column.size, diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index be196833f32..8941d111d02 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -58,7 +58,6 @@ infer_dtype, is_bool_dtype, is_datetime64_dtype, - is_decimal_dtype, is_dtype_equal, is_integer_dtype, is_list_dtype, @@ -866,7 +865,7 @@ def isin(self, values: Sequence) -> ColumnBase: except ValueError: # pandas functionally returns all False when cleansing via # typecasting fails - return full(len(self), False, dtype="bool") + return as_column(False, length=len(self), dtype="bool") return lhs._obtain_isin_result(rhs) @@ -893,9 +892,9 @@ def _isin_earlystop(self, rhs: ColumnBase) -> Union[ColumnBase, None]: if self.null_count and rhs.null_count: return self.isnull() else: - return cudf.core.column.full(len(self), False, dtype="bool") + return as_column(False, length=len(self), dtype="bool") elif self.null_count == 0 and (rhs.null_count == len(rhs)): - return cudf.core.column.full(len(self), False, dtype="bool") + return as_column(False, length=len(self), dtype="bool") else: return None @@ -1356,9 +1355,7 @@ def _label_encoding( na_sentinel = cudf.Scalar(-1) def _return_sentinel_column(): - return cudf.core.column.full( - size=len(self), fill_value=na_sentinel, dtype=dtype - ) + return as_column(na_sentinel, dtype=dtype, length=len(self)) if dtype is None: dtype = min_scalar_type(max(len(cats), na_sentinel), 8) @@ -1455,7 +1452,9 @@ def column_empty( elif isinstance(dtype, ListDtype): data = None children = ( - full(row_count + 1, 0, dtype=libcudf.types.size_type_dtype), + as_column( + 0, length=row_count + 1, dtype=libcudf.types.size_type_dtype + ), column_empty(row_count, dtype=dtype.element_type), ) elif isinstance(dtype, CategoricalDtype): @@ -1474,7 +1473,9 @@ def column_empty( elif dtype.kind in "OU" and not isinstance(dtype, DecimalDtype): data = as_buffer(rmm.DeviceBuffer(size=0)) children = ( - full(row_count + 1, 0, dtype=libcudf.types.size_type_dtype), + as_column( + 0, length=row_count + 1, dtype=libcudf.types.size_type_dtype + ), ) else: data = as_buffer(rmm.DeviceBuffer(size=row_count * dtype.itemsize)) @@ -2017,33 +2018,32 @@ def as_column( if dtype is not None: data = data.astype(dtype) - elif isinstance(arbitrary, (pd.Timestamp, pd.Timedelta)): - # This will always treat NaTs as nulls since it's not technically a - # discrete value like NaN - length = length or 1 - data = as_column( - pa.array(pd.Series([arbitrary] * length), from_pandas=True) - ) - if dtype is not None: - data = data.astype(dtype) - - elif np.isscalar(arbitrary) and not isinstance(arbitrary, memoryview): - length = length or 1 + elif is_scalar(arbitrary) and not isinstance(arbitrary, memoryview): + if length is None: + length = 1 + elif length < 0: + raise ValueError(f"{length=} must be >=0.") + if isinstance(arbitrary, pd.Interval): + # No cudf.Scalar support yet + return as_column( + pd.Series([arbitrary] * length), + nan_as_null=nan_as_null, + dtype=dtype, + length=length, + ) if ( - (nan_as_null is True) + nan_as_null is True and isinstance(arbitrary, (np.floating, float)) and np.isnan(arbitrary) ): - arbitrary = None if dtype is None: - dtype = cudf.dtype("float64") - - data = as_column(full(length, arbitrary, dtype=dtype)) - if not nan_as_null and not is_decimal_dtype(data.dtype): - if np.issubdtype(data.dtype, np.floating): - data = data.fillna(np.nan) - elif np.issubdtype(data.dtype, np.datetime64): - data = data.fillna(np.datetime64("NaT")) + dtype = getattr(arbitrary, "dtype", cudf.dtype("float64")) + arbitrary = None + arbitrary = cudf.Scalar(arbitrary, dtype=dtype) + if length == 0: + return column_empty(length, dtype=arbitrary.dtype) + else: + return ColumnBase.from_scalar(arbitrary, length) elif hasattr(arbitrary, "__array_interface__"): # CUDF assumes values are always contiguous @@ -2161,8 +2161,6 @@ def as_column( return as_column( np.asarray(view), dtype=dtype, nan_as_null=nan_as_null ) - elif isinstance(arbitrary, cudf.Scalar): - data = ColumnBase.from_scalar(arbitrary, length if length else 1) else: if dtype is not None: # Arrow throws a type error if the input is of @@ -2505,42 +2503,6 @@ def deserialize_columns(headers: List[dict], frames: List) -> List[ColumnBase]: return columns -def full( - size: int, fill_value: ScalarLike, dtype: Optional[Dtype] = None -) -> ColumnBase: - """ - Returns a column of given size and dtype, filled with a given value. - - Parameters - ---------- - size : int - size of the expected column. - fill_value : scalar - A scalar value to fill a new array. - dtype : default None - Data type specifier. It is inferred from other arguments by default. - - Returns - ------- - Column - - Examples - -------- - >>> import cudf - >>> col = cudf.core.column.full(size=5, fill_value=7, dtype='int8') - >>> col - - >>> cudf.Series(col) - 0 7 - 1 7 - 2 7 - 3 7 - 4 7 - dtype: int8 - """ - return ColumnBase.from_scalar(cudf.Scalar(fill_value, dtype), size) - - def concat_columns(objs: "MutableSequence[ColumnBase]") -> ColumnBase: """Concatenate a sequence of columns.""" if len(objs) == 0: diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index 0e90b522f2c..b83a6ded416 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -69,8 +69,8 @@ def as_string_column( def __pow__(self, other): if isinstance(other, int): if other == 0: - res = cudf.core.column.full( - size=len(self), fill_value=1, dtype=self.dtype + res = cudf.core.column.as_column( + 1, dtype=self.dtype, length=len(self) ) if self.nullable: res = res.set_mask(self.mask) diff --git a/python/cudf/cudf/core/column/numerical.py b/python/cudf/cudf/core/column/numerical.py index 82d82593c77..8d9da8982ac 100644 --- a/python/cudf/cudf/core/column/numerical.py +++ b/python/cudf/cudf/core/column/numerical.py @@ -42,7 +42,6 @@ as_column, build_column, column, - full, string, ) from cudf.core.dtypes import CategoricalDtype @@ -513,7 +512,7 @@ def find_and_replace( ) if len(replacement_col) == 1 and len(to_replace_col) > 1: replacement_col = column.as_column( - full(len(to_replace_col), replacement[0], self.dtype) + replacement[0], length=len(to_replace_col), dtype=self.dtype ) elif len(replacement_col) == 1 and len(to_replace_col) == 0: return self.copy() diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index dea60f58690..e947c9375d7 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -5499,7 +5499,9 @@ def __init__( if len(children) == 0 and size != 0: # all nulls-column: - offsets = column.full(size + 1, 0, dtype=size_type_dtype) + offsets = column.as_column( + 0, length=size + 1, dtype=size_type_dtype + ) children = (offsets,) @@ -5930,8 +5932,8 @@ def _binaryop( "__eq__", "__ne__", }: - return column.full( - len(self), op == "__ne__", dtype="bool" + return column.as_column( + op == "__ne__", length=len(self), dtype="bool" ).set_mask(self.mask) else: return NotImplemented @@ -5940,7 +5942,9 @@ def _binaryop( if isinstance(other, cudf.Scalar): other = cast( StringColumn, - column.full(len(self), other, dtype="object"), + column.as_column( + other, length=len(self), dtype="object" + ), ) # Explicit types are necessary because mypy infers ColumnBase diff --git a/python/cudf/cudf/core/column/timedelta.py b/python/cudf/cudf/core/column/timedelta.py index dab2723795e..ee326b254b9 100644 --- a/python/cudf/cudf/core/column/timedelta.py +++ b/python/cudf/cudf/core/column/timedelta.py @@ -510,7 +510,7 @@ def components(self, index=None) -> "cudf.DataFrame": break for name in keys_list: - res_col = cudf.core.column.full(len(self), 0, dtype="int64") + res_col = column.as_column(0, length=len(self), dtype="int64") if self.nullable: res_col = res_col.set_mask(self.mask) data[name] = res_col @@ -599,7 +599,7 @@ def nanoseconds(self) -> "cudf.core.column.NumericalColumn": # of nanoseconds. if self._time_unit != "ns": - res_col = cudf.core.column.full(len(self), 0, dtype="int64") + res_col = column.as_column(0, length=len(self), dtype="int64") if self.nullable: res_col = res_col.set_mask(self.mask) return cast("cudf.core.column.NumericalColumn", res_col) diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index d7d2e1acd85..31a748da856 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -1407,7 +1407,7 @@ def __setitem__(self, arg, value): allow_non_unique=True, ) if is_scalar(value): - self._data[arg] = column.full(len(self), value) + self._data[arg] = as_column(value, length=len(self)) else: value = as_column(value) self._data[arg] = value @@ -1455,8 +1455,8 @@ def __setitem__(self, arg, value): else: for col in arg: if is_scalar(value): - self._data[col] = column.full( - size=len(self), fill_value=value + self._data[col] = as_column( + value, length=len(self) ) else: self._data[col] = column.as_column(value) @@ -3205,10 +3205,16 @@ def _insert(self, loc, name, value, nan_as_null=None, ignore_index=True): ) if _is_scalar_or_zero_d_array(value): - value = column.full( - len(self), + dtype = None + if isinstance(value, (np.ndarray, cupy.ndarray)): + dtype = value.dtype + value = value.item() + if libcudf.scalar._is_null_host_scalar(value): + dtype = "str" + value = as_column( value, - "str" if libcudf.scalar._is_null_host_scalar(value) else None, + length=len(self), + dtype=dtype, ) if len(self) == 0: @@ -5912,7 +5918,7 @@ def isin(self, values): fill_value = cudf.Scalar(False) def make_false_column_like_self(): - return column.full(len(self), fill_value, "bool") + return column.as_column(fill_value, length=len(self), dtype="bool") # Preprocess different input types into a mapping from column names to # a list of values to check. @@ -6031,7 +6037,7 @@ def _prepare_for_rowwise_op(self, method, skipna, numeric_only): { name: filtered._data[name]._get_mask_as_column() if filtered._data[name].nullable - else column.full(len(filtered._data[name]), True) + else as_column(True, length=len(filtered._data[name])) for name in filtered._data.names } ) @@ -7822,8 +7828,8 @@ def func(left, right, output): return output for name in uncommon_columns: - output._data[name] = column.full( - size=len(output), fill_value=value, dtype="bool" + output._data[name] = as_column( + value, length=len(output), dtype="bool" ) return output diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index 9d481037ec6..bd9dc1ae3da 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -1231,9 +1231,9 @@ def get_indexer(self, target, method=None, limit=None, tolerance=None): ) needle = as_column(target) - result = cudf.core.column.full( - len(needle), - fill_value=-1, + result = as_column( + -1, + length=len(needle), dtype=libcudf.types.size_type_dtype, ) diff --git a/python/cudf/cudf/core/indexed_frame.py b/python/cudf/cudf/core/indexed_frame.py index 3c6e1e17142..df703370f78 100644 --- a/python/cudf/cudf/core/indexed_frame.py +++ b/python/cudf/cudf/core/indexed_frame.py @@ -50,7 +50,7 @@ from cudf.core._base_index import BaseIndex from cudf.core._compat import PANDAS_LT_300 from cudf.core.buffer import acquire_spill_lock -from cudf.core.column import ColumnBase, as_column, full +from cudf.core.column import ColumnBase, as_column from cudf.core.column_accessor import ColumnAccessor from cudf.core.copy_types import BooleanMask, GatherMap from cudf.core.dtypes import ListDtype @@ -3048,7 +3048,7 @@ def duplicated(self, subset=None, keep="first"): (result,) = libcudf.copying.scatter( [cudf.Scalar(False, dtype=bool)], distinct, - [full(len(self), True, dtype=bool)], + [as_column(True, length=len(self), dtype=bool)], bounds_check=False, ) return cudf.Series(result, index=self.index) @@ -3327,9 +3327,7 @@ def _apply(self, func, kernel_getter, *args, **kwargs): # Mask and data column preallocated ans_col = _return_arr_from_dtype(retty, len(self)) - ans_mask = cudf.core.column.full( - size=len(self), fill_value=True, dtype="bool" - ) + ans_mask = as_column(True, length=len(self), dtype="bool") output_args = [(ans_col, ans_mask), len(self)] input_args = _get_input_args_from_frame(self) launch_args = output_args + input_args + list(args) @@ -6260,10 +6258,10 @@ def _get_replacement_values_for_columns( values_columns = { col: [value] if _is_non_decimal_numeric_dtype(columns_dtype_map[col]) - else full( - len(to_replace), + else as_column( value, - cudf.dtype(type(value)), + length=len(to_replace), + dtype=cudf.dtype(type(value)), ) for col in columns_dtype_map } diff --git a/python/cudf/cudf/core/multiindex.py b/python/cudf/cudf/core/multiindex.py index 70112044f75..315a21020a2 100644 --- a/python/cudf/cudf/core/multiindex.py +++ b/python/cudf/cudf/core/multiindex.py @@ -667,7 +667,7 @@ def isin(self, values, level=None): self_df = self.to_frame(index=False).reset_index() values_df = values_idx.to_frame(index=False) idx = self_df.merge(values_df, how="leftsemi")._data["index"] - res = cudf.core.column.full(size=len(self), fill_value=False) + res = column.as_column(False, length=len(self)) res[idx] = True result = res.values else: @@ -1845,9 +1845,9 @@ def get_indexer(self, target, method=None, limit=None, tolerance=None): "index must be monotonic increasing or decreasing" ) - result = cudf.core.column.full( - len(target), - fill_value=-1, + result = column.as_column( + -1, + length=len(target), dtype=libcudf.types.size_type_dtype, ) if not len(self): diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index cb5008af3ad..1b18e11c047 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -55,7 +55,6 @@ IntervalColumn, TimeDeltaColumn, as_column, - full, ) from cudf.core.column.categorical import ( CategoricalAccessor as CategoricalAccessor, @@ -1311,7 +1310,7 @@ def map(self, arg, na_action=None) -> "Series": { "x": arg.keys(), "s": arg.values(), - "bool": full(len(arg), True, dtype=self.dtype), + "bool": as_column(True, length=len(arg), dtype=self.dtype), } ) res = lhs.merge(rhs, on="x", how="left").sort_values( @@ -1333,7 +1332,7 @@ def map(self, arg, na_action=None) -> "Series": { "x": arg.keys(), "s": arg, - "bool": full(len(arg), True, dtype=self.dtype), + "bool": as_column(True, length=len(arg), dtype=self.dtype), } ) res = lhs.merge(rhs, on="x", how="left").sort_values( diff --git a/python/cudf/cudf/core/tools/datetimes.py b/python/cudf/cudf/core/tools/datetimes.py index 0e0df4ecf6e..d182b7b4a7c 100644 --- a/python/cudf/cudf/core/tools/datetimes.py +++ b/python/cudf/cudf/core/tools/datetimes.py @@ -770,7 +770,7 @@ def _isin_datetimelike( was_string = len(rhs) and rhs.dtype.kind == "O" if rhs.dtype.kind in {"f", "i", "u"}: - return cudf.core.column.full(len(lhs), False, dtype="bool") + return column.as_column(False, length=len(lhs), dtype="bool") rhs = rhs.astype(lhs.dtype) if was_string: warnings.warn( @@ -787,7 +787,7 @@ def _isin_datetimelike( except ValueError: # pandas functionally returns all False when cleansing via # typecasting fails - return cudf.core.column.full(len(lhs), False, dtype="bool") + return column.as_column(False, length=len(lhs), dtype="bool") res = lhs._obtain_isin_result(rhs) return res diff --git a/python/cudf/cudf/core/window/rolling.py b/python/cudf/cudf/core/window/rolling.py index 890e4ecc2f0..2037b1682db 100644 --- a/python/cudf/cudf/core/window/rolling.py +++ b/python/cudf/cudf/core/window/rolling.py @@ -9,7 +9,6 @@ import cudf from cudf import _lib as libcudf from cudf.api.types import is_integer, is_number -from cudf.core import column from cudf.core.buffer import acquire_spill_lock from cudf.core.column.column import as_column from cudf.core.mixins import Reducible @@ -236,8 +235,8 @@ def _apply_agg_column(self, source_column, agg_name): window = None else: preceding_window = as_column(self.window) - following_window = column.full( - self.window.size, 0, dtype=self.window.dtype + following_window = as_column( + 0, length=self.window.size, dtype=self.window.dtype ) window = None diff --git a/python/cudf/cudf/io/parquet.py b/python/cudf/cudf/io/parquet.py index 6c70b08384f..bead9c352ef 100644 --- a/python/cudf/cudf/io/parquet.py +++ b/python/cudf/cudf/io/parquet.py @@ -20,7 +20,7 @@ import cudf from cudf._lib import parquet as libparquet from cudf.api.types import is_list_like -from cudf.core.column import build_categorical_column, column_empty, full +from cudf.core.column import as_column, build_categorical_column, column_empty from cudf.utils import ioutils from cudf.utils.nvtx_annotation import _cudf_nvtx_annotate @@ -762,9 +762,9 @@ def _parquet_to_frame( _len = len(dfs[-1]) if partition_categories and name in partition_categories: # Build the categorical column from `codes` - codes = full( - size=_len, - fill_value=partition_categories[name].index(value), + codes = as_column( + partition_categories[name].index(value), + length=_len, ) dfs[-1][name] = build_categorical_column( categories=partition_categories[name], @@ -788,10 +788,10 @@ def _parquet_to_frame( masked=True, ) else: - dfs[-1][name] = full( - size=_len, - fill_value=value, + dfs[-1][name] = as_column( + value, dtype=_dtype, + length=_len, ) if len(dfs) > 1: diff --git a/python/cudf/cudf/tests/test_testing.py b/python/cudf/cudf/tests/test_testing.py index 091cd6b57a4..1994536f395 100644 --- a/python/cudf/cudf/tests/test_testing.py +++ b/python/cudf/cudf/tests/test_testing.py @@ -6,7 +6,7 @@ import pytest import cudf -from cudf.core.column.column import as_column, full +from cudf.core.column.column import as_column from cudf.testing import ( assert_frame_equal, assert_index_equal, @@ -172,8 +172,8 @@ def test_assert_column_equal_dtype_edge_cases(other): assert_column_equal(base.slice(0, 0), other.slice(0, 0), check_dtype=False) assert_column_equal(other.slice(0, 0), base.slice(0, 0), check_dtype=False) - base = full(len(base), fill_value=cudf.NA, dtype=base.dtype) - other = full(len(other), fill_value=cudf.NA, dtype=other.dtype) + base = as_column(cudf.NA, length=len(base), dtype=base.dtype) + other = as_column(cudf.NA, length=len(other), dtype=other.dtype) assert_column_equal(base, other, check_dtype=False) assert_column_equal(other, base, check_dtype=False) diff --git a/python/cudf/cudf/utils/utils.py b/python/cudf/cudf/utils/utils.py index ec5693e14d2..95621cf9519 100644 --- a/python/cudf/cudf/utils/utils.py +++ b/python/cudf/cudf/utils/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. import decimal import functools @@ -396,8 +396,8 @@ def _all_bools_with_nulls(lhs, rhs, bool_fill_value): else: result_mask = None - result_col = column.full( - size=len(lhs), fill_value=bool_fill_value, dtype=cudf.dtype(np.bool_) + result_col = column.as_column( + bool_fill_value, dtype=cudf.dtype(np.bool_), length=len(lhs) ) if result_mask is not None: result_col = result_col.set_mask(result_mask.as_mask()) diff --git a/python/dask_cudf/dask_cudf/backends.py b/python/dask_cudf/dask_cudf/backends.py index 454cce76ff2..317c45ba582 100644 --- a/python/dask_cudf/dask_cudf/backends.py +++ b/python/dask_cudf/dask_cudf/backends.py @@ -105,8 +105,10 @@ def _get_non_empty_data(s): categories = ( s.categories if len(s.categories) else [UNKNOWN_CATEGORIES] ) - codes = cudf.core.column.full( - size=2, fill_value=0, dtype=cudf._lib.types.size_type_dtype + codes = cudf.core.column.as_column( + 0, + dtype=cudf._lib.types.size_type_dtype, + length=2, ) ordered = s.ordered data = cudf.core.column.build_categorical_column(