Skip to content

Commit

Permalink
Add support pandas 2 in flytekit (#1818)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Dec 18, 2023
1 parent 8fd3dfa commit 74f2f53
Show file tree
Hide file tree
Showing 43 changed files with 404 additions and 187 deletions.
44 changes: 41 additions & 3 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,46 @@ jobs:
key: ${{ format('{0}-pip-{1}', runner.os, hashFiles('dev-requirements.in', 'requirements.in')) }}
- name: Install dependencies
run: |
make setup && pip freeze
make setup
pip uninstall -y pandas
pip freeze
- name: Test with coverage
env:
PYTEST_OPTS: -n2
run: |
make unit_test_codecov
- name: Codecov
uses: codecov/[email protected]
with:
fail_ci_if_error: false
files: coverage.xml

build-with-pandas:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ ubuntu-latest ]
python-version: [ "3.11" ]
pandas: [ "pandas<2.0.0", "pandas>=2.0.0" ]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Cache pip
uses: actions/cache@v3
with:
# This path is specific to Ubuntu
path: ~/.cache/pip
# Look to see if there is a cache hit for the corresponding requirements files
key: ${{ format('{0}-pip-{1}', runner.os, hashFiles('dev-requirements.in', 'requirements.in')) }}
- name: Install dependencies
run: |
make setup
pip install --force-reinstall "${{ matrix.pandas }}"
pip freeze
- name: Test with coverage
env:
PYTEST_OPTS: -n2
Expand Down Expand Up @@ -69,8 +108,7 @@ jobs:
# Look to see if there is a cache hit for the corresponding requirements files
key: ${{ format('{0}-pip-{1}', runner.os, hashFiles('dev-requirements.in', 'requirements.in')) }}
- name: Install dependencies
run: |
make setup && pip freeze
run: make setup && pip freeze
- name: Test with coverage
env:
PYTEST_OPTS: -n2
Expand Down
7 changes: 5 additions & 2 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@ torch<=2.0.0; python_version>='3.11' or platform_system!='Windows'
# Once a solution is found, this should be updated to support Windows as well.
python-magic; (platform_system=='Darwin' or platform_system=='Linux')

pillow
scikit-learn
types-protobuf
types-croniter
types-mock
autoflake

pillow
numpy
pandas
scikit-learn
types-requests
prometheus-client
2 changes: 1 addition & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,7 @@ def lazy_import_transformers(cls):
from flytekit.extras import sklearn # noqa: F401
if is_imported("pandas"):
try:
from flytekit.types import schema # noqa: F401
from flytekit.types.schema.types_pandas import PandasSchemaReader, PandasSchemaWriter # noqa: F401
except ValueError:
logger.debug("Transformer for pandas is already registered.")
register_pandas_handlers()
Expand Down
9 changes: 6 additions & 3 deletions flytekit/extras/sqlite3/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@
import typing
from dataclasses import dataclass

import pandas as pd

from flytekit import FlyteContext, kwtypes
from flytekit import FlyteContext, kwtypes, lazy_module
from flytekit.configuration import DefaultImages, SerializationSettings
from flytekit.core.base_sql_task import SQLTask
from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask
from flytekit.core.shim_task import ShimTaskExecutor
from flytekit.models import task as task_models

if typing.TYPE_CHECKING:
import pandas as pd
else:
pd = lazy_module("pandas")


def unarchive_file(local_path: str, to_dir: str):
"""
Expand Down
14 changes: 14 additions & 0 deletions flytekit/lazy_import/lazy_module.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import importlib.util
import sys
import types

LAZY_MODULES = []


class LazyModule(types.ModuleType):
def __init__(self, module_name: str):
super().__init__(module_name)
self._module_name = module_name

def __getattribute__(self, attr):
raise ImportError(f"Module {object.__getattribute__(self, '_module_name')} is not yet installed.")


def is_imported(module_name):
"""
This function is used to check if a module has been imported by the regular import.
Expand All @@ -24,6 +34,10 @@ def lazy_module(fullname):
return sys.modules[fullname]
# https://docs.python.org/3/library/importlib.html#implementing-lazy-imports
spec = importlib.util.find_spec(fullname)
if spec is None:
# Return a lazy module if the module is not found in the python environment,
# so that we can raise a proper error when the user tries to access an attribute in the module.
return LazyModule(fullname)
loader = importlib.util.LazyLoader(spec.loader)
spec.loader = loader
module = importlib.util.module_from_spec(spec)
Expand Down
1 change: 0 additions & 1 deletion flytekit/types/schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@
SchemaReader,
SchemaWriter,
)
from .types_pandas import PandasSchemaReader, PandasSchemaWriter
6 changes: 4 additions & 2 deletions flytekit/types/schema/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import Type

import numpy as _np
import pandas
from dataclasses_json import config
from marshmallow import fields
from mashumaro.mixins.json import DataClassJSONMixin
Expand Down Expand Up @@ -270,7 +269,7 @@ def supported_mode(self) -> SchemaOpenMode:
return self._supported_mode

def open(
self, dataframe_fmt: type = pandas.DataFrame, override_mode: typing.Optional[SchemaOpenMode] = None
self, dataframe_fmt: typing.Optional[type] = None, override_mode: typing.Optional[SchemaOpenMode] = None
) -> typing.Union[SchemaReader, SchemaWriter]:
"""
Returns a reader or writer depending on the mode of the object when created. This mode can be
Expand All @@ -287,6 +286,9 @@ def open(
raise AssertionError("Readonly schema cannot be opened in write mode!")

mode = override_mode if override_mode else self._supported_mode
import pandas as pd

dataframe_fmt = dataframe_fmt if dataframe_fmt else pd.DataFrame
h = SchemaEngine.get_handler(dataframe_fmt)
if not h.handles_remote_io:
# The Schema Handler does not manage its own IO, and this it will expect the files are on local file-system
Expand Down
22 changes: 15 additions & 7 deletions flytekit/types/structured/basic_dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@
from pathlib import Path
from typing import TypeVar

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from botocore.exceptions import NoCredentialsError
from fsspec.core import split_protocol, strip_protocol
from fsspec.utils import get_protocol

from flytekit import FlyteContext, logger
from flytekit import FlyteContext, lazy_module, logger
from flytekit.configuration import DataConfig
from flytekit.core.data_persistence import get_fsspec_storage_options
from flytekit.models import literals
Expand All @@ -24,6 +21,13 @@
StructuredDatasetEncoder,
)

if typing.TYPE_CHECKING:
import pandas as pd
import pyarrow as pa
else:
pd = lazy_module("pandas")
pa = lazy_module("pyarrow")

T = TypeVar("T")


Expand Down Expand Up @@ -70,7 +74,7 @@ def decode(
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: StructuredDatasetMetadata,
) -> pd.DataFrame:
) -> "pd.DataFrame":
uri = flyte_value.uri
columns = None
kwargs = get_pandas_storage_options(uri=uri, data_config=ctx.file_access.data_config)
Expand Down Expand Up @@ -121,7 +125,7 @@ def decode(
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: StructuredDatasetMetadata,
) -> pd.DataFrame:
) -> "pd.DataFrame":
uri = flyte_value.uri
columns = None
kwargs = get_pandas_storage_options(uri=uri, data_config=ctx.file_access.data_config)
Expand All @@ -145,6 +149,8 @@ def encode(
structured_dataset: StructuredDataset,
structured_dataset_type: StructuredDatasetType,
) -> literals.StructuredDataset:
import pyarrow.parquet as pq

uri = typing.cast(str, structured_dataset.uri) or ctx.file_access.join(
ctx.file_access.raw_output_prefix, ctx.file_access.get_random_string()
)
Expand All @@ -165,7 +171,9 @@ def decode(
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: StructuredDatasetMetadata,
) -> pa.Table:
) -> "pa.Table":
import pyarrow.parquet as pq

uri = flyte_value.uri
if not ctx.file_access.is_remote(uri):
Path(uri).parent.mkdir(parents=True, exist_ok=True)
Expand Down
11 changes: 8 additions & 3 deletions flytekit/types/structured/bigquery.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import re
import typing

import pandas as pd
import pyarrow as pa
from google.cloud import bigquery, bigquery_storage
from google.cloud.bigquery_storage_v1 import types

from flytekit import FlyteContext
from flytekit import FlyteContext, lazy_module
from flytekit.models import literals
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
Expand All @@ -16,6 +14,13 @@
StructuredDatasetMetadata,
)

if typing.TYPE_CHECKING:
import pandas as pd
import pyarrow as pa
else:
pd = lazy_module("pandas")
pa = lazy_module("pyarrow")

BIGQUERY = "bq"


Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-duckdb/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "duckdb"]
plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "duckdb", "pandas"]

__version__ = "0.0.0+develop"

Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-mlflow/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

# TODO: support mlflow 2.0+
plugin_requires = ["flytekit>=1.1.0,<2.0.0", "plotly", "mlflow<2.0.0"]
plugin_requires = ["flytekit>=1.1.0,<2.0.0", "plotly", "mlflow<2.0.0", "pandas"]

__version__ = "0.0.0+develop"

Expand Down
3 changes: 2 additions & 1 deletion plugins/flytekit-pandera/flytekitplugins/pandera/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from flytekit.extend import TypeEngine, TypeTransformer
from flytekit.models.literals import Literal, Scalar, Schema
from flytekit.models.types import LiteralType, SchemaType
from flytekit.types.schema import FlyteSchema, PandasSchemaWriter, SchemaFormat, SchemaOpenMode
from flytekit.types.schema import FlyteSchema, SchemaFormat, SchemaOpenMode
from flytekit.types.schema.types import FlyteSchemaTransformer
from flytekit.types.schema.types_pandas import PandasSchemaWriter

pandas = lazy_module("pandas")
pandera = lazy_module("pandera")
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-pandera/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "pandera>=0.7.1"]
plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "pandera>=0.7.1", "pandas"]

__version__ = "0.0.0+develop"

Expand Down
5 changes: 1 addition & 4 deletions plugins/flytekit-polars/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = [
"flytekit>=1.3.0b2,<2.0.0",
"polars>=0.8.27,<0.17.0",
]
plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "polars>=0.8.27,<0.17.0", "pandas"]

__version__ = "0.0.0+develop"

Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-spark/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "pyspark>=3.0.0", "aiohttp", "flyteidl>=1.10.0"]
plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "pyspark>=3.0.0", "aiohttp", "flyteidl>=1.10.0", "pandas"]

__version__ = "0.0.0+develop"

Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-sqlalchemy/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "sqlalchemy>=1.4.7"]
plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "sqlalchemy>=1.4.7", "pandas"]

__version__ = "0.0.0+develop"

Expand Down
5 changes: 1 addition & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,11 @@ dependencies = [
"keyring>=18.0.1",
"kubernetes>=12.0.1",
"marshmallow-enum",
# TODO: remove upper-bound after fixing change in contract
"marshmallow-jsonschema>=0.12.0",
"mashumaro>=3.9.1",
"numpy",
"pandas>=1.0.0,<2.0.0",
# TODO: Remove upper-bound after protobuf community fixes it. https://github.com/flyteorg/flyte/issues/4359
"protobuf<4.25.0",
"pyarrow>=4.0.0",
"pyarrow",
"python-json-logger>=2.0.0",
"pytimeparse>=1.1.8,<2.0.0",
"pyyaml!=6.0.0,!=5.4.0,!=5.4.1", # pyyaml is broken with cython 3: https://github.com/yaml/pyyaml/issues/601
Expand Down
2 changes: 2 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from flytekit.interaction.click_types import DirParamType, FileParamType
from flytekit.remote import FlyteRemote

pytest.importorskip("pandas")

WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py")
REMOTE_WORKFLOW_FILE = "https://raw.githubusercontent.com/flyteorg/flytesnacks/8337b64b33df046b2f6e4cba03c74b7bdc0c4fb1/cookbook/core/flyte_basics/basic_workflow.py"
IMPERATIVE_WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "imperative_wf.py")
Expand Down
6 changes: 5 additions & 1 deletion tests/flytekit/unit/core/test_data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import pathlib
import random
import string
import sys
import tempfile

import mock
import pandas as pd
import pytest
from azure.identity import ClientSecretCredential, DefaultAzureCredential

from flytekit.core.data_persistence import FileAccessProvider
Expand All @@ -27,8 +28,11 @@ def test_is_remote():
assert fp.is_remote("s3://my-bucket/foo/bar") is True


@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.")
@mock.patch("flytekit.core.data_persistence.UUID")
def test_write_folder_put_raw(mock_uuid_class):
import pandas as pd

"""
A test that writes this structure
raw/
Expand Down
3 changes: 3 additions & 0 deletions tests/flytekit/unit/core/test_dataclass.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import sys
from dataclasses import dataclass
from typing import List

import pytest
from dataclasses_json import DataClassJsonMixin

from flytekit.core.task import task
from flytekit.core.workflow import workflow


@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.")
def test_dataclass():
@dataclass
class AppParams(DataClassJsonMixin):
Expand Down
Loading

0 comments on commit 74f2f53

Please sign in to comment.