Skip to content

Commit

Permalink
add fix for issue 21
Browse files Browse the repository at this point in the history
  • Loading branch information
dougbrn committed May 28, 2024
1 parent 7ffa702 commit 86ab3cf
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/nested_dask/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def add_nested(self, nested, name, how="outer") -> NestedFrame: # type: ignore[
-------
`nested_dask.NestedFrame`
"""
nested = nested.map_partitions(lambda x: pack_flat(x)).rename(name)
nested = nested.map_partitions(lambda x: pack_flat(npd.NestedFrame(x))).rename(name)
return self.join(nested, how=how)

def query(self, expr) -> Self: # type: ignore # noqa: F821:
Expand Down Expand Up @@ -213,7 +213,7 @@ def query(self, expr) -> Self: # type: ignore # noqa: F821:
>>> df.query("mynested.a > 2")
"""
return self.map_partitions(lambda x: x.query(expr), meta=self._meta)
return self.map_partitions(lambda x: npd.NestedFrame(x).query(expr), meta=self._meta)

def dropna(
self,
Expand Down Expand Up @@ -283,7 +283,7 @@ def dropna(
"""
# propagate meta, assumes row-based operation
return self.map_partitions(
lambda x: x.dropna(
lambda x: npd.NestedFrame(x).dropna(
axis=axis,
how=how,
thresh=thresh,
Expand Down Expand Up @@ -332,7 +332,9 @@ def reduce(self, func, *args, meta=None, **kwargs) -> NestedFrame:
"""

# apply nested_pandas reduce via map_partitions
return self.map_partitions(lambda x: x.reduce(func, *args, **kwargs), meta=meta)
# wrap the partition in a npd.NestedFrame call for:
# https://github.com/lincc-frameworks/nested-dask/issues/21
return self.map_partitions(lambda x: npd.NestedFrame(x).reduce(func, *args, **kwargs), meta=meta)

def to_parquet(self, path, by_layer=True, **kwargs) -> None:
"""Creates parquet file(s) with the data of a NestedFrame, either
Expand Down
30 changes: 30 additions & 0 deletions tests/nested_dask/test_nestedframe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import nested_dask as nd
import numpy as np
import pandas as pd
import pytest
from nested_pandas.series.dtype import NestedDtype

Expand Down Expand Up @@ -143,3 +144,32 @@ def test_to_parquet_by_layer(test_dataset, tmp_path):
loaded_dataset = loaded_dataset.compute()

assert test_dataset.equals(loaded_dataset)


def test_from_epyc():
"""test a dataset from epyc. Motivated by https://github.com/lincc-frameworks/nested-dask/issues/21"""
# Load some ZTF data
catalogs_dir = "https://epyc.astro.washington.edu/~lincc-frameworks/half_degree_surveys/ztf/"

object_ndf = (
nd.read_parquet(f"{catalogs_dir}/ztf_object", columns=["ra", "dec", "ps1_objid"])
.set_index("ps1_objid", sort=True)
.persist()
)

source_ndf = (
nd.read_parquet(
f"{catalogs_dir}/ztf_source", columns=["mjd", "mag", "magerr", "band", "ps1_objid", "catflags"]
)
.set_index("ps1_objid", sort=True)
.persist()
)

object_ndf = object_ndf.add_nested(source_ndf, "ztf_source")

# Apply a mean function
meta = pd.Series(name="mean", dtype=float)
result = object_ndf.reduce(np.mean, "ztf_source.mag", meta=meta).compute()

# just make sure the result was successfully computed
assert len(result) == 9817

0 comments on commit 86ab3cf

Please sign in to comment.