Skip to content

Commit

Permalink
[df] Enable cloning of AsNumpyResult
Browse files Browse the repository at this point in the history
  • Loading branch information
vepadulano committed Mar 28, 2024
1 parent dcebd46 commit 1edca86
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,21 @@ def __setstate__(self, state):
self._py_arrays = state


def _clone_asnumpyresult(res: AsNumpyResult) -> AsNumpyResult:
"""
Clones the internal actions held by the input result and returns a new
result.
"""
import ROOT
return AsNumpyResult(
{
col: ROOT.Internal.RDF.CloneResultAndAction(ptr)
for (col, ptr) in res._result_ptrs.items()
},
res._columns
)


class HistoProfileWrapper(MethodTemplateWrapper):
"""
Subclass of MethodTemplateWrapper that pythonizes HistoXD and ProfileXD
Expand Down
25 changes: 25 additions & 0 deletions bindings/pyroot/pythonizations/test/rdataframe_asnumpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
import pickle

from ROOT._pythonization._rdataframe import _clone_asnumpyresult


def make_tree(*dtypes):
"""
Expand Down Expand Up @@ -308,6 +310,29 @@ def test_memory_adoption_complex_types(self):
pyarr[0][0] = 42
self.assertTrue(cpparr[0][0] == pyarr[0][0])

def test_cloning(self):
"""
Testing cloning of AsNumpy results
"""
df = ROOT.RDataFrame(20).Define("x", "rdfentry_")
ranges = [(0, 5), (5, 10), (10, 15), (15, 20)]

# Get the result for the first range
(begin, end) = ranges.pop(0)
ROOT.Internal.RDF.ChangeEmptyEntryRange(
ROOT.RDF.AsRNode(df), (begin, end))
asnumpyres = df.AsNumpy(["x"], lazy=True) # To return an AsNumpyResult
self.assertSequenceEqual(
asnumpyres.GetValue()["x"].tolist(), np.arange(begin, end).tolist())

# Clone the result for following ranges
for (begin, end) in ranges:
ROOT.Internal.RDF.ChangeEmptyEntryRange(
ROOT.RDF.AsRNode(df), (begin, end))
asnumpyres = _clone_asnumpyresult(asnumpyres)
self.assertSequenceEqual(
asnumpyres.GetValue()["x"].tolist(), np.arange(begin, end).tolist())


if __name__ == '__main__':
unittest.main()

0 comments on commit 1edca86

Please sign in to comment.