Skip to content

Commit

Permalink
Merge pull request #69 from AllenNeuralDynamics/han_fix_first_warmup
Browse files Browse the repository at this point in the history
fix: set the default first stage to STAGE_1_WARMUP
  • Loading branch information
hanhou authored Apr 1, 2024
2 parents 3a1e234 + 7a424fc commit 6d068b9
Show file tree
Hide file tree
Showing 5 changed files with 1,482 additions and 266 deletions.
5 changes: 3 additions & 2 deletions code/aind_auto_train/auto_train_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ def _get_next_stage_suggested_on_last_session(self, subject_id, session) -> str:
).iloc[0]['current_stage_actual']

def _get_current_stages(self, subject_id, session) -> dict:
current_stage_suggested = 'STAGE_1' if session == 1 \
# Hardcode first suggested stage here. Should be extract from the first stage of a curriculum.
current_stage_suggested = 'STAGE_1_WARMUP' if session == 1 \
else self._get_next_stage_suggested_on_last_session(subject_id, session)

if self.if_simulation_mode:
Expand All @@ -210,7 +211,7 @@ def _get_current_stages(self, subject_id, session) -> dict:
'if_closed_loop': False}

# Throw a warning if the fist actual stage is not STAGE_1 (but still use current_stage_actual)
if session == 1 and current_stage_actual != 'STAGE_1':
if session == 1 and 'STAGE_1' not in current_stage_actual:
logger.warning(
f'First stage is not STAGE_1 for subject {subject_id}!')

Expand Down
8 changes: 5 additions & 3 deletions code/aind_auto_train/curriculum_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,11 @@ def get_curriculum(self,

# Check the schema version
schema_version = curriculum_schema.model_fields['curriculum_schema_version'].default
assert loaded_json['curriculum_schema_version'] == schema_version, \
f"Schema version in the loaded json ({loaded_json['curriculum_schema_version']}) does not match the loaded schema ({schema_version})! "\
f"Please update your `aind_auto_train` repo!"

if loaded_json['curriculum_schema_version'] != schema_version:
logger.error(f"Schema version in the loaded json ({loaded_json['curriculum_schema_version']}) does not match the loaded schema ({schema_version})! "
f"You're either using an outdated `aind_auto_train` repo or loading an outdated curriculum!")
return None

# Create the curriculum object
curriculum = curriculum_schema(**loaded_json)
Expand Down
Loading

0 comments on commit 6d068b9

Please sign in to comment.