diff --git a/code/aind_auto_train/curriculums/coupled_baiting_1p0.py b/code/aind_auto_train/curriculums/coupled_baiting_1p0.py index 1cd9bed..57fdb9d 100644 --- a/code/aind_auto_train/curriculums/coupled_baiting_1p0.py +++ b/code/aind_auto_train/curriculums/coupled_baiting_1p0.py @@ -259,9 +259,9 @@ TransitionRule( decision=Decision.PROGRESS, to_stage=TrainingStage.STAGE_FINAL, - condition_description="Finished trials >= 400 and efficiency >= 0.7", + condition_description="Finished trials >= 350 and efficiency >= 0.7", condition="""lambda metrics: - metrics.finished_trials[-1] >= 400 + metrics.finished_trials[-1] >= 350 and metrics.foraging_efficiency[-1] >= 0.7 """, @@ -269,9 +269,9 @@ TransitionRule( decision=Decision.ROLLBACK, to_stage=TrainingStage.STAGE_2, - condition_description="Finished trials < 200 or efficiency < 0.6", + condition_description="Finished trials < 250 or efficiency < 0.6", condition="""lambda metrics: - metrics.finished_trials[-1] < 200 + metrics.finished_trials[-1] < 250 or metrics.foraging_efficiency[-1] < 0.6 """, @@ -343,14 +343,14 @@ decision=Decision.PROGRESS, to_stage=TrainingStage.GRADUATED, condition_description=("For recent 5 sessions," - "mean finished trials >= 500 and mean efficiency >= 0.7 " + "mean finished trials >= 400 and mean efficiency >= 0.7 " "and total sessions >= 10 and sessions at final >= 5"), condition="""lambda metrics: metrics.session_total >= 10 and metrics.session_at_current_stage >= 5 and - np.mean(metrics.finished_trials[-5:]) >= 500 + np.mean(metrics.finished_trials[-5:]) >= 400 and np.mean(metrics.foraging_efficiency[-5:]) >= 0.7 """, @@ -358,9 +358,9 @@ TransitionRule( decision=Decision.ROLLBACK, to_stage=TrainingStage.STAGE_3, - condition_description="For recent 2 sessions, mean finished trials < 400 or efficiency < 0.6", + condition_description="For recent 2 sessions, mean finished trials < 300 or efficiency < 0.6", condition="""lambda metrics: - np.mean(metrics.finished_trials[-2:]) < 400 + np.mean(metrics.finished_trials[-2:]) < 300 or np.mean(metrics.foraging_efficiency[-2:]) < 0.6 """,