Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip empty datasets #192

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
20 changes: 16 additions & 4 deletions extra_data/keydata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Tuple

import numpy as np

Expand Down Expand Up @@ -28,6 +28,12 @@ def __init__(
def _find_chunks(self):
"""Find contiguous chunks of data for this key, in any order."""
for file in self.files:
if file.file[self.hdf5_data_path].size == 0:
# this file does not contain data for this key. We skip the files here
# for cases where index claims there's data for this key, but the
# dataset is empty.
continue

firsts, counts = file.get_index(self.source, self._key_group)

# Of trains in this file, which are in selection
Expand Down Expand Up @@ -168,6 +174,8 @@ def _trainid_index(self):
np.repeat(chunk.train_ids, chunk.counts.astype(np.intp))
for chunk in self._data_chunks
]
if not chunks_trainids:
return chunks_trainids
tmichela marked this conversation as resolved.
Show resolved Hide resolved
return np.concatenate(chunks_trainids)

def xarray(self, extra_dims=None, roi=(), name=None):
Expand Down Expand Up @@ -279,6 +287,8 @@ def dask_array(self, labelled=False):
)[chunk.slice]
)

if not chunks_darrs:
chunks_darrs = [da.empty(shape=self.shape, dtype=self.dtype, chunks=self.shape)]
tmichela marked this conversation as resolved.
Show resolved Hide resolved
dask_arr = da.concatenate(chunks_darrs, axis=0)

if labelled:
Expand All @@ -295,7 +305,7 @@ def dask_array(self, labelled=False):

# Getting data by train: --------------------------------------------------

def _find_tid(self, tid) -> (Optional[FileAccess], int):
def _find_tid(self, tid) -> Tuple[Optional[FileAccess], int]:
for fa in self.files:
matches = (fa.train_ids == tid).nonzero()[0]
if self.inc_suspect_trains and matches.size > 0:
Expand All @@ -316,8 +326,8 @@ def train_from_id(self, tid):
raise TrainIDError(tid)

fa, ix = self._find_tid(tid)
if fa is None:
return np.empty((0,) + self.entry_shape, dtype=self.dtype)
if fa is None or fa.file[self.hdf5_data_path].size == 0:
return tid, np.empty((0,) + self.entry_shape, dtype=self.dtype)

firsts, counts = fa.get_index(self.source, self._key_group)
first, count = firsts[ix], counts[ix]
Expand All @@ -341,6 +351,8 @@ def trains(self):
for chunk in self._data_chunks:
start = chunk.first
ds = chunk.dataset
if ds.size == 0:
continue
for tid, count in zip(chunk.train_ids, chunk.counts):
if count > 1:
yield tid, ds[start: start+count]
Expand Down
18 changes: 14 additions & 4 deletions extra_data/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import sys
import tempfile
import time
from typing import Tuple
from warnings import warn

from .exceptions import (
Expand Down Expand Up @@ -381,6 +382,8 @@ def train_from_id(self, train_id, devices=None, *, flat_keys=False):

for key in self.keys_for_source(source):
path = '/CONTROL/{}/{}'.format(source, key.replace('.', '/'))
if file.file[path].size == 0:
continue
source_data[key] = file.file[path][pos]

for source in self.instrument_sources:
Expand All @@ -399,6 +402,8 @@ def train_from_id(self, train_id, devices=None, *, flat_keys=False):
continue

path = '/INSTRUMENT/{}/{}'.format(source, key.replace('.', '/'))
if file.file[path].size == 0:
continue
if count == 1:
source_data[key] = file.file[path][first]
else:
Expand Down Expand Up @@ -481,6 +486,8 @@ def get_series(self, source, key):
if source in self.instrument_sources:
data_path = "/INSTRUMENT/{}/{}".format(source, key.replace('.', '/'))
for f in self._source_index[source]:
if f.file[data_path].size == 0:
continue
group = key.partition('.')[0]
firsts, counts = f.get_index(source, group)
trainids = self._expand_trainids(counts, f.train_ids)
Expand Down Expand Up @@ -510,7 +517,10 @@ def get_series(self, source, key):
else:
return self._get_key_data(source, key).series()

ser = pd.concat(sorted(seq_series, key=lambda s: s.index[0]))
if not seq_series:
ser = pd.Series([])
tmichela marked this conversation as resolved.
Show resolved Hide resolved
else:
ser = pd.concat(sorted(seq_series, key=lambda s: s.index[0]))

# Select out only the train IDs of interest
if isinstance(ser.index, pd.MultiIndex):
Expand Down Expand Up @@ -997,7 +1007,7 @@ def _find_data_chunks(self, source, key):
"""
return self._get_key_data(source, key)._data_chunks

def _find_data(self, source, train_id) -> (FileAccess, int):
def _find_data(self, source, train_id) -> Tuple[FileAccess, int]:
for f in self._source_index[source]:
ixs = (f.train_ids == train_id).nonzero()[0]
if self.inc_suspect_trains and ixs.size > 0:
Expand Down Expand Up @@ -1367,7 +1377,7 @@ def _assemble_data(self, tid):

for key in self.data.keys_for_source(source):
_, pos, ds = self._find_data(source, key, tid)
if ds is None:
if ds is None or ds.size == 0:
continue
self._set_result(res, source, key, ds[pos])

Expand All @@ -1376,7 +1386,7 @@ def _assemble_data(self, tid):
{'source': source, 'timestamp.tid': tid})
for key in self.data.keys_for_source(source):
file, pos, ds = self._find_data(source, key, tid)
if ds is None:
if ds is None or ds.size == 0:
continue
group = key.partition('.')[0]
firsts, counts = file.get_index(source, group)
Expand Down
22 changes: 22 additions & 0 deletions extra_data/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,25 @@ def mock_empty_file():
path = osp.join(td, 'RAW-R0450-DA01-S00002.h5')
make_examples.make_sa3_da_file(path, ntrains=0)
yield path


@pytest.fixture(scope='function')
def mock_empty_dataset_file(format_version):
with TemporaryDirectory() as td:
path = osp.join(td, 'RAW-R0999-DA10-S00001.h5')
make_examples.make_fxe_da_file(path, format_version=format_version)

with h5py.File(path, 'a') as f:
shape = f['INSTRUMENT/SA1_XTD2_XGM/DOOCS/MAIN:output/data/intensityTD'].shape
f['INSTRUMENT/SA1_XTD2_XGM/DOOCS/MAIN:output/data/intensityTD'].resize((0, *shape[1:]))
tmichela marked this conversation as resolved.
Show resolved Hide resolved

shape = f['INSTRUMENT/SA1_XTD2_XGM/DOOCS/MAIN:output/data/trainId'].shape
f['INSTRUMENT/SA1_XTD2_XGM/DOOCS/MAIN:output/data/trainId'].resize((0, *shape[1:]))

shape = f['CONTROL/SA1_XTD2_XGM/DOOCS/MAIN/pulseEnergy/photonFlux/value'].shape
f['CONTROL/SA1_XTD2_XGM/DOOCS/MAIN/pulseEnergy/photonFlux/value'].resize((0, *shape[1:]))

shape = f['CONTROL/SA1_XTD2_XGM/DOOCS/MAIN/beamPosition/ixPos/value'].shape
f['CONTROL/SA1_XTD2_XGM/DOOCS/MAIN/beamPosition/ixPos/value'].resize((0, *shape[1:]))

yield path
25 changes: 23 additions & 2 deletions extra_data/tests/test_keydata.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np
import pytest

from extra_data import RunDirectory
from extra_data import H5File, RunDirectory
from extra_data.exceptions import TrainIDError


def test_get_keydata(mock_spb_raw_run):
run = RunDirectory(mock_spb_raw_run)
print(run.instrument_sources)
Expand Down Expand Up @@ -95,7 +96,7 @@ def test_data_counts(mock_reduced_spb_proc_run):
assert count.index.tolist() == xgm_beam_x.train_ids
assert (count.values == 1).all()

# intrument data
# instrument data
camera = run['SPB_IRU_CAM/CAM/SIDEMIC:daqOutput', 'data.image.pixels']
count = camera.data_counts()
assert count.index.tolist() == camera.train_ids
Expand All @@ -113,3 +114,23 @@ def test_select_by(mock_spb_raw_run):
subrun = run.select(am0)
assert subrun.all_sources == {am0.source}
assert subrun.keys_for_source(am0.source) == {am0.key}


def test_empty_dataset(mock_empty_dataset_file):

run = H5File(mock_empty_dataset_file)
kd = run['SA1_XTD2_XGM/DOOCS/MAIN:output', 'data.intensityTD']
assert kd.ndarray().size == 0
assert kd.xarray().size == 0
assert kd.dask_array().size == 0

tid, data = kd.train_from_index(0)
assert data.shape == (0, 1000)

kd = run['SA1_XTD2_XGM/DOOCS/MAIN', 'pulseEnergy.photonFlux.value']
assert kd.series().size == 0
tid, data = kd.train_from_index(0)
assert tid == 10000
assert data.shape == (0,)

assert len(list(kd.trains())) == 0
28 changes: 28 additions & 0 deletions extra_data/tests/test_reader_mockdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,3 +804,31 @@ def test_run_metadata(mock_spb_raw_run):
'sample', 'sequenceNumber',
}
assert isinstance(md['creationDate'], str)


def test_empty_dataset(mock_empty_dataset_file):

run = H5File(mock_empty_dataset_file)
device, key = 'SA1_XTD2_XGM/DOOCS/MAIN:output', 'data.intensityTD'

assert run.get_array(device, key).size == 0
assert run.get_dask_array(device, key).size == 0

sel = run.select(device, key)
_, data = sel.train_from_index(0)
assert list(data[device].keys()) == ['metadata']

for _, data in sel.trains(require_all=True):
assert key not in data[device]
break

_, data = sel.train_from_index(0)
assert key not in data[device]

s = run.get_series(device, 'data.trainId')
assert isinstance(s, pd.Series)
assert len(s) == 0

df = run.get_dataframe(fields=[("*_XGM/*", "*.i[xy]Pos*")])
assert len(df.columns) == 4
assert "SA1_XTD2_XGM/DOOCS/MAIN/beamPosition.ixPos" in df.columns