Skip to content

Commit

Permalink
update curriculum_manager
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhou committed Dec 20, 2023
1 parent 18c868c commit 29e3e8f
Showing 1 changed file with 13 additions and 17 deletions.
30 changes: 13 additions & 17 deletions code/aind_auto_train/curriculum_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,13 @@ def df_curriculums(self) -> pd.DataFrame:
"""

df_curriculums = pd.DataFrame(columns=[
'task', 'task_schema_version', 'curriculum_version', 'curriculum_schema_version'])
'curriculum_task', 'curriculum_version', 'curriculum_schema_version'])
for f in self.json_files:
match = re.search(r'(.+)_v([\d.]+)_curriculum_v([\d.]+)_schema_v([\d.]+)\.json',
match = re.search(r'(.+)_curriculum_v([\d.]+)_schema_v([\d.]+)\.json',
os.path.basename(f))
task, task_schema_version, curriculum_version, curriculum_schema_version = match.groups()
curriculum_task, curriculum_version, curriculum_schema_version = match.groups()
df_curriculums = pd.concat([df_curriculums,
pd.DataFrame.from_records([dict(task=task,
task_schema_version=task_schema_version,
pd.DataFrame.from_records([dict(curriculum_task=curriculum_task,
curriculum_version=curriculum_version,
curriculum_schema_version=curriculum_schema_version)]
)
Expand All @@ -62,14 +61,13 @@ def df_curriculums(self) -> pd.DataFrame:
return df_curriculums

def get_curriculum(self,
task: Task,
task_schema_version: str,
curriculum_task: Task,
curriculum_schema_version: str,
curriculum_version: str
) -> dict:
""" Get a curriculum from the saved_curriculums directory"""

json_name = (f"{task}_v{task_schema_version}_"
json_name = (f"{curriculum_task}_"
f"curriculum_v{curriculum_version}_"
f"schema_v{curriculum_schema_version}.json")

Expand All @@ -83,9 +81,7 @@ def get_curriculum(self,
return None

# Sanity check
assert loaded_json['task'] == task, f"task in json ({loaded_json['task']}) does not match file name ({task})!"
assert loaded_json['task_schema_version'] == task_schema_version, \
f"task_schema_version in json ({loaded_json['task_schema_version']}) does not match file name ({task_schema_version})!"
assert loaded_json['curriculum_task'] == curriculum_task, f"curriculum_task in json ({loaded_json['curriculum_task']}) does not match file name ({curriculum_task})!"
assert loaded_json['curriculum_schema_version'] == curriculum_schema_version, \
f"curriculum_schema_version in json ({loaded_json['curriculum_schema_version']}) does not match file name ({curriculum_schema_version})!"
assert loaded_json['curriculum_version'] == curriculum_version, \
Expand Down Expand Up @@ -148,13 +144,13 @@ def upload_curriculums(self):
if __name__ == "__main__":
curriculum_manager = CurriculumManager()
logger.info(curriculum_manager.df_curriculums())
curriculum = curriculum_manager.get_curriculum(
task='Coupled Baiting',
task_schema_version='1.0',
curriculum_schema_version='0.1',
curriculum_version='0.2'
_curr = curriculum_manager.get_curriculum(
curriculum_task='Coupled Baiting',
curriculum_version='0.2',
curriculum_schema_version='0.2',
)

curriculum.diagram_rules(render_file_format='svg')
print(_curr)
# _curr['curriculum'].diagram_rules(render_file_format='svg')
# curriculum_manager.upload_curriculums()
# %%

0 comments on commit 29e3e8f

Please sign in to comment.