Skip to content

Commit

Permalink
Fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
mrica-equinor committed Oct 10, 2024
1 parent 696eb73 commit d5cebfb
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 74 deletions.
9 changes: 5 additions & 4 deletions src/isar/apis/models/start_mission_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from robot_interface.models.inspection.inspection import Inspection
from robot_interface.models.mission.mission import Mission
from robot_interface.models.mission.task import (
TASKS,
DockingProcedure,
Localize,
RecordAudio,
Expand Down Expand Up @@ -65,10 +66,10 @@ class StartMissionDefinition(BaseModel):


def to_isar_mission(start_mission_definition: StartMissionDefinition) -> Mission:
isar_tasks: List[Task] = []
isar_tasks: List[TASKS] = []

for start_mission_task_definition in start_mission_definition.tasks:
task: Task = create_isar_task(start_mission_task_definition)
task: TASKS = create_isar_task(start_mission_task_definition)
if start_mission_task_definition.id:
task.id = start_mission_task_definition.id
isar_tasks.append(task)
Expand Down Expand Up @@ -114,7 +115,7 @@ def to_isar_mission(start_mission_definition: StartMissionDefinition) -> Mission
return isar_mission


def check_for_duplicate_ids(items: List[Task]):
def check_for_duplicate_ids(items: List[TASKS]):
duplicate_ids = get_duplicate_ids(items=items)
if len(duplicate_ids) > 0:
raise MissionPlannerError(
Expand Down Expand Up @@ -213,7 +214,7 @@ def create_dock_task() -> DockingProcedure:
return DockingProcedure(behavior="dock")


def get_duplicate_ids(items: List[Task]) -> List[str]:
def get_duplicate_ids(items: List[TASKS]) -> List[str]:
unique_ids: List[str] = []
duplicate_ids: List[str] = []
for item in items:
Expand Down
50 changes: 2 additions & 48 deletions src/isar/mission_planner/local_planner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import List
from typing import List, Optional

from alitra import Frame
from injector import inject
Expand All @@ -13,18 +13,7 @@
)
from isar.services.readers.base_reader import BaseReader, BaseReaderError
from robot_interface.models.mission.mission import Mission
from robot_interface.models.mission.task import (
DockingProcedure,
Localize,
MoveArm,
RecordAudio,
ReturnToHome,
TakeImage,
TakeThermalImage,
TakeThermalVideo,
TakeVideo,
Task,
)


logger = logging.getLogger("api")

Expand Down Expand Up @@ -52,48 +41,13 @@ def get_mission(self, mission_id) -> Mission:
def read_mission_from_file(mission_path: Path) -> Mission:
mission_dict: dict = BaseReader.read_json(location=mission_path)

mission_tasks: List[Task] = []
task_dataclass: Task = None

for task in mission_dict["tasks"]:
if task["type"] == "return_to_home":
task_dataclass = ReturnToHome
elif task["type"] == "localize":
task_dataclass = Localize
elif task["type"] == "move_arm":
task_dataclass = MoveArm
elif task["type"] == "take_image":
task_dataclass = TakeImage
elif task["type"] == "take_thermal_image":
task_dataclass = TakeThermalImage
elif task["type"] == "take_video":
task_dataclass = TakeVideo
elif task["type"] == "take_thermal_video":
task_dataclass = TakeThermalVideo
elif task["type"] == "record_audio":
task_dataclass = RecordAudio
elif task["type"] == "docking_procedure":
task_dataclass = DockingProcedure

if task_dataclass:
task: Task = BaseReader.dict_to_dataclass(
dataclass_dict=task,
target_dataclass=task_dataclass,
cast_config=[Frame],
strict_config=True,
)
mission_tasks.append(task)

mission_dict["tasks"] = []
mission: Mission = BaseReader.dict_to_dataclass(
dataclass_dict=mission_dict,
target_dataclass=Mission,
cast_config=[Frame],
strict_config=True,
)

mission.tasks = mission_tasks

return mission

def get_predefined_missions(self) -> dict:
Expand Down
4 changes: 2 additions & 2 deletions src/isar/mission_planner/sequential_task_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
TaskSelectorInterface,
TaskSelectorStop,
)
from robot_interface.models.mission.task import Task
from robot_interface.models.mission.task import TASKS, Task


class SequentialTaskSelector(TaskSelectorInterface):
def __init__(self) -> None:
super().__init__()
self._iterator: Iterator = None

def initialize(self, tasks: List[Task]) -> None:
def initialize(self, tasks: List[TASKS]) -> None:
super().initialize(tasks=tasks)
self._iterator = iter(self.tasks)

Expand Down
8 changes: 4 additions & 4 deletions src/isar/mission_planner/task_selector_interface.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from abc import ABCMeta, abstractmethod
from typing import List

from robot_interface.models.mission.task import Task
from robot_interface.models.mission.task import TASKS, Task


class TaskSelectorInterface(metaclass=ABCMeta):
def __init__(self) -> None:
self.tasks: List[Task] = None
self.tasks: List[TASKS] = None

def initialize(self, tasks: List[Task]) -> None:
def initialize(self, tasks: List[TASKS]) -> None:
self.tasks = tasks

@abstractmethod
def next_task(self) -> Task:
def next_task(self) -> TASKS:
"""
Returns
-------
Expand Down
4 changes: 2 additions & 2 deletions src/robot_interface/models/mission/mission.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

from robot_interface.models.exceptions.robot_exceptions import ErrorMessage
from robot_interface.models.mission.status import MissionStatus
from robot_interface.models.mission.task import Task
from robot_interface.models.mission.task import TASKS
from robot_interface.utilities.uuid_string_factory import uuid4_string


@dataclass
class Mission:
tasks: List[Task]
tasks: List[TASKS]
id: str = field(default_factory=uuid4_string, init=True)
name: str = ""
start_pose: Optional[Pose] = None
Expand Down
31 changes: 22 additions & 9 deletions src/robot_interface/models/mission/task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Iterator, Literal, Optional, Type
from typing import Iterator, Literal, Optional, Type, Union

from alitra import Pose, Position

Expand Down Expand Up @@ -72,7 +72,7 @@ class DockingProcedure(Task):
"""

behavior: Literal["dock", "undock"] = field(default=None, init=True)
type: str = TaskTypes.DockingProcedure
type: Literal["docking_procedure"] = TaskTypes.DockingProcedure


@dataclass
Expand All @@ -92,7 +92,7 @@ class Localize(Task):
"""

localization_pose: Pose = field(default=None, init=True)
type: str = TaskTypes.Localize
type: Literal["localize"] = TaskTypes.Localize


@dataclass
Expand All @@ -102,7 +102,7 @@ class MoveArm(Task):
"""

arm_pose: str = field(default=None, init=True)
type: str = TaskTypes.MoveArm
type: Literal["move_arm"] = TaskTypes.MoveArm


@dataclass
Expand All @@ -112,7 +112,7 @@ class TakeImage(InspectionTask):
"""

target: Position = field(default=None, init=True)
type: str = TaskTypes.TakeImage
type: Literal["take_image"] = TaskTypes.TakeImage

@staticmethod
def get_inspection_type() -> Type[Inspection]:
Expand All @@ -126,7 +126,7 @@ class TakeThermalImage(InspectionTask):
"""

target: Position = field(default=None, init=True)
type: str = TaskTypes.TakeThermalImage
type: Literal["take_thermal_image"] = TaskTypes.TakeThermalImage

@staticmethod
def get_inspection_type() -> Type[Inspection]:
Expand All @@ -143,7 +143,7 @@ class TakeVideo(InspectionTask):

target: Position = field(default=None, init=True)
duration: float = field(default=None, init=True)
type: str = TaskTypes.TakeVideo
type: Literal["take_video"] = TaskTypes.TakeVideo

@staticmethod
def get_inspection_type() -> Type[Inspection]:
Expand All @@ -160,7 +160,7 @@ class TakeThermalVideo(InspectionTask):

target: Position = field(default=None, init=True)
duration: float = field(default=None, init=True)
type: str = TaskTypes.TakeThermalVideo
type: Literal["take_thermal_video"] = TaskTypes.TakeThermalVideo

@staticmethod
def get_inspection_type() -> Type[Inspection]:
Expand All @@ -177,8 +177,21 @@ class RecordAudio(InspectionTask):

target: Position = field(default=None, init=True)
duration: float = field(default=None, init=True)
type: str = TaskTypes.RecordAudio
type: Literal["record_audio"] = TaskTypes.RecordAudio

@staticmethod
def get_inspection_type() -> Type[Inspection]:
return Audio


TASKS = Union[
ReturnToHome,
Localize,
MoveArm,
TakeImage,
TakeThermalImage,
TakeVideo,
TakeThermalVideo,
RecordAudio,
DockingProcedure,
]
7 changes: 4 additions & 3 deletions tests/isar/mission/test_mission.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from alitra import Frame, Orientation, Pose, Position

from isar.services.readers.base_reader import BaseReader
from robot_interface.models.mission.mission import Mission, Task
from robot_interface.models.mission.mission import Mission
from robot_interface.models.mission.task import (
TASKS,
TakeImage,
TakeThermalImage,
ReturnToHome,
Expand Down Expand Up @@ -126,8 +127,8 @@ def test_mission_definition() -> None:

assert len(loaded_mission.tasks) == len(expected_mission.tasks)
for i_task in range(len(loaded_mission.tasks)):
loaded_task: Task = loaded_mission.tasks[i_task]
expected_task: Task = expected_mission.tasks[i_task]
loaded_task: TASKS = loaded_mission.tasks[i_task]
expected_task: TASKS = expected_mission.tasks[i_task]

assert loaded_task.status == expected_task.status
assert loaded_task.tag_id == expected_task.tag_id
4 changes: 2 additions & 2 deletions tests/isar/models/test_start_mission_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
to_isar_mission,
)
from robot_interface.models.mission.mission import Mission
from robot_interface.models.mission.task import Task
from robot_interface.models.mission.task import TASKS, Task

task_1: Task = Task(tag_id=None, id="123")
task_2: Task = Task(tag_id=None, id="123")
Expand All @@ -33,7 +33,7 @@
),
],
)
def test_duplicate_id_check(item_list: List[Task], expected_boolean: bool):
def test_duplicate_id_check(item_list: List[TASKS], expected_boolean: bool):
duplicates: List[str] = get_duplicate_ids(item_list)
has_duplicates: bool = len(duplicates) > 0
assert has_duplicates == expected_boolean
Expand Down

0 comments on commit d5cebfb

Please sign in to comment.