Skip to content

Commit

Permalink
Improved typing
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait committed May 7, 2024
1 parent fb43405 commit 95d5377
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
6 changes: 3 additions & 3 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,8 @@ async def receive(self, data: list[tuple[_T_partition_id, Any]] | bytes) -> None
def _deduplicate_inputs(
self,
data: Iterable[tuple[_T_partition_id, Iterable[tuple[_T_partition_id, _T]]]],
) -> list[_T]:
deduplicated = []
) -> list[tuple[_T_partition_id, _T]]:
deduplicated: list[tuple[_T_partition_id, _T]] = []
for input_partition_id, batch in data:
if input_partition_id in self.received:
continue
Expand All @@ -324,7 +324,7 @@ def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""

@abc.abstractmethod
async def _receive(self, data: Iterable[Any]) -> None:
async def _receive(self, data: Iterable[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""

def add_partition(
Expand Down
10 changes: 5 additions & 5 deletions distributed/shuffle/_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,19 @@ def restore_dataframe_shard(
from dask.dataframe._compat import PANDAS_GE_150, PANDAS_GE_210

def _ensure_arrow_dtypes_copied(blk: Block) -> Block:
if isinstance(blk.dtype, pd.StringDtype) and blk.dtype.storage in (
if isinstance(blk.dtype, pd.StringDtype) and blk.dtype.storage in ( # type: ignore[attr-defined]
"pyarrow",
"pyarrow_numpy",
):
arr = blk.values._pa_array.combine_chunks()
if blk.dtype.storage == "pyarrow":
arr = pd.arrays.ArrowStringArray(arr)
if blk.dtype.storage == "pyarrow": # type: ignore[attr-defined]
arr = pd.arrays.ArrowStringArray(arr) # type: ignore[attr-defined]
else:
arr = pd.array(arr, dtype=blk.dtype)
return make_block(arr, blk.mgr_locs)
elif PANDAS_GE_150 and isinstance(blk.dtype, pd.ArrowDtype):
return make_block(
pd.arrays.ArrowExtensionArray(blk.values._pa_array.combine_chunks()),
pd.arrays.ArrowExtensionArray(blk.values._pa_array.combine_chunks()), # type: ignore[attr-defined]
blk.mgr_locs,
)
return blk
Expand All @@ -137,6 +137,6 @@ def _ensure_arrow_dtypes_copied(blk: Block) -> Block:
axes = [meta.columns, index]
manager = BlockManager(blocks, axes, verify_integrity=False)
if PANDAS_GE_210:
return pd.DataFrame._from_mgr(manager, axes)
return pd.DataFrame._from_mgr(manager, axes) # type: ignore[attr-defined]
else:
return pd.DataFrame(manager)
6 changes: 3 additions & 3 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,12 @@ def split_by_worker(
"""
out: defaultdict[str, list[tuple[int, list[PickleBuffer]]]] = defaultdict(list)

base = df[column].values.base
base = df[column].values.base # type: ignore[union-attr]
for output_part_id, part in df.groupby(column, observed=True):
assert isinstance(output_part_id, int)
if output_part_id not in worker_for:
continue
if part[column].values.base is base and len(part) != len(base):
if part[column].values.base is base and len(part) != len(base): # type: ignore[union-attr, arg-type]
if drop_column:
del part[column]
part = part.copy(deep=True)
Expand Down Expand Up @@ -444,7 +444,7 @@ async def _receive(
# PickleBuffer objects may have been converted to bytearray by the
# pickle roundtrip that is done by _core.py when buffers are too small
self,
data: Iterable[list[PickleBuffer | bytes | bytearray]],
data: Iterable[tuple[int, list[PickleBuffer | bytes | bytearray]]],
) -> None:
self.raise_if_closed()

Expand Down

0 comments on commit 95d5377

Please sign in to comment.