Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Production phase 1.0 #66

Merged
merged 15 commits into from
Feb 20, 2024
2 changes: 1 addition & 1 deletion code/aind_auto_train/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.4.3"
__version__ = "1.0.0"

import logging

Expand Down
5 changes: 2 additions & 3 deletions code/aind_auto_train/auto_train_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from aind_auto_train.schema.curriculum import TrainingStage
from aind_auto_train.schema.task import metrics_class, Metrics, DynamicForagingMetrics
from aind_auto_train.curriculums.coupled_baiting_0p1 import curriculum as coupled_baiting_curriculum
from aind_auto_train.util.aws_util import import_df_from_s3, export_df_to_s3
from aind_auto_train.plot.manager import plot_manager_all_progress
from aind_auto_train.curriculum_manager import CurriculumManager
Expand Down Expand Up @@ -234,8 +233,8 @@ def _get_curriculum_to_use(self, df_this):
f'"Coupled Baiting_curriculum_v0.1_schema_v0.2"')
return self.curriculum_manager.get_curriculum(
curriculum_name='Coupled Baiting',
curriculum_version='0.1',
curriculum_schema_version='0.3',
curriculum_version='1.0',
curriculum_schema_version='1.0',
)

def add_and_evaluate_session(self, subject_id, session):
Expand Down
15 changes: 12 additions & 3 deletions code/aind_auto_train/curriculum_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def df_curriculums(self) -> pd.DataFrame:
'curriculum_version',
'curriculum_schema_version',
'curriculum_description'])

schema_version_code_base = curriculum_schemas.Curriculum.model_fields['curriculum_schema_version'].default

for f in self.json_files:
match = re.search(r'(.+)_curriculum_v([\d.]+)_schema_v([\d.]+)\.json',
os.path.basename(f))
Expand All @@ -58,6 +61,11 @@ def df_curriculums(self) -> pd.DataFrame:
f"Could not parse {os.path.basename(f)} as a curriculum json file.")
continue
curriculum_name, curriculum_version, curriculum_schema_version = match.groups()

# Only show curriculums whose curriculum_schema_version matches the current codebase
if schema_version_code_base != curriculum_schema_version:
continue

df_curriculums = pd.concat(
[df_curriculums,
pd.DataFrame.from_records([dict(curriculum_name=curriculum_name,
Expand Down Expand Up @@ -111,7 +119,8 @@ def get_curriculum(self,
# Check the schema version
schema_version = curriculum_schema.model_fields['curriculum_schema_version'].default
assert loaded_json['curriculum_schema_version'] == schema_version, \
f"schema version in the loaded json ({loaded_json['curriculum_schema_version']}) does not match the loaded schema ({schema_version})!"
f"Schema version in the loaded json ({loaded_json['curriculum_schema_version']}) does not match the loaded schema ({schema_version})! "\
f"Please update your `aind_auto_train` repo!"

# Create the curriculum object
curriculum = curriculum_schema(**loaded_json)
Expand Down Expand Up @@ -164,8 +173,8 @@ def upload_curriculums(self):
logger.info(curriculum_manager.df_curriculums())
_curr = curriculum_manager.get_curriculum(
curriculum_name='Coupled Baiting',
curriculum_version='0.2',
curriculum_schema_version='0.3',
curriculum_version='1.0',
curriculum_schema_version='1.0',
)

print(_curr)
Expand Down

This file was deleted.

Loading