diff --git a/code/Home.py b/code/Home.py
index cce7c18..24a418d 100644
--- a/code/Home.py
+++ b/code/Home.py
@@ -12,7 +12,7 @@
"""
-__ver__ = 'v2.2.2'
+__ver__ = 'v2.3.0'
import pandas as pd
import streamlit as st
@@ -35,66 +35,14 @@
show_debug_info,
)
from util.url_query_helper import (
- sync_widget_with_query, slider_wrapper_for_url_query, checkbox_wrapper_for_url_query
+ sync_URL_to_session_state, sync_session_state_to_URL,
+ slider_wrapper_for_url_query, checkbox_wrapper_for_url_query
)
from aind_auto_train.curriculum_manager import CurriculumManager
from aind_auto_train.auto_train_manager import DynamicForagingAutoTrainManager
-# Sync widgets with URL query params
-# https://blog.streamlit.io/how-streamlit-uses-streamlit-sharing-contextual-apps/
-# dict of "key": default pairs
-# Note: When creating the widget, add argument "value"/"index" as well as "key" for all widgets you want to sync with URL
-to_sync_with_url_query = {
- 'if_load_bpod_sessions': False,
-
- 'filter_subject_id': '',
- 'filter_session': [0.0, None],
- 'filter_finished_trials': [0.0, None],
- 'filter_foraging_eff': [0.0, None],
- 'filter_task': ['all'],
-
- 'table_height': 300,
-
- 'tab_id': 'tab_session_x_y',
- 'x_y_plot_xname': 'session',
- 'x_y_plot_yname': 'foraging_performance_random_seed',
- 'x_y_plot_group_by': 'h2o',
- 'x_y_plot_if_show_dots': True,
- 'x_y_plot_if_aggr_each_group': True,
- 'x_y_plot_aggr_method_group': 'lowess',
- 'x_y_plot_if_aggr_all': True,
- 'x_y_plot_aggr_method_all': 'mean +/- sem',
- 'x_y_plot_smooth_factor': 5,
- 'x_y_plot_if_use_x_quantile_group': False,
- 'x_y_plot_q_quantiles_group': 20,
- 'x_y_plot_if_use_x_quantile_all': False,
- 'x_y_plot_q_quantiles_all': 20,
- 'x_y_plot_if_show_diagonal': False,
- 'x_y_plot_dot_size': 10,
- 'x_y_plot_dot_opacity': 0.3,
- 'x_y_plot_line_width': 2.0,
- 'x_y_plot_figure_width': 1300,
- 'x_y_plot_figure_height': 900,
- 'x_y_plot_font_size_scale': 1.0,
- 'x_y_plot_selected_color_map': 'Plotly',
-
- 'x_y_plot_size_mapper': 'finished_trials',
- 'x_y_plot_size_mapper_gamma': 1.0,
- 'x_y_plot_size_mapper_range': [3, 20],
-
- 'session_plot_mode': 'sessions selected from table or plot',
-
- 'auto_training_history_x_axis': 'session',
- 'auto_training_history_sort_by': 'subject_id',
- 'auto_training_history_sort_order': 'descending',
- 'auto_training_curriculum_name': 'Uncoupled Baiting',
- 'auto_training_curriculum_version': '1.0',
- 'auto_training_curriculum_schema_version': '1.0',
- }
-
-
try:
st.set_page_config(layout="wide",
page_title='Foraging behavior browser',
@@ -336,14 +284,10 @@ def init():
if key in ['selected_draw_types'] or '_changed' in key:
del st.session_state[key]
- # Set session state from URL
- for key, default in to_sync_with_url_query.items():
- sync_widget_with_query(key, default)
-
df = load_data(['sessions'], data_source='bonsai')
# --- Perform any data source-dependent preprocessing here ---
- if st.session_state.if_load_bpod_sessions:
+ if (st.session_state.if_load_bpod_sessions if 'if_load_bpod_sessions' in st.session_state else False):
df_bpod = load_data(['sessions'], data_source='bpod')
# For historial reason, the suffix of df['sessions_bonsai'] just mean the data of the Home.py page
@@ -526,6 +470,9 @@ def _get_data_source(rig):
st.session_state.df['sessions_bonsai'] = _df # Somehow _df loses the reference to the original dataframe
st.session_state.session_stats_names = [keys for keys in _df.keys()]
+
+ # Set session state from URL
+ sync_URL_to_session_state()
# Establish communication between pygwalker and streamlit
init_streamlit_comm()
@@ -790,14 +737,9 @@ def app():
st.markdown('---\n##### Debug zone')
show_debug_info()
-
# Update back to URL
- for key in to_sync_with_url_query:
- try:
- st.query_params.update({key: st.session_state[key]})
- except:
- print(f'Failed to update {key} to URL query')
+ sync_session_state_to_URL()
# st.dataframe(st.session_state.df_session_filtered, use_container_width=True, height=1000)
diff --git a/code/pages/1_Old mice.py b/code/pages/1_Old mice.py
deleted file mode 100644
index 25c08a0..0000000
--- a/code/pages/1_Old mice.py
+++ /dev/null
@@ -1,686 +0,0 @@
-#%%
-import pandas as pd
-import streamlit as st
-from pathlib import Path
-import glob
-import matplotlib.pyplot as plt
-import numpy as np
-from datetime import datetime
-import s3fs
-import os
-import plotly.express as px
-import plotly
-import plotly.graph_objects as go
-import statsmodels.api as sm
-
-from PIL import Image, ImageColor
-import streamlit.components.v1 as components
-import streamlit_nested_layout
-from streamlit_plotly_events import plotly_events
-
-from util.streamlit import (filter_dataframe, aggrid_interactive_table_session, add_session_filter, data_selector,
- add_xy_selector, add_xy_setting, add_auto_train_manager,
- _plot_population_x_y)
-from util.url_query_helper import sync_widget_with_query
-
-import extra_streamlit_components as stx
-
-from aind_auto_train.auto_train_manager import DynamicForagingAutoTrainManager
-from pygwalker.api.streamlit import StreamlitRenderer, init_streamlit_comm
-
-
-# Sync widgets with URL query params
-# https://blog.streamlit.io/how-streamlit-uses-streamlit-sharing-contextual-apps/
-# dict of "key": default pairs
-# Note: When creating the widget, add argument "value"/"index" as well as "key" for all widgets you want to sync with URL
-to_sync_with_url_query = {
- 'filter_h2o': '',
- 'filter_session': [0.0, None],
- 'filter_finished_trials': [0.0, None],
- 'filter_foraging_eff': [0.0, None],
- 'filter_task': ['all'],
- 'filter_photostim_location': ['all'],
-
- 'tab_id': 'tab_session_x_y',
- 'x_y_plot_xname': 'session',
- 'x_y_plot_yname': 'foraging_performance',
- 'x_y_plot_group_by': 'h2o',
- 'x_y_plot_if_show_dots': True,
- 'x_y_plot_if_aggr_each_group': True,
- 'x_y_plot_aggr_method_group': 'lowess',
- 'x_y_plot_if_aggr_all': True,
- 'x_y_plot_aggr_method_all': 'mean +/- sem',
- 'x_y_plot_smooth_factor': 5,
- 'x_y_plot_if_use_x_quantile_group': False,
- 'x_y_plot_q_quantiles_group': 20,
- 'x_y_plot_if_use_x_quantile_all': False,
- 'x_y_plot_q_quantiles_all': 20,
- 'x_y_plot_dot_size': 7,
- 'x_y_plot_dot_opacity': 0.2,
- 'x_y_plot_line_width': 2.0,
-
- 'auto_training_history_x_axis': 'session',
- 'auto_training_history_sort_by': 'progress_to_graduated',
- 'auto_training_history_sort_order': 'descending',
- }
-
-
-if_profile = False
-
-if if_profile:
- from streamlit_profiler import Profiler
- p = Profiler()
- p.start()
-
-
-# from pipeline import experiment, ephys, lab, psth_foraging, report, foraging_analysis
-# from pipeline.plot import foraging_model_plot
-
-cache_folder = 'xxx' #'/root/capsule/data/s3/report/st_cache/'
-cache_session_level_fig_folder = 'xxx' #'/root/capsule/data/s3/report/all_units/' #
-
-if os.path.exists(cache_folder):
- st.session_state.st.session_state.use_s3 = False
-else:
- cache_folder = 'aind-behavior-data/Han/ephys/report/st_cache/'
- cache_session_level_fig_folder = 'aind-behavior-data/Han/ephys/report/all_sessions/'
- cache_mouse_level_fig_folder = 'aind-behavior-data/Han/ephys/report/all_subjects/'
-
- fs = s3fs.S3FileSystem(anon=False)
- st.session_state.use_s3 = True
-
-try:
- st.set_page_config(layout="wide",
- page_title='Foraging behavior browser',
- page_icon=':mouse2:',
- menu_items={
- 'Report a bug': "https://github.com/hanhou/foraging-behavior-browser/issues",
- 'About': "Github repo: https://github.com/hanhou/foraging-behavior-browser/"
- }
- )
-except:
- pass
-
-if 'selected_points' not in st.session_state:
- st.session_state['selected_points'] = []
-
-
-@st.cache_data(ttl=24*3600)
-def load_data(tables=['sessions']):
- df = {}
- for table in tables:
- file_name = cache_folder + f'df_{table}.pkl'
- if st.session_state.use_s3:
- with fs.open(file_name) as f:
- df[table] = pd.read_pickle(f)
- else:
- df[table] = pd.read_pickle(file_name)
- return df
-
-def _fetch_img(glob_patterns, crop=None):
- # Fetch the img that first matches the patterns
- for pattern in glob_patterns:
- file = fs.glob(pattern) if st.session_state.use_s3 else glob.glob(pattern)
- if len(file): break
-
- if not len(file):
- return None, None
-
- try:
- if st.session_state.use_s3:
- with fs.open(file[0]) as f:
- img = Image.open(f)
- img = img.crop(crop)
- else:
- img = Image.open(file[0])
- img = img.crop(crop)
- except:
- st.write('File found on S3 but failed to load...')
- return None, None
-
- return img, file[0]
-
-
-# @st.cache_data(ttl=24*3600, max_entries=20)
-def show_session_level_img_by_key_and_prefix(key, prefix, column=None, other_patterns=[''], crop=None, caption=True, **kwargs):
- try:
- sess_date_str = datetime.strftime(datetime.strptime(key['session_date'], '%Y-%m-%dT%H:%M:%S'), '%Y%m%d')
- except:
- sess_date_str = datetime.strftime(key['session_date'], '%Y%m%d')
-
- fns = [f'/{key["h2o"]}_{sess_date_str}_*{other_pattern}*' for other_pattern in other_patterns]
- glob_patterns = [cache_session_level_fig_folder + f'{prefix}/' + key["h2o"] + fn for fn in fns]
-
- img, f_name = _fetch_img(glob_patterns, crop)
-
- _f = st if column is None else column
-
- _f.image(img if img is not None else "https://cdn-icons-png.flaticon.com/512/3585/3585596.png",
- output_format='PNG',
- caption=f_name.split('/')[-1] if caption and f_name else '',
- use_column_width='always',
- **kwargs)
-
- return img
-
-def show_mouse_level_img_by_key_and_prefix(key, prefix, column=None, other_patterns=[''], crop=None, caption=True, **kwargs):
-
- fns = [f'/{key["h2o"]}_*{other_pattern}*' for other_pattern in other_patterns]
- glob_patterns = [cache_mouse_level_fig_folder + f'{prefix}/' + fn for fn in fns]
-
- img, f_name = _fetch_img(glob_patterns, crop)
-
- if img is None: # Use "not_found" image
- glob_patterns = [cache_mouse_level_fig_folder + f'{prefix}/not_found_*{other_pattern}**' for other_pattern in other_patterns]
- img, f_name = _fetch_img(glob_patterns, crop)
-
- _f = st if column is None else column
-
- _f.image(img if img is not None else "https://cdn-icons-png.flaticon.com/512/3585/3585596.png",
- output_format='PNG',
- #caption=f_name.split('/')[-1] if caption and f_name else '',
- use_column_width='always',
- **kwargs)
-
- return img
-
-# table_mapping = {
-# 'sessions': fetch_sessions,
-# 'ephys_units': fetch_ephys_units,
-# }
-
-
-def draw_session_plots(df_to_draw_session):
-
- # Setting up layout for each session
- layout_definition = [[1], # columns in the first row
- [1.5, 1], # columns in the second row
- [1, 1],
- ]
-
- # cols_option = st.columns([3, 0.5, 1])
- container_session_all_in_one = st.container()
-
- with container_session_all_in_one:
- # with st.expander("Expand to see all-in-one plot for selected unit", expanded=True):
-
- if len(df_to_draw_session):
- st.write(f'Loading selected {len(df_to_draw_session)} sessions...')
- my_bar = st.columns((1, 7))[0].progress(0)
-
- major_cols = st.columns([1] * st.session_state.num_cols)
-
- for i, key in enumerate(df_to_draw_session.to_dict(orient='records')):
- this_major_col = major_cols[i % st.session_state.num_cols]
-
- # setting up layout for each session
- rows = []
- with this_major_col:
-
- try:
- date_str = key["session_date"].strftime('%Y-%m-%d')
- except:
- date_str = key["session_date"].split("T")[0]
-
- st.markdown(f'''
{key["h2o"]}, Session {key["session"]}, {date_str}''',
- unsafe_allow_html=True)
- if len(st.session_state.selected_draw_types) > 1: # more than one types, use the pre-defined layout
- for row, column_setting in enumerate(layout_definition):
- rows.append(this_major_col.columns(column_setting))
- else: # else, put it in the whole column
- rows = this_major_col.columns([1])
- st.markdown("---")
-
- for draw_type in st.session_state.draw_type_mapper_session_level:
- if draw_type not in st.session_state.selected_draw_types: continue # To keep the draw order defined by st.session_state.draw_type_mapper_session_level
- prefix, position, setting = st.session_state.draw_type_mapper_session_level[draw_type]
- this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types) > 1 else rows[0]
- show_session_level_img_by_key_and_prefix(key,
- column=this_col,
- prefix=prefix,
- **setting)
-
- my_bar.progress(int((i + 1) / len(df_to_draw_session) * 100))
-
-
-
-def draw_mice_plots(df_to_draw_mice):
-
- # Setting up layout for each session
- layout_definition = [[1], # columns in the first row
- ]
-
- # cols_option = st.columns([3, 0.5, 1])
- container_session_all_in_one = st.container()
-
- with container_session_all_in_one:
- # with st.expander("Expand to see all-in-one plot for selected unit", expanded=True):
-
- if len(df_to_draw_mice):
- st.write(f'Loading selected {len(df_to_draw_mice)} mice...')
- my_bar = st.columns((1, 7))[0].progress(0)
-
- major_cols = st.columns([1] * st.session_state.num_cols_mice)
-
- for i, key in enumerate(df_to_draw_mice.to_dict(orient='records')):
- this_major_col = major_cols[i % st.session_state.num_cols_mice]
-
- # setting up layout for each session
- rows = []
- with this_major_col:
- st.markdown(f'''{key["h2o"]}''',
- unsafe_allow_html=True)
- if len(st.session_state.selected_draw_types_mice) > 1: # more than one types, use the pre-defined layout
- for row, column_setting in enumerate(layout_definition):
- rows.append(this_major_col.columns(column_setting))
- else: # else, put it in the whole column
- rows = this_major_col.columns([1])
- st.markdown("---")
-
- for draw_type in st.session_state.draw_type_mapper_mouse_level:
- if draw_type not in st.session_state.selected_draw_types_mice: continue
- prefix, position, setting = st.session_state.draw_type_mapper_mouse_level[draw_type]
- this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types_mice) > 1 else rows[0]
- show_mouse_level_img_by_key_and_prefix(key,
- column=this_col,
- prefix=prefix,
- **setting)
-
-def session_plot_settings(need_click=True):
- st.markdown('##### Show plots for individual sessions ')
- cols = st.columns([2, 1])
- st.session_state.selected_draw_sessions = cols[0].selectbox('Which session(s) to draw?',
- [f'selected from table/plot ({len(st.session_state.df_selected_from_plotly)} sessions)',
- f'filtered from sidebar ({len(st.session_state.df_session_filtered)} sessions)'],
- index=0
- )
- st.session_state.num_cols = cols[1].number_input('Number of columns', 1, 10,
- 3 if 'num_cols' not in st.session_state else st.session_state.num_cols)
-
- st.markdown(
- """
- """,
- unsafe_allow_html=True,
- )
- st.session_state.selected_draw_types = st.multiselect('Which plot(s) to draw?',
- st.session_state.draw_type_mapper_session_level.keys(),
- default=st.session_state.draw_type_mapper_session_level.keys()
- if 'selected_draw_types' not in st.session_state else
- st.session_state.selected_draw_types)
- if need_click:
- draw_it = st.button('Show me all sessions!', use_container_width=True)
- else:
- draw_it = True
- return draw_it
-
-def mouse_plot_settings(need_click=True):
- st.markdown('##### Show plots for individual mice ')
- cols = st.columns([2, 1])
- st.session_state.selected_draw_mice = cols[0].selectbox('Which mice to draw?',
- [f'selected from table/plot ({len(st.session_state.df_selected_from_plotly.h2o.unique())} mice)',
- f'filtered from sidebar ({len(st.session_state.df_session_filtered.h2o.unique())} mice)'],
- index=0
- )
- st.session_state.num_cols_mice = cols[1].number_input('Number of columns', 1, 10,
- 3 if 'num_cols_mice' not in st.session_state else st.session_state.num_cols_mice)
- st.markdown(
- """
- """,
- unsafe_allow_html=True,
- )
- st.session_state.selected_draw_types_mice = st.multiselect('Which plot(s) to draw?',
- st.session_state.draw_type_mapper_mouse_level.keys(),
- default=st.session_state.draw_type_mapper_mouse_level.keys()
- if 'selected_draw_types_mice' not in st.session_state else
- st.session_state.selected_draw_types_mice)
- if need_click:
- draw_it = st.button('Show me all mice!', use_container_width=True)
- else:
- draw_it = True
- return draw_it
-
-
-def plot_x_y_session():
-
- cols = st.columns([4, 10])
-
- with cols[0]:
- x_name, y_name, group_by = add_xy_selector(if_bonsai=False)
-
- (if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group,
- if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, if_show_diagonal,
- dot_size, dot_opacity, line_width, figure_width, figure_height, font_size_scale, color_map) = add_xy_setting()
-
-
- # If no sessions are selected, use all filtered entries
- # df_x_y_session = st.session_state.df_selected_from_dataframe if if_plot_only_selected_from_dataframe else st.session_state.df_session_filtered
- df_x_y_session = st.session_state.df_session_filtered
-
- names = {('session', 'foraging_eff'): 'Foraging efficiency',
- ('session', 'finished'): 'Finished trials',
- }
-
- df_selected_from_plotly = pd.DataFrame()
- # for i, (title, (x_name, y_name)) in enumerate(names.items()):
- # with cols[i]:
- with cols[1]:
- fig = _plot_population_x_y(df=df_x_y_session,
- x_name=x_name, y_name=y_name,
- group_by=group_by,
- smooth_factor=smooth_factor,
- if_show_dots=if_show_dots,
- if_aggr_each_group=if_aggr_each_group,
- if_aggr_all=if_aggr_all,
- aggr_method_group=aggr_method_group,
- aggr_method_all=aggr_method_all,
- if_use_x_quantile_group=if_use_x_quantile_group,
- q_quantiles_group=q_quantiles_group,
- if_use_x_quantile_all=if_use_x_quantile_all,
- q_quantiles_all=q_quantiles_all,
- title=names[(x_name, y_name)] if (x_name, y_name) in names else y_name,
- states = st.session_state.df_selected_from_plotly,
- dot_size_base=dot_size,
- dot_opacity=dot_opacity,
- line_width=line_width)
-
- # st.plotly_chart(fig)
- selected = plotly_events(fig, click_event=True, hover_event=False, select_event=True,
- override_height=fig.layout.height * 1.1, override_width=fig.layout.width)
-
- if len(selected):
- df_selected_from_plotly = df_x_y_session.merge(pd.DataFrame(selected).rename({'x': x_name, 'y': y_name}, axis=1),
- on=[x_name, y_name], how='inner')
-
- return df_selected_from_plotly, cols
-
-
-# ------- Layout starts here -------- #
-def init():
-
- # Clear specific session state and all filters
- for key in st.session_state:
- if key in ['selected_draw_types'] or '_changed' in key:
- del st.session_state[key]
-
- # Set session state from URL
- for key, default in to_sync_with_url_query.items():
- sync_widget_with_query(key, default)
-
- df = load_data(['sessions',
- 'logistic_regression_hattori',
- 'logistic_regression_su',
- 'linear_regression_rt',
- 'model_fitting_params'])
-
- # Try to convert datetimes into a standard format (datetime, no timezone)
- df['sessions']['session_date'] = pd.to_datetime(df['sessions']['session_date'])
- # if is_datetime64_any_dtype(df[col]):
- df['sessions']['session_date'] = df['sessions']['session_date'].dt.tz_localize(None)
- df['sessions']['photostim_location'].fillna('None', inplace=True)
-
- st.session_state.df = df
- st.session_state.df_selected_from_plotly = pd.DataFrame(columns=['h2o', 'session'])
- st.session_state.df_selected_from_dataframe = pd.DataFrame(columns=['h2o', 'session'])
-
- # Init auto training database
- st.session_state.auto_train_manager = DynamicForagingAutoTrainManager(
- manager_name='Janelia_demo',
- df_behavior_on_s3=dict(bucket='aind-behavior-data',
- root='Han/ephys/report/all_sessions/export_all_nwb/',
- file_name='df_sessions.pkl'),
- df_manager_root_on_s3=dict(bucket='aind-behavior-data',
- root='foraging_auto_training/')
- )
-
- # Init session states
- to_init = [
- ['model_id', 21], # add some model fitting params to session
- ]
-
- for name, default in to_init:
- if name not in st.session_state:
- st.session_state[name] = default
-
- selected_id = st.session_state.model_id
-
- st.session_state.draw_type_mapper_session_level = {'1. Choice history': ('fitted_choice', # prefix
- (0, 0), # location (row_idx, column_idx)
- dict(other_patterns=['model_best', 'model_None'])),
- '2. Lick times': ('lick_psth',
- (1, 0),
- {}),
- '3. Win-stay-lose-shift prob.': ('wsls',
- (1, 1),
- dict(crop=(0, 0, 1200, 600))),
- '4. Linear regression on RT': ('linear_regression_rt',
- (1, 1),
- dict()),
- '5. Logistic regression on choice (Hattori)': ('logistic_regression_hattori',
- (2, 0),
- dict(crop=(0, 0, 1200, 2000))),
- '6. Logistic regression on choice (Su)': ('logistic_regression_su',
- (2, 1),
- dict(crop=(0, 0, 1200, 2000))),
- }
-
- st.session_state.draw_type_mapper_mouse_level = {'1. Model comparison': ('model_all_sessions', # prefix
- (0, 0), # location (row_idx, column_idx)
- dict(other_patterns=['comparison'],
- crop=(0, #900,
- 100, 2800, 2200))),
- '2. Model prediction accuracy': ('model_all_sessions',
- (0, 0),
- dict(other_patterns=['pred_acc'])),
- '3. Model fitted parameters': ('model_all_sessions',
- (0, 0),
- dict(other_patterns=['fitted_para'])),
- }
-
-
- # process dfs
- df_this_model = st.session_state.df['model_fitting_params'].query(f'model_id == {selected_id}')
- valid_field = df_this_model.columns[~np.all(~df_this_model.notna(), axis=0)]
- to_add_model = st.session_state.df['model_fitting_params'].query(f'model_id == {selected_id}')[valid_field]
- st.session_state.df['sessions'].drop(st.session_state.df['sessions'].query('session < 1').index, inplace=True)
-
- st.session_state.df['sessions'] = st.session_state.df['sessions'].merge(to_add_model, on=('subject_id', 'session'), how='left')
-
- # add something else
- st.session_state.df['sessions']['abs(bias)'] = np.abs(st.session_state.df['sessions'].biasL)
-
- # delta weight
- diff_relative_weight_next_day = st.session_state.df['sessions'].set_index(
- ['session']).sort_values('session', ascending=True).groupby('h2o').apply(
- lambda x: - x.relative_weight.diff(periods=-1)).rename("diff_relative_weight_next_day")
-
- # foraging performance = foraing_eff * finished_ratio
- if 'foraging_performance' not in st.session_state.df['sessions'].columns:
- st.session_state.df['sessions']['foraging_performance'] = \
- st.session_state.df['sessions']['foraging_eff'] \
- * (1 - st.session_state.df['sessions']['ignore_rate'])
-
- # weekday
- st.session_state.df['sessions']['weekday'] = st.session_state.df['sessions'].session_date.dt.dayofweek + 1
-
- st.session_state.df['sessions'] = st.session_state.df['sessions'].merge(
- diff_relative_weight_next_day, how='left', on=['h2o', 'session'])
-
- st.session_state.session_stats_names = [keys for keys in st.session_state.df['sessions'].keys()]
-
-@st.cache_resource(ttl=24*3600)
-def get_pyg_renderer(df, spec="./gw_config.json", **kwargs) -> "StreamlitRenderer":
- return StreamlitRenderer(df, spec=spec, debug=False, **kwargs)
-
-
-def app():
- st.markdown('## Foraging Behavior Browser')
-
- with st.sidebar:
- add_session_filter()
- data_selector()
-
- st.markdown('---')
- st.markdown('#### Han Hou @ 2024 v2.0.0')
- st.markdown('[bug report / feature request](https://github.com/AllenNeuralDynamics/foraging-behavior-browser/issues)')
-
- with st.expander('Debug', expanded=False):
- st.session_state.model_id = st.selectbox('model_id', st.session_state.df['model_fitting_params'].model_id.unique())
- if st.button('Reload data from AWS S3'):
- st.cache_data.clear()
- init()
- st.rerun()
-
-
-
- with st.container():
- # col1, col2 = st.columns([1.5, 1], gap='small')
- # with col1:
- # -- 1. unit dataframe --
-
- cols = st.columns([2, 2, 2])
- cols[0].markdown(f'### Filter the sessions on the sidebar ({len(st.session_state.df_session_filtered)} filtered)')
- # if cols[1].button('Press this and then Ctrl + R to reload from S3'):
- # st.rerun()
- if cols[1].button('Reload data '):
- st.cache_data.clear()
- init()
- st.rerun()
-
- # aggrid_outputs = aggrid_interactive_table_units(df=df['ephys_units'])
- # st.session_state.df_session_filtered = aggrid_outputs['data']
-
- container_filtered_frame = st.container()
-
-
- if len(st.session_state.df_session_filtered) == 0:
- st.markdown('## No filtered results!')
- return
-
- aggrid_outputs = aggrid_interactive_table_session(df=st.session_state.df_session_filtered)
-
- if len(aggrid_outputs['selected_rows']) and not set(pd.DataFrame(aggrid_outputs['selected_rows']
- ).set_index(['h2o', 'session']).index
- ) == set(st.session_state.df_selected_from_dataframe.set_index(['h2o', 'session']).index):
- st.session_state.df_selected_from_dataframe = pd.DataFrame(aggrid_outputs['selected_rows'])
- st.session_state.df_selected_from_plotly = st.session_state.df_selected_from_dataframe # Sync selected on plotly
- # if st.session_state.tab_id == "tab_session_x_y":
- st.rerun()
-
- chosen_id = stx.tab_bar(data=[
- stx.TabBarItemData(id="tab_session_x_y", title="📈 Session X-Y plot", description="Interactive session-wise scatter plot"),
- stx.TabBarItemData(id="tab_session_inspector", title="👀 Session Inspector", description="Select sessions from the table and show plots"),
- stx.TabBarItemData(id="tab_pygwalker", title="📊 PyGWalker (Tableau)", description="Interactive dataframe explorer"),
- stx.TabBarItemData(id="tab_auto_train_history", title="🎓 Automatic Training History", description="Track progress"),
- stx.TabBarItemData(id="tab_mouse_inspector", title="🐭 Mouse Model Fitting", description="Mouse-level model fitting results"),
- ], default="tab_session_inspector" if 'tab_id' not in st.session_state else st.session_state.tab_id)
- # chosen_id = "tab_session_x_y"
-
- placeholder = st.container()
-
- if chosen_id == "tab_session_x_y":
- st.session_state.tab_id = chosen_id
- with placeholder:
- df_selected_from_plotly, x_y_cols = plot_x_y_session()
-
- with x_y_cols[0]:
- for i in range(7): st.write('\n')
- st.markdown("***")
- if_draw_all_sessions = session_plot_settings()
-
- df_to_draw_sessions = st.session_state.df_selected_from_plotly if 'selected' in st.session_state.selected_draw_sessions else st.session_state.df_session_filtered
-
- if if_draw_all_sessions and len(df_to_draw_sessions):
- draw_session_plots(df_to_draw_sessions)
-
- if len(df_selected_from_plotly) and not set(df_selected_from_plotly.set_index(['h2o', 'session']).index) == set(
- st.session_state.df_selected_from_plotly.set_index(['h2o', 'session']).index):
- st.session_state.df_selected_from_plotly = df_selected_from_plotly
- st.session_state.df_selected_from_dataframe = df_selected_from_plotly # Sync selected on dataframe
- st.rerun()
-
- elif chosen_id == "tab_session_inspector":
- st.session_state.tab_id = chosen_id
- with placeholder:
- with st.columns([4, 10])[0]:
- if_draw_all_sessions = session_plot_settings(need_click=False)
- df_to_draw_sessions = st.session_state.df_selected_from_plotly if 'selected' in st.session_state.selected_draw_sessions else st.session_state.df_session_filtered
-
- if if_draw_all_sessions and len(df_to_draw_sessions):
- draw_session_plots(df_to_draw_sessions)
-
- elif chosen_id == "tab_pygwalker":
- with placeholder:
- cols = st.columns([1, 4])
- cols[0].markdown('##### Exploring data using [PyGWalker](https://docs.kanaries.net/pygwalker)')
- with cols[1]:
- with st.expander('Specify PyGWalker json'):
- # Load json from ./gw_config.json
- pyg_user_json = st.text_area("Export your plot settings to json by clicking `export_code` "
- "button below and then paste your json here to reproduce your plots",
- key='pyg_walker', height=100)
-
- # If pyg_user_json is not empty, use it; otherwise, use the default gw_config.json
- if pyg_user_json:
- try:
- pygwalker_renderer = get_pyg_renderer(
- df=st.session_state.df_session_filtered,
- spec=pyg_user_json,
- )
- except:
- pygwalker_renderer = get_pyg_renderer(
- df=st.session_state.df_session_filtered,
- spec="./gw_config_old_mice.json",
- )
- else:
- pygwalker_renderer = get_pyg_renderer(
- df=st.session_state.df_session_filtered,
- spec="./gw_config_old_mice.json",
- )
-
- pygwalker_renderer.render_explore(height=1010, scrolling=False)
-
-
- elif chosen_id == "tab_auto_train_history": # Automatic training history
- st.session_state.tab_id = chosen_id
- with placeholder:
- add_auto_train_manager()
-
- elif chosen_id == "tab_mouse_inspector":
- st.session_state.tab_id = chosen_id
- with placeholder:
- with st.columns([4, 10])[0]:
- if_draw_all_mice = mouse_plot_settings(need_click=False)
- df_selected = st.session_state.df_selected_from_plotly if 'selected' in st.session_state.selected_draw_mice else st.session_state.df_session_filtered
- df_to_draw_mice = df_selected.groupby('h2o').count().reset_index()
-
- if if_draw_all_mice and len(df_to_draw_mice):
- draw_mice_plots(df_to_draw_mice)
-
-
-
- # st.dataframe(st.session_state.df_session_filtered, use_container_width=True, height=1000)
-
- # Update back to URL
- for key in to_sync_with_url_query:
- try:
- st.query_params.update({key: st.session_state[key]})
- except:
- print(f'Failed to update {key} to URL query')
-
-
-if 'df' not in st.session_state or 'sessions' not in st.session_state.df.keys():
- init()
-
-app()
-
-
-if if_profile:
- p.stop()
\ No newline at end of file
diff --git a/code/pages/3_Playground.py b/code/pages/1_Playground.py
similarity index 100%
rename from code/pages/3_Playground.py
rename to code/pages/1_Playground.py
diff --git a/code/pages/2_Population analysis.py b/code/pages/2_Population analysis.py
deleted file mode 100644
index 3299816..0000000
--- a/code/pages/2_Population analysis.py
+++ /dev/null
@@ -1,281 +0,0 @@
-import streamlit as st
-
-import importlib
-import matplotlib.pyplot as plt
-import numpy as np
-from io import BytesIO
-
-from util.streamlit import filter_dataframe, aggrid_interactive_table_session, add_session_filter, data_selector
-from util.population import _draw_variable_trial_back, _draw_variable_trial_back_linear_reg
-import seaborn as sns
-
-
-try:
- st.set_page_config(layout="wide",
- page_title='Foraging behavior browser',
- page_icon=':mouse2:',
- menu_items={
- 'Report a bug': "https://github.com/hanhou/foraging-behavior-browser/issues",
- 'About': "Github repo: https://github.com/hanhou/foraging-behavior-browser/"
- }
- )
-except:
- pass
-
-st.session_state.use_s3 = True
-
-
-def app():
-
-
- with st.sidebar:
- add_session_filter()
- data_selector()
-
-
- # st.dataframe(st.session_state.df['logistic_regression_hattori'])
-
-
- df_to_do_population = st.session_state.df_session_filtered
- df_log_reg_hattori_to_do = df_to_do_population.merge(st.session_state.df['logistic_regression_hattori'], on=('subject_id', 'session'), how='inner')
- df_log_reg_su_to_do = df_to_do_population.merge(st.session_state.df['logistic_regression_su'], on=('subject_id', 'session'), how='inner')
-
- df_lin_reg_rt_to_do = df_to_do_population.merge(st.session_state.df['linear_regression_rt'], on=('subject_id', 'session'), how='inner')
-
- cols = st.columns([1, 3])
-
- with cols[0]:
- if st.checkbox('Plot logistic regression (non-photostim)', True):
- st.markdown('Hattori (Rewarded choice + Unrewarded choice + Choice)')
- fig = plot_logistic_regression_non_photostim(df_log_reg_hattori_to_do.query('trial_group == "all_no_stim"'))
- if fig is not None:
- buf = BytesIO()
- fig.savefig(buf, format="png")
- st.image(buf, use_column_width=True)
-
- st.markdown('Su (Rewarded choice + Unrewarded choice)')
- fig = plot_logistic_regression_non_photostim(df_log_reg_su_to_do.query('trial_group == "all_no_stim"'))
- if fig is not None:
- buf = BytesIO()
- fig.savefig(buf, format="png")
- st.image(buf, use_column_width=True)
-
- if st.checkbox('Plot linear regression on RT (non-photostim)', True):
- fig = plot_linear_regression_rt_non_photostim(df_lin_reg_rt_to_do.query('trial_group == "all_no_stim"'))
-
- if fig is not None:
- buf = BytesIO()
- fig.savefig(buf, format="png")
- st.image(buf, use_column_width=True)
-
- with cols[1]:
- if st.checkbox('Plot logistic regression (⚡photostim sessions)', False):
- st.markdown('Hattori (Rewarded choice + Unrewarded choice + Choice)')
- with st.columns([1, 3])[0]:
- beta_names = st.multiselect('beta names', ['RewC', 'UnrC', 'C'], ['RewC', 'UnrC', 'C'])
- max_trials_back = st.slider('max trials back', 1, 5, 3)
-
- fig = plot_logistic_regression_photostim(df_log_reg_hattori_to_do.query('trial_group != "all_no_stim"'),
- beta_names=beta_names, past_trials_to_plot=range(1, max_trials_back + 1))
- if fig is not None:
- buf = BytesIO()
- fig.savefig(buf, format="png")
- st.image(buf, width=3000)
-
- st.markdown('Su (Rewarded choice + Unrewarded choice)')
- with st.columns([1, 3])[0]:
- beta_names = st.multiselect('beta names ', ['RewC', 'UnrC'], ['RewC', 'UnrC'])
- max_trials_back = st.slider('max trials back ', 1, 5, 3)
-
- fig = plot_logistic_regression_photostim(df_log_reg_su_to_do.query('trial_group != "all_no_stim"'),
- beta_names=beta_names, past_trials_to_plot=range(1, max_trials_back + 1))
- if fig is not None:
- buf = BytesIO()
- fig.savefig(buf, format="png")
- st.image(buf, width=3000)
-
- if st.checkbox('Plot linear regression on RT (⚡photostim sessions)', False):
- fig = plot_linear_regression_rt_photostim(df_lin_reg_rt_to_do.query('trial_group != "all_no_stim"'),
- beta_names=['reward_1', 'reward_2', 'previous_iti', 'trial_number', 'this_choice', 'constant'])
- if fig is not None:
- buf = BytesIO()
- fig.savefig(buf, format="png")
- st.image(buf)
-
-
-
-
- # st.pyplot(fig)
-
-
-@st.cache_data(ttl=3600*24)
-def plot_logistic_regression_photostim(df_all, beta_names=['RewC', 'UnrC', 'C'], past_trials_to_plot=(1, 2, 3)):
- if not len(df_all): return None
-
- df_all['error_bar'] = [[x, y] for x, y in zip(df_all['mean'] - df_all.lower_ci, df_all.upper_ci - df_all['mean'])]
- df_all['h2o_session'] = df_all[['h2o', 'session']].astype(str).apply('_'.join, axis=1)
-
- df_all.query('trial_group != "all_no_stim"', inplace=True)
-
-
- fig, axes = plt.subplots(len(past_trials_to_plot), len(beta_names) + 1,
- figsize=(10 * (len(beta_names) + 1), 5 * len(past_trials_to_plot)), constrained_layout=False,
- dpi=200,
- gridspec_kw=dict(hspace=0.3, wspace=0.3, top=0.9, bottom=0.1, left=0.1, right=0.9),
- )
- axes = np.atleast_2d(axes)
-
- for i, trials_back in enumerate(past_trials_to_plot):
- for j, name in enumerate(beta_names):
- _draw_variable_trial_back(df_all, name, trials_back, ax=axes[i, j])
- _draw_variable_trial_back(df_all, 'bias', 0, ax=axes[0, j + 1])
-
- for i in range(1, len(past_trials_to_plot)): axes[i, -1].remove()
-
- return fig
-
-
-def plot_logistic_regression_non_photostim(df_all, max_trials_back=10, ax=None):
- if not len(df_all): return None
-
- if ax is None:
- fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=500,
- gridspec_kw=dict(bottom=0.2, top=0.9))
-
- xx = np.arange(1, max_trials_back + 1)
-
- plot_spec = {'RewC': ('tab:green', 'reward choices'),
- 'UnrC': ('tab:red', 'unrewarded choices'),
- 'C': ('tab:blue', 'choices'),
- 'bias': ('k', 'right bias')}
-
- for name, (col, label) in plot_spec.items():
-
- means = [df_all.query(f'beta == "{name}" and trials_back == {t}')['mean'].mean()
- for t in ([0] if name == "bias" else xx)]
- if np.all(np.isnan(means)):
- continue
-
- cis = [df_all.query(f'beta == "{name}" and trials_back == {t}')['mean'].sem() * 1.96
- for t in ([0] if name == "bias" else xx)]
-
- ax.errorbar(x=1 if name == 'bias' else xx,
- y=means,
- yerr=cis,
- ls='-',
- marker='o',
- color=col,
- capsize=5, markeredgewidth=1,
- label=label + ' $\pm$95% CI')
-
- ax.legend()
- ax.set(xlabel='Past trials', ylabel='Logistic regression coeffs')
- ax.axhline(y=0, color='k', linestyle=':', linewidth=0.5)
- ax.set(xticks=[1, 5, 10],
- ylim=(-0.1, 2.5))
-
- sns.despine(trim=True)
-
- n_mice = len(df_all['h2o'].unique())
- n_sessions = len(df_all.groupby(['h2o', 'session']).count())
- fig.suptitle(f'Logistic regression on choice ({n_mice} mice, {n_sessions} sessions)')
-
-
- return fig
-
-
-
-# @st.cache_data(ttl=3600*24)
-def plot_linear_regression_rt_photostim(df_all, beta_names):
- if not len(df_all): return None
-
- fig, axes = plt.subplots(2, 3,
- figsize=(15, 10), constrained_layout=False,
- dpi=200,
- gridspec_kw=dict(hspace=0.3, wspace=0.3, top=0.9, bottom=0.1, left=0.1, right=0.9),
- )
- axes = np.atleast_2d(axes)
-
- for j, beta_name in enumerate(beta_names):
- _draw_variable_trial_back_linear_reg(df_all, beta_name, ax=axes.flat[j])
-
- return fig
-
-
-
-def plot_linear_regression_rt_non_photostim(df_all, ax=None):
- '''
- plot list of linear regressions
- '''
- if not len(df_all): return None
-
-
- if ax is None:
- fig, ax = plt.subplots(1, 1, figsize=(5, 4), dpi=200,
- gridspec_kw=dict(bottom=0.2, left=0.2))
-
- gs = ax._subplotspec.subgridspec(1, 2, width_ratios=[1, 3], wspace=0.5)
- ax_others = ax.get_figure().add_subplot(gs[0, 0])
- ax_reward = ax.get_figure().add_subplot(gs[0, 1])
-
- # Other paras
- other_names = ['previous_iti', 'this_choice', 'trial_number']
- xx = np.arange(len(other_names))
- means = [df_all.query(f'variable == "{name}"')['beta'].mean() for name in other_names]
- sems = [df_all.query(f'variable == "{name}"')['beta'].sem() * 1.96 for name in other_names]
- ax_others.errorbar(x=xx,
- y=means,
- yerr=sems,
- ls='none',
- marker='o',
- color='k',
- capsize=5, markeredgewidth=1,
- )
- ax_others.set(xlim=(-0.5, 2.5), ylim=(-0.3, 1))
-
- ax_others.set_xticks(range(len(other_names)))
- ax_others.set_xticklabels(other_names, rotation=45, ha='right')
- ax_others.axhline(y=0, color='k', linestyle=':', linewidth=1)
- ax_others.set(ylabel=r'Linear regression $\beta \pm$95% CI')
-
- # Back rewards
- xx = np.arange(1, 11)
- means = [df_all.query(f'variable == "reward" and trials_back == {x}')['beta'].mean() for x in xx]
- sems = [df_all.query(f'variable == "reward" and trials_back == {x}')['beta'].sem() * 1.96 for x in xx]
-
- ax_reward.errorbar(x=xx,
- y=means,
- yerr=sems,
- ls='-',
- marker='o',
- color='k',
- capsize=5,
- markeredgewidth=1,
- )
-
- ax_reward.set(xlabel='Reward of past trials')
- ax_reward.axhline(y=0, color='k', linestyle=':', linewidth=1)
- ax_reward.set(xticks=[1, 5, 10], ylim=(-0.35, 0.05))
- ax_reward.invert_yaxis()
-
- sns.despine(trim=True)
- ax.remove()
-
- n_mice = len(df_all['h2o'].unique())
- n_sessions = len(df_all.groupby(['h2o', 'session']).count())
- fig.suptitle(f'Linear regression on RT ({n_mice} mice, {n_sessions} sessions)')
-
- return fig
-
-
-
-if 'df' not in st.session_state:
- from Home import init
- init()
-
-app()
-
-# try:
-# app()
-# except:
-# st.markdown('### Something is wrong. Try going back to 🏠Home or refresh.')
\ No newline at end of file
diff --git a/code/util/streamlit.py b/code/util/streamlit.py
index 778f31e..4d412ac 100644
--- a/code/util/streamlit.py
+++ b/code/util/streamlit.py
@@ -6,10 +6,8 @@
from st_aggrid.shared import GridUpdateMode, ColumnsAutoSizeMode, DataReturnMode
from pandas.api.types import (
is_categorical_dtype,
- is_datetime64_any_dtype,
is_numeric_dtype,
is_string_dtype,
- is_object_dtype,
)
import streamlit.components.v1 as components
from streamlit_plotly_events import plotly_events
@@ -22,7 +20,13 @@
import statsmodels.api as sm
from scipy.stats import linregress
-from .url_query_helper import checkbox_wrapper_for_url_query, selectbox_wrapper_for_url_query, slider_wrapper_for_url_query
+from .url_query_helper import (
+ checkbox_wrapper_for_url_query,
+ selectbox_wrapper_for_url_query,
+ slider_wrapper_for_url_query,
+ multiselect_wrapper_for_url_query,
+ get_filter_type,
+ )
from .plot_autotrain_manager import plot_manager_all_progress
from .aws_s3 import draw_session_plots_quick_preview
@@ -186,6 +190,7 @@ def cache_widget(field, clear=None):
# def dec_cache_widget_state(widget, ):
+
def filter_dataframe(df: pd.DataFrame,
default_filters=['h2o', 'task', 'finished_trials', 'photostim_location'],
url_query={}) -> pd.DataFrame:
@@ -214,20 +219,25 @@ def filter_dataframe(df: pd.DataFrame,
cols = st.columns([1, 1.5])
cols[0].markdown(f"Add filters")
if_reset_filters = cols[1].button(label="Reset filters")
- to_filter_columns = st.multiselect("Filter dataframe on", df.columns,
- label_visibility='collapsed',
- default=st.session_state.to_filter_columns_changed
- if 'to_filter_columns_changed' in st.session_state
- else default_filters,
- key='to_filter_columns',
- on_change=cache_widget,
- args=['to_filter_columns'])
+
+ to_filter_columns = multiselect_wrapper_for_url_query(
+ st_prefix=st,
+ label="Filter dataframe on",
+ options=df.columns,
+ default=['subject_id', 'session', 'finished_trials', 'foraging_eff', 'task'],
+ key='to_filter_columns',
+ label_visibility='collapsed',
+ )
+
for column in to_filter_columns:
if not len(df): break
left, right = st.columns((1, 20))
# Treat columns with < 10 unique values as categorical
- if is_categorical_dtype(df[column]) or df[column].nunique() < 10 and column not in ('finished', 'foraging_eff', 'session', 'finished_trials'):
+
+ filter_type = get_filter_type(df, column)
+
+ if filter_type == 'multiselect':
right.markdown(f"Filter for :red[**{column}**]")
if if_reset_filters:
@@ -243,6 +253,7 @@ def filter_dataframe(df: pd.DataFrame,
else:
default_value = list(df[column].unique())
+ default_value = [v for v in default_value if v in list(df[column].unique())]
st.session_state[f'filter_{column}'] = default_value
selected = right.multiselect(
@@ -256,7 +267,7 @@ def filter_dataframe(df: pd.DataFrame,
)
df = df[df[column].isin(selected)]
- elif is_numeric_dtype(df[column]):
+ elif filter_type == 'slider_range_float':
# fig = px.histogram(df[column], nbins=100, )
# fig.update_layout(showlegend=False, height=50)
@@ -339,7 +350,7 @@ def filter_dataframe(df: pd.DataFrame,
df = df[x.between(*user_num_input)]
- elif is_datetime64_any_dtype(df[column]):
+ elif filter_type == 'slider_range_date':
user_date_input = right.date_input(
f"Values for :red[**{column}**]",
value=st.session_state[f'filter_{column}_changed']
@@ -354,8 +365,8 @@ def filter_dataframe(df: pd.DataFrame,
user_date_input = tuple(map(pd.to_datetime, user_date_input))
start_date, end_date = user_date_input
df = df.loc[df[column].between(start_date, end_date)]
- else: # Regular string
+ elif filter_type == 'reg_ex':
if if_reset_filters:
default_value = ''
st.session_state[f'filter_{column}_changed'] = default_value
diff --git a/code/util/url_query_helper.py b/code/util/url_query_helper.py
index cb323b6..50eea0b 100644
--- a/code/util/url_query_helper.py
+++ b/code/util/url_query_helper.py
@@ -1,5 +1,64 @@
import streamlit as st
+from pandas.api.types import (
+ is_categorical_dtype,
+ is_datetime64_any_dtype,
+ is_numeric_dtype,
+)
+
+# Sync widgets with URL query params
+# https://blog.streamlit.io/how-streamlit-uses-streamlit-sharing-contextual-apps/
+# dict of "key": default pairs
+# Note: When creating the widget, add argument "value"/"index" as well as "key" for all widgets you want to sync with URL
+to_sync_with_url_query_default = {
+ 'if_load_bpod_sessions': False,
+
+ 'to_filter_columns': ['subject_id', 'task', 'session', 'finished_trials', 'foraging_eff'],
+ 'filter_subject_id': '',
+ 'filter_session': [0.0, None],
+ 'filter_finished_trials': [0.0, None],
+ 'filter_foraging_eff': [0.0, None],
+ 'filter_task': ['all'],
+
+ 'table_height': 300,
+
+ 'tab_id': 'tab_session_x_y',
+ 'x_y_plot_xname': 'session',
+ 'x_y_plot_yname': 'foraging_performance_random_seed',
+ 'x_y_plot_group_by': 'h2o',
+ 'x_y_plot_if_show_dots': True,
+ 'x_y_plot_if_aggr_each_group': True,
+ 'x_y_plot_aggr_method_group': 'lowess',
+ 'x_y_plot_if_aggr_all': True,
+ 'x_y_plot_aggr_method_all': 'mean +/- sem',
+ 'x_y_plot_smooth_factor': 5,
+ 'x_y_plot_if_use_x_quantile_group': False,
+ 'x_y_plot_q_quantiles_group': 20,
+ 'x_y_plot_if_use_x_quantile_all': False,
+ 'x_y_plot_q_quantiles_all': 20,
+ 'x_y_plot_if_show_diagonal': False,
+ 'x_y_plot_dot_size': 10,
+ 'x_y_plot_dot_opacity': 0.3,
+ 'x_y_plot_line_width': 2.0,
+ 'x_y_plot_figure_width': 1300,
+ 'x_y_plot_figure_height': 900,
+ 'x_y_plot_font_size_scale': 1.0,
+ 'x_y_plot_selected_color_map': 'Plotly',
+
+ 'x_y_plot_size_mapper': 'finished_trials',
+ 'x_y_plot_size_mapper_gamma': 1.0,
+ 'x_y_plot_size_mapper_range': [3, 20],
+
+ 'session_plot_mode': 'sessions selected from table or plot',
+
+ 'auto_training_history_x_axis': 'session',
+ 'auto_training_history_sort_by': 'subject_id',
+ 'auto_training_history_sort_order': 'descending',
+ 'auto_training_curriculum_name': 'Uncoupled Baiting',
+ 'auto_training_curriculum_version': '1.0',
+ 'auto_training_curriculum_schema_version': '1.0',
+ }
+
def checkbox_wrapper_for_url_query(st_prefix, label, key, default, **kwargs):
return st_prefix.checkbox(
label,
@@ -26,6 +85,21 @@ def selectbox_wrapper_for_url_query(st_prefix, label, options, key, default, **k
key=key,
**kwargs,
)
+
+def multiselect_wrapper_for_url_query(st_prefix, label, options, key, default, **kwargs):
+ return st_prefix.multiselect(
+ label,
+ options=options,
+ default=(
+ st.session_state[key]
+ if key in st.session_state
+ else st.query_params[key]
+ if key in st.query_params
+ else default
+ ),
+ key=key,
+ **kwargs,
+ )
def slider_wrapper_for_url_query(st_prefix, label, min_value, max_value, key, default, **kwargs):
@@ -51,32 +125,107 @@ def slider_wrapper_for_url_query(st_prefix, label, min_value, max_value, key, de
)
-def sync_widget_with_query(key, default):
- if key in st.query_params:
- # always get all query params as a list
- q_all = st.query_params.get_all(key)
+def sync_URL_to_session_state():
+ """Assign session_state to sync with URL"""
+
+ to_sync_with_session_state_dynamic_filter_added = list(
+ set(
+ list(st.query_params.keys())
+ + list(to_sync_with_url_query_default.keys())
+ )
+ )
+
+ for key in to_sync_with_session_state_dynamic_filter_added:
- # convert type according to default
- list_default = default if isinstance(default, list) else [default]
- for d in list_default:
- _type = type(d)
- if _type: break # The first non-None type
+ if key in to_sync_with_url_query_default:
+ # If in default list, get default value there
+ default = to_sync_with_url_query_default[key]
+ else:
+ # Else, the user should have added a new filter column
+ # let's get the type from the dataframe directly.
+ #
+ # Also, in this case, the key must be from st.query_params,
+ # so the only purpose of getting default is to get the correct type,
+ # not its value per se.
+ filter_type = get_filter_type(st.session_state.df['sessions_bonsai'],
+ key.replace('filter_', ''))
+ if filter_type == 'slider_range_float':
+ default = [0.0, 1.0]
+ elif filter_type == 'reg_ex':
+ default = ''
+ elif filter_type == 'multiselect':
+ default = ['a', 'b']
+ else:
+ print('sync_URL_to_session_state: Unrecognized filter type')
+ continue
- if _type == bool:
- q_all_correct_type = [q.lower() == 'true' for q in q_all]
+ if key in st.query_params:
+ # always get all query params as a list
+ q_all = st.query_params.get_all(key)
+
+ # convert type according to default
+ list_default = default if isinstance(default, list) else [default]
+ for d in list_default:
+ _type = type(d)
+ if _type: break # The first non-None type
+
+ if _type == bool:
+ q_all_correct_type = [q.lower() == 'true' for q in q_all]
+ else:
+ q_all_correct_type = [_type(q)
+ if q.lower() != 'none'
+ else None
+ for q in q_all]
+
+ # flatten list if only one element
+ if not isinstance(default, list):
+ q_all_correct_type = q_all_correct_type[0]
+
+ try:
+ st.session_state[key] = q_all_correct_type
+ # print(f'sync_URL_to_session_state: Set {key} to {q_all_correct_type}')
+ except:
+ print(f'sync_URL_to_session_state: Failed to set {key} to {q_all_correct_type}')
else:
- q_all_correct_type = [_type(q) for q in q_all]
-
- # flatten list if only one element
- if not isinstance(default, list):
- q_all_correct_type = q_all_correct_type[0]
-
- try:
- st.session_state[key] = q_all_correct_type
- except:
- print(f'Failed to set {key} to {q_all_correct_type}')
- else:
+ try:
+ st.session_state[key] = default
+ except:
+ print(f'Failed to set {key} to {default}')
+
+
+def sync_session_state_to_URL():
+ # Add all 'filter_' fields to the default list
+ # so that all dynamic filters are synced with URL
+ to_sync_with_url_query_dynamic_filter_added = list(
+ set(
+ list(to_sync_with_url_query_default.keys()) +
+ [
+ filter_name for filter_name in st.session_state
+ if (
+ filter_name.startswith('filter_')
+ and not (filter_name.endswith('_changed'))
+ )
+ ]
+ )
+ )
+ for key in to_sync_with_url_query_dynamic_filter_added:
try:
- st.session_state[key] = default
+ st.query_params.update({key: st.session_state[key]})
except:
- print(f'Failed to set {key} to {default}')
\ No newline at end of file
+ print(f'Failed to update {key} to URL query')
+
+
+def get_filter_type(df, column):
+ if is_numeric_dtype(df[column]):
+ return 'slider_range_float'
+
+ if (is_categorical_dtype(df[column])
+ or df[column].nunique() < 10
+ or column in ('user_name') # pin to multiselect
+ ):
+ return 'multiselect'
+
+ if is_datetime64_any_dtype(df[column]):
+ return 'slider_range_date'
+
+ return 'reg_ex' # Default