diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4018509f..f0b95ee1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -48,7 +48,7 @@ jobs: - name: Install basic dependencies (MPI, FFTW, poetry, IMO) run: | ./bin/install-deps.sh ${{ matrix.mpi }} - pip install poetry + pip install poetry==1.8.5 mkdir -p $HOME/.config/litebird_imo echo -e "[[repositories]]\nlocation = \"$(pwd)/test/mock_imo/\"\nname = \"Mock IMO\"" | tee $HOME/.config/litebird_imo/imo.toml @@ -87,6 +87,7 @@ jobs: for proc in 1 5 9 ; do echo "Running MPI test ($MPI) with $proc processes" PYTHONPATH=. mpiexec -n $proc python3 ./test/test_mpi.py + PYTHONPATH=. mpiexec -n $proc python3 ./test/test_detector_blocks.py done - name: Tests OpenMPI @@ -95,4 +96,5 @@ jobs: for proc in 1 2 ; do echo "Running MPI test ($MPI) with $proc processes" PYTHONPATH=. mpiexec -n $proc python3 ./test/test_mpi.py + PYTHONPATH=. mpiexec -n $proc python3 ./test/test_detector_blocks.py done diff --git a/.readthedocs.yml b/.readthedocs.yml index 4c13f051..30681daf 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -10,8 +10,8 @@ build: jobs: pre_create_environment: - asdf plugin add poetry - - asdf install poetry latest - - asdf global poetry latest + - asdf install poetry 1.8.5 + - asdf global poetry 1.8.5 - poetry export --without-hashes > requirements.txt post_install: - pip install -r requirements.txt diff --git a/docs/source/images/detector_block_communicator.png b/docs/source/images/detector_block_communicator.png new file mode 100644 index 00000000..f9a6c526 Binary files /dev/null and b/docs/source/images/detector_block_communicator.png differ diff --git a/docs/source/images/detector_groups_case1.png b/docs/source/images/detector_groups_case1.png new file mode 100644 index 00000000..67c1bcd3 Binary files /dev/null and b/docs/source/images/detector_groups_case1.png differ diff --git a/docs/source/images/detector_groups_case2.png b/docs/source/images/detector_groups_case2.png new file mode 100644 index 00000000..aafedb06 Binary files /dev/null and b/docs/source/images/detector_groups_case2.png differ diff --git a/docs/source/images/detector_groups_case3.png b/docs/source/images/detector_groups_case3.png new file mode 100644 index 00000000..2908e0e3 Binary files /dev/null and b/docs/source/images/detector_groups_case3.png differ diff --git a/docs/source/images/grid_communicator.png b/docs/source/images/grid_communicator.png new file mode 100644 index 00000000..b4e850f3 Binary files /dev/null and b/docs/source/images/grid_communicator.png differ diff --git a/docs/source/images/time_block_communicator.png b/docs/source/images/time_block_communicator.png new file mode 100644 index 00000000..cf85aa54 Binary files /dev/null and b/docs/source/images/time_block_communicator.png differ diff --git a/docs/source/mpi.rst b/docs/source/mpi.rst index ec64e75a..79735ad1 100644 --- a/docs/source/mpi.rst +++ b/docs/source/mpi.rst @@ -133,6 +133,31 @@ variable :data:`.MPI_ENABLED`:: To ensure that your code uses MPI in the proper way, you should always use :data:`.MPI_COMM_WORLD` instead of importing ``mpi4py`` directly. +The simulation framework also provides a global object +:data:`.MPI_COMM_GRID`. It has two attributes: + +- ``COMM_OBS_GRID``: This is an MPI communicator that contains all the + MPI processes with the global rank less than ``n_blocks_time * n_blocks_det``. + It provides a safety net to the operations and MPI communications + that are needed to be performed only on the partition of :data:`.MPI_COMM_WORLD` + that contain non-zero number of pointings and TODs. By default, + ``COMM_OBS_GRID`` points to the global MPI communicator :data:`.MPI_COMM_WORLD`. + It is updated once :class:`.Observation` are defined. For example, + consider the case when a user runs the simulation with 10 MPI + processes but due some specific ``det_blocks_attributes`` argument + in :class:`.Observation` class, the number of detector and time + blocks are determined to be 2 and 4 respectively. Then the + simulation framework will store the pointings and TODs only on + :math:`2\times4=8` MPI processes and the last two ranks of :data:`.MPI_COMM_WORLD` + will be left unused. Once this happens, ``COMM_OBS_GRID`` on first 8 + ranks (rank 0 to 7) will point to the local sub-communicator + containing the processes with global rank 0 to 7. On the unused + ranks, it will simply point to the NULL communicator. +- ``COMM_NULL``: If :data:`.MPI_ENABLED` is ``True``, this object + points to a NULL MPI communicator (``mpi4py.MPI.COMM_NULL``). + Otherwise it is set to ``None``. The user should compare + ``COMM_OBS_GRID`` with ``COMM_NULL`` on every MPI process in order + to avoid running a piece of code on unused MPI processes. Enabling/disabling MPI ---------------------- diff --git a/docs/source/observations.rst b/docs/source/observations.rst index 9afe6798..f3379861 100644 --- a/docs/source/observations.rst +++ b/docs/source/observations.rst @@ -97,8 +97,15 @@ With this memory layout, typical operations look like this:: Parallel applications --------------------- -The only work that the :class:`.Observation` class actually does is handling -parallelism. ``obs.tod`` can be distributed over a +The :class:`.Observation` class allows the distribution of ``obs.tod`` over multiple MPI +processes to enable the parallelization of computations. The distribution of ``obs.tod`` +can be achieved in two different ways: + +1. Uniform distribution of detectors along the detector axis +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +With ``n_blocks_det`` and ``n_blocks_time`` arguments of :class:`.Observation` class, +the ``obs.tod`` is evenly distributed over a ``n_blocks_det`` by ``n_blocks_time`` grid of MPI ranks. The blocks can be changed at run-time. @@ -111,7 +118,7 @@ The main advantage is that the example operations in the Serial section are achieved with the same lines of code. The price to pay is that you have to set detector properties with special methods. -:: +.. code-block:: python import litebird_sim as lbs from mpi4py import MPI @@ -158,21 +165,235 @@ TOD) gets distributed. .. image:: ./images/observation_data_distribution.png -When ``n_blocks_det != 1``, keep in mind that ``obs.tod[0]`` or -``obs.wn_levels[0]`` are quantities of the first *local* detector, not global. -This should not be a problem as the only thing that matters is that the two -quantities refer to the same detector. If you need the global detector index, -you can get it with ``obs.det_idx[0]``, which is created -at construction time. - -To get a better understanding of how observations are being used in a -MPI simulation, use the method :meth:`.Simulation.describe_mpi_distribution`. -This method must be called *after* the observations have been allocated using -:meth:`.Simulation.create_observations`; it will return an instance of the -class :class:`.MpiDistributionDescr`, which can be inspected to determine -which detectors and time spans are covered by each observation in all the -MPI processes that are being used. For more information, refer to the Section -:ref:`simulations`. +2. Custom grouping of detectors along the detector axis +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +While uniform distribution of detectors along the detector axis optimizes load +balancing, it is less suitable for simulating some effects, like crosstalk and +noise correlation between the detectors. For these effects, uniform distribution +across MPI processes necessitates the transfer of large TOD arrays across multiple +MPI processes, which complicates the code implementation and may potentially lead +to significant performance overhead. To save us from this situation, the +:class:`.Observation` class +accepts an argument ``det_blocks_attributes`` that is a list of string objects +specifying the detector attributes to create the group of detectors. Once the +detector groups are made, the detectors are distributed to the MPI processes in such +a way that all the detectors of a group reside on the same set of MPI processes. + +If a valid ``det_blocks_attributes`` argument is passed to the :class:`.Observation` +class, the arguments ``n_blocks_det`` and ``n_blocks_time`` are ignored. Since the +``det_blocks_attributes`` creates the detector blocks dynamically, the +``n_blocks_time`` is computed during runtime using the size of MPI communicator and +the number of detector blocks (``n_blocks_time = comm.size // n_blocks_det``). + +The detector blocks made in this way can be accessed with +``Observation.detector_blocks``. It is a dictionary object that has the tuple of +``det_blocks_attributes`` values as dictionary keys and the list of detectors +corresponding to the key as dictionary values. This dictionary is sorted so that the +group with the largest number of detectors comes first and the one with +the fewest detectors comes last. + +The following example illustrates the distribution of ``obs.tod`` matrix across the +MPI processes when ``det_blocks_attributes`` is specified. + +.. code-block:: python + + import litebird_sim as lbs + + comm = lbs.MPI_COMM_WORLD + + start_time = 456 + duration_s = 100 + sampling_freq_Hz = 1 + + # Creating a list of detectors. + dets = [ + lbs.DetectorInfo( + name="channel1_w9_detA", + wafer="wafer_9", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel1_w3_detB", + wafer="wafer_3", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel1_w1_detC", + wafer="wafer_1", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel1_w1_detD", + wafer="wafer_1", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel2_w4_detA", + wafer="wafer_4", + channel="channel2", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel2_w4_detB", + wafer="wafer_4", + channel="channel2", + sampling_rate_hz=sampling_freq_Hz, + ), + ] + + # Initializing a simulation + sim = lbs.Simulation( + start_time=start_time, + duration_s=duration_s, + random_seed=12345, + mpi_comm=comm, + ) + + # Creating the observations with detector blocks + sim.create_observations( + detectors=dets, + split_list_over_processes=False, + num_of_obs_per_detector=3, + det_blocks_attributes=["channel"], # case 1 and 2 + # det_blocks_attributes=["channel", "wafer"] # case 3 + ) + +With the list of detectors defined in the code snippet above, we can see how the +detectors axis and time axis is divided depending on the size of MPI communicator and +``det_blocks_attributes``. + +**Case 1** + +*Size of global MPI communicator = 3*, ``det_blocks_attributes=["channel"]`` + +.. image:: ./images/detector_groups_case1.png + +**Case 2** + +*Size of global MPI communicator = 4*, ``det_blocks_attributes=["channel"]`` + +.. image:: ./images/detector_groups_case2.png + +**Case 3** + +*Size of global MPI communicator = 10*, ``det_blocks_attributes=["channel", "wafer"]`` + +.. image:: ./images/detector_groups_case3.png + +.. note:: + When ``n_blocks_det != 1``, keep in mind that ``obs.tod[0]`` or + ``obs.wn_levels[0]`` are quantities of the first *local* detector, not global. + This should not be a problem as the only thing that matters is that the two + quantities refer to the same detector. If you need the global detector index, + you can get it with ``obs.det_idx[0]``, which is created at construction time. + ``obs.det_idx`` stores the detector indices of the detectors available to an + :class:`.Observation` class, with respect to the list of detectors stored in + ``obs.detectors_global`` variable. + +.. note:: + To get a better understanding of how observations are being used in a + MPI simulation, use the method :meth:`.Simulation.describe_mpi_distribution`. + This method must be called *after* the observations have been allocated using + :meth:`.Simulation.create_observations`; it will return an instance of the + class :class:`.MpiDistributionDescr`, which can be inspected to determine + which detectors and time spans are covered by each observation in all the + MPI processes that are being used. For more information, refer to the Section + :ref:`simulations`. + +MPI communicators +^^^^^^^^^^^^^^^^^ + +The simulation framework exposes MPI communicators at different +levels. The first one is the global MPI communicator. It can be +accessed with a global variable :data:`.MPI_COMM_WORLD`, and it is +same as ``mpi4py.MPI.COMM_WORLD``. It contains all the MPI processes +used by the script. The other MPI communicators defined in the +simulation framework are the partitions of global MPI communicator and +they contain the MPI processes that have certain properties as we +explain below. For all the examples in this sub-section, we consider +the distribution of TODs across 10 MPI processes with ``n_blocks_time = 2`` +and ``n_blocks_det = 4``. + +To distribute the TODs across +``n_blocks_det`` :math:`\times` ``n_blocks_time`` :math:`=\, N` +MPI ranks, it is necessary that the script is executed with at least +:math:`N` MPI processes. There is, however, no upper limit on the +number of MPI processes to be used. When the number of MPI processes +is higher than :math:`N`, it leaves all the MPI processes with +``rank`` :math:`\geq N` with no detector (and TOD). In many cases, it +is useful to identify the unused ranks so that they can be avoided +while performing some computations. The :class:`.Observation` class +makes this happen by making a sub-communicator containing only the +processes that contain a non-zero number of detectors (and TODs). + +However, once the detectors and TODs are distributed across several +processes, it is not trivial to find out all the ranks that contain +the TOD chunks of a given detector (or detector block). Likewise, it +is hard to find all the ranks that contain the TOD chunks +corresponding to the same time interval. To solve this issue, :class:`.Observation` +class provides the MPI sub-communicators corresponding to different +groups of MPI processes. + +The three sub-communicators provided by the framework - each generated +by splitting the global MPI communicator - are listed below. The +process ranks in the sub-communicators are based on the order of their +global ranks: + +- **Grid communicator**: Once the number of detector blocks and the + number of time blocks are available, the Observation class splits the + global communicator into a grid communicator and a null communicator. + Grid communicator consist all the mpi processes whose rank is less + than :math:`N`. The null communicator contains all other MPI + processes. On the MPI processes with global rank less than + :math:`N`, the global variable ``MPI_COMM_GRID.COMM_OBS_GRID`` + points to the grid communicator. On other MPI processes, it points + to the null MPI communicator (``MPI_COMM_GRID.COMM_NULL`` which is + same as ``mpi4py.MPI.COMM_NULL``). In the example below, the + processes with global rank from 0 to 7 belong to the grid + sub-communicator. Since the processes with rank 8 and 9 are unused, + they are excluded from the grid communicator. + + .. image:: ./images/grid_communicator.png + + Note that, to exclude the MPI processes belonging to the null + communicator from the computations, one should enclose the + computations under an if condition comparing + ``MPI_COMM_GRID.COMM_OBS_GRID`` against ``MPI_COMM_GRID.COMM_NULL``: + + :: + + if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: + # proceed with the following computations when + # MPI_COMM_GRID.COMM_OBS_GRID is not null + ... + +- **Detector-block communicator**: The detector-block communicator is + made by splitting the grid communicator into ``n_blocks_det`` + sub-communicators, such that each sub-communicator contains the + processes where TODs of the same set of detectors reside. This + sub-communicator can be accessed with ``comm_det_blocks`` attribute + of the :class:`.Observation` class. In the following example, the + grid communicator is split into 4 detector-block communicators + containing the processes with global rank 0-1, 2-3, 4-5, and 6-7 + respectively. + + .. image:: ./images/detector_block_communicator.png + +- **Time-block communicator**: The time-block communicator is made by + splitting the grid communicator into ``n_blocks_time`` + sub-communicators, such that each sub-communicator contains the + processes where TOD chunks of same time interval reside. This + sub-communicator can be accessed with ``comm_time_block`` attribute + of the :class:`.Observation` class. In the example below, the grid + communicator is split into 2 time-block communicators containing the + processes with global rank 0-2-4-6 and 1-3-5-7 respectively. + + .. image:: ./images/time_block_communicator.png Other notable functionalities ----------------------------- diff --git a/litebird_sim/__init__.py b/litebird_sim/__init__.py index b4a8aa6b..44285c14 100644 --- a/litebird_sim/__init__.py +++ b/litebird_sim/__init__.py @@ -72,7 +72,7 @@ ) from .madam import save_simulation_for_madam from .mbs.mbs import Mbs, MbsParameters, MbsSavedMapInfo -from .mpi import MPI_COMM_WORLD, MPI_ENABLED, MPI_CONFIGURATION +from .mpi import MPI_COMM_WORLD, MPI_ENABLED, MPI_CONFIGURATION, MPI_COMM_GRID from .noise import ( add_white_noise, add_one_over_f_noise, @@ -218,6 +218,7 @@ def destripe_with_toast2(*args, **kwargs): "MPI_COMM_WORLD", "MPI_ENABLED", "MPI_CONFIGURATION", + "MPI_COMM_GRID", # observations.py "Observation", "TodDescription", diff --git a/litebird_sim/detectors.py b/litebird_sim/detectors.py index a489c13b..f2d0d981 100644 --- a/litebird_sim/detectors.py +++ b/litebird_sim/detectors.py @@ -75,6 +75,9 @@ class DetectorInfo: - channel (Union[str, None]): The channel. The default is None + - squid (Union[int, None]): The squid number of the detector. + The default value is None. + - sampling_rate_hz (float): The sampling rate of the ADC associated with this detector. The default is 0.0 @@ -136,6 +139,7 @@ class DetectorInfo: pixel: Union[int, None] = None pixtype: Union[str, None] = None channel: Union[str, None] = None + squid: Union[int, None] = None sampling_rate_hz: float = 0.0 fwhm_arcmin: float = 0.0 ellipticity: float = 0.0 @@ -148,8 +152,6 @@ class DetectorInfo: fknee_mhz: float = 0.0 fmin_hz: float = 0.0 alpha: float = 0.0 - bandcenter_ghz: float = 0.0 - bandwidth_ghz: float = 0.0 pol: Union[str, None] = None orient: Union[str, None] = None quat: Any = None @@ -175,6 +177,7 @@ def from_dict(dictionary: Dict[str, Any]): - ``pixel`` - ``pixtype`` - ``channel`` + - ``squid`` - ``bandcenter_ghz`` - ``bandwidth_ghz`` - ``band_freqs_ghz`` diff --git a/litebird_sim/distribute.py b/litebird_sim/distribute.py index 3649a877..ef893d82 100644 --- a/litebird_sim/distribute.py +++ b/litebird_sim/distribute.py @@ -50,7 +50,7 @@ def distribute_evenly(num_of_elements, num_of_groups): # If leftovers == 0, then the number of elements is divided evenly # by num_of_groups, and the solution is trivial. If it's not, then - # each of the "leftoverss" is placed in one of the first groups. + # each of the "leftovers" is placed in one of the first groups. # # Example: let's split 8 elements in 3 groups. In this case, # base_length=2 and leftovers=2 (elements #7 and #8): @@ -68,7 +68,7 @@ def distribute_evenly(num_of_elements, num_of_groups): cur_length = base_length + 1 cur_pos = cur_length * i else: - # No need to accomodate for leftovers, but consider their + # No need to accommodate for leftovers, but consider their # presence in fixing the starting position for this group cur_length = base_length cur_pos = base_length * i + leftovers @@ -84,6 +84,47 @@ def distribute_evenly(num_of_elements, num_of_groups): return result +def distribute_detector_blocks(detector_blocks): + """Similar to the :func:`distribute_evenly()` function, this function + returns a list of named-tuples, with fields `start_idx` (the starting + index of the detector in a group within the global list of detectors) and + num_of_elements` (the number of detectors in the group). Unlike + :func:`distribute_evenly()`, this function simply uses the detector groups + given in the `detector_blocks` attribute. + + Example: + Following the example given in + :meth:`litebird_sim.Observation._make_detector_blocks()`, + `distribute_detector_blocks()` will return + + ``` + [ + Span(start_idx=0, num_of_elements=2), + Span(start_idx=2, num_of_elements=2), + Span(start_idx=4, num_of_elements=1), + ] + ``` + + Args: + detector_blocks (dict): The detector block object. See :meth:`litebird_sim.Observation._make_detector_blocks()`. + + Returns: + A list of 2-elements named-tuples containing (1) the starting index of + the detectors of the block with respect to the flatten list of entire + detector blocks and (2) the number of elements in the detector block. + """ + cur_position = 0 + prev_length = 0 + result = [] + for key in detector_blocks: + cur_length = len(detector_blocks[key]) + cur_position += prev_length + prev_length = cur_length + result.append(Span(start_idx=cur_position, num_of_elements=cur_length)) + + return result + + # The following implementation of the painter's partition problem is # heavily inspired by the code at # https://www.geeksforgeeks.org/painters-partition-problem-set-2/?ref=rp diff --git a/litebird_sim/mapmaking/binner.py b/litebird_sim/mapmaking/binner.py index 682a5fbe..9622a234 100644 --- a/litebird_sim/mapmaking/binner.py +++ b/litebird_sim/mapmaking/binner.py @@ -68,7 +68,7 @@ class BinnerResult: @njit def _solve_binning(nobs_matrix, atd): - # Sove the map-making equation + # Solve the map-making equation # # This method alters the parameter `nobs_matrix`, so that after its completion # each 3×3 matrix in nobs_matrix[idx, :, :] will be the *inverse*. diff --git a/litebird_sim/mapmaking/common.py b/litebird_sim/mapmaking/common.py index 3f60e1c3..ac2f1393 100644 --- a/litebird_sim/mapmaking/common.py +++ b/litebird_sim/mapmaking/common.py @@ -219,7 +219,10 @@ def _compute_pixel_indices( if output_coordinate_system == CoordinateSystem.Galactic: # Free curr_pointings_det if the output map is already in Galactic coordinates - del curr_pointings_det + try: + del curr_pointings_det + except UnboundLocalError: + pass return pixidx_all, polang_all diff --git a/litebird_sim/mapmaking/destriper.py b/litebird_sim/mapmaking/destriper.py index c56fe6df..71e323af 100644 --- a/litebird_sim/mapmaking/destriper.py +++ b/litebird_sim/mapmaking/destriper.py @@ -20,7 +20,7 @@ from numba import njit, prange import healpy as hp -from litebird_sim.mpi import MPI_ENABLED, MPI_COMM_WORLD +from litebird_sim.mpi import MPI_ENABLED, MPI_COMM_WORLD, MPI_COMM_GRID from typing import Callable, Union, List, Optional, Tuple, Any, Dict from litebird_sim.hwp import HWP from litebird_sim.observations import Observation @@ -44,7 +44,7 @@ __DESTRIPER_RESULTS_FILE_NAME = "destriper_results.fits" -__BASELINES_FILE_NAME = f"baselines_mpi{MPI_COMM_WORLD.rank:04d}.fits" +__BASELINES_FILE_NAME = f"baselines_mpi{MPI_COMM_GRID.COMM_OBS_GRID.rank:04d}.fits" def _split_items_into_n_segments(n: int, num_of_segments: int) -> List[int]: @@ -498,8 +498,10 @@ def _build_nobs_matrix( ) # Now we must accumulate the result of every MPI process - if MPI_ENABLED: - MPI_COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, nobs_matrix, op=mpi4py.MPI.SUM) + if MPI_ENABLED and MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: + MPI_COMM_GRID.COMM_OBS_GRID.Allreduce( + mpi4py.MPI.IN_PLACE, nobs_matrix, op=mpi4py.MPI.SUM + ) # `nobs_matrix_cholesky` will *not* contain the M_i maps shown in # Eq. 9 of KurkiSuonio2009, but its Cholesky decomposition, i.e., @@ -746,8 +748,12 @@ def _compute_binned_map( ) if MPI_ENABLED: - MPI_COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, output_sky_map, op=mpi4py.MPI.SUM) - MPI_COMM_WORLD.Allreduce(mpi4py.MPI.IN_PLACE, output_hit_map, op=mpi4py.MPI.SUM) + MPI_COMM_GRID.COMM_OBS_GRID.Allreduce( + mpi4py.MPI.IN_PLACE, output_sky_map, op=mpi4py.MPI.SUM + ) + MPI_COMM_GRID.COMM_OBS_GRID.Allreduce( + mpi4py.MPI.IN_PLACE, output_hit_map, op=mpi4py.MPI.SUM + ) # Step 2: compute the “binned map” (Eq. 21) _sum_map_to_binned_map( @@ -987,7 +993,7 @@ def _mpi_dot(a: List[npt.ArrayLike], b: List[npt.ArrayLike]) -> float: # the dot product local_result = sum([np.dot(x1.flatten(), x2.flatten()) for (x1, x2) in zip(a, b)]) if MPI_ENABLED: - return MPI_COMM_WORLD.allreduce(local_result, op=mpi4py.MPI.SUM) + return MPI_COMM_GRID.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.SUM) else: return local_result @@ -1004,7 +1010,7 @@ def _get_stopping_factor(residual: List[npt.ArrayLike]) -> float: """ local_result = np.max(np.abs(residual)) if MPI_ENABLED: - return MPI_COMM_WORLD.allreduce(local_result, op=mpi4py.MPI.MAX) + return MPI_COMM_GRID.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.MAX) else: return local_result @@ -1418,7 +1424,7 @@ def _run_destriper( bytes_in_temporary_buffers += mask.nbytes if MPI_ENABLED: - bytes_in_temporary_buffers = MPI_COMM_WORLD.allreduce( + bytes_in_temporary_buffers = MPI_COMM_GRID.COMM_OBS_GRID.allreduce( bytes_in_temporary_buffers, op=mpi4py.MPI.SUM, ) @@ -1613,91 +1619,103 @@ def my_gui_callback( binned_map = np.empty((3, number_of_pixels)) hit_map = np.empty(number_of_pixels) - if do_destriping: - try: - # This will fail if the parameter is a scalar - len(params.samples_per_baseline) - - baseline_lengths_list = params.samples_per_baseline - assert len(baseline_lengths_list) == len(obs_list), ( - f"The list baseline_lengths_list has {len(baseline_lengths_list)} " - f"elements, but there are {len(obs_list)} observations" - ) - except TypeError: - # Ok, params.samples_per_baseline is a scalar, so we must - # figure out the number of samples in each baseline within - # each observation - baseline_lengths_list = [ - split_items_evenly( - n=getattr(cur_obs, components[0]).shape[1], - sub_n=int(params.samples_per_baseline), + if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: + # perform the following operations when MPI is not being used + # OR when the MPI_COMM_GRID.COMM_OBS_GRID is not a NULL communicator + if do_destriping: + try: + # This will fail if the parameter is a scalar + len(params.samples_per_baseline) + + baseline_lengths_list = params.samples_per_baseline + assert len(baseline_lengths_list) == len(obs_list), ( + f"The list baseline_lengths_list has {len(baseline_lengths_list)} " + f"elements, but there are {len(obs_list)} observations" ) - for cur_obs in obs_list - ] + except TypeError: + # Ok, params.samples_per_baseline is a scalar, so we must + # figure out the number of samples in each baseline within + # each observation + baseline_lengths_list = [ + split_items_evenly( + n=getattr(cur_obs, components[0]).shape[1], + sub_n=int(params.samples_per_baseline), + ) + for cur_obs in obs_list + ] + + # Each element of this list is a 2D array with shape (N_det, N_baselines), + # where N_det is the number of detectors in the i-th Observation object + recycle_baselines = False + if baselines_list is None: + baselines_list = [ + np.zeros( + (getattr(cur_obs, components[0]).shape[0], len(cur_baseline)) + ) + for (cur_obs, cur_baseline) in zip(obs_list, baseline_lengths_list) + ] + else: + recycle_baselines = True + + destriped_map = np.empty((3, number_of_pixels)) + ( + baselines_list, + baseline_errors_list, + history_of_stopping_factors, + best_stopping_factor, + converged, + bytes_in_temporary_buffers, + ) = _run_destriper( + obs_list=obs_list, + nobs_matrix_cholesky=nobs_matrix_cholesky, + binned_map=binned_map, + destriped_map=destriped_map, + hit_map=hit_map, + baseline_lengths_list=baseline_lengths_list, + baselines_list_start=baselines_list, + recycle_baselines=recycle_baselines, + recycled_convergence=recycled_convergence, + dm_list=detector_mask_list, + tm_list=time_mask_list, + component=components[0], + threshold=params.threshold, + max_steps=params.iter_max, + use_preconditioner=params.use_preconditioner, + callback=callback, + callback_kwargs=callback_kwargs if callback_kwargs else {}, + ) - # Each element of this list is a 2D array with shape (N_det, N_baselines), - # where N_det is the number of detectors in the i-th Observation object - recycle_baselines = False - if baselines_list is None: - baselines_list = [ - np.zeros((getattr(cur_obs, components[0]).shape[0], len(cur_baseline))) - for (cur_obs, cur_baseline) in zip(obs_list, baseline_lengths_list) - ] + if MPI_ENABLED: + bytes_in_temporary_buffers = MPI_COMM_GRID.COMM_OBS_GRID.allreduce( + bytes_in_temporary_buffers, + op=mpi4py.MPI.SUM, + ) else: - recycle_baselines = True - - destriped_map = np.empty((3, number_of_pixels)) - ( - baselines_list, - baseline_errors_list, - history_of_stopping_factors, - best_stopping_factor, - converged, - bytes_in_temporary_buffers, - ) = _run_destriper( - obs_list=obs_list, - nobs_matrix_cholesky=nobs_matrix_cholesky, - binned_map=binned_map, - destriped_map=destriped_map, - hit_map=hit_map, - baseline_lengths_list=baseline_lengths_list, - baselines_list_start=baselines_list, - recycle_baselines=recycle_baselines, - recycled_convergence=recycled_convergence, - dm_list=detector_mask_list, - tm_list=time_mask_list, - component=components[0], - threshold=params.threshold, - max_steps=params.iter_max, - use_preconditioner=params.use_preconditioner, - callback=callback, - callback_kwargs=callback_kwargs if callback_kwargs else {}, - ) - - if MPI_ENABLED: - bytes_in_temporary_buffers = MPI_COMM_WORLD.allreduce( - bytes_in_temporary_buffers, - op=mpi4py.MPI.SUM, + # No need to run the destriping, just compute the binned map with + # one single baseline set to zero + _compute_binned_map( + obs_list=obs_list, + output_sky_map=binned_map, + output_hit_map=hit_map, + nobs_matrix_cholesky=nobs_matrix_cholesky, + component=components[0], + dm_list=detector_mask_list, + tm_list=time_mask_list, + baselines_list=None, + baseline_lengths_list=[ + np.array([getattr(cur_obs, components[0]).shape[1]], dtype=int) + for cur_obs in obs_list + ], ) + bytes_in_temporary_buffers = 0 + destriped_map = None + baseline_lengths_list = None + baselines_list = None + baseline_errors_list = None + history_of_stopping_factors = None + best_stopping_factor = None + converged = True else: - # No need to run the destriping, just compute the binned map with - # one single baseline set to zero - _compute_binned_map( - obs_list=obs_list, - output_sky_map=binned_map, - output_hit_map=hit_map, - nobs_matrix_cholesky=nobs_matrix_cholesky, - component=components[0], - dm_list=detector_mask_list, - tm_list=time_mask_list, - baselines_list=None, - baseline_lengths_list=[ - np.array([getattr(cur_obs, components[0]).shape[1]], dtype=int) - for cur_obs in obs_list - ], - ) - bytes_in_temporary_buffers = 0 - destriped_map = None baseline_lengths_list = None baselines_list = None @@ -1707,14 +1725,18 @@ def my_gui_callback( converged = True # Add the temporary memory that was allocated *before* calling the destriper - bytes_in_temporary_buffers += sum( - [ - cur_obs.destriper_weights.nbytes - + cur_obs.destriper_pixel_idx.nbytes - + cur_obs.destriper_pol_angle_rad.nbytes - for cur_obs in obs_list - ] - ) + try: + bytes_in_temporary_buffers += sum( + [ + cur_obs.destriper_weights.nbytes + + cur_obs.destriper_pixel_idx.nbytes + + cur_obs.destriper_pol_angle_rad.nbytes + for cur_obs in obs_list + ] + ) + except UnboundLocalError: + # The case when `bytes_in_temporary_buffers` is not defined + bytes_in_temporary_buffers = 0 # We're nearly done! Let's clean up some stuff… if not keep_weights: @@ -1992,11 +2014,11 @@ def _save_baselines(results: DestriperResult, output_file: Path) -> None: primary_hdu = fits.PrimaryHDU() primary_hdu.header["MPIRANK"] = ( - MPI_COMM_WORLD.rank, + MPI_COMM_GRID.COMM_OBS_GRID.rank, "The rank of the MPI process that wrote this file", ) primary_hdu.header["MPISIZE"] = ( - MPI_COMM_WORLD.size, + MPI_COMM_GRID.COMM_OBS_GRID.size, "The number of MPI processes used in the computation", ) @@ -2212,11 +2234,11 @@ def load_destriper_results( baselines_file_name = folder / __BASELINES_FILE_NAME with fits.open(baselines_file_name) as inpf: - assert MPI_COMM_WORLD.rank == inpf[0].header["MPIRANK"], ( + assert MPI_COMM_GRID.COMM_OBS_GRID.rank == inpf[0].header["MPIRANK"], ( "You must call load_destriper_results using the " "same MPI layout that was used for save_destriper_results " ) - assert MPI_COMM_WORLD.size == inpf[0].header["MPISIZE"], ( + assert MPI_COMM_GRID.COMM_OBS_GRID.size == inpf[0].header["MPISIZE"], ( "You must call load_destriper_results using the " "same MPI layout that was used for save_destriper_results" ) diff --git a/litebird_sim/mpi.py b/litebird_sim/mpi.py index fe623248..64c31181 100644 --- a/litebird_sim/mpi.py +++ b/litebird_sim/mpi.py @@ -22,10 +22,57 @@ class _SerialMpiCommunicator: size = 1 +class _GridCommClass: + """ + This class encapsulates the `COMM_OBS_GRID` and `COMM_NULL` communicators. It + offers explicitly defined setter functions so that the communicators cannot be + changed accidentally. + + Attributes: + + COMM_OBS_GRID (mpi4py.MPI.Intracomm): A subset of `MPI.COMM_WORLD` that + contain all the processes associated with non-zero observations. + + COMM_NULL (mpi4py.MPI.Comm): A NULL communicator. When MPI is not enabled, it + is set as `None`. If MPI is enabled, it is set as `MPI.COMM_NULL` + + """ + + def __init__(self, comm_obs_grid=_SerialMpiCommunicator(), comm_null=None): + self._MPI_COMM_OBS_GRID = comm_obs_grid + self._MPI_COMM_NULL = comm_null + + @property + def COMM_OBS_GRID(self): + return self._MPI_COMM_OBS_GRID + + @property + def COMM_NULL(self): + return self._MPI_COMM_NULL + + def _set_comm_obs_grid(self, comm_obs_grid): + self._MPI_COMM_OBS_GRID = comm_obs_grid + + def _set_null_comm(self, comm_null): + self._MPI_COMM_NULL = comm_null + + #: Global variable equal either to `mpi4py.MPI.COMM_WORLD` or a object #: that defines the member variables `rank = 0` and `size = 1`. MPI_COMM_WORLD = _SerialMpiCommunicator() + +#: Global object with two attributes: +#: +#: - ``COMM_OBS_GRID``: It is a partition of ``MPI_COMM_WORLD`` that includes all the +#: MPI processes with global rank less than ``n_blocks_time * n_blocks_det``. On MPI +#: processes with higher ranks, it points to NULL MPI communicator +#: ``mpi4py.MPI.COMM_NULL``. +#: +#: - ``COMM_NULL``: If :data:`.MPI_ENABLED` is ``True``, this object points to a NULL +#: MPI communicator (``mpi4py.MPI.COMM_NULL``). Otherwise it is ``None``. +MPI_COMM_GRID = _GridCommClass() + #: `True` if MPI should be used by the application. The value of this #: variable is set according to the following rules: #: @@ -53,6 +100,8 @@ class _SerialMpiCommunicator: from mpi4py import MPI MPI_COMM_WORLD = MPI.COMM_WORLD + MPI_COMM_GRID._set_comm_obs_grid(comm_obs_grid=MPI.COMM_WORLD) + MPI_COMM_GRID._set_null_comm(comm_null=MPI.COMM_NULL) MPI_ENABLED = True MPI_CONFIGURATION = mpi4py.get_config() except ImportError: diff --git a/litebird_sim/observations.py b/litebird_sim/observations.py index 7997de56..91cf0cb8 100644 --- a/litebird_sim/observations.py +++ b/litebird_sim/observations.py @@ -2,13 +2,18 @@ from dataclasses import dataclass from typing import Union, List, Any, Optional +import numbers import astropy.time import numpy as np import numpy.typing as npt +from collections import defaultdict + from .coordinates import DEFAULT_TIME_SCALE -from .distribute import distribute_evenly +from .distribute import distribute_evenly, distribute_detector_blocks +from .detectors import DetectorInfo +from .mpi import MPI_COMM_GRID, _SerialMpiCommunicator @dataclass @@ -80,11 +85,20 @@ class Observation: sampling_rate_hz (float): The sampling frequency, in Hertz. + det_blocks_attributes (list of strings): The list of detector + attributes that will be used to divide the detector axis of the + tod matrix and all its attributes. For example, with + ``det_blocks_attributes = ["wafer", "pixel"]``, the detectors will + be divided into blocks such that all detectors in a block will + have the same ``wafer`` and ``pixel`` attribute. + n_blocks_det (int): divide the detector axis of the tod (and all the - arrays of detector attributes) in `n_blocks_det` blocks + arrays of detector attributes) in `n_blocks_det` blocks. It will + be ignored if ``det_blocks_attributes`` is not `None`. n_blocks_time (int): divide the time axis of the tod in - `n_blocks_time` blocks + `n_blocks_time` blocks. It will be ignored + if ``det_blocks_attributes`` is not `None`. comm: either `None` (do not use MPI) or a MPI communicator object, like `mpi4py.MPI.COMM_WORLD`. Its size is required to be at @@ -103,6 +117,7 @@ def __init__( sampling_rate_hz: float, allocate_tod=True, tods=None, + det_blocks_attributes: Union[List[str], None] = None, n_blocks_det=1, n_blocks_time=1, comm=None, @@ -123,27 +138,36 @@ def __init__( delta = 1.0 / sampling_rate_hz self.end_time_global = start_time_global + n_samples_global * delta + self._sampling_rate_hz = sampling_rate_hz + self._det_blocks_attributes = det_blocks_attributes + self.detector_blocks = None + if isinstance(detectors, int): self._n_detectors_global = detectors else: - if comm and comm.size > 1: - self._n_detectors_global = comm.bcast(len(detectors), root) + if self.comm and self.comm.size > 1: + self._n_detectors_global = self.comm.bcast(len(detectors), root) + + if self._det_blocks_attributes is not None: + n_blocks_det, n_blocks_time = self._make_detector_blocks( + detectors, self.comm + ) else: self._n_detectors_global = len(detectors) - self._sampling_rate_hz = sampling_rate_hz - - # Neme of the attributes that store an array with the value of a + # Name of the attributes that store an array with the value of a # property for each of the (local) detectors self._attr_det_names = [] self._check_blocks(n_blocks_det, n_blocks_time) - if comm and comm.size > 1: + if self.comm and self.comm.size > 1: self._n_blocks_det = n_blocks_det self._n_blocks_time = n_blocks_time else: self._n_blocks_det = 1 self._n_blocks_time = 1 + self._set_mpi_subcommunicators() + self.tod_list = tods for cur_tod in self.tod_list: if allocate_tod: @@ -159,8 +183,17 @@ def __init__( setattr(self, cur_tod.name, None) self.setattr_det_global("det_idx", np.arange(self._n_detectors_global), root) + + self.detectors_global = [] + + if self.detector_blocks is not None: + for key in self.detector_blocks: + self.detectors_global += self.detector_blocks[key] + else: + self.detectors_global = detectors + if not isinstance(detectors, int): - self._set_attributes_from_list_of_dict(detectors, root) + self._set_attributes_from_list_of_dict(self.detectors_global, root) ( self.start_time, @@ -176,7 +209,10 @@ def sampling_rate_hz(self): @property def n_detectors(self): - return len(self.det_idx) + if self.det_idx is None: + return 0 + else: + return len(self.det_idx) def _get_local_start_time_start_and_n_samples(self): _, _, start, num = self._get_start_and_num( @@ -203,7 +239,7 @@ def _get_local_start_time_start_and_n_samples(self): return self.start_time_global + start * delta, start, num def _set_attributes_from_list_of_dict(self, list_of_dict, root): - assert len(list_of_dict) == self.n_detectors_global + np.testing.assert_equal(len(list_of_dict), self.n_detectors_global) # Turn list of dict into dict of arrays if not self.comm or self.comm.rank == root: @@ -273,10 +309,86 @@ def n_blocks_time(self): def n_blocks_det(self): return self._n_blocks_det + def _make_detector_blocks(self, detectors, comm): + """This function distributes the detectors in groups such that each + group has the same set of attributes specified by the strings in + `self._det_block_attributes`. Once the groups are made, the number of + detector blocks is set to be the total number of detector groups, + whereas the number of time blocks is computed using the number of + detector blocks and the size of `comm` communicator. + + The detector blocks are stored in `self.detector_blocks`. This + dictionary object has the tuple of `self._det_blocks_attributes` values + as dictionary keys and the list of detectors corresponding to the key + as dictionary values. This dictionary is sorted so that the + group with the largest number of detectors comes first and the one with + the fewest detectors comes last. + + Example: + For + + ``` + detectors = [ + "000_002_123_xx_140_x", + "000_005_321_xx_140_x", + "000_004_456_xx_119_x", + "000_002_654_xx_140_x", + "000_004_789_xx_119_x", + ] + ``` + + and `self._det_blocks_attributes = ["channel", "wafer"]`, + `_make_detector_blocks()` will set + + ``` + self.detector_blocks = { + ("140", "L02"): ["000_002_123_xx_140_x", "000_002_654_xx_140_x"], + ("119", "L04"): ["000_004_456_xx_119_x", "000_004_789_xx_119_x"], + ("140", "L05"): ["000_005_321_xx_140_x"], + } + ``` + + and return `n_blocks_det = 3` + + Args: + detectors (List[dict]): List of detectors + + comm: The MPI communicator + + Returns: + n_blocks_det (int): Number of detector blocks + + n_blocks_time (int): Number of time blocks + + """ + self.detector_blocks = defaultdict(list) + for det in detectors: + key = tuple(det[attribute] for attribute in self._det_blocks_attributes) + self.detector_blocks[key].append(det) + + self.detector_blocks = dict( + sorted( + self.detector_blocks.items(), + key=lambda item: len(item[1]), + reverse=True, + ) + ) + n_blocks_det = len(self.detector_blocks) + n_blocks_time = comm.size // n_blocks_det + + return n_blocks_det, n_blocks_time + def _check_blocks(self, n_blocks_det, n_blocks_time): if self.comm is None: if n_blocks_det != 1 or n_blocks_time != 1: raise ValueError("Only one block allowed without an MPI comm") + elif n_blocks_det == 0 or n_blocks_time == 0: + raise ValueError( + "The number of detector blocks and the number of time blocks " + "must be must be non-zero\n" + f"n_blocks_det = {n_blocks_det}, " + f"n_blocks_time = {n_blocks_time}" + ) elif n_blocks_det > self.n_detectors_global: raise ValueError( "You can not have more detector blocks than detectors " @@ -296,21 +408,33 @@ def _check_blocks(self, n_blocks_det, n_blocks_time): def _get_start_and_num(self, n_blocks_det, n_blocks_time): """For both detectors and time, returns the starting (global) - index and lenght of each block if the number of blocks is changed to the + index and length of each block if the number of blocks is changed to the values passed as arguments """ - det_start, det_n = np.array( - [ - [span.start_idx, span.num_of_elements] - for span in distribute_evenly(self._n_detectors_global, n_blocks_det) - ] - ).T + if self._det_blocks_attributes is None or self.comm.size == 1: + det_start, det_n = np.array( + [ + [span.start_idx, span.num_of_elements] + for span in distribute_evenly( + self._n_detectors_global, n_blocks_det + ) + ] + ).T + else: + det_start, det_n = np.array( + [ + [span.start_idx, span.num_of_elements] + for span in distribute_detector_blocks(self.detector_blocks) + ] + ).T + time_start, time_n = np.array( [ [span.start_idx, span.num_of_elements] for span in distribute_evenly(self._n_samples_global, n_blocks_time) ] ).T + return ( np.array(det_start), np.array(det_n), @@ -327,6 +451,7 @@ def _get_tod_shape(self, n_blocks_det, n_blocks_time): return (self._n_detectors_global, self._n_samples_global) _, det_n, _, time_n = self._get_start_and_num(n_blocks_det, n_blocks_time) + try: return ( det_n[self.comm.rank // n_blocks_time], @@ -478,6 +603,9 @@ def set_n_blocks(self, n_blocks_det=1, n_blocks_time=1): self._n_blocks_det = n_blocks_det self._n_blocks_time = n_blocks_time + # Update the sub-communicators + self._set_mpi_subcommunicators() + for name in self._attr_det_names: info = None if is_in_old_fist_col: @@ -550,26 +678,27 @@ def setattr_det_global(self, name, info, root=0): setattr(self, name, info) return - is_in_grid = self.comm.rank < self._n_blocks_det * self._n_blocks_time - comm_grid = self.comm.Split(int(is_in_grid)) - if not is_in_grid: # The process does not own any detector (and TOD) - setattr(self, name, None) + if ( + MPI_COMM_GRID.COMM_OBS_GRID == MPI_COMM_GRID.COMM_NULL + ): # The process does not own any detector (and TOD) + null_det = DetectorInfo() + attribute = getattr(null_det, name, None) + value = 0 if isinstance(attribute, numbers.Number) else None + setattr(self, name, value) return - my_col = comm_grid.rank % self._n_blocks_time - comm_col = comm_grid.Split(my_col) + my_col = MPI_COMM_GRID.COMM_OBS_GRID.rank % self._n_blocks_time root_col = root // self._n_blocks_det if my_col == root_col: - if comm_grid.rank == root: + if MPI_COMM_GRID.COMM_OBS_GRID.rank == root: starts, nums, _, _ = self._get_start_and_num( self._n_blocks_det, self._n_blocks_time ) info = [info[s : s + n] for s, n in zip(starts, nums)] - info = comm_col.scatter(info, root) + info = self.comm_time_block.scatter(info, root) - comm_row = comm_grid.Split(comm_grid.rank // self._n_blocks_time) - info = comm_row.bcast(info, root_col) + info = self.comm_det_block.bcast(info, root_col) assert (not self.tod_list) or len(info) == len( getattr(self, self.tod_list[0].name) ) @@ -662,7 +791,7 @@ def get_pointings( pointing_buffer: Optional[npt.NDArray] = None, hwp_buffer: Optional[npt.NDArray] = None, pointings_dtype=np.float32, - ) -> (npt.NDArray, Optional[npt.NDArray]): + ) -> tuple[npt.NDArray, Optional[npt.NDArray]]: """ Compute the pointings for one or more detectors in this observation @@ -785,3 +914,55 @@ def get_pointings( ) return pointing_buffer, hwp_buffer + + def _set_mpi_subcommunicators(self): + """ + This function splits the global MPI communicator into three kinds of + sub-communicators: + + 1. A sub-communicator containing all the processes with global rank less than + `n_blocks_det * n_blocks_time`. Outside of this global rank, the + sub-communicator is NULL. + + 2. A sub-communicator for each block of detectors, that contains all the + processes corresponding to that detector block. This sub-communicator + is an attribute of the :class:`.Observation` class. + + 3. A sub-communicator for each block of time that contains all the processes + corresponding to that time block. This sub-communicator + is an attribute of the :class:`.Observation` class. + """ + + # Set the detector and time block sub-communicators to + # `_SerialMpiCommunicator()` when MPI is not being used + self.comm_det_block = _SerialMpiCommunicator() + self.comm_time_block = _SerialMpiCommunicator() + + if self.comm and self.comm.size > 1: + if self.comm.rank < self.n_blocks_det * self.n_blocks_time: + matrix_color = 1 + else: + from .mpi import MPI + + matrix_color = MPI.UNDEFINED + + # Case1: For `0 < rank < n_blocks_det * n_blocks_time`, + # `comm_obs_grid` is a sub-communicator that includes processes + # from rank 0 to `n_blocks_det * n_blocks_time - 1`. + # Case 2: For `n_blocks_det * n_blocks_time <= rank < comm.size`, + # `comm_obs_grid = MPI.COMM_NULL` + comm_obs_grid = self.comm.Split(matrix_color, self.comm.rank) + MPI_COMM_GRID._set_comm_obs_grid(comm_obs_grid=comm_obs_grid) + + # If the `MPI_COMM_GRID.COMM_OBS_GRID` is not NULL, we split it in + # communicators corresponding to each detector and time block + # If `MPI_COMM_GRID.COMM_OBS_GRID` is NULL, we set the communicators + # corresponding to detector and time blocks to NULL. + if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: + det_color = MPI_COMM_GRID.COMM_OBS_GRID.rank // self.n_blocks_time + time_color = MPI_COMM_GRID.COMM_OBS_GRID.rank % self.n_blocks_time + self.comm_det_block = MPI_COMM_GRID.COMM_OBS_GRID.Split(det_color) + self.comm_time_block = MPI_COMM_GRID.COMM_OBS_GRID.Split(time_color) + else: + self.comm_det_block = MPI_COMM_GRID.COMM_NULL + self.comm_time_block = MPI_COMM_GRID.COMM_NULL diff --git a/litebird_sim/simulations.py b/litebird_sim/simulations.py index 7a213a0a..51dde0cf 100644 --- a/litebird_sim/simulations.py +++ b/litebird_sim/simulations.py @@ -44,7 +44,7 @@ DestriperResult, destriper_log_callback, ) -from .mpi import MPI_ENABLED, MPI_COMM_WORLD +from .mpi import MPI_ENABLED, MPI_COMM_WORLD, MPI_COMM_GRID from .noise import add_noise_to_observations from .observations import Observation, TodDescription from .pointings_in_obs import prepare_pointings, precompute_pointings @@ -889,6 +889,7 @@ def create_observations( detectors: List[DetectorInfo], num_of_obs_per_detector: int = 1, split_list_over_processes=True, + det_blocks_attributes: Union[List[str], None] = None, n_blocks_det=1, n_blocks_time=1, root=0, @@ -922,7 +923,12 @@ def create_observations( simulating 10 detectors and you specify ``n_blocks_det=5``, this means that each observation will handle ``10 / 5 = 2`` detectors. The default is that *all* the detectors be kept - together (``n_blocks_det=1``). + together (``n_blocks_det=1``). On the other hand, the parameter + `det_blocks_attributes` specifies the list of detector attributes + to create the groups of detectors. For example, with + ``det_blocks_attributes = ["wafer", "pixel"]``, the detectors will + be divided into groups such that all detectors in a group will + have the same ``wafer`` and ``pixel`` attribute. The parameter `n_blocks_time` specifies the number of time splits of the observations. In the case of a 3-month-long @@ -1013,6 +1019,7 @@ def create_observations( start_time_global=cur_time, sampling_rate_hz=sampfreq_hz, n_samples_global=nsamples, + det_blocks_attributes=det_blocks_attributes, n_blocks_det=n_blocks_det, n_blocks_time=n_blocks_time, comm=(None if split_list_over_processes else self.mpi_comm), @@ -1214,7 +1221,8 @@ def set_scanning_strategy( num_of_obs = len(self.observations) if append_to_report and MPI_ENABLED: - num_of_obs = MPI_COMM_WORLD.allreduce(num_of_obs) + if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: + num_of_obs = MPI_COMM_GRID.COMM_OBS_GRID.allreduce(num_of_obs) if append_to_report and MPI_COMM_WORLD.rank == 0: template_file_path = get_template_file_path("report_quaternions.md") @@ -1311,8 +1319,11 @@ def prepare_pointings( memory_occupation = pointing_provider.bore2ecliptic_quats.quats.nbytes num_of_obs = len(self.observations) if append_to_report and MPI_ENABLED: - memory_occupation = MPI_COMM_WORLD.allreduce(memory_occupation) - num_of_obs = MPI_COMM_WORLD.allreduce(num_of_obs) + if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: + memory_occupation = MPI_COMM_GRID.COMM_OBS_GRID.allreduce( + memory_occupation + ) + num_of_obs = MPI_COMM_GRID.COMM_OBS_GRID.allreduce(num_of_obs) if append_to_report and MPI_COMM_WORLD.rank == 0: template_file_path = get_template_file_path("report_pointings.md") diff --git a/test/test_detector_blocks.py b/test/test_detector_blocks.py new file mode 100644 index 00000000..a329a46d --- /dev/null +++ b/test/test_detector_blocks.py @@ -0,0 +1,163 @@ +import numpy as np +import pytest +import litebird_sim as lbs + +# data for testing detector blocks and MPI sub-communicators +sampling_freq_Hz = 1 +dets = [ + lbs.DetectorInfo( + name="channel1_w9_detA", + wafer="wafer_9", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel1_w3_detB", + wafer="wafer_3", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel1_w1_detC", + wafer="wafer_1", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel1_w1_detD", + wafer="wafer_1", + channel="channel1", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel2_w4_detA", + wafer="wafer_4", + channel="channel2", + sampling_rate_hz=sampling_freq_Hz, + ), + lbs.DetectorInfo( + name="channel2_w4_detB", + wafer="wafer_4", + channel="channel2", + sampling_rate_hz=sampling_freq_Hz, + ), +] + + +def test_detector_blocks(dets=dets, sampling_freq_Hz=sampling_freq_Hz): + comm = lbs.MPI_COMM_WORLD + + start_time = 456 + duration_s = 100 + nobs_per_det = 3 + + if comm.size > 4: + det_blocks_attribute = ["channel", "wafer"] + else: + det_blocks_attribute = ["channel"] + + sim = lbs.Simulation( + start_time=start_time, + duration_s=duration_s, + random_seed=12345, + mpi_comm=comm, + ) + + sim.create_observations( + detectors=dets, + split_list_over_processes=False, + num_of_obs_per_detector=nobs_per_det, + det_blocks_attributes=det_blocks_attribute, + ) + + tod_len_per_det_per_proc = 0 + for obs in sim.observations: + tod_shape = obs.tod.shape + + n_blocks_det = obs.n_blocks_det + n_blocks_time = obs.n_blocks_time + tod_len_per_det_per_proc += obs.tod.shape[1] + + # No testing required if the proc doesn't owns a detector + if obs.det_idx is not None: + det_names_per_obs = [ + obs.detectors_global[idx]["name"] for idx in obs.det_idx + ] + + # Testing if the mapping between the obs.name and + # obs.det_idx is consistent with obs.detectors_global + np.testing.assert_equal(obs.name, det_names_per_obs) + + # Testing the distribution of the number of detectors per + # detector block + np.testing.assert_equal(obs.name.shape[0], tod_shape[0]) + + # Testing if the distribution of samples along the time axis is consistent + if comm.rank < n_blocks_det * n_blocks_time: + arr = [ + span.num_of_elements + for span in lbs.distribute.distribute_evenly( + duration_s * sampling_freq_Hz, n_blocks_time * nobs_per_det + ) + ] + + start_idx = (comm.rank % n_blocks_time) * nobs_per_det + stop_idx = start_idx + nobs_per_det + np.testing.assert_equal(sum(arr[start_idx:stop_idx]), tod_len_per_det_per_proc) + + +def test_mpi_subcommunicators(dets=dets): + comm = lbs.MPI_COMM_WORLD + + start_time = 456 + duration_s = 100 + nobs_per_det = 3 + + if comm.size > 4: + det_blocks_attribute = ["channel", "wafer"] + else: + det_blocks_attribute = ["channel"] + + sim = lbs.Simulation( + start_time=start_time, + duration_s=duration_s, + random_seed=12345, + mpi_comm=comm, + ) + + sim.create_observations( + detectors=dets, + split_list_over_processes=False, + num_of_obs_per_detector=nobs_per_det, + det_blocks_attributes=det_blocks_attribute, + ) + + if lbs.MPI_COMM_GRID.COMM_OBS_GRID != lbs.MPI_COMM_GRID.COMM_NULL: + # since unused MPI processes stay at the end of global, + # communicator, the rank of the used processes in + # `MPI_COMM_GRID.COMM_OBS_GRID` must be same as their rank in + # global communicator + np.testing.assert_equal(lbs.MPI_COMM_GRID.COMM_OBS_GRID.rank, comm.rank) + + for obs in sim.observations: + # comm_det_block.rank + comm_time_block.rank * n_block_time + # must be equal to the global communicator rank for the + # used processes. It follows from the way split colors + # were defined. + np.testing.assert_equal( + obs.comm_det_block.rank + obs.comm_time_block.rank * obs.n_blocks_time, + comm.rank, + ) + else: + for obs in sim.observations: + # the global rank of the unused MPI processes must be larger than the number of used processes. + assert comm.rank > (obs.n_blocks_det * obs.n_blocks_time - 1) + + # The block communicators on the unused MPI processes must + # be the NULL communicators + np.testing.assert_equal(obs.comm_det_block, lbs.MPI_COMM_GRID.COMM_NULL) + np.testing.assert_equal(obs.comm_time_block, lbs.MPI_COMM_GRID.COMM_NULL) + + +if __name__ == "__main__": + pytest.main([f"{__file__}"]) diff --git a/test/test_mpi.py b/test/test_mpi.py index 4f6d5deb..d7e3d4d6 100644 --- a/test/test_mpi.py +++ b/test/test_mpi.py @@ -166,9 +166,11 @@ def test_construction_from_detectors(): assert obs.wafer is None assert obs.pixel is None assert obs.pixtype is None - assert obs.ellipticity is None assert obs.quat is None - assert obs.alpha is None + # On the processes, that does not own any detector (and TOD), the numerical + # attributes of `DetectorInfo()` are assigned to zero + assert obs.ellipticity == 0 + assert obs.alpha == 0 obs.set_n_blocks(n_blocks_time=1, n_blocks_det=1) if comm_world.rank == 0: @@ -216,7 +218,7 @@ def test_observation_tod_two_block_time(): comm=comm_world, ) except ValueError: - # Not enough processes to split the TOD, constuctor expected to rise + # Not enough processes to split the TOD, constructor expected to rise if comm_world.size < 2: return @@ -240,7 +242,7 @@ def test_observation_tod_two_block_det(): comm=comm_world, ) except ValueError: - # Not enough processes to split the TOD, constuctor expected to rise + # Not enough processes to split the TOD, constructor expected to rise if comm_world.size < 2: return @@ -264,7 +266,7 @@ def test_observation_tod_set_blocks(): comm=comm_world, ) except ValueError: - # Not enough processes to split the TOD, constuctor expected to rise + # Not enough processes to split the TOD, constructor expected to rise if comm_world.size < 2: return @@ -501,7 +503,7 @@ def test_simulation_random(): assert state3["state"]["state"] != state4["state"]["state"] -def main(): +if __name__ == "__main__": test_observation_time() test_construction_from_detectors() test_observation_tod_single_block() @@ -534,6 +536,3 @@ def main(): if tmp_dir: tmp_dir.cleanup() - - -main()