Skip to content

Commit

Permalink
Feat: Adds interop with Arrow library using new method `Dataset.to_ar…
Browse files Browse the repository at this point in the history
…row()` (#281)
  • Loading branch information
avirajsingh7 authored Jul 9, 2024
1 parent 3e1ffb0 commit 2682ed2
Show file tree
Hide file tree
Showing 8 changed files with 261 additions and 107 deletions.
35 changes: 35 additions & 0 deletions airbyte/caches/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import TYPE_CHECKING, Any, Optional, final

import pandas as pd
import pyarrow as pa
import pyarrow.dataset as ds
from pydantic import Field, PrivateAttr

from airbyte_protocol.models import ConfiguredAirbyteCatalog
Expand All @@ -19,6 +21,7 @@
from airbyte._future_cdk.state_writers import StdOutStateWriter
from airbyte.caches._catalog_backend import CatalogBackendBase, SqlCatalogBackend
from airbyte.caches._state_backend import SqlStateBackend
from airbyte.constants import DEFAULT_ARROW_MAX_CHUNK_SIZE
from airbyte.datasets._sql import CachedDataset


Expand Down Expand Up @@ -146,6 +149,38 @@ def get_pandas_dataframe(
engine = self.get_sql_engine()
return pd.read_sql_table(table_name, engine, schema=self.schema_name)

def get_arrow_dataset(
self,
stream_name: str,
*,
max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE,
) -> ds.Dataset:
"""Return an Arrow Dataset with the stream's data."""
table_name = self._read_processor.get_sql_table_name(stream_name)
engine = self.get_sql_engine()

# Read the table in chunks to handle large tables which does not fits in memory
pandas_chunks = pd.read_sql_table(
table_name=table_name,
con=engine,
schema=self.schema_name,
chunksize=max_chunk_size,
)

arrow_batches_list = []
arrow_schema = None

for pandas_chunk in pandas_chunks:
if arrow_schema is None:
# Initialize the schema with the first chunk
arrow_schema = pa.Schema.from_pandas(pandas_chunk)

# Convert each pandas chunk to an Arrow Table
arrow_table = pa.RecordBatch.from_pandas(pandas_chunk, schema=arrow_schema)
arrow_batches_list.append(arrow_table)

return ds.dataset(arrow_batches_list)

@final
@property
def streams(self) -> dict[str, CachedDataset]:
Expand Down
15 changes: 15 additions & 0 deletions airbyte/caches/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,24 @@
from airbyte.caches.base import (
CacheBase,
)
from airbyte.constants import DEFAULT_ARROW_MAX_CHUNK_SIZE


class BigQueryCache(BigQueryConfig, CacheBase):
"""The BigQuery cache implementation."""

_sql_processor_class: type[BigQuerySqlProcessor] = PrivateAttr(default=BigQuerySqlProcessor)

def get_arrow_dataset(
self,
stream_name: str,
*,
max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE,
) -> None:
"""Raises NotImplementedError; BigQuery doesn't support `pd.read_sql_table`.
https://github.com/airbytehq/PyAirbyte/issues/165
"""
raise NotImplementedError(
"BigQuery doesn't currently support to_arrow"
"Please consider using a different cache implementation for these functionalities."
)
3 changes: 3 additions & 0 deletions airbyte/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,6 @@
Specific caches may override this value with a different schema name.
"""

DEFAULT_ARROW_MAX_CHUNK_SIZE = 100_000
"""The default number of records to include in each batch of an Arrow dataset."""
15 changes: 15 additions & 0 deletions airbyte/datasets/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@
from typing import TYPE_CHECKING, Any, cast

from pandas import DataFrame
from pyarrow.dataset import Dataset

from airbyte._util.document_rendering import DocumentRenderer
from airbyte.constants import DEFAULT_ARROW_MAX_CHUNK_SIZE


if TYPE_CHECKING:
from pyarrow.dataset import Dataset

from airbyte_protocol.models import ConfiguredAirbyteStream

from airbyte.documents import Document
Expand All @@ -37,6 +41,17 @@ def to_pandas(self) -> DataFrame:
# duck typing is correct for this use case.
return DataFrame(cast(Iterator[dict[str, Any]], self))

def to_arrow(
self,
*,
max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE,
) -> Dataset:
"""Return an Arrow Dataset representation of the dataset.
This method should be implemented by subclasses.
"""
raise NotImplementedError("Not implemented in base class")

def to_documents(
self,
title_property: str | None = None,
Expand Down
29 changes: 29 additions & 0 deletions airbyte/datasets/_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@

from airbyte_protocol.models.airbyte_protocol import ConfiguredAirbyteStream

from airbyte.constants import DEFAULT_ARROW_MAX_CHUNK_SIZE
from airbyte.datasets._base import DatasetBase


if TYPE_CHECKING:
from collections.abc import Iterator

from pandas import DataFrame
from pyarrow.dataset import Dataset
from sqlalchemy import Table
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.selectable import Selectable
Expand Down Expand Up @@ -102,6 +104,13 @@ def __len__(self) -> int:
def to_pandas(self) -> DataFrame:
return self._cache.get_pandas_dataframe(self._stream_name)

def to_arrow(
self,
*,
max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE,
) -> Dataset:
return self._cache.get_arrow_dataset(self._stream_name, max_chunk_size=max_chunk_size)

def with_filter(self, *filter_expressions: ClauseElement | str) -> SQLDataset:
"""Filter the dataset by a set of column values.
Expand Down Expand Up @@ -166,6 +175,26 @@ def to_pandas(self) -> DataFrame:
"""Return the underlying dataset data as a pandas DataFrame."""
return self._cache.get_pandas_dataframe(self._stream_name)

@overrides
def to_arrow(
self,
*,
max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE,
) -> Dataset:
"""Return an Arrow Dataset containing the data from the specified stream.
Args:
stream_name (str): Name of the stream to retrieve data from.
max_chunk_size (int): max number of records to include in each batch of pyarrow dataset.
Returns:
pa.dataset.Dataset: Arrow Dataset containing the stream's data.
"""
return self._cache.get_arrow_dataset(
stream_name=self._stream_name,
max_chunk_size=max_chunk_size,
)

def to_sql_table(self) -> Table:
"""Return the underlying SQL table as a SQLAlchemy Table object."""
return self._cache.processor.get_sql_table(self.stream_name)
Expand Down
Loading

0 comments on commit 2682ed2

Please sign in to comment.