From f287c965eeedf7f38329092dc1c8b276fd4ae619 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Wed, 20 Dec 2023 00:51:57 +0000 Subject: [PATCH] bug fix dummy task (put in schema) --- .../aind_auto_train/curriculums/dummy_task.py | 27 ++----------------- code/aind_auto_train/schema/curriculum.py | 16 +++++++++-- code/aind_auto_train/schema/task.py | 10 +++++++ 3 files changed, 26 insertions(+), 27 deletions(-) diff --git a/code/aind_auto_train/curriculums/dummy_task.py b/code/aind_auto_train/curriculums/dummy_task.py index f1cff59..e3d6282 100644 --- a/code/aind_auto_train/curriculums/dummy_task.py +++ b/code/aind_auto_train/curriculums/dummy_task.py @@ -4,31 +4,8 @@ from typing import List, Dict from aind_auto_train.curriculum_manager import LOCAL_SAVED_CURRICULUM_ROOT -from aind_auto_train.schema.task import Task, TrainingStage, Metrics, TaskParas -from aind_auto_train.schema.curriculum import Curriculum, StageTransitions, TransitionRule, Decision - - -# Override the metrics class -class DummyTaskMetrics(Metrics): - dummy_metric_float: List[float] - dummy_metric_int: List[int] - -# Override the task parameters class -class DummyTaskParas(TaskParas): - dummy_para_bool: bool - dummy_para_float: float - -# Override the curriculum class -class DummyTaskCurriculum(Curriculum[DummyTaskParas, DummyTaskMetrics]): - # Override parameters - parameters: Dict[TrainingStage, DummyTaskParas] - - # Override metrics - def evaluate_transitions(self, - current_stage: TrainingStage, - metrics: DummyTaskMetrics # Note the dynamical type here - ) -> TrainingStage: - return super().evaluate_transitions(current_stage, metrics) +from aind_auto_train.schema.task import Task, TrainingStage, DummyTaskParas, DummyTaskMetrics +from aind_auto_train.schema.curriculum import DummyTaskCurriculum, StageTransitions, TransitionRule, Decision meta = dict(curriculum_version="0.1", diff --git a/code/aind_auto_train/schema/curriculum.py b/code/aind_auto_train/schema/curriculum.py index 633cd4b..84963dd 100644 --- a/code/aind_auto_train/schema/curriculum.py +++ b/code/aind_auto_train/schema/curriculum.py @@ -10,8 +10,8 @@ from pydantic.json import pydantic_encoder from aind_auto_train.schema.task import (Task, TrainingStage, - taskparas_class, DynamicForagingParas, - metrics_class, DynamicForagingMetrics) + taskparas_class, DynamicForagingParas, DummyTaskParas, + metrics_class, DynamicForagingMetrics, DummyTaskMetrics) from aind_auto_train.plot.curriculum import draw_diagram_rules, draw_diagram_paras # %% @@ -192,6 +192,18 @@ def evaluate_transitions(self, return super().evaluate_transitions(current_stage, metrics) +class DummyTaskCurriculum(Curriculum[DummyTaskParas, DummyTaskMetrics]): + # Override parameters + parameters: Dict[TrainingStage, DummyTaskParas] + + # Override metrics + def evaluate_transitions(self, + current_stage: TrainingStage, + metrics: DummyTaskMetrics # Note the dynamical type here + ) -> TrainingStage: + return super().evaluate_transitions(current_stage, metrics) + + # ------------------ Helpers ------------------ # A hack to serialize TrainingStage in the dictionary keys def transform_dict_with_enum_keys(obj): diff --git a/code/aind_auto_train/schema/task.py b/code/aind_auto_train/schema/task.py index bcadac0..6ec639c 100644 --- a/code/aind_auto_train/schema/task.py +++ b/code/aind_auto_train/schema/task.py @@ -52,6 +52,11 @@ class DynamicForagingMetrics(Metrics): foraging_efficiency: List[float] # Full history of foraging efficiency finished_trials: List[int] # Full history of finished trials +# For dummy task +class DummyTaskMetrics(Metrics): + dummy_metric_float: List[float] + dummy_metric_int: List[int] + class TaskParas(AindModel): """Parent class for TaskParas. All other task parameters should inherit from this class """ @@ -156,3 +161,8 @@ class DynamicForagingParas(TaskParas): +# For dummy task +class DummyTaskParas(TaskParas): + dummy_para_bool: bool + dummy_para_float: float +