From 95d53770e1a0a88e3bf68ac53ff441c238af2d28 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 7 May 2024 18:50:37 +0200 Subject: [PATCH] Improved typing --- distributed/shuffle/_core.py | 6 +++--- distributed/shuffle/_pickle.py | 10 +++++----- distributed/shuffle/_shuffle.py | 6 +++--- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index 14cd82f6b63..726279ce1aa 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -298,8 +298,8 @@ async def receive(self, data: list[tuple[_T_partition_id, Any]] | bytes) -> None def _deduplicate_inputs( self, data: Iterable[tuple[_T_partition_id, Iterable[tuple[_T_partition_id, _T]]]], - ) -> list[_T]: - deduplicated = [] + ) -> list[tuple[_T_partition_id, _T]]: + deduplicated: list[tuple[_T_partition_id, _T]] = [] for input_partition_id, batch in data: if input_partition_id in self.received: continue @@ -324,7 +324,7 @@ def _get_assigned_worker(self, i: _T_partition_id) -> str: """Get the address of the worker assigned to the output partition""" @abc.abstractmethod - async def _receive(self, data: Iterable[Any]) -> None: + async def _receive(self, data: Iterable[tuple[_T_partition_id, Any]]) -> None: """Receive shards belonging to output partitions of this shuffle run""" def add_partition( diff --git a/distributed/shuffle/_pickle.py b/distributed/shuffle/_pickle.py index 9f40407faaa..9226bd7d3b8 100644 --- a/distributed/shuffle/_pickle.py +++ b/distributed/shuffle/_pickle.py @@ -116,19 +116,19 @@ def restore_dataframe_shard( from dask.dataframe._compat import PANDAS_GE_150, PANDAS_GE_210 def _ensure_arrow_dtypes_copied(blk: Block) -> Block: - if isinstance(blk.dtype, pd.StringDtype) and blk.dtype.storage in ( + if isinstance(blk.dtype, pd.StringDtype) and blk.dtype.storage in ( # type: ignore[attr-defined] "pyarrow", "pyarrow_numpy", ): arr = blk.values._pa_array.combine_chunks() - if blk.dtype.storage == "pyarrow": - arr = pd.arrays.ArrowStringArray(arr) + if blk.dtype.storage == "pyarrow": # type: ignore[attr-defined] + arr = pd.arrays.ArrowStringArray(arr) # type: ignore[attr-defined] else: arr = pd.array(arr, dtype=blk.dtype) return make_block(arr, blk.mgr_locs) elif PANDAS_GE_150 and isinstance(blk.dtype, pd.ArrowDtype): return make_block( - pd.arrays.ArrowExtensionArray(blk.values._pa_array.combine_chunks()), + pd.arrays.ArrowExtensionArray(blk.values._pa_array.combine_chunks()), # type: ignore[attr-defined] blk.mgr_locs, ) return blk @@ -137,6 +137,6 @@ def _ensure_arrow_dtypes_copied(blk: Block) -> Block: axes = [meta.columns, index] manager = BlockManager(blocks, axes, verify_integrity=False) if PANDAS_GE_210: - return pd.DataFrame._from_mgr(manager, axes) + return pd.DataFrame._from_mgr(manager, axes) # type: ignore[attr-defined] else: return pd.DataFrame(manager) diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index ebfd9a7b728..fe74263a6ff 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -332,12 +332,12 @@ def split_by_worker( """ out: defaultdict[str, list[tuple[int, list[PickleBuffer]]]] = defaultdict(list) - base = df[column].values.base + base = df[column].values.base # type: ignore[union-attr] for output_part_id, part in df.groupby(column, observed=True): assert isinstance(output_part_id, int) if output_part_id not in worker_for: continue - if part[column].values.base is base and len(part) != len(base): + if part[column].values.base is base and len(part) != len(base): # type: ignore[union-attr, arg-type] if drop_column: del part[column] part = part.copy(deep=True) @@ -444,7 +444,7 @@ async def _receive( # PickleBuffer objects may have been converted to bytearray by the # pickle roundtrip that is done by _core.py when buffers are too small self, - data: Iterable[list[PickleBuffer | bytes | bytearray]], + data: Iterable[tuple[int, list[PickleBuffer | bytes | bytearray]]], ) -> None: self.raise_if_closed()