Skip to content

Commit

Permalink
Use filter arg to safe extract archives (#1862)
Browse files Browse the repository at this point in the history
* use filter to safe extract archives

* update release notes

* update actions

* add tests

* fix action

* final test
  • Loading branch information
thehomebrewnerd authored May 13, 2024
1 parent 762e08f commit 27edbcd
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pull_request_check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
name: pull request check
runs-on: ubuntu-latest
steps:
- uses: nearform/github-action-check-linked-issues@v1
- uses: nearform-actions/github-action-check-linked-issues@v1
id: check-linked-issues
with:
exclude-branches: "release_v**, backport_v**, main, latest-dep-update-**, min-dep-update-**, dependabot/**"
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/release_notes_updated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ jobs:
- name: Check for development branch
id: branch
shell: python
env:
REF: ${{ github.event.pull_request.head.ref }}
run: |
from re import compile
main = '^main$'
Expand All @@ -21,7 +23,7 @@ jobs:
min_dep_update = '^min-dep-update-[a-f0-9]{7}$'
regex = main, release, backport, dep_update, min_dep_update
patterns = list(map(compile, regex))
ref = "${{ github.event.pull_request.head.ref }}"
ref = "$REF"
is_dev = not any(pattern.match(ref) for pattern in patterns)
print('::set-output name=is_dev::' + str(is_dev))
- if: ${{ steps.branch.outputs.is_dev == 'True' }}
Expand Down
5 changes: 3 additions & 2 deletions docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ Release Notes
Future Release
==============
* Enhancements
* Add support for Python 3.12 :pr:`1855`
* Fixes
* Changes
* Add support for Python 3.12 :pr:`1855`
* Drop support for using Woodwork with Dask or Pyspark dataframes (:pr:`1857`)
* Drop support for using Woodwork with Dask or Pyspark dataframes :pr:`1857`
* Use ``filter`` arg in call to ``tarfile.extractall`` to safely deserialize DataFrames :pr:`1862`
* Documentation Changes
* Testing Changes

Expand Down
8 changes: 7 additions & 1 deletion woodwork/deserializers/deserializer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import tarfile
import tempfile
import warnings
from inspect import getfullargspec
from itertools import zip_longest
from pathlib import Path

Expand Down Expand Up @@ -125,7 +126,12 @@ def read_from_s3(self, profile_name):

use_smartopen(tar_filepath, self.path, transport_params)
with tarfile.open(str(tar_filepath)) as tar:
tar.extractall(path=tmpdir)
if "filter" in getfullargspec(tar.extractall).kwonlyargs:
tar.extractall(path=tmpdir, filter="data")
else:
raise RuntimeError(
"Please upgrade your Python version to the latest patch release to allow for safe extraction of the EntitySet archive.",
)
self.read_path = os.path.join(
tmpdir,
self.typing_info["loading_info"]["location"],
Expand Down
8 changes: 7 additions & 1 deletion woodwork/deserializers/parquet_deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import tarfile
import tempfile
from inspect import getfullargspec
from pathlib import Path

import pandas as pd
Expand Down Expand Up @@ -61,7 +62,12 @@ def read_from_s3(self, profile_name):

use_smartopen(tar_filepath, self.path, transport_params)
with tarfile.open(str(tar_filepath)) as tar:
tar.extractall(path=tmpdir)
if "filter" in getfullargspec(tar.extractall).kwonlyargs:
tar.extractall(path=tmpdir, filter="data")
else:
raise RuntimeError(
"Please upgrade your Python version to the latest patch release to allow for safe extraction of the EntitySet archive.",
)

self.read_path = os.path.join(tmpdir, self.data_subdirectory, self.filename)

Expand Down
8 changes: 7 additions & 1 deletion woodwork/deserializers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import tarfile
import tempfile
from inspect import getfullargspec
from pathlib import Path

from woodwork.deserializers import (
Expand Down Expand Up @@ -99,7 +100,12 @@ def read_table_typing_information(path, typing_info_filename, profile_name):

use_smartopen(file_path, path, transport_params)
with tarfile.open(str(file_path)) as tar:
tar.extractall(path=tmpdir)
if "filter" in getfullargspec(tar.extractall).kwonlyargs:
tar.extractall(path=tmpdir, filter="data")
else:
raise RuntimeError(
"Please upgrade your Python version to the latest patch release to allow for safe extraction of the EntitySet archive.",
)

file = os.path.join(tmpdir, typing_info_filename)
with open(file, "r") as file:
Expand Down
70 changes: 69 additions & 1 deletion woodwork/tests/accessor/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import shutil
import warnings
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import boto3
import pandas as pd
Expand Down Expand Up @@ -662,6 +662,35 @@ def test_to_csv_S3(sample_df, s3_client, s3_bucket, profile_name):
assert sample_df.ww.schema == deserialized_df.ww.schema


@patch("woodwork.deserializers.utils.getfullargspec")
def test_to_csv_S3_errors_if_python_version_unsafe(
mock_inspect,
sample_df,
s3_client,
s3_bucket,
):
mock_response = MagicMock()
mock_response.kwonlyargs = []
mock_inspect.return_value = mock_response
sample_df.ww.init(
name="test_data",
index="id",
semantic_tags={"id": "tag1"},
logical_types={"age": Ordinal(order=[25, 33, 57])},
)
sample_df.ww.to_disk(
TEST_S3_URL,
format="csv",
encoding="utf-8",
engine="python",
profile_name=None,
)
make_public(s3_client, s3_bucket)

with pytest.raises(RuntimeError, match="Please upgrade your Python version"):
read_woodwork_table(TEST_S3_URL, profile_name=None)


@pytest.mark.parametrize("profile_name", [None, False])
def test_serialize_s3_pickle(sample_df, s3_client, s3_bucket, profile_name):
sample_df.ww.init()
Expand All @@ -673,6 +702,23 @@ def test_serialize_s3_pickle(sample_df, s3_client, s3_bucket, profile_name):
assert sample_df.ww.schema == deserialized_df.ww.schema


@patch("woodwork.deserializers.deserializer_base.getfullargspec")
def test_serialize_s3_pickle_errors_if_python_version_unsafe(
mock_inspect,
sample_df,
s3_client,
s3_bucket,
):
mock_response = MagicMock()
mock_response.kwonlyargs = []
mock_inspect.return_value = mock_response
sample_df.ww.init()
sample_df.ww.to_disk(TEST_S3_URL, format="pickle", profile_name=None)
make_public(s3_client, s3_bucket)
with pytest.raises(RuntimeError, match="Please upgrade your Python version"):
read_woodwork_table(TEST_S3_URL, profile_name=None)


@pytest.mark.parametrize("profile_name", [None, False])
def test_serialize_s3_parquet(sample_df, s3_client, s3_bucket, profile_name):
sample_df.ww.init()
Expand All @@ -688,6 +734,28 @@ def test_serialize_s3_parquet(sample_df, s3_client, s3_bucket, profile_name):
assert sample_df.ww.schema == deserialized_df.ww.schema


@patch("woodwork.deserializers.parquet_deserializer.getfullargspec")
def test_serialize_s3_parquet_errors_if_python_version_unsafe(
mock_inspect,
sample_df,
s3_client,
s3_bucket,
):
mock_response = MagicMock()
mock_response.kwonlyargs = []
mock_inspect.return_value = mock_response
sample_df.ww.init()
sample_df.ww.to_disk(TEST_S3_URL, format="parquet", profile_name=None)
make_public(s3_client, s3_bucket)

with pytest.raises(RuntimeError, match="Please upgrade your Python version"):
read_woodwork_table(
TEST_S3_URL,
filename="data.parquet",
profile_name=None,
)


def create_test_credentials(test_path):
with open(test_path, "w+") as f:
f.write("[test]\n")
Expand Down

0 comments on commit 27edbcd

Please sign in to comment.