Skip to content

Commit

Permalink
Backport PR pandas-dev#60324: REF: centralize pyarrow Table to pandas…
Browse files Browse the repository at this point in the history
… conversions and types_mapper handling (pandas-dev#60332)

(cherry picked from commit 12d6f60)

Co-authored-by: Joris Van den Bossche <[email protected]>
  • Loading branch information
WillAyd and jorisvandenbossche authored Nov 16, 2024
1 parent 4f13697 commit 38565aa
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 122 deletions.
49 changes: 47 additions & 2 deletions pandas/io/_util.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
from __future__ import annotations

from typing import Callable
from typing import (
TYPE_CHECKING,
Literal,
)

import numpy as np

from pandas._config import using_string_dtype

from pandas._libs import lib
from pandas.compat import pa_version_under18p0
from pandas.compat._optional import import_optional_dependency

import pandas as pd

if TYPE_CHECKING:
from collections.abc import Callable

import pyarrow

from pandas._typing import DtypeBackend


def _arrow_dtype_mapping() -> dict:
pa = import_optional_dependency("pyarrow")
Expand All @@ -30,7 +43,7 @@ def _arrow_dtype_mapping() -> dict:
}


def arrow_string_types_mapper() -> Callable:
def _arrow_string_types_mapper() -> Callable:
pa = import_optional_dependency("pyarrow")

mapping = {
Expand All @@ -41,3 +54,35 @@ def arrow_string_types_mapper() -> Callable:
mapping[pa.string_view()] = pd.StringDtype(na_value=np.nan)

return mapping.get


def arrow_table_to_pandas(
table: pyarrow.Table,
dtype_backend: DtypeBackend | Literal["numpy"] | lib.NoDefault = lib.no_default,
null_to_int64: bool = False,
to_pandas_kwargs: dict | None = None,
) -> pd.DataFrame:
if to_pandas_kwargs is None:
to_pandas_kwargs = {}

pa = import_optional_dependency("pyarrow")

types_mapper: type[pd.ArrowDtype] | None | Callable
if dtype_backend == "numpy_nullable":
mapping = _arrow_dtype_mapping()
if null_to_int64:
# Modify the default mapping to also map null to Int64
# (to match other engines - only for CSV parser)
mapping[pa.null()] = pd.Int64Dtype()
types_mapper = mapping.get
elif dtype_backend == "pyarrow":
types_mapper = pd.ArrowDtype
elif using_string_dtype():
types_mapper = _arrow_string_types_mapper()
elif dtype_backend is lib.no_default or dtype_backend == "numpy":
types_mapper = None
else:
raise NotImplementedError

df = table.to_pandas(types_mapper=types_mapper, **to_pandas_kwargs)
return df
17 changes: 2 additions & 15 deletions pandas/io/feather_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
from pandas.util._decorators import doc
from pandas.util._validators import check_dtype_backend

import pandas as pd
from pandas.core.api import DataFrame
from pandas.core.shared_docs import _shared_docs

from pandas.io._util import arrow_string_types_mapper
from pandas.io._util import arrow_table_to_pandas
from pandas.io.common import get_handle

if TYPE_CHECKING:
Expand Down Expand Up @@ -128,16 +127,4 @@ def read_feather(
pa_table = feather.read_table(
handles.handle, columns=columns, use_threads=bool(use_threads)
)

if dtype_backend == "numpy_nullable":
from pandas.io._util import _arrow_dtype_mapping

return pa_table.to_pandas(types_mapper=_arrow_dtype_mapping().get)

elif dtype_backend == "pyarrow":
return pa_table.to_pandas(types_mapper=pd.ArrowDtype)

elif using_string_dtype():
return pa_table.to_pandas(types_mapper=arrow_string_types_mapper())
else:
raise NotImplementedError
return arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)
15 changes: 2 additions & 13 deletions pandas/io/json/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from pandas.core.dtypes.dtypes import PeriodDtype

from pandas import (
ArrowDtype,
DataFrame,
Index,
MultiIndex,
Expand All @@ -52,6 +51,7 @@
from pandas.core.reshape.concat import concat
from pandas.core.shared_docs import _shared_docs

from pandas.io._util import arrow_table_to_pandas
from pandas.io.common import (
IOHandles,
dedup_names,
Expand Down Expand Up @@ -997,18 +997,7 @@ def read(self) -> DataFrame | Series:
if self.engine == "pyarrow":
pyarrow_json = import_optional_dependency("pyarrow.json")
pa_table = pyarrow_json.read_json(self.data)

mapping: type[ArrowDtype] | None | Callable
if self.dtype_backend == "pyarrow":
mapping = ArrowDtype
elif self.dtype_backend == "numpy_nullable":
from pandas.io._util import _arrow_dtype_mapping

mapping = _arrow_dtype_mapping().get
else:
mapping = None

return pa_table.to_pandas(types_mapper=mapping)
return arrow_table_to_pandas(pa_table, dtype_backend=self.dtype_backend)
elif self.engine == "ujson":
if self.lines:
if self.chunksize:
Expand Down
21 changes: 2 additions & 19 deletions pandas/io/orc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,13 @@
Literal,
)

from pandas._config import using_string_dtype

from pandas._libs import lib
from pandas.compat._optional import import_optional_dependency
from pandas.util._validators import check_dtype_backend

import pandas as pd
from pandas.core.indexes.api import default_index

from pandas.io._util import arrow_string_types_mapper
from pandas.io._util import arrow_table_to_pandas
from pandas.io.common import (
get_handle,
is_fsspec_url,
Expand Down Expand Up @@ -117,21 +114,7 @@ def read_orc(
pa_table = orc.read_table(
source=source, columns=columns, filesystem=filesystem, **kwargs
)
if dtype_backend is not lib.no_default:
if dtype_backend == "pyarrow":
df = pa_table.to_pandas(types_mapper=pd.ArrowDtype)
else:
from pandas.io._util import _arrow_dtype_mapping

mapping = _arrow_dtype_mapping()
df = pa_table.to_pandas(types_mapper=mapping.get)
return df
else:
if using_string_dtype():
types_mapper = arrow_string_types_mapper()
else:
types_mapper = None
return pa_table.to_pandas(types_mapper=types_mapper)
return arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)


def to_orc(
Expand Down
34 changes: 18 additions & 16 deletions pandas/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
Literal,
)
import warnings
from warnings import catch_warnings
from warnings import (
catch_warnings,
filterwarnings,
)

from pandas._config import using_string_dtype
from pandas._config.config import _get_option

from pandas._libs import lib
Expand All @@ -22,14 +24,13 @@
from pandas.util._exceptions import find_stack_level
from pandas.util._validators import check_dtype_backend

import pandas as pd
from pandas import (
DataFrame,
get_option,
)
from pandas.core.shared_docs import _shared_docs

from pandas.io._util import arrow_string_types_mapper
from pandas.io._util import arrow_table_to_pandas
from pandas.io.common import (
IOHandles,
get_handle,
Expand Down Expand Up @@ -250,20 +251,10 @@ def read(
kwargs["use_pandas_metadata"] = True

to_pandas_kwargs = {}
if dtype_backend == "numpy_nullable":
from pandas.io._util import _arrow_dtype_mapping

mapping = _arrow_dtype_mapping()
to_pandas_kwargs["types_mapper"] = mapping.get
elif dtype_backend == "pyarrow":
to_pandas_kwargs["types_mapper"] = pd.ArrowDtype # type: ignore[assignment]
elif using_string_dtype():
to_pandas_kwargs["types_mapper"] = arrow_string_types_mapper()

manager = _get_option("mode.data_manager", silent=True)
if manager == "array":
to_pandas_kwargs["split_blocks"] = True # type: ignore[assignment]

to_pandas_kwargs["split_blocks"] = True
path_or_handle, handles, filesystem = _get_path_or_handle(
path,
filesystem,
Expand All @@ -278,7 +269,18 @@ def read(
filters=filters,
**kwargs,
)
result = pa_table.to_pandas(**to_pandas_kwargs)

with catch_warnings():
filterwarnings(
"ignore",
"make_block is deprecated",
DeprecationWarning,
)
result = arrow_table_to_pandas(
pa_table,
dtype_backend=dtype_backend,
to_pandas_kwargs=to_pandas_kwargs,
)

if manager == "array":
result = result._as_manager("array", copy=False)
Expand Down
33 changes: 12 additions & 21 deletions pandas/io/parsers/arrow_parser_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from typing import TYPE_CHECKING
import warnings

from pandas._config import using_string_dtype

from pandas._libs import lib
from pandas.compat._optional import import_optional_dependency
from pandas.errors import (
Expand All @@ -16,18 +14,14 @@
from pandas.core.dtypes.common import pandas_dtype
from pandas.core.dtypes.inference import is_integer

import pandas as pd
from pandas import DataFrame

from pandas.io._util import (
_arrow_dtype_mapping,
arrow_string_types_mapper,
)
from pandas.io._util import arrow_table_to_pandas
from pandas.io.parsers.base_parser import ParserBase

if TYPE_CHECKING:
from pandas._typing import ReadBuffer

from pandas import DataFrame


class ArrowParserWrapper(ParserBase):
"""
Expand Down Expand Up @@ -287,17 +281,14 @@ def read(self) -> DataFrame:

table = table.cast(new_schema)

if dtype_backend == "pyarrow":
frame = table.to_pandas(types_mapper=pd.ArrowDtype)
elif dtype_backend == "numpy_nullable":
# Modify the default mapping to also
# map null to Int64 (to match other engines)
dtype_mapping = _arrow_dtype_mapping()
dtype_mapping[pa.null()] = pd.Int64Dtype()
frame = table.to_pandas(types_mapper=dtype_mapping.get)
elif using_string_dtype():
frame = table.to_pandas(types_mapper=arrow_string_types_mapper())
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
"make_block is deprecated",
DeprecationWarning,
)
frame = arrow_table_to_pandas(
table, dtype_backend=dtype_backend, null_to_int64=True
)

else:
frame = table.to_pandas()
return self._finalize_pandas_output(frame)
41 changes: 7 additions & 34 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,7 @@
is_object_dtype,
is_string_dtype,
)
from pandas.core.dtypes.dtypes import (
ArrowDtype,
DatetimeTZDtype,
)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
from pandas.core.dtypes.missing import isna

from pandas import get_option
Expand All @@ -68,6 +65,8 @@
from pandas.core.internals.construction import convert_object_array
from pandas.core.tools.datetimes import to_datetime

from pandas.io._util import arrow_table_to_pandas

if TYPE_CHECKING:
from collections.abc import (
Iterator,
Expand Down Expand Up @@ -2221,23 +2220,10 @@ def read_table(
else:
stmt = f"SELECT {select_list} FROM {table_name}"

mapping: type[ArrowDtype] | None | Callable
if dtype_backend == "pyarrow":
mapping = ArrowDtype
elif dtype_backend == "numpy_nullable":
from pandas.io._util import _arrow_dtype_mapping

mapping = _arrow_dtype_mapping().get
elif using_string_dtype():
from pandas.io._util import arrow_string_types_mapper

mapping = arrow_string_types_mapper()
else:
mapping = None

with self.con.cursor() as cur:
cur.execute(stmt)
df = cur.fetch_arrow_table().to_pandas(types_mapper=mapping)
pa_table = cur.fetch_arrow_table()
df = arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)

return _wrap_result_adbc(
df,
Expand Down Expand Up @@ -2305,23 +2291,10 @@ def read_query(
if chunksize:
raise NotImplementedError("'chunksize' is not implemented for ADBC drivers")

mapping: type[ArrowDtype] | None | Callable
if dtype_backend == "pyarrow":
mapping = ArrowDtype
elif dtype_backend == "numpy_nullable":
from pandas.io._util import _arrow_dtype_mapping

mapping = _arrow_dtype_mapping().get
elif using_string_dtype():
from pandas.io._util import arrow_string_types_mapper

mapping = arrow_string_types_mapper()
else:
mapping = None

with self.con.cursor() as cur:
cur.execute(sql)
df = cur.fetch_arrow_table().to_pandas(types_mapper=mapping)
pa_table = cur.fetch_arrow_table()
df = arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend)

return _wrap_result_adbc(
df,
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,12 +959,12 @@ def sqlite_buildin_types(sqlite_buildin, types_data):

adbc_connectable_iris = [
pytest.param("postgresql_adbc_iris", marks=pytest.mark.db),
pytest.param("sqlite_adbc_iris", marks=pytest.mark.db),
"sqlite_adbc_iris",
]

adbc_connectable_types = [
pytest.param("postgresql_adbc_types", marks=pytest.mark.db),
pytest.param("sqlite_adbc_types", marks=pytest.mark.db),
"sqlite_adbc_types",
]


Expand Down

0 comments on commit 38565aa

Please sign in to comment.