Skip to content

Commit

Permalink
add curriculum_description to df_curriculums
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhou committed Dec 20, 2023
1 parent e12e4b2 commit aeb25e6
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 20 deletions.
30 changes: 19 additions & 11 deletions code/aind_auto_train/curriculum_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def df_curriculums(self) -> pd.DataFrame:
"""

df_curriculums = pd.DataFrame(columns=[
'curriculum_task', 'curriculum_version', 'curriculum_schema_version'])
'curriculum_task',
'curriculum_version',
'curriculum_schema_version',
'curriculum_description'])
for f in self.json_files:
match = re.search(r'(.+)_curriculum_v([\d.]+)_schema_v([\d.]+)\.json',
os.path.basename(f))
Expand All @@ -55,12 +58,16 @@ def df_curriculums(self) -> pd.DataFrame:
f"Could not parse {os.path.basename(f)} as a curriculum json file.")
continue
curriculum_task, curriculum_version, curriculum_schema_version = match.groups()
df_curriculums = pd.concat([df_curriculums,
pd.DataFrame.from_records([dict(curriculum_task=curriculum_task,
curriculum_version=curriculum_version,
curriculum_schema_version=curriculum_schema_version)]
)
], ignore_index=True)
df_curriculums = pd.concat(
[df_curriculums,
pd.DataFrame.from_records([dict(curriculum_task=curriculum_task,
curriculum_version=curriculum_version,
curriculum_schema_version=curriculum_schema_version,
curriculum_description=json.load(
open(f, 'r'))['curriculum_description']
)]
)
], ignore_index=True)

return df_curriculums

Expand All @@ -85,7 +92,8 @@ def get_curriculum(self,
return None

# Sanity check
assert loaded_json['curriculum_task'] == curriculum_task, f"curriculum_task in json ({loaded_json['curriculum_task']}) does not match file name ({curriculum_task})!"
assert loaded_json[
'curriculum_task'] == curriculum_task, f"curriculum_task in json ({loaded_json['curriculum_task']}) does not match file name ({curriculum_task})!"
assert loaded_json['curriculum_schema_version'] == curriculum_schema_version, \
f"curriculum_schema_version in json ({loaded_json['curriculum_schema_version']}) does not match file name ({curriculum_schema_version})!"
assert loaded_json['curriculum_version'] == curriculum_version, \
Expand Down Expand Up @@ -114,13 +122,13 @@ def get_curriculum(self,
metrics_schema = inspect.getfullargspec(
curriculum.evaluate_transitions
).annotations.get('metrics', None)

metrics_schema_name = metrics_schema.__name__ if metrics_schema else None

# Check whether the required metrics schema is available
assert hasattr(task_schemas, metrics_schema_name), \
f"'{metrics_schema_name}' not found in aind_auto_train.schema.task"

metrics = getattr(task_schemas, metrics_schema_name)

return {'curriculum': curriculum,
Expand Down
27 changes: 18 additions & 9 deletions code/aind_auto_train/demo_curriculum_manager.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,17 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-12-20 15:10:38 | INFO | aind_auto_train.util.aws_util | AWS credentials not found in environment variables. Try ~/.aws/credentials...\n",
"2023-12-20 15:10:38 | INFO | aind_auto_train.util.aws_util | Found AWS credential from ~/.aws/credentials!\n",
"2023-12-20 15:10:38 | INFO | aind_auto_train.util.aws_util | 17 objects downloaded from s3://aind-behavior-data/foraging_auto_training/saved_curriculums/ to /root/capsule/scratch/tmp/\n",
"2023-12-20 15:10:38 | INFO | aind_auto_train.curriculum_manager | Found 3 curriculums in /root/capsule/scratch/tmp/\n"
"2023-12-20 15:30:05 | INFO | aind_auto_train.util.aws_util | AWS credentials not found in environment variables. Try ~/.aws/credentials...\n",
"2023-12-20 15:30:05 | INFO | aind_auto_train.util.aws_util | Found AWS credential from ~/.aws/credentials!\n",
"2023-12-20 15:30:06 | INFO | aind_auto_train.util.aws_util | 17 objects downloaded from s3://aind-behavior-data/foraging_auto_training/saved_curriculums/ to /root/capsule/scratch/tmp/\n",
"2023-12-20 15:30:06 | INFO | aind_auto_train.curriculum_manager | Found 3 curriculums in /root/capsule/scratch/tmp/\n"
]
}
],
Expand Down Expand Up @@ -96,6 +96,7 @@
" <th>curriculum_task</th>\n",
" <th>curriculum_version</th>\n",
" <th>curriculum_schema_version</th>\n",
" <th>curriculum_description</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
Expand All @@ -104,28 +105,36 @@
" <td>Coupled Baiting</td>\n",
" <td>0.1</td>\n",
" <td>0.2</td>\n",
" <td>Base curriculum for the coupled-baiting task</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Coupled Baiting</td>\n",
" <td>0.2</td>\n",
" <td>0.2</td>\n",
" <td>More stringent criteria before GRADUATED than 0.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Dummy task</td>\n",
" <td>0.1</td>\n",
" <td>0.2</td>\n",
" <td></td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" curriculum_task curriculum_version curriculum_schema_version\n",
"0 Coupled Baiting 0.1 0.2\n",
"1 Coupled Baiting 0.2 0.2\n",
"2 Dummy task 0.1 0.2"
" curriculum_task curriculum_version curriculum_schema_version \\\n",
"0 Coupled Baiting 0.1 0.2 \n",
"1 Coupled Baiting 0.2 0.2 \n",
"2 Dummy task 0.1 0.2 \n",
"\n",
" curriculum_description \n",
"0 Base curriculum for the coupled-baiting task \n",
"1 More stringent criteria before GRADUATED than 0.1 \n",
"2 "
]
},
"execution_count": 4,
Expand Down

0 comments on commit aeb25e6

Please sign in to comment.