Skip to content

Commit

Permalink
added messagse for and code for state ros message publishing server
Browse files Browse the repository at this point in the history
  • Loading branch information
ankithu committed Oct 3, 2023
1 parent 3247732 commit 036c3be
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 14 deletions.
2 changes: 2 additions & 0 deletions msg/StateMachineStateUpdate.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
string stateMachineName
string state
2 changes: 2 additions & 0 deletions msg/StateMachineStructure.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
string machineName
StateMachineTransition[] transitions
2 changes: 2 additions & 0 deletions msg/StateMachineTransition.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
string origin
string[] destinations
43 changes: 31 additions & 12 deletions src/util/state_lib/state_machine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .state import State, ExitState
from typing import Dict, Set, List
from typing import Dict, Set, List, Callable
import time
from dataclasses import dataclass
from threading import Lock

@dataclass
class TransitionRecord:
Expand All @@ -10,28 +11,46 @@ class TransitionRecord:
dest_state: str

class StateMachine:
def __init__(self, initial_state: State):
def __init__(self, initial_state: State, name: str):
self.current_state = initial_state
self.state_lock = Lock()
self.state_transitions: Dict[type[State], Set[type[State]]] = {}
self.transition_log: List[TransitionRecord] = []
self.context = None
self.name = name

def __update(self):
next_state = self.current_state.on_loop(self.context)
if type(next_state) not in self.state_transitions[type(self.current_state)]:
raise Exception(f"Invalid transition from {self.current_state} to {next_state}")
if type(next_state) is not type(self.current_state):
self.current_state.on_exit(self.context)
self.transition_log.append(TransitionRecord(time.time(), str(self.current_state), str(next_state)))
self.current_state = next_state
self.current_state.on_enter(self.context)
with self.state_lock:
current_state = self.current_state
next_state = current_state.on_loop(self.context)
if type(next_state) not in self.state_transitions[type(current_state)]:
raise Exception(f"Invalid transition from {current_state} to {next_state}")
if type(next_state) is not type(current_state):
current_state.on_exit(self.context)
self.transition_log.append(TransitionRecord(time.time(), str(current_state), str(next_state)))
with self.state_lock:
self.current_state = next_state
self.current_state.on_enter(self.context)

def run(self):
def run(self, update_rate: float = float('inf'), warning_handle: Callable = print):
'''
Runs the state machine until it returns an ExitState.
Aims for as close to update_rate_hz, updates per second
:param update_rate_z (float): targetted updates per second
'''
target_loop_time = None if update_rate == float('inf') else (1.0 / update_rate)
self.current_state.on_enter(self.context)
while True:
start = time.time()
self.__update()
if type(self.current_state) is ExitState:
if type(self.current_state) == ExitState:
break
elapsed_time = time.time() - start
if target_loop_time is not None and elapsed_time < target_loop_time:
time.sleep(target_loop_time - elapsed_time)
elif target_loop_time is not None and elapsed_time > target_loop_time:
warning_handle(f"[WARNING] state machine loop overran target loop time by {elapsed_time - target_loop_time} s")


def add_transition(self, state_from: State, state_to: State):
if type(state_from) not in self.state_transitions:
Expand Down
61 changes: 61 additions & 0 deletions src/util/state_lib/state_publisher_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from state_machine import StateMachine
from state import State
import rospy
from mrover.msg import StateMachineStructure, StateMachineTransition, StateMachineStateUpdate
import threading
from typing import Callable
import time

class StatePublisher:

structure_publisher: rospy.Publisher
state_publisher: rospy.Publisher
state_machine: StateMachine
__struct_thread: threading.Thread
__state_thread: threading.Thread
__stop_lock: threading.Lock()
__stop: bool

def __init__(self, state_machine: StateMachine, structure_pub_topic: str,
structure_update_rate_hz: float, state_pub_topic: str, state_update_rate_hz: float):
self.state_machine = state_machine
self.structure_publisher = rospy.Publisher(structure_pub_topic, StateMachineStructure, queue_size=1)
self.state_publisher = rospy.Publisher(state_pub_topic, StateMachineStateUpdate, queue_size=1)
self.__stop_lock = threading.Lock()
self.__struct_thread = threading.Thread(target=self.run_at_interval, args=(self.publish_structure, structure_update_rate_hz))
self.__state_thread = threading.Thread(target=self.run_at_interval, args=(self.publish_state, state_update_rate_hz))
self.__stop = False

def stop(self) -> None:
with self.__stop_lock:
self.__stop = True

def publish_structure(self) -> None:
structure = StateMachineStructure()
structure.machineName = self.state_machine.name
for origin, destinations in self.state_machine.state_transitions.items():
transition = StateMachineTransition()
transition.origin = str(origin)
transition.destinations = [str(dest) for dest in destinations]
structure.transitions.append(transition)
structure_publisher.publish(structure)

def publish_state(self) -> None:
with self.state_machine.state_lock:
cur_state = self.state_machine.current_state
state = StateMachineStateUpdate()
state.machineName = self.state_machine.name
state.state = str(cur_state)
state_publiser.publish(state)

def run_at_interval(self, func: Callable[[StatePublisher, [None]]], update_hz: float):
desired_loop_time = 1.0 / update_hz
while True:
start_time = time.time()
with self.__stop_lock:
if self.__stop: break
self.func()
time.sleep(desired_loop_time - (time.time() - start_time))



2 changes: 1 addition & 1 deletion test/util/state_lib/simple_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Context:

class TestSimpleStateMachine(unittest.TestCase):
def test_simple(self):
sm = StateMachine(ForwardState())
sm = StateMachine(ForwardState(), "SimpleStateMachine")
context = Context()
sm.set_context(context)
sm.add_transition(ForwardState(), BackwardState())
Expand Down
2 changes: 1 addition & 1 deletion test/util/state_lib/test_external_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def on_exit(self, context):
print("Stopped")

if __name__ == "__main__":
sm = StateMachine(WaitingState())
sm = StateMachine(WaitingState(), "RandomForeverStateMachine")
sm.add_transition(WaitingState(), RunningState())
sm.add_transition(RunningState(), WaitingState())
sm.add_transition(RunningState(), RunningState())
Expand Down

0 comments on commit 036c3be

Please sign in to comment.