diff --git a/code/aind_auto_train/curriculum_manager.py b/code/aind_auto_train/curriculum_manager.py index 929548d..390a434 100644 --- a/code/aind_auto_train/curriculum_manager.py +++ b/code/aind_auto_train/curriculum_manager.py @@ -46,14 +46,13 @@ def df_curriculums(self) -> pd.DataFrame: """ df_curriculums = pd.DataFrame(columns=[ - 'task', 'task_schema_version', 'curriculum_version', 'curriculum_schema_version']) + 'curriculum_task', 'curriculum_version', 'curriculum_schema_version']) for f in self.json_files: - match = re.search(r'(.+)_v([\d.]+)_curriculum_v([\d.]+)_schema_v([\d.]+)\.json', + match = re.search(r'(.+)_curriculum_v([\d.]+)_schema_v([\d.]+)\.json', os.path.basename(f)) - task, task_schema_version, curriculum_version, curriculum_schema_version = match.groups() + curriculum_task, curriculum_version, curriculum_schema_version = match.groups() df_curriculums = pd.concat([df_curriculums, - pd.DataFrame.from_records([dict(task=task, - task_schema_version=task_schema_version, + pd.DataFrame.from_records([dict(curriculum_task=curriculum_task, curriculum_version=curriculum_version, curriculum_schema_version=curriculum_schema_version)] ) @@ -62,14 +61,13 @@ def df_curriculums(self) -> pd.DataFrame: return df_curriculums def get_curriculum(self, - task: Task, - task_schema_version: str, + curriculum_task: Task, curriculum_schema_version: str, curriculum_version: str ) -> dict: """ Get a curriculum from the saved_curriculums directory""" - json_name = (f"{task}_v{task_schema_version}_" + json_name = (f"{curriculum_task}_" f"curriculum_v{curriculum_version}_" f"schema_v{curriculum_schema_version}.json") @@ -83,9 +81,7 @@ def get_curriculum(self, return None # Sanity check - assert loaded_json['task'] == task, f"task in json ({loaded_json['task']}) does not match file name ({task})!" - assert loaded_json['task_schema_version'] == task_schema_version, \ - f"task_schema_version in json ({loaded_json['task_schema_version']}) does not match file name ({task_schema_version})!" + assert loaded_json['curriculum_task'] == curriculum_task, f"curriculum_task in json ({loaded_json['curriculum_task']}) does not match file name ({curriculum_task})!" assert loaded_json['curriculum_schema_version'] == curriculum_schema_version, \ f"curriculum_schema_version in json ({loaded_json['curriculum_schema_version']}) does not match file name ({curriculum_schema_version})!" assert loaded_json['curriculum_version'] == curriculum_version, \ @@ -148,13 +144,13 @@ def upload_curriculums(self): if __name__ == "__main__": curriculum_manager = CurriculumManager() logger.info(curriculum_manager.df_curriculums()) - curriculum = curriculum_manager.get_curriculum( - task='Coupled Baiting', - task_schema_version='1.0', - curriculum_schema_version='0.1', - curriculum_version='0.2' + _curr = curriculum_manager.get_curriculum( + curriculum_task='Coupled Baiting', + curriculum_version='0.2', + curriculum_schema_version='0.2', ) - curriculum.diagram_rules(render_file_format='svg') + print(_curr) + # _curr['curriculum'].diagram_rules(render_file_format='svg') # curriculum_manager.upload_curriculums() # %%