Skip to content

Commit

Permalink
bug fix dummy task (put in schema)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhou committed Dec 20, 2023
1 parent 29e3e8f commit f287c96
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 27 deletions.
27 changes: 2 additions & 25 deletions code/aind_auto_train/curriculums/dummy_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,8 @@
from typing import List, Dict

from aind_auto_train.curriculum_manager import LOCAL_SAVED_CURRICULUM_ROOT
from aind_auto_train.schema.task import Task, TrainingStage, Metrics, TaskParas
from aind_auto_train.schema.curriculum import Curriculum, StageTransitions, TransitionRule, Decision


# Override the metrics class
class DummyTaskMetrics(Metrics):
dummy_metric_float: List[float]
dummy_metric_int: List[int]

# Override the task parameters class
class DummyTaskParas(TaskParas):
dummy_para_bool: bool
dummy_para_float: float

# Override the curriculum class
class DummyTaskCurriculum(Curriculum[DummyTaskParas, DummyTaskMetrics]):
# Override parameters
parameters: Dict[TrainingStage, DummyTaskParas]

# Override metrics
def evaluate_transitions(self,
current_stage: TrainingStage,
metrics: DummyTaskMetrics # Note the dynamical type here
) -> TrainingStage:
return super().evaluate_transitions(current_stage, metrics)
from aind_auto_train.schema.task import Task, TrainingStage, DummyTaskParas, DummyTaskMetrics
from aind_auto_train.schema.curriculum import DummyTaskCurriculum, StageTransitions, TransitionRule, Decision


meta = dict(curriculum_version="0.1",
Expand Down
16 changes: 14 additions & 2 deletions code/aind_auto_train/schema/curriculum.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from pydantic.json import pydantic_encoder

from aind_auto_train.schema.task import (Task, TrainingStage,
taskparas_class, DynamicForagingParas,
metrics_class, DynamicForagingMetrics)
taskparas_class, DynamicForagingParas, DummyTaskParas,
metrics_class, DynamicForagingMetrics, DummyTaskMetrics)
from aind_auto_train.plot.curriculum import draw_diagram_rules, draw_diagram_paras

# %%
Expand Down Expand Up @@ -192,6 +192,18 @@ def evaluate_transitions(self,
return super().evaluate_transitions(current_stage, metrics)


class DummyTaskCurriculum(Curriculum[DummyTaskParas, DummyTaskMetrics]):
# Override parameters
parameters: Dict[TrainingStage, DummyTaskParas]

# Override metrics
def evaluate_transitions(self,
current_stage: TrainingStage,
metrics: DummyTaskMetrics # Note the dynamical type here
) -> TrainingStage:
return super().evaluate_transitions(current_stage, metrics)


# ------------------ Helpers ------------------
# A hack to serialize TrainingStage in the dictionary keys
def transform_dict_with_enum_keys(obj):
Expand Down
10 changes: 10 additions & 0 deletions code/aind_auto_train/schema/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ class DynamicForagingMetrics(Metrics):
foraging_efficiency: List[float] # Full history of foraging efficiency
finished_trials: List[int] # Full history of finished trials

# For dummy task
class DummyTaskMetrics(Metrics):
dummy_metric_float: List[float]
dummy_metric_int: List[int]

class TaskParas(AindModel):
"""Parent class for TaskParas. All other task parameters should inherit from this class
"""
Expand Down Expand Up @@ -156,3 +161,8 @@ class DynamicForagingParas(TaskParas):



# For dummy task
class DummyTaskParas(TaskParas):
dummy_para_bool: bool
dummy_para_float: float

0 comments on commit f287c96

Please sign in to comment.