Skip to content

Commit

Permalink
Edit state vars to connect image viewer to plotting view
Browse files Browse the repository at this point in the history
  • Loading branch information
gurayerus committed Sep 15, 2024
1 parent 1c8affa commit e4f0d52
Show file tree
Hide file tree
Showing 23 changed files with 2,454 additions and 923 deletions.
5 changes: 5 additions & 0 deletions src/NiChart_Viewer/src/pages/home.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
# st.session_state.in_csv_MUSE = f'{st.session_state.dir_root}/test/test_input/test2_rois/Study1/Study1_DLMUSE.csv'
# st.session_state.in_csv_Demog = f'{st.session_state.dir_root}/test/test_input/test2_rois/Study1/Study1_Demog.csv'

# FIXME: Set these variables when the images are loaded or computed
st.session_state.dir_t1img = st.session_state.dir_root + '/test/test_input/test3_nifti+roi'
st.session_state.dir_dlmuse = st.session_state.dir_root + '/test/test_input/test3_nifti+roi'
st.session_state.suffix_t1img = '_T1.nii.gz'
st.session_state.suffix_dlmuse = '_T1_DLMUSE.nii.gz'

# Path to out folder
st.session_state.out_dir = ''
Expand Down
130 changes: 74 additions & 56 deletions src/NiChart_Viewer/src/pages/view_img.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import pandas as pd
import streamlit as st
from pandas.api.types import (
Expand All @@ -18,11 +19,16 @@
# st.session_state.pid = 1
# st.session_state.instantiated = True

# Parameters for viewer
VIEWS = ["axial", "sagittal", "coronal"]
VIEW_AXES = [0, 1, 2]
VIEW_OTHER_AXES = [(1,2), (0,2), (0,1)]
MASK_COLOR = (0, 255, 0) # RGB format

def reorient_nifti(nii_in, ref_orient = 'LPS'):
'''
Initial img is reoriented to a standard orientation
'''

# Find transform from current (approximate) orientation to
# target, in nibabel orientation matrix and affine forms
Expand All @@ -40,6 +46,7 @@ def crop_image(img, mask):
'''
Crop img to the foreground of the mask
'''

# Detect bounding box
nz = np.nonzero(mask)
mn = np.min(nz, axis=1)
Expand Down Expand Up @@ -74,7 +81,9 @@ def crop_image(img, mask):
def detect_mask_bounds(mask):
'''
Detect the mask start, end and center in each view
Used later to set the slider in the image viewer
'''

mask_bounds = np.zeros([3,3]).astype(int)
for i, axis in enumerate(VIEW_AXES):
mask_bounds[i,0] = 0
Expand All @@ -91,6 +100,7 @@ def show_nifti(img, view, sel_axis_bounds):
'''
Displays the nifti img
'''

# Create a slider to select the slice index
slice_index = st.slider(f"{view}", 0, sel_axis_bounds[1] - 1,
value=sel_axis_bounds[2], key = f'slider_{view}')
Expand All @@ -105,6 +115,10 @@ def show_nifti(img, view, sel_axis_bounds):

@st.cache_data
def prep_images(f_img, f_mask, sel_roi_ind):
'''
Read images from files and create 3D matrices for display
'''

# Read nifti
nii_img = nib.load(f_img)
nii_mask = nib.load(f_mask)
Expand All @@ -127,83 +141,87 @@ def prep_images(f_img, f_mask, sel_roi_ind):
img = np.stack((img,)*3, axis=-1)

img_masked = img.copy()
img_masked[mask == 1] = mask_color
img_masked[mask == 1] = MASK_COLOR

# Scale values
img = img / img.max()
img_masked = img_masked / img_masked.max()

return img, mask, img_masked

# # Config page
# st.set_page_config(page_title="DataFrame Demo", page_icon="📊", layout='wide')

# FIXME: Input data is hardcoded here for now
# fname = "../examples/test_input/vTest1/Study1/StudyTest1_DLMUSE_All.csv"
fname = "../examples/test_input3/ROIS_tmp.csv"
df = pd.read_csv(fname)

f1 = "../examples/test_input3/IXI002-Guys-0828_T1.nii.gz"
f2 = "../examples/test_input3/IXI002-Guys-0828_T1_DLMUSE.nii.gz"

sel_roi = 'Ventricles'
mask_color = (0, 255, 0) # RGB format

# Select roi index
# FIXME: This will be read from file
dict_roi = {'Ventricles':51, 'Hippocampus_R':100, 'Hippocampus_L':48}
sel_roi_ind = dict_roi[sel_roi]

# Process image and mask to prepare final 3d matrix to display
img, mask, img_masked = prep_images(f1, f2, sel_roi_ind)

# Read dataframe with subject mrids
df = pd.read_csv(st.session_state.fname_subj_list)

# Page controls in side bar
with st.sidebar:

# Show selected id (while providing the user the option to select it from the list of all MRIDs)
# - get the selected id from the session_state
# - create a selectbox with all MRIDs
# -- initialize it with the selected id if it's set
# -- initialize it with the first id if not
sel_id = st.session_state.sel_id
if sel_id == '':
sel_ind = 0
sel_type = '(auto)'
else:
sel_ind = df.MRID.tolist().index(sel_id)
sel_type = '(user)'
sel_id = st.selectbox("Select Subject", df.MRID.tolist(), key=f"selbox_mrid", index = sel_ind)
with st.container(border=True):

# st.sidebar.warning('Selected subject: ' + mrid)
st.warning(f'Selected {sel_type}: {sel_id}')
# Selection of MRID
sel_id = st.session_state.sel_id
if sel_id == '':
sel_ind = 0
sel_type = '(auto)'
else:
sel_ind = df.MRID.tolist().index(sel_id)
sel_type = '(user)'
sel_id = st.selectbox("Select Subject", df.MRID.tolist(), key=f"selbox_mrid", index = sel_ind)

st.write('---')
# st.sidebar.warning('Selected subject: ' + mrid)
st.warning(f'Selected {sel_type}: {sel_id}')

## FIXME: read list of rois from dictionary
## show the roi selected in the plot
sel_roi = st.selectbox("Select ROI", list(dict_roi.keys()), key=f"selbox_rois", index = 0)
## FIXME: read list of rois from dictionary
## show the roi selected in the plot
sel_roi = st.selectbox("Select ROI", list(dict_roi.keys()), key=f"selbox_rois", index = 0)

st.write('---')
with st.container(border=True):

# Create a list of checkbox options
#list_orient = st.multiselect("Select viewing planes:", VIEWS, VIEWS[0])
list_orient = st.multiselect("Select viewing planes:", VIEWS, VIEWS)
# Create a list of checkbox options
#list_orient = st.multiselect("Select viewing planes:", VIEWS, VIEWS[0])
list_orient = st.multiselect("Select viewing planes:", VIEWS, VIEWS)

# View hide overlay
is_show_overlay = st.checkbox('Show overlay', True)
# View hide overlay
is_show_overlay = st.checkbox('Show overlay', True)

# Print the selected options (optional)
if list_orient:
st.write("Selected options:", list_orient)
# Print the selected options (optional)
if list_orient:
st.write("Selected options:", list_orient)

# Detect mask bounds and center in each view
mask_bounds = detect_mask_bounds(mask)
sel_roi = st.session_state.sel_roi

# Show images
blocks = st.columns(len(list_orient))
for i, tmp_orient in enumerate(list_orient):
with blocks[i]:
if is_show_overlay == False:
show_nifti(img, tmp_orient, mask_bounds[i,:])
else:
show_nifti(img_masked, tmp_orient, mask_bounds[i,:])
# Select roi index
sel_roi_ind = dict_roi[sel_roi]


# File names for img and mask
f_img = os.path.join(st.session_state.dir_t1img, sel_id + st.session_state.suffix_t1img)
f_mask = os.path.join(st.session_state.dir_dlmuse, sel_id + st.session_state.suffix_dlmuse)

if os.path.exists(f_img) & os.path.exists(f_mask):

# Process image and mask to prepare final 3d matrix to display
img, mask, img_masked = prep_images(f_img, f_mask, sel_roi_ind)

# Detect mask bounds and center in each view
mask_bounds = detect_mask_bounds(mask)

# Show images
blocks = st.columns(len(list_orient))
for i, tmp_orient in enumerate(list_orient):
with blocks[i]:
if is_show_overlay == False:
show_nifti(img, tmp_orient, mask_bounds[i,:])
else:
show_nifti(img_masked, tmp_orient, mask_bounds[i,:])

else:
if not os.path.exists(f_img):
st.sidebar.warning(f'Image not found: {f_img}')
else:
st.sidebar.warning(f'Mask not found: {f_mask}')

5 changes: 5 additions & 0 deletions src/NiChart_Viewer/src/pages/view_plot_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def display_plot(pid):

x_var = st.selectbox("X Var", df_filt.columns, key=f"x_var_{pid}", index = x_ind)
y_var = st.selectbox("Y Var", df_filt.columns, key=f"y_var_{pid}", index = y_ind)
st.session_state.sel_roi = y_var

hue_var = st.selectbox("Hue Var", df_filt.columns, key=f"hue_var_{pid}", index = hue_ind)
trend_type = st.selectbox("Trend Line", st.session_state.trend_types, key=f"trend_type_{pid}", index = trend_index)

Expand Down Expand Up @@ -219,6 +221,7 @@ def filter_dataframe(df: pd.DataFrame, pid) -> pd.DataFrame:
spare_csv = st.sidebar.text_input("Enter the name of the ROI csv file:",
value = st.session_state.in_csv_sMRI,
label_visibility="collapsed")
st.session_state.fname_subj_list = spare_csv

if os.path.exists(spare_csv):

Expand Down Expand Up @@ -260,6 +263,8 @@ def filter_dataframe(df: pd.DataFrame, pid) -> pd.DataFrame:
index = def_ind_x)
st.session_state.default_y_var = st.selectbox("Default Y Var", df.columns, key=f"y_var_init",
index = def_ind_y)
st.session_state.sel_roi = st.session_state.default_y_var

st.session_state.default_hue_var = st.selectbox("Default Hue Var", df.columns, key=f"hue_var_init",
index = def_ind_hue)
trend_index = st.session_state.trend_types.index(st.session_state.default_trend_type)
Expand Down
Loading

0 comments on commit e4f0d52

Please sign in to comment.