Skip to content

Commit

Permalink
lower the bar for finished trials
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhou committed Feb 20, 2024
1 parent c4bfe0e commit 5d9b2dd
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions code/aind_auto_train/curriculums/coupled_baiting_1p0.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,19 +259,19 @@
TransitionRule(
decision=Decision.PROGRESS,
to_stage=TrainingStage.STAGE_FINAL,
condition_description="Finished trials >= 400 and efficiency >= 0.7",
condition_description="Finished trials >= 350 and efficiency >= 0.7",
condition="""lambda metrics:
metrics.finished_trials[-1] >= 400
metrics.finished_trials[-1] >= 350
and
metrics.foraging_efficiency[-1] >= 0.7
""",
),
TransitionRule(
decision=Decision.ROLLBACK,
to_stage=TrainingStage.STAGE_2,
condition_description="Finished trials < 200 or efficiency < 0.6",
condition_description="Finished trials < 250 or efficiency < 0.6",
condition="""lambda metrics:
metrics.finished_trials[-1] < 200
metrics.finished_trials[-1] < 250
or
metrics.foraging_efficiency[-1] < 0.6
""",
Expand Down Expand Up @@ -343,24 +343,24 @@
decision=Decision.PROGRESS,
to_stage=TrainingStage.GRADUATED,
condition_description=("For recent 5 sessions,"
"mean finished trials >= 500 and mean efficiency >= 0.7 "
"mean finished trials >= 400 and mean efficiency >= 0.7 "
"and total sessions >= 10 and sessions at final >= 5"),
condition="""lambda metrics:
metrics.session_total >= 10
and
metrics.session_at_current_stage >= 5
and
np.mean(metrics.finished_trials[-5:]) >= 500
np.mean(metrics.finished_trials[-5:]) >= 400
and
np.mean(metrics.foraging_efficiency[-5:]) >= 0.7
""",
),
TransitionRule(
decision=Decision.ROLLBACK,
to_stage=TrainingStage.STAGE_3,
condition_description="For recent 2 sessions, mean finished trials < 400 or efficiency < 0.6",
condition_description="For recent 2 sessions, mean finished trials < 300 or efficiency < 0.6",
condition="""lambda metrics:
np.mean(metrics.finished_trials[-2:]) < 400
np.mean(metrics.finished_trials[-2:]) < 300
or
np.mean(metrics.foraging_efficiency[-2:]) < 0.6
""",
Expand Down

0 comments on commit 5d9b2dd

Please sign in to comment.