From edbee92234989b4a8468ff682394698d8df0caec Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Mon, 8 Apr 2024 23:40:36 -0700 Subject: [PATCH 1/6] feat: click autotrain manager to show session --- code/util/aws_s3.py | 1 - code/util/streamlit.py | 31 ++++++++++++++++++++++++------- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/code/util/aws_s3.py b/code/util/aws_s3.py index 2cd5997..7244b7a 100644 --- a/code/util/aws_s3.py +++ b/code/util/aws_s3.py @@ -1,5 +1,4 @@ from PIL import Image -import glob import json import s3fs diff --git a/code/util/streamlit.py b/code/util/streamlit.py index 73dc971..53b5d0c 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -1,6 +1,7 @@ from collections import OrderedDict import pandas as pd import streamlit as st +from datetime import datetime from st_aggrid import AgGrid, GridOptionsBuilder from st_aggrid.shared import GridUpdateMode, ColumnsAutoSizeMode, DataReturnMode from pandas.api.types import ( @@ -22,7 +23,7 @@ from scipy.stats import linregress from .url_query_helper import checkbox_wrapper_for_url_query, selectbox_wrapper_for_url_query, slider_wrapper_for_url_query - +from .aws_s3 import draw_session_plots_quick_preview custom_css = { ".ag-root.ag-unselectable.ag-layout-normal": {"font-size": "15px !important", @@ -727,12 +728,28 @@ def add_auto_train_manager(): height=30 * len(df_training_manager.subject_id.unique()), ) - selected_ = plotly_events(fig_auto_train, - override_height=fig_auto_train.layout.height * 1.1, - override_width=fig_auto_train.layout.width, - click_event=False, - select_event=False, - ) + cols = st.columns([2, 1]) + with cols[0]: + selected_ = plotly_events(fig_auto_train, + override_height=fig_auto_train.layout.height * 1.1, + override_width=fig_auto_train.layout.width, + click_event=True, + select_event=False, + ) + with cols[1]: + st.markdown('#### 👀 Quick preview') + st.markdown('###### Click on one session to preview here') + if selected_: + # Some hacks to get back selected data + curve_number = selected_[0]['curveNumber'] + point_number = selected_[0]['pointNumber'] + this_subject = fig_auto_train['data'][curve_number] + session_date = datetime.strptime(this_subject['customdata'][point_number][1], "%Y-%m-%d") + subject_id = fig_auto_train['data'][curve_number]['name'].split(' ')[1] + + df_selected = (st.session_state.df['sessions_bonsai'].query( + f'''subject_id == "{subject_id}" and session_date == "{session_date}"''')) + draw_session_plots_quick_preview(df_selected) # -- Show dataframe -- # only show filtered subject From 25cce4dd7b2d0f4b8c836a6f7d2804884ba60b73 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Mon, 8 Apr 2024 23:47:52 -0700 Subject: [PATCH 2/6] refactor: migrate autotrain manager plotly to this repo --- code/util/plot_autotrain_manager.py | 203 ++++++++++++++++++++++++++++ code/util/streamlit.py | 5 +- 2 files changed, 207 insertions(+), 1 deletion(-) create mode 100644 code/util/plot_autotrain_manager.py diff --git a/code/util/plot_autotrain_manager.py b/code/util/plot_autotrain_manager.py new file mode 100644 index 0000000..48bdf78 --- /dev/null +++ b/code/util/plot_autotrain_manager.py @@ -0,0 +1,203 @@ +from datetime import datetime + +import numpy as np +import plotly.graph_objects as go +import pandas as pd + +from aind_auto_train.schema.curriculum import TrainingStage +from aind_auto_train.plot.curriculum import get_stage_color_mapper + + +def plot_manager_all_progress(manager: 'AutoTrainManager', + x_axis: ['session', 'date', + 'relative_date'] = 'session', # type: ignore + sort_by: ['subject_id', 'first_date', + 'last_date', 'progress_to_graduated'] = 'subject_id', + sort_order: ['ascending', + 'descending'] = 'descending', + marker_size=10, + marker_edge_width=2, + highlight_subjects=[], + if_show_fig=True + ): + + + # %% + # Set default order + df_manager = manager.df_manager.sort_values(by=['subject_id', 'session'], + ascending=[sort_order == 'ascending', False]) + + if not len(df_manager): + return None + + # Sort mice + if sort_by == 'subject_id': + subject_ids = df_manager.subject_id.unique() + elif sort_by == 'first_date': + subject_ids = df_manager.groupby('subject_id').session_date.min().sort_values( + ascending=sort_order == 'ascending').index + elif sort_by == 'last_date': + subject_ids = df_manager.groupby('subject_id').session_date.max().sort_values( + ascending=sort_order == 'ascending').index + elif sort_by == 'progress_to_graduated': + manager.compute_stats() + df_stats = manager.df_manager_stats + + # Sort by 'first_entry' of GRADUATED + subject_ids = df_stats.reset_index().set_index( + 'subject_id' + ).query( + f'current_stage_actual == "GRADUATED"' + )['first_entry'].sort_values( + ascending=sort_order != 'ascending').index.to_list() + + # Append subjects that have not graduated + subject_ids = subject_ids + [s for s in df_manager.subject_id.unique() if s not in subject_ids] + + else: + raise ValueError( + f'sort_by must be in {["subject_id", "first_date", "last_date", "progress"]}') + + # Preparing the scatter plot + traces = [] + for n, subject_id in enumerate(subject_ids): + df_subject = df_manager[df_manager['subject_id'] == subject_id] + + # Get stage_color_mapper + stage_color_mapper = get_stage_color_mapper(stage_list=list(TrainingStage.__members__)) + + # Get h2o if available + if 'h2o' in manager.df_behavior: + h2o = manager.df_behavior[ + manager.df_behavior['subject_id'] == subject_id]['h2o'].iloc[0] + else: + h2o = None + + # Handle open loop sessions + open_loop_ids = df_subject.if_closed_loop == False + color_actual = df_subject['current_stage_actual'].map( + stage_color_mapper) + color_actual[open_loop_ids] = 'lightgrey' + stage_actual = df_subject.current_stage_actual.values + stage_actual[open_loop_ids] = 'unknown (open loop)' + + # Select x + if x_axis == 'session': + x = df_subject['session'] + elif x_axis == 'date': + x = pd.to_datetime(df_subject['session_date']) + elif x_axis == 'relative_date': + x = pd.to_datetime(df_subject['session_date']) + x = (x - x.min()).dt.days + else: + raise ValueError( + f"x_axis can only be in ['session', 'date', 'relative_date']") + + # Cache x range + xrange_min = x.min() if n == 0 else min(x.min(), xrange_min) + xrange_max = x.max() if n == 0 else max(x.max(), xrange_max) + + traces.append(go.Scattergl( + x=x, + y=[n] * len(df_subject), + mode='markers', + marker=dict( + size=marker_size, + line=dict( + width=marker_edge_width, + color=df_subject['current_stage_suggested'].map( + stage_color_mapper) + ), + color=color_actual, + # colorbar=dict(title='Training Stage'), + ), + name=f'Mouse {subject_id}', + hovertemplate=(f"Subject {subject_id} ({h2o})" + "
Session %{customdata[0]}, %{customdata[1]}" + "
Curriculum: %{customdata[2]}_v%{customdata[3]}" + "
Suggested: %{customdata[4]}" + "
Actual: %{customdata[5]}" + "
Session task: %{customdata[6]}" + "
foraging_eff = %{customdata[7]}" + "
finished_trials = %{customdata[8]}" + "
Decision = %{customdata[9]}" + "
Next suggested: %{customdata[10]}" + ""), + customdata=np.stack( + (df_subject.session, + df_subject.session_date, + df_subject.curriculum_name, + df_subject.curriculum_version, + df_subject.current_stage_suggested, + stage_actual, + df_subject.task, + np.round(df_subject.foraging_efficiency, 3), + df_subject.finished_trials, + df_subject.decision, + df_subject.next_stage_suggested, + ), axis=-1), + showlegend=False + ) + ) + + # Add "x" for open loop sessions + traces.append(go.Scattergl( + x=x[open_loop_ids], + y=[n] * len(x[open_loop_ids]), + mode='markers', + marker=dict( + size=marker_size*0.8, + symbol='x-thin', + color='black', + line_width=marker_edge_width*0.8, + ), + showlegend=False, + ) + ) + + # Create the figure + fig = go.Figure(data=traces) + fig.update_layout( + title=f"Automatic training progress ({manager.manager_name})", + xaxis_title=x_axis, + yaxis_title='Mouse', + height=1200, + ) + + # Set subject_id as y axis label + fig.update_layout( + hovermode='closest', + yaxis=dict( + tickmode='array', + tickvals=np.arange(0, n + 1), # Original y-axis values + ticktext=subject_ids, # New labels + autorange='reversed', + zeroline=False, + title='' + ) + ) + + # Highight the selected subject + for n, subject_id in enumerate(subject_ids): + if subject_id in highlight_subjects: + fig.add_shape( + type="rect", + y0=n-0.5, + y1=n+0.5, + x0=xrange_min - (1 if x_axis != 'date' else pd.Timedelta(days=1)), + x1=xrange_max + (1 if x_axis != 'date' else pd.Timedelta(days=1)), + line=dict( + width=0, + ), + fillcolor="Gray", + opacity=0.3, + layer="below" + ) + + + # Show the plot + if if_show_fig: + fig.show() + + # %% + return fig diff --git a/code/util/streamlit.py b/code/util/streamlit.py index 53b5d0c..e4796ab 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -23,6 +23,8 @@ from scipy.stats import linregress from .url_query_helper import checkbox_wrapper_for_url_query, selectbox_wrapper_for_url_query, slider_wrapper_for_url_query +from .plot_autotrain_manager import plot_manager_all_progress + from .aws_s3 import draw_session_plots_quick_preview custom_css = { @@ -710,7 +712,8 @@ def add_auto_train_manager(): else: highlight_subjects = [] - fig_auto_train = st.session_state.auto_train_manager.plot_all_progress( + fig_auto_train = plot_manager_all_progress( + st.session_state.auto_train_manager, x_axis=x_axis, sort_by=sort_by, sort_order=sort_order, From f14d46dea3b7a7353082f1061e81f686ef2f3422 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Tue, 9 Apr 2024 00:01:51 -0700 Subject: [PATCH 3/6] minor --- code/Home.py | 1 - 1 file changed, 1 deletion(-) diff --git a/code/Home.py b/code/Home.py index 442ec86..946c216 100644 --- a/code/Home.py +++ b/code/Home.py @@ -12,7 +12,6 @@ """ -# %% __ver__ = 'v2.2.0' import pandas as pd From 34868a188331e25dfee68edc6a922b6486f2606c Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Tue, 9 Apr 2024 09:36:31 -0700 Subject: [PATCH 4/6] fix: use selectbox_wrapper_for_url_query --- code/util/streamlit.py | 68 +++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 31 deletions(-) diff --git a/code/util/streamlit.py b/code/util/streamlit.py index e4796ab..8f51946 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -672,37 +672,43 @@ def data_selector(): st.rerun() def add_auto_train_manager(): - + st.session_state.auto_train_manager.df_manager = st.session_state.auto_train_manager.df_manager[ st.session_state.auto_train_manager.df_manager.subject_id.astype(float) > 0] # Remove dummy mouse 0 df_training_manager = st.session_state.auto_train_manager.df_manager - + # -- Show plotly chart -- cols = st.columns([1, 1, 1, 0.7, 0.7, 3]) - options=['session', 'date', 'relative_date'] - x_axis = cols[0].selectbox('X axis', options=options, - index=options.index(st.session_state['auto_training_history_x_axis']), - key="auto_training_history_x_axis") - - options=['subject_id', - 'first_date', - 'last_date', - 'progress_to_graduated'] - sort_by = cols[1].selectbox('Sort by', - options=options, - index=options.index(st.session_state['auto_training_history_sort_by']), - key="auto_training_history_sort_by") - - options=['descending', 'ascending'] - sort_order = cols[2].selectbox('Sort order', - options=options, - index=options.index(st.session_state['auto_training_history_sort_order']), - key='auto_training_history_sort_order' - ) - + options = ["session", "date", "relative_date"] + x_axis = selectbox_wrapper_for_url_query( + st_prefix=cols[0], + label="X axis", + options=options, + default=options.index(st.session_state["auto_training_history_x_axis"]), + key="auto_training_history_x_axis", + ) + + options = ["subject_id", "first_date", "last_date", "progress_to_graduated"] + sort_by = selectbox_wrapper_for_url_query( + st_prefix=cols[1], + label="Sort by", + options=options, + default=options.index(st.session_state["auto_training_history_sort_by"]), + key="auto_training_history_sort_by", + ) + + options = ["descending", "ascending"] + sort_order = selectbox_wrapper_for_url_query( + st_prefix=cols[2], + label="Sort order", + options=options, + default=options.index(st.session_state["auto_training_history_sort_order"]), + key="auto_training_history_sort_order", + ) + marker_size = cols[3].number_input('Marker size', value=15, step=1) marker_edge_width = cols[4].number_input('Marker edge width', value=3, step=1) - + # Get highlighted subjects if ('filter_subject_id' in st.session_state and st.session_state['filter_subject_id']) or\ ('filter_h2o' in st.session_state and st.session_state['filter_h2o']): @@ -711,7 +717,7 @@ def add_auto_train_manager(): highlight_subjects = [str(x) for x in highlight_subjects] else: highlight_subjects = [] - + fig_auto_train = plot_manager_all_progress( st.session_state.auto_train_manager, x_axis=x_axis, @@ -722,7 +728,7 @@ def add_auto_train_manager(): highlight_subjects=highlight_subjects, if_show_fig=False ) - + fig_auto_train.update_layout( hoverlabel=dict( font_size=20, @@ -730,7 +736,7 @@ def add_auto_train_manager(): font=dict(size=18), height=30 * len(df_training_manager.subject_id.unique()), ) - + cols = st.columns([2, 1]) with cols[0]: selected_ = plotly_events(fig_auto_train, @@ -749,16 +755,16 @@ def add_auto_train_manager(): this_subject = fig_auto_train['data'][curve_number] session_date = datetime.strptime(this_subject['customdata'][point_number][1], "%Y-%m-%d") subject_id = fig_auto_train['data'][curve_number]['name'].split(' ')[1] - + df_selected = (st.session_state.df['sessions_bonsai'].query( f'''subject_id == "{subject_id}" and session_date == "{session_date}"''')) draw_session_plots_quick_preview(df_selected) - + # -- Show dataframe -- # only show filtered subject df_training_manager = df_training_manager[df_training_manager['subject_id'].isin( st.session_state.df_session_filtered['subject_id'].unique().astype(str))] - + # reorder columns df_training_manager = df_training_manager[['subject_id', 'session_date', 'session', 'curriculum_name', 'curriculum_version', 'curriculum_schema_version', @@ -768,7 +774,7 @@ def add_auto_train_manager(): 'foraging_efficiency', 'finished_trials', 'decision', 'next_stage_suggested' ]] - + with st.expander('Automatic training manager', expanded=True): st.dataframe(df_training_manager, height=3000) From 6a58a12d01f6962ae9ee2209bd4e24ff5bd7a7a2 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Tue, 9 Apr 2024 09:39:01 -0700 Subject: [PATCH 5/6] fix: minor --- code/util/aws_s3.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/code/util/aws_s3.py b/code/util/aws_s3.py index 7244b7a..23a5c63 100644 --- a/code/util/aws_s3.py +++ b/code/util/aws_s3.py @@ -55,9 +55,8 @@ def draw_session_plots_quick_preview(df_to_draw_session): rows.append(st.columns(column_setting)) for draw_type in draw_types_quick_preview: - if draw_type not in st.session_state.selected_draw_types: continue # To keep the draw order defined by st.session_state.draw_type_mapper_session_level prefix, position, setting = st.session_state.draw_type_mapper_session_level[draw_type] - this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types) > 1 else rows[0] + this_col = rows[position[0]][position[1]] if len(draw_types_quick_preview) > 1 else rows[0] show_session_level_img_by_key_and_prefix( key, column=this_col, From d05450748e0b2fc465c87dd7d2d7fccd144334a9 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Tue, 9 Apr 2024 09:57:52 -0700 Subject: [PATCH 6/6] feat: add avg trial length --- code/Home.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/code/Home.py b/code/Home.py index 946c216..78cc26f 100644 --- a/code/Home.py +++ b/code/Home.py @@ -451,6 +451,15 @@ def _get_data_source(rig): _df.dropna(subset=['session'], inplace=True) # Remove rows with no session number (only leave the nwb file with the largest finished_trials for now) _df.drop(_df.query('session < 1').index, inplace=True) + # Remove abnormal values + _df.loc[_df['weight_after'] > 100, + ['weight_after', 'weight_after_ratio', 'water_in_session_total', 'water_after_session', 'water_day_total'] + ] = np.nan + + _df.loc[_df['water_in_session_manual'] > 100, + ['water_in_session_manual', 'water_in_session_total', 'water_after_session']] = np.nan + + # # add something else # add abs(bais) to all terms that have 'bias' in name for col in _df.columns: @@ -469,6 +478,13 @@ def _get_data_source(rig): # map user_name _df['user_name'] = _df['user_name'].apply(_user_name_mapper) + # trial stats + _df['avg_trial_length_in_seconds'] = _df['session_run_time_in_min'] / _df['total_trials_with_autowater'] * 60 + + # last day's total water + _df['water_day_total_last_session'] = _df.groupby('h2o')['water_day_total'].shift(1) + _df['water_after_session_last_session'] = _df.groupby('h2o')['water_after_session'].shift(1) + # fill nan for autotrain fields filled_values = {'curriculum_name': 'None', 'curriculum_version': 'None', @@ -479,15 +495,7 @@ def _get_data_source(rig): 'if_overriden_by_trainer': False, } _df.fillna(filled_values, inplace=True) - - # Remove abnormal values - _df.loc[_df['weight_after'] > 100, - ['weight_after', 'weight_after_ratio', 'water_in_session_total', 'water_after_session', 'water_day_total'] - ] = np.nan - - _df.loc[_df['water_in_session_manual'] > 100, - ['water_in_session_manual', 'water_in_session_total', 'water_after_session']] = np.nan - + # foraging performance = foraing_eff * finished_rate if 'foraging_performance' not in _df.columns: _df['foraging_performance'] = \