Skip to content

Commit

Permalink
Merge pull request #71 from AllenNeuralDynamics/han_add_hmm_glm
Browse files Browse the repository at this point in the history
fix: some bugs in HMM-GLM
  • Loading branch information
hanhou authored Jul 6, 2024
2 parents 0223066 + 04c24a8 commit 59442dc
Showing 1 changed file with 99 additions and 74 deletions.
173 changes: 99 additions & 74 deletions code/pages/2_HMM-GLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,25 @@
Han, ChatGPT
"""
import os
import re
from PIL import Image

import streamlit as st
import s3fs
import streamlit_nested_layout

try:
st.set_page_config(layout="wide",
page_title='Foraging behavior browser',
page_icon=':mouse2:',
menu_items={
'Report a bug': "https://github.com/hanhou/foraging-behavior-browser/issues",
'About': "Github repo: https://github.com/hanhou/foraging-behavior-browser/"
}
)
except:
pass

# Set up the S3 bucket and prefix
bucket_name = "s3://aind-behavior-data/faeze/HMM-GLM"

Expand All @@ -23,86 +36,98 @@ def open_image_from_s3(file_key, crop=None, **kwargs):
img = img.crop(crop)

st.image(img, **kwargs)

# Get the list of data folders
data_folders = [os.path.basename(f) for f in
s3.glob(f'{bucket_name}/*')
]

# Data folder selection dropdown
with st.sidebar:
widget_data_folder = st.container()
widget_mouse = st.container()
widget_n_states = st.container()
widget_model_comparison = st.container()

if st.button('Reload data from S3'):
st.cache_data.clear()
st.rerun()

data_folder_selected = widget_data_folder.selectbox('Select Data Folder', data_folders)
# Function to extract the session number from file name
def extract_session_number(file_name):
match = re.search(r'sess_(\d+)', file_name)
return int(match.group(1)) if match else None

if data_folder_selected:
# Get the list of mice folders
mice = [os.path.basename(f) for f in
s3.glob(f'{bucket_name}/{data_folder_selected}/*')
if not os.path.basename(f).startswith('.')

# --------------- Main app -------------------
def app():
# Get the list of data folders
data_folders = [os.path.basename(f) for f in
s3.glob(f'{bucket_name}/*')
]
# Mouse selection dropdown
mouse_selected = widget_mouse.selectbox('Select Mouse', mice)

# Show mouse-wise figures
if mouse_selected:
mouse_folder = f'{bucket_name}/{data_folder_selected}/{mouse_selected}'
# Data folder selection dropdown
with st.sidebar:
widget_data_folder = st.container()
widget_mouse = st.container()
widget_n_states = st.container()
widget_model_comparison = st.container()

with widget_model_comparison:
fig_model_comparisons = ['AIC.png', 'BIC.png', 'LL.png']
height = st.slider('Session container height', 100, 3000, 700, step=100)
if st.button('Reload data from S3'):
st.cache_data.clear()
st.rerun()

data_folder_selected = widget_data_folder.selectbox('Select Data Folder', data_folders)

for i, fig_model_comparison in enumerate(fig_model_comparisons):
img = open_image_from_s3(
f'{mouse_folder}/{fig_model_comparison}',
caption=fig_model_comparison,
)
if data_folder_selected:
# Get the list of mice folders
mice = [os.path.basename(f) for f in
s3.glob(f'{bucket_name}/{data_folder_selected}/*')
if not os.path.basename(f).startswith('.')
]
# Mouse selection dropdown
mouse_selected = widget_mouse.selectbox('Select Mouse', mice)

# Number of states selection
num_states = widget_n_states.selectbox('Select Number of States',
['two_states', 'three_states', 'four_states'],
index=2, # Default shows four states
)

if num_states:
num_states_folder = f'{mouse_folder}/{num_states}'
fig_states = ['GLM_Weights.png', 'GLM_TM.png', 'frac_occupancy.png', 'frac_occupancy_of_sessions.png']

cols = st.columns([1, 0.2, 1])
with cols[0]:
open_image_from_s3(f'{num_states_folder}/GLM_Weights.png',
caption='GLM weights')
with cols[2]:
open_image_from_s3(f'{num_states_folder}/frac_occupancy_of_sessions.png',
caption='Fraction occupancy over sessions')
with cols[1]:
open_image_from_s3(f'{num_states_folder}/frac_occupancy.png',
caption='Fraction occupancy (all)')
open_image_from_s3(f'{num_states_folder}/GLM_TM.png',
caption='Fraction occupancy (all)')

# Grouped by selection
with st.container(height=2000):
cols = st.columns([1.5, 1, 4])
grouped_by = cols[0].selectbox('Grouped By', ['grouped_by_sessions', 'grouped_by_sessions_conventional_view', 'grouped_by_states'])
num_cols = cols[1].number_input(
label='number of columns',
min_value=1,
max_value=10,
value=2,
)

if grouped_by:
fig_sessions = s3.glob(f'{num_states_folder}/{grouped_by}/*.png')
cols = st.columns(num_cols)
# Show mouse-wise figures
if mouse_selected:
mouse_folder = f'{bucket_name}/{data_folder_selected}/{mouse_selected}'

with widget_model_comparison:
fig_model_comparisons = ['AIC.png', 'BIC.png', 'LL.png']

for i, fig_model_comparison in enumerate(fig_model_comparisons):
img = open_image_from_s3(
f'{mouse_folder}/{fig_model_comparison}',
caption=fig_model_comparison,
)

# Number of states selection
num_states = widget_n_states.selectbox('Select Number of States',
['two_states', 'three_states', 'four_states'],
index=2, # Default shows four states
)

if num_states:
num_states_folder = f'{mouse_folder}/{num_states}'
fig_states = ['GLM_Weights.png', 'GLM_TM.png', 'frac_occupancy.png', 'frac_occupancy_of_sessions.png']

cols = st.columns([1, 0.2, 1])
with cols[0]:
open_image_from_s3(f'{num_states_folder}/GLM_Weights.png',
caption='GLM weights')
with cols[2]:
open_image_from_s3(f'{num_states_folder}/frac_occupancy_of_sessions.png',
caption='Fraction occupancy over sessions')
with cols[1]:
open_image_from_s3(f'{num_states_folder}/frac_occupancy.png',
caption='Fraction occupancy (all)')
open_image_from_s3(f'{num_states_folder}/GLM_TM.png',
caption='Fraction occupancy (all)')

# Grouped by selection
with st.container(height=height):
cols = st.columns([1.5, 1, 4])
grouped_by = cols[0].selectbox('Grouped By', ['grouped_by_sessions', 'grouped_by_sessions_conventional_view', 'grouped_by_states'])
num_cols = cols[1].number_input(
label='number of columns',
min_value=1,
max_value=10,
value=2,
)

for i, fig_session in enumerate(fig_sessions):
with cols[i % num_cols]:
img = open_image_from_s3(fig_session, caption=fig_session)
if grouped_by:
fig_sessions = s3.glob(f'{num_states_folder}/{grouped_by}/*.png')
fig_sessions = sorted(fig_sessions, key=extract_session_number)
cols = st.columns(num_cols)

for i, fig_session in enumerate(fig_sessions):
with cols[i % num_cols]:
img = open_image_from_s3(fig_session, caption=fig_session)


app()

0 comments on commit 59442dc

Please sign in to comment.