Skip to content

Commit

Permalink
Handle missing parent and invalid categories when adding locations an…
Browse files Browse the repository at this point in the history
…d objects (#144)
  • Loading branch information
sea-bass authored Sep 6, 2023
1 parent 853a338 commit 89e385b
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 11 deletions.
7 changes: 6 additions & 1 deletion pyrobosim/pyrobosim/core/locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from shapely import intersects_xy
from shapely.plotting import patch_from_polygon

from ..utils.general import EntityMetadata
from ..utils.general import EntityMetadata, InvalidEntityCategoryException
from ..utils.pose import Pose, rot2d
from ..utils.polygon import (
inflate_polygon,
Expand Down Expand Up @@ -60,6 +60,11 @@ def __init__(self, name=None, category=None, pose=None, parent=None, color=None)
self.parent = parent

self.metadata = Location.metadata.get(self.category)
if not self.metadata:
raise InvalidEntityCategoryException(
f"Invalid location category: {self.category}"
)

if color is not None:
self.viz_color = color
elif "color" in self.metadata:
Expand Down
7 changes: 6 additions & 1 deletion pyrobosim/pyrobosim/core/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from shapely.plotting import patch_from_polygon
from scipy.spatial import ConvexHull

from ..utils.general import EntityMetadata
from ..utils.general import EntityMetadata, InvalidEntityCategoryException
from ..utils.pose import Pose
from ..utils.polygon import (
convhull_to_rectangle,
Expand Down Expand Up @@ -69,6 +69,11 @@ def __init__(
self.viz_text = None

self.metadata = Object.metadata.get(self.category)
if not self.metadata:
raise InvalidEntityCategoryException(
f"Invalid object category: {self.category}"
)

if color is not None:
self.viz_color = color
elif "color" in self.metadata:
Expand Down
27 changes: 21 additions & 6 deletions pyrobosim/pyrobosim/core/world.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .objects import Object
from .room import Room
from .robot import Robot
from ..utils.general import InvalidEntityCategoryException
from ..utils.pose import Pose
from ..utils.knowledge import (
apply_resolution_strategy,
Expand Down Expand Up @@ -303,13 +304,20 @@ def add_location(self, **location_config):
# Else, create a location directly from the specified arguments.
if "location" in location_config:
loc = location_config["location"]
else:
elif "parent" in location_config:
if isinstance(location_config["parent"], str):
location_config["parent"] = self.get_room_by_name(
location_config["parent"]
)

loc = Location(**location_config)
try:
loc = Location(**location_config)
except InvalidEntityCategoryException as exception:
warnings.warn(exception)
return None
else:
warnings.warn("Location instance or parent must be specified.")
return None

# If the category name is empty, use "location" as the base name.
category = loc.category
Expand Down Expand Up @@ -450,11 +458,11 @@ def add_object(self, **object_config):
:return: Object instance if successfully created, else None.
:rtype: :class:`pyrobosim.core.objects.Object`
"""
# If it's on Object instance, get it from the "location" named argument.
# Else, create a location directly from the specified arguments.
# If it's an Object instance, get it from the "object" named argument.
# Else, create an object directly from the specified arguments.
if "object" in object_config:
obj = object_config["object"]
else:
elif "parent" in object_config:
parent = object_config.get("parent", None)
if isinstance(parent, str):
parent = self.get_entity_by_name(parent)
Expand All @@ -473,7 +481,14 @@ def add_object(self, **object_config):

object_config["parent"] = parent
object_config["inflation_radius"] = self.object_radius
obj = Object(**object_config)
try:
obj = Object(**object_config)
except InvalidEntityCategoryException as exception:
warnings.warn(exception)
return None
else:
warnings.warn("Object instance or parent must be specified.")
return None

# If the category name is empty, use "object" as the base name.
category = obj.category
Expand Down
6 changes: 6 additions & 0 deletions pyrobosim/pyrobosim/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ def get(self, category):
return self.data.get(category, None)


class InvalidEntityCategoryException(Exception):
"""Raised when an invalid entity metadata category is used."""

pass


def replace_special_yaml_tokens(in_text, root_dir=None):
"""
Replaces special tokens permitted in our YAML specification.
Expand Down
36 changes: 35 additions & 1 deletion test/core/test_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pyrobosim.core import Hallway, Object, World
from pyrobosim.utils.pose import Pose

from pyrobosim.utils.general import get_data_folder
from pyrobosim.utils.general import get_data_folder, InvalidEntityCategoryException


class TestWorldModeling:
Expand Down Expand Up @@ -104,6 +104,23 @@ def test_create_location():
assert len(TestWorldModeling.world.locations) == 2
assert TestWorldModeling.world.get_location_by_name("study_desk") == desk

# Test missing parent
with pytest.warns(UserWarning):
result = TestWorldModeling.world.add_location(
category="desk",
pose=Pose(),
)
assert result is None

# Test invalid category
with pytest.warns(UserWarning):
result = TestWorldModeling.world.add_location(
category="does_not_exist",
parent="bedroom",
pose=Pose(),
)
assert result is None

@staticmethod
@pytest.mark.dependency(
depends=[
Expand All @@ -127,6 +144,23 @@ def test_create_object():
TestWorldModeling.world.get_object_by_name("apple1") == test_obj
) # Automatic naming

# Test missing parent
with pytest.warns(UserWarning):
result = TestWorldModeling.world.add_object(
category="apple",
pose=Pose(),
)
assert result is None

# Test invalid category
with pytest.warns(UserWarning):
result = TestWorldModeling.world.add_object(
category="does_not_exist",
parent="study_desk",
pose=Pose(),
)
assert result is None

@staticmethod
@pytest.mark.dependency(
depends=[
Expand Down
60 changes: 58 additions & 2 deletions test/core/test_yaml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ def test_create_world_yaml_loader():
TestWorldYamlLoading.yaml_loader = WorldYamlLoader()

# Clean up metadata for test reproducibility
delattr(Location, "metadata")
delattr(Object, "metadata")
if hasattr(Location, "metadata"):
delattr(Location, "metadata")
if hasattr(Object, "metadata"):
delattr(Object, "metadata")

@staticmethod
@pytest.mark.dependency(
Expand Down Expand Up @@ -177,6 +179,33 @@ def test_create_locations_from_yaml():
loader.add_locations()
assert len(loader.world.locations) == 0

# No parent means the location is not added.
loader.data = {
"locations": [
{
"category": "table",
"pose": [0.85, -0.5, 0.0, -1.57],
}
]
}
with pytest.warns(UserWarning):
loader.add_locations()
assert len(loader.world.locations) == 0

# Invalid location category means the object is not added.
loader.data = {
"locations": [
{
"category": "does_not_exist",
"parent": "kitchen",
"pose": [0.85, -0.5, 0.0, -1.57],
}
]
}
with pytest.warns(UserWarning):
loader.add_locations()
assert len(loader.world.locations) == 0

# Load locations from a YAML specified dictionary.
locations_dict = {
"locations": [
Expand Down Expand Up @@ -222,6 +251,33 @@ def test_create_objects_from_yaml():
loader.add_objects()
assert len(loader.world.objects) == 0

# No parent means the object is not added.
loader.data = {
"objects": [
{
"category": "banana",
"pose": [3.2, 3.5, 0.0, 0.707],
}
]
}
with pytest.warns(UserWarning):
loader.add_objects()
assert len(loader.world.objects) == 0

# Invalid object category means the object is not added.
loader.data = {
"objects": [
{
"category": "does_not_exist",
"parent": "table0",
"pose": [3.2, 3.5, 0.0, 0.707],
}
]
}
with pytest.warns(UserWarning):
loader.add_objects()
assert len(loader.world.objects) == 0

# Load objects from a YAML specified dictionary.
objects_dict = {
"objects": [
Expand Down

0 comments on commit 89e385b

Please sign in to comment.