From 1be5c1247664dea5da08f93b5899f13d4fd51562 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Thu, 4 Apr 2024 22:50:55 -0700 Subject: [PATCH 01/20] refactor: add urlquery wrapper for checkbox --- code/Home.py | 3 +- code/pages/1_Old mice.py | 2 +- code/util/streamlit.py | 89 ++++++++++++++++++++-------------------- 3 files changed, 48 insertions(+), 46 deletions(-) diff --git a/code/Home.py b/code/Home.py index 09f41fc..f3b21dc 100644 --- a/code/Home.py +++ b/code/Home.py @@ -448,7 +448,8 @@ def plot_x_y_session(): q_quantiles_all=q_quantiles_all, title=names[(x_name, y_name)] if (x_name, y_name) in names else y_name, states = st.session_state.df_selected_from_plotly, - dot_size=dot_size, + dot_size_base=dot_size, + dot_size_mapping_name='session', dot_opacity=dot_opacity, line_width=line_width) diff --git a/code/pages/1_Old mice.py b/code/pages/1_Old mice.py index cf910ec..eda6805 100644 --- a/code/pages/1_Old mice.py +++ b/code/pages/1_Old mice.py @@ -385,7 +385,7 @@ def plot_x_y_session(): q_quantiles_all=q_quantiles_all, title=names[(x_name, y_name)] if (x_name, y_name) in names else y_name, states = st.session_state.df_selected_from_plotly, - dot_size=dot_size, + dot_size_base=dot_size, dot_opacity=dot_opacity, line_width=line_width) diff --git a/code/util/streamlit.py b/code/util/streamlit.py index 54bcbb6..2a6c6ae 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -429,28 +429,32 @@ def add_xy_selector(if_bonsai): # st.form_submit_button("update axes") return x_name, y_name, group_by - + +def _checkbox_wrapper_for_url_query(st_prefix, label, key, default_value, **kwargs): + return st_prefix.checkbox( + label, + value=st.session_state[key] + if key in st.session_state else + st.query_params[key].lower()=='true' + if key in st.query_params + else default_value, + key=key, + **kwargs, + ) def add_xy_setting(): with st.expander('Plot settings', expanded=True): s_cols = st.columns([1, 1, 1]) # if_plot_only_selected_from_dataframe = s_cols[0].checkbox('Only selected', False) - if_show_dots = s_cols[0].checkbox('Show data points', - value=st.session_state['x_y_plot_if_show_dots'] - if 'x_y_plot_if_show_dots' in st.session_state else - st.query_params['x_y_plot_if_show_dots'].lower()=='true' - if 'x_y_plot_if_show_dots' in st.query_params - else True, - key='x_y_plot_if_show_dots') - - - if_aggr_each_group = s_cols[1].checkbox('Aggr each group', - value=st.session_state['x_y_plot_if_aggr_each_group'] - if 'x_y_plot_if_aggr_each_group' in st.session_state else - st.query_params['x_y_plot_if_aggr_each_group'].lower()=='true' - if 'x_y_plot_if_aggr_each_group' in st.query_params - else True, - key='x_y_plot_if_aggr_each_group') + if_show_dots = _checkbox_wrapper_for_url_query(s_cols[0], + label='Show data points', + key='x_y_plot_if_show_dots', + default_value=True) + + if_aggr_each_group = _checkbox_wrapper_for_url_query(s_cols[1], + label='Aggr each group', + key='x_y_plot_if_aggr_each_group', + default_value=True) aggr_methods = ['mean', 'mean +/- sem', 'lowess', 'running average', 'linear fit'] aggr_method_group = s_cols[1].selectbox('aggr method group', @@ -463,14 +467,11 @@ def add_xy_setting(): key='x_y_plot_aggr_method_group', disabled=not if_aggr_each_group) - if_use_x_quantile_group = s_cols[1].checkbox('Use quantiles of x ', - value=st.session_state['x_y_plot_if_use_x_quantile_group'] - if 'x_y_plot_if_use_x_quantile_group' in st.session_state else - st.query_params['x_y_plot_if_use_x_quantile_group'].lower()=='true' - if 'x_y_plot_if_use_x_quantile_group' in st.query_params - else False, - key='x_y_plot_if_use_x_quantile_group', - disabled='mean' not in aggr_method_group) + if_use_x_quantile_group = _checkbox_wrapper_for_url_query(s_cols[1], + label='Use quantiles of x', + key='x_y_plot_if_use_x_quantile_group', + default_value=False, + disabled='mean' not in aggr_method_group) q_quantiles_group = s_cols[1].slider('Number of quantiles ', 1, 100, value=st.session_state['x_y_plot_q_quantiles_group'] @@ -482,14 +483,10 @@ def add_xy_setting(): disabled=not if_use_x_quantile_group ) - if_aggr_all = s_cols[2].checkbox('Aggr all', - value=st.session_state['x_y_plot_if_aggr_all'] - if 'x_y_plot_if_aggr_all' in st.session_state else - st.query_params['x_y_plot_if_aggr_all'].lower()=='true' - if 'x_y_plot_if_aggr_all' in st.query_params - else True, - key='x_y_plot_if_aggr_all', - ) + if_aggr_all = _checkbox_wrapper_for_url_query(s_cols[2], + label='Aggr all', + key='x_y_plot_if_aggr_all', + default_value=True) # st.session_state.if_aggr_all_cache = if_aggr_all # Have to use another variable to store this explicitly (my cache_widget somehow doesn't work with checkbox) aggr_method_all = s_cols[2].selectbox('aggr method all', aggr_methods, @@ -501,15 +498,13 @@ def add_xy_setting(): key='x_y_plot_aggr_method_all', disabled=not if_aggr_all) - if_use_x_quantile_all = s_cols[2].checkbox('Use quantiles of x', - value=st.session_state['x_y_plot_if_use_x_quantile_all'] - if 'x_y_plot_if_use_x_quantile_all' in st.session_state else - st.query_params['x_y_plot_if_use_x_quantile_all'].lower()=='true' - if 'x_y_plot_if_use_x_quantile_all' in st.query_params - else False, - key='x_y_plot_if_use_x_quantile_all', - disabled='mean' not in aggr_method_all, - ) + if_use_x_quantile_all = _checkbox_wrapper_for_url_query(s_cols[2], + label='Use quantiles of x', + key='x_y_plot_if_use_x_quantile_all', + default_value=False, + disabled='mean' not in aggr_method_all) + + q_quantiles_all = s_cols[2].slider('number of quantiles', 1, 100, value=st.session_state['x_y_plot_q_quantiles_all'] if 'x_y_plot_q_quantiles_all' in st.session_state else @@ -690,7 +685,8 @@ def _plot_population_x_y(df, x_name='session', y_name='foraging_eff', group_by=' if_use_x_quantile_all=False, q_quantiles_all=20, title='', - dot_size=10, + dot_size_base=10, + dot_size_mapping_name=None, dot_opacity=0.4, line_width=2, **kwarg): @@ -827,6 +823,11 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q fig = go.Figure() col_map = px.colors.qualitative.Plotly + if dot_size_mapping_name is not None and dot_size_mapping_name in df.columns: + df['dot_size'] = df[dot_size_mapping_name] + else: + df['dot_size'] = dot_size_base + for i, group in enumerate(df.sort_values(group_by)[group_by].unique()): this_session = df.query(f'{group_by} == "{group}"').sort_values('session') col = col_map[i%len(col_map)] @@ -849,7 +850,7 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q showlegend=not if_aggr_each_group, mode="markers", line_width=line_width, - marker_size=dot_size, + marker_size=this_session['dot_size'], marker_color=this_session['colors'], opacity=dot_opacity, # 0.5 if if_aggr_each_group else 0.8, text=this_session['session'], From 73b484f876e891312b6a8328d37ac0e08412290d Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Thu, 4 Apr 2024 23:11:09 -0700 Subject: [PATCH 02/20] refactor: all url query wrappers --- code/util/streamlit.py | 174 +++++++++++++++------------------- code/util/url_query_helper.py | 44 +++++++++ 2 files changed, 120 insertions(+), 98 deletions(-) create mode 100644 code/util/url_query_helper.py diff --git a/code/util/streamlit.py b/code/util/streamlit.py index 2a6c6ae..82dbb6a 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -20,6 +20,8 @@ import statsmodels.api as sm from scipy.stats import linregress +from .url_query_helper import checkbox_wrapper_for_url_query, selectbox_wrapper_for_url_query, slider_wrapper_for_url_query + custom_css = { ".ag-root.ag-unselectable.ag-layout-normal": {"font-size": "15px !important", @@ -176,7 +178,7 @@ def cache_widget(field, clear=None): # Clear cache if needed if clear: if clear in st.session_state: del st.session_state[clear] - + # def dec_cache_widget_state(widget, ): @@ -430,133 +432,109 @@ def add_xy_selector(if_bonsai): # st.form_submit_button("update axes") return x_name, y_name, group_by -def _checkbox_wrapper_for_url_query(st_prefix, label, key, default_value, **kwargs): - return st_prefix.checkbox( - label, - value=st.session_state[key] - if key in st.session_state else - st.query_params[key].lower()=='true' - if key in st.query_params - else default_value, - key=key, - **kwargs, - ) - def add_xy_setting(): with st.expander('Plot settings', expanded=True): s_cols = st.columns([1, 1, 1]) # if_plot_only_selected_from_dataframe = s_cols[0].checkbox('Only selected', False) - if_show_dots = _checkbox_wrapper_for_url_query(s_cols[0], + if_show_dots = checkbox_wrapper_for_url_query(s_cols[0], label='Show data points', key='x_y_plot_if_show_dots', - default_value=True) + default=True) - if_aggr_each_group = _checkbox_wrapper_for_url_query(s_cols[1], + if_aggr_each_group = checkbox_wrapper_for_url_query(s_cols[1], label='Aggr each group', key='x_y_plot_if_aggr_each_group', - default_value=True) + default=True) aggr_methods = ['mean', 'mean +/- sem', 'lowess', 'running average', 'linear fit'] - aggr_method_group = s_cols[1].selectbox('aggr method group', - options=aggr_methods, - index=aggr_methods.index(st.session_state['x_y_plot_aggr_method_group']) - if 'x_y_plot_aggr_method_group' in st.session_state else - aggr_methods.index(st.query_params['x_y_plot_aggr_method_group']) - if 'x_y_plot_aggr_method_group' in st.query_params - else 2, - key='x_y_plot_aggr_method_group', - disabled=not if_aggr_each_group) + aggr_method_group = selectbox_wrapper_for_url_query(s_cols[1], + label='aggr method group', + options=aggr_methods, + key='x_y_plot_aggr_method_group', + default=2, + disabled=not if_aggr_each_group) - if_use_x_quantile_group = _checkbox_wrapper_for_url_query(s_cols[1], + if_use_x_quantile_group = checkbox_wrapper_for_url_query(s_cols[1], label='Use quantiles of x', key='x_y_plot_if_use_x_quantile_group', - default_value=False, + default=False, disabled='mean' not in aggr_method_group) - q_quantiles_group = s_cols[1].slider('Number of quantiles ', 1, 100, - value=st.session_state['x_y_plot_q_quantiles_group'] - if 'x_y_plot_q_quantiles_group' in st.session_state else - int(st.query_params['x_y_plot_q_quantiles_group']) - if 'x_y_plot_q_quantiles_group' in st.query_params - else 20, - key='x_y_plot_q_quantiles_group', - disabled=not if_use_x_quantile_group - ) + q_quantiles_group = slider_wrapper_for_url_query(s_cols[1], + label='Number of quantiles ', + min_value=1, + max_value=100, + key='x_y_plot_q_quantiles_group', + default=20, + disabled=not if_use_x_quantile_group + ) - if_aggr_all = _checkbox_wrapper_for_url_query(s_cols[2], + if_aggr_all = checkbox_wrapper_for_url_query(s_cols[2], label='Aggr all', key='x_y_plot_if_aggr_all', - default_value=True) + default=True) # st.session_state.if_aggr_all_cache = if_aggr_all # Have to use another variable to store this explicitly (my cache_widget somehow doesn't work with checkbox) - aggr_method_all = s_cols[2].selectbox('aggr method all', aggr_methods, - index=aggr_methods.index(st.session_state['x_y_plot_aggr_method_all']) - if 'x_y_plot_aggr_method_all' in st.session_state else - aggr_methods.index(st.query_params['x_y_plot_aggr_method_all']) - if 'x_y_plot_aggr_method_all' in st.query_params - else 1, - key='x_y_plot_aggr_method_all', - disabled=not if_aggr_all) - - if_use_x_quantile_all = _checkbox_wrapper_for_url_query(s_cols[2], + aggr_method_all = selectbox_wrapper_for_url_query(s_cols[2], + label='aggr method all', + options=aggr_methods, + key='x_y_plot_aggr_method_all', + default=1, + disabled=not if_aggr_all) + + if_use_x_quantile_all = checkbox_wrapper_for_url_query(s_cols[2], label='Use quantiles of x', key='x_y_plot_if_use_x_quantile_all', - default_value=False, + default=False, disabled='mean' not in aggr_method_all) - q_quantiles_all = s_cols[2].slider('number of quantiles', 1, 100, - value=st.session_state['x_y_plot_q_quantiles_all'] - if 'x_y_plot_q_quantiles_all' in st.session_state else - int(st.query_params['x_y_plot_q_quantiles_all']) - if 'x_y_plot_q_quantiles_all' in st.query_params - else 20, - key='x_y_plot_q_quantiles_all', - disabled=not if_use_x_quantile_all - ) - - smooth_factor = s_cols[0].slider('smooth factor', 1, 20, - value=st.session_state['x_y_plot_smooth_factor'] - if 'x_y_plot_smooth_factor' in st.session_state else - int(st.query_params['x_y_plot_smooth_factor']) - if 'x_y_plot_smooth_factor' in st.query_params - else 5, - key='x_y_plot_smooth_factor', - disabled= not ((if_aggr_each_group and aggr_method_group in ('running average', 'lowess')) - or (if_aggr_all and aggr_method_all in ('running average', 'lowess'))) - ) + q_quantiles_all = slider_wrapper_for_url_query(s_cols[2], + label='Number of quantiles', + min_value=1, + max_value=100, + key='x_y_plot_q_quantiles_all', + default=20, + disabled=not if_use_x_quantile_all + ) + + smooth_factor = slider_wrapper_for_url_query(s_cols[0], + label='smooth factor', + min_value=1, + max_value=20, + key='x_y_plot_smooth_factor', + default=5, + disabled= not ((if_aggr_each_group and aggr_method_group in ('running average', 'lowess')) + or (if_aggr_all and aggr_method_all in ('running average', 'lowess'))) + ) c = st.columns([1, 1, 1]) - dot_size = c[0].slider('dot size', 1, 30, - step=1, - value=st.session_state['x_y_plot_dot_size'] - if 'x_y_plot_dot_size' in st.session_state else - int(st.query_params['x_y_plot_dot_size']) - if 'x_y_plot_dot_size' in st.query_params - else 10, - key='x_y_plot_dot_size' - ) - dot_opacity = c[1].slider('opacity', 0.0, 1.0, - step=0.05, - value=st.session_state['x_y_plot_dot_opacity'] - if 'x_y_plot_dot_opacity' in st.session_state else - float(st.query_params['x_y_plot_dot_opacity']) - if 'x_y_plot_dot_opacity' in st.query_params - else 0.5, - key='x_y_plot_dot_opacity') - - line_width = c[2].slider('line width', 0.0, 5.0, - step=0.25, - value=st.session_state['x_y_plot_line_width'] - if 'x_y_plot_line_width' in st.session_state else - float(st.query_params['x_y_plot_line_width']) - if 'x_y_plot_line_width' in st.query_params - else 2.0, - key='x_y_plot_line_width') + dot_size = slider_wrapper_for_url_query(c[0], + label='dot size', + min_value=1, + max_value=30, + key='x_y_plot_dot_size', + default=10) + + dot_opacity = slider_wrapper_for_url_query(c[1], + label='opacity', + min_value=0.0, + max_value=1.0, + step=0.05, + key='x_y_plot_dot_opacity', + default=0.5) + + line_width = slider_wrapper_for_url_query(c[2], + label='line width', + min_value=0.0, + max_value=5.0, + step=0.25, + key='x_y_plot_line_width', + default=2.0) return (if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group, - if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, - dot_size, dot_opacity, line_width) + if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, + dot_size, dot_opacity, line_width) def data_selector(): diff --git a/code/util/url_query_helper.py b/code/util/url_query_helper.py new file mode 100644 index 0000000..80e5d5d --- /dev/null +++ b/code/util/url_query_helper.py @@ -0,0 +1,44 @@ +import streamlit as st + +def checkbox_wrapper_for_url_query(st_prefix, label, key, default, **kwargs): + return st_prefix.checkbox( + label, + value=st.session_state[key] + if key in st.session_state else + st.query_params[key].lower()=='true' + if key in st.query_params + else default, + key=key, + **kwargs, + ) + +def selectbox_wrapper_for_url_query(st_prefix, label, options, key, default, **kwargs): + return st_prefix.selectbox( + label, + options=options, + index=( + options.index(st.session_state[key]) + if key in st.session_state + else options.index(st.query_params[key]) + if key in st.query_params + else default + ), + key=key, + **kwargs, + ) + +def slider_wrapper_for_url_query(st_prefix, label, min_value, max_value, key, default, **kwargs): + return st_prefix.slider( + label, + min_value, + max_value, + value=( + st.session_state[key] + if key in st.session_state + else int(st.query_params[key]) + if key in st.query_params + else default + ), + key=key, + **kwargs, + ) \ No newline at end of file From e0b35072585bda8f9dff3a7612b8b614f41fb4f6 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Thu, 4 Apr 2024 23:21:18 -0700 Subject: [PATCH 03/20] refactor: move sync_widget_with_query to helper.py --- code/Home.py | 8 ++-- code/pages/1_Old mice.py | 6 ++- code/util/streamlit.py | 72 +++++++++++------------------------ code/util/url_query_helper.py | 33 +++++++++++++++- 4 files changed, 63 insertions(+), 56 deletions(-) diff --git a/code/Home.py b/code/Home.py index f3b21dc..6ad1679 100644 --- a/code/Home.py +++ b/code/Home.py @@ -40,8 +40,10 @@ from util.streamlit import (filter_dataframe, aggrid_interactive_table_session, aggrid_interactive_table_curriculum, add_session_filter, data_selector, - _sync_widget_with_query, add_xy_selector, add_xy_setting, add_auto_train_manager, + add_xy_selector, add_xy_setting, add_auto_train_manager, _plot_population_x_y) +from util.url_query_helper import sync_widget_with_query + import extra_streamlit_components as stx from aind_auto_train.curriculum_manager import CurriculumManager @@ -413,7 +415,7 @@ def plot_x_y_session(): with cols[0]: - x_name, y_name, group_by = add_xy_selector(if_bonsai=True) + x_name, y_name, group_by, size_mapper = add_xy_selector(if_bonsai=True) (if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group, if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, @@ -478,7 +480,7 @@ def init(): # Set session state from URL for key, default in to_sync_with_url_query.items(): - _sync_widget_with_query(key, default) + sync_widget_with_query(key, default) df = load_data(['sessions', ]) diff --git a/code/pages/1_Old mice.py b/code/pages/1_Old mice.py index eda6805..9d6b026 100644 --- a/code/pages/1_Old mice.py +++ b/code/pages/1_Old mice.py @@ -19,8 +19,10 @@ from streamlit_plotly_events import plotly_events from util.streamlit import (filter_dataframe, aggrid_interactive_table_session, add_session_filter, data_selector, - add_xy_selector, _sync_widget_with_query, add_xy_setting, add_auto_train_manager, + add_xy_selector, add_xy_setting, add_auto_train_manager, _plot_population_x_y) +from util.url_query_helper import sync_widget_with_query + import extra_streamlit_components as stx from aind_auto_train.auto_train_manager import DynamicForagingAutoTrainManager @@ -410,7 +412,7 @@ def init(): # Set session state from URL for key, default in to_sync_with_url_query.items(): - _sync_widget_with_query(key, default) + sync_widget_with_query(key, default) df = load_data(['sessions', 'logistic_regression_hattori', diff --git a/code/util/streamlit.py b/code/util/streamlit.py index 82dbb6a..2dfc524 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -405,32 +405,32 @@ def add_xy_selector(if_bonsai): else st.session_state.session_stats_names.index('session'), key='x_y_plot_xname' ) - y_name = cols[0].selectbox("y axis", - st.session_state.session_stats_names, - index=st.session_state.session_stats_names.index(st.session_state['x_y_plot_yname']) - if 'x_y_plot_yname' in st.session_state else - st.session_state.session_stats_names.index(st.query_params['x_y_plot_yname']) - if 'x_y_plot_yname' in st.query_params - else st.session_state.session_stats_names.index('foraging_eff'), - key='x_y_plot_yname') - + y_name = selectbox_wrapper_for_url_query( + cols[0], + label="y axis", + options=st.session_state.session_stats_names, + key="x_y_plot_yname", + default=st.session_state.session_stats_names.index('foraging_performance') + ) + if if_bonsai: options = ['h2o', 'task', 'user_name', 'rig', 'weekday'] else: options = ['h2o', 'task', 'photostim_location', 'weekday', 'headbar', 'user_name', 'sex', 'rig'] - - group_by = cols[0].selectbox("grouped by", - options=options, - index=options.index(st.session_state['x_y_plot_group_by']) - if 'x_y_plot_group_by' in st.session_state else - options.index(st.query_params['x_y_plot_group_by']) - if 'x_y_plot_group_by' in st.query_params - else 0, - key='x_y_plot_group_by') - - # st.form_submit_button("update axes") - return x_name, y_name, group_by + + group_by = selectbox_wrapper_for_url_query( + cols[0], + label="grouped by", + options=options, + key="x_y_plot_group_by", + default=0, + ) + + size_mapper = 1 + + # st.form_submit_button("update axes") + return x_name, y_name, group_by, size_mapper def add_xy_setting(): with st.expander('Plot settings', expanded=True): @@ -873,32 +873,4 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q ticks = "outside", tickcolor='black', ticklen=10, tickwidth=2, ticksuffix=' ') return fig -def _sync_widget_with_query(key, default): - if key in st.query_params: - # always get all query params as a list - q_all = st.query_params.get_all(key) - - # convert type according to default - list_default = default if isinstance(default, list) else [default] - for d in list_default: - _type = type(d) - if _type: break # The first non-None type - - if _type == bool: - q_all_correct_type = [q.lower() == 'true' for q in q_all] - else: - q_all_correct_type = [_type(q) for q in q_all] - - # flatten list if only one element - if not isinstance(default, list): - q_all_correct_type = q_all_correct_type[0] - - try: - st.session_state[key] = q_all_correct_type - except: - print(f'Failed to set {key} to {q_all_correct_type}') - else: - try: - st.session_state[key] = default - except: - print(f'Failed to set {key} to {default}') + diff --git a/code/util/url_query_helper.py b/code/util/url_query_helper.py index 80e5d5d..ab78d89 100644 --- a/code/util/url_query_helper.py +++ b/code/util/url_query_helper.py @@ -41,4 +41,35 @@ def slider_wrapper_for_url_query(st_prefix, label, min_value, max_value, key, de ), key=key, **kwargs, - ) \ No newline at end of file + ) + + +def sync_widget_with_query(key, default): + if key in st.query_params: + # always get all query params as a list + q_all = st.query_params.get_all(key) + + # convert type according to default + list_default = default if isinstance(default, list) else [default] + for d in list_default: + _type = type(d) + if _type: break # The first non-None type + + if _type == bool: + q_all_correct_type = [q.lower() == 'true' for q in q_all] + else: + q_all_correct_type = [_type(q) for q in q_all] + + # flatten list if only one element + if not isinstance(default, list): + q_all_correct_type = q_all_correct_type[0] + + try: + st.session_state[key] = q_all_correct_type + except: + print(f'Failed to set {key} to {q_all_correct_type}') + else: + try: + st.session_state[key] = default + except: + print(f'Failed to set {key} to {default}') \ No newline at end of file From 2d8a01f9bf7011d2086ac902330af7e9627d8ea5 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Thu, 4 Apr 2024 23:44:36 -0700 Subject: [PATCH 04/20] feat: add "dot size mapper" button --- code/Home.py | 8 ++++---- code/pages/1_Old mice.py | 4 ++-- code/util/streamlit.py | 23 +++++++++++++++++------ 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/code/Home.py b/code/Home.py index 6ad1679..b7b17ff 100644 --- a/code/Home.py +++ b/code/Home.py @@ -63,7 +63,7 @@ 'tab_id': 'tab_session_x_y', 'x_y_plot_xname': 'session', - 'x_y_plot_yname': 'foraging_eff', + 'x_y_plot_yname': 'foraging_performance_random_seed', 'x_y_plot_group_by': 'h2o', 'x_y_plot_if_show_dots': True, 'x_y_plot_if_aggr_each_group': True, @@ -79,6 +79,8 @@ 'x_y_plot_dot_opacity': 0.5, 'x_y_plot_line_width': 2.0, + 'x_y_plot_size_mapper': 'None', + 'session_plot_mode': 'sessions selected from table or plot', 'auto_training_history_x_axis': 'date', @@ -420,8 +422,6 @@ def plot_x_y_session(): (if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group, if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, dot_size, dot_opacity, line_width) = add_xy_setting() - - # If no sessions are selected, use all filtered entries # df_x_y_session = st.session_state.df_selected_from_dataframe if if_plot_only_selected_from_dataframe else st.session_state.df_session_filtered @@ -451,7 +451,7 @@ def plot_x_y_session(): title=names[(x_name, y_name)] if (x_name, y_name) in names else y_name, states = st.session_state.df_selected_from_plotly, dot_size_base=dot_size, - dot_size_mapping_name='session', + dot_size_mapping_name=size_mapper, dot_opacity=dot_opacity, line_width=line_width) diff --git a/code/pages/1_Old mice.py b/code/pages/1_Old mice.py index 9d6b026..195147f 100644 --- a/code/pages/1_Old mice.py +++ b/code/pages/1_Old mice.py @@ -43,7 +43,7 @@ 'tab_id': 'tab_session_x_y', 'x_y_plot_xname': 'session', - 'x_y_plot_yname': 'foraging_eff', + 'x_y_plot_yname': 'foraging_performance', 'x_y_plot_group_by': 'h2o', 'x_y_plot_if_show_dots': True, 'x_y_plot_if_aggr_each_group': True, @@ -353,7 +353,7 @@ def plot_x_y_session(): cols = st.columns([4, 10]) with cols[0]: - x_name, y_name, group_by = add_xy_selector(if_bonsai=False) + x_name, y_name, group_by, size_mapper = add_xy_selector(if_bonsai=False) (if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group, if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, diff --git a/code/util/streamlit.py b/code/util/streamlit.py index 2dfc524..4e5f3d6 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -410,7 +410,7 @@ def add_xy_selector(if_bonsai): label="y axis", options=st.session_state.session_stats_names, key="x_y_plot_yname", - default=st.session_state.session_stats_names.index('foraging_performance') + default=st.session_state.session_stats_names.index('foraging_eff') ) if if_bonsai: @@ -427,7 +427,20 @@ def add_xy_selector(if_bonsai): default=0, ) - size_mapper = 1 + # Get all columns that are numeric + available_size_cols = ['None'] + [ + col + for col in st.session_state.session_stats_names + if is_numeric_dtype(st.session_state.df["sessions_bonsai"][col]) + ] + + size_mapper = selectbox_wrapper_for_url_query( + cols[0], + label="dot size mapper", + options=available_size_cols, + key='x_y_plot_size_mapper', + default=0, + ) # st.form_submit_button("update axes") return x_name, y_name, group_by, size_mapper @@ -664,7 +677,7 @@ def _plot_population_x_y(df, x_name='session', y_name='foraging_eff', group_by=' q_quantiles_all=20, title='', dot_size_base=10, - dot_size_mapping_name=None, + dot_size_mapping_name='None', dot_opacity=0.4, line_width=2, **kwarg): @@ -801,7 +814,7 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q fig = go.Figure() col_map = px.colors.qualitative.Plotly - if dot_size_mapping_name is not None and dot_size_mapping_name in df.columns: + if dot_size_mapping_name !='None' and dot_size_mapping_name in df.columns: df['dot_size'] = df[dot_size_mapping_name] else: df['dot_size'] = dot_size_base @@ -872,5 +885,3 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q title_standoff=40, ticks = "outside", tickcolor='black', ticklen=10, tickwidth=2, ticksuffix=' ') return fig - - From 55531f7a347ee146b65cc0f1dd4e348d28f1fe37 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Fri, 5 Apr 2024 00:17:37 -0700 Subject: [PATCH 05/20] fix: flexible typing in slider --- code/util/url_query_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/code/util/url_query_helper.py b/code/util/url_query_helper.py index ab78d89..8926326 100644 --- a/code/util/url_query_helper.py +++ b/code/util/url_query_helper.py @@ -35,7 +35,7 @@ def slider_wrapper_for_url_query(st_prefix, label, min_value, max_value, key, de value=( st.session_state[key] if key in st.session_state - else int(st.query_params[key]) + else type(min_value)(st.query_params[key]) if key in st.query_params else default ), From 989018f9043432ec5c1847a86d2ec724da7b154d Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Fri, 5 Apr 2024 00:42:33 -0700 Subject: [PATCH 06/20] feat: improve hover label of x_y_session --- code/Home.py | 38 +++++++------- code/util/streamlit.py | 116 +++++++++++++++++++++++++---------------- 2 files changed, 91 insertions(+), 63 deletions(-) diff --git a/code/Home.py b/code/Home.py index b7b17ff..aa74706 100644 --- a/code/Home.py +++ b/code/Home.py @@ -435,25 +435,25 @@ def plot_x_y_session(): # for i, (title, (x_name, y_name)) in enumerate(names.items()): # with cols[i]: with cols[1]: - fig = _plot_population_x_y(df=df_x_y_session, - x_name=x_name, y_name=y_name, - group_by=group_by, - smooth_factor=smooth_factor, - if_show_dots=if_show_dots, - if_aggr_each_group=if_aggr_each_group, - if_aggr_all=if_aggr_all, - aggr_method_group=aggr_method_group, - aggr_method_all=aggr_method_all, - if_use_x_quantile_group=if_use_x_quantile_group, - q_quantiles_group=q_quantiles_group, - if_use_x_quantile_all=if_use_x_quantile_all, - q_quantiles_all=q_quantiles_all, - title=names[(x_name, y_name)] if (x_name, y_name) in names else y_name, - states = st.session_state.df_selected_from_plotly, - dot_size_base=dot_size, - dot_size_mapping_name=size_mapper, - dot_opacity=dot_opacity, - line_width=line_width) + fig = _plot_population_x_y(df=df_x_y_session.copy(), + x_name=x_name, y_name=y_name, + group_by=group_by, + smooth_factor=smooth_factor, + if_show_dots=if_show_dots, + if_aggr_each_group=if_aggr_each_group, + if_aggr_all=if_aggr_all, + aggr_method_group=aggr_method_group, + aggr_method_all=aggr_method_all, + if_use_x_quantile_group=if_use_x_quantile_group, + q_quantiles_group=q_quantiles_group, + if_use_x_quantile_all=if_use_x_quantile_all, + q_quantiles_all=q_quantiles_all, + title=names[(x_name, y_name)] if (x_name, y_name) in names else y_name, + states = st.session_state.df_selected_from_plotly, + dot_size_base=dot_size, + dot_size_mapping_name=size_mapper, + dot_opacity=dot_opacity, + line_width=line_width) # st.plotly_chart(fig) selected = plotly_events(fig, click_event=True, hover_event=False, select_event=True, diff --git a/code/util/streamlit.py b/code/util/streamlit.py index 4e5f3d6..feaf29c 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -663,6 +663,12 @@ def add_auto_train_manager(): with st.expander('Automatic training manager', expanded=True): st.dataframe(df_training_manager, height=3000) +def _create_autotrain_string(row): + if pd.notna(row['curriculum_name']): + return f"{row['current_stage_actual']} @ {row['curriculum_name']}" + else: + return "None" + @st.cache_data(ttl=3600*24) def _plot_population_x_y(df, x_name='session', y_name='foraging_eff', group_by='h2o', smooth_factor=5, @@ -681,11 +687,11 @@ def _plot_population_x_y(df, x_name='session', y_name='foraging_eff', group_by=' dot_opacity=0.4, line_width=2, **kwarg): - - def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_quantiles, col, line_width): + + def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_quantiles, col, line_width, **kwarg): x = df_this.sort_values(x_name)[x_name].astype(float) y = df_this.sort_values(x_name)[y_name].astype(float) - + n_mice = len(df_this['h2o'].unique()) n_sessions = len(df_this.groupby(['h2o', 'session']).count()) n_str = f' ({n_mice} mice, {n_sessions} sessions)' if group_by !='h2o' else f' ({n_sessions} sessions)' @@ -701,12 +707,14 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q line_width=line_width, opacity=1, hoveron='points+fills', # Scattergl doesn't support this + hoverinfo='skip', + **kwarg, )) - + elif aggr_method == 'lowess': x_new = np.linspace(x.min(), x.max(), 200) lowess = sm.nonparametric.lowess(y, x, frac=smooth_factor/20) - + fig.add_trace(go.Scatter( x=lowess[:, 0], y=lowess[:, 1], @@ -717,25 +725,27 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q marker_color=col, opacity=1, hoveron='points+fills', # Scattergl doesn't support this + hoverinfo='skip', + **kwarg, )) - + elif aggr_method in ('mean +/- sem', 'mean'): - + # Re-bin x if use quantiles of x if if_use_x_quantile: df_this[f'{x_name}_quantile'] = pd.qcut(df_this[x_name], q=q_quantiles, labels=False, duplicates='drop') - + mean = df_this.groupby(f'{x_name}_quantile')[y_name].mean() sem = df_this.groupby(f'{x_name}_quantile')[y_name].sem() valid_y = mean.notna() mean = mean[valid_y] sem = sem[valid_y] sem[~sem.notna()] = 0 - + x = df_this.groupby(f'{x_name}_quantile')[x_name].median() # Use median of x in each quantile as x y_upper = mean + sem y_lower = mean - sem - + else: # mean and sem groupby x_name mean = df_this.groupby(x_name)[y_name].mean() @@ -744,11 +754,11 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q mean = mean[valid_y] sem = sem[valid_y] sem[~sem.notna()] = 0 - + x = mean.index y_upper = mean + sem y_lower = mean - sem - + fig.add_trace(go.Scatter( x=x, y=mean, @@ -762,8 +772,10 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q marker_color=col, opacity=1, hoveron='points+fills', # Scattergl doesn't support this + hoverinfo='skip', + **kwarg, )) - + if 'sem' in aggr_method: fig.add_trace(go.Scatter( # name='Upper Bound', @@ -789,7 +801,6 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q showlegend=False, hoverinfo='skip' )) - elif aggr_method == 'linear fit': # perform linear regression @@ -805,24 +816,27 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q line=dict(dash='dot' if p_value > 0.05 else 'solid', width=line_width if p_value > 0.05 else line_width*1.5), legendgroup=f'group_{group}', - # hoverinfo='skip' + hoverinfo='skip' ) ) except: pass - + fig = go.Figure() col_map = px.colors.qualitative.Plotly - + + # Add some more columns if dot_size_mapping_name !='None' and dot_size_mapping_name in df.columns: df['dot_size'] = df[dot_size_mapping_name] else: df['dot_size'] = dot_size_base - + + df['auto_train_string'] = df.apply(_create_autotrain_string, axis=1) + for i, group in enumerate(df.sort_values(group_by)[group_by].unique()): this_session = df.query(f'{group_by} == "{group}"').sort_values('session') col = col_map[i%len(col_map)] - + if if_show_dots: if not len(st.session_state.df_selected_from_plotly): this_session['colors'] = col # all use normal colors @@ -832,7 +846,7 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q merged.loc[merged.subject_id_y.notna(), 'colors'] = col # only use normal colors for the selected dots this_session['colors'] = merged.colors.values this_session = pd.concat([this_session.query('colors != "lightgrey"'), this_session.query('colors == "lightgrey"')]) # make sure the real color goes first - + fig.add_trace(go.Scattergl( x=this_session[x_name], y=this_session[y_name], @@ -843,44 +857,58 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q line_width=line_width, marker_size=this_session['dot_size'], marker_color=this_session['colors'], - opacity=dot_opacity, # 0.5 if if_aggr_each_group else 0.8, - text=this_session['session'], - hovertemplate = '
%{customdata[0]}, Session %{text}' + - '
%s = %%{x}' % (x_name) + - '
%s = %%{y}' % (y_name), - # '%{name}', - customdata=np.stack((this_session.h2o, this_session.session), axis=-1), + opacity=dot_opacity, + hovertemplate = '%{customdata[0]}, %{customdata[1]}, Session %{customdata[2]}' + '
%{customdata[3]}, %{customdata[4]}' + '
Task: %{customdata[5]}' + '
AutoTrain: %{customdata[6]}
' + f'
{"-"*10}
X: {x_name} = %{{x}}' + f'
Y: {y_name} = %{{y}}' + + (f'
Size: {dot_size_mapping_name} = %{{customdata[7]}}' + if dot_size_mapping_name !='None' + else '') + + '', + customdata=np.stack((this_session.h2o, # 0 + this_session.session_date.dt.strftime('%Y-%m-%d'), # 1 + this_session.session, # 2 + this_session.rig, # 3 + this_session.user_name, # 4 + this_session.task, # 5 + this_session.auto_train_string, # 6 + this_session[dot_size_mapping_name] + if dot_size_mapping_name !='None' + else [np.nan] * len(this_session.h2o), # 7 + ), axis=-1), unselected=dict(marker_color='lightgrey') )) - + if if_aggr_each_group: _add_agg(this_session, x_name, y_name, group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group, col, line_width=line_width) - if if_aggr_all: _add_agg(df, x_name, y_name, 'all', aggr_method_all, if_use_x_quantile_all, q_quantiles_all, 'rgb(0, 0, 0)', line_width=line_width*1.5) - n_mice = len(df['h2o'].unique()) n_sessions = len(df.groupby(['h2o', 'session']).count()) - + fig.update_layout( - width=1300, - height=900, - xaxis_title=x_name, - yaxis_title=y_name, - font=dict(size=25), - hovermode='closest', - legend={'traceorder':'reversed'}, - legend_font_size=15, - title=f'{title}, {n_mice} mice, {n_sessions} sessions', - dragmode='zoom', # 'select', - margin=dict(l=130, r=50, b=130, t=100), - ) + width=1300, + height=900, + xaxis_title=x_name, + yaxis_title=y_name, + font=dict(size=25), + hovermode="closest", + hoverlabel=dict(font_size=17), + legend={"traceorder": "reversed"}, + legend_font_size=15, + title=f"{title}, {n_mice} mice, {n_sessions} sessions", + dragmode="zoom", # 'select', + margin=dict(l=130, r=50, b=130, t=100), + ) fig.update_xaxes(showline=True, linewidth=2, linecolor='black', # range=[1, min(100, df[x_name].max())], ticks = "outside", tickcolor='black', ticklen=10, tickwidth=2, ticksuffix=' ') - + fig.update_yaxes(showline=True, linewidth=2, linecolor='black', title_standoff=40, ticks = "outside", tickcolor='black', ticklen=10, tickwidth=2, ticksuffix=' ') From 66cb6a12dc369f3846a7295ae9a6a7efc85c447f Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Fri, 5 Apr 2024 01:26:29 -0700 Subject: [PATCH 07/20] bug fix --- code/Home.py | 7 ++++++ code/util/streamlit.py | 50 +++++++++++++++++++++++++----------------- 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/code/Home.py b/code/Home.py index aa74706..1cd625d 100644 --- a/code/Home.py +++ b/code/Home.py @@ -576,6 +576,13 @@ def init(): # map user_name st.session_state.df['sessions_bonsai']['user_name'] = st.session_state.df['sessions_bonsai']['user_name'].apply(_user_name_mapper) + # fill nan for autotrain fields + filled_values = {'curriculum_name': 'None', + 'curriculum_version': 'None', + 'curriculum_schema_version': 'None', + 'current_stage_actual': 'None'} + st.session_state.df['sessions_bonsai'].fillna(filled_values, inplace=True) + # foraging performance = foraing_eff * finished_rate if 'foraging_performance' not in st.session_state.df['sessions_bonsai'].columns: st.session_state.df['sessions_bonsai']['foraging_performance'] = \ diff --git a/code/util/streamlit.py b/code/util/streamlit.py index feaf29c..d1b6583 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -1,4 +1,4 @@ -from email import header +from collections import OrderedDict import pandas as pd import streamlit as st from st_aggrid import AgGrid, GridOptionsBuilder @@ -392,6 +392,24 @@ def add_session_filter(if_bonsai=False, url_query={}): default_filters=['subject_id', 'task', 'session', 'finished_trials', 'foraging_eff'], url_query=url_query) +@st.cache_data(ttl=3600*24) +def _get_grouped_by_fields(if_bonsai): + if if_bonsai: + options = ['h2o', 'task', 'user_name', 'rig', 'weekday'] + options += [col + for col in st.session_state.df_session_filtered.columns + if is_categorical_dtype(st.session_state.df_session_filtered[col]) + or st.session_state.df_session_filtered[col].nunique() < 20 + and not any([exclude in col for exclude in + ('date', 'time', 'session', 'finished', 'foraging_eff')]) + ] + options = list(list(OrderedDict.fromkeys(options))) # Remove duplicates + else: + options = ['h2o', 'task', 'photostim_location', 'weekday', + 'headbar', 'user_name', 'sex', 'rig'] + return options + + def add_xy_selector(if_bonsai): with st.expander("Select axes", expanded=True): # with st.form("axis_selection"): @@ -413,16 +431,11 @@ def add_xy_selector(if_bonsai): default=st.session_state.session_stats_names.index('foraging_eff') ) - if if_bonsai: - options = ['h2o', 'task', 'user_name', 'rig', 'weekday'] - else: - options = ['h2o', 'task', 'photostim_location', 'weekday', - 'headbar', 'user_name', 'sex', 'rig'] group_by = selectbox_wrapper_for_url_query( cols[0], label="grouped by", - options=options, + options=_get_grouped_by_fields(if_bonsai), key="x_y_plot_group_by", default=0, ) @@ -431,7 +444,7 @@ def add_xy_selector(if_bonsai): available_size_cols = ['None'] + [ col for col in st.session_state.session_stats_names - if is_numeric_dtype(st.session_state.df["sessions_bonsai"][col]) + if is_numeric_dtype(st.session_state.df_session_filtered[col]) ] size_mapper = selectbox_wrapper_for_url_query( @@ -663,12 +676,6 @@ def add_auto_train_manager(): with st.expander('Automatic training manager', expanded=True): st.dataframe(df_training_manager, height=3000) -def _create_autotrain_string(row): - if pd.notna(row['curriculum_name']): - return f"{row['current_stage_actual']} @ {row['curriculum_name']}" - else: - return "None" - @st.cache_data(ttl=3600*24) def _plot_population_x_y(df, x_name='session', y_name='foraging_eff', group_by='h2o', smooth_factor=5, @@ -831,8 +838,6 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q else: df['dot_size'] = dot_size_base - df['auto_train_string'] = df.apply(_create_autotrain_string, axis=1) - for i, group in enumerate(df.sort_values(group_by)[group_by].unique()): this_session = df.query(f'{group_by} == "{group}"').sort_values('session') col = col_map[i%len(col_map)] @@ -861,10 +866,10 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q hovertemplate = '%{customdata[0]}, %{customdata[1]}, Session %{customdata[2]}' '
%{customdata[3]}, %{customdata[4]}' '
Task: %{customdata[5]}' - '
AutoTrain: %{customdata[6]}
' + '
AutoTrain: %{customdata[7]} @ %{customdata[6]}' f'
{"-"*10}
X: {x_name} = %{{x}}' f'
Y: {y_name} = %{{y}}' - + (f'
Size: {dot_size_mapping_name} = %{{customdata[7]}}' + + (f'
Size: {dot_size_mapping_name} = %{{customdata[8]}}' if dot_size_mapping_name !='None' else '') + '', @@ -874,10 +879,15 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q this_session.rig, # 3 this_session.user_name, # 4 this_session.task, # 5 - this_session.auto_train_string, # 6 + this_session.curriculum_name + if 'curriculum_name' in this_session.columns + else ['None'] * len(this_session.h2o), # 6 + this_session.current_stage_actual + if 'current_stage_actual' in this_session.columns + else ['None'] * len(this_session.h2o), # 7 this_session[dot_size_mapping_name] if dot_size_mapping_name !='None' - else [np.nan] * len(this_session.h2o), # 7 + else [np.nan] * len(this_session.h2o), # 8 ), axis=-1), unselected=dict(marker_color='lightgrey') )) From f51f00e08c0f95da8100b0a44e5c45768f080785 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Fri, 5 Apr 2024 01:39:53 -0700 Subject: [PATCH 08/20] fix: numeric grouped_by --- code/util/streamlit.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/code/util/streamlit.py b/code/util/streamlit.py index d1b6583..3f8f270 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -7,6 +7,7 @@ is_categorical_dtype, is_datetime64_any_dtype, is_numeric_dtype, + is_string_dtype, is_object_dtype, ) import streamlit.components.v1 as components @@ -838,10 +839,14 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q else: df['dot_size'] = dot_size_base + # Turn column of group_by to string if it's not + if not is_string_dtype(df[group_by]): + df[group_by] = df[group_by].astype(str) + for i, group in enumerate(df.sort_values(group_by)[group_by].unique()): this_session = df.query(f'{group_by} == "{group}"').sort_values('session') col = col_map[i%len(col_map)] - + if if_show_dots: if not len(st.session_state.df_selected_from_plotly): this_session['colors'] = col # all use normal colors From 9371f78fc68fb38efef8e6f2eb2ec6dd0c3a4e64 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Fri, 5 Apr 2024 09:54:03 -0700 Subject: [PATCH 09/20] feat: improve show plot --- code/Home.py | 50 ++++++++++++-------------------------------------- 1 file changed, 12 insertions(+), 38 deletions(-) diff --git a/code/Home.py b/code/Home.py index 1cd625d..0ab7db5 100644 --- a/code/Home.py +++ b/code/Home.py @@ -339,7 +339,7 @@ def draw_mice_plots(df_to_draw_mice): def session_plot_settings(need_click=True): st.markdown('##### Show plots for individual sessions ') - cols = st.columns([2, 1]) + cols = st.columns([2, 1, 6]) session_plot_modes = [f'sessions selected from table or plot', f'all sessions filtered from sidebar'] st.session_state.selected_draw_sessions = cols[0].selectbox(f'Which session(s) to draw?', @@ -355,7 +355,7 @@ def session_plot_settings(need_click=True): n_session_to_draw = len(st.session_state.df_selected_from_plotly) \ if 'selected from table or plot' in st.session_state.selected_draw_sessions \ else len(st.session_state.df_session_filtered) - st.markdown(f'{n_session_to_draw} sessions to draw') + cols[0].markdown(f'{n_session_to_draw} sessions to draw') st.session_state.num_cols = cols[1].number_input('number of columns', 1, 10, 3 if 'num_cols' not in st.session_state else st.session_state.num_cols) @@ -369,46 +369,18 @@ def session_plot_settings(need_click=True): """, unsafe_allow_html=True, ) - st.session_state.selected_draw_types = st.multiselect('Which plot(s) to draw?', + st.session_state.selected_draw_types = cols[2].multiselect('Which plot(s) to draw?', st.session_state.draw_type_mapper_session_level.keys(), default=st.session_state.draw_type_mapper_session_level.keys() if 'selected_draw_types' not in st.session_state else st.session_state.selected_draw_types) if need_click: - draw_it = st.button(f'Show me all {n_session_to_draw} sessions!', use_container_width=True) + draw_it = st.button(f'Show me all {n_session_to_draw} sessions!', use_container_width=True, type="primary") + draw_it_now_override = cols[1].checkbox('Auto draw') else: draw_it = True - return draw_it - -def mouse_plot_settings(need_click=True): - st.markdown('##### Show plots for individual mice ') - cols = st.columns([2, 1]) - st.session_state.selected_draw_mice = cols[0].selectbox('Which mice to draw?', - [f'selected from table/plot ({len(st.session_state.df_selected_from_plotly.h2o.unique())} mice)', - f'filtered from sidebar ({len(st.session_state.df_session_filtered.h2o.unique())} mice)'], - index=0 - ) - st.session_state.num_cols_mice = cols[1].number_input('Number of columns', 1, 10, - 3 if 'num_cols_mice' not in st.session_state else st.session_state.num_cols_mice) - st.markdown( - """ - """, - unsafe_allow_html=True, - ) - st.session_state.selected_draw_types_mice = st.multiselect('Which plot(s) to draw?', - st.session_state.draw_type_mapper_mouse_level.keys(), - default=st.session_state.draw_type_mapper_mouse_level.keys() - if 'selected_draw_types_mice' not in st.session_state else - st.session_state.selected_draw_types_mice) - if need_click: - draw_it = st.button('Show me all mice!', use_container_width=True) - else: - draw_it = True - return draw_it + draw_it_now_override = True + return draw_it | draw_it_now_override def plot_x_y_session(): @@ -683,8 +655,8 @@ def app(): with placeholder: df_selected_from_plotly, x_y_cols = plot_x_y_session() - with x_y_cols[0]: - for i in range(7): st.write('\n') + # Add session_plot_setting + with st.columns([1, 0.5])[0]: st.markdown("***") if_draw_all_sessions = session_plot_settings() @@ -732,7 +704,7 @@ def app(): elif chosen_id == "tab_session_inspector": with placeholder: - cols = st.columns([6, 3, 7]) + cols = st.columns([1, 0.5]) with cols[0]: if_draw_all_sessions = session_plot_settings(need_click=False) df_to_draw_sessions = st.session_state.df_selected_from_plotly if 'selected' in st.session_state.selected_draw_sessions else st.session_state.df_session_filtered @@ -842,6 +814,8 @@ def app(): # Add debug info if chosen_id != "tab_auto_train_curriculum": + for _ in range(10): st.write('\n') + st.markdown('---\n##### Debug zone') with st.expander('CO processing NWB errors', expanded=False): error_file = cache_folder + 'error_files.json' if fs.exists(error_file): From 21bd0106ea8b0886c87866ac7dfbaad4018224c5 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Fri, 5 Apr 2024 10:16:27 -0700 Subject: [PATCH 10/20] feat: flexble figure size and font --- code/Home.py | 11 ++++++++-- code/pages/1_Old mice.py | 2 +- code/util/streamlit.py | 43 +++++++++++++++++++++++++++++++++------- 3 files changed, 46 insertions(+), 10 deletions(-) diff --git a/code/Home.py b/code/Home.py index 0ab7db5..efa5a90 100644 --- a/code/Home.py +++ b/code/Home.py @@ -78,6 +78,9 @@ 'x_y_plot_dot_size': 10, 'x_y_plot_dot_opacity': 0.5, 'x_y_plot_line_width': 2.0, + 'x_y_plot_figure_width': 1300, + 'x_y_plot_figure_height': 900, + 'x_y_plot_font_size_scale': 1.0, 'x_y_plot_size_mapper': 'None', @@ -393,7 +396,7 @@ def plot_x_y_session(): (if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group, if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, - dot_size, dot_opacity, line_width) = add_xy_setting() + dot_size, dot_opacity, line_width, x_y_plot_figure_width, x_y_plot_figure_height, font_size_scale) = add_xy_setting() # If no sessions are selected, use all filtered entries # df_x_y_session = st.session_state.df_selected_from_dataframe if if_plot_only_selected_from_dataframe else st.session_state.df_session_filtered @@ -425,7 +428,11 @@ def plot_x_y_session(): dot_size_base=dot_size, dot_size_mapping_name=size_mapper, dot_opacity=dot_opacity, - line_width=line_width) + line_width=line_width, + x_y_plot_figure_width=x_y_plot_figure_width, + x_y_plot_figure_height=x_y_plot_figure_height, + font_size_scale=font_size_scale, + ) # st.plotly_chart(fig) selected = plotly_events(fig, click_event=True, hover_event=False, select_event=True, diff --git a/code/pages/1_Old mice.py b/code/pages/1_Old mice.py index 195147f..379aec7 100644 --- a/code/pages/1_Old mice.py +++ b/code/pages/1_Old mice.py @@ -357,7 +357,7 @@ def plot_x_y_session(): (if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group, if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, - dot_size, dot_opacity, line_width) = add_xy_setting() + dot_size, dot_opacity, line_width, _, _, _) = add_xy_setting() # If no sessions are selected, use all filtered entries diff --git a/code/util/streamlit.py b/code/util/streamlit.py index 3f8f270..c7ec9a0 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -558,10 +558,32 @@ def add_xy_setting(): step=0.25, key='x_y_plot_line_width', default=2.0) + + figure_width = slider_wrapper_for_url_query(c[0], + label='figure width', + min_value=500, + max_value=2500, + key='x_y_plot_figure_width', + default=1300) + + figure_height = slider_wrapper_for_url_query(c[1], + label='figure height', + min_value=500, + max_value=2500, + key='x_y_plot_figure_height', + default=900) + + font_size_scale = slider_wrapper_for_url_query(c[2], + label='font size', + min_value=0.0, + max_value=2.0, + step=0.1, + key='x_y_plot_font_size_scale', + default=1.0) return (if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group, if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, - dot_size, dot_opacity, line_width) + dot_size, dot_opacity, line_width, figure_width, figure_height, font_size_scale) def data_selector(): @@ -694,6 +716,9 @@ def _plot_population_x_y(df, x_name='session', y_name='foraging_eff', group_by=' dot_size_mapping_name='None', dot_opacity=0.4, line_width=2, + x_y_plot_figure_width=1300, + x_y_plot_figure_height=900, + font_size_scale=1.0, **kwarg): def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_quantiles, col, line_width, **kwarg): @@ -907,18 +932,22 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q n_sessions = len(df.groupby(['h2o', 'session']).count()) fig.update_layout( - width=1300, - height=900, + width=x_y_plot_figure_width, + height=x_y_plot_figure_height, xaxis_title=x_name, yaxis_title=y_name, - font=dict(size=25), + font=dict(size=24 * font_size_scale), hovermode="closest", - hoverlabel=dict(font_size=17), + hoverlabel=dict(font_size=17 * font_size_scale), legend={"traceorder": "reversed"}, - legend_font_size=15, + legend_font_size=20 * font_size_scale, title=f"{title}, {n_mice} mice, {n_sessions} sessions", dragmode="zoom", # 'select', - margin=dict(l=130, r=50, b=130, t=100), + margin=dict(l=130 * font_size_scale, + r=50 * font_size_scale, + b=130 * font_size_scale, + t=100 * font_size_scale, + ), ) fig.update_xaxes(showline=True, linewidth=2, linecolor='black', # range=[1, min(100, df[x_name].max())], From ed4828bdc36b4437f941f2c566e1b8a8fa6d85a2 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Fri, 5 Apr 2024 10:24:25 -0700 Subject: [PATCH 11/20] feat: reorganize plot settings --- code/Home.py | 9 ++-- code/util/streamlit.py | 95 +++++++++++++++++++++++------------------- 2 files changed, 56 insertions(+), 48 deletions(-) diff --git a/code/Home.py b/code/Home.py index efa5a90..709a1d7 100644 --- a/code/Home.py +++ b/code/Home.py @@ -388,14 +388,14 @@ def session_plot_settings(need_click=True): def plot_x_y_session(): - cols = st.columns([4, 10]) + cols = st.columns([1, 1, 1]) with cols[0]: - x_name, y_name, group_by, size_mapper = add_xy_selector(if_bonsai=True) + with cols[1]: (if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group, - if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, + if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, if_show_diagonal, dot_size, dot_opacity, line_width, x_y_plot_figure_width, x_y_plot_figure_height, font_size_scale) = add_xy_setting() # If no sessions are selected, use all filtered entries @@ -409,7 +409,7 @@ def plot_x_y_session(): df_selected_from_plotly = pd.DataFrame() # for i, (title, (x_name, y_name)) in enumerate(names.items()): # with cols[i]: - with cols[1]: + with st.columns([1])[0]: fig = _plot_population_x_y(df=df_x_y_session.copy(), x_name=x_name, y_name=y_name, group_by=group_by, @@ -425,6 +425,7 @@ def plot_x_y_session(): q_quantiles_all=q_quantiles_all, title=names[(x_name, y_name)] if (x_name, y_name) in names else y_name, states = st.session_state.df_selected_from_plotly, + if_show_diagonal=if_show_diagonal, dot_size_base=dot_size, dot_size_mapping_name=size_mapper, dot_opacity=dot_opacity, diff --git a/code/util/streamlit.py b/code/util/streamlit.py index c7ec9a0..9731102 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -535,54 +535,60 @@ def add_xy_setting(): or (if_aggr_all and aggr_method_all in ('running average', 'lowess'))) ) - c = st.columns([1, 1, 1]) - dot_size = slider_wrapper_for_url_query(c[0], - label='dot size', - min_value=1, - max_value=30, - key='x_y_plot_dot_size', - default=10) + if_show_diagonal = checkbox_wrapper_for_url_query(s_cols[0], + label='Show diagonal line', + key='x_y_plot_if_show_diagonal', + default=False) - dot_opacity = slider_wrapper_for_url_query(c[1], - label='opacity', - min_value=0.0, - max_value=1.0, - step=0.05, - key='x_y_plot_dot_opacity', - default=0.5) - - line_width = slider_wrapper_for_url_query(c[2], - label='line width', + with st.expander('Misc', expanded=False): + c = st.columns([1, 1, 1]) + dot_size = slider_wrapper_for_url_query(c[0], + label='dot size', + min_value=1, + max_value=30, + key='x_y_plot_dot_size', + default=10) + + dot_opacity = slider_wrapper_for_url_query(c[1], + label='opacity', + min_value=0.0, + max_value=1.0, + step=0.05, + key='x_y_plot_dot_opacity', + default=0.5) + + line_width = slider_wrapper_for_url_query(c[2], + label='line width', + min_value=0.0, + max_value=5.0, + step=0.25, + key='x_y_plot_line_width', + default=2.0) + + figure_width = slider_wrapper_for_url_query(c[0], + label='figure width', + min_value=500, + max_value=2500, + key='x_y_plot_figure_width', + default=1300) + + figure_height = slider_wrapper_for_url_query(c[1], + label='figure height', + min_value=500, + max_value=2500, + key='x_y_plot_figure_height', + default=900) + + font_size_scale = slider_wrapper_for_url_query(c[2], + label='font size', min_value=0.0, - max_value=5.0, - step=0.25, - key='x_y_plot_line_width', - default=2.0) - - figure_width = slider_wrapper_for_url_query(c[0], - label='figure width', - min_value=500, - max_value=2500, - key='x_y_plot_figure_width', - default=1300) - - figure_height = slider_wrapper_for_url_query(c[1], - label='figure height', - min_value=500, - max_value=2500, - key='x_y_plot_figure_height', - default=900) - - font_size_scale = slider_wrapper_for_url_query(c[2], - label='font size', - min_value=0.0, - max_value=2.0, - step=0.1, - key='x_y_plot_font_size_scale', - default=1.0) + max_value=2.0, + step=0.1, + key='x_y_plot_font_size_scale', + default=1.0) return (if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group, - if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, + if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, if_show_diagonal, dot_size, dot_opacity, line_width, figure_width, figure_height, font_size_scale) def data_selector(): @@ -712,6 +718,7 @@ def _plot_population_x_y(df, x_name='session', y_name='foraging_eff', group_by=' if_use_x_quantile_all=False, q_quantiles_all=20, title='', + if_show_diagonal=False, dot_size_base=10, dot_size_mapping_name='None', dot_opacity=0.4, From 4d83ff6174e547f35ffa2109956b3461c8503c5c Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Fri, 5 Apr 2024 10:30:51 -0700 Subject: [PATCH 12/20] feat: show diagonal --- code/Home.py | 1 + code/util/streamlit.py | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/code/Home.py b/code/Home.py index 709a1d7..2b00e80 100644 --- a/code/Home.py +++ b/code/Home.py @@ -75,6 +75,7 @@ 'x_y_plot_q_quantiles_group': 20, 'x_y_plot_if_use_x_quantile_all': False, 'x_y_plot_q_quantiles_all': 20, + 'x_y_plot_if_show_diagonal': False, 'x_y_plot_dot_size': 10, 'x_y_plot_dot_opacity': 0.5, 'x_y_plot_line_width': 2.0, diff --git a/code/util/streamlit.py b/code/util/streamlit.py index 9731102..6aba2b7 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -870,15 +870,26 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q df['dot_size'] = df[dot_size_mapping_name] else: df['dot_size'] = dot_size_base - + # Turn column of group_by to string if it's not if not is_string_dtype(df[group_by]): df[group_by] = df[group_by].astype(str) + # Add a diagonal line first + if if_show_diagonal: + _min = df[x_name].values.ravel().min() + _max = df[y_name].values.ravel().max() + fig.add_trace(go.Scattergl(x=[_min, _max], + y=[_min, _max], + mode='lines', + line=dict(dash='dash', color='black', width=2), + showlegend=False) + ) + for i, group in enumerate(df.sort_values(group_by)[group_by].unique()): this_session = df.query(f'{group_by} == "{group}"').sort_values('session') col = col_map[i%len(col_map)] - + if if_show_dots: if not len(st.session_state.df_selected_from_plotly): this_session['colors'] = col # all use normal colors From a77c7f898be94db51506bf9705be7914fdb4e7e9 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Fri, 5 Apr 2024 10:37:16 -0700 Subject: [PATCH 13/20] feat: add table height to url --- code/Home.py | 17 ++++++++++++++--- code/pages/1_Old mice.py | 2 +- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/code/Home.py b/code/Home.py index 2b00e80..e8569bf 100644 --- a/code/Home.py +++ b/code/Home.py @@ -42,7 +42,9 @@ aggrid_interactive_table_curriculum, add_session_filter, data_selector, add_xy_selector, add_xy_setting, add_auto_train_manager, _plot_population_x_y) -from util.url_query_helper import sync_widget_with_query +from util.url_query_helper import ( + sync_widget_with_query, slider_wrapper_for_url_query, +) import extra_streamlit_components as stx @@ -61,6 +63,8 @@ 'filter_foraging_eff': [0.0, None], 'filter_task': ['all'], + 'table_height': 300, + 'tab_id': 'tab_session_x_y', 'x_y_plot_xname': 'session', 'x_y_plot_yname': 'foraging_performance_random_seed', @@ -625,8 +629,15 @@ def app(): init() st.rerun() - table_height = cols[3].slider('Table height', 100, 2000, 400, 50, key='table_height') - + table_height = slider_wrapper_for_url_query(st_prefix=cols[3], + label='Table height', + min_value=0, + max_value=2000, + default=300, + step=50, + key='table_height', + ) + # aggrid_outputs = aggrid_interactive_table_units(df=df['ephys_units']) # st.session_state.df_session_filtered = aggrid_outputs['data'] diff --git a/code/pages/1_Old mice.py b/code/pages/1_Old mice.py index 379aec7..b25c24b 100644 --- a/code/pages/1_Old mice.py +++ b/code/pages/1_Old mice.py @@ -356,7 +356,7 @@ def plot_x_y_session(): x_name, y_name, group_by, size_mapper = add_xy_selector(if_bonsai=False) (if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group, - if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, + if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, if_show_diagonal, dot_size, dot_opacity, line_width, _, _, _) = add_xy_setting() From 52030a9c4663f797a7fb90167e8dda0ff30982ea Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Fri, 5 Apr 2024 11:01:20 -0700 Subject: [PATCH 14/20] feat: add quick preview --- code/Home.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/code/Home.py b/code/Home.py index e8569bf..135b6c0 100644 --- a/code/Home.py +++ b/code/Home.py @@ -242,7 +242,7 @@ def get_pyg_renderer(df, spec="./gw_config.json", **kwargs) -> "StreamlitRendere return StreamlitRenderer(df, spec=spec, debug=False, **kwargs) -def draw_session_plots(df_to_draw_session): +def draw_session_plots(df_to_draw_session, if_quick_preview=False): # Setting up layout for each session layout_definition = [[1], # columns in the first row @@ -294,7 +294,42 @@ def draw_session_plots(df_to_draw_session): my_bar.progress(int((i + 1) / len(df_to_draw_session) * 100)) +def draw_session_plots_quick_preview(df_to_draw_session): + + # Setting up layout for each session + layout_definition = [[1], # columns in the first row + [1, 1], + ] + draw_types_quick_preview = ['1. Choice history', '2. Logistic regression (Su2022)'] + + container_session_all_in_one = st.container() + + key = df_to_draw_session.to_dict(orient='records')[0] + + with container_session_all_in_one: + try: + date_str = key["session_date"].strftime('%Y-%m-%d') + except: + date_str = key["session_date"].split("T")[0] + + st.markdown(f'''

{key["h2o"]}, Session {int(key["session"])}, {date_str}''', + unsafe_allow_html=True) + + rows = [] + for row, column_setting in enumerate(layout_definition): + rows.append(st.columns(column_setting)) + + for draw_type in draw_types_quick_preview: + if draw_type not in st.session_state.selected_draw_types: continue # To keep the draw order defined by st.session_state.draw_type_mapper_session_level + prefix, position, setting = st.session_state.draw_type_mapper_session_level[draw_type] + this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types) > 1 else rows[0] + show_session_level_img_by_key_and_prefix(key, + column=this_col, + prefix=prefix, + **setting) + + def draw_mice_plots(df_to_draw_mice): @@ -414,7 +449,9 @@ def plot_x_y_session(): df_selected_from_plotly = pd.DataFrame() # for i, (title, (x_name, y_name)) in enumerate(names.items()): # with cols[i]: - with st.columns([1])[0]: + + cols = st.columns([1, 0.7]) + with cols[0]: fig = _plot_population_x_y(df=df_x_y_session.copy(), x_name=x_name, y_name=y_name, group_by=group_by, @@ -443,10 +480,17 @@ def plot_x_y_session(): # st.plotly_chart(fig) selected = plotly_events(fig, click_event=True, hover_event=False, select_event=True, override_height=fig.layout.height * 1.1, override_width=fig.layout.width) + + with cols[1]: + st.markdown('#### 👀 Quick preview') + st.markdown('Click on one session to preview here (or Box/Lasso select multiple sessions to draw them in the section below)') if len(selected): df_selected_from_plotly = df_x_y_session.merge(pd.DataFrame(selected).rename({'x': x_name, 'y': y_name}, axis=1), on=[x_name, y_name], how='inner') + if len(st.session_state.df_selected_from_plotly) == 1: + with cols[1]: + draw_session_plots_quick_preview(st.session_state.df_selected_from_plotly) return df_selected_from_plotly, cols From 19fbeaee705b9942e0b640479283949b5e994189 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Fri, 5 Apr 2024 11:12:32 -0700 Subject: [PATCH 15/20] feat: adaptive col scale for quick preview --- code/Home.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/code/Home.py b/code/Home.py index 135b6c0..2829c43 100644 --- a/code/Home.py +++ b/code/Home.py @@ -242,7 +242,7 @@ def get_pyg_renderer(df, spec="./gw_config.json", **kwargs) -> "StreamlitRendere return StreamlitRenderer(df, spec=spec, debug=False, **kwargs) -def draw_session_plots(df_to_draw_session, if_quick_preview=False): +def draw_session_plots(df_to_draw_session): # Setting up layout for each session layout_definition = [[1], # columns in the first row @@ -450,7 +450,11 @@ def plot_x_y_session(): # for i, (title, (x_name, y_name)) in enumerate(names.items()): # with cols[i]: - cols = st.columns([1, 0.7]) + if hasattr(st.session_state, 'x_y_plot_figure_width'): + _x_y_plot_scale = st.session_state.x_y_plot_figure_width / 1300 + cols = st.columns([1 * _x_y_plot_scale, 0.7]) + else: + cols = st.columns([1, 0.7]) with cols[0]: fig = _plot_population_x_y(df=df_x_y_session.copy(), x_name=x_name, y_name=y_name, From 1d25456e85e275604f42bde1857dbd012d6b26ab Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Fri, 5 Apr 2024 11:15:57 -0700 Subject: [PATCH 16/20] feat: add instruction --- code/Home.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/code/Home.py b/code/Home.py index 2829c43..265c104 100644 --- a/code/Home.py +++ b/code/Home.py @@ -487,7 +487,8 @@ def plot_x_y_session(): with cols[1]: st.markdown('#### 👀 Quick preview') - st.markdown('Click on one session to preview here (or Box/Lasso select multiple sessions to draw them in the section below)') + st.markdown('###### Click on one session to preview here, or Box/Lasso select multiple sessions to draw them in the section below') + st.markdown('(sometimes you have to click twice...)') if len(selected): df_selected_from_plotly = df_x_y_session.merge(pd.DataFrame(selected).rename({'x': x_name, 'y': y_name}, axis=1), From 1fc628348edd68b6847613cbe8996ee883ef1f5a Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Fri, 5 Apr 2024 12:02:49 -0700 Subject: [PATCH 17/20] feat: add dot size mapper range and gamma --- code/Home.py | 15 ++++++- code/pages/1_Old mice.py | 2 +- code/util/streamlit.py | 75 +++++++++++++++++++++++++++-------- code/util/url_query_helper.py | 7 +++- 4 files changed, 78 insertions(+), 21 deletions(-) diff --git a/code/Home.py b/code/Home.py index 265c104..467a88a 100644 --- a/code/Home.py +++ b/code/Home.py @@ -40,7 +40,7 @@ from util.streamlit import (filter_dataframe, aggrid_interactive_table_session, aggrid_interactive_table_curriculum, add_session_filter, data_selector, - add_xy_selector, add_xy_setting, add_auto_train_manager, + add_xy_selector, add_xy_setting, add_auto_train_manager, add_dot_property_mapper, _plot_population_x_y) from util.url_query_helper import ( sync_widget_with_query, slider_wrapper_for_url_query, @@ -88,6 +88,8 @@ 'x_y_plot_font_size_scale': 1.0, 'x_y_plot_size_mapper': 'None', + 'x_y_plot_size_mapper_gamma': 1.0, + 'x_y_plot_size_mapper_range': [0, 50], 'session_plot_mode': 'sessions selected from table or plot', @@ -431,13 +433,20 @@ def plot_x_y_session(): cols = st.columns([1, 1, 1]) with cols[0]: - x_name, y_name, group_by, size_mapper = add_xy_selector(if_bonsai=True) + x_name, y_name, group_by = add_xy_selector(if_bonsai=True) with cols[1]: (if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group, if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, if_show_diagonal, dot_size, dot_opacity, line_width, x_y_plot_figure_width, x_y_plot_figure_height, font_size_scale) = add_xy_setting() + if st.session_state.x_y_plot_if_show_dots: + with cols[2]: + size_mapper, size_mapper_range, size_mapper_gamma = add_dot_property_mapper() + else: + size_mapper = 'None' + size_mapper_range, size_mapper_gamma = None, None + # If no sessions are selected, use all filtered entries # df_x_y_session = st.session_state.df_selected_from_dataframe if if_plot_only_selected_from_dataframe else st.session_state.df_session_filtered df_x_y_session = st.session_state.df_session_filtered @@ -474,6 +483,8 @@ def plot_x_y_session(): if_show_diagonal=if_show_diagonal, dot_size_base=dot_size, dot_size_mapping_name=size_mapper, + dot_size_mapping_range=size_mapper_range, + dot_size_mapping_gamma=size_mapper_gamma, dot_opacity=dot_opacity, line_width=line_width, x_y_plot_figure_width=x_y_plot_figure_width, diff --git a/code/pages/1_Old mice.py b/code/pages/1_Old mice.py index b25c24b..f11d1be 100644 --- a/code/pages/1_Old mice.py +++ b/code/pages/1_Old mice.py @@ -353,7 +353,7 @@ def plot_x_y_session(): cols = st.columns([4, 10]) with cols[0]: - x_name, y_name, group_by, size_mapper = add_xy_selector(if_bonsai=False) + x_name, y_name, group_by = add_xy_selector(if_bonsai=False) (if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group, if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, if_show_diagonal, diff --git a/code/util/streamlit.py b/code/util/streamlit.py index 6aba2b7..662bdf3 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -441,23 +441,8 @@ def add_xy_selector(if_bonsai): default=0, ) - # Get all columns that are numeric - available_size_cols = ['None'] + [ - col - for col in st.session_state.session_stats_names - if is_numeric_dtype(st.session_state.df_session_filtered[col]) - ] - - size_mapper = selectbox_wrapper_for_url_query( - cols[0], - label="dot size mapper", - options=available_size_cols, - key='x_y_plot_size_mapper', - default=0, - ) - # st.form_submit_button("update axes") - return x_name, y_name, group_by, size_mapper + return x_name, y_name, group_by def add_xy_setting(): with st.expander('Plot settings', expanded=True): @@ -591,6 +576,60 @@ def add_xy_setting(): if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, if_show_diagonal, dot_size, dot_opacity, line_width, figure_width, figure_height, font_size_scale) +@st.cache_data(ttl=24*3600) +def _get_min_max(x, size_mapper_gamma): + x_gamma_all = x ** size_mapper_gamma + return np.percentile(x_gamma_all, 5), np.percentile(x_gamma_all, 95) + +def _size_mapping(x, size_mapper_range, size_mapper_gamma): + x = x / np.quantile(x[~np.isnan(x)], 0.95) + x_gamma = x**size_mapper_gamma + min_x, max_x = _get_min_max(x, size_mapper_gamma) + sizes = size_mapper_range[0] + x_gamma / (max_x - min_x) * (size_mapper_range[1] - size_mapper_range[0]) + sizes[np.isnan(sizes)] = 0 + return sizes + +def add_dot_property_mapper(): + with st.expander('Data point property mapper', expanded=True): + cols = st.columns([2, 1, 1]) + + # Get all columns that are numeric + available_size_cols = ['None'] + [ + col + for col in st.session_state.session_stats_names + if is_numeric_dtype(st.session_state.df_session_filtered[col]) + ] + + size_mapper = selectbox_wrapper_for_url_query( + cols[0], + label="dot size mapper", + options=available_size_cols, + key='x_y_plot_size_mapper', + default=0, + ) + + if st.session_state.x_y_plot_size_mapper != 'None': + size_mapper_range = slider_wrapper_for_url_query(cols[1], + label="size range", + min_value=0, + max_value=100, + key='x_y_plot_size_mapper_range', + default=(0, 50), + ) + + size_mapper_gamma = slider_wrapper_for_url_query(cols[2], + label="size gamma", + min_value=0.0, + max_value=2.0, + key='x_y_plot_size_mapper_gamma', + default=1.0, + step=0.1) + else: + size_mapper_range, size_mapper_gamma = None, None + + return size_mapper, size_mapper_range, size_mapper_gamma + + def data_selector(): with st.expander(f'Session selector', expanded=True): @@ -721,6 +760,8 @@ def _plot_population_x_y(df, x_name='session', y_name='foraging_eff', group_by=' if_show_diagonal=False, dot_size_base=10, dot_size_mapping_name='None', + dot_size_mapping_range=None, + dot_size_mapping_gamma=None, dot_opacity=0.4, line_width=2, x_y_plot_figure_width=1300, @@ -867,7 +908,7 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q # Add some more columns if dot_size_mapping_name !='None' and dot_size_mapping_name in df.columns: - df['dot_size'] = df[dot_size_mapping_name] + df['dot_size'] = _size_mapping(df[dot_size_mapping_name], dot_size_mapping_range, dot_size_mapping_gamma) else: df['dot_size'] = dot_size_base diff --git a/code/util/url_query_helper.py b/code/util/url_query_helper.py index 8926326..19a70a0 100644 --- a/code/util/url_query_helper.py +++ b/code/util/url_query_helper.py @@ -28,6 +28,11 @@ def selectbox_wrapper_for_url_query(st_prefix, label, options, key, default, **k ) def slider_wrapper_for_url_query(st_prefix, label, min_value, max_value, key, default, **kwargs): + + # Parse range from URL, compatible with both one or two values + if key in st.query_params: + parse_range_from_url = [type(min_value)(v) for v in st.query_params.get_all(key)] + return st_prefix.slider( label, min_value, @@ -35,7 +40,7 @@ def slider_wrapper_for_url_query(st_prefix, label, min_value, max_value, key, de value=( st.session_state[key] if key in st.session_state - else type(min_value)(st.query_params[key]) + else parse_range_from_url if key in st.query_params else default ), From bf9214b62b3d20ec0bc04d81e89da02878a01d72 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Fri, 5 Apr 2024 12:08:45 -0700 Subject: [PATCH 18/20] feat: modify default size mapper --- code/Home.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/code/Home.py b/code/Home.py index 467a88a..2f57666 100644 --- a/code/Home.py +++ b/code/Home.py @@ -87,9 +87,9 @@ 'x_y_plot_figure_height': 900, 'x_y_plot_font_size_scale': 1.0, - 'x_y_plot_size_mapper': 'None', + 'x_y_plot_size_mapper': 'finished_trials', 'x_y_plot_size_mapper_gamma': 1.0, - 'x_y_plot_size_mapper_range': [0, 50], + 'x_y_plot_size_mapper_range': [5, 30], 'session_plot_mode': 'sessions selected from table or plot', From 4f174badc7fb0dddbc2fc624d5e9ea0d668b7df2 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Fri, 5 Apr 2024 12:12:10 -0700 Subject: [PATCH 19/20] feat: update default --- code/Home.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/code/Home.py b/code/Home.py index 2f57666..33d3637 100644 --- a/code/Home.py +++ b/code/Home.py @@ -81,7 +81,7 @@ 'x_y_plot_q_quantiles_all': 20, 'x_y_plot_if_show_diagonal': False, 'x_y_plot_dot_size': 10, - 'x_y_plot_dot_opacity': 0.5, + 'x_y_plot_dot_opacity': 0.3, 'x_y_plot_line_width': 2.0, 'x_y_plot_figure_width': 1300, 'x_y_plot_figure_height': 900, @@ -89,7 +89,7 @@ 'x_y_plot_size_mapper': 'finished_trials', 'x_y_plot_size_mapper_gamma': 1.0, - 'x_y_plot_size_mapper_range': [5, 30], + 'x_y_plot_size_mapper_range': [3, 20], 'session_plot_mode': 'sessions selected from table or plot', From 57cba2761a3af78a1658086c3f2e892fe128bda2 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Fri, 5 Apr 2024 12:21:37 -0700 Subject: [PATCH 20/20] minor fix --- code/util/url_query_helper.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/code/util/url_query_helper.py b/code/util/url_query_helper.py index 19a70a0..cb323b6 100644 --- a/code/util/url_query_helper.py +++ b/code/util/url_query_helper.py @@ -32,6 +32,8 @@ def slider_wrapper_for_url_query(st_prefix, label, min_value, max_value, key, de # Parse range from URL, compatible with both one or two values if key in st.query_params: parse_range_from_url = [type(min_value)(v) for v in st.query_params.get_all(key)] + if len(parse_range_from_url) == 1: + parse_range_from_url = parse_range_from_url[0] return st_prefix.slider( label,