From 3bf4322b69665ef7dc6cb38de95a1b16e4aa4a65 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Mon, 27 Feb 2023 19:55:38 +0000 Subject: [PATCH] minor improvements --- code/Home.py | 94 ++++++++++++++++++++++++++---------------- code/streamlit_util.py | 8 +--- 2 files changed, 59 insertions(+), 43 deletions(-) diff --git a/code/Home.py b/code/Home.py index 08806cd..527e1cc 100644 --- a/code/Home.py +++ b/code/Home.py @@ -123,26 +123,7 @@ def draw_session_plots(keys_to_draw_session): [1.5, 1], # columns in the second row ] - draw_type_mapper = {'1. Choice history': ('fitted_choice', # prefix - (0, 0), # location (row_idx, column_idx) - dict(other_patterns=['model_best', 'model_None'])), - '2. Lick times': ('lick_psth', - (1, 0), - {}), - '3. Logistic regression on choice': ('logistic_regression', - (1, 1), - dict(crop=(0, 0, 1200, 2000))), - '4. Win-stay-lose-shift prob.': ('wsls', - (1, 1), - dict(crop=(0, 0, 1200, 600))), - '5. Linear regression on RT': ('linear_regression_rt', - (1, 0), - dict()), - } - - cols_option = st.columns([3, 0.5, 1]) - selected_draw_types = cols_option[0].multiselect('Which plot(s) to draw?', draw_type_mapper.keys(), default=draw_type_mapper.keys()) - num_cols = cols_option[1].number_input('Number of columns', 1, 10, 2) + # cols_option = st.columns([3, 0.5, 1]) container_session_all_in_one = st.container() with container_session_all_in_one: @@ -152,30 +133,30 @@ def draw_session_plots(keys_to_draw_session): st.write(f'Loading selected {len(keys_to_draw_session)} sessions...') my_bar = st.columns((1, 7))[0].progress(0) - major_cols = st.columns([1] * num_cols) + major_cols = st.columns([1] * st.session_state.num_cols) if not isinstance(keys_to_draw_session, list): # Turn dataframe to list, if necessary keys_to_draw_session = keys_to_draw_session.to_dict(orient='records') for i, key in enumerate(keys_to_draw_session): - this_major_col = major_cols[i % num_cols] + this_major_col = major_cols[i % st.session_state.num_cols] # setting up layout for each session rows = [] with this_major_col: - st.markdown(f'''

{key["h2o"]}, Session {key["session"]}, {key["session_date"].split("T")[0]}''', + st.markdown(f'''

{key["h2o"]}, Session {key["session"]}, {key["session_date"].split("T")[0]}''', unsafe_allow_html=True) - if len(selected_draw_types) > 1: # more than one types, use the pre-defined layout + if len(st.session_state.selected_draw_types) > 1: # more than one types, use the pre-defined layout for row, column_setting in enumerate(layout_definition): rows.append(this_major_col.columns(column_setting)) else: # else, put it in the whole column rows = this_major_col.columns([1]) st.markdown("---") - for draw_type in draw_type_mapper: - if draw_type not in selected_draw_types: continue # To keep the draw order defined by draw_type_mapper - prefix, position, setting = draw_type_mapper[draw_type] - this_col = rows[position[0]][position[1]] if len(selected_draw_types) > 1 else rows[0] + for draw_type in st.session_state.draw_type_mapper: + if draw_type not in st.session_state.selected_draw_types: continue # To keep the draw order defined by st.session_state.draw_type_mapper + prefix, position, setting = st.session_state.draw_type_mapper[draw_type] + this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types) > 1 else rows[0] show_img_by_key_and_prefix(key, column=this_col, prefix=prefix, @@ -350,7 +331,7 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, col): fig.update_layout( # width=1300, - height=850, + # height=850, xaxis_title=x_name, yaxis_title=y_name, # xaxis_range=[0, min(100, df[x_name].max())], @@ -361,9 +342,14 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, col): ) # st.plotly_chart(fig) - selected_sessions_from_plot = plotly_events(fig, click_event=True, hover_event=False, select_event=True, override_height=870) + selected_sessions_from_plot = plotly_events(fig, click_event=True, hover_event=False, select_event=True, override_height=870, override_width=1400) return selected_sessions_from_plot + +def session_plot_settings(): + st.session_state.selected_draw_types = st.multiselect('Which plot(s) to draw?', st.session_state.draw_type_mapper.keys(), default=st.session_state.draw_type_mapper.keys()) + st.session_state.num_cols = st.columns([1, 3])[0].number_input('Number of columns', 1, 10, 2) + def population_analysis(): @@ -393,7 +379,19 @@ def population_analysis(): smooth_factor = s_cols[0].slider('Smooth factor', 1, 20, 5, disabled=not ((if_aggr_each_group and aggr_method_group=='lowess') or (if_aggr_all and aggr_method_all=='lowess'))) - + for i in range(7): st.write('\n') + + st.markdown("***") + st.markdown('##### Click or box/lasso select session(s) from the plots to draw 👉') + session_plot_settings() + + with st.expander(f'{len(st.session_state.df_selected_from_plotly)} sessions selected from plotly', expanded=False): + if st.button('clear selection'): + st.session_state.df_selected_from_plotly = pd.DataFrame() + + with st.expander('show sessions', expanded=False): + st.dataframe(st.session_state.df_selected_from_plotly) + names = {('session', 'foraging_eff'): 'Foraging efficiency', ('session', 'finished'): 'Finished trials', @@ -442,11 +440,30 @@ def init(): st.session_state.df_selected_from_plotly = pd.DataFrame() + # Init session states # add some model fitting params to session if not 'model_id' in st.session_state: st.session_state.model_id = 21 selected_id = st.session_state.model_id + st.session_state.draw_type_mapper = {'1. Choice history': ('fitted_choice', # prefix + (0, 0), # location (row_idx, column_idx) + dict(other_patterns=['model_best', 'model_None'])), + '2. Lick times': ('lick_psth', + (1, 0), + {}), + '3. Logistic regression on choice': ('logistic_regression', + (1, 1), + dict(crop=(0, 0, 1200, 2000))), + '4. Win-stay-lose-shift prob.': ('wsls', + (1, 1), + dict(crop=(0, 0, 1200, 600))), + '5. Linear regression on RT': ('linear_regression_rt', + (1, 0), + dict()), + } + + # process dfs df_this_model = st.session_state.df['model_fitting_params'].query(f'model_id == {selected_id}') valid_field = df_this_model.columns[~np.all(~df_this_model.notna(), axis=0)] to_add_model = st.session_state.df['model_fitting_params'].query(f'model_id == {selected_id}')[valid_field] @@ -458,6 +475,7 @@ def init(): st.session_state.session_stats_names = [keys for keys in st.session_state.df['sessions'].keys()] + def app(): @@ -472,6 +490,9 @@ def app(): st.cache_data.clear() init() st.experimental_rerun() + + st.markdown('---') + st.write('Han Hou @ 2023\nv1.0.0') with st.container(): @@ -496,17 +517,17 @@ def app(): st.session_state.aggrid_outputs = aggrid_interactive_table_session(df=st.session_state.df_session_filtered) - chosen_id = stx.tab_bar(data=[ - stx.TabBarItemData(id="tab1", title="📈Training summary", description="Plot training summary"), - stx.TabBarItemData(id="tab2", title="📚Session inspection", description="Generate plots for each session"), - ], default="tab1") + # chosen_id = stx.tab_bar(data=[ + # stx.TabBarItemData(id="tab1", title="📈Training summary", description="Plot training summary"), + # stx.TabBarItemData(id="tab2", title="📚Session inspection", description="Generate plots for each session"), + # ], default="tab1") + chosen_id = "tab1" placeholder = st.container() if chosen_id == "tab1": with placeholder: df_selected_from_plotly = population_analysis() - st.markdown('##### Select session(s) from the plots above to draw') if len(st.session_state.df_selected_from_plotly): draw_session_plots(st.session_state.df_selected_from_plotly) @@ -519,6 +540,7 @@ def app(): with placeholder: selected_keys_from_aggrid = st.session_state.aggrid_outputs['selected_rows'] st.markdown('##### Select session(s) from the table above to draw') + session_plot_settings() draw_session_plots(selected_keys_from_aggrid) # st.dataframe(st.session_state.df_session_filtered, use_container_width=True, height=1000) diff --git a/code/streamlit_util.py b/code/streamlit_util.py index a5e1fd3..48533e2 100644 --- a/code/streamlit_util.py +++ b/code/streamlit_util.py @@ -274,10 +274,4 @@ def add_session_filter(): with st.expander("Behavioral session filter", expanded=True): st.session_state.df_session_filtered = filter_dataframe(df=st.session_state.df['sessions']) st.markdown(f"{len(st.session_state.df_session_filtered)} sessions filtered (use_s3 = {st.session_state.use_s3})") - - with st.expander(f'{len(st.session_state.df_selected_from_plotly)} sessions selected from plotly', expanded=True): - if st.button('clear selection'): - st.session_state.df_selected_from_plotly = pd.DataFrame() - - with st.expander('show sessions', expanded=False): - st.dataframe(st.session_state.df_selected_from_plotly) \ No newline at end of file + \ No newline at end of file