Skip to content

Commit

Permalink
separate current_stage_actual and current_stage_suggested
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhou committed Dec 15, 2023
1 parent d58ccef commit 702b81a
Showing 1 changed file with 113 additions and 57 deletions.
170 changes: 113 additions & 57 deletions code/aind_auto_train/auto_train_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from aind_auto_train.plot.manager import plot_manager_all_progress

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

# Directory for caching df_maseter tables
task_mapper = {'coupled_block_baiting': 'Coupled Baiting',
Expand All @@ -26,11 +25,10 @@ class AutoTrainManager:
download_from_database()
upload_to_database()
"""

# Specify the Metrics subclass for a specific task
_metrics_model: metrics_class


def __init__(self,
manager_name: str,
):
Expand All @@ -42,18 +40,30 @@ def __init__(self,
self.df_behavior, self.df_manager = self.download_from_database()

# Check if all required metrics exist in df_behavior
self.task_specific_metrics_keys = set(self._metrics_model.schema()['properties'].keys()) \
- set(Metrics.schema()['properties'].keys())
self.task_specific_metrics_keys = set(self._metrics_model.model_json_schema()['properties'].keys()) \
- set(Metrics.model_json_schema()['properties'].keys())
assert all([col in self.df_behavior.columns for col in
list(self.task_specific_metrics_keys)]), "Not all required metrics exist in df_behavior!"

if self.df_manager is None: # Create a new table
logger.info('No df_manager found, creating a new one...')
# If `current_stage_actual` is not in df_behavior, we are in open loop simulation mode
if 'current_stage_actual' not in self.df_behavior:
self.if_simulation_mode = True
logger.warning(
"current_stage_actual is not in df_behavior, we are in simulation mode!")
else:
self.if_simulation_mode = False

self.df_manager = None

# Create a new table if df_manager is empty
if self.df_manager is None:
logger.warning('No df_manager found, creating a new one...')
self.df_manager = pd.DataFrame(columns=['subject_id', 'session_date', 'task',
'session', 'session_at_current_stage',
'curriculum_version', 'task_schema_version',
*self.task_specific_metrics_keys,
'metrics', 'current_stage_suggested', 'current_stage_actual',
'if_closed_loop', 'if_overriden_by_trainer',
'decision', 'next_stage_suggested'])

def download_from_database(self) -> (pd.DataFrame, pd.DataFrame):
Expand All @@ -77,7 +87,7 @@ def _count_session_at_current_stage(self,
current_stage: str) -> int:
""" Count the number of sessions at the current stage (reset after rolling back) """
session_at_current_stage = 1
for stage in reversed(df.query(f'subject_id == "{subject_id}"')['current_stage_suggested'].to_list()):
for stage in reversed(df.query(f'subject_id == "{subject_id}"')['current_stage_actual'].to_list()):
if stage == current_stage:
session_at_current_stage += 1
else:
Expand All @@ -88,7 +98,7 @@ def _count_session_at_current_stage(self,
def compute_stats(self):
"""compute simple stats"""
df_stats = self.df_manager.groupby(
['subject_id', 'current_stage_suggested'], sort=False
['subject_id', 'current_stage_actual'], sort=False
)['session'].agg([('session_spent', 'count'), # Number of sessions spent at this stage
# First entry to this stage
('first_entry', 'min'),
Expand All @@ -101,7 +111,7 @@ def compute_stats(self):

# Count the number of different decisions made at each stage
df_decision = self.df_manager.groupby(
['subject_id', 'current_stage_suggested', 'decision'], sort=False
['subject_id', 'current_stage_actual', 'decision'], sort=False
)['session'].agg('count').to_frame()

# Reorganize the table and rename the columns
Expand All @@ -112,41 +122,76 @@ def compute_stats(self):

# Merge df_decision with df_stats
df_stats = df_stats.merge(df_decision, how='left', on=[
'subject_id', 'current_stage_suggested'])
'subject_id', 'current_stage_actual'])

self.df_manager_stats = df_stats

def plot_all_progress(self, if_show_fig=True):
return plot_manager_all_progress(self, if_show_fig=if_show_fig)

def add_and_evaluate_session(self, subject_id, session):
""" Add a session to the curriculum manager and evaluate the transition """
def _get_next_stage_suggested_on_last_session(self, subject_id, session) -> str:
df_this_mouse = self.df_manager.query(f'subject_id == "{subject_id}"')

# If we don't have feedback from the GUI about the actual training stage used
if 'actual_stage' not in self.df_behavior:
if session == 1: # If this is the first session
current_stage = 'STAGE_1'
else:
# Assuming current session uses the suggested stage from the previous session
q_current_stage = df_this_mouse.query(
f"session == {session - 1}")['next_stage_suggested']

if len(q_current_stage) > 0:
current_stage = q_current_stage.iloc[0]
else: # Catch missing session or wrong session number
id_last_session = df_this_mouse[df_this_mouse.session < session]
if len(id_last_session) > 0:
id_last_session = id_last_session.session.astype(int).idxmax()
current_stage = df_this_mouse.loc[id_last_session,
'next_stage_suggested']
logger.warning(
msg=f"Cannot find subject {subject_id} session {session - 1}, "
f"use session {df_this_mouse.loc[id_last_session].session} instead")
else:
logger.error(
msg=f"Cannot find subject {subject_id} anysession < {session}")
return
q_current_stage = df_this_mouse.query(
f"session == {session - 1}")['next_stage_suggested']

if len(q_current_stage) > 0:
return q_current_stage.iloc[0]

# Catch missing session or wrong session number
id_last_session = df_this_mouse[df_this_mouse.session < session]
if len(id_last_session) > 0:
id_last_session = id_last_session.session.astype(int).idxmax()
logger.warning(
msg=f"Cannot find subject {subject_id} session {session - 1}, "
f"use session {df_this_mouse.loc[id_last_session].session} instead")

return df_this_mouse.loc[id_last_session, 'next_stage_suggested']

# Else, throw an error
logger.error(
msg=f"Cannot find subject {subject_id} any session < {session}")
return None

def _get_current_stages(self, subject_id, session) -> dict:
current_stage_suggested = 'STAGE_1' if session == 1 \
else self._get_next_stage_suggested_on_last_session(subject_id, session)

if self.if_simulation_mode:
# Assuming current session uses the suggested stage from the previous session
return {'current_stage_suggested': current_stage_suggested,
'current_stage_actual': current_stage_suggested,
'if_closed_loop': False}

# If not in simulation mode, use the actual stage
current_stage_actual = self.df_behavior.query(f'subject_id == "{subject_id}"')[
'current_stage_actual'].iloc[0]

# If current_stage_actual not in TrainingStage (including None), then we are in open loop for this specific session
if current_stage_actual not in TrainingStage.__members__:
logger.warning(
f'current_stage_actual "{current_stage_actual}" is invalid for subject {subject_id}, session {session}, we are in open loop for this session.'
)
return {'current_stage_suggested': current_stage_suggested,
'current_stage_actual': current_stage_suggested,
'if_closed_loop': False}

# Throw a warning if the fist actual stage is not STAGE_1 (but still use current_stage_actual)
if session == 1 and current_stage_actual != 'STAGE_1':
logger.warning(
f'First stage is not STAGE_1 for subject {subject_id}!')

return {'current_stage_suggested': current_stage_suggested,
'current_stage_actual': current_stage_actual,
'if_closed_loop': True}

def add_and_evaluate_session(self, subject_id, session):
""" Add a session to the curriculum manager and evaluate the transition """
# Get current stages
_current_stages = self._get_current_stages(subject_id, session)
current_stage_suggested = _current_stages['current_stage_suggested']
current_stage_actual = _current_stages['current_stage_actual']
if_closed_loop = _current_stages['if_closed_loop']

# Get metrics history (already sorted by session)
df_history = self.df_behavior.query(
Expand All @@ -161,7 +206,7 @@ def add_and_evaluate_session(self, subject_id, session):

# Count session_at_current_stage
session_at_current_stage = self._count_session_at_current_stage(
self.df_manager, subject_id, current_stage)
self.df_manager, subject_id, current_stage_actual)

# Evaluate
metrics = dict(**task_specific_metrics,
Expand All @@ -172,7 +217,7 @@ def add_and_evaluate_session(self, subject_id, session):
# Should we allow change of curriculum version during a training? maybe not...
# But we should definitely allow different curriculum versions for different
decision, next_stage_suggested = coupled_baiting_curriculum.evaluate_transitions(
current_stage=TrainingStage[current_stage],
current_stage=TrainingStage[current_stage_actual],
metrics=DynamicForagingMetrics(**metrics))

# Add to the manager
Expand All @@ -188,8 +233,13 @@ def add_and_evaluate_session(self, subject_id, session):
curriculum_version='0.1', # Allows changing curriculum during training
task_schema_version='1.0', # Allows changing task schema during training
session_at_current_stage=session_at_current_stage,
current_stage_suggested=current_stage,
**{key: df_this[key] for key in self.task_specific_metrics_keys}, # Copy task-specific metrics
current_stage_suggested=current_stage_suggested,
current_stage_actual=current_stage_actual, # Note this could be from simulation or invalid feedback-induced open loop session
if_closed_loop=if_closed_loop,
if_overriden_by_trainer=current_stage_actual != current_stage_suggested if if_closed_loop else False,

# Copy task-specific metrics
**{key: df_this[key] for key in self.task_specific_metrics_keys},
metrics=metrics,
decision=decision.name,
next_stage_suggested=next_stage_suggested.name
Expand All @@ -198,8 +248,8 @@ def add_and_evaluate_session(self, subject_id, session):

# Logging
logger.info(f"{subject_id}, {df_this.session_date}, session {session}: " +
(f"STAY at {current_stage}" if decision.name == 'STAY'
else f"{decision.name} {current_stage} --> {next_stage_suggested.name}"))
(f"STAY at {current_stage_actual}" if decision.name == 'STAY'
else f"{decision.name} {current_stage_actual} --> {next_stage_suggested.name}"))

def update(self):
"""update each mouse's training stage"""
Expand Down Expand Up @@ -232,9 +282,9 @@ def update(self):


class DynamicForagingAutoTrainManager(AutoTrainManager):

_metrics_model = DynamicForagingMetrics # Override the metrics model

def __init__(
self,
manager_name: str = 'Janelia_demo',
Expand Down Expand Up @@ -264,9 +314,9 @@ def __init__(
def download_from_database(self):
# --- load df_auto_train_manager and df_behavior from s3 ---
df_behavior = import_df_from_s3(bucket=self.df_behavior_on_s3['bucket'],
s3_path=self.df_behavior_on_s3['root'],
file_name=self.df_behavior_on_s3['file_name'],
)
s3_path=self.df_behavior_on_s3['root'],
file_name=self.df_behavior_on_s3['file_name'],
)

if df_behavior is None:
logger.error('No df_behavior found, exiting...')
Expand Down Expand Up @@ -294,9 +344,9 @@ def download_from_database(self):

# --- Load curriculum manager table; if not exist, create a new one ---
df_manager = import_df_from_s3(bucket=self.df_manager_root_on_s3['bucket'],
s3_path=self.df_manager_root_on_s3['root'],
file_name=self.df_manager_name,
)
s3_path=self.df_manager_root_on_s3['root'],
file_name=self.df_manager_name,
)

return df_behavior, df_manager

Expand All @@ -308,16 +358,22 @@ def upload_to_database(self):

for file_name, df in df_to_upload.items():
export_df_to_s3(df=df,
bucket='aind-behavior-data',
s3_path=self.df_manager_root_on_s3['root'],
file_name=file_name,
)
bucket='aind-behavior-data',
s3_path=self.df_manager_root_on_s3['root'],
file_name=file_name,
)


if __name__ == "__main__":
manager = DynamicForagingAutoTrainManager()
manager.df_manager
manager = DynamicForagingAutoTrainManager(manager_name='447_demo',
df_behavior_on_s3=dict(bucket='aind-behavior-data',
root='foraging_nwb_bonsai_processed/',
file_name='df_sessions.pkl'),
df_manager_root_on_s3=dict(bucket='aind-behavior-data',
root='foraging_auto_training/')
)
manager.update()
manager.plot_all_progress()
print(manager.df_manager)

# %%

0 comments on commit 702b81a

Please sign in to comment.