Skip to content

Commit

Permalink
added test for detector block distribution; fixed the assignment of s…
Browse files Browse the repository at this point in the history
…ome attributes in observations.py
  • Loading branch information
anand-avinash committed Oct 30, 2024
1 parent 2eeab2c commit a941fd0
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 15 deletions.
14 changes: 7 additions & 7 deletions docs/source/observations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ the number of detector blocks (``n_blocks_time = comm.size // n_blocks_det``).

The detector blocks made in this way can be accessed with
``Observation.detector_blocks``. It is a dictionary object has the tuple of
``self._det_blocks_attributes`` values as dictionary keys and the list of detectors
``det_blocks_attributes`` values as dictionary keys and the list of detectors
corresponding to the key as dictionary values. This dictionary is sorted so that the
group with the largest number of detectors comes first and the one with
the fewest detectors comes last.
Expand All @@ -211,37 +211,37 @@ MPI processes when ``det_blocks_attributes`` is specified.
name="channel1_w9_detA",
wafer="wafer_9",
channel="channel1",
sampling_rate_hz=1,
sampling_rate_hz=sampling_freq_Hz,
),
lbs.DetectorInfo(
name="channel1_w3_detB",
wafer="wafer_3",
channel="channel1",
sampling_rate_hz=1,
sampling_rate_hz=sampling_freq_Hz,
),
lbs.DetectorInfo(
name="channel1_w1_detC",
wafer="wafer_1",
channel="channel1",
sampling_rate_hz=1,
sampling_rate_hz=sampling_freq_Hz,
),
lbs.DetectorInfo(
name="channel1_w1_detD",
wafer="wafer_1",
channel="channel1",
sampling_rate_hz=1,
sampling_rate_hz=sampling_freq_Hz,
),
lbs.DetectorInfo(
name="channel2_w4_detA",
wafer="wafer_4",
channel="channel2",
sampling_rate_hz=1,
sampling_rate_hz=sampling_freq_Hz,
),
lbs.DetectorInfo(
name="channel2_w4_detB",
wafer="wafer_4",
channel="channel2",
sampling_rate_hz=1,
sampling_rate_hz=sampling_freq_Hz,
),
]
Expand Down
21 changes: 13 additions & 8 deletions litebird_sim/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dataclasses import dataclass
from typing import Union, List, Any, Optional
import numbers

import astropy.time
import numpy as np
Expand All @@ -11,6 +12,7 @@

from .coordinates import DEFAULT_TIME_SCALE
from .distribute import distribute_evenly, distribute_detector_blocks
from .detectors import DetectorInfo


@dataclass
Expand Down Expand Up @@ -142,16 +144,16 @@ def __init__(
if isinstance(detectors, int):
self._n_detectors_global = detectors
else:
if comm and comm.size > 1:
self._n_detectors_global = comm.bcast(len(detectors), root)
if self.comm and self.comm.size > 1:
self._n_detectors_global = self.comm.bcast(len(detectors), root)

if self._det_blocks_attributes is not None:
n_blocks_det, n_blocks_time = self._make_detector_blocks(
detectors, self.comm
)
else:
self._n_detectors_global = len(detectors)

if self._det_blocks_attributes is not None and comm.size > 1:
n_blocks_det, n_blocks_time = self._make_detector_blocks(
detectors, comm
)

# Name of the attributes that store an array with the value of a
# property for each of the (local) detectors
self._attr_det_names = []
Expand Down Expand Up @@ -673,7 +675,10 @@ def setattr_det_global(self, name, info, root=0):
is_in_grid = self.comm.rank < self._n_blocks_det * self._n_blocks_time
comm_grid = self.comm.Split(int(is_in_grid))
if not is_in_grid: # The process does not own any detector (and TOD)
setattr(self, name, None)
null_det = DetectorInfo()
attribute = getattr(null_det, name, None)
value = 0 if isinstance(attribute, numbers.Number) else None
setattr(self, name, value)
return

my_col = comm_grid.rank % self._n_blocks_time
Expand Down
100 changes: 100 additions & 0 deletions test/test_detector_blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import numpy as np
import litebird_sim as lbs


def test_detector_blocks():
comm = lbs.MPI_COMM_WORLD

start_time = 456
duration_s = 100
sampling_freq_Hz = 1
nobs_per_det = 3

# Creating a list of detectors.
dets = [
lbs.DetectorInfo(
name="channel1_w9_detA",
wafer="wafer_9",
channel="channel1",
sampling_rate_hz=sampling_freq_Hz,
),
lbs.DetectorInfo(
name="channel1_w3_detB",
wafer="wafer_3",
channel="channel1",
sampling_rate_hz=sampling_freq_Hz,
),
lbs.DetectorInfo(
name="channel1_w1_detC",
wafer="wafer_1",
channel="channel1",
sampling_rate_hz=sampling_freq_Hz,
),
lbs.DetectorInfo(
name="channel1_w1_detD",
wafer="wafer_1",
channel="channel1",
sampling_rate_hz=sampling_freq_Hz,
),
lbs.DetectorInfo(
name="channel2_w4_detA",
wafer="wafer_4",
channel="channel2",
sampling_rate_hz=sampling_freq_Hz,
),
lbs.DetectorInfo(
name="channel2_w4_detB",
wafer="wafer_4",
channel="channel2",
sampling_rate_hz=sampling_freq_Hz,
),
]

sim = lbs.Simulation(
start_time=start_time,
duration_s=duration_s,
random_seed=12345,
mpi_comm=comm,
)

sim.create_observations(
detectors=dets,
split_list_over_processes=False,
num_of_obs_per_detector=nobs_per_det,
det_blocks_attributes=["channel", "wafer"],
)

tod_len_per_det_per_proc = 0
for obs in sim.observations:
tod_shape = obs.tod.shape

n_blocks_det = obs.n_blocks_det
n_blocks_time = obs.n_blocks_time
tod_len_per_det_per_proc += obs.tod.shape[1]

# No testing required if the proc doesn't owns a detector
if obs.det_idx is not None:
det_names_per_obs = [
obs.detectors_global[idx]["name"] for idx in obs.det_idx
]

# Testing if the mapping between the obs.name and
# obs.det_idx is consistent with obs.detectors_global
np.testing.assert_equal(obs.name, det_names_per_obs)

# Testing the distribution of the number of detectors per
# detector block
np.testing.assert_equal(obs.name.shape[0], tod_shape[0])

# Testing if the distribution of samples along the time axis is consistent
if comm.rank < n_blocks_det * n_blocks_time:
arr = [
span.num_of_elements
for span in lbs.distribute.distribute_evenly(
duration_s * sampling_freq_Hz, n_blocks_time * nobs_per_det
)
]

start_idx = (comm.rank % n_blocks_time) * nobs_per_det
stop_idx = start_idx + nobs_per_det
np.testing.assert_equal(sum(arr[start_idx:stop_idx]), tod_len_per_det_per_proc)

0 comments on commit a941fd0

Please sign in to comment.