From 702b81aec2c2dd29d5cb2d260bfc98c76cf4430d Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Fri, 15 Dec 2023 10:37:18 +0000 Subject: [PATCH] separate `current_stage_actual` and `current_stage_suggested` --- code/aind_auto_train/auto_train_manager.py | 170 ++++++++++++++------- 1 file changed, 113 insertions(+), 57 deletions(-) diff --git a/code/aind_auto_train/auto_train_manager.py b/code/aind_auto_train/auto_train_manager.py index f7e637f..49c93f4 100644 --- a/code/aind_auto_train/auto_train_manager.py +++ b/code/aind_auto_train/auto_train_manager.py @@ -13,7 +13,6 @@ from aind_auto_train.plot.manager import plot_manager_all_progress logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.INFO) # Directory for caching df_maseter tables task_mapper = {'coupled_block_baiting': 'Coupled Baiting', @@ -26,11 +25,10 @@ class AutoTrainManager: download_from_database() upload_to_database() """ - + # Specify the Metrics subclass for a specific task _metrics_model: metrics_class - def __init__(self, manager_name: str, ): @@ -42,18 +40,30 @@ def __init__(self, self.df_behavior, self.df_manager = self.download_from_database() # Check if all required metrics exist in df_behavior - self.task_specific_metrics_keys = set(self._metrics_model.schema()['properties'].keys()) \ - - set(Metrics.schema()['properties'].keys()) + self.task_specific_metrics_keys = set(self._metrics_model.model_json_schema()['properties'].keys()) \ + - set(Metrics.model_json_schema()['properties'].keys()) assert all([col in self.df_behavior.columns for col in list(self.task_specific_metrics_keys)]), "Not all required metrics exist in df_behavior!" - if self.df_manager is None: # Create a new table - logger.info('No df_manager found, creating a new one...') + # If `current_stage_actual` is not in df_behavior, we are in open loop simulation mode + if 'current_stage_actual' not in self.df_behavior: + self.if_simulation_mode = True + logger.warning( + "current_stage_actual is not in df_behavior, we are in simulation mode!") + else: + self.if_simulation_mode = False + + self.df_manager = None + + # Create a new table if df_manager is empty + if self.df_manager is None: + logger.warning('No df_manager found, creating a new one...') self.df_manager = pd.DataFrame(columns=['subject_id', 'session_date', 'task', 'session', 'session_at_current_stage', 'curriculum_version', 'task_schema_version', *self.task_specific_metrics_keys, 'metrics', 'current_stage_suggested', 'current_stage_actual', + 'if_closed_loop', 'if_overriden_by_trainer', 'decision', 'next_stage_suggested']) def download_from_database(self) -> (pd.DataFrame, pd.DataFrame): @@ -77,7 +87,7 @@ def _count_session_at_current_stage(self, current_stage: str) -> int: """ Count the number of sessions at the current stage (reset after rolling back) """ session_at_current_stage = 1 - for stage in reversed(df.query(f'subject_id == "{subject_id}"')['current_stage_suggested'].to_list()): + for stage in reversed(df.query(f'subject_id == "{subject_id}"')['current_stage_actual'].to_list()): if stage == current_stage: session_at_current_stage += 1 else: @@ -88,7 +98,7 @@ def _count_session_at_current_stage(self, def compute_stats(self): """compute simple stats""" df_stats = self.df_manager.groupby( - ['subject_id', 'current_stage_suggested'], sort=False + ['subject_id', 'current_stage_actual'], sort=False )['session'].agg([('session_spent', 'count'), # Number of sessions spent at this stage # First entry to this stage ('first_entry', 'min'), @@ -101,7 +111,7 @@ def compute_stats(self): # Count the number of different decisions made at each stage df_decision = self.df_manager.groupby( - ['subject_id', 'current_stage_suggested', 'decision'], sort=False + ['subject_id', 'current_stage_actual', 'decision'], sort=False )['session'].agg('count').to_frame() # Reorganize the table and rename the columns @@ -112,41 +122,76 @@ def compute_stats(self): # Merge df_decision with df_stats df_stats = df_stats.merge(df_decision, how='left', on=[ - 'subject_id', 'current_stage_suggested']) + 'subject_id', 'current_stage_actual']) self.df_manager_stats = df_stats def plot_all_progress(self, if_show_fig=True): return plot_manager_all_progress(self, if_show_fig=if_show_fig) - def add_and_evaluate_session(self, subject_id, session): - """ Add a session to the curriculum manager and evaluate the transition """ + def _get_next_stage_suggested_on_last_session(self, subject_id, session) -> str: df_this_mouse = self.df_manager.query(f'subject_id == "{subject_id}"') - # If we don't have feedback from the GUI about the actual training stage used - if 'actual_stage' not in self.df_behavior: - if session == 1: # If this is the first session - current_stage = 'STAGE_1' - else: - # Assuming current session uses the suggested stage from the previous session - q_current_stage = df_this_mouse.query( - f"session == {session - 1}")['next_stage_suggested'] - - if len(q_current_stage) > 0: - current_stage = q_current_stage.iloc[0] - else: # Catch missing session or wrong session number - id_last_session = df_this_mouse[df_this_mouse.session < session] - if len(id_last_session) > 0: - id_last_session = id_last_session.session.astype(int).idxmax() - current_stage = df_this_mouse.loc[id_last_session, - 'next_stage_suggested'] - logger.warning( - msg=f"Cannot find subject {subject_id} session {session - 1}, " - f"use session {df_this_mouse.loc[id_last_session].session} instead") - else: - logger.error( - msg=f"Cannot find subject {subject_id} anysession < {session}") - return + q_current_stage = df_this_mouse.query( + f"session == {session - 1}")['next_stage_suggested'] + + if len(q_current_stage) > 0: + return q_current_stage.iloc[0] + + # Catch missing session or wrong session number + id_last_session = df_this_mouse[df_this_mouse.session < session] + if len(id_last_session) > 0: + id_last_session = id_last_session.session.astype(int).idxmax() + logger.warning( + msg=f"Cannot find subject {subject_id} session {session - 1}, " + f"use session {df_this_mouse.loc[id_last_session].session} instead") + + return df_this_mouse.loc[id_last_session, 'next_stage_suggested'] + + # Else, throw an error + logger.error( + msg=f"Cannot find subject {subject_id} any session < {session}") + return None + + def _get_current_stages(self, subject_id, session) -> dict: + current_stage_suggested = 'STAGE_1' if session == 1 \ + else self._get_next_stage_suggested_on_last_session(subject_id, session) + + if self.if_simulation_mode: + # Assuming current session uses the suggested stage from the previous session + return {'current_stage_suggested': current_stage_suggested, + 'current_stage_actual': current_stage_suggested, + 'if_closed_loop': False} + + # If not in simulation mode, use the actual stage + current_stage_actual = self.df_behavior.query(f'subject_id == "{subject_id}"')[ + 'current_stage_actual'].iloc[0] + + # If current_stage_actual not in TrainingStage (including None), then we are in open loop for this specific session + if current_stage_actual not in TrainingStage.__members__: + logger.warning( + f'current_stage_actual "{current_stage_actual}" is invalid for subject {subject_id}, session {session}, we are in open loop for this session.' + ) + return {'current_stage_suggested': current_stage_suggested, + 'current_stage_actual': current_stage_suggested, + 'if_closed_loop': False} + + # Throw a warning if the fist actual stage is not STAGE_1 (but still use current_stage_actual) + if session == 1 and current_stage_actual != 'STAGE_1': + logger.warning( + f'First stage is not STAGE_1 for subject {subject_id}!') + + return {'current_stage_suggested': current_stage_suggested, + 'current_stage_actual': current_stage_actual, + 'if_closed_loop': True} + + def add_and_evaluate_session(self, subject_id, session): + """ Add a session to the curriculum manager and evaluate the transition """ + # Get current stages + _current_stages = self._get_current_stages(subject_id, session) + current_stage_suggested = _current_stages['current_stage_suggested'] + current_stage_actual = _current_stages['current_stage_actual'] + if_closed_loop = _current_stages['if_closed_loop'] # Get metrics history (already sorted by session) df_history = self.df_behavior.query( @@ -161,7 +206,7 @@ def add_and_evaluate_session(self, subject_id, session): # Count session_at_current_stage session_at_current_stage = self._count_session_at_current_stage( - self.df_manager, subject_id, current_stage) + self.df_manager, subject_id, current_stage_actual) # Evaluate metrics = dict(**task_specific_metrics, @@ -172,7 +217,7 @@ def add_and_evaluate_session(self, subject_id, session): # Should we allow change of curriculum version during a training? maybe not... # But we should definitely allow different curriculum versions for different decision, next_stage_suggested = coupled_baiting_curriculum.evaluate_transitions( - current_stage=TrainingStage[current_stage], + current_stage=TrainingStage[current_stage_actual], metrics=DynamicForagingMetrics(**metrics)) # Add to the manager @@ -188,8 +233,13 @@ def add_and_evaluate_session(self, subject_id, session): curriculum_version='0.1', # Allows changing curriculum during training task_schema_version='1.0', # Allows changing task schema during training session_at_current_stage=session_at_current_stage, - current_stage_suggested=current_stage, - **{key: df_this[key] for key in self.task_specific_metrics_keys}, # Copy task-specific metrics + current_stage_suggested=current_stage_suggested, + current_stage_actual=current_stage_actual, # Note this could be from simulation or invalid feedback-induced open loop session + if_closed_loop=if_closed_loop, + if_overriden_by_trainer=current_stage_actual != current_stage_suggested if if_closed_loop else False, + + # Copy task-specific metrics + **{key: df_this[key] for key in self.task_specific_metrics_keys}, metrics=metrics, decision=decision.name, next_stage_suggested=next_stage_suggested.name @@ -198,8 +248,8 @@ def add_and_evaluate_session(self, subject_id, session): # Logging logger.info(f"{subject_id}, {df_this.session_date}, session {session}: " + - (f"STAY at {current_stage}" if decision.name == 'STAY' - else f"{decision.name} {current_stage} --> {next_stage_suggested.name}")) + (f"STAY at {current_stage_actual}" if decision.name == 'STAY' + else f"{decision.name} {current_stage_actual} --> {next_stage_suggested.name}")) def update(self): """update each mouse's training stage""" @@ -232,9 +282,9 @@ def update(self): class DynamicForagingAutoTrainManager(AutoTrainManager): - + _metrics_model = DynamicForagingMetrics # Override the metrics model - + def __init__( self, manager_name: str = 'Janelia_demo', @@ -264,9 +314,9 @@ def __init__( def download_from_database(self): # --- load df_auto_train_manager and df_behavior from s3 --- df_behavior = import_df_from_s3(bucket=self.df_behavior_on_s3['bucket'], - s3_path=self.df_behavior_on_s3['root'], - file_name=self.df_behavior_on_s3['file_name'], - ) + s3_path=self.df_behavior_on_s3['root'], + file_name=self.df_behavior_on_s3['file_name'], + ) if df_behavior is None: logger.error('No df_behavior found, exiting...') @@ -294,9 +344,9 @@ def download_from_database(self): # --- Load curriculum manager table; if not exist, create a new one --- df_manager = import_df_from_s3(bucket=self.df_manager_root_on_s3['bucket'], - s3_path=self.df_manager_root_on_s3['root'], - file_name=self.df_manager_name, - ) + s3_path=self.df_manager_root_on_s3['root'], + file_name=self.df_manager_name, + ) return df_behavior, df_manager @@ -308,16 +358,22 @@ def upload_to_database(self): for file_name, df in df_to_upload.items(): export_df_to_s3(df=df, - bucket='aind-behavior-data', - s3_path=self.df_manager_root_on_s3['root'], - file_name=file_name, - ) + bucket='aind-behavior-data', + s3_path=self.df_manager_root_on_s3['root'], + file_name=file_name, + ) if __name__ == "__main__": - manager = DynamicForagingAutoTrainManager() - manager.df_manager + manager = DynamicForagingAutoTrainManager(manager_name='447_demo', + df_behavior_on_s3=dict(bucket='aind-behavior-data', + root='foraging_nwb_bonsai_processed/', + file_name='df_sessions.pkl'), + df_manager_root_on_s3=dict(bucket='aind-behavior-data', + root='foraging_auto_training/') + ) manager.update() manager.plot_all_progress() + print(manager.df_manager) # %%