From 3ef095c5d922b2f5812d93a431da6f1de5c986fb Mon Sep 17 00:00:00 2001 From: gurayerus Date: Wed, 18 Sep 2024 13:09:14 -0400 Subject: [PATCH] Add state to plots --- src/NiChart_Viewer/src/pages/home.py | 6 +- .../src/pages/view_plot_data.py | 141 ++++++++++-------- 2 files changed, 83 insertions(+), 64 deletions(-) diff --git a/src/NiChart_Viewer/src/pages/home.py b/src/NiChart_Viewer/src/pages/home.py index b04b3750..49ca86ff 100644 --- a/src/NiChart_Viewer/src/pages/home.py +++ b/src/NiChart_Viewer/src/pages/home.py @@ -40,8 +40,8 @@ # FIXME: temp path for running fast # Should be set as the images are created - st.session_state.dir_t1img = st.session_state.path_root + '/test/test_input/test3_nifti+roi' - st.session_state.dir_dlmuse = st.session_state.path_root + '/test/test_input/test3_nifti+roi' + st.session_state.dir_t1img = st.session_state.path_root + '/test/test3_nifti+roi' + st.session_state.dir_dlmuse = st.session_state.path_root + '/test/test3_nifti+roi' st.session_state.suffix_t1img = '_T1.nii.gz' st.session_state.suffix_dlmuse = '_T1_DLMUSE.nii.gz' @@ -64,7 +64,7 @@ st.session_state.path_csv_spare = '' ## FIXME : this is for quickly loading a test example - st.session_state.path_csv_spare = st.session_state.path_root + '/test/test_input/test3_nifti+roi/sMRI_Results_n4.csv' + st.session_state.path_csv_spare = st.session_state.path_root + '/test4_adni3/output/out_combined/MyStudy_All.csv' st.session_state.instantiated = True diff --git a/src/NiChart_Viewer/src/pages/view_plot_data.py b/src/NiChart_Viewer/src/pages/view_plot_data.py index fe2d548e..a6c3bc79 100644 --- a/src/NiChart_Viewer/src/pages/view_plot_data.py +++ b/src/NiChart_Viewer/src/pages/view_plot_data.py @@ -45,22 +45,29 @@ def add_plot(): Adds a new plot (updates a dataframe with plot ids) ''' df_p = st.session_state.plots - df_p.loc[st.session_state.pid] = [f'Plot {st.session_state.pid}'] - st.session_state.pid += 1 + plot_id = f'Plot{st.session_state.plot_index}' + df_p.loc[plot_id] = [plot_id, + st.session_state.plot_xvar, + st.session_state.plot_yvar, + st.session_state.plot_hvar, + st.session_state.plot_trend + ] + + st.session_state.plot_index += 1 # Remove a plot -def remove_plot(pid): +def remove_plot(plot_id): ''' - Removes the plot with the pid (updates the plot ids dataframe) + Removes the plot with the plot_id (updates the plot ids dataframe) ''' df_p = st.session_state.plots - df_p = df_p[df_p.PID != pid] + df_p = df_p[df_p.PID != plot_id] st.session_state.plots = df_p -def display_plot(pid): +def display_plot(plot_id): ''' - Displays the plot with the pid + Displays the plot with the plot_id ''' # Create a copy of dataframe for filtered data @@ -80,46 +87,58 @@ def display_plot(pid): # Tab 1: to set plotting parameters with ptabs[1]: - plot_type = st.selectbox("Plot Type", ["DistPlot", "RegPlot"], key=f"plot_type_{pid}") - # x_var = st.selectbox("X Var", df_filt.columns, key=f"x_var_{pid}", index=3) - # y_var = st.selectbox("Y Var", df_filt.columns, key=f"y_var_{pid}", index=8) - - # Set index for default values - x_ind = df.columns.get_loc(st.session_state.default_x_var) - y_ind = df.columns.get_loc(st.session_state.default_y_var) - hue_ind = df.columns.get_loc(st.session_state.default_hue_var) - trend_index = st.session_state.trend_types.index(st.session_state.default_trend_type) - - x_var = st.selectbox("X Var", df_filt.columns, key=f"x_var_{pid}", index = x_ind) - y_var = st.selectbox("Y Var", df_filt.columns, key=f"y_var_{pid}", index = y_ind) - st.session_state.sel_var = y_var + plot_type = st.selectbox("Plot Type", ["DistPlot", "RegPlot"], key=f"plot_type_{plot_id}") + + # Get plot params + xvar = st.session_state.plots.loc[plot_id].xvar + yvar = st.session_state.plots.loc[plot_id].yvar + hvar = st.session_state.plots.loc[plot_id].hvar + trend = st.session_state.plots.loc[plot_id].trend + + # Select plot params from the user + xind = df.columns.get_loc(xvar) + yind = df.columns.get_loc(yvar) + hind = df.columns.get_loc(hvar) + tind = st.session_state.trend_types.index(trend) + + xvar = st.selectbox("X Var", df_filt.columns, + key=f"plot_xvar_{plot_id}", index = xind) + yvar = st.selectbox("Y Var", df_filt.columns, + key=f"plot_yvar_{plot_id}", index = yind) + hvar = st.selectbox("Hue Var", df_filt.columns, + key=f"plot_hvar_{plot_id}", index = hind) + trend = st.selectbox("Trend Line", st.session_state.trend_types, + key=f"trend_type_{plot_id}", index = tind) + + # Set plot params to session_state + st.session_state.plots.loc[plot_id].xvar = xvar + st.session_state.plots.loc[plot_id].yvar = yvar + st.session_state.plots.loc[plot_id].hvar = hvar + st.session_state.plots.loc[plot_id].trend = trend - hue_var = st.selectbox("Hue Var", df_filt.columns, key=f"hue_var_{pid}", index = hue_ind) - trend_type = st.selectbox("Trend Line", st.session_state.trend_types, key=f"trend_type_{pid}", index = trend_index) # Tab 2: to set data filtering parameters with ptabs[2]: - df_filt = filter_dataframe(df, pid) + df_filt = filter_dataframe(df, plot_id) # Tab 3: to set centiles with ptabs[3]: - cent_type = st.selectbox("Centile Type", ['CN-All', 'CN-F', 'CN-M'], key=f"cent_type_{pid}") + cent_type = st.selectbox("Centile Type", ['CN-All', 'CN-F', 'CN-M'], key=f"cent_type_{plot_id}") # Tab 4: to reset parameters or to delete plot with ptabs[4]: - st.button('Delete Plot', key=f'p_delete_{pid}', - on_click=remove_plot, args=[pid]) + st.button('Delete Plot', key=f'p_delete_{plot_id}', + on_click=remove_plot, args=[plot_id]) # Main plot - if trend_type == 'none': - scatter_plot = px.scatter(df_filt, x = x_var, y = y_var, color = hue_var) + if trend == 'none': + scatter_plot = px.scatter(df_filt, x = xvar, y = yvar, color = hvar) else: - scatter_plot = px.scatter(df_filt, x = x_var, y = y_var, color = hue_var, - trendline = trend_type) + scatter_plot = px.scatter(df_filt, x = xvar, y = yvar, color = hvar, trendline = trend) # Add plot # - on_select: when clicked it will rerun and return the info - sel_info = st.plotly_chart(scatter_plot, on_select='rerun', key=f"bubble_chart_{pid}") + sel_info = st.plotly_chart(scatter_plot, on_select='rerun', key=f"bubble_chart_{plot_id}") # Detect MRID from the click info try: @@ -136,7 +155,7 @@ def display_plot(pid): # ## FIXME: this is temp (for debugging the selection of clicked subject) # st.dataframe(df_filt) -def filter_dataframe(df: pd.DataFrame, pid) -> pd.DataFrame: +def filter_dataframe(df: pd.DataFrame, plot_id) -> pd.DataFrame: """ Adds a UI on top of a dataframe to let viewers filter columns @@ -153,14 +172,14 @@ def filter_dataframe(df: pd.DataFrame, pid) -> pd.DataFrame: # Create filters selected by the user modification_container = st.container() with modification_container: - widget_no = pid + '_filter' + widget_no = plot_id + '_filter' to_filter_columns = st.multiselect("Filter dataframe on", df.columns, key = widget_no) for vno, column in enumerate(to_filter_columns): left, right = st.columns((1, 20)) left.write("↳") # Treat columns with < 10 unique values as categorical if is_categorical_dtype(df[column]) or df[column].nunique() < 10: - widget_no = pid + '_col_' + str(vno) + widget_no = plot_id + '_col_' + str(vno) user_cat_input = right.multiselect( f"Values for {column}", df[column].unique(), @@ -241,30 +260,30 @@ def filter_dataframe(df: pd.DataFrame, pid) -> pd.DataFrame: # Tab 0: to set plotting parameters with ptabs[1]: # Default values for plot params - st.session_state.default_hue_var = 'Sex' - - def_ind_x = 0 - if st.session_state.default_x_var in df.columns: - def_ind_x = df.columns.get_loc(st.session_state.default_x_var) - - def_ind_y = 0 - if st.session_state.default_y_var in df.columns: - def_ind_y = df.columns.get_loc(st.session_state.default_y_var) - - def_ind_hue = 0 - if st.session_state.default_hue_var in df.columns: - def_ind_hue = df.columns.get_loc(st.session_state.default_hue_var) - - st.session_state.default_x_var = st.selectbox("Default X Var", df.columns, key=f"x_var_init", - index = def_ind_x) - st.session_state.default_y_var = st.selectbox("Default Y Var", df.columns, key=f"y_var_init", - index = def_ind_y) - st.session_state.sel_var = st.session_state.default_y_var - - st.session_state.default_hue_var = st.selectbox("Default Hue Var", df.columns, key=f"hue_var_init", - index = def_ind_hue) - trend_index = st.session_state.trend_types.index(st.session_state.default_trend_type) - st.session_state.default_trend_type = st.selectbox("Default Trend Line", st.session_state.trend_types, + st.session_state.plot_hvar = 'Sex' + + plot_xvar_ind = 0 + if st.session_state.plot_xvar in df.columns: + plot_xvar_ind = df.columns.get_loc(st.session_state.plot_xvar) + + plot_yvar_ind = 0 + if st.session_state.plot_yvar in df.columns: + plot_yvar_ind = df.columns.get_loc(st.session_state.plot_yvar) + + plot_hvar_ind = 0 + if st.session_state.plot_hvar in df.columns: + plot_hvar_ind = df.columns.get_loc(st.session_state.plot_hvar) + + st.session_state.plot_xvar = st.selectbox("Default X Var", df.columns, key=f"plot_xvar_init", + index = plot_xvar_ind) + st.session_state.plot_yvar = st.selectbox("Default Y Var", df.columns, key=f"plot_yvar_init", + index = plot_yvar_ind) + st.session_state.sel_var = st.session_state.plot_yvar + + st.session_state.plot_hvar = st.selectbox("Default Hue Var", df.columns, key=f"plot_hvar_init", + index = plot_hvar_ind) + trend_index = st.session_state.trend_types.index(st.session_state.plot_trend) + st.session_state.plot_trend = st.selectbox("Default Trend Line", st.session_state.trend_types, key=f"trend_type_init", index = trend_index) # Button to add a new plot @@ -277,18 +296,18 @@ def filter_dataframe(df: pd.DataFrame, pid) -> pd.DataFrame: # Read plot ids df_p = st.session_state.plots - p_index = df_p.PID.tolist() + list_plots = df_p.index.tolist() plot_per_raw = st.session_state.plot_per_raw # Render plots # - iterates over plots; # - for every "plot_per_raw" plots, creates a new columns block, resets column index, and displays the plot - for i in range(0, len(p_index)): + for i, plot_ind in enumerate(list_plots): column_no = i % plot_per_raw if column_no == 0: blocks = st.columns(plot_per_raw) with blocks[column_no]: - display_plot(p_index[i]) + display_plot(plot_ind) # FIXME: this is for debugging; will be removed