From 263466d467bd4ac78fec1d9f971037997b41f600 Mon Sep 17 00:00:00 2001 From: Ankith Udupa Date: Tue, 31 Oct 2023 17:32:17 -0400 Subject: [PATCH] State library (#559) * state library and tests * working simple test * fixed external data test, added comment * added messagse for and code for state ros message publishing server * basically fully integrated with current nav system, tested on a few scenarios in sim * address comments: fix and integrate publisher + logging levels * style * style w correct black version * fix mypy issues and add stop funcitonality * Cleanup imports, improve some type hinting for member vars, fix wrong call to context * address comments * fix missing config * fix style * fix import issue --------- Co-authored-by: qhdwight --- config/navigation.yaml | 1 + msg/StateMachineStateUpdate.msg | 2 + msg/StateMachineStructure.msg | 2 + msg/StateMachineTransition.msg | 2 + setup.py | 1 + src/navigation/approach_post.py | 55 +++---- src/navigation/context.py | 9 +- src/navigation/drive.py | 17 ++- .../failure_identification/__init__.py | 0 .../failure_identification.py | 8 +- .../failure_identification/watchdog.py | 11 +- src/navigation/navigation.py | 94 ++++-------- src/navigation/post_backup.py | 76 ++++------ src/navigation/recovery.py | 60 +++----- src/navigation/search.py | 82 ++++------ src/navigation/state.py | 143 ++++-------------- src/navigation/waypoint.py | 88 ++++------- src/util/state_lib/__init__.py | 0 src/util/state_lib/state.py | 66 ++++++++ src/util/state_lib/state_machine.py | 127 ++++++++++++++++ src/util/state_lib/state_publisher_server.py | 76 ++++++++++ test/util/state_lib/simple_test.py | 72 +++++++++ test/util/state_lib/test_external_data.py | 82 ++++++++++ test/util/state_lib/test_loop_overrun.py | 84 ++++++++++ 24 files changed, 738 insertions(+), 420 deletions(-) create mode 100644 msg/StateMachineStateUpdate.msg create mode 100644 msg/StateMachineStructure.msg create mode 100644 msg/StateMachineTransition.msg create mode 100644 src/navigation/failure_identification/__init__.py create mode 100644 src/util/state_lib/__init__.py create mode 100644 src/util/state_lib/state.py create mode 100644 src/util/state_lib/state_machine.py create mode 100644 src/util/state_lib/state_publisher_server.py create mode 100644 test/util/state_lib/simple_test.py create mode 100644 test/util/state_lib/test_external_data.py create mode 100644 test/util/state_lib/test_loop_overrun.py diff --git a/config/navigation.yaml b/config/navigation.yaml index 2994bca36..fe5a6adfb 100644 --- a/config/navigation.yaml +++ b/config/navigation.yaml @@ -26,6 +26,7 @@ single_fiducial: stop_thresh: 1.0 fiducial_stop_threshold: 1.75 post_avoidance_multiplier: 1.42 + post_radius: 0.7 waypoint: stop_thresh: 0.5 diff --git a/msg/StateMachineStateUpdate.msg b/msg/StateMachineStateUpdate.msg new file mode 100644 index 000000000..bd4a5926a --- /dev/null +++ b/msg/StateMachineStateUpdate.msg @@ -0,0 +1,2 @@ +string stateMachineName +string state \ No newline at end of file diff --git a/msg/StateMachineStructure.msg b/msg/StateMachineStructure.msg new file mode 100644 index 000000000..9dfad8e89 --- /dev/null +++ b/msg/StateMachineStructure.msg @@ -0,0 +1,2 @@ +string machineName +StateMachineTransition[] transitions \ No newline at end of file diff --git a/msg/StateMachineTransition.msg b/msg/StateMachineTransition.msg new file mode 100644 index 000000000..a39102164 --- /dev/null +++ b/msg/StateMachineTransition.msg @@ -0,0 +1,2 @@ +string origin +string[] destinations \ No newline at end of file diff --git a/setup.py b/setup.py index 5c1a1350e..5bb0b94e7 100644 --- a/setup.py +++ b/setup.py @@ -8,6 +8,7 @@ packages=[ "localization", "util", + "util.state_lib", "navigation", "navigation.failure_identification", "esw", diff --git a/src/navigation/approach_post.py b/src/navigation/approach_post.py index bd4c505b0..b3ebfb676 100644 --- a/src/navigation/approach_post.py +++ b/src/navigation/approach_post.py @@ -1,60 +1,47 @@ import tf2_ros -from aenum import Enum, NoAlias -from geometry_msgs.msg import Twist -from util.ros_utils import get_rosparam - -from context import Context -from state import BaseState - -class ApproachPostStateTransitions(Enum): - _settings_ = NoAlias +from util.ros_utils import get_rosparam +from util.state_lib.state import State - finished_fiducial = "WaypointState" - continue_fiducial_id = "ApproachPostState" - no_fiducial = "SearchState" - recovery_state = "RecoveryState" +from navigation import search, waypoint -class ApproachPostState(BaseState): +class ApproachPostState(State): STOP_THRESH = get_rosparam("single_fiducial/stop_thresh", 0.7) FIDUCIAL_STOP_THRESHOLD = get_rosparam("single_fiducial/fiducial_stop_threshold", 1.75) DRIVE_FWD_THRESH = get_rosparam("waypoint/drive_fwd_thresh", 0.34) # 20 degrees - def __init__(self, context: Context): - own_transitions = [ApproachPostStateTransitions.continue_fiducial_id.name] # type: ignore - super().__init__(context, own_transitions, add_outcomes=[transition.name for transition in ApproachPostStateTransitions]) # type: ignore + def on_enter(self, context) -> None: + pass + + def on_exit(self, context) -> None: + pass - def evaluate(self, ud) -> str: + def on_loop(self, context) -> State: """ Drive towards a single fiducial if we see it and stop within a certain threshold if we see it. Else conduct a search to find it. :param ud: State machine user data :return: Next state """ - fid_pos = self.context.env.current_fid_pos() + fid_pos = context.env.current_fid_pos() if fid_pos is None: - # We have arrived at the waypoint where the fiducial should be but we have not seen it yet - cmd_vel = Twist() - cmd_vel.linear.x = 0.0 - self.context.rover.send_drive_command(cmd_vel) - return ApproachPostStateTransitions.no_fiducial.name # type: ignore + return search.SearchState() try: - cmd_vel, arrived = self.context.rover.driver.get_drive_command( + cmd_vel, arrived = context.rover.driver.get_drive_command( fid_pos, - self.context.rover.get_pose(in_odom_frame=True), + context.rover.get_pose(in_odom_frame=True), self.STOP_THRESH, self.DRIVE_FWD_THRESH, - in_odom=self.context.use_odom, + in_odom=context.use_odom, ) if arrived: - self.context.env.arrived_at_post = True - self.context.env.last_post_location = self.context.env.current_fid_pos(odom_override=False) - print(f"set last post location to {self.context.env.last_post_location}.") - self.context.course.increment_waypoint() - return ApproachPostStateTransitions.finished_fiducial.name # type: ignore - self.context.rover.send_drive_command(cmd_vel) + context.env.arrived_at_post = True + context.env.last_post_location = context.env.current_fid_pos(odom_override=False) + context.course.increment_waypoint() + return waypoint.WaypointState() + context.rover.send_drive_command(cmd_vel) except ( tf2_ros.LookupException, tf2_ros.ConnectivityException, @@ -63,4 +50,4 @@ def evaluate(self, ud) -> str: # TODO: probably go into some waiting state pass - return ApproachPostStateTransitions.continue_fiducial_id.name # type: ignore + return self diff --git a/src/navigation/context.py b/src/navigation/context.py index d4fe20407..3c242edc5 100644 --- a/src/navigation/context.py +++ b/src/navigation/context.py @@ -11,13 +11,12 @@ import tf2_ros from geometry_msgs.msg import Twist from mrover.msg import Waypoint, GPSWaypoint, EnableAuton, WaypointType, GPSPointList -from shapely.geometry import Point from std_msgs.msg import Time, Bool -from util.SE3 import SE3 -from util.ros_utils import get_rosparam from visualization_msgs.msg import Marker -from drive import DriveController +from util.SE3 import SE3 + +from navigation.drive import DriveController TAG_EXPIRATION_TIME_SECONDS = 60 @@ -162,7 +161,7 @@ def setup_course(ctx: Context, waypoints: List[Tuple[Waypoint, SE3]]) -> Course: return Course(ctx=ctx, course_data=mrover.msg.Course([waypoint[0] for waypoint in waypoints])) -def convert_gps_to_cartesian(waypoint: GPSWaypoint) -> Waypoint: +def convert_gps_to_cartesian(waypoint: GPSWaypoint) -> Tuple[Waypoint, SE3]: """ Converts a GPSWaypoint into a "Waypoint" used for publishing to the CourseService. """ diff --git a/src/navigation/drive.py b/src/navigation/drive.py index 1264c345c..2ed873124 100644 --- a/src/navigation/drive.py +++ b/src/navigation/drive.py @@ -5,6 +5,7 @@ import numpy as np from geometry_msgs.msg import Twist + from util.SE3 import SE3 from util.np_utils import angle_to_rotate, normalized from util.ros_utils import get_rosparam @@ -71,33 +72,33 @@ def _get_state_machine_output( # if we are at the target position, reset the controller and return a zero command if abs(linear_error) < completion_thresh: self.reset() - return (Twist(), True) + return Twist(), True if self._driver_state == self.DriveMode.STOPPED: # if the drive mode is STOP (we know we aren't at the target) so we must start moving towards it # just switch to the TURN_IN_PLACE state for now under the assumption that we need to turn to face the target # return a zero command and False to indicate we aren't at the target (and are also not in the correct state to figure out how to get there) self._driver_state = self.DriveMode.TURN_IN_PLACE - return (Twist(), False) + return Twist(), False elif self._driver_state == self.DriveMode.TURN_IN_PLACE: # if we are in the TURN_IN_PLACE state, we need to turn to face the target # if we are within the turn threshold to face the target, we can start driving straight towards it if abs(angular_error) < turn_in_place_thresh: self._driver_state = self.DriveMode.DRIVE_FORWARD - return (Twist(), False) + return Twist(), False # IVT (Intermediate Value Theorem) check. If the sign of the angular error has changed, this means we've crossed through 0 error # in order to prevent osciallation, we 'give up' and just switch to the drive forward state elif self._last_angular_error is not None and np.sign(self._last_angular_error) != np.sign(angular_error): self._driver_state = self.DriveMode.DRIVE_FORWARD - return (Twist(), False) + return Twist(), False # if neither of those things are true, we need to turn in place towards our target heading, so set the z component of the output Twist message else: cmd_vel = Twist() cmd_vel.angular.z = np.clip(angular_error * TURNING_P, MIN_TURNING_EFFORT, MAX_TURNING_EFFORT) - return (cmd_vel, False) + return cmd_vel, False elif self._driver_state == self.DriveMode.DRIVE_FORWARD: # if we are driving straight towards the target and our last angular error was inside the threshold @@ -112,13 +113,13 @@ def _get_state_machine_output( cur_angular_is_outside = abs(angular_error) >= turn_in_place_thresh if cur_angular_is_outside and last_angular_was_inside: self._driver_state = self.DriveMode.TURN_IN_PLACE - return (Twist(), False) + return Twist(), False # otherwise we compute a drive command with both a linear and angular component in the Twist message else: cmd_vel = Twist() cmd_vel.linear.x = np.clip(linear_error * DRIVING_P, MIN_DRIVING_EFFORT, MAX_DRIVING_EFFORT) cmd_vel.angular.z = np.clip(angular_error * TURNING_P, MIN_TURNING_EFFORT, MAX_TURNING_EFFORT) - return (cmd_vel, False) + return cmd_vel, False else: raise ValueError(f"Invalid drive state {self._driver_state}") @@ -189,7 +190,7 @@ def get_drive_command( if np.linalg.norm(target_pos - rover_pos) < completion_thresh: self.reset() - return (Twist(), True) + return Twist(), True if path_start is not None: target_pos = self.get_lookahead_pt(path_start, target_pos, rover_pos, LOOKAHEAD_DISTANCE) diff --git a/src/navigation/failure_identification/__init__.py b/src/navigation/failure_identification/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/navigation/failure_identification/failure_identification.py b/src/navigation/failure_identification/failure_identification.py index 6ad2c5e01..27f9833a1 100755 --- a/src/navigation/failure_identification/failure_identification.py +++ b/src/navigation/failure_identification/failure_identification.py @@ -10,12 +10,12 @@ from geometry_msgs.msg import Twist from mrover.msg import MotorsStatus from nav_msgs.msg import Odometry -from pandas import DataFrame from smach_msgs.msg import SmachContainerStatus from std_msgs.msg import Bool + from util.ros_utils import get_rosparam -from watchdog import WatchDog +from navigation.failure_identification.watchdog import WatchDog DATAFRAME_MAX_SIZE = get_rosparam("failure_identification/dataframe_max_size", 200) POST_RECOVERY_GRACE_PERIOD = get_rosparam("failure_identification/post_recovery_grace_period", 5.0) @@ -28,7 +28,7 @@ class FailureIdentifier: """ stuck_publisher: rospy.Publisher - _df: DataFrame + _df: pd.DataFrame watchdog: WatchDog actively_collecting: bool cur_cmd: Twist @@ -164,7 +164,7 @@ def update(self, nav_status: SmachContainerStatus, drive_status: MotorsStatus, o cur_row[f"wheel_{wheel_num}_velocity"] = drive_status.joint_states.velocity[wheel_num] # update the data frame with the cur row - self._df = pd.concat([self._df, DataFrame([cur_row])]) + self._df = pd.concat([self._df, pd.DataFrame([cur_row])]) if len(self._df) == DATAFRAME_MAX_SIZE: self.write_to_csv() diff --git a/src/navigation/failure_identification/watchdog.py b/src/navigation/failure_identification/watchdog.py index 8a63fe9e6..a3adc0d0c 100644 --- a/src/navigation/failure_identification/watchdog.py +++ b/src/navigation/failure_identification/watchdog.py @@ -1,7 +1,8 @@ from typing import Tuple import numpy as np -from pandas import DataFrame +import pandas as pd + from util.SO3 import SO3 from util.ros_utils import get_rosparam @@ -11,14 +12,14 @@ class WatchDog: - def get_start_end_positions(self, dataframe: DataFrame) -> Tuple[np.ndarray, np.ndarray]: + def get_start_end_positions(self, dataframe: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]: start_x, start_y, start_z = dataframe["x"].iloc[0], dataframe["y"].iloc[0], dataframe["z"].iloc[0] start_pos = np.array([start_x, start_y, start_z]) end_x, end_y, end_z = dataframe["x"].iloc[-1], dataframe["y"].iloc[-1], dataframe["z"].iloc[-1] end_pos = np.array([end_x, end_y, end_z]) return start_pos, end_pos - def get_start_end_rotations(self, dataframe: DataFrame) -> Tuple[SO3, SO3]: + def get_start_end_rotations(self, dataframe: pd.DataFrame) -> Tuple[SO3, SO3]: start_rot = np.array( [ dataframe["rot_x"].iloc[0], @@ -37,7 +38,7 @@ def get_start_end_rotations(self, dataframe: DataFrame) -> Tuple[SO3, SO3]: ) return SO3(start_rot), SO3(end_rot) - def get_start_end_time(self, dataframe: DataFrame) -> Tuple[float, float]: + def get_start_end_time(self, dataframe: pd.DataFrame) -> Tuple[float, float]: start_time = dataframe["time"].iloc[0] end_time = dataframe["time"].iloc[-1] return start_time, end_time @@ -79,7 +80,7 @@ def check_linear_stuck(self, delta_time, delta_pos, dataframe) -> bool: print(linear_velocity, LINEAR_THRESHOLD) return linear_velocity < LINEAR_THRESHOLD - def is_stuck(self, dataframe: DataFrame) -> bool: + def is_stuck(self, dataframe: pd.DataFrame) -> bool: if len(dataframe) > WINDOW_SIZE: dataframe_sliced = dataframe.tail(WINDOW_SIZE) # get the start and end position and rotation diff --git a/src/navigation/navigation.py b/src/navigation/navigation.py index cb82b177b..410823572 100755 --- a/src/navigation/navigation.py +++ b/src/navigation/navigation.py @@ -5,85 +5,51 @@ import threading import rospy -import smach -import smach_ros -from smach.log import loginfo -from smach.log import set_loggers -from std_msgs.msg import String -from approach_post import ApproachPostState, ApproachPostStateTransitions -from context import Context -from post_backup import PostBackupState, PostBackupTransitions -from recovery import RecoveryState, RecoveryStateTransitions -from search import SearchState, SearchStateTransitions -from state import DoneState, DoneStateTransitions, OffState, OffStateTransitions -from waypoint import WaypointState, WaypointStateTransitions +from util.state_lib.state_machine import StateMachine +from util.state_lib.state_publisher_server import StatePublisher + +from navigation.approach_post import ApproachPostState +from navigation.context import Context +from navigation.post_backup import PostBackupState +from navigation.recovery import RecoveryState +from navigation.search import SearchState +from navigation.state import DoneState, OffState, off_check +from navigation.waypoint import WaypointState class Navigation(threading.Thread): - state_machine: smach.StateMachine + state_machine: StateMachine context: Context - sis: smach_ros.IntrospectionServer + state_machine_server: StatePublisher def __init__(self, context: Context): super().__init__() - set_loggers(info=lambda _: None, warn=loginfo, error=loginfo, debug=loginfo) self.name = "NavigationThread" - self.state_machine = smach.StateMachine(outcomes=["terminated"]) - self.state_machine.userdata.waypoint_index = 0 - self.context = context - self.sis = smach_ros.IntrospectionServer("", self.state_machine, "/SM_ROOT") - self.sis.start() - self.state_publisher = rospy.Publisher("/nav_state", String, queue_size=1) - with self.state_machine: - self.state_machine.add( - "OffState", OffState(self.context), transitions=self.get_transitions(OffStateTransitions) - ) - self.state_machine.add( - "DoneState", DoneState(self.context), transitions=self.get_transitions(DoneStateTransitions) - ) - self.state_machine.add( - "WaypointState", WaypointState(self.context), transitions=self.get_transitions(WaypointStateTransitions) - ) - self.state_machine.add( - "ApproachPostState", - ApproachPostState(self.context), - transitions=self.get_transitions(ApproachPostStateTransitions), - ) - self.state_machine.add( - "SearchState", SearchState(self.context), transitions=self.get_transitions(SearchStateTransitions) - ) - self.state_machine.add( - "RecoveryState", RecoveryState(self.context), transitions=self.get_transitions(RecoveryStateTransitions) - ) - self.state_machine.add( - "PostBackupState", - PostBackupState(self.context), - transitions=self.get_transitions(PostBackupTransitions), - ) - rospy.Timer(rospy.Duration(0.1), self.publish_state) - - def get_transitions(self, transitions_enum): - transition_dict = {transition.name: transition.value for transition in transitions_enum} - transition_dict["off"] = "OffState" # logic for switching to offstate is built into OffState - return transition_dict - - def publish_state(self, event=None): - with self.state_machine: - active_states = self.state_machine.get_active_states() - if len(active_states) > 0: - self.state_publisher.publish(active_states[0]) + self.state_machine = StateMachine(OffState(), "NavStateMachine") + self.state_machine.set_context(context) + self.state_machine.add_transitions(ApproachPostState(), [WaypointState(), SearchState(), RecoveryState()]) + self.state_machine.add_transitions(PostBackupState(), [WaypointState(), RecoveryState()]) + self.state_machine.add_transitions( + RecoveryState(), [WaypointState(), SearchState(), PostBackupState(), ApproachPostState()] + ) + self.state_machine.add_transitions(SearchState(), [ApproachPostState(), WaypointState(), RecoveryState()]) + self.state_machine.add_transitions(DoneState(), [WaypointState()]) + self.state_machine.add_transitions( + WaypointState(), [PostBackupState(), ApproachPostState(), SearchState(), RecoveryState()] + ) + self.state_machine.add_transitions(OffState(), [WaypointState()]) + self.state_machine.configure_off_switch(OffState(), off_check) + self.state_machine_server = StatePublisher(self.state_machine, "nav_structure", 1, "nav_state", 10) def run(self): - self.state_machine.execute() + self.state_machine.run() def stop(self): - self.sis.stop() # Requests current state to go into 'terminated' to cleanly exit state machine - self.state_machine.request_preempt() - # Wait for smach thread to terminate + self.state_machine.stop() self.join() - self.context.rover.send_drive_stop() + self.state_machine.context.rover.send_drive_stop() def main(): diff --git a/src/navigation/post_backup.py b/src/navigation/post_backup.py index 66f278b27..b7c51246f 100644 --- a/src/navigation/post_backup.py +++ b/src/navigation/post_backup.py @@ -4,19 +4,20 @@ from typing import Optional import numpy as np -import rospy import tf2_ros -from aenum import Enum, NoAlias from shapely.geometry import Point, LineString + from util.SE3 import SE3 from util.np_utils import perpendicular_2d from util.ros_utils import get_rosparam +from util.state_lib.state import State -from context import Context -from state import BaseState -from trajectory import Trajectory +from navigation import waypoint, recovery +from navigation.trajectory import Trajectory -POST_RADIUS = get_rosparam("gate/post_radius", 0.7) * get_rosparam("single_fiducial/post_avoidance_multiplier", 1.42) +POST_RADIUS = get_rosparam("single_fiducial/post_radius", 0.7) * get_rosparam( + "single_fiducial/post_avoidance_multiplier", 1.42 +) BACKUP_DISTANCE = get_rosparam("recovery/recovery_distance", 2.0) STOP_THRESH = get_rosparam("search/stop_thresh", 0.2) DRIVE_FWD_THRESH = get_rosparam("search/drive_fwd_thresh", 0.34) @@ -24,6 +25,7 @@ @dataclass class AvoidPostTrajectory(Trajectory): + @staticmethod def avoid_post_trajectory(rover_pose: SE3, post_pos: np.ndarray, waypoint_pos: np.ndarray) -> AvoidPostTrajectory: """ Generates a trajectory that avoids a post until the rover has a clear path to the waypoint @@ -90,39 +92,27 @@ def avoid_post_trajectory(rover_pose: SE3, post_pos: np.ndarray, waypoint_pos: n return AvoidPostTrajectory(coords) -class PostBackupTransitions(Enum): - _settings_ = NoAlias - # State Transitions - finished_traj = "WaypointState" - recovery_state = "RecoveryState" - continue_post_backup = "PostBackupState" - +class PostBackupState(State): + traj: Optional[AvoidPostTrajectory] -class PostBackupState(BaseState): - def __init__( - self, - context: Context, - ): - own_transitions = [PostBackupTransitions.continue_post_backup.name] # type: ignore - super().__init__(context, own_transitions, add_outcomes=[transition.name for transition in PostBackupTransitions]) # type: ignore - self.traj: Optional[AvoidPostTrajectory] = None - - def reset(self): + def on_exit(self, context): self.traj = None - def evaluate(self, ud): + def on_enter(self, context) -> None: + if context.env.last_post_location is None: + self.traj = None + else: + self.traj = AvoidPostTrajectory.avoid_post_trajectory( + context.rover.get_pose(), + context.env.last_post_location, + context.course.current_waypoint_pose().position, + ) + self.traj.cur_pt = 0 + + def on_loop(self, context) -> State: try: if self.traj is None: - if self.context.env.last_post_location is None: - rospy.logerr("PostBackupState: last_post_location is None") - return PostBackupTransitions.finished_traj.name # type: ignore - - self.traj = AvoidPostTrajectory.avoid_post_trajectory( - self.context.rover.get_pose(), - self.context.env.last_post_location, - self.context.course.current_waypoint_pose().position, - ) - self.traj.cur_pt = 0 + return waypoint.WaypointState() target_pos = self.traj.get_cur_pt() @@ -130,9 +120,9 @@ def evaluate(self, ud): point_index = self.traj.cur_pt drive_backwards = point_index == 0 - cmd_vel, arrived = self.context.rover.driver.get_drive_command( + cmd_vel, arrived = context.rover.driver.get_drive_command( target_pos, - self.context.rover.get_pose(), + context.rover.get_pose(), STOP_THRESH, DRIVE_FWD_THRESH, drive_back=drive_backwards, @@ -141,18 +131,18 @@ def evaluate(self, ud): print(f"ARRIVED AT POINT {point_index}") if self.traj.increment_point(): self.traj = None - return PostBackupTransitions.finished_traj.name # type: ignore + return waypoint.WaypointState() - if self.context.rover.stuck: - self.context.rover.previous_state = PostBackupTransitions.continue_post_backup.name # type: ignore + if context.rover.stuck: + context.rover.previous_state = self self.traj = None - return PostBackupTransitions.recovery_state.name # type: ignore + return recovery.RecoveryState() - self.context.rover.send_drive_command(cmd_vel) - return PostBackupTransitions.continue_post_backup.name # type: ignore + context.rover.send_drive_command(cmd_vel) + return self except ( tf2_ros.LookupException, tf2_ros.ConnectivityException, tf2_ros.ExtrapolationException, ): - return PostBackupTransitions.continue_post_backup.name # type: ignore + return self diff --git a/src/navigation/recovery.py b/src/navigation/recovery.py index d8ab217e8..33f5af69d 100644 --- a/src/navigation/recovery.py +++ b/src/navigation/recovery.py @@ -1,14 +1,12 @@ -from context import Context from typing import Optional import numpy as np import rospy -from aenum import Enum, NoAlias +from aenum import Enum + from util.np_utils import rotate_2d from util.ros_utils import get_rosparam - -from context import Context -from state import BaseState +from util.state_lib.state import State STOP_THRESH = get_rosparam("recovery/stop_thresh", 0.2) DRIVE_FWD_THRESH = get_rosparam("recovery/drive_fwd_thresh", 0.34) # 20 degrees @@ -16,46 +14,36 @@ GIVE_UP_TIME = get_rosparam("recovery/give_up_time", 10.0) -class RecoveryStateTransitions(Enum): - _settings_ = NoAlias - continue_waypoint_traverse = "WaypointState" - continue_search = "SearchState" - continue_recovery = "RecoveryState" - continue_post_backup = "PostBackupState" - recovery_state = "RecoveryState" - - class JTurnAction(Enum): moving_back: Enum = 0 j_turning: Enum = 1 -class RecoveryState(BaseState): +class RecoveryState(State): waypoint_behind: Optional[np.ndarray] current_action: JTurnAction start_time: Optional[rospy.Time] = None + waypoint_calculated: bool - def __init__(self, context: Context): - own_transitions = [RecoveryStateTransitions.continue_recovery.name] # type: ignore - super().__init__(context, own_transitions, add_outcomes=[transition.name for transition in RecoveryStateTransitions]) # type: ignore - self.waypoint_calculated = False - self.waypoint_behind = None - self.current_action = JTurnAction.moving_back - - def reset(self) -> None: + def reset(self, context) -> None: self.waypoint_calculated = False self.waypoint_behind = None self.current_action = JTurnAction.moving_back - self.context.rover.stuck = False + context.rover.stuck = False self.start_time = None - def evaluate(self, ud) -> str: - if self.start_time is None: - self.start_time = rospy.Time.now() + def on_enter(self, context) -> None: + self.reset(context) + self.start_time = rospy.Time.now() + + def on_exit(self, context) -> None: + self.reset(context) + + def on_loop(self, context) -> State: if rospy.Time.now() - self.start_time > rospy.Duration(GIVE_UP_TIME): - return self.context.rover.previous_state + return context.rover.previous_state # Making waypoint behind the rover to go backwards - pose = self.context.rover.get_pose() + pose = context.rover.get_pose() # if first round, set a waypoint directly behind the rover and command it to # drive backwards toward it until it arrives at that point. if self.current_action == JTurnAction.moving_back: @@ -65,15 +53,15 @@ def evaluate(self, ud) -> str: dir_vector = -1 * RECOVERY_DISTANCE * pose.rotation.direction_vector() self.waypoint_behind = pose.position + dir_vector - cmd_vel, arrived_back = self.context.rover.driver.get_drive_command( + cmd_vel, arrived_back = context.rover.driver.get_drive_command( self.waypoint_behind, pose, STOP_THRESH, DRIVE_FWD_THRESH, drive_back=True ) - self.context.rover.send_drive_command(cmd_vel) + context.rover.send_drive_command(cmd_vel) if arrived_back: self.current_action = JTurnAction.j_turning # move to second part of turn self.waypoint_behind = None - self.context.rover.driver.reset() + context.rover.driver.reset() # if second round, set a waypoint off to the side of the rover and command it to # turn and drive backwards towards it until it arrives at that point. So it will begin @@ -85,13 +73,13 @@ def evaluate(self, ud) -> str: dir_vector[:2] = RECOVERY_DISTANCE * rotate_2d(dir_vector[:2], 3 * np.pi / 4) self.waypoint_behind = pose.position + dir_vector - cmd_vel, arrived_turn = self.context.rover.driver.get_drive_command( + cmd_vel, arrived_turn = context.rover.driver.get_drive_command( self.waypoint_behind, pose, STOP_THRESH, DRIVE_FWD_THRESH, drive_back=True ) - self.context.rover.send_drive_command(cmd_vel) + context.rover.send_drive_command(cmd_vel) # set stuck to False if arrived_turn: - return self.context.rover.previous_state + return context.rover.previous_state - return RecoveryStateTransitions.continue_recovery.name # type: ignore + return self diff --git a/src/navigation/search.py b/src/navigation/search.py index 86531e66f..c56d6ec37 100644 --- a/src/navigation/search.py +++ b/src/navigation/search.py @@ -4,13 +4,14 @@ from typing import Optional import numpy as np -from aenum import Enum, NoAlias from mrover.msg import GPSPointList + from util.ros_utils import get_rosparam +from util.state_lib.state import State -from context import Context, convert_cartesian_to_gps -from state import BaseState -from trajectory import Trajectory +from navigation import approach_post, recovery, waypoint +from navigation.context import convert_cartesian_to_gps +from navigation.trajectory import Trajectory @dataclass @@ -79,59 +80,38 @@ def spiral_traj( ) -class SearchStateTransitions(Enum): - _settings_ = NoAlias - - no_fiducial = "WaypointState" - continue_search = "SearchState" - found_fiducial_post = "ApproachPostState" - recovery_state = "RecoveryState" - +class SearchState(State): + traj: SearchTrajectory + prev_target: Optional[np.ndarray] = None + is_recovering: bool = False -class SearchState(BaseState): STOP_THRESH = get_rosparam("search/stop_thresh", 0.2) DRIVE_FWD_THRESH = get_rosparam("search/drive_fwd_thresh", 0.34) # 20 degrees SPIRAL_COVERAGE_RADIUS = get_rosparam("search/coverage_radius", 20) SEGMENTS_PER_ROTATION = get_rosparam("search/segments_per_rotation", 8) DISTANCE_BETWEEN_SPIRALS = get_rosparam("search/distance_between_spirals", 2.5) - def __init__( - self, - context: Context, - ): - own_transitions = [SearchStateTransitions.continue_search.name] # type: ignore - super().__init__( - context, - own_transitions, - add_outcomes=[transition.name for transition in SearchStateTransitions], # type: ignore - ) - self.traj: Optional[SearchTrajectory] = None - self.prev_target: Optional[np.ndarray] = None - self.is_recovering = False - - def reset(self) -> None: + def on_enter(self, context) -> None: + search_center = context.course.current_waypoint() if not self.is_recovering: - self.traj = None - self.prev_target = None - - def evaluate(self, ud): - # Check if a path has been generated, and it's associated with the same - # waypoint as the previous one. Generate one if not - waypoint = self.context.course.current_waypoint() - if self.traj is None or self.traj.fid_id != waypoint.fiducial_id: self.traj = SearchTrajectory.spiral_traj( - self.context.rover.get_pose().position[0:2], + context.rover.get_pose().position[0:2], self.SPIRAL_COVERAGE_RADIUS, self.DISTANCE_BETWEEN_SPIRALS, self.SEGMENTS_PER_ROTATION, - waypoint.fiducial_id, + search_center.fiducial_id, ) + self.prev_target = None - # continue executing this path from wherever it left off + def on_exit(self, context) -> None: + pass + + def on_loop(self, context) -> State: + # continue executing the path from wherever it left off target_pos = self.traj.get_cur_pt() - cmd_vel, arrived = self.context.rover.driver.get_drive_command( + cmd_vel, arrived = context.rover.driver.get_drive_command( target_pos, - self.context.rover.get_pose(), + context.rover.get_pose(), self.STOP_THRESH, self.DRIVE_FWD_THRESH, path_start=self.prev_target, @@ -140,22 +120,20 @@ def evaluate(self, ud): self.prev_target = target_pos # if we finish the spiral without seeing the fiducial, move on with course if self.traj.increment_point(): - return SearchStateTransitions.no_fiducial.name # type: ignore + return waypoint.WaypointState() - if self.context.rover.stuck: - self.context.rover.previous_state = SearchStateTransitions.continue_search.name # type: ignore + if context.rover.stuck: + context.rover.previous_state = self self.is_recovering = True - return SearchStateTransitions.recovery_state.name # type: ignore + return recovery.RecoveryState() else: self.is_recovering = False - self.context.search_point_publisher.publish( + context.search_point_publisher.publish( GPSPointList([convert_cartesian_to_gps(pt) for pt in self.traj.coordinates]) ) - self.context.rover.send_drive_command(cmd_vel) - - # if we see the fiduicial go to either fiducial - if self.context.env.current_fid_pos() is not None and self.context.course.look_for_post(): - return SearchStateTransitions.found_fiducial_post.name # type: ignore + context.rover.send_drive_command(cmd_vel) - return SearchStateTransitions.continue_search.name # type: ignore + if context.env.current_fid_pos() is not None and context.course.look_for_post(): + return approach_post.ApproachPostState() + return self diff --git a/src/navigation/state.py b/src/navigation/state.py index 9084ab741..9e7ec8b43 100644 --- a/src/navigation/state.py +++ b/src/navigation/state.py @@ -1,129 +1,52 @@ -from abc import ABC, abstractmethod -from typing import List, Optional - -import smach -from aenum import Enum, NoAlias from geometry_msgs.msg import Twist -from context import Context - - -class BaseState(smach.State, ABC): - """ - Custom base state which handles termination cleanly via smach preemption. - """ +from util.state_lib.state import State - context: Context - own_transitions: List[str] # any transitions that map back to the same state +from navigation import waypoint - def __init__( - self, - context: Context, - own_transitions: List[str], - add_outcomes: Optional[List[str]] = None, - add_input_keys: Optional[List[str]] = None, - add_output_keys: Optional[List[str]] = None, - ): - add_outcomes = add_outcomes or [] - add_input_keys = add_input_keys or [] - add_output_keys = add_output_keys or [] - self.own_transitions = own_transitions - super().__init__( - add_outcomes + ["terminated", "off"], - add_input_keys + ["waypoint_index"], - add_output_keys + ["waypoint_index"], - ) - self.context = context - def execute(self, ud): - """ - Override execute method to add logic for early termination. - Base classes should override evaluate instead of this! - :param ud: State machine user data - :return: Next state, 'terminated' if we want to quit early - """ - if self.preempt_requested(): - self.service_preempt() - self.context.rover.stuck = False - return "terminated" - if self.context.disable_requested: - self.context.disable_requested = False - self.context.course = None - self.context.rover.stuck = False - self.context.rover.driver.reset() - self.reset() - return "off" - transition = self.evaluate(ud) - - if transition in self.own_transitions: - # we are staying in the same state - return transition - else: - # we are exiting the state so cleanup - self.reset() - return transition - - def reset(self): - """ - Is called anytime we transition out of the current state. - Override this function to reset any state variables - that need to reset everytime we exit the state. - """ +class DoneState(State): + def on_enter(self, context) -> None: pass - @abstractmethod - def evaluate(self, ud: smach.UserData) -> str: - """Override me instead of execute!""" - ... - - -class DoneStateTransitions(Enum): - _settings_ = NoAlias - - idle = "DoneState" - begin_course = "WaypointState" - - -class DoneState(BaseState): - def __init__(self, context: Context): - super().__init__( - context, - [DoneStateTransitions.idle.name], # type: ignore - add_outcomes=[transition.name for transition in DoneStateTransitions], # type: ignore - ) + def on_exit(self, context) -> None: + pass - def evaluate(self, ud): + def on_loop(self, context): # Check if we have a course to traverse - if self.context.course and (not self.context.course.is_complete()): - return DoneStateTransitions.begin_course.name # type: ignore + if context.course and (not context.course.is_complete()): + return waypoint.WaypointState() # Stop rover cmd_vel = Twist() - self.context.rover.send_drive_command(cmd_vel) - return DoneStateTransitions.idle.name # type: ignore + context.rover.send_drive_command(cmd_vel) + return self -class OffStateTransitions(Enum): - _settings_ = NoAlias +class OffState(State): + def on_enter(self, context) -> None: + pass - idle = "OffState" - begin_course = "WaypointState" + def on_exit(self, context) -> None: + pass + def on_loop(self, context): + if context.course and (not context.course.is_complete()): + return waypoint.WaypointState() -class OffState(BaseState): - def __init__(self, context: Context): - super().__init__( - context, - [OffStateTransitions.idle.name], # type: ignore - add_outcomes=[transition.name for transition in OffStateTransitions], # type: ignore - ) - self.stop_count = 0 + cmd_vel = Twist() + context.rover.send_drive_command(cmd_vel) + return self - def evaluate(self, ud): - # Check if we need to ignore on - if self.context.course and (not self.context.course.is_complete()): - self.stop_count = 0 - return OffStateTransitions.begin_course.name # type: ignore - return OffStateTransitions.idle.name # type: ignore - # We have determined the Rover is off, now ignore Rover on ... +def off_check(context) -> bool: + """ + function that state machine will call to check if the rover is turned off. + """ + if context.disable_requested: + context.disable_requested = False + context.course = None + context.rover.stuck = False + context.rover.driver.reset() + return True + return False diff --git a/src/navigation/waypoint.py b/src/navigation/waypoint.py index 47e89ce77..489aea42d 100644 --- a/src/navigation/waypoint.py +++ b/src/navigation/waypoint.py @@ -1,53 +1,23 @@ -from typing import List, Optional - -import numpy as np import tf2_ros -from aenum import Enum, NoAlias -from util.ros_utils import get_rosparam - -from context import Context -from state import BaseState +from util.ros_utils import get_rosparam +from util.state_lib.state import State -class WaypointStateTransitions(Enum): - _settings_ = NoAlias - - continue_waypoint_traverse = "WaypointState" - search_at_waypoint = "SearchState" - no_waypoint = "DoneState" - find_approach_post = "ApproachPostState" - recovery_state = "RecoveryState" - backup_from_post = "PostBackupState" +from navigation import search, recovery, approach_post, post_backup, state -class WaypointState(BaseState): +class WaypointState(State): STOP_THRESH = get_rosparam("waypoint/stop_thresh", 0.5) DRIVE_FWD_THRESH = get_rosparam("waypoint/drive_fwd_thresh", 0.34) # 20 degrees NO_FIDUCIAL = get_rosparam("waypoint/no_fiducial", -1) - def __init__( - self, - context: Context, - add_outcomes: Optional[List[str]] = None, - add_input_keys: Optional[List[str]] = None, - add_output_keys: Optional[List[str]] = None, - ): - add_outcomes = add_outcomes or [] - add_input_keys = add_input_keys or [] - add_output_keys = add_output_keys or [] - own_transitions = [WaypointStateTransitions.continue_waypoint_traverse.name] # type: ignore - super().__init__( - context, - own_transitions, - add_outcomes + [transition.name for transition in WaypointStateTransitions], # type: ignore - add_input_keys, - add_output_keys, - ) + def on_enter(self, context) -> None: + pass - def rover_forward(self) -> np.ndarray: - return self.context.rover.get_pose().rotation.direction_vector() + def on_exit(self, context) -> None: + pass - def evaluate(self, ud) -> str: + def on_loop(self, context) -> State: """ Handle driving to a waypoint defined by a linearized cartesian position. If the waypoint is associated with a fiducial id, go into that state early if we see it, @@ -55,41 +25,41 @@ def evaluate(self, ud) -> str: :param ud: State machine user data :return: Next state """ - current_waypoint = self.context.course.current_waypoint() + current_waypoint = context.course.current_waypoint() if current_waypoint is None: - return WaypointStateTransitions.no_waypoint.name # type: ignore + return state.DoneState() # if we are at a post currently (from a previous leg), backup to avoid collision - if self.context.env.arrived_at_post: - self.context.env.arrived_at_post = False - return WaypointStateTransitions.backup_from_post.name # type: ignore - if self.context.course.look_for_post(): - if self.context.env.current_fid_pos() is not None: - return WaypointStateTransitions.find_approach_post.name # type: ignore + if context.env.arrived_at_post: + context.env.arrived_at_post = False + return post_backup.PostBackupState() + + if context.course.look_for_post(): + if context.env.current_fid_pos() is not None: + return approach_post.ApproachPostState() # Attempt to find the waypoint in the TF tree and drive to it try: - waypoint_pos = self.context.course.current_waypoint_pose().position - cmd_vel, arrived = self.context.rover.driver.get_drive_command( + waypoint_pos = context.course.current_waypoint_pose().position + cmd_vel, arrived = context.rover.driver.get_drive_command( waypoint_pos, - self.context.rover.get_pose(), + context.rover.get_pose(), self.STOP_THRESH, self.DRIVE_FWD_THRESH, ) if arrived: - if not self.context.course.look_for_post(): + if not context.course.look_for_post(): # We finished a regular waypoint, go onto the next one - self.context.course.increment_waypoint() + context.course.increment_waypoint() else: # We finished a waypoint associated with a fiducial id, but we have not seen it yet. - return WaypointStateTransitions.search_at_waypoint.name # type: ignore + return search.SearchState() - if self.context.rover.stuck: - # Removed .name - self.context.rover.previous_state = WaypointStateTransitions.continue_waypoint_traverse.name # type: ignore - return WaypointStateTransitions.recovery_state.name # type: ignore + if context.rover.stuck: + context.rover.previous_state = self + return recovery.RecoveryState() - self.context.rover.send_drive_command(cmd_vel) + context.rover.send_drive_command(cmd_vel) except ( tf2_ros.LookupException, @@ -98,4 +68,4 @@ def evaluate(self, ud) -> str: ): pass - return WaypointStateTransitions.continue_waypoint_traverse.name # type: ignore + return self diff --git a/src/util/state_lib/__init__.py b/src/util/state_lib/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/util/state_lib/state.py b/src/util/state_lib/state.py new file mode 100644 index 000000000..39f701a6a --- /dev/null +++ b/src/util/state_lib/state.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + + +class State(ABC): + """ + Abstract class that represents a state in the state machine. + """ + + def __init__(self) -> None: + super().__init__() + + @abstractmethod + def on_enter(self, context): + """ + Called exactly once when the state is entered. If any state needs to be initialized, do it here. + :param context: The context object that is passed to the state machine. + """ + raise NotImplementedError + + @abstractmethod + def on_exit(self, context): + """ + Called exactly once when the state is exited. + No cleanup of internal state is necessary since this state will be destroyed. + An example usecase of this may be to write to an external log or send a service call to an external system. + :param context: The context object that is passed to the state machine. + """ + raise NotImplementedError + + @abstractmethod + def on_loop(self, context) -> State: + """ + Called repeatedly while the state is active. + :param context: The context object that is passed to the state machine. + :return: The next state to transition to. If the state should not change, return self. + """ + raise NotImplementedError + + def __repr__(self): + return self.__class__.__name__ + + def __str__(self): + return self.__repr__() + + def __eq__(self, other): + return self.__class__ == other.__class__ + + def __hash__(self): + return hash(self.__class__) + + def __ne__(self, other): + return not self.__eq__(other) + + +##state to be returned to signfy that the state machine should exit +class ExitState(State): + def on_enter(self, ctx): + pass + + def on_exit(self, ctx): + pass + + def on_loop(self, ctx) -> State: + return self diff --git a/src/util/state_lib/state_machine.py b/src/util/state_lib/state_machine.py new file mode 100644 index 000000000..9236493b4 --- /dev/null +++ b/src/util/state_lib/state_machine.py @@ -0,0 +1,127 @@ +import time +from collections import defaultdict +from dataclasses import dataclass +from enum import Enum +from threading import Lock +from typing import DefaultDict, Set, List, Callable, Any, Optional + +from util.state_lib.state import State, ExitState + + +class LogLevel(Enum): + OFF = 0 + DEBUG = 1 + VERBOSE = 2 + + +@dataclass +class TransitionRecord: + time: float + origin_state: str + dest_state: str + + +class StateMachine: + current_state: State + state_lock: Lock + state_transitions: DefaultDict[type[State], Set[type[State]]] + transition_log: List[TransitionRecord] + context: Any + name: str + off_lambda: Optional[Callable[[Any], bool]] + off_state: Optional[State] + log_level: LogLevel + logger: Callable[[str], None] + on: bool + onLock: Lock + + def __init__( + self, + initial_state: State, + name: str, + log_level: LogLevel = LogLevel.DEBUG, + logger: Callable[[str], None] = print, + ): + self.current_state = initial_state + self.state_lock = Lock() + self.state_transitions = defaultdict(set) + self.state_transitions[type(self.current_state)] = set() + self.transition_log: List[TransitionRecord] = [] + self.context = None + self.name = name + self.off_lamdba = None + self.off_state = None + self.log_level = log_level + self.logger = logger + self.on = True + self.onLock = Lock() + + def __update(self): + with self.state_lock: + current_state = self.current_state + if self.log_level == LogLevel.VERBOSE: + self.logger(f"{self.name} state machine, current state = {str(current_state)}") + if self.off_lambda is not None and self.off_lambda(self.context) and self.off_state is not None: + next_state = self.off_state + else: + next_state = current_state.on_loop(self.context) + if type(next_state) not in self.state_transitions[type(current_state)]: + raise Exception(f"Invalid transition from {current_state} to {next_state}") + if type(next_state) is not type(current_state): + if self.log_level == LogLevel.DEBUG or self.log_level == LogLevel.VERBOSE: + self.logger(f"{self.name} state machine, transistioning to {str(next_state)}") + current_state.on_exit(self.context) + self.transition_log.append(TransitionRecord(time.time(), str(current_state), str(next_state))) + with self.state_lock: + self.current_state = next_state + self.current_state.on_enter(self.context) + + def stop(self): + with self.onLock: + self.on = False + + def run(self, update_rate: float = float("inf"), warning_handle: Callable = print): + """ + Runs the state machine until it returns an ExitState. + Aims for as close to update_rate_hz, updates per second + :param update_rate: targeted updates per second + """ + target_loop_time = None if update_rate == float("inf") else (1.0 / update_rate) + self.current_state.on_enter(self.context) + is_on = True + + while is_on: + start = time.time() + self.__update() + if isinstance(self.current_state, ExitState): + break + with self.onLock: + is_on = self.on + elapsed_time = time.time() - start + if target_loop_time is not None and elapsed_time < target_loop_time: + time.sleep(target_loop_time - elapsed_time) + elif target_loop_time is not None and elapsed_time > target_loop_time: + warning_handle( + f"[WARNING] state machine loop overran target loop time by {elapsed_time - target_loop_time} s" + ) + + def add_transition(self, state_from: State, state_to: State) -> None: + self.state_transitions[type(state_from)].add(type(state_to)) + + def add_transitions(self, state_from: State, states_to: List[State]) -> None: + for state_to in states_to: + self.add_transition(state_from, state_to) + self.add_transition(state_from, state_from) + if self.off_state is not None: + self.add_transition(state_from, self.off_state) + + def set_context(self, context: Any): + self.context = context + + def configure_off_switch(self, off_state: State, off_lambda: Callable[[Any], bool]): + if type(off_state) not in self.state_transitions: + raise Exception("Attempted to configure an Off State that doesn't exist") + self.off_state = off_state + self.off_lambda = off_lambda + for _, to_states in self.state_transitions.items(): + to_states.add(type(off_state)) diff --git a/src/util/state_lib/state_publisher_server.py b/src/util/state_lib/state_publisher_server.py new file mode 100644 index 000000000..9b69262ad --- /dev/null +++ b/src/util/state_lib/state_publisher_server.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import threading +import time +from typing import Callable + +import rospy +from mrover.msg import StateMachineStructure, StateMachineTransition, StateMachineStateUpdate + +from util.state_lib.state_machine import StateMachine + + +class StatePublisher: + structure_publisher: rospy.Publisher + state_publisher: rospy.Publisher + state_machine: StateMachine + __structure_thread: threading.Thread + __state_thread: threading.Thread + __stop_lock: threading.Lock + __stop: bool + + def __init__( + self, + state_machine: StateMachine, + structure_pub_topic: str, + structure_update_rate_hz: float, + state_pub_topic: str, + state_update_rate_hz: float, + ): + self.state_machine = state_machine + self.structure_publisher = rospy.Publisher(structure_pub_topic, StateMachineStructure, queue_size=1) + self.state_publisher = rospy.Publisher(state_pub_topic, StateMachineStateUpdate, queue_size=1) + self.__stop_lock = threading.Lock() + self.__structure_thread = threading.Thread( + target=self.run_at_interval, args=(self.publish_structure, structure_update_rate_hz) + ) + self.__state_thread = threading.Thread( + target=self.run_at_interval, args=(self.publish_state, state_update_rate_hz) + ) + self.__stop = False + self.__structure_thread.start() + self.__state_thread.start() + + def stop(self) -> None: + with self.__stop_lock: + self.__stop = True + + def publish_structure(self) -> None: + structure = StateMachineStructure() + structure.machineName = self.state_machine.name + for origin, destinations in self.state_machine.state_transitions.items(): + transition = StateMachineTransition() + transition.origin = origin.__name__ + transition.destinations = [dest.__name__ for dest in destinations] + structure.transitions.append(transition) + self.structure_publisher.publish(structure) + + def publish_state(self) -> None: + with self.state_machine.state_lock: + cur_state = self.state_machine.current_state + state = StateMachineStateUpdate() + state.stateMachineName = self.state_machine.name + state.state = str(cur_state) + self.state_publisher.publish(state) + + def run_at_interval(self, func: Callable[[], None], update_hz: float): + desired_loop_time = 1.0 / update_hz + while True: + start_time = time.time() + with self.__stop_lock: + if self.__stop: + break + func() + elapsed_time = time.time() - start_time + if desired_loop_time - elapsed_time > 0: + time.sleep(desired_loop_time - elapsed_time) diff --git a/test/util/state_lib/simple_test.py b/test/util/state_lib/simple_test.py new file mode 100644 index 000000000..175225579 --- /dev/null +++ b/test/util/state_lib/simple_test.py @@ -0,0 +1,72 @@ +from __future__ import annotations +import sys +import unittest +from util.state_lib.state_machine import StateMachine +from util.state_lib.state import State, ExitState +import unittest + + +class ForwardState(State): + def on_enter(self, context): + pass + + def on_exit(self, context): + pass + + def on_loop(self, context) -> State: + context.var += 1 + context.forward_loop_count += 1 + if context.var == 3: + return BackwardState() + return ForwardState() + + +class BackwardState(State): + def on_enter(self, context): + pass + + def on_exit(self, context): + pass + + def on_loop(self, context) -> State: + context.var -= 1 + context.backward_loop_count += 1 + if context.var == 0: + return ExitState() + return BackwardState() + + +from dataclasses import dataclass + + +@dataclass +class Context: + var: int = 0 + forward_loop_count: int = 0 + backward_loop_count: int = 0 + + +class TestSimpleStateMachine(unittest.TestCase): + def test_simple(self): + sm = StateMachine(ForwardState(), "SimpleStateMachine") + context = Context() + sm.set_context(context) + sm.add_transition(ForwardState(), BackwardState()) + sm.add_transition(BackwardState(), ForwardState()) + sm.add_transition(BackwardState(), ExitState()) + sm.add_transition(ForwardState(), ForwardState()) + sm.add_transition(BackwardState(), BackwardState()) + sm.run() + tlog = sm.transition_log + timeless = [(t.origin_state, t.dest_state) for t in tlog] + expected = [("ForwardState", "BackwardState"), ("BackwardState", "ExitState")] + for t, e in zip(timeless, expected): + self.assertEqual(t, e) + self.assertEqual(context.forward_loop_count, 3) + self.assertEqual(context.backward_loop_count, 3) + + +if __name__ == "__main__": + import rostest + + rostest.rosrun("mrover", "SimpleStateLibraryTaste", TestSimpleStateMachine) diff --git a/test/util/state_lib/test_external_data.py b/test/util/state_lib/test_external_data.py new file mode 100644 index 000000000..34a06280d --- /dev/null +++ b/test/util/state_lib/test_external_data.py @@ -0,0 +1,82 @@ +from util.state_lib.state_machine import StateMachine +from util.state_lib.state import State, ExitState +import random +from threading import Thread, Lock +import time + +""" +Test multi-threaded program that has an external thread that feeds a resource that the +Context object queries. Not a unit-test but run manually to ensure relatively predicatble +behavior +""" + + +class Context: + def __init__(self): + self.stateCapture = ExternalStateCapture() + + def getTrigger(self): + self.stateCapture.triggerLock.acquire() + trigger = self.stateCapture.trigger + self.stateCapture.triggerLock.release() + return trigger + + +class ExternalStateCapture: + def random_loop(self): + # run forever and with probability 0.5 flip the trigger + while True: + if random.random() < 0.5: + self.triggerLock.acquire() + self.trigger = not self.trigger + self.triggerLock.release() + time.sleep(1) + + def __init__(self): + self.triggerLock = Lock() + self.trigger = False + + +class WaitingState(State): + def __init__(self): + super().__init__() + + def on_enter(self, context): + print("Waiting for trigger") + + def on_loop(self, context): + if context.getTrigger(): + return RunningState() + return self + + def on_exit(self, context): + print("Triggered!") + + +class RunningState(State): + def __init__(self): + super().__init__() + + def on_enter(self, context): + print("Running") + + def on_loop(self, context): + if not context.getTrigger(): + return WaitingState() + return self + + def on_exit(self, context): + print("Stopped") + + +if __name__ == "__main__": + sm = StateMachine(WaitingState(), "RandomForeverStateMachine") + sm.add_transition(WaitingState(), RunningState()) + sm.add_transition(RunningState(), WaitingState()) + sm.add_transition(RunningState(), RunningState()) + sm.add_transition(WaitingState(), WaitingState()) + context = Context() + sm.set_context(context) + thread = Thread(target=context.stateCapture.random_loop) + thread.start() + sm.run() diff --git a/test/util/state_lib/test_loop_overrun.py b/test/util/state_lib/test_loop_overrun.py new file mode 100644 index 000000000..caefc17b4 --- /dev/null +++ b/test/util/state_lib/test_loop_overrun.py @@ -0,0 +1,84 @@ +from __future__ import annotations +import sys +import unittest +from util.state_lib.state_machine import StateMachine +from util.state_lib.state import State, ExitState +import unittest +import time + + +class ForwardState(State): + def on_enter(self, context): + pass + + def on_exit(self, context): + pass + + def on_loop(self, context) -> State: + context.var += 1 + context.forward_loop_count += 1 + if context.var == 3: + return BackwardState() + return ForwardState() + + +class BackwardState(State): + def on_enter(self, context): + pass + + def on_exit(self, context): + pass + + def on_loop(self, context) -> State: + context.var -= 1 + context.backward_loop_count += 1 + time.sleep(1) + if context.var == 0: + return ExitState() + return BackwardState() + + +from dataclasses import dataclass + + +@dataclass +class Context: + var: int = 0 + forward_loop_count: int = 0 + backward_loop_count: int = 0 + + +@dataclass +class WarningHandle: + got_warning: bool = False + + def set_warning(self, s): + self.got_warning = True + + +class TestLoopOverrun(unittest.TestCase): + def test_loop_overrun(self): + sm = StateMachine(ForwardState(), "LoopOverrun") + context = Context() + sm.set_context(context) + sm.add_transition(ForwardState(), BackwardState()) + sm.add_transition(BackwardState(), ForwardState()) + sm.add_transition(BackwardState(), ExitState()) + sm.add_transition(ForwardState(), ForwardState()) + sm.add_transition(BackwardState(), BackwardState()) + wh = WarningHandle() + sm.run(10, wh.set_warning) + self.assertTrue(wh.got_warning) + tlog = sm.transition_log + timeless = [(t.origin_state, t.dest_state) for t in tlog] + expected = [("ForwardState", "BackwardState"), ("BackwardState", "ExitState")] + for t, e in zip(timeless, expected): + self.assertEqual(t, e) + self.assertEqual(context.forward_loop_count, 3) + self.assertEqual(context.backward_loop_count, 3) + + +if __name__ == "__main__": + import rostest + + rostest.rosrun("mrover", "SimpleStateLibraryTaste", TestLoopOverrun)