diff --git a/code/aind_auto_train/auto_train_manager.py b/code/aind_auto_train/auto_train_manager.py index d64512d..0ee5fc9 100644 --- a/code/aind_auto_train/auto_train_manager.py +++ b/code/aind_auto_train/auto_train_manager.py @@ -186,7 +186,8 @@ def _get_next_stage_suggested_on_last_session(self, subject_id, session) -> str: ).iloc[0]['current_stage_actual'] def _get_current_stages(self, subject_id, session) -> dict: - current_stage_suggested = 'STAGE_1' if session == 1 \ + # Hardcode first suggested stage here. Should be extract from the first stage of a curriculum. + current_stage_suggested = 'STAGE_1_WARMUP' if session == 1 \ else self._get_next_stage_suggested_on_last_session(subject_id, session) if self.if_simulation_mode: @@ -210,7 +211,7 @@ def _get_current_stages(self, subject_id, session) -> dict: 'if_closed_loop': False} # Throw a warning if the fist actual stage is not STAGE_1 (but still use current_stage_actual) - if session == 1 and current_stage_actual != 'STAGE_1': + if session == 1 and 'STAGE_1' not in current_stage_actual: logger.warning( f'First stage is not STAGE_1 for subject {subject_id}!')