Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Nate Parsons committed May 10, 2024
1 parent 529056d commit a6c1ff4
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ Release Notes
Future Release
==============
* Enhancements
* Add support for Python 3.12 :pr:`1855`
* Fixes
* Changes
* Add support for Python 3.12 :pr:`1855`
* Drop support for using Woodwork with Dask or Pyspark dataframes :pr:`1857`
* Use ``filter`` arg in call to ``tarfile.extractall`` to safely deserialize DataFrames :pr:`1862`
* Documentation Changes
Expand Down
3 changes: 3 additions & 0 deletions woodwork/deserializers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def read_table_typing_information(path, typing_info_filename, profile_name):

use_smartopen(file_path, path, transport_params)
with tarfile.open(str(file_path)) as tar:
import pdb

pdb.set_trace()
if "filter" in getfullargspec(tar.extractall).kwonlyargs:
tar.extractall(path=tmpdir, filter="data")
else:
Expand Down
53 changes: 52 additions & 1 deletion woodwork/tests/accessor/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import shutil
import warnings
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import boto3
import pandas as pd
Expand Down Expand Up @@ -662,6 +662,35 @@ def test_to_csv_S3(sample_df, s3_client, s3_bucket, profile_name):
assert sample_df.ww.schema == deserialized_df.ww.schema


@patch("woodwork.deserializers.deserializer_base.getfullargspec")
def test_to_csv_S3_errors_if_python_version_unsafe(
mock_inspect,
sample_df,
s3_client,
s3_bucket,
):
mock_response = MagicMock()
mock_response.kwonlyargs = []
mock_inspect.return_value = mock_response
sample_df.ww.init(
name="test_data",
index="id",
semantic_tags={"id": "tag1"},
logical_types={"age": Ordinal(order=[25, 33, 57])},
)
sample_df.ww.to_disk(
TEST_S3_URL,
format="csv",
encoding="utf-8",
engine="python",
profile_name=None,
)
make_public(s3_client, s3_bucket)

with pytest.raises(RuntimeError, match="Please upgrade your Python version"):
read_woodwork_table(TEST_S3_URL, profile_name=None)


@pytest.mark.parametrize("profile_name", [None, False])
def test_serialize_s3_pickle(sample_df, s3_client, s3_bucket, profile_name):
sample_df.ww.init()
Expand All @@ -688,6 +717,28 @@ def test_serialize_s3_parquet(sample_df, s3_client, s3_bucket, profile_name):
assert sample_df.ww.schema == deserialized_df.ww.schema


@patch("woodwork.deserializers.parquet_deserializer.getfullargspec")
def test_serialize_s3_parquet_errors_if_python_version_unsafe(
mock_inspect,
sample_df,
s3_client,
s3_bucket,
):
mock_response = MagicMock()
mock_response.kwonlyargs = []
mock_inspect.return_value = mock_response
sample_df.ww.init()
sample_df.ww.to_disk(TEST_S3_URL, format="parquet", profile_name=None)
make_public(s3_client, s3_bucket)

with pytest.raises(RuntimeError, match="Please upgrade your Python version"):
read_woodwork_table(
TEST_S3_URL,
filename="data.parquet",
profile_name=None,
)


def create_test_credentials(test_path):
with open(test_path, "w+") as f:
f.write("[test]\n")
Expand Down

0 comments on commit a6c1ff4

Please sign in to comment.