Skip to content

Commit

Permalink
MFIter in Python
Browse files Browse the repository at this point in the history
Iterate boxes and particle tiles in Python. This simplifies downstream usage
with derived data types, avoiding "downcasting" iterators to their base
types during iteration.
  • Loading branch information
ax3l committed Feb 7, 2024
1 parent e442401 commit 416845e
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 83 deletions.
63 changes: 0 additions & 63 deletions src/Base/Iterator.H

This file was deleted.

20 changes: 3 additions & 17 deletions src/Base/MultiFab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
*/
#include "pyAMReX.H"

#include "Base/Iterator.H"

#include <AMReX_BoxArray.H>
#include <AMReX_DistributionMapping.H>
#include <AMReX_FArrayBox.H>
Expand Down Expand Up @@ -69,11 +67,9 @@ void init_MultiFab(py::module &m)
//.def(py::init< iMultiFab const & >())
//.def(py::init< iMultiFab const &, MFItInfo const & >())

// eq. to void operator++()
.def("__next__",
&pyAMReX::iterator_next<MFIter>,
py::return_value_policy::reference_internal
)
// helpers for iteration __next__
.def("_incr", &MFIter::operator++)
.def("finalize", &MFIter::Finalize)

.def("tilebox", py::overload_cast< >(&MFIter::tilebox, py::const_))
.def("tilebox", py::overload_cast< IntVect const & >(&MFIter::tilebox, py::const_))
Expand Down Expand Up @@ -114,16 +110,6 @@ void init_MultiFab(py::module &m)
.def_property_readonly("size", &FabArrayBase::size)

.def_property_readonly("n_grow_vect", &FabArrayBase::nGrowVect)

/* data access in Box index space */
.def("__iter__",
[](FabArrayBase& fab) {
return MFIter(fab);
},
// while the returned iterator (argument 0) exists,
// keep the FabArrayBase (argument 1; usually a MultiFab) alive
py::keep_alive<0, 1>()
)
;

py_FabArray_FArrayBox
Expand Down
8 changes: 5 additions & 3 deletions src/Particle/ParticleContainer.H
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

#include "pyAMReX.H"

#include "Base/Iterator.H"

#include <AMReX_BoxArray.H>
#include <AMReX_GpuAllocators.H>
#include <AMReX_IntVect.H>
Expand Down Expand Up @@ -75,9 +73,13 @@ void make_Base_Iterators (py::module &m, std::string allocstr)
.def_property_readonly("num_neighbor_particles", &iterator_base::numNeighborParticles)
.def_property_readonly("level", &iterator_base::GetLevel)
.def_property_readonly("pair_index", &iterator_base::GetPairIndex)
.def_property_readonly("is_valid", &iterator_base::isValid)
.def("geom", &iterator_base::Geom, py::arg("level"))

// eq. to void operator++()
.def("_incr", &iterator_base::operator++)
.def("finalize", &iterator_base::Finalize)
/*
.def("__next__",
&pyAMReX::iterator_next<iterator_base>,
py::return_value_policy::reference_internal
Expand All @@ -87,7 +89,7 @@ void make_Base_Iterators (py::module &m, std::string allocstr)
return it;
},
py::return_value_policy::reference_internal
)
)*/
;

// only legacy particle has an AoS data structure for positions and id+cpu
Expand Down
23 changes: 23 additions & 0 deletions src/amrex/MultiFab.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,32 @@ def mf_to_cupy(self, copy=False, order="F"):
return views


def next(self):
if hasattr(self, "first_or_done") is False:
self.first_or_done = True

first_or_done = self.first_or_done
if first_or_done:
first_or_done = False
self.first_or_done = first_or_done
else:
self._incr()
if self.is_valid is False:
# self.first_or_done = True
self.finalize()
raise StopIteration

return self


def register_MultiFab_extension(amr):
"""MultiFab helper methods"""

# register member functions for the MFIter type
amr.MFIter.__next__ = next
# FabArrayBase: iterate as data access in Box index space
amr.FabArrayBase.__iter__ = lambda fab: amr.MFIter(fab)

# register member functions for the MultiFab type
amr.MultiFab.to_numpy = lambda self, copy=False, order="F": mf_to_numpy(
amr, self, copy, order
Expand Down
31 changes: 31 additions & 0 deletions src/amrex/ParticleContainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,42 @@ def pc_to_df(self, local=True, comm=None, root_rank=0):
return df


def next(self):
if hasattr(self, "first_or_done") is False:
self.first_or_done = True

first_or_done = self.first_or_done
if first_or_done:
first_or_done = False
self.first_or_done = first_or_done
else:
self._incr()
if self.is_valid is False:
# self.first_or_done = True
self.finalize()
raise StopIteration

return self


def register_ParticleContainer_extension(amr):
"""ParticleContainer helper methods"""
import inspect
import sys

# register member functions for every Par(Const)Iter* type
for _, ParIter_type in inspect.getmembers(
sys.modules[amr.__name__],
lambda member: inspect.isclass(member)
and member.__module__ == amr.__name__
and (
member.__name__.startswith("ParIter")
or member.__name__.startswith("ParConstIter")
),
):
ParIter_type.__next__ = next
ParIter_type.__iter__ = lambda self: self

# register member functions for every ParticleContainer_* type
for _, ParticleContainer_type in inspect.getmembers(
sys.modules[amr.__name__],
Expand Down

0 comments on commit 416845e

Please sign in to comment.