Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip empty datasets #192

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
25 changes: 20 additions & 5 deletions extra_data/keydata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Tuple

import numpy as np

Expand Down Expand Up @@ -28,6 +28,12 @@ def __init__(
def _find_chunks(self):
"""Find contiguous chunks of data for this key, in any order."""
for file in self.files:
if file.file[self.hdf5_data_path].size == 0:
# this file does not contain data for this key. We skip the files here
# for cases where index claims there's data for this key, but the
# dataset is empty.
continue

firsts, counts = file.get_index(self.source, self._key_group)

# Of trains in this file, which are in selection
Expand Down Expand Up @@ -116,7 +122,9 @@ def data_counts(self):
seq_series = []

for f in self.files:
if self.section == 'CONTROL':
if f.file[self.hdf5_data_path].size == 0:
counts = np.zeros_like(f.train_ids, dtype=np.uint64)
elif self.section == 'CONTROL':
counts = np.ones(len(f.train_ids), dtype=np.uint64)
else:
_, counts = f.get_index(self.source, self._key_group)
Expand Down Expand Up @@ -168,6 +176,8 @@ def _trainid_index(self):
np.repeat(chunk.train_ids, chunk.counts.astype(np.intp))
for chunk in self._data_chunks
]
if not chunks_trainids:
return np.array([], dtype=np.uint64)
return np.concatenate(chunks_trainids)

def xarray(self, extra_dims=None, roi=(), name=None):
Expand Down Expand Up @@ -279,6 +289,9 @@ def dask_array(self, labelled=False):
)[chunk.slice]
)

if not chunks_darrs:
chunks_darrs = [da.empty(shape=(0,) + self.entry_shape,
dtype=self.dtype, chunks=self.shape)]
dask_arr = da.concatenate(chunks_darrs, axis=0)

if labelled:
Expand All @@ -295,7 +308,7 @@ def dask_array(self, labelled=False):

# Getting data by train: --------------------------------------------------

def _find_tid(self, tid) -> (Optional[FileAccess], int):
def _find_tid(self, tid) -> Tuple[Optional[FileAccess], int]:
for fa in self.files:
matches = (fa.train_ids == tid).nonzero()[0]
if self.inc_suspect_trains and matches.size > 0:
Expand All @@ -316,8 +329,8 @@ def train_from_id(self, tid):
raise TrainIDError(tid)

fa, ix = self._find_tid(tid)
if fa is None:
return np.empty((0,) + self.entry_shape, dtype=self.dtype)
if fa is None or fa.file[self.hdf5_data_path].size == 0:
return tid, np.empty((0,) + self.entry_shape, dtype=self.dtype)

firsts, counts = fa.get_index(self.source, self._key_group)
first, count = firsts[ix], counts[ix]
Expand All @@ -341,6 +354,8 @@ def trains(self):
for chunk in self._data_chunks:
start = chunk.first
ds = chunk.dataset
if ds.size == 0:
continue
for tid, count in zip(chunk.train_ids, chunk.counts):
if count > 1:
yield tid, ds[start: start+count]
Expand Down
57 changes: 28 additions & 29 deletions extra_data/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import sys
import tempfile
import time
from typing import Tuple
from warnings import warn

from .exceptions import (
Expand Down Expand Up @@ -381,6 +382,8 @@ def train_from_id(self, train_id, devices=None, *, flat_keys=False):

for key in self.keys_for_source(source):
path = '/CONTROL/{}/{}'.format(source, key.replace('.', '/'))
if file.file[path].size == 0:
continue
source_data[key] = file.file[path][pos]

for source in self.instrument_sources:
Expand All @@ -399,6 +402,8 @@ def train_from_id(self, train_id, devices=None, *, flat_keys=False):
continue

path = '/INSTRUMENT/{}/{}'.format(source, key.replace('.', '/'))
if file.file[path].size == 0:
continue
if count == 1:
source_data[key] = file.file[path][first]
else:
Expand Down Expand Up @@ -481,6 +486,8 @@ def get_series(self, source, key):
if source in self.instrument_sources:
data_path = "/INSTRUMENT/{}/{}".format(source, key.replace('.', '/'))
for f in self._source_index[source]:
if f.file[data_path].size == 0:
continue
group = key.partition('.')[0]
firsts, counts = f.get_index(source, group)
trainids = self._expand_trainids(counts, f.train_ids)
Expand Down Expand Up @@ -510,7 +517,10 @@ def get_series(self, source, key):
else:
return self._get_key_data(source, key).series()

ser = pd.concat(sorted(seq_series, key=lambda s: s.index[0]))
if not seq_series:
ser = pd.Series([], dtype=self._get_key_data(source, key).dtype)
else:
ser = pd.concat(sorted(seq_series, key=lambda s: s.index[0]))

# Select out only the train IDs of interest
if isinstance(ser.index, pd.MultiIndex):
Expand Down Expand Up @@ -840,34 +850,23 @@ def select(self, seln_or_source_glob, key_glob='*', require_all=False):
train_ids = self.train_ids

for source, keys in selection.items():
if source in self.instrument_sources:
# For INSTRUMENT sources, the INDEX is saved by
# key group, which is the first hash component. In
# many cases this is 'data', but not always.
if keys is None:
# All keys are selected.
keys = self.keys_for_source(source)

groups = {key.partition('.')[0] for key in keys}
else:
# CONTROL data has no key group.
groups = ['']

for group in groups:
# Empty list would be converted to np.float64 array.
source_tids = np.empty(0, dtype=np.uint64)

for f in self._source_index[source]:
valid = True if self.inc_suspect_trains else f.validity_flag
# Add the trains with data in each file.
_, counts = f.get_index(source, group)
source_tids = np.union1d(
f.train_ids[valid & (counts > 0)], source_tids
)
if keys is None:
keys = self.keys_for_source(source)

for key in keys:
counts = self._get_key_data(source, key).data_counts()
key_tids = counts[counts>0].index.values

# Remove any trains previously selected, for which this
# selected source and key group has no data.
train_ids = np.intersect1d(train_ids, source_tids)
train_ids = np.intersect1d(train_ids, key_tids)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code now needs to:

  • Find all keys for control sources (which are often more numerous than instrument keys, IIRC), where the selection keeps all keys.
  • Do one intersection per key, rather than one per source (/source group). I found in Speed up _check_source_conflicts() #183 that NumPy set operations are not that efficient for working with many small sets. (I appreciate that it avoids a set union per file, though).

So I am concerned that this could be markedly slower in some circumstances (probably when selecting sources with many keys) than the code it replaces.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is slower, but I don't think that it's a problem here. Selections are not operation you'd do very often and you'd also not select many sources.
When testing with this branch, a selection took always less than 1 ms.


if train_ids.size == 0:
return DataCollection(
[], selection=selection, train_ids=[],
inc_suspect_trains=self.inc_suspect_trains,
is_single_run=self.is_single_run
)

# Filtering may have eliminated previously selected files.
files = [f for f in files
Expand Down Expand Up @@ -997,7 +996,7 @@ def _find_data_chunks(self, source, key):
"""
return self._get_key_data(source, key)._data_chunks

def _find_data(self, source, train_id) -> (FileAccess, int):
def _find_data(self, source, train_id) -> Tuple[FileAccess, int]:
for f in self._source_index[source]:
ixs = (f.train_ids == train_id).nonzero()[0]
if self.inc_suspect_trains and ixs.size > 0:
Expand Down Expand Up @@ -1367,7 +1366,7 @@ def _assemble_data(self, tid):

for key in self.data.keys_for_source(source):
_, pos, ds = self._find_data(source, key, tid)
if ds is None:
if ds is None or ds.size == 0:
continue
self._set_result(res, source, key, ds[pos])

Expand All @@ -1376,7 +1375,7 @@ def _assemble_data(self, tid):
{'source': source, 'timestamp.tid': tid})
for key in self.data.keys_for_source(source):
file, pos, ds = self._find_data(source, key, tid)
if ds is None:
if ds is None or ds.size == 0:
continue
group = key.partition('.')[0]
firsts, counts = file.get_index(source, group)
Expand Down
15 changes: 15 additions & 0 deletions extra_data/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,18 @@ def mock_empty_file():
path = osp.join(td, 'RAW-R0450-DA01-S00002.h5')
make_examples.make_sa3_da_file(path, ntrains=0)
yield path


@pytest.fixture(scope='function')
def mock_empty_dataset_file(format_version):
with TemporaryDirectory() as td:
path = osp.join(td, 'RAW-R0999-DA10-S00001.h5')
make_examples.make_fxe_da_file(path, format_version=format_version)

with h5py.File(path, 'a') as f:
f['INSTRUMENT/SA1_XTD2_XGM/DOOCS/MAIN:output/data/intensityTD'].resize(0, axis=0)
f['INSTRUMENT/SA1_XTD2_XGM/DOOCS/MAIN:output/data/trainId'].resize(0, axis=0)
f['CONTROL/SA1_XTD2_XGM/DOOCS/MAIN/pulseEnergy/photonFlux/value'].resize(0, axis=0)
f['CONTROL/SA1_XTD2_XGM/DOOCS/MAIN/beamPosition/ixPos/value'].resize(0, axis=0)

yield path
26 changes: 24 additions & 2 deletions extra_data/tests/test_keydata.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np
import pytest

from extra_data import RunDirectory
from extra_data import H5File, RunDirectory
from extra_data.exceptions import TrainIDError


def test_get_keydata(mock_spb_raw_run):
run = RunDirectory(mock_spb_raw_run)
print(run.instrument_sources)
Expand Down Expand Up @@ -95,7 +96,7 @@ def test_data_counts(mock_reduced_spb_proc_run):
assert count.index.tolist() == xgm_beam_x.train_ids
assert (count.values == 1).all()

# intrument data
# instrument data
camera = run['SPB_IRU_CAM/CAM/SIDEMIC:daqOutput', 'data.image.pixels']
count = camera.data_counts()
assert count.index.tolist() == camera.train_ids
Expand All @@ -113,3 +114,24 @@ def test_select_by(mock_spb_raw_run):
subrun = run.select(am0)
assert subrun.all_sources == {am0.source}
assert subrun.keys_for_source(am0.source) == {am0.key}


def test_empty_dataset(mock_empty_dataset_file):

run = H5File(mock_empty_dataset_file)
kd = run['SA1_XTD2_XGM/DOOCS/MAIN:output', 'data.intensityTD']
assert not kd.data_counts().values.any()
assert kd.ndarray().size == 0
assert kd.xarray().size == 0
assert kd.dask_array().size == 0

tid, data = kd.train_from_index(0)
assert data.shape == (0, 1000)

kd = run['SA1_XTD2_XGM/DOOCS/MAIN', 'pulseEnergy.photonFlux.value']
assert kd.series().size == 0
tid, data = kd.train_from_index(0)
assert tid == 10000
assert data.shape == (0,)

assert len(list(kd.trains())) == 0
34 changes: 34 additions & 0 deletions extra_data/tests/test_reader_mockdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,3 +804,37 @@ def test_run_metadata(mock_spb_raw_run):
'sample', 'sequenceNumber',
}
assert isinstance(md['creationDate'], str)


def test_empty_dataset(mock_empty_dataset_file):

run = H5File(mock_empty_dataset_file)
device, key = 'SA1_XTD2_XGM/DOOCS/MAIN:output', 'data.intensityTD'

assert not run.get_data_counts(device, key).any()
assert run.get_array(device, key).size == 0
assert run.get_dask_array(device, key).size == 0

sel = run.select(device, key)
_, data = sel.train_from_index(0)
assert list(data[device].keys()) == ['metadata']

for _, data in sel.trains(require_all=True):
assert key not in data[device]
break

_, data = sel.train_from_index(0)
assert key not in data[device]

s = run.get_series(device, 'data.trainId')
assert isinstance(s, pd.Series)
assert len(s) == 0

df = run.get_dataframe(fields=[("*_XGM/*", "*.i[xy]Pos*")])
assert len(df.columns) == 4
assert "SA1_XTD2_XGM/DOOCS/MAIN/beamPosition.ixPos" in df.columns

dc = run.select(device, require_all=True)
assert dc.selection == {device: None}
assert dc.all_sources == frozenset()
assert dc.train_ids == []