Skip to content

Commit

Permalink
Issue #3: Improves readability of the code
Browse files Browse the repository at this point in the history
- Updates some of the formatting of the controller file to match the PEP-8 standard.
- Updates some of the formatting of the planner file to match the PEP-8 standard.
  • Loading branch information
exoticDFT committed Oct 3, 2020
1 parent 8df4765 commit dd5ef62
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 32 deletions.
54 changes: 23 additions & 31 deletions script/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
from ackermann_msgs.msg import AckermannDrive
from visualization_msgs.msg import Marker

from carla_msgs.msg import CarlaEgoVehicleControl
from carla_msgs.msg import CarlaEgoVehicleInfo
from scipy.spatial import KDTree
import numpy as np
import timeit


class odom_state(object):
'''
Expand All @@ -32,12 +31,11 @@ def __init__(self):
self.vy = None
self.speed = None


def update_vehicle_state(self, odom_msg):
'''
Updates the member variables to the instantaneous odometry state of a
vehicle.
Parameters
----------
odom_msg : Odometry
Expand All @@ -58,24 +56,23 @@ def update_vehicle_state(self, odom_msg):
self.vy = odom_msg.twist.twist.linear.y
self.speed = np.sqrt(self.vx**2 + self.vy**2)


def get_position(self):
'''
Returns the stored global position :math:`(x, y)` of the vehicle center.
Returns the stored global position :math:`(x, y)` of the vehicle
center.
Returns
-------
Array
The 2D position (global) of the vehicle.
'''
return [self.x, self.y]


def get_pose(self):
'''
Returns the stored global position :math:`(x, y)` of the vehicle center
and heading/yaw :math:`(\theta)` with respect to the x-axis.
Returns
-------
Array
Expand All @@ -84,30 +81,29 @@ def get_pose(self):
'''
return [self.x, self.y, self.yaw]


def get_velocity(self):
'''
Returns the stored global velocity :math:`(v_x, v_y)` of the vehicle.
Returns
-------
Array
The 2D velocity (global) of the vehicle.
'''
return [self.vx, self.vy]


def get_speed(self):
'''
Returns the stored global speed of the vehicle.
Returns
-------
Array
The speed (global) of the vehicle.
'''
return self.speed


class AckermannController:
'''
A ROS node used to control the vehicle based on some trajectory.
Expand Down Expand Up @@ -139,7 +135,7 @@ def __init__(self):

# path tracking information
self.pathReady = False
self.path = np.zeros(shape=(self.traj_steps, 2)) # Absolute position
self.path = np.zeros(shape=(self.traj_steps, 2)) # Absolute position
self.path_tree = KDTree(self.path)
self.vel_path = np.zeros(shape=(self.traj_steps, 2))

Expand Down Expand Up @@ -176,7 +172,7 @@ def __init__(self):
topic = "/carla/{}/vehicle_info".format(rolename)
rospy.loginfo_once(
"Vehicle information for %s: %s",
rolename,
rolename,
rospy.wait_for_message(topic, CarlaEgoVehicleInfo)
)

Expand All @@ -186,12 +182,11 @@ def __init__(self):
self.timer_cb
)


def desired_waypoints_cb(self, msg):
'''
A callback function that updates the desired the desired trajectory the
ego vehicle should follow.
Parameters
----------
msg : MultiDOFJointTrajectory
Expand All @@ -208,12 +203,11 @@ def desired_waypoints_cb(self, msg):
if not self.pathReady:
self.pathReady = True


def odom_cb(self, msg):
'''
Callback function to update the odometry state of the vehicle according
to the received message.
Parameters
----------
msg : Odometry
Expand All @@ -224,13 +218,13 @@ def odom_cb(self, msg):
if not self.stateReady:
self.stateReady = True


def timer_cb(self, event):
'''
This is a callback for the class timer, which will be called every tick.
The callback calculates the target values of the AckermannDrive message
and publishes the command to the ego vehicle's Ackermann control topic.
This is a callback for the class timer, which will be called every
tick. The callback calculates the target values of the AckermannDrive
message and publishes the command to the ego vehicle's Ackermann
control topic.
Parameters
----------
event : rospy.TimerEvent
Expand All @@ -250,8 +244,8 @@ def timer_cb(self, event):
if self.pathReady and self.stateReady:
pos_x, pos_y = self.state.get_position()

# Pick target w/o collision avoidance. Find the closest point in the
# trajectory tree.
# Pick target w/o collision avoidance. Find the closest point in
# the trajectory tree.
_, idx = self.path_tree.query([pos_x, pos_y])

# Steering target. Three points ahead of closest point.
Expand All @@ -273,7 +267,7 @@ def timer_cb(self, event):
acceleration = abs(speed_diff) / (2.0 * delta_t)
cmd_msg.acceleration = np.min([1.5, acceleration])
steer = self.compute_ackermann_steer(target_pt)

steer_diff = abs(steer - self.steer_prev)
rospy.logdebug("Steering difference: %f", steer_diff)

Expand Down Expand Up @@ -318,7 +312,6 @@ def timer_cb(self, event):
# self.vehicle_cmd_pub.publish(vehicle_cmd_msg)
self.command_pub.publish(cmd_msg)


def compute_ackermann_long_params(self, target_velocity):
desired_speed = 0.0
desired_acceleration = 0.0
Expand All @@ -333,17 +326,16 @@ def compute_ackermann_long_params(self, target_velocity):

return (desired_speed, desired_acceleration, desired_jerk)


def compute_ackermann_steer(self, target_pt):
'''
Calculates the desired steering command [-1, 1] for the vehicle to
manuever towards the provided global position :math:`(x, y)`.
Parameters
----------
target_pt : Array
The 2D target position (global) in which the vehicle should head.
Returns
-------
float
Expand Down
4 changes: 3 additions & 1 deletion script/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def get_angle(x, y):
return t_min, t_max, relevant_flag

t_min, t_max, flag = enter_area_time(pos_x, pos_y, ang_speed)

if flag:
other_min_list = []
other_max_list = []
Expand All @@ -259,6 +260,7 @@ def overlap(inter1, inter2):
min2, max2 = inter2
min = np.max([min1, min2])
max = np.min([max1, max2])

if min < max - 0.01:
return [min, max]
else:
Expand All @@ -282,7 +284,6 @@ def overlap(inter1, inter2):
interaction_flag = False
else:
interaction_flag = True

else:
interaction_flag = False

Expand All @@ -296,6 +297,7 @@ def overlap(inter1, inter2):
traj_msg.header.frame_id = 'map'
traj_msg.points = []
angle_increment = ang_speed * self.time_step

for i in range(self.steps):
ang = current_ang + angle_increment * i
traj_point = MultiDOFJointTrajectoryPoint()
Expand Down

0 comments on commit dd5ef62

Please sign in to comment.