diff --git a/code/Home.py b/code/Home.py index cce7c18..24a418d 100644 --- a/code/Home.py +++ b/code/Home.py @@ -12,7 +12,7 @@ """ -__ver__ = 'v2.2.2' +__ver__ = 'v2.3.0' import pandas as pd import streamlit as st @@ -35,66 +35,14 @@ show_debug_info, ) from util.url_query_helper import ( - sync_widget_with_query, slider_wrapper_for_url_query, checkbox_wrapper_for_url_query + sync_URL_to_session_state, sync_session_state_to_URL, + slider_wrapper_for_url_query, checkbox_wrapper_for_url_query ) from aind_auto_train.curriculum_manager import CurriculumManager from aind_auto_train.auto_train_manager import DynamicForagingAutoTrainManager -# Sync widgets with URL query params -# https://blog.streamlit.io/how-streamlit-uses-streamlit-sharing-contextual-apps/ -# dict of "key": default pairs -# Note: When creating the widget, add argument "value"/"index" as well as "key" for all widgets you want to sync with URL -to_sync_with_url_query = { - 'if_load_bpod_sessions': False, - - 'filter_subject_id': '', - 'filter_session': [0.0, None], - 'filter_finished_trials': [0.0, None], - 'filter_foraging_eff': [0.0, None], - 'filter_task': ['all'], - - 'table_height': 300, - - 'tab_id': 'tab_session_x_y', - 'x_y_plot_xname': 'session', - 'x_y_plot_yname': 'foraging_performance_random_seed', - 'x_y_plot_group_by': 'h2o', - 'x_y_plot_if_show_dots': True, - 'x_y_plot_if_aggr_each_group': True, - 'x_y_plot_aggr_method_group': 'lowess', - 'x_y_plot_if_aggr_all': True, - 'x_y_plot_aggr_method_all': 'mean +/- sem', - 'x_y_plot_smooth_factor': 5, - 'x_y_plot_if_use_x_quantile_group': False, - 'x_y_plot_q_quantiles_group': 20, - 'x_y_plot_if_use_x_quantile_all': False, - 'x_y_plot_q_quantiles_all': 20, - 'x_y_plot_if_show_diagonal': False, - 'x_y_plot_dot_size': 10, - 'x_y_plot_dot_opacity': 0.3, - 'x_y_plot_line_width': 2.0, - 'x_y_plot_figure_width': 1300, - 'x_y_plot_figure_height': 900, - 'x_y_plot_font_size_scale': 1.0, - 'x_y_plot_selected_color_map': 'Plotly', - - 'x_y_plot_size_mapper': 'finished_trials', - 'x_y_plot_size_mapper_gamma': 1.0, - 'x_y_plot_size_mapper_range': [3, 20], - - 'session_plot_mode': 'sessions selected from table or plot', - - 'auto_training_history_x_axis': 'session', - 'auto_training_history_sort_by': 'subject_id', - 'auto_training_history_sort_order': 'descending', - 'auto_training_curriculum_name': 'Uncoupled Baiting', - 'auto_training_curriculum_version': '1.0', - 'auto_training_curriculum_schema_version': '1.0', - } - - try: st.set_page_config(layout="wide", page_title='Foraging behavior browser', @@ -336,14 +284,10 @@ def init(): if key in ['selected_draw_types'] or '_changed' in key: del st.session_state[key] - # Set session state from URL - for key, default in to_sync_with_url_query.items(): - sync_widget_with_query(key, default) - df = load_data(['sessions'], data_source='bonsai') # --- Perform any data source-dependent preprocessing here --- - if st.session_state.if_load_bpod_sessions: + if (st.session_state.if_load_bpod_sessions if 'if_load_bpod_sessions' in st.session_state else False): df_bpod = load_data(['sessions'], data_source='bpod') # For historial reason, the suffix of df['sessions_bonsai'] just mean the data of the Home.py page @@ -526,6 +470,9 @@ def _get_data_source(rig): st.session_state.df['sessions_bonsai'] = _df # Somehow _df loses the reference to the original dataframe st.session_state.session_stats_names = [keys for keys in _df.keys()] + + # Set session state from URL + sync_URL_to_session_state() # Establish communication between pygwalker and streamlit init_streamlit_comm() @@ -790,14 +737,9 @@ def app(): st.markdown('---\n##### Debug zone') show_debug_info() - # Update back to URL - for key in to_sync_with_url_query: - try: - st.query_params.update({key: st.session_state[key]}) - except: - print(f'Failed to update {key} to URL query') + sync_session_state_to_URL() # st.dataframe(st.session_state.df_session_filtered, use_container_width=True, height=1000) diff --git a/code/pages/1_Old mice.py b/code/pages/1_Old mice.py deleted file mode 100644 index 25c08a0..0000000 --- a/code/pages/1_Old mice.py +++ /dev/null @@ -1,686 +0,0 @@ -#%% -import pandas as pd -import streamlit as st -from pathlib import Path -import glob -import matplotlib.pyplot as plt -import numpy as np -from datetime import datetime -import s3fs -import os -import plotly.express as px -import plotly -import plotly.graph_objects as go -import statsmodels.api as sm - -from PIL import Image, ImageColor -import streamlit.components.v1 as components -import streamlit_nested_layout -from streamlit_plotly_events import plotly_events - -from util.streamlit import (filter_dataframe, aggrid_interactive_table_session, add_session_filter, data_selector, - add_xy_selector, add_xy_setting, add_auto_train_manager, - _plot_population_x_y) -from util.url_query_helper import sync_widget_with_query - -import extra_streamlit_components as stx - -from aind_auto_train.auto_train_manager import DynamicForagingAutoTrainManager -from pygwalker.api.streamlit import StreamlitRenderer, init_streamlit_comm - - -# Sync widgets with URL query params -# https://blog.streamlit.io/how-streamlit-uses-streamlit-sharing-contextual-apps/ -# dict of "key": default pairs -# Note: When creating the widget, add argument "value"/"index" as well as "key" for all widgets you want to sync with URL -to_sync_with_url_query = { - 'filter_h2o': '', - 'filter_session': [0.0, None], - 'filter_finished_trials': [0.0, None], - 'filter_foraging_eff': [0.0, None], - 'filter_task': ['all'], - 'filter_photostim_location': ['all'], - - 'tab_id': 'tab_session_x_y', - 'x_y_plot_xname': 'session', - 'x_y_plot_yname': 'foraging_performance', - 'x_y_plot_group_by': 'h2o', - 'x_y_plot_if_show_dots': True, - 'x_y_plot_if_aggr_each_group': True, - 'x_y_plot_aggr_method_group': 'lowess', - 'x_y_plot_if_aggr_all': True, - 'x_y_plot_aggr_method_all': 'mean +/- sem', - 'x_y_plot_smooth_factor': 5, - 'x_y_plot_if_use_x_quantile_group': False, - 'x_y_plot_q_quantiles_group': 20, - 'x_y_plot_if_use_x_quantile_all': False, - 'x_y_plot_q_quantiles_all': 20, - 'x_y_plot_dot_size': 7, - 'x_y_plot_dot_opacity': 0.2, - 'x_y_plot_line_width': 2.0, - - 'auto_training_history_x_axis': 'session', - 'auto_training_history_sort_by': 'progress_to_graduated', - 'auto_training_history_sort_order': 'descending', - } - - -if_profile = False - -if if_profile: - from streamlit_profiler import Profiler - p = Profiler() - p.start() - - -# from pipeline import experiment, ephys, lab, psth_foraging, report, foraging_analysis -# from pipeline.plot import foraging_model_plot - -cache_folder = 'xxx' #'/root/capsule/data/s3/report/st_cache/' -cache_session_level_fig_folder = 'xxx' #'/root/capsule/data/s3/report/all_units/' # - -if os.path.exists(cache_folder): - st.session_state.st.session_state.use_s3 = False -else: - cache_folder = 'aind-behavior-data/Han/ephys/report/st_cache/' - cache_session_level_fig_folder = 'aind-behavior-data/Han/ephys/report/all_sessions/' - cache_mouse_level_fig_folder = 'aind-behavior-data/Han/ephys/report/all_subjects/' - - fs = s3fs.S3FileSystem(anon=False) - st.session_state.use_s3 = True - -try: - st.set_page_config(layout="wide", - page_title='Foraging behavior browser', - page_icon=':mouse2:', - menu_items={ - 'Report a bug': "https://github.com/hanhou/foraging-behavior-browser/issues", - 'About': "Github repo: https://github.com/hanhou/foraging-behavior-browser/" - } - ) -except: - pass - -if 'selected_points' not in st.session_state: - st.session_state['selected_points'] = [] - - -@st.cache_data(ttl=24*3600) -def load_data(tables=['sessions']): - df = {} - for table in tables: - file_name = cache_folder + f'df_{table}.pkl' - if st.session_state.use_s3: - with fs.open(file_name) as f: - df[table] = pd.read_pickle(f) - else: - df[table] = pd.read_pickle(file_name) - return df - -def _fetch_img(glob_patterns, crop=None): - # Fetch the img that first matches the patterns - for pattern in glob_patterns: - file = fs.glob(pattern) if st.session_state.use_s3 else glob.glob(pattern) - if len(file): break - - if not len(file): - return None, None - - try: - if st.session_state.use_s3: - with fs.open(file[0]) as f: - img = Image.open(f) - img = img.crop(crop) - else: - img = Image.open(file[0]) - img = img.crop(crop) - except: - st.write('File found on S3 but failed to load...') - return None, None - - return img, file[0] - - -# @st.cache_data(ttl=24*3600, max_entries=20) -def show_session_level_img_by_key_and_prefix(key, prefix, column=None, other_patterns=[''], crop=None, caption=True, **kwargs): - try: - sess_date_str = datetime.strftime(datetime.strptime(key['session_date'], '%Y-%m-%dT%H:%M:%S'), '%Y%m%d') - except: - sess_date_str = datetime.strftime(key['session_date'], '%Y%m%d') - - fns = [f'/{key["h2o"]}_{sess_date_str}_*{other_pattern}*' for other_pattern in other_patterns] - glob_patterns = [cache_session_level_fig_folder + f'{prefix}/' + key["h2o"] + fn for fn in fns] - - img, f_name = _fetch_img(glob_patterns, crop) - - _f = st if column is None else column - - _f.image(img if img is not None else "https://cdn-icons-png.flaticon.com/512/3585/3585596.png", - output_format='PNG', - caption=f_name.split('/')[-1] if caption and f_name else '', - use_column_width='always', - **kwargs) - - return img - -def show_mouse_level_img_by_key_and_prefix(key, prefix, column=None, other_patterns=[''], crop=None, caption=True, **kwargs): - - fns = [f'/{key["h2o"]}_*{other_pattern}*' for other_pattern in other_patterns] - glob_patterns = [cache_mouse_level_fig_folder + f'{prefix}/' + fn for fn in fns] - - img, f_name = _fetch_img(glob_patterns, crop) - - if img is None: # Use "not_found" image - glob_patterns = [cache_mouse_level_fig_folder + f'{prefix}/not_found_*{other_pattern}**' for other_pattern in other_patterns] - img, f_name = _fetch_img(glob_patterns, crop) - - _f = st if column is None else column - - _f.image(img if img is not None else "https://cdn-icons-png.flaticon.com/512/3585/3585596.png", - output_format='PNG', - #caption=f_name.split('/')[-1] if caption and f_name else '', - use_column_width='always', - **kwargs) - - return img - -# table_mapping = { -# 'sessions': fetch_sessions, -# 'ephys_units': fetch_ephys_units, -# } - - -def draw_session_plots(df_to_draw_session): - - # Setting up layout for each session - layout_definition = [[1], # columns in the first row - [1.5, 1], # columns in the second row - [1, 1], - ] - - # cols_option = st.columns([3, 0.5, 1]) - container_session_all_in_one = st.container() - - with container_session_all_in_one: - # with st.expander("Expand to see all-in-one plot for selected unit", expanded=True): - - if len(df_to_draw_session): - st.write(f'Loading selected {len(df_to_draw_session)} sessions...') - my_bar = st.columns((1, 7))[0].progress(0) - - major_cols = st.columns([1] * st.session_state.num_cols) - - for i, key in enumerate(df_to_draw_session.to_dict(orient='records')): - this_major_col = major_cols[i % st.session_state.num_cols] - - # setting up layout for each session - rows = [] - with this_major_col: - - try: - date_str = key["session_date"].strftime('%Y-%m-%d') - except: - date_str = key["session_date"].split("T")[0] - - st.markdown(f'''

{key["h2o"]}, Session {key["session"]}, {date_str}''', - unsafe_allow_html=True) - if len(st.session_state.selected_draw_types) > 1: # more than one types, use the pre-defined layout - for row, column_setting in enumerate(layout_definition): - rows.append(this_major_col.columns(column_setting)) - else: # else, put it in the whole column - rows = this_major_col.columns([1]) - st.markdown("---") - - for draw_type in st.session_state.draw_type_mapper_session_level: - if draw_type not in st.session_state.selected_draw_types: continue # To keep the draw order defined by st.session_state.draw_type_mapper_session_level - prefix, position, setting = st.session_state.draw_type_mapper_session_level[draw_type] - this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types) > 1 else rows[0] - show_session_level_img_by_key_and_prefix(key, - column=this_col, - prefix=prefix, - **setting) - - my_bar.progress(int((i + 1) / len(df_to_draw_session) * 100)) - - - -def draw_mice_plots(df_to_draw_mice): - - # Setting up layout for each session - layout_definition = [[1], # columns in the first row - ] - - # cols_option = st.columns([3, 0.5, 1]) - container_session_all_in_one = st.container() - - with container_session_all_in_one: - # with st.expander("Expand to see all-in-one plot for selected unit", expanded=True): - - if len(df_to_draw_mice): - st.write(f'Loading selected {len(df_to_draw_mice)} mice...') - my_bar = st.columns((1, 7))[0].progress(0) - - major_cols = st.columns([1] * st.session_state.num_cols_mice) - - for i, key in enumerate(df_to_draw_mice.to_dict(orient='records')): - this_major_col = major_cols[i % st.session_state.num_cols_mice] - - # setting up layout for each session - rows = [] - with this_major_col: - st.markdown(f'''

{key["h2o"]}''', - unsafe_allow_html=True) - if len(st.session_state.selected_draw_types_mice) > 1: # more than one types, use the pre-defined layout - for row, column_setting in enumerate(layout_definition): - rows.append(this_major_col.columns(column_setting)) - else: # else, put it in the whole column - rows = this_major_col.columns([1]) - st.markdown("---") - - for draw_type in st.session_state.draw_type_mapper_mouse_level: - if draw_type not in st.session_state.selected_draw_types_mice: continue - prefix, position, setting = st.session_state.draw_type_mapper_mouse_level[draw_type] - this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types_mice) > 1 else rows[0] - show_mouse_level_img_by_key_and_prefix(key, - column=this_col, - prefix=prefix, - **setting) - -def session_plot_settings(need_click=True): - st.markdown('##### Show plots for individual sessions ') - cols = st.columns([2, 1]) - st.session_state.selected_draw_sessions = cols[0].selectbox('Which session(s) to draw?', - [f'selected from table/plot ({len(st.session_state.df_selected_from_plotly)} sessions)', - f'filtered from sidebar ({len(st.session_state.df_session_filtered)} sessions)'], - index=0 - ) - st.session_state.num_cols = cols[1].number_input('Number of columns', 1, 10, - 3 if 'num_cols' not in st.session_state else st.session_state.num_cols) - - st.markdown( - """ - """, - unsafe_allow_html=True, - ) - st.session_state.selected_draw_types = st.multiselect('Which plot(s) to draw?', - st.session_state.draw_type_mapper_session_level.keys(), - default=st.session_state.draw_type_mapper_session_level.keys() - if 'selected_draw_types' not in st.session_state else - st.session_state.selected_draw_types) - if need_click: - draw_it = st.button('Show me all sessions!', use_container_width=True) - else: - draw_it = True - return draw_it - -def mouse_plot_settings(need_click=True): - st.markdown('##### Show plots for individual mice ') - cols = st.columns([2, 1]) - st.session_state.selected_draw_mice = cols[0].selectbox('Which mice to draw?', - [f'selected from table/plot ({len(st.session_state.df_selected_from_plotly.h2o.unique())} mice)', - f'filtered from sidebar ({len(st.session_state.df_session_filtered.h2o.unique())} mice)'], - index=0 - ) - st.session_state.num_cols_mice = cols[1].number_input('Number of columns', 1, 10, - 3 if 'num_cols_mice' not in st.session_state else st.session_state.num_cols_mice) - st.markdown( - """ - """, - unsafe_allow_html=True, - ) - st.session_state.selected_draw_types_mice = st.multiselect('Which plot(s) to draw?', - st.session_state.draw_type_mapper_mouse_level.keys(), - default=st.session_state.draw_type_mapper_mouse_level.keys() - if 'selected_draw_types_mice' not in st.session_state else - st.session_state.selected_draw_types_mice) - if need_click: - draw_it = st.button('Show me all mice!', use_container_width=True) - else: - draw_it = True - return draw_it - - -def plot_x_y_session(): - - cols = st.columns([4, 10]) - - with cols[0]: - x_name, y_name, group_by = add_xy_selector(if_bonsai=False) - - (if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group, - if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, if_show_diagonal, - dot_size, dot_opacity, line_width, figure_width, figure_height, font_size_scale, color_map) = add_xy_setting() - - - # If no sessions are selected, use all filtered entries - # df_x_y_session = st.session_state.df_selected_from_dataframe if if_plot_only_selected_from_dataframe else st.session_state.df_session_filtered - df_x_y_session = st.session_state.df_session_filtered - - names = {('session', 'foraging_eff'): 'Foraging efficiency', - ('session', 'finished'): 'Finished trials', - } - - df_selected_from_plotly = pd.DataFrame() - # for i, (title, (x_name, y_name)) in enumerate(names.items()): - # with cols[i]: - with cols[1]: - fig = _plot_population_x_y(df=df_x_y_session, - x_name=x_name, y_name=y_name, - group_by=group_by, - smooth_factor=smooth_factor, - if_show_dots=if_show_dots, - if_aggr_each_group=if_aggr_each_group, - if_aggr_all=if_aggr_all, - aggr_method_group=aggr_method_group, - aggr_method_all=aggr_method_all, - if_use_x_quantile_group=if_use_x_quantile_group, - q_quantiles_group=q_quantiles_group, - if_use_x_quantile_all=if_use_x_quantile_all, - q_quantiles_all=q_quantiles_all, - title=names[(x_name, y_name)] if (x_name, y_name) in names else y_name, - states = st.session_state.df_selected_from_plotly, - dot_size_base=dot_size, - dot_opacity=dot_opacity, - line_width=line_width) - - # st.plotly_chart(fig) - selected = plotly_events(fig, click_event=True, hover_event=False, select_event=True, - override_height=fig.layout.height * 1.1, override_width=fig.layout.width) - - if len(selected): - df_selected_from_plotly = df_x_y_session.merge(pd.DataFrame(selected).rename({'x': x_name, 'y': y_name}, axis=1), - on=[x_name, y_name], how='inner') - - return df_selected_from_plotly, cols - - -# ------- Layout starts here -------- # -def init(): - - # Clear specific session state and all filters - for key in st.session_state: - if key in ['selected_draw_types'] or '_changed' in key: - del st.session_state[key] - - # Set session state from URL - for key, default in to_sync_with_url_query.items(): - sync_widget_with_query(key, default) - - df = load_data(['sessions', - 'logistic_regression_hattori', - 'logistic_regression_su', - 'linear_regression_rt', - 'model_fitting_params']) - - # Try to convert datetimes into a standard format (datetime, no timezone) - df['sessions']['session_date'] = pd.to_datetime(df['sessions']['session_date']) - # if is_datetime64_any_dtype(df[col]): - df['sessions']['session_date'] = df['sessions']['session_date'].dt.tz_localize(None) - df['sessions']['photostim_location'].fillna('None', inplace=True) - - st.session_state.df = df - st.session_state.df_selected_from_plotly = pd.DataFrame(columns=['h2o', 'session']) - st.session_state.df_selected_from_dataframe = pd.DataFrame(columns=['h2o', 'session']) - - # Init auto training database - st.session_state.auto_train_manager = DynamicForagingAutoTrainManager( - manager_name='Janelia_demo', - df_behavior_on_s3=dict(bucket='aind-behavior-data', - root='Han/ephys/report/all_sessions/export_all_nwb/', - file_name='df_sessions.pkl'), - df_manager_root_on_s3=dict(bucket='aind-behavior-data', - root='foraging_auto_training/') - ) - - # Init session states - to_init = [ - ['model_id', 21], # add some model fitting params to session - ] - - for name, default in to_init: - if name not in st.session_state: - st.session_state[name] = default - - selected_id = st.session_state.model_id - - st.session_state.draw_type_mapper_session_level = {'1. Choice history': ('fitted_choice', # prefix - (0, 0), # location (row_idx, column_idx) - dict(other_patterns=['model_best', 'model_None'])), - '2. Lick times': ('lick_psth', - (1, 0), - {}), - '3. Win-stay-lose-shift prob.': ('wsls', - (1, 1), - dict(crop=(0, 0, 1200, 600))), - '4. Linear regression on RT': ('linear_regression_rt', - (1, 1), - dict()), - '5. Logistic regression on choice (Hattori)': ('logistic_regression_hattori', - (2, 0), - dict(crop=(0, 0, 1200, 2000))), - '6. Logistic regression on choice (Su)': ('logistic_regression_su', - (2, 1), - dict(crop=(0, 0, 1200, 2000))), - } - - st.session_state.draw_type_mapper_mouse_level = {'1. Model comparison': ('model_all_sessions', # prefix - (0, 0), # location (row_idx, column_idx) - dict(other_patterns=['comparison'], - crop=(0, #900, - 100, 2800, 2200))), - '2. Model prediction accuracy': ('model_all_sessions', - (0, 0), - dict(other_patterns=['pred_acc'])), - '3. Model fitted parameters': ('model_all_sessions', - (0, 0), - dict(other_patterns=['fitted_para'])), - } - - - # process dfs - df_this_model = st.session_state.df['model_fitting_params'].query(f'model_id == {selected_id}') - valid_field = df_this_model.columns[~np.all(~df_this_model.notna(), axis=0)] - to_add_model = st.session_state.df['model_fitting_params'].query(f'model_id == {selected_id}')[valid_field] - st.session_state.df['sessions'].drop(st.session_state.df['sessions'].query('session < 1').index, inplace=True) - - st.session_state.df['sessions'] = st.session_state.df['sessions'].merge(to_add_model, on=('subject_id', 'session'), how='left') - - # add something else - st.session_state.df['sessions']['abs(bias)'] = np.abs(st.session_state.df['sessions'].biasL) - - # delta weight - diff_relative_weight_next_day = st.session_state.df['sessions'].set_index( - ['session']).sort_values('session', ascending=True).groupby('h2o').apply( - lambda x: - x.relative_weight.diff(periods=-1)).rename("diff_relative_weight_next_day") - - # foraging performance = foraing_eff * finished_ratio - if 'foraging_performance' not in st.session_state.df['sessions'].columns: - st.session_state.df['sessions']['foraging_performance'] = \ - st.session_state.df['sessions']['foraging_eff'] \ - * (1 - st.session_state.df['sessions']['ignore_rate']) - - # weekday - st.session_state.df['sessions']['weekday'] = st.session_state.df['sessions'].session_date.dt.dayofweek + 1 - - st.session_state.df['sessions'] = st.session_state.df['sessions'].merge( - diff_relative_weight_next_day, how='left', on=['h2o', 'session']) - - st.session_state.session_stats_names = [keys for keys in st.session_state.df['sessions'].keys()] - -@st.cache_resource(ttl=24*3600) -def get_pyg_renderer(df, spec="./gw_config.json", **kwargs) -> "StreamlitRenderer": - return StreamlitRenderer(df, spec=spec, debug=False, **kwargs) - - -def app(): - st.markdown('## Foraging Behavior Browser') - - with st.sidebar: - add_session_filter() - data_selector() - - st.markdown('---') - st.markdown('#### Han Hou @ 2024 v2.0.0') - st.markdown('[bug report / feature request](https://github.com/AllenNeuralDynamics/foraging-behavior-browser/issues)') - - with st.expander('Debug', expanded=False): - st.session_state.model_id = st.selectbox('model_id', st.session_state.df['model_fitting_params'].model_id.unique()) - if st.button('Reload data from AWS S3'): - st.cache_data.clear() - init() - st.rerun() - - - - with st.container(): - # col1, col2 = st.columns([1.5, 1], gap='small') - # with col1: - # -- 1. unit dataframe -- - - cols = st.columns([2, 2, 2]) - cols[0].markdown(f'### Filter the sessions on the sidebar ({len(st.session_state.df_session_filtered)} filtered)') - # if cols[1].button('Press this and then Ctrl + R to reload from S3'): - # st.rerun() - if cols[1].button('Reload data '): - st.cache_data.clear() - init() - st.rerun() - - # aggrid_outputs = aggrid_interactive_table_units(df=df['ephys_units']) - # st.session_state.df_session_filtered = aggrid_outputs['data'] - - container_filtered_frame = st.container() - - - if len(st.session_state.df_session_filtered) == 0: - st.markdown('## No filtered results!') - return - - aggrid_outputs = aggrid_interactive_table_session(df=st.session_state.df_session_filtered) - - if len(aggrid_outputs['selected_rows']) and not set(pd.DataFrame(aggrid_outputs['selected_rows'] - ).set_index(['h2o', 'session']).index - ) == set(st.session_state.df_selected_from_dataframe.set_index(['h2o', 'session']).index): - st.session_state.df_selected_from_dataframe = pd.DataFrame(aggrid_outputs['selected_rows']) - st.session_state.df_selected_from_plotly = st.session_state.df_selected_from_dataframe # Sync selected on plotly - # if st.session_state.tab_id == "tab_session_x_y": - st.rerun() - - chosen_id = stx.tab_bar(data=[ - stx.TabBarItemData(id="tab_session_x_y", title="📈 Session X-Y plot", description="Interactive session-wise scatter plot"), - stx.TabBarItemData(id="tab_session_inspector", title="👀 Session Inspector", description="Select sessions from the table and show plots"), - stx.TabBarItemData(id="tab_pygwalker", title="📊 PyGWalker (Tableau)", description="Interactive dataframe explorer"), - stx.TabBarItemData(id="tab_auto_train_history", title="🎓 Automatic Training History", description="Track progress"), - stx.TabBarItemData(id="tab_mouse_inspector", title="🐭 Mouse Model Fitting", description="Mouse-level model fitting results"), - ], default="tab_session_inspector" if 'tab_id' not in st.session_state else st.session_state.tab_id) - # chosen_id = "tab_session_x_y" - - placeholder = st.container() - - if chosen_id == "tab_session_x_y": - st.session_state.tab_id = chosen_id - with placeholder: - df_selected_from_plotly, x_y_cols = plot_x_y_session() - - with x_y_cols[0]: - for i in range(7): st.write('\n') - st.markdown("***") - if_draw_all_sessions = session_plot_settings() - - df_to_draw_sessions = st.session_state.df_selected_from_plotly if 'selected' in st.session_state.selected_draw_sessions else st.session_state.df_session_filtered - - if if_draw_all_sessions and len(df_to_draw_sessions): - draw_session_plots(df_to_draw_sessions) - - if len(df_selected_from_plotly) and not set(df_selected_from_plotly.set_index(['h2o', 'session']).index) == set( - st.session_state.df_selected_from_plotly.set_index(['h2o', 'session']).index): - st.session_state.df_selected_from_plotly = df_selected_from_plotly - st.session_state.df_selected_from_dataframe = df_selected_from_plotly # Sync selected on dataframe - st.rerun() - - elif chosen_id == "tab_session_inspector": - st.session_state.tab_id = chosen_id - with placeholder: - with st.columns([4, 10])[0]: - if_draw_all_sessions = session_plot_settings(need_click=False) - df_to_draw_sessions = st.session_state.df_selected_from_plotly if 'selected' in st.session_state.selected_draw_sessions else st.session_state.df_session_filtered - - if if_draw_all_sessions and len(df_to_draw_sessions): - draw_session_plots(df_to_draw_sessions) - - elif chosen_id == "tab_pygwalker": - with placeholder: - cols = st.columns([1, 4]) - cols[0].markdown('##### Exploring data using [PyGWalker](https://docs.kanaries.net/pygwalker)') - with cols[1]: - with st.expander('Specify PyGWalker json'): - # Load json from ./gw_config.json - pyg_user_json = st.text_area("Export your plot settings to json by clicking `export_code` " - "button below and then paste your json here to reproduce your plots", - key='pyg_walker', height=100) - - # If pyg_user_json is not empty, use it; otherwise, use the default gw_config.json - if pyg_user_json: - try: - pygwalker_renderer = get_pyg_renderer( - df=st.session_state.df_session_filtered, - spec=pyg_user_json, - ) - except: - pygwalker_renderer = get_pyg_renderer( - df=st.session_state.df_session_filtered, - spec="./gw_config_old_mice.json", - ) - else: - pygwalker_renderer = get_pyg_renderer( - df=st.session_state.df_session_filtered, - spec="./gw_config_old_mice.json", - ) - - pygwalker_renderer.render_explore(height=1010, scrolling=False) - - - elif chosen_id == "tab_auto_train_history": # Automatic training history - st.session_state.tab_id = chosen_id - with placeholder: - add_auto_train_manager() - - elif chosen_id == "tab_mouse_inspector": - st.session_state.tab_id = chosen_id - with placeholder: - with st.columns([4, 10])[0]: - if_draw_all_mice = mouse_plot_settings(need_click=False) - df_selected = st.session_state.df_selected_from_plotly if 'selected' in st.session_state.selected_draw_mice else st.session_state.df_session_filtered - df_to_draw_mice = df_selected.groupby('h2o').count().reset_index() - - if if_draw_all_mice and len(df_to_draw_mice): - draw_mice_plots(df_to_draw_mice) - - - - # st.dataframe(st.session_state.df_session_filtered, use_container_width=True, height=1000) - - # Update back to URL - for key in to_sync_with_url_query: - try: - st.query_params.update({key: st.session_state[key]}) - except: - print(f'Failed to update {key} to URL query') - - -if 'df' not in st.session_state or 'sessions' not in st.session_state.df.keys(): - init() - -app() - - -if if_profile: - p.stop() \ No newline at end of file diff --git a/code/pages/3_Playground.py b/code/pages/1_Playground.py similarity index 100% rename from code/pages/3_Playground.py rename to code/pages/1_Playground.py diff --git a/code/pages/2_Population analysis.py b/code/pages/2_Population analysis.py deleted file mode 100644 index 3299816..0000000 --- a/code/pages/2_Population analysis.py +++ /dev/null @@ -1,281 +0,0 @@ -import streamlit as st - -import importlib -import matplotlib.pyplot as plt -import numpy as np -from io import BytesIO - -from util.streamlit import filter_dataframe, aggrid_interactive_table_session, add_session_filter, data_selector -from util.population import _draw_variable_trial_back, _draw_variable_trial_back_linear_reg -import seaborn as sns - - -try: - st.set_page_config(layout="wide", - page_title='Foraging behavior browser', - page_icon=':mouse2:', - menu_items={ - 'Report a bug': "https://github.com/hanhou/foraging-behavior-browser/issues", - 'About': "Github repo: https://github.com/hanhou/foraging-behavior-browser/" - } - ) -except: - pass - -st.session_state.use_s3 = True - - -def app(): - - - with st.sidebar: - add_session_filter() - data_selector() - - - # st.dataframe(st.session_state.df['logistic_regression_hattori']) - - - df_to_do_population = st.session_state.df_session_filtered - df_log_reg_hattori_to_do = df_to_do_population.merge(st.session_state.df['logistic_regression_hattori'], on=('subject_id', 'session'), how='inner') - df_log_reg_su_to_do = df_to_do_population.merge(st.session_state.df['logistic_regression_su'], on=('subject_id', 'session'), how='inner') - - df_lin_reg_rt_to_do = df_to_do_population.merge(st.session_state.df['linear_regression_rt'], on=('subject_id', 'session'), how='inner') - - cols = st.columns([1, 3]) - - with cols[0]: - if st.checkbox('Plot logistic regression (non-photostim)', True): - st.markdown('Hattori (Rewarded choice + Unrewarded choice + Choice)') - fig = plot_logistic_regression_non_photostim(df_log_reg_hattori_to_do.query('trial_group == "all_no_stim"')) - if fig is not None: - buf = BytesIO() - fig.savefig(buf, format="png") - st.image(buf, use_column_width=True) - - st.markdown('Su (Rewarded choice + Unrewarded choice)') - fig = plot_logistic_regression_non_photostim(df_log_reg_su_to_do.query('trial_group == "all_no_stim"')) - if fig is not None: - buf = BytesIO() - fig.savefig(buf, format="png") - st.image(buf, use_column_width=True) - - if st.checkbox('Plot linear regression on RT (non-photostim)', True): - fig = plot_linear_regression_rt_non_photostim(df_lin_reg_rt_to_do.query('trial_group == "all_no_stim"')) - - if fig is not None: - buf = BytesIO() - fig.savefig(buf, format="png") - st.image(buf, use_column_width=True) - - with cols[1]: - if st.checkbox('Plot logistic regression (⚡photostim sessions)', False): - st.markdown('Hattori (Rewarded choice + Unrewarded choice + Choice)') - with st.columns([1, 3])[0]: - beta_names = st.multiselect('beta names', ['RewC', 'UnrC', 'C'], ['RewC', 'UnrC', 'C']) - max_trials_back = st.slider('max trials back', 1, 5, 3) - - fig = plot_logistic_regression_photostim(df_log_reg_hattori_to_do.query('trial_group != "all_no_stim"'), - beta_names=beta_names, past_trials_to_plot=range(1, max_trials_back + 1)) - if fig is not None: - buf = BytesIO() - fig.savefig(buf, format="png") - st.image(buf, width=3000) - - st.markdown('Su (Rewarded choice + Unrewarded choice)') - with st.columns([1, 3])[0]: - beta_names = st.multiselect('beta names ', ['RewC', 'UnrC'], ['RewC', 'UnrC']) - max_trials_back = st.slider('max trials back ', 1, 5, 3) - - fig = plot_logistic_regression_photostim(df_log_reg_su_to_do.query('trial_group != "all_no_stim"'), - beta_names=beta_names, past_trials_to_plot=range(1, max_trials_back + 1)) - if fig is not None: - buf = BytesIO() - fig.savefig(buf, format="png") - st.image(buf, width=3000) - - if st.checkbox('Plot linear regression on RT (⚡photostim sessions)', False): - fig = plot_linear_regression_rt_photostim(df_lin_reg_rt_to_do.query('trial_group != "all_no_stim"'), - beta_names=['reward_1', 'reward_2', 'previous_iti', 'trial_number', 'this_choice', 'constant']) - if fig is not None: - buf = BytesIO() - fig.savefig(buf, format="png") - st.image(buf) - - - - - # st.pyplot(fig) - - -@st.cache_data(ttl=3600*24) -def plot_logistic_regression_photostim(df_all, beta_names=['RewC', 'UnrC', 'C'], past_trials_to_plot=(1, 2, 3)): - if not len(df_all): return None - - df_all['error_bar'] = [[x, y] for x, y in zip(df_all['mean'] - df_all.lower_ci, df_all.upper_ci - df_all['mean'])] - df_all['h2o_session'] = df_all[['h2o', 'session']].astype(str).apply('_'.join, axis=1) - - df_all.query('trial_group != "all_no_stim"', inplace=True) - - - fig, axes = plt.subplots(len(past_trials_to_plot), len(beta_names) + 1, - figsize=(10 * (len(beta_names) + 1), 5 * len(past_trials_to_plot)), constrained_layout=False, - dpi=200, - gridspec_kw=dict(hspace=0.3, wspace=0.3, top=0.9, bottom=0.1, left=0.1, right=0.9), - ) - axes = np.atleast_2d(axes) - - for i, trials_back in enumerate(past_trials_to_plot): - for j, name in enumerate(beta_names): - _draw_variable_trial_back(df_all, name, trials_back, ax=axes[i, j]) - _draw_variable_trial_back(df_all, 'bias', 0, ax=axes[0, j + 1]) - - for i in range(1, len(past_trials_to_plot)): axes[i, -1].remove() - - return fig - - -def plot_logistic_regression_non_photostim(df_all, max_trials_back=10, ax=None): - if not len(df_all): return None - - if ax is None: - fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=500, - gridspec_kw=dict(bottom=0.2, top=0.9)) - - xx = np.arange(1, max_trials_back + 1) - - plot_spec = {'RewC': ('tab:green', 'reward choices'), - 'UnrC': ('tab:red', 'unrewarded choices'), - 'C': ('tab:blue', 'choices'), - 'bias': ('k', 'right bias')} - - for name, (col, label) in plot_spec.items(): - - means = [df_all.query(f'beta == "{name}" and trials_back == {t}')['mean'].mean() - for t in ([0] if name == "bias" else xx)] - if np.all(np.isnan(means)): - continue - - cis = [df_all.query(f'beta == "{name}" and trials_back == {t}')['mean'].sem() * 1.96 - for t in ([0] if name == "bias" else xx)] - - ax.errorbar(x=1 if name == 'bias' else xx, - y=means, - yerr=cis, - ls='-', - marker='o', - color=col, - capsize=5, markeredgewidth=1, - label=label + ' $\pm$95% CI') - - ax.legend() - ax.set(xlabel='Past trials', ylabel='Logistic regression coeffs') - ax.axhline(y=0, color='k', linestyle=':', linewidth=0.5) - ax.set(xticks=[1, 5, 10], - ylim=(-0.1, 2.5)) - - sns.despine(trim=True) - - n_mice = len(df_all['h2o'].unique()) - n_sessions = len(df_all.groupby(['h2o', 'session']).count()) - fig.suptitle(f'Logistic regression on choice ({n_mice} mice, {n_sessions} sessions)') - - - return fig - - - -# @st.cache_data(ttl=3600*24) -def plot_linear_regression_rt_photostim(df_all, beta_names): - if not len(df_all): return None - - fig, axes = plt.subplots(2, 3, - figsize=(15, 10), constrained_layout=False, - dpi=200, - gridspec_kw=dict(hspace=0.3, wspace=0.3, top=0.9, bottom=0.1, left=0.1, right=0.9), - ) - axes = np.atleast_2d(axes) - - for j, beta_name in enumerate(beta_names): - _draw_variable_trial_back_linear_reg(df_all, beta_name, ax=axes.flat[j]) - - return fig - - - -def plot_linear_regression_rt_non_photostim(df_all, ax=None): - ''' - plot list of linear regressions - ''' - if not len(df_all): return None - - - if ax is None: - fig, ax = plt.subplots(1, 1, figsize=(5, 4), dpi=200, - gridspec_kw=dict(bottom=0.2, left=0.2)) - - gs = ax._subplotspec.subgridspec(1, 2, width_ratios=[1, 3], wspace=0.5) - ax_others = ax.get_figure().add_subplot(gs[0, 0]) - ax_reward = ax.get_figure().add_subplot(gs[0, 1]) - - # Other paras - other_names = ['previous_iti', 'this_choice', 'trial_number'] - xx = np.arange(len(other_names)) - means = [df_all.query(f'variable == "{name}"')['beta'].mean() for name in other_names] - sems = [df_all.query(f'variable == "{name}"')['beta'].sem() * 1.96 for name in other_names] - ax_others.errorbar(x=xx, - y=means, - yerr=sems, - ls='none', - marker='o', - color='k', - capsize=5, markeredgewidth=1, - ) - ax_others.set(xlim=(-0.5, 2.5), ylim=(-0.3, 1)) - - ax_others.set_xticks(range(len(other_names))) - ax_others.set_xticklabels(other_names, rotation=45, ha='right') - ax_others.axhline(y=0, color='k', linestyle=':', linewidth=1) - ax_others.set(ylabel=r'Linear regression $\beta \pm$95% CI') - - # Back rewards - xx = np.arange(1, 11) - means = [df_all.query(f'variable == "reward" and trials_back == {x}')['beta'].mean() for x in xx] - sems = [df_all.query(f'variable == "reward" and trials_back == {x}')['beta'].sem() * 1.96 for x in xx] - - ax_reward.errorbar(x=xx, - y=means, - yerr=sems, - ls='-', - marker='o', - color='k', - capsize=5, - markeredgewidth=1, - ) - - ax_reward.set(xlabel='Reward of past trials') - ax_reward.axhline(y=0, color='k', linestyle=':', linewidth=1) - ax_reward.set(xticks=[1, 5, 10], ylim=(-0.35, 0.05)) - ax_reward.invert_yaxis() - - sns.despine(trim=True) - ax.remove() - - n_mice = len(df_all['h2o'].unique()) - n_sessions = len(df_all.groupby(['h2o', 'session']).count()) - fig.suptitle(f'Linear regression on RT ({n_mice} mice, {n_sessions} sessions)') - - return fig - - - -if 'df' not in st.session_state: - from Home import init - init() - -app() - -# try: -# app() -# except: -# st.markdown('### Something is wrong. Try going back to 🏠Home or refresh.') \ No newline at end of file diff --git a/code/util/streamlit.py b/code/util/streamlit.py index 778f31e..4d412ac 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -6,10 +6,8 @@ from st_aggrid.shared import GridUpdateMode, ColumnsAutoSizeMode, DataReturnMode from pandas.api.types import ( is_categorical_dtype, - is_datetime64_any_dtype, is_numeric_dtype, is_string_dtype, - is_object_dtype, ) import streamlit.components.v1 as components from streamlit_plotly_events import plotly_events @@ -22,7 +20,13 @@ import statsmodels.api as sm from scipy.stats import linregress -from .url_query_helper import checkbox_wrapper_for_url_query, selectbox_wrapper_for_url_query, slider_wrapper_for_url_query +from .url_query_helper import ( + checkbox_wrapper_for_url_query, + selectbox_wrapper_for_url_query, + slider_wrapper_for_url_query, + multiselect_wrapper_for_url_query, + get_filter_type, + ) from .plot_autotrain_manager import plot_manager_all_progress from .aws_s3 import draw_session_plots_quick_preview @@ -186,6 +190,7 @@ def cache_widget(field, clear=None): # def dec_cache_widget_state(widget, ): + def filter_dataframe(df: pd.DataFrame, default_filters=['h2o', 'task', 'finished_trials', 'photostim_location'], url_query={}) -> pd.DataFrame: @@ -214,20 +219,25 @@ def filter_dataframe(df: pd.DataFrame, cols = st.columns([1, 1.5]) cols[0].markdown(f"Add filters") if_reset_filters = cols[1].button(label="Reset filters") - to_filter_columns = st.multiselect("Filter dataframe on", df.columns, - label_visibility='collapsed', - default=st.session_state.to_filter_columns_changed - if 'to_filter_columns_changed' in st.session_state - else default_filters, - key='to_filter_columns', - on_change=cache_widget, - args=['to_filter_columns']) + + to_filter_columns = multiselect_wrapper_for_url_query( + st_prefix=st, + label="Filter dataframe on", + options=df.columns, + default=['subject_id', 'session', 'finished_trials', 'foraging_eff', 'task'], + key='to_filter_columns', + label_visibility='collapsed', + ) + for column in to_filter_columns: if not len(df): break left, right = st.columns((1, 20)) # Treat columns with < 10 unique values as categorical - if is_categorical_dtype(df[column]) or df[column].nunique() < 10 and column not in ('finished', 'foraging_eff', 'session', 'finished_trials'): + + filter_type = get_filter_type(df, column) + + if filter_type == 'multiselect': right.markdown(f"Filter for :red[**{column}**]") if if_reset_filters: @@ -243,6 +253,7 @@ def filter_dataframe(df: pd.DataFrame, else: default_value = list(df[column].unique()) + default_value = [v for v in default_value if v in list(df[column].unique())] st.session_state[f'filter_{column}'] = default_value selected = right.multiselect( @@ -256,7 +267,7 @@ def filter_dataframe(df: pd.DataFrame, ) df = df[df[column].isin(selected)] - elif is_numeric_dtype(df[column]): + elif filter_type == 'slider_range_float': # fig = px.histogram(df[column], nbins=100, ) # fig.update_layout(showlegend=False, height=50) @@ -339,7 +350,7 @@ def filter_dataframe(df: pd.DataFrame, df = df[x.between(*user_num_input)] - elif is_datetime64_any_dtype(df[column]): + elif filter_type == 'slider_range_date': user_date_input = right.date_input( f"Values for :red[**{column}**]", value=st.session_state[f'filter_{column}_changed'] @@ -354,8 +365,8 @@ def filter_dataframe(df: pd.DataFrame, user_date_input = tuple(map(pd.to_datetime, user_date_input)) start_date, end_date = user_date_input df = df.loc[df[column].between(start_date, end_date)] - else: # Regular string + elif filter_type == 'reg_ex': if if_reset_filters: default_value = '' st.session_state[f'filter_{column}_changed'] = default_value diff --git a/code/util/url_query_helper.py b/code/util/url_query_helper.py index cb323b6..50eea0b 100644 --- a/code/util/url_query_helper.py +++ b/code/util/url_query_helper.py @@ -1,5 +1,64 @@ import streamlit as st +from pandas.api.types import ( + is_categorical_dtype, + is_datetime64_any_dtype, + is_numeric_dtype, +) + +# Sync widgets with URL query params +# https://blog.streamlit.io/how-streamlit-uses-streamlit-sharing-contextual-apps/ +# dict of "key": default pairs +# Note: When creating the widget, add argument "value"/"index" as well as "key" for all widgets you want to sync with URL +to_sync_with_url_query_default = { + 'if_load_bpod_sessions': False, + + 'to_filter_columns': ['subject_id', 'task', 'session', 'finished_trials', 'foraging_eff'], + 'filter_subject_id': '', + 'filter_session': [0.0, None], + 'filter_finished_trials': [0.0, None], + 'filter_foraging_eff': [0.0, None], + 'filter_task': ['all'], + + 'table_height': 300, + + 'tab_id': 'tab_session_x_y', + 'x_y_plot_xname': 'session', + 'x_y_plot_yname': 'foraging_performance_random_seed', + 'x_y_plot_group_by': 'h2o', + 'x_y_plot_if_show_dots': True, + 'x_y_plot_if_aggr_each_group': True, + 'x_y_plot_aggr_method_group': 'lowess', + 'x_y_plot_if_aggr_all': True, + 'x_y_plot_aggr_method_all': 'mean +/- sem', + 'x_y_plot_smooth_factor': 5, + 'x_y_plot_if_use_x_quantile_group': False, + 'x_y_plot_q_quantiles_group': 20, + 'x_y_plot_if_use_x_quantile_all': False, + 'x_y_plot_q_quantiles_all': 20, + 'x_y_plot_if_show_diagonal': False, + 'x_y_plot_dot_size': 10, + 'x_y_plot_dot_opacity': 0.3, + 'x_y_plot_line_width': 2.0, + 'x_y_plot_figure_width': 1300, + 'x_y_plot_figure_height': 900, + 'x_y_plot_font_size_scale': 1.0, + 'x_y_plot_selected_color_map': 'Plotly', + + 'x_y_plot_size_mapper': 'finished_trials', + 'x_y_plot_size_mapper_gamma': 1.0, + 'x_y_plot_size_mapper_range': [3, 20], + + 'session_plot_mode': 'sessions selected from table or plot', + + 'auto_training_history_x_axis': 'session', + 'auto_training_history_sort_by': 'subject_id', + 'auto_training_history_sort_order': 'descending', + 'auto_training_curriculum_name': 'Uncoupled Baiting', + 'auto_training_curriculum_version': '1.0', + 'auto_training_curriculum_schema_version': '1.0', + } + def checkbox_wrapper_for_url_query(st_prefix, label, key, default, **kwargs): return st_prefix.checkbox( label, @@ -26,6 +85,21 @@ def selectbox_wrapper_for_url_query(st_prefix, label, options, key, default, **k key=key, **kwargs, ) + +def multiselect_wrapper_for_url_query(st_prefix, label, options, key, default, **kwargs): + return st_prefix.multiselect( + label, + options=options, + default=( + st.session_state[key] + if key in st.session_state + else st.query_params[key] + if key in st.query_params + else default + ), + key=key, + **kwargs, + ) def slider_wrapper_for_url_query(st_prefix, label, min_value, max_value, key, default, **kwargs): @@ -51,32 +125,107 @@ def slider_wrapper_for_url_query(st_prefix, label, min_value, max_value, key, de ) -def sync_widget_with_query(key, default): - if key in st.query_params: - # always get all query params as a list - q_all = st.query_params.get_all(key) +def sync_URL_to_session_state(): + """Assign session_state to sync with URL""" + + to_sync_with_session_state_dynamic_filter_added = list( + set( + list(st.query_params.keys()) + + list(to_sync_with_url_query_default.keys()) + ) + ) + + for key in to_sync_with_session_state_dynamic_filter_added: - # convert type according to default - list_default = default if isinstance(default, list) else [default] - for d in list_default: - _type = type(d) - if _type: break # The first non-None type + if key in to_sync_with_url_query_default: + # If in default list, get default value there + default = to_sync_with_url_query_default[key] + else: + # Else, the user should have added a new filter column + # let's get the type from the dataframe directly. + # + # Also, in this case, the key must be from st.query_params, + # so the only purpose of getting default is to get the correct type, + # not its value per se. + filter_type = get_filter_type(st.session_state.df['sessions_bonsai'], + key.replace('filter_', '')) + if filter_type == 'slider_range_float': + default = [0.0, 1.0] + elif filter_type == 'reg_ex': + default = '' + elif filter_type == 'multiselect': + default = ['a', 'b'] + else: + print('sync_URL_to_session_state: Unrecognized filter type') + continue - if _type == bool: - q_all_correct_type = [q.lower() == 'true' for q in q_all] + if key in st.query_params: + # always get all query params as a list + q_all = st.query_params.get_all(key) + + # convert type according to default + list_default = default if isinstance(default, list) else [default] + for d in list_default: + _type = type(d) + if _type: break # The first non-None type + + if _type == bool: + q_all_correct_type = [q.lower() == 'true' for q in q_all] + else: + q_all_correct_type = [_type(q) + if q.lower() != 'none' + else None + for q in q_all] + + # flatten list if only one element + if not isinstance(default, list): + q_all_correct_type = q_all_correct_type[0] + + try: + st.session_state[key] = q_all_correct_type + # print(f'sync_URL_to_session_state: Set {key} to {q_all_correct_type}') + except: + print(f'sync_URL_to_session_state: Failed to set {key} to {q_all_correct_type}') else: - q_all_correct_type = [_type(q) for q in q_all] - - # flatten list if only one element - if not isinstance(default, list): - q_all_correct_type = q_all_correct_type[0] - - try: - st.session_state[key] = q_all_correct_type - except: - print(f'Failed to set {key} to {q_all_correct_type}') - else: + try: + st.session_state[key] = default + except: + print(f'Failed to set {key} to {default}') + + +def sync_session_state_to_URL(): + # Add all 'filter_' fields to the default list + # so that all dynamic filters are synced with URL + to_sync_with_url_query_dynamic_filter_added = list( + set( + list(to_sync_with_url_query_default.keys()) + + [ + filter_name for filter_name in st.session_state + if ( + filter_name.startswith('filter_') + and not (filter_name.endswith('_changed')) + ) + ] + ) + ) + for key in to_sync_with_url_query_dynamic_filter_added: try: - st.session_state[key] = default + st.query_params.update({key: st.session_state[key]}) except: - print(f'Failed to set {key} to {default}') \ No newline at end of file + print(f'Failed to update {key} to URL query') + + +def get_filter_type(df, column): + if is_numeric_dtype(df[column]): + return 'slider_range_float' + + if (is_categorical_dtype(df[column]) + or df[column].nunique() < 10 + or column in ('user_name') # pin to multiselect + ): + return 'multiselect' + + if is_datetime64_any_dtype(df[column]): + return 'slider_range_date' + + return 'reg_ex' # Default