Skip to content

Commit

Permalink
Merge pull request #72 from AllenNeuralDynamics/han_add_hmm_glm
Browse files Browse the repository at this point in the history
feat: bug fix and improve grouped_by_states
  • Loading branch information
hanhou authored Jul 7, 2024
2 parents 59442dc + 212f7e9 commit 1e18227
Showing 1 changed file with 42 additions and 11 deletions.
53 changes: 42 additions & 11 deletions code/pages/2_HMM-GLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
import os
import re
import numpy as np
from PIL import Image

import streamlit as st
Expand Down Expand Up @@ -42,6 +43,9 @@ def extract_session_number(file_name):
match = re.search(r'sess_(\d+)', file_name)
return int(match.group(1)) if match else None

def extract_state_part_number(file_name):
match = re.search(r'state-(\d+)_part-(\d+)', file_name)
return int(match.group(1)), int(match.group(2)) if match else None

# --------------- Main app -------------------
def app():
Expand Down Expand Up @@ -111,23 +115,50 @@ def app():

# Grouped by selection
with st.container(height=height):
cols = st.columns([1.5, 1, 4])
cols = st.columns([2, 1, 4])
grouped_by = cols[0].selectbox('Grouped By', ['grouped_by_sessions', 'grouped_by_sessions_conventional_view', 'grouped_by_states'])
num_cols = cols[1].number_input(
label='number of columns',
min_value=1,
max_value=10,
value=2,
)

fig_sessions = s3.glob(f'{num_states_folder}/{grouped_by}/*.png')

if grouped_by:
fig_sessions = s3.glob(f'{num_states_folder}/{grouped_by}/*.png')
fig_sessions = sorted(fig_sessions, key=extract_session_number)
if 'by_sessions' in grouped_by:
# Choose number of columns
num_cols = cols[1].number_input(
label='number of columns',
min_value=1,
max_value=10,
value=2,
)

# Show plots
cols = st.columns(num_cols)

fig_sessions = sorted(fig_sessions, key=extract_session_number)
for i, fig_session in enumerate(fig_sessions):
with cols[i % num_cols]:
img = open_image_from_s3(fig_session, caption=fig_session)

elif 'by_states' in grouped_by:
# Interpret file names
state_part = np.array([extract_state_part_number(f)
for f in fig_sessions])
unique_states = np.unique(state_part[:, 0])

# Select states to compare
selected_state_to_compare = cols[2].multiselect('States to compare',
options=unique_states,
default=unique_states,
)

# Show plots
cols_state = st.columns(len(selected_state_to_compare))
for n, state in enumerate(selected_state_to_compare):
with cols_state[n]:
st.markdown(f'State {state}')
ids_this_state = np.where(state_part[:, 0] == state)[0]
for id in ids_this_state:
open_image_from_s3(
fig_sessions[id],
caption='',
)


app()

0 comments on commit 1e18227

Please sign in to comment.