Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix various issues in Index.intersection #14054

Merged
merged 9 commits into from
Sep 12, 2023
26 changes: 22 additions & 4 deletions python/cudf/cudf/core/_base_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,19 +608,37 @@ def intersection(self, other, sort=False):
(1, 'Blue')],
)
"""
if not can_convert_to_column(other):
raise TypeError("Input must be Index or array-like")

if not isinstance(other, BaseIndex):
other = cudf.Index(other, name=self.name)
other = cudf.Index(
other,
name=getattr(other, "name", self.name),
)

if sort not in {None, False}:
raise ValueError(
f"The 'sort' keyword only takes the values of "
f"None or False; {sort} was passed."
)

if self.equals(other):
if not len(self) or self.equals(other):
dtypes = [self.dtype, other.dtype]
common_dtype = cudf.utils.dtypes.common_dtype_compatible(dtypes)

if self.has_duplicates:
galipremsagar marked this conversation as resolved.
Show resolved Hide resolved
return self.unique()._get_reconciled_name_object(other)
return self._get_reconciled_name_object(other)
return (
self.unique()
._get_reconciled_name_object(other)
.astype(common_dtype)
)
return self._get_reconciled_name_object(other).astype(common_dtype)
elif not len(other):
dtypes = [self.dtype, other.dtype]
common_dtype = cudf.utils.dtypes.common_dtype_compatible(dtypes)

return other._get_reconciled_name_object(self).astype(common_dtype)

res_name = _get_result_name(self.name, other.name)

Expand Down
6 changes: 4 additions & 2 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,9 @@ def _union(self, other, sort=None):
@_cudf_nvtx_annotate
def _intersection(self, other, sort=False):
if not isinstance(other, RangeIndex):
return super()._intersection(other, sort=sort)
return self._try_reconstruct_range_index(
super()._intersection(other, sort=sort)
)

if not len(self) or not len(other):
return RangeIndex(0)
Expand Down Expand Up @@ -722,7 +724,7 @@ def _intersection(self, other, sort=False):
if sort is None:
new_index = new_index.sort_values()

return new_index
return self._try_reconstruct_range_index(new_index)

@_cudf_nvtx_annotate
def difference(self, other, sort=None):
Expand Down
4 changes: 4 additions & 0 deletions python/cudf/cudf/core/join/_join_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def _match_join_keys(
common_type = ltype.categories.dtype
else:
common_type = rtype.categories.dtype
if cudf.get_option(
"mode.pandas_compatible"
) and common_type == cudf.dtype("object"):
common_type = "str"
return lcol.astype(common_type), rcol.astype(common_type)

if is_dtype_equal(ltype, rtype):
Expand Down
40 changes: 32 additions & 8 deletions python/cudf/cudf/tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytest

import cudf
from cudf.api.types import is_bool_dtype
from cudf.core._compat import PANDAS_GE_133, PANDAS_GE_200
from cudf.core.index import (
CategoricalIndex,
Expand Down Expand Up @@ -2099,25 +2100,48 @@ def test_union_index(idx1, idx2, sort):
(pd.Index([0, 1, 2, 30], name=pd.NA), pd.Index([30, 0, 90, 100])),
(pd.Index([0, 1, 2, 30], name="a"), [90, 100]),
(pd.Index([0, 1, 2, 30]), pd.Index([0, 10, 1.0, 11])),
(pd.Index(["a", "b", "c", "d", "c"]), pd.Index(["a", "c", "z"])),
(
pd.Index(["a", "b", "c", "d", "c"]),
pd.Index(["a", "c", "z"], name="abc"),
),
(
pd.Index(["a", "b", "c", "d", "c"]),
pd.Index(["a", "b", "c", "d", "c"]),
),
(pd.Index([True, False, True, True]), pd.Index([10, 11, 12, 0, 1, 2])),
(pd.Index([True, False, True, True]), pd.Index([True, True])),
(pd.RangeIndex(0, 10, name="a"), pd.Index([5, 6, 7], name="b")),
(pd.Index(["a", "b", "c"], dtype="category"), pd.Index(["a", "b"])),
(pd.Index(["a", "b", "c"], dtype="category"), pd.Index([1, 2, 3])),
(pd.Index([0, 1, 2], dtype="category"), pd.RangeIndex(0, 10)),
(pd.Index(["a", "b", "c"], name="abc"), []),
(pd.Index([], name="abc"), pd.RangeIndex(0, 4)),
(pd.Index([1, 2, 3]), pd.Index([1, 2], dtype="category")),
(pd.Index([]), pd.Index([1, 2], dtype="category")),
],
)
@pytest.mark.parametrize("sort", [None, False])
def test_intersection_index(idx1, idx2, sort):
@pytest.mark.parametrize("pandas_compatible", [True, False])
def test_intersection_index(idx1, idx2, sort, pandas_compatible):
expected = idx1.intersection(idx2, sort=sort)

idx1 = cudf.from_pandas(idx1) if isinstance(idx1, pd.Index) else idx1
idx2 = cudf.from_pandas(idx2) if isinstance(idx2, pd.Index) else idx2

actual = idx1.intersection(idx2, sort=sort)

assert_eq(expected, actual, exact=False)
with cudf.option_context("mode.pandas_compatible", pandas_compatible):
idx1 = cudf.from_pandas(idx1) if isinstance(idx1, pd.Index) else idx1
idx2 = cudf.from_pandas(idx2) if isinstance(idx2, pd.Index) else idx2

actual = idx1.intersection(idx2, sort=sort)

# TODO: Resolve the bool vs ints mixed issue
# once pandas has a direction on this issue
# https://github.com/pandas-dev/pandas/issues/44000
assert_eq(
expected,
actual,
exact=False
if (is_bool_dtype(idx1.dtype) and not is_bool_dtype(idx2.dtype))
or (not is_bool_dtype(idx1.dtype) or is_bool_dtype(idx2.dtype))
else True,
)


@pytest.mark.parametrize(
Expand Down
13 changes: 13 additions & 0 deletions python/cudf/cudf/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,19 @@ def find_common_type(dtypes):
return cudf.dtype(common_dtype)


def common_dtype_compatible(dtypes):
galipremsagar marked this conversation as resolved.
Show resolved Hide resolved
"""
A utility function, that wraps around `find_common_type`
and return `str` when pandas comptibility mode is enabled.
"""
common_dtype = find_common_type(dtypes)
if cudf.get_option(
"mode.pandas_compatible"
) and common_dtype == cudf.dtype("O"):
return "str"
return common_dtype


def _can_cast(from_dtype, to_dtype):
"""
Utility function to determine if we can cast
Expand Down