Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get and set interface #48

Merged
merged 1 commit into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 70 additions & 4 deletions src/nested_dask/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import dask.dataframe as dd
import dask_expr as dx
import nested_pandas as npd
import pandas as pd
import pyarrow as pa
from dask_expr._collection import new_collection
from nested_pandas.series.dtype import NestedDtype
from nested_pandas.series.packer import pack_flat
from nested_pandas.series.packer import pack, pack_flat
from pandas._libs import lib
from pandas._typing import AnyAll, Axis, IndexLabel
from pandas.api.extensions import no_default
Expand Down Expand Up @@ -59,9 +61,64 @@ class NestedFrame(

_partition_type = npd.NestedFrame # Tracks the underlying data type

def __getitem__(self, key):
result = super().__getitem__(key)
return result
def __getitem__(self, item):
"""Adds custom __getitem__ functionality for nested columns"""
if isinstance(item, str) and self._is_known_hierarchical_column(item):
nested, col = item.split(".")
meta = pd.Series(name=col, dtype=pd.ArrowDtype(self.dtypes[nested].fields[col]))
return self.map_partitions(lambda x: x[nested].nest.get_flat_series(col), meta=meta)
else:
return super().__getitem__(item)

def _nested_meta_from_flat(self, flat, name):
"""construct meta for a packed series from a flat dataframe"""
pd_fields = flat.dtypes.to_dict() # grabbing pandas dtypes
pyarrow_fields = {} # grab underlying pyarrow dtypes
for field, dtype in pd_fields.items():
if hasattr(dtype, "pyarrow_dtype"):
pyarrow_fields[field] = dtype.pyarrow_dtype
else: # or convert from numpy types
pyarrow_fields[field] = pa.from_numpy_dtype(dtype)
return pd.Series(name=name, dtype=NestedDtype.from_fields(pyarrow_fields))

def __setitem__(self, key, value):
"""Adds custom __setitem__ behavior for nested columns"""

# Replacing or adding columns to a nested structure
# Allows statements like ndf["nested.t"] = ndf["nested.t"] - 5
# Or ndf["nested.base_t"] = ndf["nested.t"] - 5
# Performance note: This requires building a new nested structure
if self._is_known_hierarchical_column(key) or (
"." in key and key.split(".")[0] in self.nested_columns
):
nested, col = key.split(".")

# View the nested column as a flat df
new_flat = self[nested].nest.to_flat()
new_flat[col] = value

# Handle strings specially
if isinstance(value, str):
new_flat = new_flat.astype({col: pd.ArrowDtype(pa.string())})

# pack the modified df back into a nested column
meta = self._nested_meta_from_flat(new_flat, nested)
packed = new_flat.map_partitions(lambda x: pack(x), meta=meta)
return super().__setitem__(nested, packed)

# Adding a new nested structure from a column
# Allows statements like ndf["new_nested.t"] = ndf["nested.t"] - 5
elif "." in key:
new_nested, col = key.split(".")
if isinstance(value, dd.Series):
value.name = col
value = value.to_frame()

meta = self._nested_meta_from_flat(value, new_nested)
packed = value.map_partitions(lambda x: pack(x), meta=meta)
return super().__setitem__(new_nested, packed)

return super().__setitem__(key, value)

@classmethod
def from_pandas(
Expand Down Expand Up @@ -246,6 +303,15 @@ def nested_columns(self) -> list:
nest_cols.append(column)
return nest_cols

def _is_known_hierarchical_column(self, colname) -> bool:
"""Determine whether a string is a known hierarchical column name"""
if "." in colname:
left, right = colname.split(".")
if left in self.nested_columns:
return right in self.all_columns[left]
return False
return False

def add_nested(self, nested, name, how="outer") -> NestedFrame: # type: ignore[name-defined] # noqa: F821
"""Packs a dataframe into a nested column

Expand Down
50 changes: 50 additions & 0 deletions tests/nested_dask/test_nestedframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pandas as pd
import pytest
from nested_dask.datasets import generate_data
from nested_pandas.series.dtype import NestedDtype


Expand Down Expand Up @@ -39,6 +40,55 @@ def test_nested_columns(test_dataset):
assert test_dataset.nested_columns == ["nested"]


def test_getitem_on_nested():
"""test getitem with nested columns"""
ndf = generate_data(10, 10, npartitions=3, seed=1)

nest_col = ndf["nested.t"]

assert len(nest_col) == 100
assert nest_col.name == "t"


def test_set_or_replace_nested_col():
"""Test that __setitem__ can set or replace a column in a existing nested structure"""

ndf = generate_data(10, 10, npartitions=3, seed=1)

# test direct replacement, with ints
orig_t_head = ndf["nested.t"].head(10, npartitions=-1)

ndf["nested.t"] = ndf["nested.t"] + 1
assert np.array_equal(ndf["nested.t"].head(10).values.to_numpy(), orig_t_head.values.to_numpy() + 1)

# test direct replacement, with str
ndf["nested.band"] = "lsst"
assert np.all(ndf["nested.band"].compute().values.to_numpy() == "lsst")

# test setting a new column within nested
ndf["nested.t_plus_flux"] = ndf["nested.t"] + ndf["nested.flux"]

true_vals = (ndf["nested.t"] + ndf["nested.flux"]).head(10).values.to_numpy()
assert np.array_equal(ndf["nested.t_plus_flux"].head(10).values.to_numpy(), true_vals)


def test_set_new_nested_col():
"""Test that __setitem__ can create a new nested structure"""

ndf = generate_data(10, 10, npartitions=3, seed=1)

# assign column in new nested structure from columns in nested
ndf["new_nested.t_plus_flux"] = ndf["nested.t"] + ndf["nested.flux"]

assert "new_nested" in ndf.nested_columns
assert "t_plus_flux" in ndf["new_nested"].nest.fields

assert np.array_equal(
ndf["new_nested.t_plus_flux"].compute().values.to_numpy(),
ndf["nested.t"].compute().values.to_numpy() + ndf["nested.flux"].compute().values.to_numpy(),
)


def test_add_nested(test_dataset_no_add_nested):
"""test the add_nested function"""
base, layer = test_dataset_no_add_nested
Expand Down
Loading