Skip to content

Commit

Permalink
proj: update test for interpol, and docs about "threads"
Browse files Browse the repository at this point in the history
  • Loading branch information
mhasself committed Dec 20, 2023
1 parent af724bb commit 693b4e4
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 18 deletions.
51 changes: 44 additions & 7 deletions python/proj/wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,22 @@ class Projectionist:
* ``wcs`` - the WCS describing the celestial axes of the map.
Together with ``shape`` this is a geometry; see pixell.enmap
documentation.
* ``threads`` - the thread assignment, which is a RangesMatrix
with shape (n_threads,n_dets,n_samps), used to specify which
samples should be treated by each thread in TOD-to-map
operations. Such objects should satisfy the condition that
threads[x,j]*threads[y,j] is the empty Range for x != y;
i.e. each detector-sample is assigned to at most one thread.
* ``threads`` - the thread assignment, consisting of a list of
RangesMatrix objects. Each RangesMatrix object must have shape
(n_threads, n_dets, n_samps). The n_threads does not need to
be the same for every entry in the list. In TOD-to-map
operations, each entry of this list is processed fully before
proceeding to the next one. Each entry "ranges" is processed
using (up to) the specified number of threads, such that thread
i performs operations only on the samples included in
ranges[i,:,:]. Most thread assignment routines in this module
will return a list of two RangesMatrix objects,
[ranges_parallel, ranges_serial]. The first item represents the
part of the computation that can be done in parallel, and has
shape (n_threads, n_dets, n_samps). The ranges_serial object
has shape (1, n_dets, n_samps) and represents any samples that
need to be treated in a single thread. The ranges_serial is
only non-trivial when interpolation is active.
* ``interpol``: How positions that fall between pixel centers will
be handled. Options are "nearest" (default): Use Nearest
Neighbor interpolation, so a sample takes the value of
Expand Down Expand Up @@ -645,7 +655,34 @@ def tile_offset(self, tile):
return row * self.tile_shape[0], col * self.tile_shape[1]

def wrap_ivals(ivals):
return tuple([RangesMatrix([RangesMatrix(y) for y in x]) for x in ivals])
"""Thread computation routines at C++ level return nested lists of
Ranges objects; i.e. something like this::
ivals = [
[ # thread assignments for first "bunch"
[Ranges, Ranges, ... ], # for thread 0
[Ranges, Ranges, ... ],
...
[Ranges, Ranges, ... ], # for thread n-1.
],
[ # thread assignments for second "bunch"
[Ranges, Ranges, ... ], # for thread 0
],
]
This function wraps and returns each highest level entry into a
RangesMatrix, i.e.::
wrapped = [
RangesMatrix(n_threads1, n_det, n_samp),
RangesMatrix(n_threads2, n_det, n_samp),
]
Currently all use cases have len(ivals) == 2 and n_threads2 = 1
but the scheme is more general than that.
"""
return [RangesMatrix([RangesMatrix(y) for y in x]) for x in ivals]

THREAD_ASSIGNMENT_METHODS = [
'simple',
Expand Down
35 changes: 24 additions & 11 deletions test/test_proj_eng.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,19 @@ def test_10_tiled(self):

@requires_pixell
def test_20_threads(self):
for (clipped, tiled, method) in itertools.product(
for (clipped, tiled, interpol, method) in itertools.product(
[False, True],
[False, True],
['nearest', 'bilinear'],
proj.wcs.THREAD_ASSIGNMENT_METHODS):
# For error messages ...
detail = f'(method={method}, tiled={tiled}, clipped={clipped})'
detail = f'(method={method}, tiled={tiled}, clipped={clipped}, interpol={interpol})'
scan, asm, (shape, wcs) = get_basics(clipped=clipped)
if tiled:
p = proj.Projectionist.for_tiled(shape, wcs, (150, 150), active_tiles=False)
p = proj.Projectionist.for_tiled(shape, wcs, (150, 150), active_tiles=False,
interpol=interpol)
else:
p = proj.Projectionist.for_geom(shape, wcs)
p = proj.Projectionist.for_geom(shape, wcs, interpol=interpol)
sig = np.ones((2, len(scan[0])), 'float32')
n_threads = 3

Expand All @@ -104,17 +106,28 @@ def test_20_threads(self):
else:
threads = p.assign_threads(asm, method=method, n_threads=n_threads)
# This may need to be generalized if we implement fancier threads schemes.
self.assertEqual(threads[0].shape, (n_threads,) + sig.shape,
msg=f'threads has wrong shape ({detail})')
self.assertIsInstance(threads, list,
msg=f'a thread assignment routine did not return a list ({detail})')

# Make sure the threads cover the TOD, or not,
# depending on clipped. This may also need generalization
counts = np.zeros(threads[0].shape[1:], int)
for t in threads[0]:
counts += t.mask()
# depending on clipped.
counts0 = threads[0].mask().sum(axis=0)
counts1 = np.zeros(counts0.shape, int)

self.assertEqual(threads[0].shape, (n_threads,) + sig.shape,
msg=f'a thread bunch has wrong shape ({detail})')

for t in threads[1:]:
counts1 += t.mask().sum(axis=0)
self.assertEqual(t.shape[1:], sig.shape,
msg=f'a thread bunch has unexpected shape ({detail})')

target = set([0,1]) if clipped else set([1])
self.assertEqual(set(counts.ravel()), target,
self.assertEqual(set((counts0 + counts1).ravel()), target,
msg=f'threads does not cover TOD ({detail})')
# Only the first segment should be non-empty, unless bilinear.
if interpol == 'nearest':
self.assertEqual(counts1.sum(), 0)


if __name__ == '__main__':
Expand Down

0 comments on commit 693b4e4

Please sign in to comment.