Skip to content

Commit

Permalink
minor improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhou committed Feb 27, 2023
1 parent 8d7e5b2 commit 3bf4322
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 43 deletions.
94 changes: 58 additions & 36 deletions code/Home.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,26 +123,7 @@ def draw_session_plots(keys_to_draw_session):
[1.5, 1], # columns in the second row
]

draw_type_mapper = {'1. Choice history': ('fitted_choice', # prefix
(0, 0), # location (row_idx, column_idx)
dict(other_patterns=['model_best', 'model_None'])),
'2. Lick times': ('lick_psth',
(1, 0),
{}),
'3. Logistic regression on choice': ('logistic_regression',
(1, 1),
dict(crop=(0, 0, 1200, 2000))),
'4. Win-stay-lose-shift prob.': ('wsls',
(1, 1),
dict(crop=(0, 0, 1200, 600))),
'5. Linear regression on RT': ('linear_regression_rt',
(1, 0),
dict()),
}

cols_option = st.columns([3, 0.5, 1])
selected_draw_types = cols_option[0].multiselect('Which plot(s) to draw?', draw_type_mapper.keys(), default=draw_type_mapper.keys())
num_cols = cols_option[1].number_input('Number of columns', 1, 10, 2)
# cols_option = st.columns([3, 0.5, 1])
container_session_all_in_one = st.container()

with container_session_all_in_one:
Expand All @@ -152,30 +133,30 @@ def draw_session_plots(keys_to_draw_session):
st.write(f'Loading selected {len(keys_to_draw_session)} sessions...')
my_bar = st.columns((1, 7))[0].progress(0)

major_cols = st.columns([1] * num_cols)
major_cols = st.columns([1] * st.session_state.num_cols)

if not isinstance(keys_to_draw_session, list): # Turn dataframe to list, if necessary
keys_to_draw_session = keys_to_draw_session.to_dict(orient='records')

for i, key in enumerate(keys_to_draw_session):
this_major_col = major_cols[i % num_cols]
this_major_col = major_cols[i % st.session_state.num_cols]

# setting up layout for each session
rows = []
with this_major_col:
st.markdown(f'''<h3 style='text-align: center; color: blue;'>{key["h2o"]}, Session {key["session"]}, {key["session_date"].split("T")[0]}''',
st.markdown(f'''<h3 style='text-align: center; color: orange;'>{key["h2o"]}, Session {key["session"]}, {key["session_date"].split("T")[0]}''',
unsafe_allow_html=True)
if len(selected_draw_types) > 1: # more than one types, use the pre-defined layout
if len(st.session_state.selected_draw_types) > 1: # more than one types, use the pre-defined layout
for row, column_setting in enumerate(layout_definition):
rows.append(this_major_col.columns(column_setting))
else: # else, put it in the whole column
rows = this_major_col.columns([1])
st.markdown("---")

for draw_type in draw_type_mapper:
if draw_type not in selected_draw_types: continue # To keep the draw order defined by draw_type_mapper
prefix, position, setting = draw_type_mapper[draw_type]
this_col = rows[position[0]][position[1]] if len(selected_draw_types) > 1 else rows[0]
for draw_type in st.session_state.draw_type_mapper:
if draw_type not in st.session_state.selected_draw_types: continue # To keep the draw order defined by st.session_state.draw_type_mapper
prefix, position, setting = st.session_state.draw_type_mapper[draw_type]
this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types) > 1 else rows[0]
show_img_by_key_and_prefix(key,
column=this_col,
prefix=prefix,
Expand Down Expand Up @@ -350,7 +331,7 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, col):

fig.update_layout(
# width=1300,
height=850,
# height=850,
xaxis_title=x_name,
yaxis_title=y_name,
# xaxis_range=[0, min(100, df[x_name].max())],
Expand All @@ -361,9 +342,14 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, col):
)

# st.plotly_chart(fig)
selected_sessions_from_plot = plotly_events(fig, click_event=True, hover_event=False, select_event=True, override_height=870)
selected_sessions_from_plot = plotly_events(fig, click_event=True, hover_event=False, select_event=True, override_height=870, override_width=1400)

return selected_sessions_from_plot

def session_plot_settings():
st.session_state.selected_draw_types = st.multiselect('Which plot(s) to draw?', st.session_state.draw_type_mapper.keys(), default=st.session_state.draw_type_mapper.keys())
st.session_state.num_cols = st.columns([1, 3])[0].number_input('Number of columns', 1, 10, 2)



def population_analysis():
Expand Down Expand Up @@ -393,7 +379,19 @@ def population_analysis():

smooth_factor = s_cols[0].slider('Smooth factor', 1, 20, 5, disabled=not ((if_aggr_each_group and aggr_method_group=='lowess')
or (if_aggr_all and aggr_method_all=='lowess')))

for i in range(7): st.write('\n')

st.markdown("***")
st.markdown('##### Click or box/lasso select session(s) from the plots to draw 👉')
session_plot_settings()

with st.expander(f'{len(st.session_state.df_selected_from_plotly)} sessions selected from plotly', expanded=False):
if st.button('clear selection'):
st.session_state.df_selected_from_plotly = pd.DataFrame()

with st.expander('show sessions', expanded=False):
st.dataframe(st.session_state.df_selected_from_plotly)


names = {('session', 'foraging_eff'): 'Foraging efficiency',
('session', 'finished'): 'Finished trials',
Expand Down Expand Up @@ -442,11 +440,30 @@ def init():

st.session_state.df_selected_from_plotly = pd.DataFrame()

# Init session states
# add some model fitting params to session
if not 'model_id' in st.session_state:
st.session_state.model_id = 21
selected_id = st.session_state.model_id

st.session_state.draw_type_mapper = {'1. Choice history': ('fitted_choice', # prefix
(0, 0), # location (row_idx, column_idx)
dict(other_patterns=['model_best', 'model_None'])),
'2. Lick times': ('lick_psth',
(1, 0),
{}),
'3. Logistic regression on choice': ('logistic_regression',
(1, 1),
dict(crop=(0, 0, 1200, 2000))),
'4. Win-stay-lose-shift prob.': ('wsls',
(1, 1),
dict(crop=(0, 0, 1200, 600))),
'5. Linear regression on RT': ('linear_regression_rt',
(1, 0),
dict()),
}

# process dfs
df_this_model = st.session_state.df['model_fitting_params'].query(f'model_id == {selected_id}')
valid_field = df_this_model.columns[~np.all(~df_this_model.notna(), axis=0)]
to_add_model = st.session_state.df['model_fitting_params'].query(f'model_id == {selected_id}')[valid_field]
Expand All @@ -458,6 +475,7 @@ def init():

st.session_state.session_stats_names = [keys for keys in st.session_state.df['sessions'].keys()]




def app():
Expand All @@ -472,6 +490,9 @@ def app():
st.cache_data.clear()
init()
st.experimental_rerun()

st.markdown('---')
st.write('Han Hou @ 2023\nv1.0.0')


with st.container():
Expand All @@ -496,17 +517,17 @@ def app():

st.session_state.aggrid_outputs = aggrid_interactive_table_session(df=st.session_state.df_session_filtered)

chosen_id = stx.tab_bar(data=[
stx.TabBarItemData(id="tab1", title="📈Training summary", description="Plot training summary"),
stx.TabBarItemData(id="tab2", title="📚Session inspection", description="Generate plots for each session"),
], default="tab1")
# chosen_id = stx.tab_bar(data=[
# stx.TabBarItemData(id="tab1", title="📈Training summary", description="Plot training summary"),
# stx.TabBarItemData(id="tab2", title="📚Session inspection", description="Generate plots for each session"),
# ], default="tab1")
chosen_id = "tab1"

placeholder = st.container()

if chosen_id == "tab1":
with placeholder:
df_selected_from_plotly = population_analysis()
st.markdown('##### Select session(s) from the plots above to draw')

if len(st.session_state.df_selected_from_plotly):
draw_session_plots(st.session_state.df_selected_from_plotly)
Expand All @@ -519,6 +540,7 @@ def app():
with placeholder:
selected_keys_from_aggrid = st.session_state.aggrid_outputs['selected_rows']
st.markdown('##### Select session(s) from the table above to draw')
session_plot_settings()
draw_session_plots(selected_keys_from_aggrid)

# st.dataframe(st.session_state.df_session_filtered, use_container_width=True, height=1000)
Expand Down
8 changes: 1 addition & 7 deletions code/streamlit_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,4 @@ def add_session_filter():
with st.expander("Behavioral session filter", expanded=True):
st.session_state.df_session_filtered = filter_dataframe(df=st.session_state.df['sessions'])
st.markdown(f"{len(st.session_state.df_session_filtered)} sessions filtered (use_s3 = {st.session_state.use_s3})")

with st.expander(f'{len(st.session_state.df_selected_from_plotly)} sessions selected from plotly', expanded=True):
if st.button('clear selection'):
st.session_state.df_selected_from_plotly = pd.DataFrame()

with st.expander('show sessions', expanded=False):
st.dataframe(st.session_state.df_selected_from_plotly)

0 comments on commit 3bf4322

Please sign in to comment.