Skip to content

Commit

Permalink
Merge pull request #59 from ZLLentz/enh_seq_cond_item
Browse files Browse the repository at this point in the history
ENH: implement SequenceConditionItem from #54
  • Loading branch information
ZLLentz authored Nov 4, 2024
2 parents 53b006f + 9b6f6fc commit 68012b7
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 78 deletions.
6 changes: 6 additions & 0 deletions beams/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _capitalized(name: str) -> str:


generic_name = type_name(_get_generic_name_factory)
tagged_union_cache = set()


def as_tagged_union(cls: Cls) -> Cls:
Expand Down Expand Up @@ -135,4 +136,9 @@ def deserialization() -> Conversion:

deserializer(lazy=deserialization, target=cls)
serializer(lazy=serialization, source=cls)
tagged_union_cache.add(cls)
return cls


def is_tagged_union(cls) -> bool:
return cls in tagged_union_cache
24 changes: 14 additions & 10 deletions beams/tests/artifacts/eggs.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
"name": "self_test",
"description": "",
"check": {
"name": "self_test_check",
"description": "",
"pv": "PERC:COMP",
"value": 100,
"operator": "ge"
"ConditionItem": {
"name": "self_test_check",
"description": "",
"pv": "PERC:COMP",
"value": 100,
"operator": "ge"
}
},
"do": {
"name": "self_test_do",
Expand All @@ -17,11 +19,13 @@
"increment": 10,
"loop_period_sec": 0.01,
"termination_check": {
"name": "",
"description": "",
"pv": "PERC:COMP",
"value": 100,
"operator": "ge"
"ConditionItem": {
"name": "",
"description": "",
"pv": "PERC:COMP",
"value": 100,
"operator": "ge"
}
}
}
}
Expand Down
48 changes: 28 additions & 20 deletions beams/tests/artifacts/eggs2.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
"name": "ret_find",
"description": "",
"check": {
"name": "ret_find_check",
"description": "",
"pv": "RET:FOUND",
"value": 1,
"operator": "ge"
"ConditionItem": {
"name": "ret_find_check",
"description": "",
"pv": "RET:FOUND",
"value": 1,
"operator": "ge"
}
},
"do": {
"name": "ret_find_do",
Expand All @@ -23,11 +25,13 @@
"value": 1,
"loop_period_sec": 0.01,
"termination_check": {
"name": "",
"description": "",
"pv": "RET:FOUND",
"value": 1,
"operator": "ge"
"ConditionItem": {
"name": "",
"description": "",
"pv": "RET:FOUND",
"value": 1,
"operator": "ge"
}
}
}
}
Expand All @@ -37,11 +41,13 @@
"name": "ret_insert",
"description": "",
"check": {
"name": "",
"description": "",
"pv": "RET:INSERT",
"value": 1,
"operator": "ge"
"ConditionItem": {
"name": "",
"description": "",
"pv": "RET:INSERT",
"value": 1,
"operator": "ge"
}
},
"do": {
"name": "",
Expand All @@ -50,11 +56,13 @@
"value": 1,
"loop_period_sec": 1.0,
"termination_check": {
"name": "",
"description": "",
"pv": "RET:INSERT",
"value": 1,
"operator": "ge"
"ConditionItem": {
"name": "",
"description": "",
"pv": "RET:INSERT",
"value": 1,
"operator": "ge"
}
}
}
}
Expand Down
72 changes: 42 additions & 30 deletions beams/tests/artifacts/im2l0_test.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
"name": "reticle_state_out",
"description": "",
"check": {
"name": "check_reticule_state",
"description": "",
"pv": "IM2L0:XTES:MMS:STATE:GET_RBV",
"value": "OUT",
"operator": "eq"
"ConditionItem": {
"name": "check_reticule_state",
"description": "",
"pv": "IM2L0:XTES:MMS:STATE:GET_RBV",
"value": "OUT",
"operator": "eq"
}
},
"do": {
"name": "set_reticule_state_to_out",
Expand All @@ -23,11 +25,13 @@
"value": "OUT",
"loop_period_sec": 0.01,
"termination_check": {
"name": "check_reticule_state",
"description": "",
"pv": "IM2L0:XTES:MMS:STATE:GET_RBV",
"value": "OUT",
"operator": "eq"
"ConditionItem": {
"name": "check_reticule_state",
"description": "",
"pv": "IM2L0:XTES:MMS:STATE:GET_RBV",
"value": "OUT",
"operator": "eq"
}
}
}
}
Expand All @@ -37,11 +41,13 @@
"name": "zoom_motor",
"description": "",
"check": {
"name": "check_zoom_motor",
"description": "",
"pv": "IM2L0:XTES:CLZ.RBV",
"value": 25,
"operator": "eq"
"ConditionItem": {
"name": "check_zoom_motor",
"description": "",
"pv": "IM2L0:XTES:CLZ.RBV",
"value": 25,
"operator": "eq"
}
},
"do": {
"name": "set_zoom_motor",
Expand All @@ -50,11 +56,13 @@
"value": 25,
"loop_period_sec": 0.01,
"termination_check": {
"name": "check_zoom_motor",
"description": "",
"pv": "IM2L0:XTES:CLZ.RBV",
"value": 25,
"operator": "eq"
"ConditionItem": {
"name": "check_zoom_motor",
"description": "",
"pv": "IM2L0:XTES:CLZ.RBV",
"value": 25,
"operator": "eq"
}
}
}
}
Expand All @@ -64,11 +72,13 @@
"name": "focus_motor",
"description": "",
"check": {
"name": "check_focus_motor",
"description": "",
"pv": "IM2L0:XTES:CLF.RBV",
"value": 50,
"operator": "eq"
"ConditionItem": {
"name": "check_focus_motor",
"description": "",
"pv": "IM2L0:XTES:CLF.RBV",
"value": 50,
"operator": "eq"
}
},
"do": {
"name": "set_focus_motor",
Expand All @@ -77,11 +87,13 @@
"value": 50,
"loop_period_sec": 0.01,
"termination_check": {
"name": "check_focus_motor",
"description": "",
"pv": "IM2L0:XTES:CLF.RBV",
"value": 50,
"operator": "eq"
"ConditionItem": {
"name": "check_focus_motor",
"description": "",
"pv": "IM2L0:XTES:CLF.RBV",
"value": 50,
"operator": "eq"
}
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions beams/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import sys
from contextlib import contextmanager
from copy import copy
Expand All @@ -21,6 +22,15 @@ def central_logging_setup(caplog):
py_trees.logging.level = py_trees.logging.Level.DEBUG


@pytest.fixture(autouse=True)
def ca_env_vars():
# Pick a non-standard port to avoid collisions with same-named prod PVs
os.environ["EPICS_CA_SERVER_PORT"] = "5066"
# Only broadcast and get on local if
os.environ["EPICS_CA_AUTO_ADDR_LIST"] = "NO"
os.environ["EPICS_CA_ADDR_LIST"] = "localhost"


@contextmanager
def cli_args(args):
"""
Expand Down
84 changes: 75 additions & 9 deletions beams/tests/test_tree_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from itertools import combinations, combinations_with_replacement

import apischema
import pytest
from py_trees.behaviour import Behaviour
Expand All @@ -14,13 +16,15 @@
from beams.behavior_tree.ActionNode import ActionNode
from beams.behavior_tree.CheckAndDo import CheckAndDo
from beams.behavior_tree.ConditionNode import ConditionNode
from beams.tree_config import (BaseItem, BlackboardToStatusItem,
CheckAndDoItem,
from beams.serialization import get_all_subclasses, is_tagged_union
from beams.tree_config import (BaseConditionItem, BaseItem,
BlackboardToStatusItem, CheckAndDoItem,
CheckBlackboardVariableExistsItem,
CheckBlackboardVariableValueItem, ConditionItem,
DummyItem, FailureItem, IncPVActionItem,
ParallelItem, PeriodicItem, RunningItem,
SelectorItem, SequenceItem,
DummyConditionItem, DummyItem, FailureItem,
IncPVActionItem, ParallelItem, PeriodicItem,
RunningItem, SelectorItem,
SequenceConditionItem, SequenceItem,
SetBlackboardVariableItem, SetPVActionItem,
StatusQueueItem, SuccessEveryNItem, SuccessItem,
TickCounterItem, UnsetBlackboardVariableItem,
Expand All @@ -35,7 +39,9 @@
(FailureItem, Failure),
(RunningItem, Running),
(DummyItem, Dummy),
(DummyConditionItem, ConditionNode),
(ConditionItem, ConditionNode),
(SequenceConditionItem, Sequence),
(SetPVActionItem, ActionNode),
(IncPVActionItem, ActionNode),
(CheckAndDoItem, CheckAndDo),
Expand All @@ -53,17 +59,77 @@
]


@pytest.mark.parametrize('item, node_type', ITEM_TO_BEHAVIOUR)
def test_get_tree(item: BaseItem, node_type: Behaviour):
item_instance = item()
@pytest.mark.parametrize('item_class, node_type', ITEM_TO_BEHAVIOUR)
def test_get_tree(item_class: type[BaseItem], node_type: type[Behaviour]):
item_instance = item_class()
assert isinstance(item_instance.get_tree(), node_type)


@pytest.mark.parametrize('item_class', [item[0] for item in ITEM_TO_BEHAVIOUR])
def test_item_serialize_roundtrip(item_class: BaseItem):
def test_item_serialize_roundtrip(item_class: type[BaseItem]):
item = item_class()
ser = apischema.serialize(item_class, item)

deser = apischema.deserialize(item_class, ser)

assert item == deser


@pytest.mark.parametrize(
'item_class, attr, expand',
[
(SetPVActionItem, "termination_check", BaseConditionItem),
(IncPVActionItem, "termination_check", BaseConditionItem),
(CheckAndDoItem, "check", BaseConditionItem),
]
)
def test_item_serialize_roundtrip_union_singles(item_class: type[BaseItem], attr: str, expand: type[BaseItem]):
count = 0
for cls in get_all_subclasses(expand):
if is_tagged_union(cls):
continue
item = item_class(**{attr: cls()})
ser = apischema.serialize(item_class, item)
deser = apischema.deserialize(item_class, ser)
assert item == deser
count += 1
# If all subclasses skip assert item == deser, we should also have an error
assert count > 0


@pytest.mark.parametrize(
'item_class, attr, expand',
[
(ParallelItem, "children", BaseItem),
(SelectorItem, "children", BaseItem),
(SequenceItem, "children", BaseItem),
(SequenceConditionItem, "children", BaseConditionItem),
]
)
def test_item_serailize_roundtrip_union_sublists(item_class: type[BaseItem], attr: str, expand: type[BaseItem]):
options = [cls for cls in get_all_subclasses(expand) if not is_tagged_union(cls)]
combos = [tuple(options)]
for size in range(3):
for variant in combinations(options, size + 1):
combos.append(variant)
for tuple_of_cls in combos:
item = item_class(**{attr: [cls() for cls in tuple_of_cls]})
ser = apischema.serialize(item_class, item)
deser = apischema.deserialize(item_class, ser)
assert item == deser


def test_sequence_condition_item_condition_function():
item = SequenceConditionItem(
name="mega_dummy",
children=[
DummyConditionItem(),
DummyConditionItem(),
DummyConditionItem(),
]
)
cond_func = item.get_condition_function()
for variant in combinations_with_replacement((True, False), 3):
for idx in range(3):
item.children[idx].result = variant[idx]
assert cond_func() == all(variant)
Loading

0 comments on commit 68012b7

Please sign in to comment.