generated from AllenNeuralDynamics/aind-capsule-template
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #70 from AllenNeuralDynamics/han_add_hmm_glm
feat: add HMM-GLM for Faeze
- Loading branch information
Showing
1 changed file
with
108 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
"""Visualize HMM-GLM results from Faeze | ||
Han, ChatGPT | ||
""" | ||
import os | ||
from PIL import Image | ||
|
||
import streamlit as st | ||
import s3fs | ||
import streamlit_nested_layout | ||
|
||
# Set up the S3 bucket and prefix | ||
bucket_name = "s3://aind-behavior-data/faeze/HMM-GLM" | ||
|
||
# Initialize s3fs | ||
s3 = s3fs.S3FileSystem(anon=False) | ||
|
||
# Function to open image from S3 | ||
@st.cache_data(ttl=24*3600) | ||
def open_image_from_s3(file_key, crop=None, **kwargs): | ||
with s3.open(f'{file_key}', 'rb') as f: | ||
img = Image.open(f) | ||
img = img.crop(crop) | ||
|
||
st.image(img, **kwargs) | ||
|
||
# Get the list of data folders | ||
data_folders = [os.path.basename(f) for f in | ||
s3.glob(f'{bucket_name}/*') | ||
] | ||
|
||
# Data folder selection dropdown | ||
with st.sidebar: | ||
widget_data_folder = st.container() | ||
widget_mouse = st.container() | ||
widget_n_states = st.container() | ||
widget_model_comparison = st.container() | ||
|
||
if st.button('Reload data from S3'): | ||
st.cache_data.clear() | ||
st.rerun() | ||
|
||
data_folder_selected = widget_data_folder.selectbox('Select Data Folder', data_folders) | ||
|
||
if data_folder_selected: | ||
# Get the list of mice folders | ||
mice = [os.path.basename(f) for f in | ||
s3.glob(f'{bucket_name}/{data_folder_selected}/*') | ||
if not os.path.basename(f).startswith('.') | ||
] | ||
# Mouse selection dropdown | ||
mouse_selected = widget_mouse.selectbox('Select Mouse', mice) | ||
|
||
# Show mouse-wise figures | ||
if mouse_selected: | ||
mouse_folder = f'{bucket_name}/{data_folder_selected}/{mouse_selected}' | ||
|
||
with widget_model_comparison: | ||
fig_model_comparisons = ['AIC.png', 'BIC.png', 'LL.png'] | ||
|
||
for i, fig_model_comparison in enumerate(fig_model_comparisons): | ||
img = open_image_from_s3( | ||
f'{mouse_folder}/{fig_model_comparison}', | ||
caption=fig_model_comparison, | ||
) | ||
|
||
# Number of states selection | ||
num_states = widget_n_states.selectbox('Select Number of States', | ||
['two_states', 'three_states', 'four_states'], | ||
index=2, # Default shows four states | ||
) | ||
|
||
if num_states: | ||
num_states_folder = f'{mouse_folder}/{num_states}' | ||
fig_states = ['GLM_Weights.png', 'GLM_TM.png', 'frac_occupancy.png', 'frac_occupancy_of_sessions.png'] | ||
|
||
cols = st.columns([1, 0.2, 1]) | ||
with cols[0]: | ||
open_image_from_s3(f'{num_states_folder}/GLM_Weights.png', | ||
caption='GLM weights') | ||
with cols[2]: | ||
open_image_from_s3(f'{num_states_folder}/frac_occupancy_of_sessions.png', | ||
caption='Fraction occupancy over sessions') | ||
with cols[1]: | ||
open_image_from_s3(f'{num_states_folder}/frac_occupancy.png', | ||
caption='Fraction occupancy (all)') | ||
open_image_from_s3(f'{num_states_folder}/GLM_TM.png', | ||
caption='Fraction occupancy (all)') | ||
|
||
# Grouped by selection | ||
with st.container(height=2000): | ||
cols = st.columns([1.5, 1, 4]) | ||
grouped_by = cols[0].selectbox('Grouped By', ['grouped_by_sessions', 'grouped_by_sessions_conventional_view', 'grouped_by_states']) | ||
num_cols = cols[1].number_input( | ||
label='number of columns', | ||
min_value=1, | ||
max_value=10, | ||
value=2, | ||
) | ||
|
||
if grouped_by: | ||
fig_sessions = s3.glob(f'{num_states_folder}/{grouped_by}/*.png') | ||
cols = st.columns(num_cols) | ||
|
||
for i, fig_session in enumerate(fig_sessions): | ||
with cols[i % num_cols]: | ||
img = open_image_from_s3(fig_session, caption=fig_session) | ||
|