diff --git a/script/controller.py b/script/controller.py index 65f25ad..a678a6d 100755 --- a/script/controller.py +++ b/script/controller.py @@ -1,11 +1,6 @@ #!/usr/bin/env python # author: mingyuw@stanford.edu -""" -given desired waypoints, this ros node sends the ackermann steering command for -the ego vehicle to follow the trajectory -""" - import rospy import tf from nav_msgs.msg import Odometry @@ -19,7 +14,14 @@ import timeit class odom_state(object): + ''' + This class stores an instantaneous odometry state of the vehicle. All units + are SI. + ''' def __init__(self): + ''' + Initializes the class variables. + ''' self.time = None self.x = None self.y = None @@ -28,14 +30,26 @@ def __init__(self): self.vy = None self.speed = None + def update_vehicle_state(self, odom_msg): + ''' + Updates the member variables to the instantaneous odometry state of a + vehicle. + + Parameters + ---------- + odom_msg : Odometry + A message containing the current odometry state of the vehicle. + ''' self.time = odom_msg.header.stamp.to_sec() self.x = odom_msg.pose.pose.position.x self.y = odom_msg.pose.pose.position.y - ori_quat = (odom_msg.pose.pose.orientation.x, - odom_msg.pose.pose.orientation.y, - odom_msg.pose.pose.orientation.z, - odom_msg.pose.pose.orientation.w) + ori_quat = ( + odom_msg.pose.pose.orientation.x, + odom_msg.pose.pose.orientation.y, + odom_msg.pose.pose.orientation.z, + odom_msg.pose.pose.orientation.w + ) ori_euler = tf.transformations.euler_from_quaternion(ori_quat) self.yaw = ori_euler[2] self.vx = odom_msg.twist.twist.linear.x @@ -47,20 +61,67 @@ def update_vehicle_state(self, odom_msg): # self.time # ) + def get_position(self): + ''' + Returns the stored global position :math:`(x, y)` of the vehicle center. + + Returns + ------- + Array + The 2D position (global) of the vehicle. + ''' return [self.x, self.y] + def get_pose(self): + ''' + Returns the stored global position :math:`(x, y)` of the vehicle center + and heading/yaw :math:`(\theta)` with respect to the x-axis. + + Returns + ------- + Array + A SE(2) vector representing the global position and heading of the + vehicle. + ''' return [self.x, self.y, self.yaw] + def get_velocity(self): + ''' + Returns the stored global velocity :math:`(v_x, v_y)` of the vehicle. + + Returns + ------- + Array + The 2D velocity (global) of the vehicle. + ''' return [self.vx, self.vy] + def get_speed(self): + ''' + Returns the stored global speed of the vehicle. + + Returns + ------- + Array + The speed (global) of the vehicle. + ''' return self.speed class AckermannController: + ''' + A ROS node used to control the vehicle based on some trajectory. + Specifically, this node publishes AckermannDrive messages for the ego + vehicle determined by a set of desired waypoints. + ''' def __init__(self): + ''' + Initializes the class, including creating the used publishers and + subscribers used throughout the class. + ''' rospy.init_node("controller", anonymous=True) self.time = rospy.get_time() @@ -78,17 +139,17 @@ def __init__(self): # path tracking information self.pathReady = False - self.path = np.zeros(shape=(self.traj_steps, 2)) # absolute position + self.path = np.zeros(shape=(self.traj_steps, 2)) # Absolute position self.path_tree = KDTree(self.path) self.vel_path = np.zeros(shape=(self.traj_steps, 2)) self.steer_cache = None - # # PID controller parameter + # PID controller parameter self.pid_str_prop = rospy.get_param("~str_prop") self.pid_str_deriv = rospy.get_param("~str_deriv") - # subscribers, publishers + # Subscribers rospy.Subscriber( "/carla/" + rolename + "/odometry", Odometry, @@ -99,21 +160,36 @@ def __init__(self): MultiDOFJointTrajectory, self.desired_waypoints_cb ) + + # Publishers self.command_pub = rospy.Publisher( "/carla/" + rolename + "/ackermann_cmd", AckermannDrive, queue_size=10 ) - # self.vehicle_cmd_pub = rospy.Publisher( - # "/carla/" + rolename + "/vehicle_control_cmd", - # CarlaEgoVehicleControl, - # queue_size=10 - # ) - self.tracking_pt_viz_pub = rospy.Publisher("tracking_point_mkr", Marker, queue_size=10) - self.ctrl_timer = rospy.Timer(rospy.Duration(1.0/ctrl_freq), self.timer_cb) + self.tracking_pt_viz_pub = rospy.Publisher( + "tracking_point_mkr", + Marker, + queue_size=10 + ) - def desired_waypoints_cb(self, msg): + # Class timer + self.ctrl_timer = rospy.Timer( + rospy.Duration(1.0/ctrl_freq), + self.timer_cb + ) + + def desired_waypoints_cb(self, msg): + ''' + A callback function that updates the desired the desired trajectory the + ego vehicle should follow. + + Parameters + ---------- + msg : MultiDOFJointTrajectory + The previously calculated trajectory from the planner. + ''' for i, pt in enumerate(msg.points): self.path[i, 0] = pt.transforms[0].translation.x self.path[i, 1] = pt.transforms[0].translation.y @@ -121,16 +197,38 @@ def desired_waypoints_cb(self, msg): self.vel_path[i, 1] = pt.velocities[0].linear.y self.path_tree = KDTree(self.path) + if not self.pathReady: self.pathReady = True def odom_cb(self, msg): + ''' + Callback function to update the odometry state of the vehicle according + to the received message. + + Parameters + ---------- + msg : Odometry + The message containing the vehicle's odometry state. + ''' self.state.update_vehicle_state(msg) + if not self.stateReady: self.stateReady = True + def timer_cb(self, event): + ''' + This is a callback for the class timer, which will be called every tick. + The callback calculates the target values of the AckermannDrive message + and publishes the command to the ego vehicle's Ackermann control topic. + + Parameters + ---------- + event : rospy.TimerEvent + The timer's tick event. + ''' cmd_msg = AckermannDrive() current_time = rospy.get_time() # delta_t = current_time - self.time @@ -159,6 +257,7 @@ def timer_cb(self, event): target_pt = self.path_tree.data[-1, :] # str_idx = self.vel_path.shape[0] - 1 print("CONTROLLER: at the end of the desired waypoits!!!") + # Target point for velocity if idx < self.traj_steps: target_vel = self.vel_path[idx, :] @@ -179,6 +278,7 @@ def timer_cb(self, event): steer = self.compute_ackermann_steer(target_pt) steer_diff = abs(steer - self.steer_prev) + if steer_diff >= 0.3: print(" 0.3") steer = self.steer_prev @@ -237,10 +337,24 @@ def timer_cb(self, event): self.command_pub.publish(cmd_msg) - - def compute_ackermann_steer(self, target_pt): + ''' + Calculates the desired steering command [-1, 1] for the vehicle to + manuever towards the provided global position :math:`(x, y)`. + + Parameters + ---------- + target_pt : Array + The 2D target position (global) in which the vehicle should head. + + Returns + ------- + float + The steering command for the vehicle in the range :math:`[-1, 1]`, + representing the max left and right control, respectively. + ''' pos_x, pos_y, yaw = self.state.get_pose() + if np.linalg.norm([target_pt[0] - pos_x, target_pt[1] - pos_y]) < 1: print("target point too close!!!!!!!!") if self.steer_cache: @@ -264,8 +378,8 @@ def compute_ackermann_steer(self, target_pt): # print("steer, d_steer:", steer, d_steer) steer = steer + d_steer*self.pid_str_deriv - self.steer_cache = steer + return steer