Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Map function over detector frames #333

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 125 additions & 1 deletion extra_data/components.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Interfaces to data from specific instruments
"""
import inspect
import logging
import re
from copy import copy
Expand Down Expand Up @@ -98,6 +99,7 @@ class MultimodDetectorBase:
# Override in subclass
_main_data_key = '' # Key to use for checking data counts match
_frames_per_entry = 1 # Override if separate pulse dimension in files
_modnos_start_at = 0 # Override if module numbers start at 1 (JUNGFRAU)
module_shape = (0, 0)
n_modules = 0

Expand Down Expand Up @@ -282,6 +284,10 @@ def frames_per_train(self):
raise ValueError(f"Varying number of frames per train: {counts}")
return counts.pop() * self._frames_per_entry

@property
def n_frames(self):
return self.frame_counts.sum() * self._frames_per_entry

def __repr__(self):
return "<{}: Data interface for detector {!r} with {} modules>".format(
type(self).__name__, self.detector_name, len(self.source_to_modno),
Expand Down Expand Up @@ -407,7 +413,7 @@ def get_array(self, key, *, fill_value=None, roi=(), astype=None):
Specify e.g. ``np.s_[10:60, 100:200]`` to select pixels within each
module when reading data. The selection is applied to each individual
module, so it may only be useful when working with a single module.
astype: Type
astype: dtype
Data type of the output array. If None (default) the dtype matches the
input array dtype
"""
Expand Down Expand Up @@ -466,6 +472,123 @@ def get_dask_array(self, key, fill_value=None, astype=None):

return self._concat(arrays, modnos, fill_value, astype)

def _get_data(self, key, *, fill_value=None, roi=(), astype=None):
"""Get data as a plain NumPy array with no labels"""
train_ids = self.train_ids_perframe

eg_src = min(self.source_to_modno)
eg_keydata = self.data[eg_src, key]

# Find the shape of 1 frame for 1 module with the ROI applied
out_shape = ((self.n_modules, len(train_ids))
+ roi_shape(eg_keydata.entry_shape, roi))

dtype = eg_keydata.dtype if astype is None else np.dtype(astype)
out = self._out_array(out_shape, dtype, fill_value=fill_value)

for modno, source in sorted(self.modno_to_source.items()):
mod_ix = modno - self._modnos_start_at
for chunk in self.data._find_data_chunks(source, key):
for tgt_slice, chunk_slice in self._split_align_chunk(chunk, train_ids):
chunk.dataset.read_direct(
out[mod_ix, tgt_slice], source_sel=(chunk_slice,) + roi
)

return out

def _apply_framewise(self, f, out, data_params={}):
arr = self._get_data(self._main_data_key)
# Array should be (modules, frames, *pixel_dims)
ndim_px = len(self.module_shape)
ndim_iter = arr.ndim - 1 - ndim_px
arr = arr.reshape((arr.shape[0], -1, *arr.shape[-ndim_px:]))

# Prepare arrays for data to be passed as kwargs (mask, pulseId, etc.)
kw_arrs = {}
for param, key in data_params.items():
a = self._get_data(key)
ndim_inner = a.ndim - 1 - ndim_iter
kw_arrs[param] = a.reshape((a.shape[0], -1, *arr.shape[-ndim_inner:]))

for i in range(arr.shape[1]):
kw = {p: a[:, i] for (p, a) in kw_arrs.items()}
out[i] = f(arr[:, i], **kw)

def _frame_func_to_chunk_func(self, f, out_shape=None, out_dtype=None):
eg_srcdata = self.data[min(self.source_to_modno)]
main_group = self._main_data_key.rpartition('.')[0] + '.'
data_keys = {k.rpartition('.')[2]: k for k in eg_srcdata.keys()
if k.startswith(main_group)}
data_params = {}
for param_name in list(inspect.signature(f).parameters)[1:]:
if param_name in data_keys:
data_params[param_name] = data_keys[param_name]
else:
raise KeyError(f"No {param_name} data available; "
f"possible names are {', '.join(data_keys)}")

def chunk_func(chunk):
if out_shape is not None:
out = chunk._out_array((chunk.n_frames, *out_shape), dtype=out_dtype)
else:
out = [None] * chunk.n_frames
chunk._apply_framewise(f, out, data_params)
return out

return chunk_func

def map_frames(
self, f, mapper=None, *,
out=None, out_shape=None, out_dtype=None,
parts=None, trains_per_part=None, frames_per_part=None
):
if mapper is None:
# Default to using multiprocessing with up to 16 cores.
# We're likely to spend a fair bit of
import multiprocessing
with multiprocessing.Pool(min(multiprocessing.cpu_count(), 16)) as p:
return self.map_frames(
f, p.imap, out=out, out_shape=out_shape, out_dtype=out_dtype,
parts=parts, trains_per_part=trains_per_part,
frames_per_part=frames_per_part,
)

if parts is None and trains_per_part is None and frames_per_part is None:
# Default ~4 GiB chunks for 1 MPx detectors. This is probably too
# big for all cores in parallel on one node, but in many cases the
# limiting step will be loading data, so you want fewer workers
# (or split it over multiple nodes)
frames_per_part = 1000
chunks = list(self.split_trains(parts, trains_per_part, frames_per_part))

map_kwargs = {}
if 'key' in inspect.signature(mapper).parameters:
# Dask workaround: avoid pickling & clumsily md5-ing function to
# produce task keys
from secrets import token_hex
map_kwargs['key'] = [f"map-frames-{token_hex(16)}" for _ in chunks]

chunk_func = self._frame_func_to_chunk_func(f)
results_iter = mapper(chunk_func, chunks, **map_kwargs)

if out is None:
if out_shape is not None:
out = self._out_array((self.n_frames, *out_shape), dtype=out_dtype)
else:
out = [None] * self.n_frames

# Assemble per-chunk results into output list/array
out_cursor = 0
for chunk_res in results_iter:
if hasattr(chunk_res, 'result'):
# Dask returns futures rather than direct results
chunk_res = chunk_res.result()
to = out_cursor + len(chunk_res)
out[out_cursor : to] = chunk_res
out_cursor = to

return out

def trains(self, require_all=True):
"""Iterate over trains for detector data.

Expand Down Expand Up @@ -1429,6 +1552,7 @@ class JUNGFRAU(MultimodDetectorBase):
r'(MODULE_|RECEIVER-|JNGFR)(?P<modno>\d+)'
)
_main_data_key = 'data.adc'
_modnos_start_at = 1
module_shape = (512, 1024)

def __init__(self, data: DataCollection, detector_name=None, modules=None,
Expand Down