Skip to content

Commit

Permalink
Merge pull request #70 from AllenNeuralDynamics/han_add_hmm_glm
Browse files Browse the repository at this point in the history
feat: add HMM-GLM for Faeze
  • Loading branch information
hanhou authored Jul 6, 2024
2 parents a491b00 + e6a0fd5 commit 0223066
Showing 1 changed file with 108 additions and 0 deletions.
108 changes: 108 additions & 0 deletions code/pages/2_HMM-GLM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Visualize HMM-GLM results from Faeze
Han, ChatGPT
"""
import os
from PIL import Image

import streamlit as st
import s3fs
import streamlit_nested_layout

# Set up the S3 bucket and prefix
bucket_name = "s3://aind-behavior-data/faeze/HMM-GLM"

# Initialize s3fs
s3 = s3fs.S3FileSystem(anon=False)

# Function to open image from S3
@st.cache_data(ttl=24*3600)
def open_image_from_s3(file_key, crop=None, **kwargs):
with s3.open(f'{file_key}', 'rb') as f:
img = Image.open(f)
img = img.crop(crop)

st.image(img, **kwargs)

# Get the list of data folders
data_folders = [os.path.basename(f) for f in
s3.glob(f'{bucket_name}/*')
]

# Data folder selection dropdown
with st.sidebar:
widget_data_folder = st.container()
widget_mouse = st.container()
widget_n_states = st.container()
widget_model_comparison = st.container()

if st.button('Reload data from S3'):
st.cache_data.clear()
st.rerun()

data_folder_selected = widget_data_folder.selectbox('Select Data Folder', data_folders)

if data_folder_selected:
# Get the list of mice folders
mice = [os.path.basename(f) for f in
s3.glob(f'{bucket_name}/{data_folder_selected}/*')
if not os.path.basename(f).startswith('.')
]
# Mouse selection dropdown
mouse_selected = widget_mouse.selectbox('Select Mouse', mice)

# Show mouse-wise figures
if mouse_selected:
mouse_folder = f'{bucket_name}/{data_folder_selected}/{mouse_selected}'

with widget_model_comparison:
fig_model_comparisons = ['AIC.png', 'BIC.png', 'LL.png']

for i, fig_model_comparison in enumerate(fig_model_comparisons):
img = open_image_from_s3(
f'{mouse_folder}/{fig_model_comparison}',
caption=fig_model_comparison,
)

# Number of states selection
num_states = widget_n_states.selectbox('Select Number of States',
['two_states', 'three_states', 'four_states'],
index=2, # Default shows four states
)

if num_states:
num_states_folder = f'{mouse_folder}/{num_states}'
fig_states = ['GLM_Weights.png', 'GLM_TM.png', 'frac_occupancy.png', 'frac_occupancy_of_sessions.png']

cols = st.columns([1, 0.2, 1])
with cols[0]:
open_image_from_s3(f'{num_states_folder}/GLM_Weights.png',
caption='GLM weights')
with cols[2]:
open_image_from_s3(f'{num_states_folder}/frac_occupancy_of_sessions.png',
caption='Fraction occupancy over sessions')
with cols[1]:
open_image_from_s3(f'{num_states_folder}/frac_occupancy.png',
caption='Fraction occupancy (all)')
open_image_from_s3(f'{num_states_folder}/GLM_TM.png',
caption='Fraction occupancy (all)')

# Grouped by selection
with st.container(height=2000):
cols = st.columns([1.5, 1, 4])
grouped_by = cols[0].selectbox('Grouped By', ['grouped_by_sessions', 'grouped_by_sessions_conventional_view', 'grouped_by_states'])
num_cols = cols[1].number_input(
label='number of columns',
min_value=1,
max_value=10,
value=2,
)

if grouped_by:
fig_sessions = s3.glob(f'{num_states_folder}/{grouped_by}/*.png')
cols = st.columns(num_cols)

for i, fig_session in enumerate(fig_sessions):
with cols[i % num_cols]:
img = open_image_from_s3(fig_session, caption=fig_session)

0 comments on commit 0223066

Please sign in to comment.