diff --git a/databroker/mongo_normalized.py b/databroker/mongo_normalized.py index 90d6cb765..3ad914510 100644 --- a/databroker/mongo_normalized.py +++ b/databroker/mongo_normalized.py @@ -977,7 +977,7 @@ def populate_columns(keys, min_seq_num, max_seq_num): # their size if we need to squeeze more performance out here. But maybe # we can get away with never adding that complexity. for key, est_row_bytesize in zip(nonscalars, estimated_nonscalar_row_bytesizes): - page_size = TARGET_PAGE_BYTESIZE // est_row_bytesize + page_size = max(1, TARGET_PAGE_BYTESIZE // est_row_bytesize) boundaries = list(range(min_seq_num, 1 + max_seq_num, page_size)) if boundaries[-1] != max_seq_num: boundaries.append(max_seq_num) @@ -2209,10 +2209,11 @@ def default_validate_shape(key, data, expected_shape): Check that data.shape == expected.shape. * If number of dimensions differ, raise BadShapeMetadata - * If any dimension is larger than expected, raise BadShapeMetadata. + * If any dimension differs by more than MAX_SIZE_DIFF, raise BadShapeMetadata. * If some dimensions are smaller than expected,, pad "right" edge of each dimension that falls short with NaN. """ + MAX_SIZE_DIFF = 2 if data.shape == expected_shape: return data if len(data.shape) != len(expected_shape): @@ -2224,12 +2225,10 @@ def default_validate_shape(key, data, expected_shape): ) # Pad at the "end" along any dimension that is too short. padding = [] - trimming = [] for actual, expected in zip(data.shape, expected_shape): margin = expected - actual # Limit how much padding or trimming we are willing to do. - SOMEWHAT_ARBITRARY_LIMIT_OF_WHAT_IS_REASONABLE = 2 - if abs(margin) > SOMEWHAT_ARBITRARY_LIMIT_OF_WHAT_IS_REASONABLE: + if abs(margin) > MAX_SIZE_DIFF: raise BadShapeMetadata( f"For data key {key} " f"shape {data.shape} does not " @@ -2237,19 +2236,12 @@ def default_validate_shape(key, data, expected_shape): ) if margin > 0: padding.append((0, margin)) - trimming.append(slice(None, None)) elif margin < 0: padding.append((0, 0)) - trimming.append(slice(None)) else: # margin == 0 padding.append((0, 0)) - trimming.append(slice(None, None)) - # TODO Rethink this! - # We cannot do NaN because that does not work for integers - # and it is too late to change our mind about the data type. - padded = numpy.pad(data, padding, "edge") - padded_and_trimmed = padded[tuple(trimming)] - return padded_and_trimmed + padded = numpy.pad(data, padding, "edge") + return padded def build_summary(run_start_doc, run_stop_doc, stream_names): diff --git a/databroker/tests/test_validate_shape.py b/databroker/tests/test_validate_shape.py index 01a37df6c..df563e384 100644 --- a/databroker/tests/test_validate_shape.py +++ b/databroker/tests/test_validate_shape.py @@ -1,10 +1,13 @@ from bluesky import RunEngine from bluesky.plans import count -from ophyd.sim import img +from ophyd.sim import img, DirectImage from tiled.client import Context, from_context from tiled.server.app import build_app -from ..mongo_normalized import MongoAdapter +from ..mongo_normalized import MongoAdapter, BadShapeMetadata + +import numpy as np +import pytest def test_validate_shape(tmpdir): @@ -29,3 +32,70 @@ def post_document(name, doc): assert not shapes client[uid]["primary"]["data"]["img"][:] assert shapes + + +@pytest.mark.parametrize( + "shape,expected_shape", + [ + ((10,), (11,)), + ((10, 20), (10, 21)), + ((10, 20), (10, 19)), + ((10, 20, 30), (10, 21, 30)), + ((10, 20, 30), (10, 20, 31)), + ((20, 20, 20, 20), (20, 21, 20, 22)), + ], +) +def test_padding(tmpdir, shape, expected_shape): + adapter = MongoAdapter.from_mongomock() + + direct_img = DirectImage( + func=lambda: np.array(np.ones(shape)), name="direct", labels={"detectors"} + ) + direct_img.img.name = "img" + + with Context.from_app(build_app(adapter), token_cache=tmpdir) as context: + client = from_context(context) + + def post_document(name, doc): + if name == "descriptor": + doc["data_keys"]["img"]["shape"] = expected_shape + + client.post_document(name, doc) + + RE = RunEngine() + RE.subscribe(post_document) + (uid,) = RE(count([direct_img])) + assert client[uid]["primary"]["data"]["img"][0].shape == expected_shape + + +@pytest.mark.parametrize( + "shape,expected_shape", + [ + ((10,), (11, 12)), + ((10, 20), (10, 200)), + ((20, 20, 20, 20), (20, 21, 20, 200)), + ((10, 20), (5, 20)), + ], +) +def test_default_validate_shape(tmpdir, shape, expected_shape): + adapter = MongoAdapter.from_mongomock() + + direct_img = DirectImage( + func=lambda: np.array(np.ones(shape)), name="direct", labels={"detectors"} + ) + direct_img.img.name = "img" + + with Context.from_app(build_app(adapter), token_cache=tmpdir) as context: + client = from_context(context) + + def post_document(name, doc): + if name == "descriptor": + doc["data_keys"]["img"]["shape"] = expected_shape + + client.post_document(name, doc) + + RE = RunEngine() + RE.subscribe(post_document) + (uid,) = RE(count([direct_img])) + with pytest.raises(BadShapeMetadata): + client[uid]["primary"]["data"]["img"][:]