From 1bb0360f4a45152691fd6f740ae2bcd248b97294 Mon Sep 17 00:00:00 2001 From: Tucker Date: Thu, 20 Apr 2023 09:22:19 -0400 Subject: [PATCH 1/5] Make base interface. --- smarts/core/plan.py | 42 +++++++++++++++---- smarts/primatives/__init__.py | 21 ++++++++++ .../types => primatives}/constants.py | 13 +++++- smarts/sstudio/generators.py | 31 +++++++++----- smarts/sstudio/genscenario.py | 37 +++++++++------- smarts/sstudio/types/__init__.py | 2 +- smarts/sstudio/types/bubble_limits.py | 2 +- smarts/sstudio/types/mission.py | 15 ++++--- smarts/sstudio/types/route.py | 7 ++-- 9 files changed, 122 insertions(+), 48 deletions(-) create mode 100644 smarts/primatives/__init__.py rename smarts/{sstudio/types => primatives}/constants.py (85%) diff --git a/smarts/core/plan.py b/smarts/core/plan.py index 255d49180b..0b262141c9 100644 --- a/smarts/core/plan.py +++ b/smarts/core/plan.py @@ -24,19 +24,17 @@ import math import random import sys -import warnings from dataclasses import dataclass, field -from typing import List, Optional, Tuple +from typing import List, Literal, Optional, Tuple, Union import numpy as np from smarts.core.coordinates import Dimensions, Heading, Point, Pose, RefLinePoint from smarts.core.road_map import RoadMap from smarts.core.utils.math import min_angles_difference_signed, vec_to_radians +from smarts.primatives.constants import SmartsLiteral from smarts.sstudio.types import EntryTactic, TrapEntryTactic -MISSING = sys.maxsize - class PlanningError(Exception): """Raised in cases when map related planning fails.""" @@ -44,9 +42,18 @@ class PlanningError(Exception): pass +@dataclass(frozen=True) +class StartBase: + """The base type for Start objects.""" + + def resolve(self, scenario, vehicle) -> "Start": + """Converts an abstract start into a concrete one.""" + raise NotImplementedError() + + # XXX: consider using smarts.core.coordinates.Pose for this @dataclass(frozen=True) -class Start: +class Start(StartBase): """A starting state for a route or mission.""" position: np.ndarray @@ -68,6 +75,13 @@ def from_pose(cls, pose: Pose): ) +@dataclass(frozen=True) +class AutomaticStart(StartBase): + """Generates a start""" + + pass + + @dataclass(frozen=True, unsafe_hash=True) class Goal: """Describes an expected end state for a route or mission.""" @@ -81,6 +95,13 @@ def is_reached(self, vehicle_state) -> bool: return False +@dataclass(frozen=True, unsafe_hash=True) +class AutomaticGoal(Goal): + """A goal that determines an end result from pre-existing vehicle and mission values.""" + + pass + + @dataclass(frozen=True, unsafe_hash=True) class EndlessGoal(Goal): """A goal that can never be completed.""" @@ -170,7 +191,7 @@ def _drove_off_map(self, veh_pos: Point, veh_heading: float) -> bool: def default_entry_tactic(default_entry_speed: Optional[float] = None) -> EntryTactic: """The default tactic the simulation will use to acquire an actor for an agent.""" return TrapEntryTactic( - start_time=MISSING, + start_time=SmartsLiteral.MISSING, wait_to_hijack_limit_s=0, exclusion_prefixes=tuple(), zone=None, @@ -213,7 +234,7 @@ class Mission: # An optional list of road IDs between the start and end goal that we want to # ensure the mission includes route_vias: Tuple[str, ...] = field(default_factory=tuple) - start_time: float = MISSING + start_time: Union[float, Literal[SmartsLiteral.MISSING]] = SmartsLiteral.MISSING entry_tactic: Optional[EntryTactic] = None via: Tuple[Via, ...] = () # if specified, will use vehicle_spec to build the vehicle (for histories) @@ -263,9 +284,12 @@ def random_endless_mission( return Mission.endless_mission(start_pose=target_pose) def __post_init__(self): - if self.entry_tactic is not None and self.entry_tactic.start_time != MISSING: + if ( + self.entry_tactic is not None + and self.entry_tactic.start_time != SmartsLiteral.MISSING + ): object.__setattr__(self, "start_time", self.entry_tactic.start_time) - elif self.start_time == MISSING: + elif self.start_time == SmartsLiteral.MISSING: object.__setattr__(self, "start_time", 0.1) diff --git a/smarts/primatives/__init__.py b/smarts/primatives/__init__.py new file mode 100644 index 0000000000..2c65923417 --- /dev/null +++ b/smarts/primatives/__init__.py @@ -0,0 +1,21 @@ +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. diff --git a/smarts/sstudio/types/constants.py b/smarts/primatives/constants.py similarity index 85% rename from smarts/sstudio/types/constants.py rename to smarts/primatives/constants.py index 78e58ce442..a3f24df5fb 100644 --- a/smarts/sstudio/types/constants.py +++ b/smarts/primatives/constants.py @@ -21,6 +21,15 @@ # THE SOFTWARE. import sys +from enum import Enum -MAX = sys.maxsize -MISSING = sys.maxsize + +class SmartsLiteral(Enum): + AUTO = "auto" + MAX = sys.maxsize + MISSING = sys.maxsize + + +AUTO = SmartsLiteral.AUTO +MAX = SmartsLiteral.MAX +MISSING = SmartsLiteral.MISSING diff --git a/smarts/sstudio/generators.py b/smarts/sstudio/generators.py index 4762fa9634..9571ecf706 100644 --- a/smarts/sstudio/generators.py +++ b/smarts/sstudio/generators.py @@ -28,7 +28,7 @@ import random import subprocess import tempfile -from typing import Optional +from typing import Dict, Optional, Union from yattag import Doc, indent @@ -262,7 +262,9 @@ def _writexml( # Make sure all routes are "resolved" (e.g. `RandomRoute` are converted to # `Route`) so that we can write them all to file. - resolved_routes = {} + resolved_routes: Dict[ + Union[types.RandomRoute, types.Route], types.Route + ] = {} for route in {flow.route for flow in traffic.flows}: resolved_routes[route] = self.resolve_route(route, fill_in_route_gaps) @@ -277,12 +279,15 @@ def _writexml( route = resolved_routes[flow.route] for actor_idx, (actor, weight) in enumerate(flow.actors.items()): vehs_per_hour = flow.rate * (weight / total_weight) - rate_option = {} + options = {} if flow.randomly_spaced: vehs_per_sec = vehs_per_hour * SECONDS_PER_HOUR_INV - rate_option = dict(probability=vehs_per_sec) + options["probability"] = vehs_per_sec else: - rate_option = dict(vehsPerHour=vehs_per_hour) + options["vehsPerHour"] = vehs_per_hour + + if len(route.via): + options["via"] = " ".join(route.via) doc.stag( "flow", # have to encode the flow.repeat_route within the vehcile id b/c @@ -303,26 +308,26 @@ def _writexml( arrivalPos=route.end[2], begin=flow.begin, end=flow.end, - **rate_option, + **options, ) # write trip into xml format if traffic.trips: self.write_trip_xml(traffic, doc, fill_in_route_gaps) - with open(route_path, "w") as f: + with open(route_path, "w", encoding="utf-8") as f: f.write( indent( doc.getvalue(), indentation=" ", newline="\r\n", indent_text=True ) ) - def write_trip_xml(self, traffic, doc, fill_in_gaps): + def write_trip_xml(self, traffic: types.Traffic, doc: Doc, fill_in_gaps: bool): """Writes a trip spec into a route file. Typically this would be the source data to SUMO's DUAROUTER. """ # Make sure all routes are "resolved" (e.g. `RandomRoute` are converted to # `Route`) so that we can write them all to file. - resolved_routes = {} + resolved_routes: Dict[Union[types.RandomRoute, types.Route], types.Route] = {} for route in {trip.route for trip in traffic.trips}: resolved_routes[route] = self.resolve_route(route, fill_in_gaps) @@ -332,6 +337,9 @@ def write_trip_xml(self, traffic, doc, fill_in_gaps): # We don't de-dup flows since defining the same flow multiple times should # create multiple traffic flows. Since IDs can't be reused, we also unique # them here. + options: Dict[str, Union[str, int, float]] = {} + if len(route.via): + options["via"] = " ".join(route.via) for trip_idx, trip in enumerate(traffic.trips): route = resolved_routes[trip.route] actor = trip.actor @@ -346,6 +354,7 @@ def write_trip_xml(self, traffic, doc, fill_in_gaps): departSpeed=actor.depart_speed, arrivalLane=route.end[1], arrivalPos=route.end[2], + **options, ) def _cache_road_network(self): @@ -383,7 +392,7 @@ def _map_for_route(self, route) -> RoadMap: road_map, _ = map_spec.builder_fn(map_spec) return road_map - def _fill_in_gaps(self, route: types.Route) -> types.Route: + def _fill_in_traffic_route_gaps(self, route: types.Route) -> types.Route: # TODO: do this at runtime so each vehicle on the flow can take a different variation of the route ? # TODO: or do it like SUMO and generate a huge *.rou.xml file instead ? road_map = self._map_for_route(route) @@ -409,7 +418,7 @@ def resolve_route(self, route, fill_in_gaps: bool) -> types.Route: smarts.sstudio.types.route.Route: A complete route listing all road segments it passes through. """ if not isinstance(route, types.RandomRoute): - return self._fill_in_gaps(route) if fill_in_gaps else route + return self._fill_in_traffic_route_gaps(route) if fill_in_gaps else route if not self._random_route_generator: road_map = self._map_for_route(route) diff --git a/smarts/sstudio/genscenario.py b/smarts/sstudio/genscenario.py index 0b43d12f1c..7734cd17bc 100644 --- a/smarts/sstudio/genscenario.py +++ b/smarts/sstudio/genscenario.py @@ -30,7 +30,7 @@ import sqlite3 from dataclasses import dataclass, replace from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union import cloudpickle import yaml @@ -264,8 +264,7 @@ def gen_scenario( if isinstance(mission, types.GroupedLapMission): gen_group_laps( scenario=output_dir, - begin=mission.route.begin, - end=mission.route.end, + route=mission.route, grid_offset=mission.offset, used_lanes=mission.lanes, vehicle_count=mission.actor_count, @@ -484,8 +483,7 @@ def gen_agent_missions( def gen_group_laps( scenario: str, - begin: Tuple[str, int, Any], - end: Tuple[str, int, Any], + route: Union[types.Route, Literal[types.AUTO]], grid_offset: int, used_lanes: int, vehicle_count: int, @@ -514,8 +512,11 @@ def gen_group_laps( """ _check_if_called_externally() - start_road_id, start_lane, start_offset = begin - end_road_id, end_lane, end_offset = end + if route is types.AUTO or route.begin is types.AUTO or route.end is types.AUTO: + raise ValueError("Automatic routes are not implemented.") + + start_road_id, start_lane, start_offset = route.begin + end_road_id, end_lane, end_offset = route.end missions = [] for i in range(vehicle_count): @@ -639,17 +640,23 @@ def _validate_entry_tactic(mission): return z_edge, _, _ = mission.entry_tactic.zone.start - if isinstance(mission, types.EndlessMission): + if isinstance(mission, types.EndlessMission) and mission.start != types.AUTO: edge, _, _ = mission.start - assert ( - edge == z_edge - ), f"Zone edge `{z_edge}` is not the same edge as `types.EndlessMission` start edge `{edge}`" + assert edge == z_edge, ( + f"Zone edge `{z_edge}` is not the same edge as `types.EndlessMission` start edge `{edge}`." + "Perhaps you wish to use `types.AUTO`?" + ) - elif isinstance(mission, (types.Mission, types.LapMission)): + elif ( + isinstance(mission, (types.Mission, types.LapMission)) + and mission.route != types.AUTO + and mission.route.begin != types.AUTO + ): edge, _, _ = mission.route.begin - assert ( - edge == z_edge - ), f"Zone edge `{z_edge}` is not the same edge as `types.Mission` route begin edge `{edge}`" + assert edge == z_edge, ( + f"Zone edge `{z_edge}` is not the same edge as `types.Mission` route begin edge `{edge}`." + "Perhaps you wish to use `types.AUTO`?" + ) def gen_traffic_histories( diff --git a/smarts/sstudio/types/__init__.py b/smarts/sstudio/types/__init__.py index c706e64cf7..641177ca58 100644 --- a/smarts/sstudio/types/__init__.py +++ b/smarts/sstudio/types/__init__.py @@ -17,13 +17,13 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. +from smarts.primatives.constants import * from smarts.sstudio.types.actor.social_agent_actor import * from smarts.sstudio.types.actor.traffic_actor import * from smarts.sstudio.types.actor.traffic_engine_actor import * from smarts.sstudio.types.bubble import * from smarts.sstudio.types.bubble_limits import * from smarts.sstudio.types.condition import * -from smarts.sstudio.types.constants import * from smarts.sstudio.types.dataset import * from smarts.sstudio.types.distribution import * from smarts.sstudio.types.entry_tactic import * diff --git a/smarts/sstudio/types/bubble_limits.py b/smarts/sstudio/types/bubble_limits.py index a4255bff2a..945b467b2b 100644 --- a/smarts/sstudio/types/bubble_limits.py +++ b/smarts/sstudio/types/bubble_limits.py @@ -23,7 +23,7 @@ from dataclasses import dataclass -from smarts.sstudio.types.constants import MAX +from smarts.primatives.constants import MAX @dataclass(frozen=True) diff --git a/smarts/sstudio/types/mission.py b/smarts/sstudio/types/mission.py index d65450b643..b0b5157b36 100644 --- a/smarts/sstudio/types/mission.py +++ b/smarts/sstudio/types/mission.py @@ -22,9 +22,9 @@ import sys import warnings from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Literal, Optional, Tuple, Union -from smarts.sstudio.types.constants import MISSING +from smarts.primatives.constants import SmartsLiteral from smarts.sstudio.types.entry_tactic import EntryTactic from smarts.sstudio.types.route import JunctionEdgeIDResolver, RandomRoute, Route @@ -55,7 +55,7 @@ class Mission: via: Tuple[Via, ...] = () """Points on an road that an actor must pass through""" - start_time: float = MISSING + start_time: Union[float, Literal[SmartsLiteral.MISSING]] = SmartsLiteral.MISSING """The earliest simulation time that this mission starts but may start later in couple with `entry_tactic`. """ @@ -87,7 +87,7 @@ class EndlessMission: """ via: Tuple[Via, ...] = () """Points on a road that an actor must pass through""" - start_time: float = MISSING + start_time: Union[float, Literal[SmartsLiteral.MISSING]] = SmartsLiteral.MISSING """The earliest simulation time that this mission starts""" entry_tactic: Optional[EntryTactic] = None """A specific tactic the mission should employ to start the mission""" @@ -107,17 +107,20 @@ class LapMission: """ route: Route - """The route for the actor to attempt to follow""" + """The route for the actor to attempt to follow. This cannot have automatic values.""" num_laps: int """The amount of times to repeat the mission""" via: Tuple[Via, ...] = () """Points on a road that an actor must pass through""" - start_time: float = MISSING + start_time: Union[float, Literal[SmartsLiteral.MISSING]] = SmartsLiteral.MISSING """The earliest simulation time that this mission starts""" entry_tactic: Optional[EntryTactic] = None """A specific tactic the mission should employ to start the mission""" def __post_init__(self): + assert isinstance(self.route, Route) + assert self.route.begin != SmartsLiteral.AUTO + assert self.route.end != SmartsLiteral.AUTO if self.start_time != sys.maxsize: warnings.warn( "`start_time` is deprecated. Instead use `entry_tactic=EntryTactic(start_time=...)`.", diff --git a/smarts/sstudio/types/route.py b/smarts/sstudio/types/route.py index 300f565401..e9522086fa 100644 --- a/smarts/sstudio/types/route.py +++ b/smarts/sstudio/types/route.py @@ -22,10 +22,11 @@ from dataclasses import dataclass, field -from typing import Any, Optional, Tuple +from typing import Any, Literal, Optional, Tuple, Union from smarts.core import gen_id from smarts.core.utils.file import pickle_hash_int +from smarts.primatives.constants import SmartsLiteral from smarts.sstudio.types.map_spec import MapSpec @@ -58,7 +59,7 @@ class Route: """ ## road, lane index, offset - begin: Tuple[str, int, Any] + begin: Union[Tuple[str, int, Any], Literal[SmartsLiteral.AUTO]] """The (road, lane_index, offset) details of the start location for the route. road: @@ -69,7 +70,7 @@ class Route: The offset in meters into the lane. Also acceptable\\: "max", "random" """ ## road, lane index, offset - end: Tuple[str, int, Any] + end: Union[Tuple[str, int, Any], Literal[SmartsLiteral.AUTO]] """The (road, lane_index, offset) details of the end location for the route. road: From ed6ce97e5c82d13403606e088cc95d6695b2609a Mon Sep 17 00:00:00 2001 From: Tucker Date: Mon, 24 Apr 2023 09:59:01 -0400 Subject: [PATCH 2/5] Clarify test --- envision/tests/test_data_formatter.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/envision/tests/test_data_formatter.py b/envision/tests/test_data_formatter.py index 3226a5407f..21c6b0f504 100644 --- a/envision/tests/test_data_formatter.py +++ b/envision/tests/test_data_formatter.py @@ -233,26 +233,25 @@ def complex_data(): def test_covered_data_format(covered_data): - for item in covered_data: + for unformatted, formatted in covered_data: es = EnvisionDataFormatter(EnvisionDataFormatterArgs(None)) - vt = item[0] - _formatter_map[type(vt)](vt, es) + _formatter_map[type(unformatted)](unformatted, es) data = es.resolve() - assert data == item[1] + assert data == formatted assert data == unpack(data) def test_primitive_data_format(primitive_data): - for item in primitive_data: - vt = item[0] + for unformatted, formatted in primitive_data: + es = EnvisionDataFormatter(EnvisionDataFormatterArgs(None)) - es.add_any(vt) + es.add_any(unformatted) data = es.resolve() - assert data == item[1] + assert data == formatted assert data == unpack(data) @@ -276,14 +275,13 @@ def test_layer(): def test_complex_data(complex_data): - for item in complex_data: - vt = item[0] + for unformatted, formatted in complex_data: es = EnvisionDataFormatter(EnvisionDataFormatterArgs(None)) - es.add_any(vt) + es.add_any(unformatted) data = es.resolve() - assert data == item[1] + assert data == formatted assert data == unpack(data) From 0b53bb0812966f392610585b4cfd9e371806a9c1 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 9 Jun 2023 18:30:18 +0000 Subject: [PATCH 3/5] Update literals. --- envision/tests/test_data_formatter.py | 2 +- smarts/core/plan.py | 13 +++++-------- smarts/core/trap_manager.py | 5 +++-- smarts/primatives/constants.py | 14 ++++++++------ smarts/sstudio/types/entry_tactic.py | 7 ++++--- smarts/sstudio/types/map_spec.py | 5 +++-- smarts/sstudio/types/mission.py | 12 ++++++------ smarts/sstudio/types/route.py | 6 +++--- 8 files changed, 33 insertions(+), 31 deletions(-) diff --git a/envision/tests/test_data_formatter.py b/envision/tests/test_data_formatter.py index 21c6b0f504..b7a74cbcd1 100644 --- a/envision/tests/test_data_formatter.py +++ b/envision/tests/test_data_formatter.py @@ -245,7 +245,7 @@ def test_covered_data_format(covered_data): def test_primitive_data_format(primitive_data): for unformatted, formatted in primitive_data: - + es = EnvisionDataFormatter(EnvisionDataFormatterArgs(None)) es.add_any(unformatted) diff --git a/smarts/core/plan.py b/smarts/core/plan.py index 0b262141c9..ae9f398ccb 100644 --- a/smarts/core/plan.py +++ b/smarts/core/plan.py @@ -32,7 +32,7 @@ from smarts.core.coordinates import Dimensions, Heading, Point, Pose, RefLinePoint from smarts.core.road_map import RoadMap from smarts.core.utils.math import min_angles_difference_signed, vec_to_radians -from smarts.primatives.constants import SmartsLiteral +from smarts.primatives.constants import MISSING from smarts.sstudio.types import EntryTactic, TrapEntryTactic @@ -191,7 +191,7 @@ def _drove_off_map(self, veh_pos: Point, veh_heading: float) -> bool: def default_entry_tactic(default_entry_speed: Optional[float] = None) -> EntryTactic: """The default tactic the simulation will use to acquire an actor for an agent.""" return TrapEntryTactic( - start_time=SmartsLiteral.MISSING, + start_time=MISSING, wait_to_hijack_limit_s=0, exclusion_prefixes=tuple(), zone=None, @@ -234,7 +234,7 @@ class Mission: # An optional list of road IDs between the start and end goal that we want to # ensure the mission includes route_vias: Tuple[str, ...] = field(default_factory=tuple) - start_time: Union[float, Literal[SmartsLiteral.MISSING]] = SmartsLiteral.MISSING + start_time: Union[float, Literal[MISSING]] = MISSING entry_tactic: Optional[EntryTactic] = None via: Tuple[Via, ...] = () # if specified, will use vehicle_spec to build the vehicle (for histories) @@ -284,12 +284,9 @@ def random_endless_mission( return Mission.endless_mission(start_pose=target_pose) def __post_init__(self): - if ( - self.entry_tactic is not None - and self.entry_tactic.start_time != SmartsLiteral.MISSING - ): + if self.entry_tactic is not None and self.entry_tactic.start_time != MISSING: object.__setattr__(self, "start_time", self.entry_tactic.start_time) - elif self.start_time == SmartsLiteral.MISSING: + elif self.start_time == MISSING: object.__setattr__(self, "start_time", 0.1) diff --git a/smarts/core/trap_manager.py b/smarts/core/trap_manager.py index f3a8638bd3..ba00bd1dd5 100644 --- a/smarts/core/trap_manager.py +++ b/smarts/core/trap_manager.py @@ -32,6 +32,7 @@ from smarts.core.utils.file import replace from smarts.core.utils.math import clip, squared_dist from smarts.core.vehicle import Vehicle +from smarts.primatives.constants import AUTO from smarts.sstudio.types import MapZone, PositionalZone, TrapEntryTactic @@ -355,11 +356,11 @@ def _mission2trap(self, road_map, mission: Mission, default_zone_dist: float = 6 default_entry_speed = entry_tactic.default_entry_speed n_lane = None - if default_entry_speed is None: + if default_entry_speed is AUTO: n_lane = road_map.nearest_lane(mission.start.point) default_entry_speed = n_lane.speed_limit if n_lane is not None else 0 - if zone is None: + if zone is AUTO: n_lane = n_lane or road_map.nearest_lane(mission.start.point) if n_lane is None: zone = PositionalZone(mission.start.position[:2], size=(3, 3)) diff --git a/smarts/primatives/constants.py b/smarts/primatives/constants.py index a3f24df5fb..b11a8bc63d 100644 --- a/smarts/primatives/constants.py +++ b/smarts/primatives/constants.py @@ -20,16 +20,18 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -import sys from enum import Enum +from typing import Final class SmartsLiteral(Enum): - AUTO = "auto" - MAX = sys.maxsize - MISSING = sys.maxsize + AUTO: Final = "auto" + MAX: Final = 9223372036854775807 + MISSING: Final = MAX + NONE: Final = None -AUTO = SmartsLiteral.AUTO -MAX = SmartsLiteral.MAX +AUTO: Final = SmartsLiteral.AUTO +MAX: Final = SmartsLiteral.MAX MISSING = SmartsLiteral.MISSING +NONE: Final = SmartsLiteral.NONE diff --git a/smarts/sstudio/types/entry_tactic.py b/smarts/sstudio/types/entry_tactic.py index 33e19b449f..d078721123 100644 --- a/smarts/sstudio/types/entry_tactic.py +++ b/smarts/sstudio/types/entry_tactic.py @@ -22,9 +22,10 @@ from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Literal, Tuple, Union from smarts.core.condition_state import ConditionState +from smarts.primatives.constants import AUTO from smarts.sstudio.types.condition import ( Condition, ConditionRequires, @@ -51,11 +52,11 @@ class TrapEntryTactic(EntryTactic): wait_to_hijack_limit_s: float = 0 """The amount of seconds a hijack will wait to get a vehicle before defaulting to a new vehicle""" - zone: Optional[MapZone] = None + zone: Union[MapZone, Literal[AUTO]] = AUTO """The zone of the hijack area""" exclusion_prefixes: Tuple[str, ...] = tuple() """The prefixes of vehicles to avoid hijacking""" - default_entry_speed: Optional[float] = None + default_entry_speed: Union[float, Literal[AUTO]] = AUTO """The speed that the vehicle starts at when the hijack limit expiry emits a new vehicle""" condition: Condition = LiteralCondition(ConditionState.TRUE) """A condition that is used to add additional exclusions.""" diff --git a/smarts/sstudio/types/map_spec.py b/smarts/sstudio/types/map_spec.py index bbfce962a7..7ebb31fb28 100644 --- a/smarts/sstudio/types/map_spec.py +++ b/smarts/sstudio/types/map_spec.py @@ -29,10 +29,11 @@ # The idea here is that anything in SMARTS that needs to use a RoadMap # can call this builder to get or create one as necessary. from dataclasses import dataclass -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Literal, Optional, Tuple, Union from smarts.core.default_map_builder import get_road_map from smarts.core.road_map import RoadMap +from smarts.primatives.constants import AUTO MapBuilder = Callable[[Any], Tuple[Optional[RoadMap], Optional[str]]] @@ -45,7 +46,7 @@ class MapSpec: """A path or URL or name uniquely designating the map source.""" lanepoint_spacing: float = 1.0 """The default distance between pre-generated Lane Points (Waypoints).""" - default_lane_width: Optional[float] = None + default_lane_width: Union[float, Literal[AUTO]] = AUTO """If specified, the default width (in meters) of lanes on this map.""" shift_to_origin: bool = False """If True, upon creation a map whose bounding-box does not intersect with diff --git a/smarts/sstudio/types/mission.py b/smarts/sstudio/types/mission.py index b0b5157b36..431a34b34e 100644 --- a/smarts/sstudio/types/mission.py +++ b/smarts/sstudio/types/mission.py @@ -24,7 +24,7 @@ from dataclasses import dataclass from typing import Literal, Optional, Tuple, Union -from smarts.primatives.constants import SmartsLiteral +from smarts.primatives.constants import AUTO, MISSING from smarts.sstudio.types.entry_tactic import EntryTactic from smarts.sstudio.types.route import JunctionEdgeIDResolver, RandomRoute, Route @@ -55,7 +55,7 @@ class Mission: via: Tuple[Via, ...] = () """Points on an road that an actor must pass through""" - start_time: Union[float, Literal[SmartsLiteral.MISSING]] = SmartsLiteral.MISSING + start_time: Union[float, Literal[MISSING]] = MISSING """The earliest simulation time that this mission starts but may start later in couple with `entry_tactic`. """ @@ -87,7 +87,7 @@ class EndlessMission: """ via: Tuple[Via, ...] = () """Points on a road that an actor must pass through""" - start_time: Union[float, Literal[SmartsLiteral.MISSING]] = SmartsLiteral.MISSING + start_time: Union[float, Literal[MISSING]] = MISSING """The earliest simulation time that this mission starts""" entry_tactic: Optional[EntryTactic] = None """A specific tactic the mission should employ to start the mission""" @@ -112,15 +112,15 @@ class LapMission: """The amount of times to repeat the mission""" via: Tuple[Via, ...] = () """Points on a road that an actor must pass through""" - start_time: Union[float, Literal[SmartsLiteral.MISSING]] = SmartsLiteral.MISSING + start_time: Union[float, Literal[MISSING]] = MISSING """The earliest simulation time that this mission starts""" entry_tactic: Optional[EntryTactic] = None """A specific tactic the mission should employ to start the mission""" def __post_init__(self): assert isinstance(self.route, Route) - assert self.route.begin != SmartsLiteral.AUTO - assert self.route.end != SmartsLiteral.AUTO + assert self.route.begin != AUTO + assert self.route.end != AUTO if self.start_time != sys.maxsize: warnings.warn( "`start_time` is deprecated. Instead use `entry_tactic=EntryTactic(start_time=...)`.", diff --git a/smarts/sstudio/types/route.py b/smarts/sstudio/types/route.py index e9522086fa..66c5675f13 100644 --- a/smarts/sstudio/types/route.py +++ b/smarts/sstudio/types/route.py @@ -26,7 +26,7 @@ from smarts.core import gen_id from smarts.core.utils.file import pickle_hash_int -from smarts.primatives.constants import SmartsLiteral +from smarts.primatives.constants import AUTO from smarts.sstudio.types.map_spec import MapSpec @@ -59,7 +59,7 @@ class Route: """ ## road, lane index, offset - begin: Union[Tuple[str, int, Any], Literal[SmartsLiteral.AUTO]] + begin: Union[Tuple[str, int, Any], Literal[AUTO]] """The (road, lane_index, offset) details of the start location for the route. road: @@ -70,7 +70,7 @@ class Route: The offset in meters into the lane. Also acceptable\\: "max", "random" """ ## road, lane index, offset - end: Union[Tuple[str, int, Any], Literal[SmartsLiteral.AUTO]] + end: Union[Tuple[str, int, Any], Literal[AUTO]] """The (road, lane_index, offset) details of the end location for the route. road: From d3c803c5699674318f524e1a987c8b0edffcd872 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 9 Jun 2023 18:41:13 +0000 Subject: [PATCH 4/5] Add INHERIT as an option. --- smarts/core/plan.py | 12 ++++++++---- smarts/primatives/constants.py | 10 ++++++---- smarts/sstudio/types/actor/social_agent_actor.py | 6 +++--- smarts/sstudio/types/mission.py | 2 +- smarts/sstudio/types/route.py | 2 +- 5 files changed, 19 insertions(+), 13 deletions(-) diff --git a/smarts/core/plan.py b/smarts/core/plan.py index ae9f398ccb..d5412a7eb3 100644 --- a/smarts/core/plan.py +++ b/smarts/core/plan.py @@ -76,10 +76,8 @@ def from_pose(cls, pose: Pose): @dataclass(frozen=True) -class AutomaticStart(StartBase): - """Generates a start""" - - pass +class InheritedStart(StartBase): + """A starting state that inherits from the original vehicle.""" @dataclass(frozen=True, unsafe_hash=True) @@ -95,6 +93,12 @@ def is_reached(self, vehicle_state) -> bool: return False +@dataclass(frozen=True) +class InheritedGoal(Goal): + """Describes a goal that is inherited from the vehicle (or original dataset).""" + + pass + @dataclass(frozen=True, unsafe_hash=True) class AutomaticGoal(Goal): """A goal that determines an end result from pre-existing vehicle and mission values.""" diff --git a/smarts/primatives/constants.py b/smarts/primatives/constants.py index b11a8bc63d..2bc5a95b4e 100644 --- a/smarts/primatives/constants.py +++ b/smarts/primatives/constants.py @@ -25,13 +25,15 @@ class SmartsLiteral(Enum): - AUTO: Final = "auto" - MAX: Final = 9223372036854775807 - MISSING: Final = MAX - NONE: Final = None + AUTO = "auto" + INHERIT = ... + MAX = 9223372036854775807 + MISSING = MAX + NONE = None AUTO: Final = SmartsLiteral.AUTO +INHERIT: Final = SmartsLiteral.INHERIT MAX: Final = SmartsLiteral.MAX MISSING = SmartsLiteral.MISSING NONE: Final = SmartsLiteral.NONE diff --git a/smarts/sstudio/types/actor/social_agent_actor.py b/smarts/sstudio/types/actor/social_agent_actor.py index 3c176c0145..3bd3d4e3f1 100644 --- a/smarts/sstudio/types/actor/social_agent_actor.py +++ b/smarts/sstudio/types/actor/social_agent_actor.py @@ -31,8 +31,8 @@ @dataclass(frozen=True) class SocialAgentActor(Actor): - """Used as a description/spec for zoo traffic actors. These actors use a - pre-trained model to understand how to act in the environment. + """Used as a description/spec for zoo traffic actors. These actors are controlled by a + pre-trained model that understands how to behave in the environment. """ # A pre-registered zoo identifying tag you provide to help SMARTS identify the @@ -40,7 +40,7 @@ class SocialAgentActor(Actor): agent_locator: str """The locator reference to the zoo registration call. Expects a string in the format of 'path.to.file:locator-name' where the path to the registration call is in the form - `{PYTHONPATH}[n]/path/to/file.py` + ``{PYTHONPATH}[n]/path/to/file.py``. """ policy_kwargs: Dict[str, Any] = field(default_factory=dict) """Additional keyword arguments to be passed to the constructed class overriding the diff --git a/smarts/sstudio/types/mission.py b/smarts/sstudio/types/mission.py index 431a34b34e..3a701414dd 100644 --- a/smarts/sstudio/types/mission.py +++ b/smarts/sstudio/types/mission.py @@ -75,7 +75,7 @@ def __post_init__(self): class EndlessMission: """The descriptor for an actor's mission that has no end.""" - begin: Tuple[str, int, float] + begin: Union[Tuple[str, int, float], Literal[AUTO]] """The (road, lane_index, offset) details of the start location for the route. road: diff --git a/smarts/sstudio/types/route.py b/smarts/sstudio/types/route.py index 66c5675f13..7f68fe85e1 100644 --- a/smarts/sstudio/types/route.py +++ b/smarts/sstudio/types/route.py @@ -86,7 +86,7 @@ class Route: """The ids of roads that must be included in the route between `begin` and `end`.""" map_spec: Optional[MapSpec] = None - """All routes are relative to a road map. If not specified here, + """All routes are relative to a road map. If not specified here, the default map_spec for the scenario is used.""" @property From cd2ca1146530e18375085f30ca67da8caa4d71d9 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 9 Jun 2023 19:19:53 +0000 Subject: [PATCH 5/5] Fix format tests. --- docs/conf.py | 5 +++-- smarts/core/plan.py | 1 + smarts/primatives/constants.py | 4 ++++ smarts/sstudio/generators.py | 4 ++-- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 186800f968..4178e113f2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -124,6 +124,7 @@ } nitpick_ignore_regex = { (r"py:.*", r"av2\..*"), + (r"py:.*", r"configparser\..*"), (r"py:.*", r"google\.protobuf\..*"), (r"py:.*", r"grpc\..*"), (r"py:.*", r"gym\..*"), @@ -142,9 +143,9 @@ (r"py:.*", r"tornado\..*"), (r"py:.*", r"traci\..*"), (r"py:.*", r"typing(_extensions)?\..*"), - (r"py:.*", r"configparser\..*"), - (r"py:class", r".*\.?T"), (r"py:class", r".*\.?S"), + (r"py:class", r".*\.?T"), + (r"py:class", r".*typing.Literal\[\].*"), } # -- Options for spelling ---------------------------------------------------- diff --git a/smarts/core/plan.py b/smarts/core/plan.py index d5412a7eb3..7c238ff925 100644 --- a/smarts/core/plan.py +++ b/smarts/core/plan.py @@ -99,6 +99,7 @@ class InheritedGoal(Goal): pass + @dataclass(frozen=True, unsafe_hash=True) class AutomaticGoal(Goal): """A goal that determines an end result from pre-existing vehicle and mission values.""" diff --git a/smarts/primatives/constants.py b/smarts/primatives/constants.py index 2bc5a95b4e..ba4e5fb547 100644 --- a/smarts/primatives/constants.py +++ b/smarts/primatives/constants.py @@ -25,6 +25,10 @@ class SmartsLiteral(Enum): + """Constants that SMARTS uses. This is intended to constant type the + values. + """ + AUTO = "auto" INHERIT = ... MAX = 9223372036854775807 diff --git a/smarts/sstudio/generators.py b/smarts/sstudio/generators.py index 9571ecf706..74187d9419 100644 --- a/smarts/sstudio/generators.py +++ b/smarts/sstudio/generators.py @@ -338,10 +338,10 @@ def write_trip_xml(self, traffic: types.Traffic, doc: Doc, fill_in_gaps: bool): # create multiple traffic flows. Since IDs can't be reused, we also unique # them here. options: Dict[str, Union[str, int, float]] = {} - if len(route.via): - options["via"] = " ".join(route.via) for trip_idx, trip in enumerate(traffic.trips): route = resolved_routes[trip.route] + if len(route.via): + options["via"] = " ".join(route.via) actor = trip.actor doc.stag( "vehicle",