diff --git a/code/Home.py b/code/Home.py index 09f41fc..33d3637 100644 --- a/code/Home.py +++ b/code/Home.py @@ -40,8 +40,12 @@ from util.streamlit import (filter_dataframe, aggrid_interactive_table_session, aggrid_interactive_table_curriculum, add_session_filter, data_selector, - _sync_widget_with_query, add_xy_selector, add_xy_setting, add_auto_train_manager, + add_xy_selector, add_xy_setting, add_auto_train_manager, add_dot_property_mapper, _plot_population_x_y) +from util.url_query_helper import ( + sync_widget_with_query, slider_wrapper_for_url_query, +) + import extra_streamlit_components as stx from aind_auto_train.curriculum_manager import CurriculumManager @@ -59,9 +63,11 @@ 'filter_foraging_eff': [0.0, None], 'filter_task': ['all'], + 'table_height': 300, + 'tab_id': 'tab_session_x_y', 'x_y_plot_xname': 'session', - 'x_y_plot_yname': 'foraging_eff', + 'x_y_plot_yname': 'foraging_performance_random_seed', 'x_y_plot_group_by': 'h2o', 'x_y_plot_if_show_dots': True, 'x_y_plot_if_aggr_each_group': True, @@ -73,9 +79,17 @@ 'x_y_plot_q_quantiles_group': 20, 'x_y_plot_if_use_x_quantile_all': False, 'x_y_plot_q_quantiles_all': 20, + 'x_y_plot_if_show_diagonal': False, 'x_y_plot_dot_size': 10, - 'x_y_plot_dot_opacity': 0.5, + 'x_y_plot_dot_opacity': 0.3, 'x_y_plot_line_width': 2.0, + 'x_y_plot_figure_width': 1300, + 'x_y_plot_figure_height': 900, + 'x_y_plot_font_size_scale': 1.0, + + 'x_y_plot_size_mapper': 'finished_trials', + 'x_y_plot_size_mapper_gamma': 1.0, + 'x_y_plot_size_mapper_range': [3, 20], 'session_plot_mode': 'sessions selected from table or plot', @@ -282,7 +296,42 @@ def draw_session_plots(df_to_draw_session): my_bar.progress(int((i + 1) / len(df_to_draw_session) * 100)) +def draw_session_plots_quick_preview(df_to_draw_session): + + # Setting up layout for each session + layout_definition = [[1], # columns in the first row + [1, 1], + ] + draw_types_quick_preview = ['1. Choice history', '2. Logistic regression (Su2022)'] + + container_session_all_in_one = st.container() + + key = df_to_draw_session.to_dict(orient='records')[0] + + with container_session_all_in_one: + try: + date_str = key["session_date"].strftime('%Y-%m-%d') + except: + date_str = key["session_date"].split("T")[0] + + st.markdown(f'''

{key["h2o"]}, Session {int(key["session"])}, {date_str}''', + unsafe_allow_html=True) + + rows = [] + for row, column_setting in enumerate(layout_definition): + rows.append(st.columns(column_setting)) + + for draw_type in draw_types_quick_preview: + if draw_type not in st.session_state.selected_draw_types: continue # To keep the draw order defined by st.session_state.draw_type_mapper_session_level + prefix, position, setting = st.session_state.draw_type_mapper_session_level[draw_type] + this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types) > 1 else rows[0] + show_session_level_img_by_key_and_prefix(key, + column=this_col, + prefix=prefix, + **setting) + + def draw_mice_plots(df_to_draw_mice): @@ -335,7 +384,7 @@ def draw_mice_plots(df_to_draw_mice): def session_plot_settings(need_click=True): st.markdown('##### Show plots for individual sessions ') - cols = st.columns([2, 1]) + cols = st.columns([2, 1, 6]) session_plot_modes = [f'sessions selected from table or plot', f'all sessions filtered from sidebar'] st.session_state.selected_draw_sessions = cols[0].selectbox(f'Which session(s) to draw?', @@ -351,7 +400,7 @@ def session_plot_settings(need_click=True): n_session_to_draw = len(st.session_state.df_selected_from_plotly) \ if 'selected from table or plot' in st.session_state.selected_draw_sessions \ else len(st.session_state.df_session_filtered) - st.markdown(f'{n_session_to_draw} sessions to draw') + cols[0].markdown(f'{n_session_to_draw} sessions to draw') st.session_state.num_cols = cols[1].number_input('number of columns', 1, 10, 3 if 'num_cols' not in st.session_state else st.session_state.num_cols) @@ -365,61 +414,38 @@ def session_plot_settings(need_click=True): """, unsafe_allow_html=True, ) - st.session_state.selected_draw_types = st.multiselect('Which plot(s) to draw?', + st.session_state.selected_draw_types = cols[2].multiselect('Which plot(s) to draw?', st.session_state.draw_type_mapper_session_level.keys(), default=st.session_state.draw_type_mapper_session_level.keys() if 'selected_draw_types' not in st.session_state else st.session_state.selected_draw_types) if need_click: - draw_it = st.button(f'Show me all {n_session_to_draw} sessions!', use_container_width=True) - else: - draw_it = True - return draw_it - -def mouse_plot_settings(need_click=True): - st.markdown('##### Show plots for individual mice ') - cols = st.columns([2, 1]) - st.session_state.selected_draw_mice = cols[0].selectbox('Which mice to draw?', - [f'selected from table/plot ({len(st.session_state.df_selected_from_plotly.h2o.unique())} mice)', - f'filtered from sidebar ({len(st.session_state.df_session_filtered.h2o.unique())} mice)'], - index=0 - ) - st.session_state.num_cols_mice = cols[1].number_input('Number of columns', 1, 10, - 3 if 'num_cols_mice' not in st.session_state else st.session_state.num_cols_mice) - st.markdown( - """ - """, - unsafe_allow_html=True, - ) - st.session_state.selected_draw_types_mice = st.multiselect('Which plot(s) to draw?', - st.session_state.draw_type_mapper_mouse_level.keys(), - default=st.session_state.draw_type_mapper_mouse_level.keys() - if 'selected_draw_types_mice' not in st.session_state else - st.session_state.selected_draw_types_mice) - if need_click: - draw_it = st.button('Show me all mice!', use_container_width=True) + draw_it = st.button(f'Show me all {n_session_to_draw} sessions!', use_container_width=True, type="primary") + draw_it_now_override = cols[1].checkbox('Auto draw') else: draw_it = True - return draw_it + draw_it_now_override = True + return draw_it | draw_it_now_override def plot_x_y_session(): - cols = st.columns([4, 10]) + cols = st.columns([1, 1, 1]) with cols[0]: - x_name, y_name, group_by = add_xy_selector(if_bonsai=True) + with cols[1]: (if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group, - if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, - dot_size, dot_opacity, line_width) = add_xy_setting() - - + if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, if_show_diagonal, + dot_size, dot_opacity, line_width, x_y_plot_figure_width, x_y_plot_figure_height, font_size_scale) = add_xy_setting() + + if st.session_state.x_y_plot_if_show_dots: + with cols[2]: + size_mapper, size_mapper_range, size_mapper_gamma = add_dot_property_mapper() + else: + size_mapper = 'None' + size_mapper_range, size_mapper_gamma = None, None # If no sessions are selected, use all filtered entries # df_x_y_session = st.session_state.df_selected_from_dataframe if if_plot_only_selected_from_dataframe else st.session_state.df_session_filtered @@ -432,33 +458,55 @@ def plot_x_y_session(): df_selected_from_plotly = pd.DataFrame() # for i, (title, (x_name, y_name)) in enumerate(names.items()): # with cols[i]: - with cols[1]: - fig = _plot_population_x_y(df=df_x_y_session, - x_name=x_name, y_name=y_name, - group_by=group_by, - smooth_factor=smooth_factor, - if_show_dots=if_show_dots, - if_aggr_each_group=if_aggr_each_group, - if_aggr_all=if_aggr_all, - aggr_method_group=aggr_method_group, - aggr_method_all=aggr_method_all, - if_use_x_quantile_group=if_use_x_quantile_group, - q_quantiles_group=q_quantiles_group, - if_use_x_quantile_all=if_use_x_quantile_all, - q_quantiles_all=q_quantiles_all, - title=names[(x_name, y_name)] if (x_name, y_name) in names else y_name, - states = st.session_state.df_selected_from_plotly, - dot_size=dot_size, - dot_opacity=dot_opacity, - line_width=line_width) + + if hasattr(st.session_state, 'x_y_plot_figure_width'): + _x_y_plot_scale = st.session_state.x_y_plot_figure_width / 1300 + cols = st.columns([1 * _x_y_plot_scale, 0.7]) + else: + cols = st.columns([1, 0.7]) + with cols[0]: + fig = _plot_population_x_y(df=df_x_y_session.copy(), + x_name=x_name, y_name=y_name, + group_by=group_by, + smooth_factor=smooth_factor, + if_show_dots=if_show_dots, + if_aggr_each_group=if_aggr_each_group, + if_aggr_all=if_aggr_all, + aggr_method_group=aggr_method_group, + aggr_method_all=aggr_method_all, + if_use_x_quantile_group=if_use_x_quantile_group, + q_quantiles_group=q_quantiles_group, + if_use_x_quantile_all=if_use_x_quantile_all, + q_quantiles_all=q_quantiles_all, + title=names[(x_name, y_name)] if (x_name, y_name) in names else y_name, + states = st.session_state.df_selected_from_plotly, + if_show_diagonal=if_show_diagonal, + dot_size_base=dot_size, + dot_size_mapping_name=size_mapper, + dot_size_mapping_range=size_mapper_range, + dot_size_mapping_gamma=size_mapper_gamma, + dot_opacity=dot_opacity, + line_width=line_width, + x_y_plot_figure_width=x_y_plot_figure_width, + x_y_plot_figure_height=x_y_plot_figure_height, + font_size_scale=font_size_scale, + ) # st.plotly_chart(fig) selected = plotly_events(fig, click_event=True, hover_event=False, select_event=True, override_height=fig.layout.height * 1.1, override_width=fig.layout.width) + + with cols[1]: + st.markdown('#### 👀 Quick preview') + st.markdown('###### Click on one session to preview here, or Box/Lasso select multiple sessions to draw them in the section below') + st.markdown('(sometimes you have to click twice...)') if len(selected): df_selected_from_plotly = df_x_y_session.merge(pd.DataFrame(selected).rename({'x': x_name, 'y': y_name}, axis=1), on=[x_name, y_name], how='inner') + if len(st.session_state.df_selected_from_plotly) == 1: + with cols[1]: + draw_session_plots_quick_preview(st.session_state.df_selected_from_plotly) return df_selected_from_plotly, cols @@ -477,7 +525,7 @@ def init(): # Set session state from URL for key, default in to_sync_with_url_query.items(): - _sync_widget_with_query(key, default) + sync_widget_with_query(key, default) df = load_data(['sessions', ]) @@ -573,6 +621,13 @@ def init(): # map user_name st.session_state.df['sessions_bonsai']['user_name'] = st.session_state.df['sessions_bonsai']['user_name'].apply(_user_name_mapper) + # fill nan for autotrain fields + filled_values = {'curriculum_name': 'None', + 'curriculum_version': 'None', + 'curriculum_schema_version': 'None', + 'current_stage_actual': 'None'} + st.session_state.df['sessions_bonsai'].fillna(filled_values, inplace=True) + # foraging performance = foraing_eff * finished_rate if 'foraging_performance' not in st.session_state.df['sessions_bonsai'].columns: st.session_state.df['sessions_bonsai']['foraging_performance'] = \ @@ -634,8 +689,15 @@ def app(): init() st.rerun() - table_height = cols[3].slider('Table height', 100, 2000, 400, 50, key='table_height') - + table_height = slider_wrapper_for_url_query(st_prefix=cols[3], + label='Table height', + min_value=0, + max_value=2000, + default=300, + step=50, + key='table_height', + ) + # aggrid_outputs = aggrid_interactive_table_units(df=df['ephys_units']) # st.session_state.df_session_filtered = aggrid_outputs['data'] @@ -673,8 +735,8 @@ def app(): with placeholder: df_selected_from_plotly, x_y_cols = plot_x_y_session() - with x_y_cols[0]: - for i in range(7): st.write('\n') + # Add session_plot_setting + with st.columns([1, 0.5])[0]: st.markdown("***") if_draw_all_sessions = session_plot_settings() @@ -722,7 +784,7 @@ def app(): elif chosen_id == "tab_session_inspector": with placeholder: - cols = st.columns([6, 3, 7]) + cols = st.columns([1, 0.5]) with cols[0]: if_draw_all_sessions = session_plot_settings(need_click=False) df_to_draw_sessions = st.session_state.df_selected_from_plotly if 'selected' in st.session_state.selected_draw_sessions else st.session_state.df_session_filtered @@ -832,6 +894,8 @@ def app(): # Add debug info if chosen_id != "tab_auto_train_curriculum": + for _ in range(10): st.write('\n') + st.markdown('---\n##### Debug zone') with st.expander('CO processing NWB errors', expanded=False): error_file = cache_folder + 'error_files.json' if fs.exists(error_file): diff --git a/code/pages/1_Old mice.py b/code/pages/1_Old mice.py index cf910ec..f11d1be 100644 --- a/code/pages/1_Old mice.py +++ b/code/pages/1_Old mice.py @@ -19,8 +19,10 @@ from streamlit_plotly_events import plotly_events from util.streamlit import (filter_dataframe, aggrid_interactive_table_session, add_session_filter, data_selector, - add_xy_selector, _sync_widget_with_query, add_xy_setting, add_auto_train_manager, + add_xy_selector, add_xy_setting, add_auto_train_manager, _plot_population_x_y) +from util.url_query_helper import sync_widget_with_query + import extra_streamlit_components as stx from aind_auto_train.auto_train_manager import DynamicForagingAutoTrainManager @@ -41,7 +43,7 @@ 'tab_id': 'tab_session_x_y', 'x_y_plot_xname': 'session', - 'x_y_plot_yname': 'foraging_eff', + 'x_y_plot_yname': 'foraging_performance', 'x_y_plot_group_by': 'h2o', 'x_y_plot_if_show_dots': True, 'x_y_plot_if_aggr_each_group': True, @@ -354,8 +356,8 @@ def plot_x_y_session(): x_name, y_name, group_by = add_xy_selector(if_bonsai=False) (if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group, - if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, - dot_size, dot_opacity, line_width) = add_xy_setting() + if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, if_show_diagonal, + dot_size, dot_opacity, line_width, _, _, _) = add_xy_setting() # If no sessions are selected, use all filtered entries @@ -385,7 +387,7 @@ def plot_x_y_session(): q_quantiles_all=q_quantiles_all, title=names[(x_name, y_name)] if (x_name, y_name) in names else y_name, states = st.session_state.df_selected_from_plotly, - dot_size=dot_size, + dot_size_base=dot_size, dot_opacity=dot_opacity, line_width=line_width) @@ -410,7 +412,7 @@ def init(): # Set session state from URL for key, default in to_sync_with_url_query.items(): - _sync_widget_with_query(key, default) + sync_widget_with_query(key, default) df = load_data(['sessions', 'logistic_regression_hattori', diff --git a/code/util/streamlit.py b/code/util/streamlit.py index 54bcbb6..662bdf3 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -1,4 +1,4 @@ -from email import header +from collections import OrderedDict import pandas as pd import streamlit as st from st_aggrid import AgGrid, GridOptionsBuilder @@ -7,6 +7,7 @@ is_categorical_dtype, is_datetime64_any_dtype, is_numeric_dtype, + is_string_dtype, is_object_dtype, ) import streamlit.components.v1 as components @@ -20,6 +21,8 @@ import statsmodels.api as sm from scipy.stats import linregress +from .url_query_helper import checkbox_wrapper_for_url_query, selectbox_wrapper_for_url_query, slider_wrapper_for_url_query + custom_css = { ".ag-root.ag-unselectable.ag-layout-normal": {"font-size": "15px !important", @@ -176,7 +179,7 @@ def cache_widget(field, clear=None): # Clear cache if needed if clear: if clear in st.session_state: del st.session_state[clear] - + # def dec_cache_widget_state(widget, ): @@ -390,6 +393,24 @@ def add_session_filter(if_bonsai=False, url_query={}): default_filters=['subject_id', 'task', 'session', 'finished_trials', 'foraging_eff'], url_query=url_query) +@st.cache_data(ttl=3600*24) +def _get_grouped_by_fields(if_bonsai): + if if_bonsai: + options = ['h2o', 'task', 'user_name', 'rig', 'weekday'] + options += [col + for col in st.session_state.df_session_filtered.columns + if is_categorical_dtype(st.session_state.df_session_filtered[col]) + or st.session_state.df_session_filtered[col].nunique() < 20 + and not any([exclude in col for exclude in + ('date', 'time', 'session', 'finished', 'foraging_eff')]) + ] + options = list(list(OrderedDict.fromkeys(options))) # Remove duplicates + else: + options = ['h2o', 'task', 'photostim_location', 'weekday', + 'headbar', 'user_name', 'sex', 'rig'] + return options + + def add_xy_selector(if_bonsai): with st.expander("Select axes", expanded=True): # with st.form("axis_selection"): @@ -403,165 +424,211 @@ def add_xy_selector(if_bonsai): else st.session_state.session_stats_names.index('session'), key='x_y_plot_xname' ) - y_name = cols[0].selectbox("y axis", - st.session_state.session_stats_names, - index=st.session_state.session_stats_names.index(st.session_state['x_y_plot_yname']) - if 'x_y_plot_yname' in st.session_state else - st.session_state.session_stats_names.index(st.query_params['x_y_plot_yname']) - if 'x_y_plot_yname' in st.query_params - else st.session_state.session_stats_names.index('foraging_eff'), - key='x_y_plot_yname') - - if if_bonsai: - options = ['h2o', 'task', 'user_name', 'rig', 'weekday'] - else: - options = ['h2o', 'task', 'photostim_location', 'weekday', - 'headbar', 'user_name', 'sex', 'rig'] - - group_by = cols[0].selectbox("grouped by", - options=options, - index=options.index(st.session_state['x_y_plot_group_by']) - if 'x_y_plot_group_by' in st.session_state else - options.index(st.query_params['x_y_plot_group_by']) - if 'x_y_plot_group_by' in st.query_params - else 0, - key='x_y_plot_group_by') - - # st.form_submit_button("update axes") + y_name = selectbox_wrapper_for_url_query( + cols[0], + label="y axis", + options=st.session_state.session_stats_names, + key="x_y_plot_yname", + default=st.session_state.session_stats_names.index('foraging_eff') + ) + + + group_by = selectbox_wrapper_for_url_query( + cols[0], + label="grouped by", + options=_get_grouped_by_fields(if_bonsai), + key="x_y_plot_group_by", + default=0, + ) + + # st.form_submit_button("update axes") return x_name, y_name, group_by - def add_xy_setting(): with st.expander('Plot settings', expanded=True): s_cols = st.columns([1, 1, 1]) # if_plot_only_selected_from_dataframe = s_cols[0].checkbox('Only selected', False) - if_show_dots = s_cols[0].checkbox('Show data points', - value=st.session_state['x_y_plot_if_show_dots'] - if 'x_y_plot_if_show_dots' in st.session_state else - st.query_params['x_y_plot_if_show_dots'].lower()=='true' - if 'x_y_plot_if_show_dots' in st.query_params - else True, - key='x_y_plot_if_show_dots') - - - if_aggr_each_group = s_cols[1].checkbox('Aggr each group', - value=st.session_state['x_y_plot_if_aggr_each_group'] - if 'x_y_plot_if_aggr_each_group' in st.session_state else - st.query_params['x_y_plot_if_aggr_each_group'].lower()=='true' - if 'x_y_plot_if_aggr_each_group' in st.query_params - else True, - key='x_y_plot_if_aggr_each_group') + if_show_dots = checkbox_wrapper_for_url_query(s_cols[0], + label='Show data points', + key='x_y_plot_if_show_dots', + default=True) + + if_aggr_each_group = checkbox_wrapper_for_url_query(s_cols[1], + label='Aggr each group', + key='x_y_plot_if_aggr_each_group', + default=True) aggr_methods = ['mean', 'mean +/- sem', 'lowess', 'running average', 'linear fit'] - aggr_method_group = s_cols[1].selectbox('aggr method group', - options=aggr_methods, - index=aggr_methods.index(st.session_state['x_y_plot_aggr_method_group']) - if 'x_y_plot_aggr_method_group' in st.session_state else - aggr_methods.index(st.query_params['x_y_plot_aggr_method_group']) - if 'x_y_plot_aggr_method_group' in st.query_params - else 2, - key='x_y_plot_aggr_method_group', - disabled=not if_aggr_each_group) + aggr_method_group = selectbox_wrapper_for_url_query(s_cols[1], + label='aggr method group', + options=aggr_methods, + key='x_y_plot_aggr_method_group', + default=2, + disabled=not if_aggr_each_group) - if_use_x_quantile_group = s_cols[1].checkbox('Use quantiles of x ', - value=st.session_state['x_y_plot_if_use_x_quantile_group'] - if 'x_y_plot_if_use_x_quantile_group' in st.session_state else - st.query_params['x_y_plot_if_use_x_quantile_group'].lower()=='true' - if 'x_y_plot_if_use_x_quantile_group' in st.query_params - else False, - key='x_y_plot_if_use_x_quantile_group', - disabled='mean' not in aggr_method_group) + if_use_x_quantile_group = checkbox_wrapper_for_url_query(s_cols[1], + label='Use quantiles of x', + key='x_y_plot_if_use_x_quantile_group', + default=False, + disabled='mean' not in aggr_method_group) - q_quantiles_group = s_cols[1].slider('Number of quantiles ', 1, 100, - value=st.session_state['x_y_plot_q_quantiles_group'] - if 'x_y_plot_q_quantiles_group' in st.session_state else - int(st.query_params['x_y_plot_q_quantiles_group']) - if 'x_y_plot_q_quantiles_group' in st.query_params - else 20, - key='x_y_plot_q_quantiles_group', - disabled=not if_use_x_quantile_group - ) + q_quantiles_group = slider_wrapper_for_url_query(s_cols[1], + label='Number of quantiles ', + min_value=1, + max_value=100, + key='x_y_plot_q_quantiles_group', + default=20, + disabled=not if_use_x_quantile_group + ) - if_aggr_all = s_cols[2].checkbox('Aggr all', - value=st.session_state['x_y_plot_if_aggr_all'] - if 'x_y_plot_if_aggr_all' in st.session_state else - st.query_params['x_y_plot_if_aggr_all'].lower()=='true' - if 'x_y_plot_if_aggr_all' in st.query_params - else True, - key='x_y_plot_if_aggr_all', - ) + if_aggr_all = checkbox_wrapper_for_url_query(s_cols[2], + label='Aggr all', + key='x_y_plot_if_aggr_all', + default=True) # st.session_state.if_aggr_all_cache = if_aggr_all # Have to use another variable to store this explicitly (my cache_widget somehow doesn't work with checkbox) - aggr_method_all = s_cols[2].selectbox('aggr method all', aggr_methods, - index=aggr_methods.index(st.session_state['x_y_plot_aggr_method_all']) - if 'x_y_plot_aggr_method_all' in st.session_state else - aggr_methods.index(st.query_params['x_y_plot_aggr_method_all']) - if 'x_y_plot_aggr_method_all' in st.query_params - else 1, - key='x_y_plot_aggr_method_all', - disabled=not if_aggr_all) - - if_use_x_quantile_all = s_cols[2].checkbox('Use quantiles of x', - value=st.session_state['x_y_plot_if_use_x_quantile_all'] - if 'x_y_plot_if_use_x_quantile_all' in st.session_state else - st.query_params['x_y_plot_if_use_x_quantile_all'].lower()=='true' - if 'x_y_plot_if_use_x_quantile_all' in st.query_params - else False, - key='x_y_plot_if_use_x_quantile_all', - disabled='mean' not in aggr_method_all, - ) - q_quantiles_all = s_cols[2].slider('number of quantiles', 1, 100, - value=st.session_state['x_y_plot_q_quantiles_all'] - if 'x_y_plot_q_quantiles_all' in st.session_state else - int(st.query_params['x_y_plot_q_quantiles_all']) - if 'x_y_plot_q_quantiles_all' in st.query_params - else 20, - key='x_y_plot_q_quantiles_all', - disabled=not if_use_x_quantile_all - ) - - smooth_factor = s_cols[0].slider('smooth factor', 1, 20, - value=st.session_state['x_y_plot_smooth_factor'] - if 'x_y_plot_smooth_factor' in st.session_state else - int(st.query_params['x_y_plot_smooth_factor']) - if 'x_y_plot_smooth_factor' in st.query_params - else 5, - key='x_y_plot_smooth_factor', - disabled= not ((if_aggr_each_group and aggr_method_group in ('running average', 'lowess')) - or (if_aggr_all and aggr_method_all in ('running average', 'lowess'))) - ) + aggr_method_all = selectbox_wrapper_for_url_query(s_cols[2], + label='aggr method all', + options=aggr_methods, + key='x_y_plot_aggr_method_all', + default=1, + disabled=not if_aggr_all) + + if_use_x_quantile_all = checkbox_wrapper_for_url_query(s_cols[2], + label='Use quantiles of x', + key='x_y_plot_if_use_x_quantile_all', + default=False, + disabled='mean' not in aggr_method_all) + - c = st.columns([1, 1, 1]) - dot_size = c[0].slider('dot size', 1, 30, - step=1, - value=st.session_state['x_y_plot_dot_size'] - if 'x_y_plot_dot_size' in st.session_state else - int(st.query_params['x_y_plot_dot_size']) - if 'x_y_plot_dot_size' in st.query_params - else 10, - key='x_y_plot_dot_size' - ) - dot_opacity = c[1].slider('opacity', 0.0, 1.0, - step=0.05, - value=st.session_state['x_y_plot_dot_opacity'] - if 'x_y_plot_dot_opacity' in st.session_state else - float(st.query_params['x_y_plot_dot_opacity']) - if 'x_y_plot_dot_opacity' in st.query_params - else 0.5, - key='x_y_plot_dot_opacity') - - line_width = c[2].slider('line width', 0.0, 5.0, - step=0.25, - value=st.session_state['x_y_plot_line_width'] - if 'x_y_plot_line_width' in st.session_state else - float(st.query_params['x_y_plot_line_width']) - if 'x_y_plot_line_width' in st.query_params - else 2.0, - key='x_y_plot_line_width') + q_quantiles_all = slider_wrapper_for_url_query(s_cols[2], + label='Number of quantiles', + min_value=1, + max_value=100, + key='x_y_plot_q_quantiles_all', + default=20, + disabled=not if_use_x_quantile_all + ) + + smooth_factor = slider_wrapper_for_url_query(s_cols[0], + label='smooth factor', + min_value=1, + max_value=20, + key='x_y_plot_smooth_factor', + default=5, + disabled= not ((if_aggr_each_group and aggr_method_group in ('running average', 'lowess')) + or (if_aggr_all and aggr_method_all in ('running average', 'lowess'))) + ) + + if_show_diagonal = checkbox_wrapper_for_url_query(s_cols[0], + label='Show diagonal line', + key='x_y_plot_if_show_diagonal', + default=False) + + with st.expander('Misc', expanded=False): + c = st.columns([1, 1, 1]) + dot_size = slider_wrapper_for_url_query(c[0], + label='dot size', + min_value=1, + max_value=30, + key='x_y_plot_dot_size', + default=10) + + dot_opacity = slider_wrapper_for_url_query(c[1], + label='opacity', + min_value=0.0, + max_value=1.0, + step=0.05, + key='x_y_plot_dot_opacity', + default=0.5) + + line_width = slider_wrapper_for_url_query(c[2], + label='line width', + min_value=0.0, + max_value=5.0, + step=0.25, + key='x_y_plot_line_width', + default=2.0) + + figure_width = slider_wrapper_for_url_query(c[0], + label='figure width', + min_value=500, + max_value=2500, + key='x_y_plot_figure_width', + default=1300) + + figure_height = slider_wrapper_for_url_query(c[1], + label='figure height', + min_value=500, + max_value=2500, + key='x_y_plot_figure_height', + default=900) + + font_size_scale = slider_wrapper_for_url_query(c[2], + label='font size', + min_value=0.0, + max_value=2.0, + step=0.1, + key='x_y_plot_font_size_scale', + default=1.0) return (if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group, - if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, - dot_size, dot_opacity, line_width) + if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, if_show_diagonal, + dot_size, dot_opacity, line_width, figure_width, figure_height, font_size_scale) + +@st.cache_data(ttl=24*3600) +def _get_min_max(x, size_mapper_gamma): + x_gamma_all = x ** size_mapper_gamma + return np.percentile(x_gamma_all, 5), np.percentile(x_gamma_all, 95) + +def _size_mapping(x, size_mapper_range, size_mapper_gamma): + x = x / np.quantile(x[~np.isnan(x)], 0.95) + x_gamma = x**size_mapper_gamma + min_x, max_x = _get_min_max(x, size_mapper_gamma) + sizes = size_mapper_range[0] + x_gamma / (max_x - min_x) * (size_mapper_range[1] - size_mapper_range[0]) + sizes[np.isnan(sizes)] = 0 + return sizes + +def add_dot_property_mapper(): + with st.expander('Data point property mapper', expanded=True): + cols = st.columns([2, 1, 1]) + + # Get all columns that are numeric + available_size_cols = ['None'] + [ + col + for col in st.session_state.session_stats_names + if is_numeric_dtype(st.session_state.df_session_filtered[col]) + ] + + size_mapper = selectbox_wrapper_for_url_query( + cols[0], + label="dot size mapper", + options=available_size_cols, + key='x_y_plot_size_mapper', + default=0, + ) + + if st.session_state.x_y_plot_size_mapper != 'None': + size_mapper_range = slider_wrapper_for_url_query(cols[1], + label="size range", + min_value=0, + max_value=100, + key='x_y_plot_size_mapper_range', + default=(0, 50), + ) + + size_mapper_gamma = slider_wrapper_for_url_query(cols[2], + label="size gamma", + min_value=0.0, + max_value=2.0, + key='x_y_plot_size_mapper_gamma', + default=1.0, + step=0.1) + else: + size_mapper_range, size_mapper_gamma = None, None + + return size_mapper, size_mapper_range, size_mapper_gamma + def data_selector(): @@ -690,15 +757,22 @@ def _plot_population_x_y(df, x_name='session', y_name='foraging_eff', group_by=' if_use_x_quantile_all=False, q_quantiles_all=20, title='', - dot_size=10, + if_show_diagonal=False, + dot_size_base=10, + dot_size_mapping_name='None', + dot_size_mapping_range=None, + dot_size_mapping_gamma=None, dot_opacity=0.4, line_width=2, + x_y_plot_figure_width=1300, + x_y_plot_figure_height=900, + font_size_scale=1.0, **kwarg): - - def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_quantiles, col, line_width): + + def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_quantiles, col, line_width, **kwarg): x = df_this.sort_values(x_name)[x_name].astype(float) y = df_this.sort_values(x_name)[y_name].astype(float) - + n_mice = len(df_this['h2o'].unique()) n_sessions = len(df_this.groupby(['h2o', 'session']).count()) n_str = f' ({n_mice} mice, {n_sessions} sessions)' if group_by !='h2o' else f' ({n_sessions} sessions)' @@ -714,12 +788,14 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q line_width=line_width, opacity=1, hoveron='points+fills', # Scattergl doesn't support this + hoverinfo='skip', + **kwarg, )) - + elif aggr_method == 'lowess': x_new = np.linspace(x.min(), x.max(), 200) lowess = sm.nonparametric.lowess(y, x, frac=smooth_factor/20) - + fig.add_trace(go.Scatter( x=lowess[:, 0], y=lowess[:, 1], @@ -730,25 +806,27 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q marker_color=col, opacity=1, hoveron='points+fills', # Scattergl doesn't support this + hoverinfo='skip', + **kwarg, )) - + elif aggr_method in ('mean +/- sem', 'mean'): - + # Re-bin x if use quantiles of x if if_use_x_quantile: df_this[f'{x_name}_quantile'] = pd.qcut(df_this[x_name], q=q_quantiles, labels=False, duplicates='drop') - + mean = df_this.groupby(f'{x_name}_quantile')[y_name].mean() sem = df_this.groupby(f'{x_name}_quantile')[y_name].sem() valid_y = mean.notna() mean = mean[valid_y] sem = sem[valid_y] sem[~sem.notna()] = 0 - + x = df_this.groupby(f'{x_name}_quantile')[x_name].median() # Use median of x in each quantile as x y_upper = mean + sem y_lower = mean - sem - + else: # mean and sem groupby x_name mean = df_this.groupby(x_name)[y_name].mean() @@ -757,11 +835,11 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q mean = mean[valid_y] sem = sem[valid_y] sem[~sem.notna()] = 0 - + x = mean.index y_upper = mean + sem y_lower = mean - sem - + fig.add_trace(go.Scatter( x=x, y=mean, @@ -775,8 +853,10 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q marker_color=col, opacity=1, hoveron='points+fills', # Scattergl doesn't support this + hoverinfo='skip', + **kwarg, )) - + if 'sem' in aggr_method: fig.add_trace(go.Scatter( # name='Upper Bound', @@ -802,7 +882,6 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q showlegend=False, hoverinfo='skip' )) - elif aggr_method == 'linear fit': # perform linear regression @@ -818,19 +897,40 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q line=dict(dash='dot' if p_value > 0.05 else 'solid', width=line_width if p_value > 0.05 else line_width*1.5), legendgroup=f'group_{group}', - # hoverinfo='skip' + hoverinfo='skip' ) ) except: pass - + fig = go.Figure() col_map = px.colors.qualitative.Plotly - + + # Add some more columns + if dot_size_mapping_name !='None' and dot_size_mapping_name in df.columns: + df['dot_size'] = _size_mapping(df[dot_size_mapping_name], dot_size_mapping_range, dot_size_mapping_gamma) + else: + df['dot_size'] = dot_size_base + + # Turn column of group_by to string if it's not + if not is_string_dtype(df[group_by]): + df[group_by] = df[group_by].astype(str) + + # Add a diagonal line first + if if_show_diagonal: + _min = df[x_name].values.ravel().min() + _max = df[y_name].values.ravel().max() + fig.add_trace(go.Scattergl(x=[_min, _max], + y=[_min, _max], + mode='lines', + line=dict(dash='dash', color='black', width=2), + showlegend=False) + ) + for i, group in enumerate(df.sort_values(group_by)[group_by].unique()): this_session = df.query(f'{group_by} == "{group}"').sort_values('session') col = col_map[i%len(col_map)] - + if if_show_dots: if not len(st.session_state.df_selected_from_plotly): this_session['colors'] = col # all use normal colors @@ -840,7 +940,7 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q merged.loc[merged.subject_id_y.notna(), 'colors'] = col # only use normal colors for the selected dots this_session['colors'] = merged.colors.values this_session = pd.concat([this_session.query('colors != "lightgrey"'), this_session.query('colors == "lightgrey"')]) # make sure the real color goes first - + fig.add_trace(go.Scattergl( x=this_session[x_name], y=this_session[y_name], @@ -849,77 +949,70 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q showlegend=not if_aggr_each_group, mode="markers", line_width=line_width, - marker_size=dot_size, + marker_size=this_session['dot_size'], marker_color=this_session['colors'], - opacity=dot_opacity, # 0.5 if if_aggr_each_group else 0.8, - text=this_session['session'], - hovertemplate = '
%{customdata[0]}, Session %{text}' + - '
%s = %%{x}' % (x_name) + - '
%s = %%{y}' % (y_name), - # '%{name}', - customdata=np.stack((this_session.h2o, this_session.session), axis=-1), + opacity=dot_opacity, + hovertemplate = '%{customdata[0]}, %{customdata[1]}, Session %{customdata[2]}' + '
%{customdata[3]}, %{customdata[4]}' + '
Task: %{customdata[5]}' + '
AutoTrain: %{customdata[7]} @ %{customdata[6]}
' + f'
{"-"*10}
X: {x_name} = %{{x}}' + f'
Y: {y_name} = %{{y}}' + + (f'
Size: {dot_size_mapping_name} = %{{customdata[8]}}' + if dot_size_mapping_name !='None' + else '') + + '', + customdata=np.stack((this_session.h2o, # 0 + this_session.session_date.dt.strftime('%Y-%m-%d'), # 1 + this_session.session, # 2 + this_session.rig, # 3 + this_session.user_name, # 4 + this_session.task, # 5 + this_session.curriculum_name + if 'curriculum_name' in this_session.columns + else ['None'] * len(this_session.h2o), # 6 + this_session.current_stage_actual + if 'current_stage_actual' in this_session.columns + else ['None'] * len(this_session.h2o), # 7 + this_session[dot_size_mapping_name] + if dot_size_mapping_name !='None' + else [np.nan] * len(this_session.h2o), # 8 + ), axis=-1), unselected=dict(marker_color='lightgrey') )) - + if if_aggr_each_group: _add_agg(this_session, x_name, y_name, group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group, col, line_width=line_width) - if if_aggr_all: _add_agg(df, x_name, y_name, 'all', aggr_method_all, if_use_x_quantile_all, q_quantiles_all, 'rgb(0, 0, 0)', line_width=line_width*1.5) - n_mice = len(df['h2o'].unique()) n_sessions = len(df.groupby(['h2o', 'session']).count()) - + fig.update_layout( - width=1300, - height=900, - xaxis_title=x_name, - yaxis_title=y_name, - font=dict(size=25), - hovermode='closest', - legend={'traceorder':'reversed'}, - legend_font_size=15, - title=f'{title}, {n_mice} mice, {n_sessions} sessions', - dragmode='zoom', # 'select', - margin=dict(l=130, r=50, b=130, t=100), - ) + width=x_y_plot_figure_width, + height=x_y_plot_figure_height, + xaxis_title=x_name, + yaxis_title=y_name, + font=dict(size=24 * font_size_scale), + hovermode="closest", + hoverlabel=dict(font_size=17 * font_size_scale), + legend={"traceorder": "reversed"}, + legend_font_size=20 * font_size_scale, + title=f"{title}, {n_mice} mice, {n_sessions} sessions", + dragmode="zoom", # 'select', + margin=dict(l=130 * font_size_scale, + r=50 * font_size_scale, + b=130 * font_size_scale, + t=100 * font_size_scale, + ), + ) fig.update_xaxes(showline=True, linewidth=2, linecolor='black', # range=[1, min(100, df[x_name].max())], ticks = "outside", tickcolor='black', ticklen=10, tickwidth=2, ticksuffix=' ') - + fig.update_yaxes(showline=True, linewidth=2, linecolor='black', title_standoff=40, ticks = "outside", tickcolor='black', ticklen=10, tickwidth=2, ticksuffix=' ') return fig - -def _sync_widget_with_query(key, default): - if key in st.query_params: - # always get all query params as a list - q_all = st.query_params.get_all(key) - - # convert type according to default - list_default = default if isinstance(default, list) else [default] - for d in list_default: - _type = type(d) - if _type: break # The first non-None type - - if _type == bool: - q_all_correct_type = [q.lower() == 'true' for q in q_all] - else: - q_all_correct_type = [_type(q) for q in q_all] - - # flatten list if only one element - if not isinstance(default, list): - q_all_correct_type = q_all_correct_type[0] - - try: - st.session_state[key] = q_all_correct_type - except: - print(f'Failed to set {key} to {q_all_correct_type}') - else: - try: - st.session_state[key] = default - except: - print(f'Failed to set {key} to {default}') diff --git a/code/util/url_query_helper.py b/code/util/url_query_helper.py new file mode 100644 index 0000000..cb323b6 --- /dev/null +++ b/code/util/url_query_helper.py @@ -0,0 +1,82 @@ +import streamlit as st + +def checkbox_wrapper_for_url_query(st_prefix, label, key, default, **kwargs): + return st_prefix.checkbox( + label, + value=st.session_state[key] + if key in st.session_state else + st.query_params[key].lower()=='true' + if key in st.query_params + else default, + key=key, + **kwargs, + ) + +def selectbox_wrapper_for_url_query(st_prefix, label, options, key, default, **kwargs): + return st_prefix.selectbox( + label, + options=options, + index=( + options.index(st.session_state[key]) + if key in st.session_state + else options.index(st.query_params[key]) + if key in st.query_params + else default + ), + key=key, + **kwargs, + ) + +def slider_wrapper_for_url_query(st_prefix, label, min_value, max_value, key, default, **kwargs): + + # Parse range from URL, compatible with both one or two values + if key in st.query_params: + parse_range_from_url = [type(min_value)(v) for v in st.query_params.get_all(key)] + if len(parse_range_from_url) == 1: + parse_range_from_url = parse_range_from_url[0] + + return st_prefix.slider( + label, + min_value, + max_value, + value=( + st.session_state[key] + if key in st.session_state + else parse_range_from_url + if key in st.query_params + else default + ), + key=key, + **kwargs, + ) + + +def sync_widget_with_query(key, default): + if key in st.query_params: + # always get all query params as a list + q_all = st.query_params.get_all(key) + + # convert type according to default + list_default = default if isinstance(default, list) else [default] + for d in list_default: + _type = type(d) + if _type: break # The first non-None type + + if _type == bool: + q_all_correct_type = [q.lower() == 'true' for q in q_all] + else: + q_all_correct_type = [_type(q) for q in q_all] + + # flatten list if only one element + if not isinstance(default, list): + q_all_correct_type = q_all_correct_type[0] + + try: + st.session_state[key] = q_all_correct_type + except: + print(f'Failed to set {key} to {q_all_correct_type}') + else: + try: + st.session_state[key] = default + except: + print(f'Failed to set {key} to {default}') \ No newline at end of file