diff --git a/code/Home.py b/code/Home.py
index 09f41fc..33d3637 100644
--- a/code/Home.py
+++ b/code/Home.py
@@ -40,8 +40,12 @@
from util.streamlit import (filter_dataframe, aggrid_interactive_table_session,
aggrid_interactive_table_curriculum, add_session_filter, data_selector,
- _sync_widget_with_query, add_xy_selector, add_xy_setting, add_auto_train_manager,
+ add_xy_selector, add_xy_setting, add_auto_train_manager, add_dot_property_mapper,
_plot_population_x_y)
+from util.url_query_helper import (
+ sync_widget_with_query, slider_wrapper_for_url_query,
+)
+
import extra_streamlit_components as stx
from aind_auto_train.curriculum_manager import CurriculumManager
@@ -59,9 +63,11 @@
'filter_foraging_eff': [0.0, None],
'filter_task': ['all'],
+ 'table_height': 300,
+
'tab_id': 'tab_session_x_y',
'x_y_plot_xname': 'session',
- 'x_y_plot_yname': 'foraging_eff',
+ 'x_y_plot_yname': 'foraging_performance_random_seed',
'x_y_plot_group_by': 'h2o',
'x_y_plot_if_show_dots': True,
'x_y_plot_if_aggr_each_group': True,
@@ -73,9 +79,17 @@
'x_y_plot_q_quantiles_group': 20,
'x_y_plot_if_use_x_quantile_all': False,
'x_y_plot_q_quantiles_all': 20,
+ 'x_y_plot_if_show_diagonal': False,
'x_y_plot_dot_size': 10,
- 'x_y_plot_dot_opacity': 0.5,
+ 'x_y_plot_dot_opacity': 0.3,
'x_y_plot_line_width': 2.0,
+ 'x_y_plot_figure_width': 1300,
+ 'x_y_plot_figure_height': 900,
+ 'x_y_plot_font_size_scale': 1.0,
+
+ 'x_y_plot_size_mapper': 'finished_trials',
+ 'x_y_plot_size_mapper_gamma': 1.0,
+ 'x_y_plot_size_mapper_range': [3, 20],
'session_plot_mode': 'sessions selected from table or plot',
@@ -282,7 +296,42 @@ def draw_session_plots(df_to_draw_session):
my_bar.progress(int((i + 1) / len(df_to_draw_session) * 100))
+def draw_session_plots_quick_preview(df_to_draw_session):
+
+ # Setting up layout for each session
+ layout_definition = [[1], # columns in the first row
+ [1, 1],
+ ]
+ draw_types_quick_preview = ['1. Choice history', '2. Logistic regression (Su2022)']
+
+ container_session_all_in_one = st.container()
+
+ key = df_to_draw_session.to_dict(orient='records')[0]
+
+ with container_session_all_in_one:
+ try:
+ date_str = key["session_date"].strftime('%Y-%m-%d')
+ except:
+ date_str = key["session_date"].split("T")[0]
+
+ st.markdown(f'''
{key["h2o"]}, Session {int(key["session"])}, {date_str}''',
+ unsafe_allow_html=True)
+
+ rows = []
+ for row, column_setting in enumerate(layout_definition):
+ rows.append(st.columns(column_setting))
+
+ for draw_type in draw_types_quick_preview:
+ if draw_type not in st.session_state.selected_draw_types: continue # To keep the draw order defined by st.session_state.draw_type_mapper_session_level
+ prefix, position, setting = st.session_state.draw_type_mapper_session_level[draw_type]
+ this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types) > 1 else rows[0]
+ show_session_level_img_by_key_and_prefix(key,
+ column=this_col,
+ prefix=prefix,
+ **setting)
+
+
def draw_mice_plots(df_to_draw_mice):
@@ -335,7 +384,7 @@ def draw_mice_plots(df_to_draw_mice):
def session_plot_settings(need_click=True):
st.markdown('##### Show plots for individual sessions ')
- cols = st.columns([2, 1])
+ cols = st.columns([2, 1, 6])
session_plot_modes = [f'sessions selected from table or plot', f'all sessions filtered from sidebar']
st.session_state.selected_draw_sessions = cols[0].selectbox(f'Which session(s) to draw?',
@@ -351,7 +400,7 @@ def session_plot_settings(need_click=True):
n_session_to_draw = len(st.session_state.df_selected_from_plotly) \
if 'selected from table or plot' in st.session_state.selected_draw_sessions \
else len(st.session_state.df_session_filtered)
- st.markdown(f'{n_session_to_draw} sessions to draw')
+ cols[0].markdown(f'{n_session_to_draw} sessions to draw')
st.session_state.num_cols = cols[1].number_input('number of columns', 1, 10,
3 if 'num_cols' not in st.session_state else st.session_state.num_cols)
@@ -365,61 +414,38 @@ def session_plot_settings(need_click=True):
""",
unsafe_allow_html=True,
)
- st.session_state.selected_draw_types = st.multiselect('Which plot(s) to draw?',
+ st.session_state.selected_draw_types = cols[2].multiselect('Which plot(s) to draw?',
st.session_state.draw_type_mapper_session_level.keys(),
default=st.session_state.draw_type_mapper_session_level.keys()
if 'selected_draw_types' not in st.session_state else
st.session_state.selected_draw_types)
if need_click:
- draw_it = st.button(f'Show me all {n_session_to_draw} sessions!', use_container_width=True)
- else:
- draw_it = True
- return draw_it
-
-def mouse_plot_settings(need_click=True):
- st.markdown('##### Show plots for individual mice ')
- cols = st.columns([2, 1])
- st.session_state.selected_draw_mice = cols[0].selectbox('Which mice to draw?',
- [f'selected from table/plot ({len(st.session_state.df_selected_from_plotly.h2o.unique())} mice)',
- f'filtered from sidebar ({len(st.session_state.df_session_filtered.h2o.unique())} mice)'],
- index=0
- )
- st.session_state.num_cols_mice = cols[1].number_input('Number of columns', 1, 10,
- 3 if 'num_cols_mice' not in st.session_state else st.session_state.num_cols_mice)
- st.markdown(
- """
- """,
- unsafe_allow_html=True,
- )
- st.session_state.selected_draw_types_mice = st.multiselect('Which plot(s) to draw?',
- st.session_state.draw_type_mapper_mouse_level.keys(),
- default=st.session_state.draw_type_mapper_mouse_level.keys()
- if 'selected_draw_types_mice' not in st.session_state else
- st.session_state.selected_draw_types_mice)
- if need_click:
- draw_it = st.button('Show me all mice!', use_container_width=True)
+ draw_it = st.button(f'Show me all {n_session_to_draw} sessions!', use_container_width=True, type="primary")
+ draw_it_now_override = cols[1].checkbox('Auto draw')
else:
draw_it = True
- return draw_it
+ draw_it_now_override = True
+ return draw_it | draw_it_now_override
def plot_x_y_session():
- cols = st.columns([4, 10])
+ cols = st.columns([1, 1, 1])
with cols[0]:
-
x_name, y_name, group_by = add_xy_selector(if_bonsai=True)
+ with cols[1]:
(if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group,
- if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor,
- dot_size, dot_opacity, line_width) = add_xy_setting()
-
-
+ if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, if_show_diagonal,
+ dot_size, dot_opacity, line_width, x_y_plot_figure_width, x_y_plot_figure_height, font_size_scale) = add_xy_setting()
+
+ if st.session_state.x_y_plot_if_show_dots:
+ with cols[2]:
+ size_mapper, size_mapper_range, size_mapper_gamma = add_dot_property_mapper()
+ else:
+ size_mapper = 'None'
+ size_mapper_range, size_mapper_gamma = None, None
# If no sessions are selected, use all filtered entries
# df_x_y_session = st.session_state.df_selected_from_dataframe if if_plot_only_selected_from_dataframe else st.session_state.df_session_filtered
@@ -432,33 +458,55 @@ def plot_x_y_session():
df_selected_from_plotly = pd.DataFrame()
# for i, (title, (x_name, y_name)) in enumerate(names.items()):
# with cols[i]:
- with cols[1]:
- fig = _plot_population_x_y(df=df_x_y_session,
- x_name=x_name, y_name=y_name,
- group_by=group_by,
- smooth_factor=smooth_factor,
- if_show_dots=if_show_dots,
- if_aggr_each_group=if_aggr_each_group,
- if_aggr_all=if_aggr_all,
- aggr_method_group=aggr_method_group,
- aggr_method_all=aggr_method_all,
- if_use_x_quantile_group=if_use_x_quantile_group,
- q_quantiles_group=q_quantiles_group,
- if_use_x_quantile_all=if_use_x_quantile_all,
- q_quantiles_all=q_quantiles_all,
- title=names[(x_name, y_name)] if (x_name, y_name) in names else y_name,
- states = st.session_state.df_selected_from_plotly,
- dot_size=dot_size,
- dot_opacity=dot_opacity,
- line_width=line_width)
+
+ if hasattr(st.session_state, 'x_y_plot_figure_width'):
+ _x_y_plot_scale = st.session_state.x_y_plot_figure_width / 1300
+ cols = st.columns([1 * _x_y_plot_scale, 0.7])
+ else:
+ cols = st.columns([1, 0.7])
+ with cols[0]:
+ fig = _plot_population_x_y(df=df_x_y_session.copy(),
+ x_name=x_name, y_name=y_name,
+ group_by=group_by,
+ smooth_factor=smooth_factor,
+ if_show_dots=if_show_dots,
+ if_aggr_each_group=if_aggr_each_group,
+ if_aggr_all=if_aggr_all,
+ aggr_method_group=aggr_method_group,
+ aggr_method_all=aggr_method_all,
+ if_use_x_quantile_group=if_use_x_quantile_group,
+ q_quantiles_group=q_quantiles_group,
+ if_use_x_quantile_all=if_use_x_quantile_all,
+ q_quantiles_all=q_quantiles_all,
+ title=names[(x_name, y_name)] if (x_name, y_name) in names else y_name,
+ states = st.session_state.df_selected_from_plotly,
+ if_show_diagonal=if_show_diagonal,
+ dot_size_base=dot_size,
+ dot_size_mapping_name=size_mapper,
+ dot_size_mapping_range=size_mapper_range,
+ dot_size_mapping_gamma=size_mapper_gamma,
+ dot_opacity=dot_opacity,
+ line_width=line_width,
+ x_y_plot_figure_width=x_y_plot_figure_width,
+ x_y_plot_figure_height=x_y_plot_figure_height,
+ font_size_scale=font_size_scale,
+ )
# st.plotly_chart(fig)
selected = plotly_events(fig, click_event=True, hover_event=False, select_event=True,
override_height=fig.layout.height * 1.1, override_width=fig.layout.width)
+
+ with cols[1]:
+ st.markdown('#### 👀 Quick preview')
+ st.markdown('###### Click on one session to preview here, or Box/Lasso select multiple sessions to draw them in the section below')
+ st.markdown('(sometimes you have to click twice...)')
if len(selected):
df_selected_from_plotly = df_x_y_session.merge(pd.DataFrame(selected).rename({'x': x_name, 'y': y_name}, axis=1),
on=[x_name, y_name], how='inner')
+ if len(st.session_state.df_selected_from_plotly) == 1:
+ with cols[1]:
+ draw_session_plots_quick_preview(st.session_state.df_selected_from_plotly)
return df_selected_from_plotly, cols
@@ -477,7 +525,7 @@ def init():
# Set session state from URL
for key, default in to_sync_with_url_query.items():
- _sync_widget_with_query(key, default)
+ sync_widget_with_query(key, default)
df = load_data(['sessions',
])
@@ -573,6 +621,13 @@ def init():
# map user_name
st.session_state.df['sessions_bonsai']['user_name'] = st.session_state.df['sessions_bonsai']['user_name'].apply(_user_name_mapper)
+ # fill nan for autotrain fields
+ filled_values = {'curriculum_name': 'None',
+ 'curriculum_version': 'None',
+ 'curriculum_schema_version': 'None',
+ 'current_stage_actual': 'None'}
+ st.session_state.df['sessions_bonsai'].fillna(filled_values, inplace=True)
+
# foraging performance = foraing_eff * finished_rate
if 'foraging_performance' not in st.session_state.df['sessions_bonsai'].columns:
st.session_state.df['sessions_bonsai']['foraging_performance'] = \
@@ -634,8 +689,15 @@ def app():
init()
st.rerun()
- table_height = cols[3].slider('Table height', 100, 2000, 400, 50, key='table_height')
-
+ table_height = slider_wrapper_for_url_query(st_prefix=cols[3],
+ label='Table height',
+ min_value=0,
+ max_value=2000,
+ default=300,
+ step=50,
+ key='table_height',
+ )
+
# aggrid_outputs = aggrid_interactive_table_units(df=df['ephys_units'])
# st.session_state.df_session_filtered = aggrid_outputs['data']
@@ -673,8 +735,8 @@ def app():
with placeholder:
df_selected_from_plotly, x_y_cols = plot_x_y_session()
- with x_y_cols[0]:
- for i in range(7): st.write('\n')
+ # Add session_plot_setting
+ with st.columns([1, 0.5])[0]:
st.markdown("***")
if_draw_all_sessions = session_plot_settings()
@@ -722,7 +784,7 @@ def app():
elif chosen_id == "tab_session_inspector":
with placeholder:
- cols = st.columns([6, 3, 7])
+ cols = st.columns([1, 0.5])
with cols[0]:
if_draw_all_sessions = session_plot_settings(need_click=False)
df_to_draw_sessions = st.session_state.df_selected_from_plotly if 'selected' in st.session_state.selected_draw_sessions else st.session_state.df_session_filtered
@@ -832,6 +894,8 @@ def app():
# Add debug info
if chosen_id != "tab_auto_train_curriculum":
+ for _ in range(10): st.write('\n')
+ st.markdown('---\n##### Debug zone')
with st.expander('CO processing NWB errors', expanded=False):
error_file = cache_folder + 'error_files.json'
if fs.exists(error_file):
diff --git a/code/pages/1_Old mice.py b/code/pages/1_Old mice.py
index cf910ec..f11d1be 100644
--- a/code/pages/1_Old mice.py
+++ b/code/pages/1_Old mice.py
@@ -19,8 +19,10 @@
from streamlit_plotly_events import plotly_events
from util.streamlit import (filter_dataframe, aggrid_interactive_table_session, add_session_filter, data_selector,
- add_xy_selector, _sync_widget_with_query, add_xy_setting, add_auto_train_manager,
+ add_xy_selector, add_xy_setting, add_auto_train_manager,
_plot_population_x_y)
+from util.url_query_helper import sync_widget_with_query
+
import extra_streamlit_components as stx
from aind_auto_train.auto_train_manager import DynamicForagingAutoTrainManager
@@ -41,7 +43,7 @@
'tab_id': 'tab_session_x_y',
'x_y_plot_xname': 'session',
- 'x_y_plot_yname': 'foraging_eff',
+ 'x_y_plot_yname': 'foraging_performance',
'x_y_plot_group_by': 'h2o',
'x_y_plot_if_show_dots': True,
'x_y_plot_if_aggr_each_group': True,
@@ -354,8 +356,8 @@ def plot_x_y_session():
x_name, y_name, group_by = add_xy_selector(if_bonsai=False)
(if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group,
- if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor,
- dot_size, dot_opacity, line_width) = add_xy_setting()
+ if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, if_show_diagonal,
+ dot_size, dot_opacity, line_width, _, _, _) = add_xy_setting()
# If no sessions are selected, use all filtered entries
@@ -385,7 +387,7 @@ def plot_x_y_session():
q_quantiles_all=q_quantiles_all,
title=names[(x_name, y_name)] if (x_name, y_name) in names else y_name,
states = st.session_state.df_selected_from_plotly,
- dot_size=dot_size,
+ dot_size_base=dot_size,
dot_opacity=dot_opacity,
line_width=line_width)
@@ -410,7 +412,7 @@ def init():
# Set session state from URL
for key, default in to_sync_with_url_query.items():
- _sync_widget_with_query(key, default)
+ sync_widget_with_query(key, default)
df = load_data(['sessions',
'logistic_regression_hattori',
diff --git a/code/util/streamlit.py b/code/util/streamlit.py
index 54bcbb6..662bdf3 100644
--- a/code/util/streamlit.py
+++ b/code/util/streamlit.py
@@ -1,4 +1,4 @@
-from email import header
+from collections import OrderedDict
import pandas as pd
import streamlit as st
from st_aggrid import AgGrid, GridOptionsBuilder
@@ -7,6 +7,7 @@
is_categorical_dtype,
is_datetime64_any_dtype,
is_numeric_dtype,
+ is_string_dtype,
is_object_dtype,
)
import streamlit.components.v1 as components
@@ -20,6 +21,8 @@
import statsmodels.api as sm
from scipy.stats import linregress
+from .url_query_helper import checkbox_wrapper_for_url_query, selectbox_wrapper_for_url_query, slider_wrapper_for_url_query
+
custom_css = {
".ag-root.ag-unselectable.ag-layout-normal": {"font-size": "15px !important",
@@ -176,7 +179,7 @@ def cache_widget(field, clear=None):
# Clear cache if needed
if clear:
if clear in st.session_state: del st.session_state[clear]
-
+
# def dec_cache_widget_state(widget, ):
@@ -390,6 +393,24 @@ def add_session_filter(if_bonsai=False, url_query={}):
default_filters=['subject_id', 'task', 'session', 'finished_trials', 'foraging_eff'],
url_query=url_query)
+@st.cache_data(ttl=3600*24)
+def _get_grouped_by_fields(if_bonsai):
+ if if_bonsai:
+ options = ['h2o', 'task', 'user_name', 'rig', 'weekday']
+ options += [col
+ for col in st.session_state.df_session_filtered.columns
+ if is_categorical_dtype(st.session_state.df_session_filtered[col])
+ or st.session_state.df_session_filtered[col].nunique() < 20
+ and not any([exclude in col for exclude in
+ ('date', 'time', 'session', 'finished', 'foraging_eff')])
+ ]
+ options = list(list(OrderedDict.fromkeys(options))) # Remove duplicates
+ else:
+ options = ['h2o', 'task', 'photostim_location', 'weekday',
+ 'headbar', 'user_name', 'sex', 'rig']
+ return options
+
+
def add_xy_selector(if_bonsai):
with st.expander("Select axes", expanded=True):
# with st.form("axis_selection"):
@@ -403,165 +424,211 @@ def add_xy_selector(if_bonsai):
else st.session_state.session_stats_names.index('session'),
key='x_y_plot_xname'
)
- y_name = cols[0].selectbox("y axis",
- st.session_state.session_stats_names,
- index=st.session_state.session_stats_names.index(st.session_state['x_y_plot_yname'])
- if 'x_y_plot_yname' in st.session_state else
- st.session_state.session_stats_names.index(st.query_params['x_y_plot_yname'])
- if 'x_y_plot_yname' in st.query_params
- else st.session_state.session_stats_names.index('foraging_eff'),
- key='x_y_plot_yname')
-
- if if_bonsai:
- options = ['h2o', 'task', 'user_name', 'rig', 'weekday']
- else:
- options = ['h2o', 'task', 'photostim_location', 'weekday',
- 'headbar', 'user_name', 'sex', 'rig']
-
- group_by = cols[0].selectbox("grouped by",
- options=options,
- index=options.index(st.session_state['x_y_plot_group_by'])
- if 'x_y_plot_group_by' in st.session_state else
- options.index(st.query_params['x_y_plot_group_by'])
- if 'x_y_plot_group_by' in st.query_params
- else 0,
- key='x_y_plot_group_by')
-
- # st.form_submit_button("update axes")
+ y_name = selectbox_wrapper_for_url_query(
+ cols[0],
+ label="y axis",
+ options=st.session_state.session_stats_names,
+ key="x_y_plot_yname",
+ default=st.session_state.session_stats_names.index('foraging_eff')
+ )
+
+
+ group_by = selectbox_wrapper_for_url_query(
+ cols[0],
+ label="grouped by",
+ options=_get_grouped_by_fields(if_bonsai),
+ key="x_y_plot_group_by",
+ default=0,
+ )
+
+ # st.form_submit_button("update axes")
return x_name, y_name, group_by
-
def add_xy_setting():
with st.expander('Plot settings', expanded=True):
s_cols = st.columns([1, 1, 1])
# if_plot_only_selected_from_dataframe = s_cols[0].checkbox('Only selected', False)
- if_show_dots = s_cols[0].checkbox('Show data points',
- value=st.session_state['x_y_plot_if_show_dots']
- if 'x_y_plot_if_show_dots' in st.session_state else
- st.query_params['x_y_plot_if_show_dots'].lower()=='true'
- if 'x_y_plot_if_show_dots' in st.query_params
- else True,
- key='x_y_plot_if_show_dots')
-
-
- if_aggr_each_group = s_cols[1].checkbox('Aggr each group',
- value=st.session_state['x_y_plot_if_aggr_each_group']
- if 'x_y_plot_if_aggr_each_group' in st.session_state else
- st.query_params['x_y_plot_if_aggr_each_group'].lower()=='true'
- if 'x_y_plot_if_aggr_each_group' in st.query_params
- else True,
- key='x_y_plot_if_aggr_each_group')
+ if_show_dots = checkbox_wrapper_for_url_query(s_cols[0],
+ label='Show data points',
+ key='x_y_plot_if_show_dots',
+ default=True)
+
+ if_aggr_each_group = checkbox_wrapper_for_url_query(s_cols[1],
+ label='Aggr each group',
+ key='x_y_plot_if_aggr_each_group',
+ default=True)
aggr_methods = ['mean', 'mean +/- sem', 'lowess', 'running average', 'linear fit']
- aggr_method_group = s_cols[1].selectbox('aggr method group',
- options=aggr_methods,
- index=aggr_methods.index(st.session_state['x_y_plot_aggr_method_group'])
- if 'x_y_plot_aggr_method_group' in st.session_state else
- aggr_methods.index(st.query_params['x_y_plot_aggr_method_group'])
- if 'x_y_plot_aggr_method_group' in st.query_params
- else 2,
- key='x_y_plot_aggr_method_group',
- disabled=not if_aggr_each_group)
+ aggr_method_group = selectbox_wrapper_for_url_query(s_cols[1],
+ label='aggr method group',
+ options=aggr_methods,
+ key='x_y_plot_aggr_method_group',
+ default=2,
+ disabled=not if_aggr_each_group)
- if_use_x_quantile_group = s_cols[1].checkbox('Use quantiles of x ',
- value=st.session_state['x_y_plot_if_use_x_quantile_group']
- if 'x_y_plot_if_use_x_quantile_group' in st.session_state else
- st.query_params['x_y_plot_if_use_x_quantile_group'].lower()=='true'
- if 'x_y_plot_if_use_x_quantile_group' in st.query_params
- else False,
- key='x_y_plot_if_use_x_quantile_group',
- disabled='mean' not in aggr_method_group)
+ if_use_x_quantile_group = checkbox_wrapper_for_url_query(s_cols[1],
+ label='Use quantiles of x',
+ key='x_y_plot_if_use_x_quantile_group',
+ default=False,
+ disabled='mean' not in aggr_method_group)
- q_quantiles_group = s_cols[1].slider('Number of quantiles ', 1, 100,
- value=st.session_state['x_y_plot_q_quantiles_group']
- if 'x_y_plot_q_quantiles_group' in st.session_state else
- int(st.query_params['x_y_plot_q_quantiles_group'])
- if 'x_y_plot_q_quantiles_group' in st.query_params
- else 20,
- key='x_y_plot_q_quantiles_group',
- disabled=not if_use_x_quantile_group
- )
+ q_quantiles_group = slider_wrapper_for_url_query(s_cols[1],
+ label='Number of quantiles ',
+ min_value=1,
+ max_value=100,
+ key='x_y_plot_q_quantiles_group',
+ default=20,
+ disabled=not if_use_x_quantile_group
+ )
- if_aggr_all = s_cols[2].checkbox('Aggr all',
- value=st.session_state['x_y_plot_if_aggr_all']
- if 'x_y_plot_if_aggr_all' in st.session_state else
- st.query_params['x_y_plot_if_aggr_all'].lower()=='true'
- if 'x_y_plot_if_aggr_all' in st.query_params
- else True,
- key='x_y_plot_if_aggr_all',
- )
+ if_aggr_all = checkbox_wrapper_for_url_query(s_cols[2],
+ label='Aggr all',
+ key='x_y_plot_if_aggr_all',
+ default=True)
# st.session_state.if_aggr_all_cache = if_aggr_all # Have to use another variable to store this explicitly (my cache_widget somehow doesn't work with checkbox)
- aggr_method_all = s_cols[2].selectbox('aggr method all', aggr_methods,
- index=aggr_methods.index(st.session_state['x_y_plot_aggr_method_all'])
- if 'x_y_plot_aggr_method_all' in st.session_state else
- aggr_methods.index(st.query_params['x_y_plot_aggr_method_all'])
- if 'x_y_plot_aggr_method_all' in st.query_params
- else 1,
- key='x_y_plot_aggr_method_all',
- disabled=not if_aggr_all)
-
- if_use_x_quantile_all = s_cols[2].checkbox('Use quantiles of x',
- value=st.session_state['x_y_plot_if_use_x_quantile_all']
- if 'x_y_plot_if_use_x_quantile_all' in st.session_state else
- st.query_params['x_y_plot_if_use_x_quantile_all'].lower()=='true'
- if 'x_y_plot_if_use_x_quantile_all' in st.query_params
- else False,
- key='x_y_plot_if_use_x_quantile_all',
- disabled='mean' not in aggr_method_all,
- )
- q_quantiles_all = s_cols[2].slider('number of quantiles', 1, 100,
- value=st.session_state['x_y_plot_q_quantiles_all']
- if 'x_y_plot_q_quantiles_all' in st.session_state else
- int(st.query_params['x_y_plot_q_quantiles_all'])
- if 'x_y_plot_q_quantiles_all' in st.query_params
- else 20,
- key='x_y_plot_q_quantiles_all',
- disabled=not if_use_x_quantile_all
- )
-
- smooth_factor = s_cols[0].slider('smooth factor', 1, 20,
- value=st.session_state['x_y_plot_smooth_factor']
- if 'x_y_plot_smooth_factor' in st.session_state else
- int(st.query_params['x_y_plot_smooth_factor'])
- if 'x_y_plot_smooth_factor' in st.query_params
- else 5,
- key='x_y_plot_smooth_factor',
- disabled= not ((if_aggr_each_group and aggr_method_group in ('running average', 'lowess'))
- or (if_aggr_all and aggr_method_all in ('running average', 'lowess')))
- )
+ aggr_method_all = selectbox_wrapper_for_url_query(s_cols[2],
+ label='aggr method all',
+ options=aggr_methods,
+ key='x_y_plot_aggr_method_all',
+ default=1,
+ disabled=not if_aggr_all)
+
+ if_use_x_quantile_all = checkbox_wrapper_for_url_query(s_cols[2],
+ label='Use quantiles of x',
+ key='x_y_plot_if_use_x_quantile_all',
+ default=False,
+ disabled='mean' not in aggr_method_all)
+
- c = st.columns([1, 1, 1])
- dot_size = c[0].slider('dot size', 1, 30,
- step=1,
- value=st.session_state['x_y_plot_dot_size']
- if 'x_y_plot_dot_size' in st.session_state else
- int(st.query_params['x_y_plot_dot_size'])
- if 'x_y_plot_dot_size' in st.query_params
- else 10,
- key='x_y_plot_dot_size'
- )
- dot_opacity = c[1].slider('opacity', 0.0, 1.0,
- step=0.05,
- value=st.session_state['x_y_plot_dot_opacity']
- if 'x_y_plot_dot_opacity' in st.session_state else
- float(st.query_params['x_y_plot_dot_opacity'])
- if 'x_y_plot_dot_opacity' in st.query_params
- else 0.5,
- key='x_y_plot_dot_opacity')
-
- line_width = c[2].slider('line width', 0.0, 5.0,
- step=0.25,
- value=st.session_state['x_y_plot_line_width']
- if 'x_y_plot_line_width' in st.session_state else
- float(st.query_params['x_y_plot_line_width'])
- if 'x_y_plot_line_width' in st.query_params
- else 2.0,
- key='x_y_plot_line_width')
+ q_quantiles_all = slider_wrapper_for_url_query(s_cols[2],
+ label='Number of quantiles',
+ min_value=1,
+ max_value=100,
+ key='x_y_plot_q_quantiles_all',
+ default=20,
+ disabled=not if_use_x_quantile_all
+ )
+
+ smooth_factor = slider_wrapper_for_url_query(s_cols[0],
+ label='smooth factor',
+ min_value=1,
+ max_value=20,
+ key='x_y_plot_smooth_factor',
+ default=5,
+ disabled= not ((if_aggr_each_group and aggr_method_group in ('running average', 'lowess'))
+ or (if_aggr_all and aggr_method_all in ('running average', 'lowess')))
+ )
+
+ if_show_diagonal = checkbox_wrapper_for_url_query(s_cols[0],
+ label='Show diagonal line',
+ key='x_y_plot_if_show_diagonal',
+ default=False)
+
+ with st.expander('Misc', expanded=False):
+ c = st.columns([1, 1, 1])
+ dot_size = slider_wrapper_for_url_query(c[0],
+ label='dot size',
+ min_value=1,
+ max_value=30,
+ key='x_y_plot_dot_size',
+ default=10)
+
+ dot_opacity = slider_wrapper_for_url_query(c[1],
+ label='opacity',
+ min_value=0.0,
+ max_value=1.0,
+ step=0.05,
+ key='x_y_plot_dot_opacity',
+ default=0.5)
+
+ line_width = slider_wrapper_for_url_query(c[2],
+ label='line width',
+ min_value=0.0,
+ max_value=5.0,
+ step=0.25,
+ key='x_y_plot_line_width',
+ default=2.0)
+
+ figure_width = slider_wrapper_for_url_query(c[0],
+ label='figure width',
+ min_value=500,
+ max_value=2500,
+ key='x_y_plot_figure_width',
+ default=1300)
+
+ figure_height = slider_wrapper_for_url_query(c[1],
+ label='figure height',
+ min_value=500,
+ max_value=2500,
+ key='x_y_plot_figure_height',
+ default=900)
+
+ font_size_scale = slider_wrapper_for_url_query(c[2],
+ label='font size',
+ min_value=0.0,
+ max_value=2.0,
+ step=0.1,
+ key='x_y_plot_font_size_scale',
+ default=1.0)
return (if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group,
- if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor,
- dot_size, dot_opacity, line_width)
+ if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, if_show_diagonal,
+ dot_size, dot_opacity, line_width, figure_width, figure_height, font_size_scale)
+
+@st.cache_data(ttl=24*3600)
+def _get_min_max(x, size_mapper_gamma):
+ x_gamma_all = x ** size_mapper_gamma
+ return np.percentile(x_gamma_all, 5), np.percentile(x_gamma_all, 95)
+
+def _size_mapping(x, size_mapper_range, size_mapper_gamma):
+ x = x / np.quantile(x[~np.isnan(x)], 0.95)
+ x_gamma = x**size_mapper_gamma
+ min_x, max_x = _get_min_max(x, size_mapper_gamma)
+ sizes = size_mapper_range[0] + x_gamma / (max_x - min_x) * (size_mapper_range[1] - size_mapper_range[0])
+ sizes[np.isnan(sizes)] = 0
+ return sizes
+
+def add_dot_property_mapper():
+ with st.expander('Data point property mapper', expanded=True):
+ cols = st.columns([2, 1, 1])
+
+ # Get all columns that are numeric
+ available_size_cols = ['None'] + [
+ col
+ for col in st.session_state.session_stats_names
+ if is_numeric_dtype(st.session_state.df_session_filtered[col])
+ ]
+
+ size_mapper = selectbox_wrapper_for_url_query(
+ cols[0],
+ label="dot size mapper",
+ options=available_size_cols,
+ key='x_y_plot_size_mapper',
+ default=0,
+ )
+
+ if st.session_state.x_y_plot_size_mapper != 'None':
+ size_mapper_range = slider_wrapper_for_url_query(cols[1],
+ label="size range",
+ min_value=0,
+ max_value=100,
+ key='x_y_plot_size_mapper_range',
+ default=(0, 50),
+ )
+
+ size_mapper_gamma = slider_wrapper_for_url_query(cols[2],
+ label="size gamma",
+ min_value=0.0,
+ max_value=2.0,
+ key='x_y_plot_size_mapper_gamma',
+ default=1.0,
+ step=0.1)
+ else:
+ size_mapper_range, size_mapper_gamma = None, None
+
+ return size_mapper, size_mapper_range, size_mapper_gamma
+
def data_selector():
@@ -690,15 +757,22 @@ def _plot_population_x_y(df, x_name='session', y_name='foraging_eff', group_by='
if_use_x_quantile_all=False,
q_quantiles_all=20,
title='',
- dot_size=10,
+ if_show_diagonal=False,
+ dot_size_base=10,
+ dot_size_mapping_name='None',
+ dot_size_mapping_range=None,
+ dot_size_mapping_gamma=None,
dot_opacity=0.4,
line_width=2,
+ x_y_plot_figure_width=1300,
+ x_y_plot_figure_height=900,
+ font_size_scale=1.0,
**kwarg):
-
- def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_quantiles, col, line_width):
+
+ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_quantiles, col, line_width, **kwarg):
x = df_this.sort_values(x_name)[x_name].astype(float)
y = df_this.sort_values(x_name)[y_name].astype(float)
-
+
n_mice = len(df_this['h2o'].unique())
n_sessions = len(df_this.groupby(['h2o', 'session']).count())
n_str = f' ({n_mice} mice, {n_sessions} sessions)' if group_by !='h2o' else f' ({n_sessions} sessions)'
@@ -714,12 +788,14 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q
line_width=line_width,
opacity=1,
hoveron='points+fills', # Scattergl doesn't support this
+ hoverinfo='skip',
+ **kwarg,
))
-
+
elif aggr_method == 'lowess':
x_new = np.linspace(x.min(), x.max(), 200)
lowess = sm.nonparametric.lowess(y, x, frac=smooth_factor/20)
-
+
fig.add_trace(go.Scatter(
x=lowess[:, 0],
y=lowess[:, 1],
@@ -730,25 +806,27 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q
marker_color=col,
opacity=1,
hoveron='points+fills', # Scattergl doesn't support this
+ hoverinfo='skip',
+ **kwarg,
))
-
+
elif aggr_method in ('mean +/- sem', 'mean'):
-
+
# Re-bin x if use quantiles of x
if if_use_x_quantile:
df_this[f'{x_name}_quantile'] = pd.qcut(df_this[x_name], q=q_quantiles, labels=False, duplicates='drop')
-
+
mean = df_this.groupby(f'{x_name}_quantile')[y_name].mean()
sem = df_this.groupby(f'{x_name}_quantile')[y_name].sem()
valid_y = mean.notna()
mean = mean[valid_y]
sem = sem[valid_y]
sem[~sem.notna()] = 0
-
+
x = df_this.groupby(f'{x_name}_quantile')[x_name].median() # Use median of x in each quantile as x
y_upper = mean + sem
y_lower = mean - sem
-
+
else:
# mean and sem groupby x_name
mean = df_this.groupby(x_name)[y_name].mean()
@@ -757,11 +835,11 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q
mean = mean[valid_y]
sem = sem[valid_y]
sem[~sem.notna()] = 0
-
+
x = mean.index
y_upper = mean + sem
y_lower = mean - sem
-
+
fig.add_trace(go.Scatter(
x=x,
y=mean,
@@ -775,8 +853,10 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q
marker_color=col,
opacity=1,
hoveron='points+fills', # Scattergl doesn't support this
+ hoverinfo='skip',
+ **kwarg,
))
-
+
if 'sem' in aggr_method:
fig.add_trace(go.Scatter(
# name='Upper Bound',
@@ -802,7 +882,6 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q
showlegend=False,
hoverinfo='skip'
))
-
elif aggr_method == 'linear fit':
# perform linear regression
@@ -818,19 +897,40 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q
line=dict(dash='dot' if p_value > 0.05 else 'solid',
width=line_width if p_value > 0.05 else line_width*1.5),
legendgroup=f'group_{group}',
- # hoverinfo='skip'
+ hoverinfo='skip'
)
)
except:
pass
-
+
fig = go.Figure()
col_map = px.colors.qualitative.Plotly
-
+
+ # Add some more columns
+ if dot_size_mapping_name !='None' and dot_size_mapping_name in df.columns:
+ df['dot_size'] = _size_mapping(df[dot_size_mapping_name], dot_size_mapping_range, dot_size_mapping_gamma)
+ else:
+ df['dot_size'] = dot_size_base
+
+ # Turn column of group_by to string if it's not
+ if not is_string_dtype(df[group_by]):
+ df[group_by] = df[group_by].astype(str)
+
+ # Add a diagonal line first
+ if if_show_diagonal:
+ _min = df[x_name].values.ravel().min()
+ _max = df[y_name].values.ravel().max()
+ fig.add_trace(go.Scattergl(x=[_min, _max],
+ y=[_min, _max],
+ mode='lines',
+ line=dict(dash='dash', color='black', width=2),
+ showlegend=False)
+ )
+
for i, group in enumerate(df.sort_values(group_by)[group_by].unique()):
this_session = df.query(f'{group_by} == "{group}"').sort_values('session')
col = col_map[i%len(col_map)]
-
+
if if_show_dots:
if not len(st.session_state.df_selected_from_plotly):
this_session['colors'] = col # all use normal colors
@@ -840,7 +940,7 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q
merged.loc[merged.subject_id_y.notna(), 'colors'] = col # only use normal colors for the selected dots
this_session['colors'] = merged.colors.values
this_session = pd.concat([this_session.query('colors != "lightgrey"'), this_session.query('colors == "lightgrey"')]) # make sure the real color goes first
-
+
fig.add_trace(go.Scattergl(
x=this_session[x_name],
y=this_session[y_name],
@@ -849,77 +949,70 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q
showlegend=not if_aggr_each_group,
mode="markers",
line_width=line_width,
- marker_size=dot_size,
+ marker_size=this_session['dot_size'],
marker_color=this_session['colors'],
- opacity=dot_opacity, # 0.5 if if_aggr_each_group else 0.8,
- text=this_session['session'],
- hovertemplate = '
%{customdata[0]}, Session %{text}' +
- '
%s = %%{x}' % (x_name) +
- '
%s = %%{y}' % (y_name),
- # '%{name}',
- customdata=np.stack((this_session.h2o, this_session.session), axis=-1),
+ opacity=dot_opacity,
+ hovertemplate = '%{customdata[0]}, %{customdata[1]}, Session %{customdata[2]}'
+ '
%{customdata[3]}, %{customdata[4]}'
+ '
Task: %{customdata[5]}'
+ '
AutoTrain: %{customdata[7]} @ %{customdata[6]}'
+ f'
{"-"*10}
X: {x_name} = %{{x}}'
+ f'
Y: {y_name} = %{{y}}'
+ + (f'
Size: {dot_size_mapping_name} = %{{customdata[8]}}'
+ if dot_size_mapping_name !='None'
+ else '')
+ + '',
+ customdata=np.stack((this_session.h2o, # 0
+ this_session.session_date.dt.strftime('%Y-%m-%d'), # 1
+ this_session.session, # 2
+ this_session.rig, # 3
+ this_session.user_name, # 4
+ this_session.task, # 5
+ this_session.curriculum_name
+ if 'curriculum_name' in this_session.columns
+ else ['None'] * len(this_session.h2o), # 6
+ this_session.current_stage_actual
+ if 'current_stage_actual' in this_session.columns
+ else ['None'] * len(this_session.h2o), # 7
+ this_session[dot_size_mapping_name]
+ if dot_size_mapping_name !='None'
+ else [np.nan] * len(this_session.h2o), # 8
+ ), axis=-1),
unselected=dict(marker_color='lightgrey')
))
-
+
if if_aggr_each_group:
_add_agg(this_session, x_name, y_name, group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group, col, line_width=line_width)
-
if if_aggr_all:
_add_agg(df, x_name, y_name, 'all', aggr_method_all, if_use_x_quantile_all, q_quantiles_all, 'rgb(0, 0, 0)', line_width=line_width*1.5)
-
n_mice = len(df['h2o'].unique())
n_sessions = len(df.groupby(['h2o', 'session']).count())
-
+
fig.update_layout(
- width=1300,
- height=900,
- xaxis_title=x_name,
- yaxis_title=y_name,
- font=dict(size=25),
- hovermode='closest',
- legend={'traceorder':'reversed'},
- legend_font_size=15,
- title=f'{title}, {n_mice} mice, {n_sessions} sessions',
- dragmode='zoom', # 'select',
- margin=dict(l=130, r=50, b=130, t=100),
- )
+ width=x_y_plot_figure_width,
+ height=x_y_plot_figure_height,
+ xaxis_title=x_name,
+ yaxis_title=y_name,
+ font=dict(size=24 * font_size_scale),
+ hovermode="closest",
+ hoverlabel=dict(font_size=17 * font_size_scale),
+ legend={"traceorder": "reversed"},
+ legend_font_size=20 * font_size_scale,
+ title=f"{title}, {n_mice} mice, {n_sessions} sessions",
+ dragmode="zoom", # 'select',
+ margin=dict(l=130 * font_size_scale,
+ r=50 * font_size_scale,
+ b=130 * font_size_scale,
+ t=100 * font_size_scale,
+ ),
+ )
fig.update_xaxes(showline=True, linewidth=2, linecolor='black',
# range=[1, min(100, df[x_name].max())],
ticks = "outside", tickcolor='black', ticklen=10, tickwidth=2, ticksuffix=' ')
-
+
fig.update_yaxes(showline=True, linewidth=2, linecolor='black',
title_standoff=40,
ticks = "outside", tickcolor='black', ticklen=10, tickwidth=2, ticksuffix=' ')
return fig
-
-def _sync_widget_with_query(key, default):
- if key in st.query_params:
- # always get all query params as a list
- q_all = st.query_params.get_all(key)
-
- # convert type according to default
- list_default = default if isinstance(default, list) else [default]
- for d in list_default:
- _type = type(d)
- if _type: break # The first non-None type
-
- if _type == bool:
- q_all_correct_type = [q.lower() == 'true' for q in q_all]
- else:
- q_all_correct_type = [_type(q) for q in q_all]
-
- # flatten list if only one element
- if not isinstance(default, list):
- q_all_correct_type = q_all_correct_type[0]
-
- try:
- st.session_state[key] = q_all_correct_type
- except:
- print(f'Failed to set {key} to {q_all_correct_type}')
- else:
- try:
- st.session_state[key] = default
- except:
- print(f'Failed to set {key} to {default}')
diff --git a/code/util/url_query_helper.py b/code/util/url_query_helper.py
new file mode 100644
index 0000000..cb323b6
--- /dev/null
+++ b/code/util/url_query_helper.py
@@ -0,0 +1,82 @@
+import streamlit as st
+
+def checkbox_wrapper_for_url_query(st_prefix, label, key, default, **kwargs):
+ return st_prefix.checkbox(
+ label,
+ value=st.session_state[key]
+ if key in st.session_state else
+ st.query_params[key].lower()=='true'
+ if key in st.query_params
+ else default,
+ key=key,
+ **kwargs,
+ )
+
+def selectbox_wrapper_for_url_query(st_prefix, label, options, key, default, **kwargs):
+ return st_prefix.selectbox(
+ label,
+ options=options,
+ index=(
+ options.index(st.session_state[key])
+ if key in st.session_state
+ else options.index(st.query_params[key])
+ if key in st.query_params
+ else default
+ ),
+ key=key,
+ **kwargs,
+ )
+
+def slider_wrapper_for_url_query(st_prefix, label, min_value, max_value, key, default, **kwargs):
+
+ # Parse range from URL, compatible with both one or two values
+ if key in st.query_params:
+ parse_range_from_url = [type(min_value)(v) for v in st.query_params.get_all(key)]
+ if len(parse_range_from_url) == 1:
+ parse_range_from_url = parse_range_from_url[0]
+
+ return st_prefix.slider(
+ label,
+ min_value,
+ max_value,
+ value=(
+ st.session_state[key]
+ if key in st.session_state
+ else parse_range_from_url
+ if key in st.query_params
+ else default
+ ),
+ key=key,
+ **kwargs,
+ )
+
+
+def sync_widget_with_query(key, default):
+ if key in st.query_params:
+ # always get all query params as a list
+ q_all = st.query_params.get_all(key)
+
+ # convert type according to default
+ list_default = default if isinstance(default, list) else [default]
+ for d in list_default:
+ _type = type(d)
+ if _type: break # The first non-None type
+
+ if _type == bool:
+ q_all_correct_type = [q.lower() == 'true' for q in q_all]
+ else:
+ q_all_correct_type = [_type(q) for q in q_all]
+
+ # flatten list if only one element
+ if not isinstance(default, list):
+ q_all_correct_type = q_all_correct_type[0]
+
+ try:
+ st.session_state[key] = q_all_correct_type
+ except:
+ print(f'Failed to set {key} to {q_all_correct_type}')
+ else:
+ try:
+ st.session_state[key] = default
+ except:
+ print(f'Failed to set {key} to {default}')
\ No newline at end of file