Skip to content

Commit

Permalink
move fix to __getitem__ and LocIndexer
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Jan 9, 2025
1 parent c7b2552 commit aa022c9
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 18 deletions.
2 changes: 2 additions & 0 deletions dask/dataframe/dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,8 @@ def __getitem__(self, other):
other = list(other)
elif isinstance(other, list):
other = other.copy()
elif is_scalar(other) and hasattr(other, "item"):
other = other.item()
return new_collection(self.expr.__getitem__(other))

def __dask_tokenize__(self):
Expand Down
5 changes: 2 additions & 3 deletions dask/dataframe/dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from dask.dataframe.dask_expr import _core as core
from dask.dataframe.dask_expr._util import (
_calc_maybe_new_divisions,
_columns_equal,
_convert_to_list,
_tokenize_deterministic,
_tokenize_partial,
Expand Down Expand Up @@ -3853,10 +3852,10 @@ def plain_column_projection(expr, parent, dependents, additional_columns=None):
# we are accesing the index
column_union = []

if _columns_equal(column_union, expr.frame.columns):
if column_union == expr.frame.columns:
return
result = type(expr)(expr.frame[column_union], *expr.operands[1:])
if _columns_equal(column_union, parent.operand("columns")):
if column_union == parent.operand("columns"):
return result
return type(parent)(result, parent.operand("columns"))

Expand Down
3 changes: 3 additions & 0 deletions dask/dataframe/dask_expr/_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def __getitem__(self, key):
iindexer = key
cindexer = None

if is_scalar(cindexer) and hasattr(cindexer, "item"):
cindexer = cindexer.item()

return self._loc(iindexer, cindexer)

def _loc(self, iindexer, cindexer):
Expand Down
10 changes: 0 additions & 10 deletions dask/dataframe/dask_expr/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,6 @@ def is_scalar(x):
return not isinstance(x, Expr)


def _columns_equal(left_columns, right_columns):
# Checks if left_columns and right_columns are equal.
# It is possible that one of the arguments is a
# numpy array or a numpy scalar. Therefore, we
# cannot always rely on the == operator.
if is_scalar(left_columns) == is_scalar(right_columns):
return _convert_to_list(left_columns) == _convert_to_list(right_columns)
return False


def _tokenize_deterministic(*args, **kwargs) -> str:
# Utility to be strict about deterministic tokens
return tokenize(*args, ensure_deterministic=True, **kwargs)
Expand Down
8 changes: 3 additions & 5 deletions dask/dataframe/dask_expr/io/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,12 +451,10 @@ def test_from_dask_array_projection():
arr = da.from_array(arr_np, chunks=(50, 10))
pdf = pd.DataFrame(arr_np)
df = from_dask_array(arr)
# Project possible np.int64(0) argument
# Check getitem[np.int64(0)]
dd.assert_eq(pdf[pdf.columns[0]], df[df.columns[0]])
# Project possible Index([0, 1], dtype='int64') argument
dd.assert_eq(pdf[pdf.columns[0:2]], df[df.columns[0:2]])
# Project list argument
dd.assert_eq(pdf[list(pdf.columns[0:2])], df[list(df.columns[0:2])])
# Check loc[:, np.int64(0)]
dd.assert_eq(pdf.loc[:, pdf.columns[0]], df.loc[:, df.columns[0]])


def test_from_dict():
Expand Down

0 comments on commit aa022c9

Please sign in to comment.