Skip to content

Commit

Permalink
Merge pull request #825 from hyperrealist/fix-padding-bug
Browse files Browse the repository at this point in the history
FIX: data padding bug in mongo_normalized
  • Loading branch information
danielballan authored Oct 10, 2024
2 parents 47623b0 + e3723ea commit 84b8c76
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 16 deletions.
20 changes: 6 additions & 14 deletions databroker/mongo_normalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ def populate_columns(keys, min_seq_num, max_seq_num):
# their size if we need to squeeze more performance out here. But maybe
# we can get away with never adding that complexity.
for key, est_row_bytesize in zip(nonscalars, estimated_nonscalar_row_bytesizes):
page_size = TARGET_PAGE_BYTESIZE // est_row_bytesize
page_size = max(1, TARGET_PAGE_BYTESIZE // est_row_bytesize)
boundaries = list(range(min_seq_num, 1 + max_seq_num, page_size))
if boundaries[-1] != max_seq_num:
boundaries.append(max_seq_num)
Expand Down Expand Up @@ -2209,10 +2209,11 @@ def default_validate_shape(key, data, expected_shape):
Check that data.shape == expected.shape.
* If number of dimensions differ, raise BadShapeMetadata
* If any dimension is larger than expected, raise BadShapeMetadata.
* If any dimension differs by more than MAX_SIZE_DIFF, raise BadShapeMetadata.
* If some dimensions are smaller than expected,, pad "right" edge of each
dimension that falls short with NaN.
"""
MAX_SIZE_DIFF = 2
if data.shape == expected_shape:
return data
if len(data.shape) != len(expected_shape):
Expand All @@ -2224,32 +2225,23 @@ def default_validate_shape(key, data, expected_shape):
)
# Pad at the "end" along any dimension that is too short.
padding = []
trimming = []
for actual, expected in zip(data.shape, expected_shape):
margin = expected - actual
# Limit how much padding or trimming we are willing to do.
SOMEWHAT_ARBITRARY_LIMIT_OF_WHAT_IS_REASONABLE = 2
if abs(margin) > SOMEWHAT_ARBITRARY_LIMIT_OF_WHAT_IS_REASONABLE:
if abs(margin) > MAX_SIZE_DIFF:
raise BadShapeMetadata(
f"For data key {key} "
f"shape {data.shape} does not "
f"match expected shape {expected_shape}."
)
if margin > 0:
padding.append((0, margin))
trimming.append(slice(None, None))
elif margin < 0:
padding.append((0, 0))
trimming.append(slice(None))
else: # margin == 0
padding.append((0, 0))
trimming.append(slice(None, None))
# TODO Rethink this!
# We cannot do NaN because that does not work for integers
# and it is too late to change our mind about the data type.
padded = numpy.pad(data, padding, "edge")
padded_and_trimmed = padded[tuple(trimming)]
return padded_and_trimmed
padded = numpy.pad(data, padding, "edge")
return padded


def build_summary(run_start_doc, run_stop_doc, stream_names):
Expand Down
74 changes: 72 additions & 2 deletions databroker/tests/test_validate_shape.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from bluesky import RunEngine
from bluesky.plans import count
from ophyd.sim import img
from ophyd.sim import img, DirectImage
from tiled.client import Context, from_context
from tiled.server.app import build_app

from ..mongo_normalized import MongoAdapter
from ..mongo_normalized import MongoAdapter, BadShapeMetadata

import numpy as np
import pytest


def test_validate_shape(tmpdir):
Expand All @@ -29,3 +32,70 @@ def post_document(name, doc):
assert not shapes
client[uid]["primary"]["data"]["img"][:]
assert shapes


@pytest.mark.parametrize(
"shape,expected_shape",
[
((10,), (11,)),
((10, 20), (10, 21)),
((10, 20), (10, 19)),
((10, 20, 30), (10, 21, 30)),
((10, 20, 30), (10, 20, 31)),
((20, 20, 20, 20), (20, 21, 20, 22)),
],
)
def test_padding(tmpdir, shape, expected_shape):
adapter = MongoAdapter.from_mongomock()

direct_img = DirectImage(
func=lambda: np.array(np.ones(shape)), name="direct", labels={"detectors"}
)
direct_img.img.name = "img"

with Context.from_app(build_app(adapter), token_cache=tmpdir) as context:
client = from_context(context)

def post_document(name, doc):
if name == "descriptor":
doc["data_keys"]["img"]["shape"] = expected_shape

client.post_document(name, doc)

RE = RunEngine()
RE.subscribe(post_document)
(uid,) = RE(count([direct_img]))
assert client[uid]["primary"]["data"]["img"][0].shape == expected_shape


@pytest.mark.parametrize(
"shape,expected_shape",
[
((10,), (11, 12)),
((10, 20), (10, 200)),
((20, 20, 20, 20), (20, 21, 20, 200)),
((10, 20), (5, 20)),
],
)
def test_default_validate_shape(tmpdir, shape, expected_shape):
adapter = MongoAdapter.from_mongomock()

direct_img = DirectImage(
func=lambda: np.array(np.ones(shape)), name="direct", labels={"detectors"}
)
direct_img.img.name = "img"

with Context.from_app(build_app(adapter), token_cache=tmpdir) as context:
client = from_context(context)

def post_document(name, doc):
if name == "descriptor":
doc["data_keys"]["img"]["shape"] = expected_shape

client.post_document(name, doc)

RE = RunEngine()
RE.subscribe(post_document)
(uid,) = RE(count([direct_img]))
with pytest.raises(BadShapeMetadata):
client[uid]["primary"]["data"]["img"][:]

0 comments on commit 84b8c76

Please sign in to comment.