From e4af24f2e34c52c1f47f4a4b9f8e96cefa5d5d75 Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Sun, 7 Apr 2024 23:06:57 -0700 Subject: [PATCH] feat: allow user to select colormap --- code/Home.py | 5 ++++- code/util/streamlit.py | 13 +++++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/code/Home.py b/code/Home.py index 61a3f76..ec029f3 100644 --- a/code/Home.py +++ b/code/Home.py @@ -88,6 +88,7 @@ 'x_y_plot_figure_width': 1300, 'x_y_plot_figure_height': 900, 'x_y_plot_font_size_scale': 1.0, + 'x_y_plot_selected_color_map': 'Plotly', 'x_y_plot_size_mapper': 'finished_trials', 'x_y_plot_size_mapper_gamma': 1.0, @@ -426,7 +427,8 @@ def plot_x_y_session(): with cols[1]: (if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group, if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, if_show_diagonal, - dot_size, dot_opacity, line_width, x_y_plot_figure_width, x_y_plot_figure_height, font_size_scale) = add_xy_setting() + dot_size, dot_opacity, line_width, x_y_plot_figure_width, x_y_plot_figure_height, + font_size_scale, color_map) = add_xy_setting() if st.session_state.x_y_plot_if_show_dots: with cols[2]: @@ -478,6 +480,7 @@ def plot_x_y_session(): x_y_plot_figure_width=x_y_plot_figure_width, x_y_plot_figure_height=x_y_plot_figure_height, font_size_scale=font_size_scale, + color_map=color_map, ) # st.plotly_chart(fig) diff --git a/code/util/streamlit.py b/code/util/streamlit.py index fcd1cd3..17c68f0 100644 --- a/code/util/streamlit.py +++ b/code/util/streamlit.py @@ -571,10 +571,18 @@ def add_xy_setting(): step=0.1, key='x_y_plot_font_size_scale', default=1.0) + + available_color_maps = list(px.colors.qualitative.__dict__.keys()) + available_color_maps = [c for c in available_color_maps if not c.startswith("_") and c != 'swatches'] + color_map = selectbox_wrapper_for_url_query(c[0], + label='Color map', + options=available_color_maps, + key='x_y_plot_selected_color_map', + default=available_color_maps.index('Plotly')) return (if_show_dots, if_aggr_each_group, aggr_method_group, if_use_x_quantile_group, q_quantiles_group, if_aggr_all, aggr_method_all, if_use_x_quantile_all, q_quantiles_all, smooth_factor, if_show_diagonal, - dot_size, dot_opacity, line_width, figure_width, figure_height, font_size_scale) + dot_size, dot_opacity, line_width, figure_width, figure_height, font_size_scale, color_map) @st.cache_data(ttl=24*3600) def _get_min_max(x, size_mapper_gamma): @@ -767,6 +775,7 @@ def _plot_population_x_y(df, x_name='session', y_name='foraging_eff', group_by=' x_y_plot_figure_width=1300, x_y_plot_figure_height=900, font_size_scale=1.0, + color_map='Plotly', **kwarg): def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_quantiles, col, line_width, hoverinfo='skip', **kwarg): @@ -904,7 +913,7 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, if_use_x_quantile, q_q pass fig = go.Figure() - col_map = px.colors.qualitative.Plotly + col_map = px.colors.qualitative.__dict__[color_map] # Add some more columns if dot_size_mapping_name !='None' and dot_size_mapping_name in df.columns: