diff --git a/code/pages/2_HMM-GLM.py b/code/pages/2_HMM-GLM.py index 8e2c67f..f3e0ac5 100644 --- a/code/pages/2_HMM-GLM.py +++ b/code/pages/2_HMM-GLM.py @@ -4,6 +4,7 @@ """ import os import re +import numpy as np from PIL import Image import streamlit as st @@ -42,6 +43,9 @@ def extract_session_number(file_name): match = re.search(r'sess_(\d+)', file_name) return int(match.group(1)) if match else None +def extract_state_part_number(file_name): + match = re.search(r'state-(\d+)_part-(\d+)', file_name) + return int(match.group(1)), int(match.group(2)) if match else None # --------------- Main app ------------------- def app(): @@ -111,23 +115,50 @@ def app(): # Grouped by selection with st.container(height=height): - cols = st.columns([1.5, 1, 4]) + cols = st.columns([2, 1, 4]) grouped_by = cols[0].selectbox('Grouped By', ['grouped_by_sessions', 'grouped_by_sessions_conventional_view', 'grouped_by_states']) - num_cols = cols[1].number_input( - label='number of columns', - min_value=1, - max_value=10, - value=2, - ) + + fig_sessions = s3.glob(f'{num_states_folder}/{grouped_by}/*.png') - if grouped_by: - fig_sessions = s3.glob(f'{num_states_folder}/{grouped_by}/*.png') - fig_sessions = sorted(fig_sessions, key=extract_session_number) + if 'by_sessions' in grouped_by: + # Choose number of columns + num_cols = cols[1].number_input( + label='number of columns', + min_value=1, + max_value=10, + value=2, + ) + + # Show plots cols = st.columns(num_cols) - + fig_sessions = sorted(fig_sessions, key=extract_session_number) for i, fig_session in enumerate(fig_sessions): with cols[i % num_cols]: img = open_image_from_s3(fig_session, caption=fig_session) + + elif 'by_states' in grouped_by: + # Interpret file names + state_part = np.array([extract_state_part_number(f) + for f in fig_sessions]) + unique_states = np.unique(state_part[:, 0]) + + # Select states to compare + selected_state_to_compare = cols[2].multiselect('States to compare', + options=unique_states, + default=unique_states, + ) + + # Show plots + cols_state = st.columns(len(selected_state_to_compare)) + for n, state in enumerate(selected_state_to_compare): + with cols_state[n]: + st.markdown(f'State {state}') + ids_this_state = np.where(state_part[:, 0] == state)[0] + for id in ids_this_state: + open_image_from_s3( + fig_sessions[id], + caption='', + ) app() \ No newline at end of file