Skip to content

Commit

Permalink
Use common observation utility functions. (#2141)
Browse files Browse the repository at this point in the history
  • Loading branch information
Adaickalavan authored Feb 12, 2024
1 parent e38fbca commit 9fa8e28
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 117 deletions.
2 changes: 1 addition & 1 deletion examples/e10_drive/inference/contrib_policy/filter_obs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Any, Dict, Sequence, Tuple
from typing import Any, Dict, Tuple

import gymnasium as gym
import numpy as np
Expand Down
124 changes: 14 additions & 110 deletions examples/e11_platoon/inference/contrib_policy/filter_obs.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import math
from typing import Any, Dict, Sequence, Tuple
from typing import Any, Dict, Tuple

import gymnasium as gym
import numpy as np

from smarts.core.agent_interface import RGB
from smarts.core.colors import Colors, SceneColors
from smarts.core.utils.observations import points_to_pixels, replace_rgb_image_color


class FilterObs:
Expand Down Expand Up @@ -72,19 +73,19 @@ def filter(self, obs: Dict[str, Any]) -> Dict[str, Any]:
# Get rgb image, remove road, and replace other egos (if any) as background vehicles
rgb = obs["top_down_rgb"]
h, w, _ = rgb.shape
rgb_noroad = replace_color(rgb=rgb, old_color=[self._road_color, self._lane_divider_color, self._edge_divider_color], new_color=self._no_color)
rgb_ego = replace_color(rgb=rgb_noroad, old_color=[self._ego_color], new_color=self._traffic_color, mask=self._rgb_mask)
rgb_noroad = replace_rgb_image_color(rgb=rgb, old_color=[self._road_color, self._lane_divider_color, self._edge_divider_color], new_color=self._no_color)
rgb_ego = replace_rgb_image_color(rgb=rgb_noroad, old_color=[self._ego_color], new_color=self._traffic_color, mask=self._rgb_mask)

# Superimpose waypoints onto rgb image
wps = obs["waypoint_paths"]["position"][0:11, 3:, 0:3]
for path in wps[:]:
wps_valid = points_to_pixels(
points=path,
ego_pos=ego_pos,
ego_heading=ego_heading,
w=w,
h=h,
res=self._res,
center_position=ego_pos,
heading=ego_heading,
width=w,
height=h,
resolution=self._res,
)
for point in wps_valid:
img_x, img_y = point[0], point[1]
Expand All @@ -95,11 +96,11 @@ def filter(self, obs: Dict[str, Any]) -> Dict[str, Any]:
if not all((goal:=obs["ego_vehicle_state"]["mission"]["goal_position"]) == np.zeros((3,))):
goal_pixel = points_to_pixels(
points=np.expand_dims(goal,axis=0),
ego_pos=ego_pos,
ego_heading=ego_heading,
w=w,
h=h,
res=self._res,
center_position=ego_pos,
heading=ego_heading,
width=w,
height=h,
resolution=self._res,
)
if len(goal_pixel) != 0:
img_x, img_y = goal_pixel[0][0], goal_pixel[0][1]
Expand All @@ -121,100 +122,3 @@ def filter(self, obs: Dict[str, Any]) -> Dict[str, Any]:
return filtered_obs
# fmt: on


def replace_color(
rgb: np.ndarray,
old_color: Sequence[np.ndarray],
new_color: np.ndarray,
mask: np.ndarray = np.ma.nomask,
) -> np.ndarray:
"""Convert pixels of value `old_color` to `new_color` within the masked
region in the received RGB image.
Args:
rgb (np.ndarray): RGB image. Shape = (m,n,3).
old_color (Sequence[np.ndarray]): List of old colors to be removed from the RGB image. Shape = (3,).
new_color (np.ndarray): New color to be added to the RGB image. Shape = (3,).
mask (np.ndarray, optional): Valid regions for color replacement. Shape = (m,n,3).
Defaults to np.ma.nomask .
Returns:
np.ndarray: RGB image with `old_color` pixels changed to `new_color`
within the masked region. Shape = (m,n,3).
"""
# fmt: off
assert all(color.shape == (3,) for color in old_color), (
f"Expected old_color to be of shape (3,), but got {[color.shape for color in old_color]}.")
assert new_color.shape == (3,), (
f"Expected new_color to be of shape (3,), but got {new_color.shape}.")

nc = new_color.reshape((1, 1, 3))
nc_array = np.full_like(rgb, nc)
rgb_masked = np.ma.MaskedArray(data=rgb, mask=mask)

rgb_condition = rgb_masked
result = rgb
for color in old_color:
result = np.ma.where((rgb_condition == color.reshape((1, 1, 3))).all(axis=-1)[..., None], nc_array, result)

return result
# fmt: on


def points_to_pixels(
points: np.ndarray,
ego_pos: np.ndarray,
ego_heading: float,
w: int,
h: int,
res: float,
) -> np.ndarray:
"""Converts points into pixel coordinates in order to superimpose the
points onto the RGB image.
Args:
points (np.ndarray): Array of points. Shape (n,3).
ego_pos (np.ndarray): Ego position. Shape = (3,).
ego_heading (float): Ego heading in radians.
w (int): Width of RGB image
h (int): Height of RGB image.
res (float): Resolution of RGB image in meters/pixels. Computed as
ground_size/image_size.
Returns:
np.ndarray: Array of point coordinates on the RGB image. Shape = (m,3).
"""
# fmt: off
mask = [False if all(point == np.zeros(3,)) else True for point in points]
points_nonzero = points[mask]
points_delta = points_nonzero - ego_pos
points_rotated = rotate_axes(points_delta, theta=ego_heading)
points_pixels = points_rotated / np.array([res, res, res])
points_overlay = np.array([w / 2, h / 2, 0]) + points_pixels * np.array([1, -1, 1])
points_rfloat = np.rint(points_overlay)
points_valid = points_rfloat[(points_rfloat[:,0] >= 0) & (points_rfloat[:,0] < w) & (points_rfloat[:,1] >= 0) & (points_rfloat[:,1] < h)]
points_rint = points_valid.astype(int)
return points_rint
# fmt: on


def rotate_axes(points: np.ndarray, theta: float) -> np.ndarray:
"""A counterclockwise rotation of the x-y axes by an angle theta θ about
the z-axis.
Args:
points (np.ndarray): x,y,z coordinates in original axes. Shape = (n,3).
theta (np.float): Axes rotation angle in radians.
Returns:
np.ndarray: x,y,z coordinates in rotated axes. Shape = (n,3).
"""
# fmt: off
theta = (theta + np.pi) % (2 * np.pi) - np.pi
ct, st = np.cos(theta), np.sin(theta)
R = np.array([[ ct, st, 0],
[-st, ct, 0],
[ 0, 0, 1]])
rotated_points = (R.dot(points.T)).T
return rotated_points
# fmt: on
14 changes: 8 additions & 6 deletions smarts/core/utils/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,14 @@ def points_to_pixels(
Args:
points (np.ndarray): Array of points. Shape (n,3).
ego_pos (np.ndarray): Ego position. Shape = (3,).
ego_heading (float): Ego heading in radians.
w (int): Width of RGB image
h (int): Height of RGB image.
res (float): Resolution of RGB image in meters/pixels. Computed as
ground_size/image_size.
center_position (np.ndarray): Center position of image. Generally, this
is equivalent to ego position. Shape = (3,).
heading (float): Heading of image in radians. Generally, this is
equivalent to ego heading.
width (int): Width of RGB image
height (int): Height of RGB image.
resolution (float): Resolution of RGB image in meters/pixels. Computed
as ground_size/image_size.
Returns:
np.ndarray: Array of point coordinates on the RGB image. Shape = (m,3).
Expand Down

0 comments on commit 9fa8e28

Please sign in to comment.