diff --git a/code/Home.py b/code/Home.py index 71a809a..14f2433 100644 --- a/code/Home.py +++ b/code/Home.py @@ -280,7 +280,7 @@ def show_curriculums(): pass # ------- Layout starts here -------- # -def init(): +def init(if_load_docDB=True): # Clear specific session state and all filters for key in st.session_state: @@ -385,6 +385,11 @@ def _get_data_source(rig): _df.loc[_df['water_in_session_manual'] > 100, ['water_in_session_manual', 'water_in_session_total', 'water_after_session']] = np.nan + _df.loc[(_df['duration_iti_median'] < 0) | (_df['duration_iti_mean'] < 0), + ['duration_iti_median', 'duration_iti_mean', 'duration_iti_std', 'duration_iti_min', 'duration_iti_max']] = np.nan + + _df.loc[_df['invalid_lick_ratio'] < 0, + ['invalid_lick_ratio']]= np.nan # # add something else # add abs(bais) to all terms that have 'bias' in name @@ -449,21 +454,22 @@ def _get_data_source(rig): # --- Load data from docDB --- - _df = merge_in_df_docDB(_df) - - # add docDB_status column - _df["docDB_status"] = _df.apply( - lambda row: ( - "0_not uploaded" - if pd.isnull(row["session_loc"]) - else ( - "1_uploaded but not processed" - if pd.isnull(row["processed_session_loc"]) - else "2_uploaded and processed" - ) - ), - axis=1, - ) + if if_load_docDB: + _df = merge_in_df_docDB(_df) + + # add docDB_status column + _df["docDB_status"] = _df.apply( + lambda row: ( + "0_not uploaded" + if pd.isnull(row["session_loc"]) + else ( + "1_uploaded but not processed" + if pd.isnull(row["processed_session_loc"]) + else "2_uploaded and processed" + ) + ), + axis=1, + ) st.session_state.df['sessions_bonsai'] = _df # Somehow _df loses the reference to the original dataframe st.session_state.session_stats_names = [keys for keys in _df.keys()] @@ -753,9 +759,10 @@ def app(): # st.dataframe(st.session_state.df_session_filtered, use_container_width=True, height=1000) -ok = True -if 'df' not in st.session_state or 'sessions_bonsai' not in st.session_state.df.keys(): - ok = init() +if __name__ == "__main__": + ok = True + if 'df' not in st.session_state or 'sessions_bonsai' not in st.session_state.df.keys(): + ok = init() -if ok: - app() + if ok: + app() diff --git a/code/__init__.py b/code/__init__.py index 0f9ce39..005fd5f 100644 --- a/code/__init__.py +++ b/code/__init__.py @@ -1 +1 @@ -__ver__ = 'v2.5.5' +__ver__ = 'v2.5.6' diff --git a/code/pages/1_Basic behavior analysis.py b/code/pages/1_Basic behavior analysis.py new file mode 100644 index 0000000..9c3b752 --- /dev/null +++ b/code/pages/1_Basic behavior analysis.py @@ -0,0 +1,382 @@ +import numpy as np +import pandas as pd +import plotly.graph_objects as go +import s3fs +import streamlit as st +import matplotlib.pyplot as plt +import matplotlib +from plotly.subplots import make_subplots +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +from streamlit_plotly_events import plotly_events +from util.aws_s3 import load_data +from util.streamlit import add_session_filter, data_selector, add_footnote +from scipy.stats import gaussian_kde +import streamlit_nested_layout + +import extra_streamlit_components as stx + +from util.url_query_helper import ( + checkbox_wrapper_for_url_query, + multiselect_wrapper_for_url_query, + number_input_wrapper_for_url_query, + slider_wrapper_for_url_query, + sync_session_state_to_URL, + sync_URL_to_session_state, + to_sync_with_url_query_default, +) + + +from Home import init + +ss = st.session_state + +fs = s3fs.S3FileSystem(anon=False) +cache_folder = "aind-behavior-data/foraging_nwb_bonsai_processed/" + +try: + st.set_page_config( + layout="wide", + page_title="Foraging behavior browser", + page_icon=":mouse2:", + menu_items={ + "Report a bug": "https://github.com/AllenNeuralDynamics/foraging-behavior-browser/issues", + "About": "Github repo: https://github.com/AllenNeuralDynamics/foraging-behavior-browser", + }, + ) +except: + pass + +# Sort stages in the desired order +STAGE_ORDER = [ + "STAGE_1_WARMUP", + "STAGE_1", + "STAGE_2", + "STAGE_3", + "STAGE_4", + "STAGE_FINAL", + "GRADUATED", +] + +@st.cache_data() +def get_stage_color_mapper(stage_list): + # Mapping stages to colors from red to green, return rgb values + # Interpolate between red and green using the number of stages + cmap = plt.cm.get_cmap('RdYlGn', 100) + stage_color_mapper = { + stage: matplotlib.colors.rgb2hex( + cmap(i / (len(stage_list) - 1))) + for i, stage in enumerate(stage_list) + } + return stage_color_mapper + +STAGE_COLOR_MAPPER = get_stage_color_mapper(STAGE_ORDER) + +@st.cache_data() +def _get_metadata_col(): + df = load_data()["sessions_bonsai"] + + # -- get cols -- + col_task = [ + s + for s in df.metadata.columns + if not any( + ss in s + for ss in [ + "lickspout", + "weight", + "water", + "time", + "rig", + "user_name", + "experiment", + "task", + "notes", + "laser", + "commit", + "repo", + "branch", + ] # exclude some columns + ) + ] + [ + 'avg_trial_length_in_seconds', + 'weight_after_ratio', + ] + + col_perf = [ + s + for s in df.session_stats.columns + if not any(ss in s for ss in ["performance"]) + ] + [ + #TODO: build column groups in Home.py. Now I'm hardcoding. + 'abs(bias_naive)', + 'abs(logistic_Su2022_bias)', + 'logistic_Su2022_RewC_amp', + 'logistic_Su2022_RewC_tau', + 'logistic_Su2022_UnrC_amp', + 'logistic_Su2022_UnrC_tau', + 'logistic_Su2022_bias', + 'logistic_Su2022_score_mean', + ] + return col_perf, col_task + +COL_PERF, COL_TASK = _get_metadata_col() + +def app(): + with st.sidebar: + add_session_filter(if_bonsai=True) + data_selector() + add_footnote() + + if not hasattr(ss, "df"): + st.write("##### Data not loaded yet, start from Home:") + st.page_link("Home.py", label="Home", icon="🏠") + return + + # === Main tabs === + chosen_id = stx.tab_bar( + data=[ + stx.TabBarItemData( + id="tab_stage", + title="Training stages", + description="Compare across training stages", + ), + stx.TabBarItemData( + id="tab_PCA", + title="Learning trajectory", + description="PCA on performance and task parameters", + ), + ], + default=( + st.query_params["tab_id_learning_trajectory"] + if "tab_id_learning_trajectory" in st.query_params + else st.session_state.tab_id_learning_trajectory + ), + ) + + placeholder = st.container() + st.session_state.tab_id_learning_trajectory = chosen_id + + if chosen_id == "tab_PCA": + do_pca( + ss.df_session_filtered.loc[:, ["subject_id", "session"] + COL_PERF], + "performance", + ) + do_pca( + ss.df_session_filtered.loc[:, ["subject_id", "session"] + COL_TASK], "task" + ) + elif chosen_id == "tab_stage": + st.markdown("### Distributions of metrics and/or parameters grouped by training stages") + metrics_grouped_by_stages(df=ss.df_session_filtered) + + # Update back to URL + sync_session_state_to_URL() + + +def do_pca(df, name): + df = df.dropna(axis=0, how="any") + df = df[~df.isin([np.nan, np.inf, -np.inf]).any(axis=1)] + + df_to_pca = df.drop(columns=["subject_id", "session"]) + df_to_pca = df_to_pca.select_dtypes(include=[np.number, float, int]) + + # Standardize the features + x = StandardScaler().fit_transform(df_to_pca) + + # Apply PCA + pca = PCA(n_components=10) # Reduce to 2 dimensions for visualization + principalComponents = pca.fit_transform(x) + + # Create a new DataFrame with the principal components + principalDf = pd.DataFrame(data=principalComponents) + principalDf.index = df.set_index(["subject_id", "session"]).index + + principalDf.reset_index(inplace=True) + + # -- trajectory -- + st.markdown(f"### PCA on {name} metrics") + fig = go.Figure() + + for mouse_id in principalDf["subject_id"].unique(): + subset = principalDf[principalDf["subject_id"] == mouse_id] + + # Add a 3D scatter plot for the current group + fig.add_trace( + go.Scatter3d( + x=subset[0], + y=subset[1], + z=subset[2], + mode="lines+markers", + marker=dict(size=subset["session"].apply(lambda x: 5 + 15 * (x / 20))), + name=f"{mouse_id}", # Name the trace for the legend + ) + ) + + fig.update_layout( + title=name, + scene=dict(xaxis_title="Dim1", yaxis_title="Dim2", zaxis_title="Dim3"), + width=1300, + height=1000, + font_size=15, + ) + st.plotly_chart(fig) + + # -- variance explained -- + var_explained = pca.explained_variance_ratio_ + fig = go.Figure() + fig.add_trace( + go.Scatter( + x=np.arange(1, len(var_explained) + 1), + y=np.cumsum(var_explained), + ) + ) + fig.update_layout( + title="Variance Explained", + yaxis=dict(range=[0, 1]), + width=300, + height=400, + font_size=15, + ) + st.plotly_chart(fig) + + # -- pca components -- + pca_components = pd.DataFrame(pca.components_, columns=df_to_pca.columns) + pca_components + fig = make_subplots(rows=3, cols=1) + + # In vertical subplots, each subplot show the components of a principal component + for i in range(3): + fig.add_trace( + go.Bar( + x=pca_components.columns, + y=pca_components.loc[i], + name=f"PC{i+1}", + ), + row=i + 1, + col=1, + ) + + fig.update_xaxes(showticklabels=i == 2, row=i + 1, col=1) + + fig.update_layout( + title="PCA weights", + width=1000, + height=800, + font_size=20, + ) + st.plotly_chart(fig) + + +def metrics_grouped_by_stages(df): + + df["current_stage_actual"] = pd.Categorical( + df["current_stage_actual"], categories=STAGE_ORDER, ordered=True + ) + df = df.sort_values("current_stage_actual") + + # Multiselect for choosing numeric columns + selected_perf_columns = multiselect_wrapper_for_url_query( + st, + label= "Animal performance metrics to plot", + options=COL_PERF, + default=to_sync_with_url_query_default["stage_distribution_selected_perf_columns"], + key='stage_distribution_selected_perf_columns', + ) + selected_task_columns = multiselect_wrapper_for_url_query( + st, + label= "Task parameters to plot", + options=COL_TASK, + default=to_sync_with_url_query_default["stage_distribution_selected_task_columns"], + key='stage_distribution_selected_task_columns', + ) + selected_columns = selected_perf_columns + selected_task_columns + + # Checkbox to use density or not + use_kernel_smooth = st.checkbox("Use Kernel Smoothing", value=True) + if use_kernel_smooth: + use_density = False + bins = 100 + else: + bins = st.columns([1, 5])[0].slider("Number of bins", 10, 100, 20, 5) + use_density = st.checkbox("Use Density", value=False) + + num_plot_cols = st.columns([1, 7])[0].slider("Number of plotting columns", 1, 5, 4) + st.markdown("---") + + # Create a density plot for each selected column grouped by 'current_stage_actual' + unique_curriculum_name = ['Uncoupled Without Baiting', 'Uncoupled Baiting', 'Coupled Baiting'] + for curriculum_name in [name for name in unique_curriculum_name if name != "None"]: + st.markdown(f"### Curriculum name: {curriculum_name}") + + # Columns to plot + cols = st.columns([1] * num_plot_cols) + for n, column in enumerate(selected_columns): + with cols[n % num_plot_cols]: + st.write(f'''
Animal performance: {column}''' + if column in COL_PERF + else f"
Task parameter: {column}", + unsafe_allow_html=True) + fig = _plot_histograms( + df[df["curriculum_name"] == curriculum_name], + column, + bins, + use_kernel_smooth, + use_density, + ) + st.plotly_chart(fig, use_container_width=True) + + st.markdown("---") + + +@st.cache_data() +def _plot_histograms(df, column, bins, use_kernel_smooth, use_density): + fig = go.Figure() + + stage_data_all = df[column].dropna() + stage_data_all = stage_data_all[~stage_data_all.isin([np.inf, -np.inf])] + bin_edges = np.linspace(stage_data_all.min(), stage_data_all.max(), bins) + + for stage in df["current_stage_actual"].cat.categories: + if stage not in df["current_stage_actual"].unique(): + continue + stage_data = df[df["current_stage_actual"] == stage][[column, "subject_id"]].dropna() + n_sessions = len(stage_data) + n_mice = len(stage_data["subject_id"].unique()) + + stage_data = stage_data[column] + if use_kernel_smooth: + if len(stage_data.unique()) == 1: + # Handle case with only one unique value + unique_value = stage_data.iloc[0] + # Create a small range around the unique value for KDE + kde = lambda x: np.exp(-((x - unique_value) ** 2) / (unique_value/100)) # Fallback + else: + kde = gaussian_kde(stage_data) + y_vals = kde(bin_edges) + else: + y_vals, _ = np.histogram(stage_data, bins=bin_edges, density=use_density) + percentiles = [(np.sum(stage_data <= x) / len(stage_data)) * 100 for x in bin_edges[1:]] + customdata = np.array([percentiles]).T + + fig.add_trace( + go.Scatter( + x=(bin_edges[1:] + bin_edges[:-1]) / 2, + y=y_vals, + mode="lines", + line=dict(color=STAGE_COLOR_MAPPER[stage]), + name=f"{stage}
({n_mice} mice, {n_sessions} sessions)", + customdata=customdata, + hovertemplate=f"Percentile: %{{customdata[0]:.2f}}%
" + ) + ) + fig.update_layout( + xaxis_title=column, + yaxis_title="Kernel density" if use_kernel_smooth else "Density" if use_density else "Count", + hovermode="x unified", + ) + return fig + +if "df" not in st.session_state or "sessions_bonsai" not in st.session_state.df.keys(): + init(if_load_docDB=False) + +app() diff --git a/code/pages/1_Learning trajectory.py b/code/pages/1_Learning trajectory.py deleted file mode 100644 index d8fa143..0000000 --- a/code/pages/1_Learning trajectory.py +++ /dev/null @@ -1,137 +0,0 @@ -import numpy as np -import pandas as pd -import plotly.graph_objects as go -import s3fs -import streamlit as st -from plotly.subplots import make_subplots -from sklearn.decomposition import PCA -from sklearn.preprocessing import StandardScaler -from streamlit_plotly_events import plotly_events -from util.aws_s3 import load_data -from util.streamlit import add_session_filter, data_selector - -ss = st.session_state - -fs = s3fs.S3FileSystem(anon=False) -cache_folder = 'aind-behavior-data/foraging_nwb_bonsai_processed/' - - -def app(): - - with st.sidebar: - add_session_filter(if_bonsai=True) - data_selector() - - if not hasattr(ss, 'df'): - st.write('##### Data not loaded yet, start from Home:') - st.page_link('Home.py', label='Home', icon="🏠") - return - - df = load_data()['sessions_bonsai'] - - # -- get cols -- - col_task = [s for s in df.metadata.columns - if not any(ss in s for ss in ['lickspout', 'weight', 'water', 'time', 'rig', - 'user_name', 'experiment', 'task', 'notes', 'laser'] - ) - ] - - col_perf = [s for s in df.session_stats.columns - if not any(ss in s for ss in ['performance'] - ) - ] - - do_pca(ss.df_session_filtered.loc[:, ['subject_id', 'session'] + col_perf], 'performance') - do_pca(ss.df_session_filtered.loc[:, ['subject_id', 'session'] + col_task], 'task') - - -def do_pca(df, name): - df = df.dropna(axis=0, how='any') - df = df[~df.isin([np.nan, np.inf, -np.inf]).any(axis=1)] - - df_to_pca = df.drop(columns=['subject_id', 'session']) - df_to_pca = df_to_pca.select_dtypes(include=[np.number, float, int]) - - # Standardize the features - x = StandardScaler().fit_transform(df_to_pca) - - # Apply PCA - pca = PCA(n_components=10) # Reduce to 2 dimensions for visualization - principalComponents = pca.fit_transform(x) - - # Create a new DataFrame with the principal components - principalDf = pd.DataFrame(data=principalComponents) - principalDf.index = df.set_index(['subject_id', 'session']).index - - principalDf.reset_index(inplace=True) - - # -- trajectory -- - st.markdown(f'### PCA on {name} metrics') - fig = go.Figure() - - for mouse_id in principalDf['subject_id'].unique(): - subset = principalDf[principalDf['subject_id'] == mouse_id] - - # Add a 3D scatter plot for the current group - fig.add_trace(go.Scatter3d( - x=subset[0], - y=subset[1], - z=subset[2], - mode='lines+markers', - marker=dict(size=subset['session'].apply( - lambda x: 5 + 15*(x/20))), - name=f'{mouse_id}', # Name the trace for the legend - )) - - fig.update_layout(title=name, - scene=dict( - xaxis_title='Dim1', - yaxis_title='Dim2', - zaxis_title='Dim3' - ), - width=1300, - height=1000, - font_size=15, - ) - st.plotly_chart(fig) - - # -- variance explained -- - var_explained = pca.explained_variance_ratio_ - fig = go.Figure() - fig.add_trace(go.Scatter( - x=np.arange(1, len(var_explained)+1), - y=np.cumsum(var_explained), - ) - ) - fig.update_layout(title='Variance Explained', - yaxis=dict(range=[0, 1]), - width=300, - height=400, - font_size=15, - ) - st.plotly_chart(fig) - - # -- pca components -- - pca_components = pd.DataFrame(pca.components_, - columns=df_to_pca.columns) - pca_components - fig = make_subplots(rows=3, cols=1) - - # In vertical subplots, each subplot show the components of a principal component - for i in range(3): - fig.add_trace(go.Bar( - x=pca_components.columns, - y=pca_components.loc[i], - name=f'PC{i+1}', - ), row=i+1, col=1) - - fig.update_xaxes(showticklabels=i==2, row=i+1, col=1) - - fig.update_layout(title='PCA weights', - width=1000, - height=800, - font_size=20, - ) - st.plotly_chart(fig) - -app() \ No newline at end of file diff --git a/code/util/streamlit.py b/code/util/streamlit.py index 0849ff3..5ef2b06 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -691,13 +691,14 @@ def add_dot_property_mapper(): def data_selector(): - with st.expander(f'Session selector', expanded=True): - # --- add a download button --- - _add_download_filtered_session() - + with st.expander(f'Session selector', expanded=True): with st.expander(f"Filtered: {len(st.session_state.df_session_filtered)} sessions, " f"{len(st.session_state.df_session_filtered.h2o.unique())} mice", expanded=False): st.dataframe(st.session_state.df_session_filtered) + + # --- add a download button --- + with st.columns([1, 10])[1]: + _add_download_filtered_session() # cols = st.columns([4, 1]) # with cols[0].expander(f"From dataframe: {len(st.session_state.df_selected_from_dataframe)} sessions", expanded=False): diff --git a/code/util/url_query_helper.py b/code/util/url_query_helper.py index e1fa419..f9cc99f 100644 --- a/code/util/url_query_helper.py +++ b/code/util/url_query_helper.py @@ -9,56 +9,78 @@ # dict of "key": default pairs # Note: When creating the widget, add argument "value"/"index" as well as "key" for all widgets you want to sync with URL to_sync_with_url_query_default = { - 'if_load_bpod_sessions': False, + "if_load_bpod_sessions": False, + "to_filter_columns": [ + "subject_id", + "task", + "session", + "finished_trials", + "foraging_eff", + ], + "filter_subject_id": "", + "filter_session": [0.0, None], + "filter_finished_trials": [0.0, None], + "filter_foraging_eff": [0.0, None], + "filter_task": ["all"], + "table_height": 300, + "tab_id": "tab_auto_train_history", + "x_y_plot_xname": "session", + "x_y_plot_yname": "foraging_performance_random_seed", + "x_y_plot_group_by": "h2o", + "x_y_plot_if_show_dots": True, + "x_y_plot_if_aggr_each_group": True, + "x_y_plot_aggr_method_group": "lowess", + "x_y_plot_if_aggr_all": True, + "x_y_plot_aggr_method_all": "mean +/- sem", + "x_y_plot_smooth_factor": 5, + "x_y_plot_if_use_x_quantile_group": False, + "x_y_plot_q_quantiles_group": 20, + "x_y_plot_if_use_x_quantile_all": False, + "x_y_plot_q_quantiles_all": 20, + "x_y_plot_if_show_diagonal": False, + "x_y_plot_dot_size": 10, + "x_y_plot_dot_opacity": 0.3, + "x_y_plot_line_width": 2.0, + "x_y_plot_figure_width": 1300, + "x_y_plot_figure_height": 900, + "x_y_plot_font_size_scale": 1.0, + "x_y_plot_selected_color_map": "Plotly", + "x_y_plot_size_mapper": "finished_trials", + "x_y_plot_size_mapper_gamma": 1.0, + "x_y_plot_size_mapper_range": [3, 20], + "session_plot_mode": "sessions selected from table or plot", + "session_plot_selected_draw_types": list(draw_type_mapper_session_level.keys()), + "session_plot_number_cols": 3, + "auto_training_history_x_axis": "date", + "auto_training_history_sort_by": "first_date", + "auto_training_history_sort_order": "descending", + "auto_training_curriculum_name": "Uncoupled Baiting", + "auto_training_curriculum_version": "1.0", + "auto_training_curriculum_schema_version": "1.0", + "auto_training_history_recent_weeks": 8, - 'to_filter_columns': ['subject_id', 'task', 'session', 'finished_trials', 'foraging_eff'], - 'filter_subject_id': '', - 'filter_session': [0.0, None], - 'filter_finished_trials': [0.0, None], - 'filter_foraging_eff': [0.0, None], - 'filter_task': ['all'], - - 'table_height': 300, - - 'tab_id': 'tab_auto_train_history', - 'x_y_plot_xname': 'session', - 'x_y_plot_yname': 'foraging_performance_random_seed', - 'x_y_plot_group_by': 'h2o', - 'x_y_plot_if_show_dots': True, - 'x_y_plot_if_aggr_each_group': True, - 'x_y_plot_aggr_method_group': 'lowess', - 'x_y_plot_if_aggr_all': True, - 'x_y_plot_aggr_method_all': 'mean +/- sem', - 'x_y_plot_smooth_factor': 5, - 'x_y_plot_if_use_x_quantile_group': False, - 'x_y_plot_q_quantiles_group': 20, - 'x_y_plot_if_use_x_quantile_all': False, - 'x_y_plot_q_quantiles_all': 20, - 'x_y_plot_if_show_diagonal': False, - 'x_y_plot_dot_size': 10, - 'x_y_plot_dot_opacity': 0.3, - 'x_y_plot_line_width': 2.0, - 'x_y_plot_figure_width': 1300, - 'x_y_plot_figure_height': 900, - 'x_y_plot_font_size_scale': 1.0, - 'x_y_plot_selected_color_map': 'Plotly', - - 'x_y_plot_size_mapper': 'finished_trials', - 'x_y_plot_size_mapper_gamma': 1.0, - 'x_y_plot_size_mapper_range': [3, 20], - - 'session_plot_mode': 'sessions selected from table or plot', - 'session_plot_selected_draw_types': list(draw_type_mapper_session_level.keys()), - 'session_plot_number_cols': 3, - - 'auto_training_history_x_axis': 'date', - 'auto_training_history_sort_by': 'first_date', - 'auto_training_history_sort_order': 'descending', - 'auto_training_curriculum_name': 'Uncoupled Baiting', - 'auto_training_curriculum_version': '1.0', - 'auto_training_curriculum_schema_version': '1.0', - 'auto_training_history_recent_weeks': 8, - } + "tab_id_learning_trajectory": "tab_stage", + "stage_distribution_selected_perf_columns": [ + "finished_trials", + "finished_rate", + "foraging_eff_random_seed", + "abs(logistic_Su2022_bias)", + 'logistic_Su2022_RewC_amp', + 'logistic_Su2022_RewC_tau', + 'logistic_Su2022_UnrC_amp', + 'logistic_Su2022_UnrC_tau', + 'logistic_Su2022_score_mean', + "early_lick_rate", + "invalid_lick_ratio", + "double_dipping_rate_finished_trials", + ], + "stage_distribution_selected_task_columns": [ + "effective_block_length_median", + "duration_iti_mean", + "p_reward_contrast_mean", + "weight_after_ratio", + ], +} def checkbox_wrapper_for_url_query(st_prefix, label, key, default, **kwargs): return st_prefix.checkbox( @@ -86,7 +108,7 @@ def selectbox_wrapper_for_url_query(st_prefix, label, options, key, default, **k key=key, **kwargs, ) - + def multiselect_wrapper_for_url_query(st_prefix, label, options, key, default, **kwargs): return st_prefix.multiselect( label, @@ -124,8 +146,8 @@ def slider_wrapper_for_url_query(st_prefix, label, min_value, max_value, key, de key=key, **kwargs, ) - - + + def number_input_wrapper_for_url_query(st_prefix, label, min_value, max_value, key, default, **kwargs): return st_prefix.number_input( label=label, @@ -141,8 +163,8 @@ def number_input_wrapper_for_url_query(st_prefix, label, min_value, max_value, k key=key, **kwargs, ) - - + + def sync_URL_to_session_state(): """Assign session_state to sync with URL""" @@ -209,7 +231,7 @@ def sync_URL_to_session_state(): st.session_state[key] = default except: print(f'Failed to set {key} to {default}') - + def sync_session_state_to_URL(): # Add all 'filter_' fields to the default list @@ -231,8 +253,8 @@ def sync_session_state_to_URL(): st.query_params.update({key: st.session_state[key]}) except: print(f'Failed to update {key} to URL query') - - + + def get_filter_type(df, column): if is_numeric_dtype(df[column]): return 'slider_range_float'