diff --git a/extra_data/keydata.py b/extra_data/keydata.py index 2846a852..5231acb9 100644 --- a/extra_data/keydata.py +++ b/extra_data/keydata.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Tuple import numpy as np @@ -28,6 +28,12 @@ def __init__( def _find_chunks(self): """Find contiguous chunks of data for this key, in any order.""" for file in self.files: + if file.file[self.hdf5_data_path].size == 0: + # this file does not contain data for this key. We skip the files here + # for cases where index claims there's data for this key, but the + # dataset is empty. + continue + firsts, counts = file.get_index(self.source, self._key_group) # Of trains in this file, which are in selection @@ -116,7 +122,9 @@ def data_counts(self): seq_series = [] for f in self.files: - if self.section == 'CONTROL': + if f.file[self.hdf5_data_path].size == 0: + counts = np.zeros_like(f.train_ids, dtype=np.uint64) + elif self.section == 'CONTROL': counts = np.ones(len(f.train_ids), dtype=np.uint64) else: _, counts = f.get_index(self.source, self._key_group) @@ -168,6 +176,8 @@ def _trainid_index(self): np.repeat(chunk.train_ids, chunk.counts.astype(np.intp)) for chunk in self._data_chunks ] + if not chunks_trainids: + return np.array([], dtype=np.uint64) return np.concatenate(chunks_trainids) def xarray(self, extra_dims=None, roi=(), name=None): @@ -279,6 +289,9 @@ def dask_array(self, labelled=False): )[chunk.slice] ) + if not chunks_darrs: + chunks_darrs = [da.empty(shape=(0,) + self.entry_shape, + dtype=self.dtype, chunks=self.shape)] dask_arr = da.concatenate(chunks_darrs, axis=0) if labelled: @@ -295,7 +308,7 @@ def dask_array(self, labelled=False): # Getting data by train: -------------------------------------------------- - def _find_tid(self, tid) -> (Optional[FileAccess], int): + def _find_tid(self, tid) -> Tuple[Optional[FileAccess], int]: for fa in self.files: matches = (fa.train_ids == tid).nonzero()[0] if self.inc_suspect_trains and matches.size > 0: @@ -316,8 +329,8 @@ def train_from_id(self, tid): raise TrainIDError(tid) fa, ix = self._find_tid(tid) - if fa is None: - return np.empty((0,) + self.entry_shape, dtype=self.dtype) + if fa is None or fa.file[self.hdf5_data_path].size == 0: + return tid, np.empty((0,) + self.entry_shape, dtype=self.dtype) firsts, counts = fa.get_index(self.source, self._key_group) first, count = firsts[ix], counts[ix] @@ -341,6 +354,8 @@ def trains(self): for chunk in self._data_chunks: start = chunk.first ds = chunk.dataset + if ds.size == 0: + continue for tid, count in zip(chunk.train_ids, chunk.counts): if count > 1: yield tid, ds[start: start+count] diff --git a/extra_data/reader.py b/extra_data/reader.py index 4907cacb..1f2ee718 100644 --- a/extra_data/reader.py +++ b/extra_data/reader.py @@ -28,6 +28,7 @@ import sys import tempfile import time +from typing import Tuple from warnings import warn from .exceptions import ( @@ -381,6 +382,8 @@ def train_from_id(self, train_id, devices=None, *, flat_keys=False): for key in self.keys_for_source(source): path = '/CONTROL/{}/{}'.format(source, key.replace('.', '/')) + if file.file[path].size == 0: + continue source_data[key] = file.file[path][pos] for source in self.instrument_sources: @@ -399,6 +402,8 @@ def train_from_id(self, train_id, devices=None, *, flat_keys=False): continue path = '/INSTRUMENT/{}/{}'.format(source, key.replace('.', '/')) + if file.file[path].size == 0: + continue if count == 1: source_data[key] = file.file[path][first] else: @@ -481,6 +486,8 @@ def get_series(self, source, key): if source in self.instrument_sources: data_path = "/INSTRUMENT/{}/{}".format(source, key.replace('.', '/')) for f in self._source_index[source]: + if f.file[data_path].size == 0: + continue group = key.partition('.')[0] firsts, counts = f.get_index(source, group) trainids = self._expand_trainids(counts, f.train_ids) @@ -510,7 +517,10 @@ def get_series(self, source, key): else: return self._get_key_data(source, key).series() - ser = pd.concat(sorted(seq_series, key=lambda s: s.index[0])) + if not seq_series: + ser = pd.Series([], dtype=self._get_key_data(source, key).dtype) + else: + ser = pd.concat(sorted(seq_series, key=lambda s: s.index[0])) # Select out only the train IDs of interest if isinstance(ser.index, pd.MultiIndex): @@ -840,34 +850,23 @@ def select(self, seln_or_source_glob, key_glob='*', require_all=False): train_ids = self.train_ids for source, keys in selection.items(): - if source in self.instrument_sources: - # For INSTRUMENT sources, the INDEX is saved by - # key group, which is the first hash component. In - # many cases this is 'data', but not always. - if keys is None: - # All keys are selected. - keys = self.keys_for_source(source) - - groups = {key.partition('.')[0] for key in keys} - else: - # CONTROL data has no key group. - groups = [''] - - for group in groups: - # Empty list would be converted to np.float64 array. - source_tids = np.empty(0, dtype=np.uint64) - - for f in self._source_index[source]: - valid = True if self.inc_suspect_trains else f.validity_flag - # Add the trains with data in each file. - _, counts = f.get_index(source, group) - source_tids = np.union1d( - f.train_ids[valid & (counts > 0)], source_tids - ) + if keys is None: + keys = self.keys_for_source(source) + + for key in keys: + counts = self._get_key_data(source, key).data_counts() + key_tids = counts[counts>0].index.values # Remove any trains previously selected, for which this # selected source and key group has no data. - train_ids = np.intersect1d(train_ids, source_tids) + train_ids = np.intersect1d(train_ids, key_tids) + + if train_ids.size == 0: + return DataCollection( + [], selection=selection, train_ids=[], + inc_suspect_trains=self.inc_suspect_trains, + is_single_run=self.is_single_run + ) # Filtering may have eliminated previously selected files. files = [f for f in files @@ -997,7 +996,7 @@ def _find_data_chunks(self, source, key): """ return self._get_key_data(source, key)._data_chunks - def _find_data(self, source, train_id) -> (FileAccess, int): + def _find_data(self, source, train_id) -> Tuple[FileAccess, int]: for f in self._source_index[source]: ixs = (f.train_ids == train_id).nonzero()[0] if self.inc_suspect_trains and ixs.size > 0: @@ -1367,7 +1366,7 @@ def _assemble_data(self, tid): for key in self.data.keys_for_source(source): _, pos, ds = self._find_data(source, key, tid) - if ds is None: + if ds is None or ds.size == 0: continue self._set_result(res, source, key, ds[pos]) @@ -1376,7 +1375,7 @@ def _assemble_data(self, tid): {'source': source, 'timestamp.tid': tid}) for key in self.data.keys_for_source(source): file, pos, ds = self._find_data(source, key, tid) - if ds is None: + if ds is None or ds.size == 0: continue group = key.partition('.')[0] firsts, counts = file.get_index(source, group) diff --git a/extra_data/tests/conftest.py b/extra_data/tests/conftest.py index e9f53d6b..901f902c 100644 --- a/extra_data/tests/conftest.py +++ b/extra_data/tests/conftest.py @@ -122,3 +122,18 @@ def mock_empty_file(): path = osp.join(td, 'RAW-R0450-DA01-S00002.h5') make_examples.make_sa3_da_file(path, ntrains=0) yield path + + +@pytest.fixture(scope='function') +def mock_empty_dataset_file(format_version): + with TemporaryDirectory() as td: + path = osp.join(td, 'RAW-R0999-DA10-S00001.h5') + make_examples.make_fxe_da_file(path, format_version=format_version) + + with h5py.File(path, 'a') as f: + f['INSTRUMENT/SA1_XTD2_XGM/DOOCS/MAIN:output/data/intensityTD'].resize(0, axis=0) + f['INSTRUMENT/SA1_XTD2_XGM/DOOCS/MAIN:output/data/trainId'].resize(0, axis=0) + f['CONTROL/SA1_XTD2_XGM/DOOCS/MAIN/pulseEnergy/photonFlux/value'].resize(0, axis=0) + f['CONTROL/SA1_XTD2_XGM/DOOCS/MAIN/beamPosition/ixPos/value'].resize(0, axis=0) + + yield path diff --git a/extra_data/tests/test_keydata.py b/extra_data/tests/test_keydata.py index 637a1f83..91130501 100644 --- a/extra_data/tests/test_keydata.py +++ b/extra_data/tests/test_keydata.py @@ -1,9 +1,10 @@ import numpy as np import pytest -from extra_data import RunDirectory +from extra_data import H5File, RunDirectory from extra_data.exceptions import TrainIDError + def test_get_keydata(mock_spb_raw_run): run = RunDirectory(mock_spb_raw_run) print(run.instrument_sources) @@ -95,7 +96,7 @@ def test_data_counts(mock_reduced_spb_proc_run): assert count.index.tolist() == xgm_beam_x.train_ids assert (count.values == 1).all() - # intrument data + # instrument data camera = run['SPB_IRU_CAM/CAM/SIDEMIC:daqOutput', 'data.image.pixels'] count = camera.data_counts() assert count.index.tolist() == camera.train_ids @@ -113,3 +114,24 @@ def test_select_by(mock_spb_raw_run): subrun = run.select(am0) assert subrun.all_sources == {am0.source} assert subrun.keys_for_source(am0.source) == {am0.key} + + +def test_empty_dataset(mock_empty_dataset_file): + + run = H5File(mock_empty_dataset_file) + kd = run['SA1_XTD2_XGM/DOOCS/MAIN:output', 'data.intensityTD'] + assert not kd.data_counts().values.any() + assert kd.ndarray().size == 0 + assert kd.xarray().size == 0 + assert kd.dask_array().size == 0 + + tid, data = kd.train_from_index(0) + assert data.shape == (0, 1000) + + kd = run['SA1_XTD2_XGM/DOOCS/MAIN', 'pulseEnergy.photonFlux.value'] + assert kd.series().size == 0 + tid, data = kd.train_from_index(0) + assert tid == 10000 + assert data.shape == (0,) + + assert len(list(kd.trains())) == 0 diff --git a/extra_data/tests/test_reader_mockdata.py b/extra_data/tests/test_reader_mockdata.py index 8ba38447..cb80ebfa 100644 --- a/extra_data/tests/test_reader_mockdata.py +++ b/extra_data/tests/test_reader_mockdata.py @@ -804,3 +804,37 @@ def test_run_metadata(mock_spb_raw_run): 'sample', 'sequenceNumber', } assert isinstance(md['creationDate'], str) + + +def test_empty_dataset(mock_empty_dataset_file): + + run = H5File(mock_empty_dataset_file) + device, key = 'SA1_XTD2_XGM/DOOCS/MAIN:output', 'data.intensityTD' + + assert not run.get_data_counts(device, key).any() + assert run.get_array(device, key).size == 0 + assert run.get_dask_array(device, key).size == 0 + + sel = run.select(device, key) + _, data = sel.train_from_index(0) + assert list(data[device].keys()) == ['metadata'] + + for _, data in sel.trains(require_all=True): + assert key not in data[device] + break + + _, data = sel.train_from_index(0) + assert key not in data[device] + + s = run.get_series(device, 'data.trainId') + assert isinstance(s, pd.Series) + assert len(s) == 0 + + df = run.get_dataframe(fields=[("*_XGM/*", "*.i[xy]Pos*")]) + assert len(df.columns) == 4 + assert "SA1_XTD2_XGM/DOOCS/MAIN/beamPosition.ixPos" in df.columns + + dc = run.select(device, require_all=True) + assert dc.selection == {device: None} + assert dc.all_sources == frozenset() + assert dc.train_ids == []