diff --git a/msg/StateMachineStateUpdate.msg b/msg/StateMachineStateUpdate.msg new file mode 100644 index 000000000..bd4a5926a --- /dev/null +++ b/msg/StateMachineStateUpdate.msg @@ -0,0 +1,2 @@ +string stateMachineName +string state \ No newline at end of file diff --git a/msg/StateMachineStructure.msg b/msg/StateMachineStructure.msg new file mode 100644 index 000000000..9dfad8e89 --- /dev/null +++ b/msg/StateMachineStructure.msg @@ -0,0 +1,2 @@ +string machineName +StateMachineTransition[] transitions \ No newline at end of file diff --git a/msg/StateMachineTransition.msg b/msg/StateMachineTransition.msg new file mode 100644 index 000000000..a39102164 --- /dev/null +++ b/msg/StateMachineTransition.msg @@ -0,0 +1,2 @@ +string origin +string[] destinations \ No newline at end of file diff --git a/src/util/state_lib/state_machine.py b/src/util/state_lib/state_machine.py index 13aae1e5f..78556eda4 100644 --- a/src/util/state_lib/state_machine.py +++ b/src/util/state_lib/state_machine.py @@ -1,7 +1,8 @@ from .state import State, ExitState -from typing import Dict, Set, List +from typing import Dict, Set, List, Callable import time from dataclasses import dataclass +from threading import Lock @dataclass class TransitionRecord: @@ -10,28 +11,46 @@ class TransitionRecord: dest_state: str class StateMachine: - def __init__(self, initial_state: State): + def __init__(self, initial_state: State, name: str): self.current_state = initial_state + self.state_lock = Lock() self.state_transitions: Dict[type[State], Set[type[State]]] = {} self.transition_log: List[TransitionRecord] = [] self.context = None + self.name = name def __update(self): - next_state = self.current_state.on_loop(self.context) - if type(next_state) not in self.state_transitions[type(self.current_state)]: - raise Exception(f"Invalid transition from {self.current_state} to {next_state}") - if type(next_state) is not type(self.current_state): - self.current_state.on_exit(self.context) - self.transition_log.append(TransitionRecord(time.time(), str(self.current_state), str(next_state))) - self.current_state = next_state - self.current_state.on_enter(self.context) + with self.state_lock: + current_state = self.current_state + next_state = current_state.on_loop(self.context) + if type(next_state) not in self.state_transitions[type(current_state)]: + raise Exception(f"Invalid transition from {current_state} to {next_state}") + if type(next_state) is not type(current_state): + current_state.on_exit(self.context) + self.transition_log.append(TransitionRecord(time.time(), str(current_state), str(next_state))) + with self.state_lock: + self.current_state = next_state + self.current_state.on_enter(self.context) - def run(self): + def run(self, update_rate: float = float('inf'), warning_handle: Callable = print): + ''' + Runs the state machine until it returns an ExitState. + Aims for as close to update_rate_hz, updates per second + :param update_rate_z (float): targetted updates per second + ''' + target_loop_time = None if update_rate == float('inf') else (1.0 / update_rate) self.current_state.on_enter(self.context) while True: + start = time.time() self.__update() - if type(self.current_state) is ExitState: + if type(self.current_state) == ExitState: break + elapsed_time = time.time() - start + if target_loop_time is not None and elapsed_time < target_loop_time: + time.sleep(target_loop_time - elapsed_time) + elif target_loop_time is not None and elapsed_time > target_loop_time: + warning_handle(f"[WARNING] state machine loop overran target loop time by {elapsed_time - target_loop_time} s") + def add_transition(self, state_from: State, state_to: State): if type(state_from) not in self.state_transitions: diff --git a/src/util/state_lib/state_publisher_server.py b/src/util/state_lib/state_publisher_server.py new file mode 100644 index 000000000..7c0be49df --- /dev/null +++ b/src/util/state_lib/state_publisher_server.py @@ -0,0 +1,61 @@ +from state_machine import StateMachine +from state import State +import rospy +from mrover.msg import StateMachineStructure, StateMachineTransition, StateMachineStateUpdate +import threading +from typing import Callable +import time + +class StatePublisher: + + structure_publisher: rospy.Publisher + state_publisher: rospy.Publisher + state_machine: StateMachine + __struct_thread: threading.Thread + __state_thread: threading.Thread + __stop_lock: threading.Lock() + __stop: bool + + def __init__(self, state_machine: StateMachine, structure_pub_topic: str, + structure_update_rate_hz: float, state_pub_topic: str, state_update_rate_hz: float): + self.state_machine = state_machine + self.structure_publisher = rospy.Publisher(structure_pub_topic, StateMachineStructure, queue_size=1) + self.state_publisher = rospy.Publisher(state_pub_topic, StateMachineStateUpdate, queue_size=1) + self.__stop_lock = threading.Lock() + self.__struct_thread = threading.Thread(target=self.run_at_interval, args=(self.publish_structure, structure_update_rate_hz)) + self.__state_thread = threading.Thread(target=self.run_at_interval, args=(self.publish_state, state_update_rate_hz)) + self.__stop = False + + def stop(self) -> None: + with self.__stop_lock: + self.__stop = True + + def publish_structure(self) -> None: + structure = StateMachineStructure() + structure.machineName = self.state_machine.name + for origin, destinations in self.state_machine.state_transitions.items(): + transition = StateMachineTransition() + transition.origin = str(origin) + transition.destinations = [str(dest) for dest in destinations] + structure.transitions.append(transition) + structure_publisher.publish(structure) + + def publish_state(self) -> None: + with self.state_machine.state_lock: + cur_state = self.state_machine.current_state + state = StateMachineStateUpdate() + state.machineName = self.state_machine.name + state.state = str(cur_state) + state_publiser.publish(state) + + def run_at_interval(self, func: Callable[[StatePublisher, [None]]], update_hz: float): + desired_loop_time = 1.0 / update_hz + while True: + start_time = time.time() + with self.__stop_lock: + if self.__stop: break + self.func() + time.sleep(desired_loop_time - (time.time() - start_time)) + + + diff --git a/test/util/state_lib/simple_test.py b/test/util/state_lib/simple_test.py index 5fec47d6f..4ea5264e1 100644 --- a/test/util/state_lib/simple_test.py +++ b/test/util/state_lib/simple_test.py @@ -44,7 +44,7 @@ class Context: class TestSimpleStateMachine(unittest.TestCase): def test_simple(self): - sm = StateMachine(ForwardState()) + sm = StateMachine(ForwardState(), "SimpleStateMachine") context = Context() sm.set_context(context) sm.add_transition(ForwardState(), BackwardState()) diff --git a/test/util/state_lib/test_external_data.py b/test/util/state_lib/test_external_data.py index aea64d1af..a28c39b9c 100644 --- a/test/util/state_lib/test_external_data.py +++ b/test/util/state_lib/test_external_data.py @@ -68,7 +68,7 @@ def on_exit(self, context): print("Stopped") if __name__ == "__main__": - sm = StateMachine(WaitingState()) + sm = StateMachine(WaitingState(), "RandomForeverStateMachine") sm.add_transition(WaitingState(), RunningState()) sm.add_transition(RunningState(), WaitingState()) sm.add_transition(RunningState(), RunningState())