From 416845e04bb3935f796dd9eaf627a764dc39570f Mon Sep 17 00:00:00 2001 From: Axel Huebl Date: Wed, 7 Feb 2024 13:57:55 -0800 Subject: [PATCH] MFIter in Python Iterate boxes and particle tiles in Python. This simplifies downstream usage with derived data types, avoiding "downcasting" iterators to their base types during iteration. --- src/Base/Iterator.H | 63 -------------------------------- src/Base/MultiFab.cpp | 20 ++-------- src/Particle/ParticleContainer.H | 8 ++-- src/amrex/MultiFab.py | 23 ++++++++++++ src/amrex/ParticleContainer.py | 31 ++++++++++++++++ 5 files changed, 62 insertions(+), 83 deletions(-) delete mode 100644 src/Base/Iterator.H diff --git a/src/Base/Iterator.H b/src/Base/Iterator.H deleted file mode 100644 index c4693588..00000000 --- a/src/Base/Iterator.H +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2021-2022 The AMReX Community - * - * Authors: Axel Huebl - * License: BSD-3-Clause-LBNL - */ -#pragma once - -#include "pyAMReX.H" - -#include -#include -#include -#include -#include -#include - -#include -#include - - -namespace pyAMReX -{ - /** This is a helper function for the C++ equivalent of void operator++() - * - * In Python, iterators always are called with __next__, even for the - * first access. This means we need to handle the first iterator element - * explicitly, otherwise we will jump directly to the 2nd element. We do - * this the same way as pybind11 does this, via a little state: - * https://github.com/AMReX-Codes/pyamrex/pull/50 - * https://github.com/pybind/pybind11/blob/v2.10.0/include/pybind11/pybind11.h#L2269-L2282 - * - * To avoid unnecessary (and expensive) copies, remember to only call this - * helper always with py::return_value_policy::reference_internal! - * - * - * @tparam T_Iterator This is usally MFIter or Par(Const)Iter or derived classes - * @param it the current iterator - * @return the updated iterator - */ - template< typename T_Iterator > - T_Iterator & - iterator_next( T_Iterator & it ) - { - py::object self = py::cast(it); - if (!py::hasattr(self, "first_or_done")) - self.attr("first_or_done") = true; - - bool first_or_done = self.attr("first_or_done").cast(); - if (first_or_done) { - first_or_done = false; - self.attr("first_or_done") = first_or_done; - } - else - ++it; - if( !it.isValid() ) - { - first_or_done = true; - it.Finalize(); - throw py::stop_iteration(); - } - return it; - } -} diff --git a/src/Base/MultiFab.cpp b/src/Base/MultiFab.cpp index 5853d0f5..61864a97 100644 --- a/src/Base/MultiFab.cpp +++ b/src/Base/MultiFab.cpp @@ -5,8 +5,6 @@ */ #include "pyAMReX.H" -#include "Base/Iterator.H" - #include #include #include @@ -69,11 +67,9 @@ void init_MultiFab(py::module &m) //.def(py::init< iMultiFab const & >()) //.def(py::init< iMultiFab const &, MFItInfo const & >()) - // eq. to void operator++() - .def("__next__", - &pyAMReX::iterator_next, - py::return_value_policy::reference_internal - ) + // helpers for iteration __next__ + .def("_incr", &MFIter::operator++) + .def("finalize", &MFIter::Finalize) .def("tilebox", py::overload_cast< >(&MFIter::tilebox, py::const_)) .def("tilebox", py::overload_cast< IntVect const & >(&MFIter::tilebox, py::const_)) @@ -114,16 +110,6 @@ void init_MultiFab(py::module &m) .def_property_readonly("size", &FabArrayBase::size) .def_property_readonly("n_grow_vect", &FabArrayBase::nGrowVect) - - /* data access in Box index space */ - .def("__iter__", - [](FabArrayBase& fab) { - return MFIter(fab); - }, - // while the returned iterator (argument 0) exists, - // keep the FabArrayBase (argument 1; usually a MultiFab) alive - py::keep_alive<0, 1>() - ) ; py_FabArray_FArrayBox diff --git a/src/Particle/ParticleContainer.H b/src/Particle/ParticleContainer.H index 8cc12b91..92105d27 100644 --- a/src/Particle/ParticleContainer.H +++ b/src/Particle/ParticleContainer.H @@ -7,8 +7,6 @@ #include "pyAMReX.H" -#include "Base/Iterator.H" - #include #include #include @@ -75,9 +73,13 @@ void make_Base_Iterators (py::module &m, std::string allocstr) .def_property_readonly("num_neighbor_particles", &iterator_base::numNeighborParticles) .def_property_readonly("level", &iterator_base::GetLevel) .def_property_readonly("pair_index", &iterator_base::GetPairIndex) + .def_property_readonly("is_valid", &iterator_base::isValid) .def("geom", &iterator_base::Geom, py::arg("level")) // eq. to void operator++() + .def("_incr", &iterator_base::operator++) + .def("finalize", &iterator_base::Finalize) + /* .def("__next__", &pyAMReX::iterator_next, py::return_value_policy::reference_internal @@ -87,7 +89,7 @@ void make_Base_Iterators (py::module &m, std::string allocstr) return it; }, py::return_value_policy::reference_internal - ) + )*/ ; // only legacy particle has an AoS data structure for positions and id+cpu diff --git a/src/amrex/MultiFab.py b/src/amrex/MultiFab.py index d53f1371..67f801da 100644 --- a/src/amrex/MultiFab.py +++ b/src/amrex/MultiFab.py @@ -89,9 +89,32 @@ def mf_to_cupy(self, copy=False, order="F"): return views +def next(self): + if hasattr(self, "first_or_done") is False: + self.first_or_done = True + + first_or_done = self.first_or_done + if first_or_done: + first_or_done = False + self.first_or_done = first_or_done + else: + self._incr() + if self.is_valid is False: + # self.first_or_done = True + self.finalize() + raise StopIteration + + return self + + def register_MultiFab_extension(amr): """MultiFab helper methods""" + # register member functions for the MFIter type + amr.MFIter.__next__ = next + # FabArrayBase: iterate as data access in Box index space + amr.FabArrayBase.__iter__ = lambda fab: amr.MFIter(fab) + # register member functions for the MultiFab type amr.MultiFab.to_numpy = lambda self, copy=False, order="F": mf_to_numpy( amr, self, copy, order diff --git a/src/amrex/ParticleContainer.py b/src/amrex/ParticleContainer.py index ab934920..1bfbad7f 100644 --- a/src/amrex/ParticleContainer.py +++ b/src/amrex/ParticleContainer.py @@ -111,11 +111,42 @@ def pc_to_df(self, local=True, comm=None, root_rank=0): return df +def next(self): + if hasattr(self, "first_or_done") is False: + self.first_or_done = True + + first_or_done = self.first_or_done + if first_or_done: + first_or_done = False + self.first_or_done = first_or_done + else: + self._incr() + if self.is_valid is False: + # self.first_or_done = True + self.finalize() + raise StopIteration + + return self + + def register_ParticleContainer_extension(amr): """ParticleContainer helper methods""" import inspect import sys + # register member functions for every Par(Const)Iter* type + for _, ParIter_type in inspect.getmembers( + sys.modules[amr.__name__], + lambda member: inspect.isclass(member) + and member.__module__ == amr.__name__ + and ( + member.__name__.startswith("ParIter") + or member.__name__.startswith("ParConstIter") + ), + ): + ParIter_type.__next__ = next + ParIter_type.__iter__ = lambda self: self + # register member functions for every ParticleContainer_* type for _, ParticleContainer_type in inspect.getmembers( sys.modules[amr.__name__],