Skip to content

Commit

Permalink
Merge pull request #819 from danielballan/shape-fixer-patch
Browse files Browse the repository at this point in the history
Update shape fixer and add example
  • Loading branch information
danielballan authored Aug 28, 2024
2 parents 5676d95 + c026f7e commit 11fa21f
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 33 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,6 @@ docs/output_directory
docs/data.csv
docs/data.xlsx
docs/data.h5

# generated by docker-compose
data/*
66 changes: 36 additions & 30 deletions databroker/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List, Optional

import tiled.client
import tiled.server.app
import typer

from tiled.utils import import_object
Expand Down Expand Up @@ -35,7 +37,7 @@ def shape_fix(
import databroker.queries
from databroker.mongo_normalized import MongoAdapter, discover_handlers

from .shape_fixer import measure, fix
from .shape_fixer import measure, patch

if handler is None:
handler_registry = discover_handlers()
Expand Down Expand Up @@ -65,36 +67,40 @@ def shape_fix(
typer.echo(f"Limited to first {limit} BlueskyRuns only")
items = items[:limit]

with Progress() as progress:
task = progress.add_task("Migrating...", total=len(items))
for uid, run in items:
try:
for stream_name, stream in run.items():
descriptor = stream.metadata["descriptors"][0]
recorded_shapes, measured_shapes = measure(
mds_database,
asset_database,
descriptor,
adapter.root_map,
handler_registry,
patch_resource=patch_resource,
)
if dry_run:
msg = "Dry run"
else:
msg = "Edited"
fix(mds_database, descriptor, measured_shapes)
if recorded_shapes != measured_shapes:
progress.console.print(
f"{msg} {uid} {stream_name}: {recorded_shapes} -> {measured_shapes}"
app = tiled.server.app.build_app(adapter)
with tiled.client.Context.from_app(app) as context:
tiled_client = tiled.client.from_context(context)

with Progress() as progress:
task = progress.add_task("Migrating...", total=len(items))
for uid, run in items:
try:
for stream_name, stream in run.items():
descriptor = stream.metadata()["descriptors"][0]
recorded_shapes, measured_shapes = measure(
mds_database,
asset_database,
descriptor,
adapter.root_map,
handler_registry,
patch_resource=patch_resource,
)
except Exception as exc:
if strict:
raise
progress.console.print(
f"Failed: {uid} {exc!r} (Use --strict for more.)"
)
progress.update(task, advance=1)
if dry_run:
msg = "Dry run"
else:
msg = "Edited"
patch(tiled_client[uid][stream_name], measured_shapes)
if recorded_shapes != measured_shapes:
progress.console.print(
f"{msg} {uid} {stream_name}: {recorded_shapes} -> {measured_shapes}"
)
except Exception as exc:
if strict:
raise
progress.console.print(
f"Failed: {uid} {exc!r} (Use --strict for more.)"
)
progress.update(task, advance=1)


main = cli_app
Expand Down
27 changes: 27 additions & 0 deletions databroker/shape_fixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Given a Run, fill one Event from each stream, measure its shape, and patch the
Descriptors in that stream.
"""

from event_model import Filler


Expand Down Expand Up @@ -59,9 +60,35 @@ def measure(


def fix(mds_database, descriptor, measured_shapes):
"""
Patch shape via direct update in database.
DEPRECATED: Prefer the function patch below because it uses Tiled's
PATCH API which means:
- The database write is more constrained, reducing the possibility of
unexpected behavior and corruption.
- The change is logged in HTTP logs.
- The original value is retained the revisions collection.
"""
for key, measured_shape in measured_shapes.items():
mds_database["event_descriptor"].update_one(
{"uid": descriptor["uid"]},
{"$set": {f"data_keys.{key}.shape": measured_shape}},
upsert=False,
)


def patch(tiled_client, measured_shapes):
"Patch shape using Tiled PATCH request."
for key, measured_shape in measured_shapes.items():
# Update shape in each descriptor for this stream.
tiled_client.patch_metadata(
[
{
"op": "replace",
"path": f"/descriptors/{index}/data_keys/{key}/shape",
"value": measured_shape,
}
for index in range(len(tiled_client.metadata["descriptors"]))
]
)
11 changes: 9 additions & 2 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ services:
image: ghcr.io/bluesky/databroker:v2.0.0b47
volumes:
- type: bind
source: ./example_config.yml
target: /deploy/config.yml
source: ./example_config
target: /deploy/config
- type: bind
source: ./data
target: /deploy/data
environment:
- TILED_SINGLE_USER_API_KEY=$TILED_SINGLE_USER_API_KEY
ports:
Expand All @@ -29,6 +32,10 @@ services:
- mongo
mongo:
image: docker.io/mongo:6.0.4
# These ports should not be exposed in production, but direct access
# to the MongoDB may be useful for development and debugging.
ports:
- 27017:27017
networks:
- backend

Expand Down
4 changes: 3 additions & 1 deletion example_config.yml → example_config/config.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# This is intended to be used with docker-compose.yml.
trees:
- path: /
- path: /raw
tree: databroker.mongo_normalized:MongoAdapter.from_uri
args:
uri: mongodb://mongo:27017/example_database
handler_registry:
NPY_SEQ: "handlers:NumpySeqHandler"
20 changes: 20 additions & 0 deletions example_config/handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"This is vendored from ophyd.sim"
import os

import numpy as np


class NumpySeqHandler:

def __init__(self, filename, root=""):
self._name = os.path.join(root, filename)

def __call__(self, index):
return np.load("{}_{}.npy".format(self._name, index), allow_pickle=False)

def get_file_list(self, datum_kwarg_gen):
"This method is optional. It is not needed for access, but for export."
return [
"{name}_{index}.npy".format(name=self._name, **kwargs)
for kwargs in datum_kwarg_gen
]
64 changes: 64 additions & 0 deletions examples/generate_data_with_wrong_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
# Generate data with wrong shape in a demo MongoDB and Tiled server.
TILED_SINGLE_USER_API_KEY podman-compose up
python examples/generate_data_with_wrong_shape.py
# Try to load the data in the Tiled Python client.
# It will fail because the shape metadata is wrong.
from tiled.client import from_uri
c = from_uri('http://localhost:8000', api_key='secret')
c['raw'].values().last()['primary']['data']['img'][:] # ERROR!
# The server logs should show:
# databroker.mongo_normalized.BadShapeMetadata:
# For data key img shape (5, 7) does not match expected shape (1, 11, 3).
# Run the shape-fixer CLI. Start with a dry run.
# The `--strict` mode ensures that errors are raised, not skipped.
databroker admin shape-fixer mongodb://localhost:27017/example_database --strict --dry-run
databroker admin shape-fixer mongodb://localhost:27017/example_database --strict
# The output should include something like this.
# (Of course, the uid(s) will be different.)
Edited 90b7ffa8-ba02-4163-a2aa-5f47d1eb322b primary: {'img': [1, 11, 3]} -> {'img': [5, 7]}
Migrating... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00
# Back in the Python client, try loading again.
# There is no need to reconnect; just run this line again:
c['raw'].values().last()['primary']['data']['img'][:] # Now it works!
"""

import numpy
from ophyd.sim import SynSignalWithRegistry
from bluesky import RunEngine
from bluesky.plans import count
from tiled.client import from_uri


RE = RunEngine()
client = from_uri("http://localhost:8000?api_key=secret")["raw"]


def post_document(name, doc):
client.post_document(name, doc)


RE.subscribe(post_document)


class LyingDetector(SynSignalWithRegistry):
def describe(self):
res = super().describe()
res["img"]["shape"] = (1, 11, 3)
return res


img = LyingDetector(
func=lambda: numpy.ones((5, 7), dtype=numpy.uint8),
name="img",
save_path="./data",
)
RE(count([img]))

0 comments on commit 11fa21f

Please sign in to comment.