Skip to content

Commit

Permalink
Merge pull request #62 from AllenNeuralDynamics/han_improve_autotrain…
Browse files Browse the repository at this point in the history
…_history

feat: add quick preview to the autotrain manager
  • Loading branch information
hanhou authored Apr 10, 2024
2 parents 830df81 + d054507 commit 7afee3f
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 50 deletions.
27 changes: 17 additions & 10 deletions code/Home.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
"""

# %%
__ver__ = 'v2.2.0'

import pandas as pd
Expand Down Expand Up @@ -452,6 +451,15 @@ def _get_data_source(rig):
_df.dropna(subset=['session'], inplace=True) # Remove rows with no session number (only leave the nwb file with the largest finished_trials for now)
_df.drop(_df.query('session < 1').index, inplace=True)

# Remove abnormal values
_df.loc[_df['weight_after'] > 100,
['weight_after', 'weight_after_ratio', 'water_in_session_total', 'water_after_session', 'water_day_total']
] = np.nan

_df.loc[_df['water_in_session_manual'] > 100,
['water_in_session_manual', 'water_in_session_total', 'water_after_session']] = np.nan


# # add something else
# add abs(bais) to all terms that have 'bias' in name
for col in _df.columns:
Expand All @@ -470,6 +478,13 @@ def _get_data_source(rig):
# map user_name
_df['user_name'] = _df['user_name'].apply(_user_name_mapper)

# trial stats
_df['avg_trial_length_in_seconds'] = _df['session_run_time_in_min'] / _df['total_trials_with_autowater'] * 60

# last day's total water
_df['water_day_total_last_session'] = _df.groupby('h2o')['water_day_total'].shift(1)
_df['water_after_session_last_session'] = _df.groupby('h2o')['water_after_session'].shift(1)

# fill nan for autotrain fields
filled_values = {'curriculum_name': 'None',
'curriculum_version': 'None',
Expand All @@ -480,15 +495,7 @@ def _get_data_source(rig):
'if_overriden_by_trainer': False,
}
_df.fillna(filled_values, inplace=True)

# Remove abnormal values
_df.loc[_df['weight_after'] > 100,
['weight_after', 'weight_after_ratio', 'water_in_session_total', 'water_after_session', 'water_day_total']
] = np.nan

_df.loc[_df['water_in_session_manual'] > 100,
['water_in_session_manual', 'water_in_session_total', 'water_after_session']] = np.nan


# foraging performance = foraing_eff * finished_rate
if 'foraging_performance' not in _df.columns:
_df['foraging_performance'] = \
Expand Down
4 changes: 1 addition & 3 deletions code/util/aws_s3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from PIL import Image
import glob
import json

import s3fs
Expand Down Expand Up @@ -56,9 +55,8 @@ def draw_session_plots_quick_preview(df_to_draw_session):
rows.append(st.columns(column_setting))

for draw_type in draw_types_quick_preview:
if draw_type not in st.session_state.selected_draw_types: continue # To keep the draw order defined by st.session_state.draw_type_mapper_session_level
prefix, position, setting = st.session_state.draw_type_mapper_session_level[draw_type]
this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types) > 1 else rows[0]
this_col = rows[position[0]][position[1]] if len(draw_types_quick_preview) > 1 else rows[0]
show_session_level_img_by_key_and_prefix(
key,
column=this_col,
Expand Down
203 changes: 203 additions & 0 deletions code/util/plot_autotrain_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
from datetime import datetime

import numpy as np
import plotly.graph_objects as go
import pandas as pd

from aind_auto_train.schema.curriculum import TrainingStage
from aind_auto_train.plot.curriculum import get_stage_color_mapper


def plot_manager_all_progress(manager: 'AutoTrainManager',
x_axis: ['session', 'date',
'relative_date'] = 'session', # type: ignore
sort_by: ['subject_id', 'first_date',
'last_date', 'progress_to_graduated'] = 'subject_id',
sort_order: ['ascending',
'descending'] = 'descending',
marker_size=10,
marker_edge_width=2,
highlight_subjects=[],
if_show_fig=True
):


# %%
# Set default order
df_manager = manager.df_manager.sort_values(by=['subject_id', 'session'],
ascending=[sort_order == 'ascending', False])

if not len(df_manager):
return None

# Sort mice
if sort_by == 'subject_id':
subject_ids = df_manager.subject_id.unique()
elif sort_by == 'first_date':
subject_ids = df_manager.groupby('subject_id').session_date.min().sort_values(
ascending=sort_order == 'ascending').index
elif sort_by == 'last_date':
subject_ids = df_manager.groupby('subject_id').session_date.max().sort_values(
ascending=sort_order == 'ascending').index
elif sort_by == 'progress_to_graduated':
manager.compute_stats()
df_stats = manager.df_manager_stats

# Sort by 'first_entry' of GRADUATED
subject_ids = df_stats.reset_index().set_index(
'subject_id'
).query(
f'current_stage_actual == "GRADUATED"'
)['first_entry'].sort_values(
ascending=sort_order != 'ascending').index.to_list()

# Append subjects that have not graduated
subject_ids = subject_ids + [s for s in df_manager.subject_id.unique() if s not in subject_ids]

else:
raise ValueError(
f'sort_by must be in {["subject_id", "first_date", "last_date", "progress"]}')

# Preparing the scatter plot
traces = []
for n, subject_id in enumerate(subject_ids):
df_subject = df_manager[df_manager['subject_id'] == subject_id]

# Get stage_color_mapper
stage_color_mapper = get_stage_color_mapper(stage_list=list(TrainingStage.__members__))

# Get h2o if available
if 'h2o' in manager.df_behavior:
h2o = manager.df_behavior[
manager.df_behavior['subject_id'] == subject_id]['h2o'].iloc[0]
else:
h2o = None

# Handle open loop sessions
open_loop_ids = df_subject.if_closed_loop == False
color_actual = df_subject['current_stage_actual'].map(
stage_color_mapper)
color_actual[open_loop_ids] = 'lightgrey'
stage_actual = df_subject.current_stage_actual.values
stage_actual[open_loop_ids] = 'unknown (open loop)'

# Select x
if x_axis == 'session':
x = df_subject['session']
elif x_axis == 'date':
x = pd.to_datetime(df_subject['session_date'])
elif x_axis == 'relative_date':
x = pd.to_datetime(df_subject['session_date'])
x = (x - x.min()).dt.days
else:
raise ValueError(
f"x_axis can only be in ['session', 'date', 'relative_date']")

# Cache x range
xrange_min = x.min() if n == 0 else min(x.min(), xrange_min)
xrange_max = x.max() if n == 0 else max(x.max(), xrange_max)

traces.append(go.Scattergl(
x=x,
y=[n] * len(df_subject),
mode='markers',
marker=dict(
size=marker_size,
line=dict(
width=marker_edge_width,
color=df_subject['current_stage_suggested'].map(
stage_color_mapper)
),
color=color_actual,
# colorbar=dict(title='Training Stage'),
),
name=f'Mouse {subject_id}',
hovertemplate=(f"<b>Subject {subject_id} ({h2o})</b>"
"<br><b>Session %{customdata[0]}, %{customdata[1]}</b>"
"<br>Curriculum: <b>%{customdata[2]}_v%{customdata[3]}</b>"
"<br>Suggested: <b>%{customdata[4]}</b>"
"<br>Actual: <b>%{customdata[5]}</b>"
"<br>Session task: <b>%{customdata[6]}</b>"
"<br>foraging_eff = %{customdata[7]}"
"<br>finished_trials = %{customdata[8]}"
"<br>Decision = <b>%{customdata[9]}</b>"
"<br>Next suggested: <b>%{customdata[10]}</b>"
"<extra></extra>"),
customdata=np.stack(
(df_subject.session,
df_subject.session_date,
df_subject.curriculum_name,
df_subject.curriculum_version,
df_subject.current_stage_suggested,
stage_actual,
df_subject.task,
np.round(df_subject.foraging_efficiency, 3),
df_subject.finished_trials,
df_subject.decision,
df_subject.next_stage_suggested,
), axis=-1),
showlegend=False
)
)

# Add "x" for open loop sessions
traces.append(go.Scattergl(
x=x[open_loop_ids],
y=[n] * len(x[open_loop_ids]),
mode='markers',
marker=dict(
size=marker_size*0.8,
symbol='x-thin',
color='black',
line_width=marker_edge_width*0.8,
),
showlegend=False,
)
)

# Create the figure
fig = go.Figure(data=traces)
fig.update_layout(
title=f"Automatic training progress ({manager.manager_name})",
xaxis_title=x_axis,
yaxis_title='Mouse',
height=1200,
)

# Set subject_id as y axis label
fig.update_layout(
hovermode='closest',
yaxis=dict(
tickmode='array',
tickvals=np.arange(0, n + 1), # Original y-axis values
ticktext=subject_ids, # New labels
autorange='reversed',
zeroline=False,
title=''
)
)

# Highight the selected subject
for n, subject_id in enumerate(subject_ids):
if subject_id in highlight_subjects:
fig.add_shape(
type="rect",
y0=n-0.5,
y1=n+0.5,
x0=xrange_min - (1 if x_axis != 'date' else pd.Timedelta(days=1)),
x1=xrange_max + (1 if x_axis != 'date' else pd.Timedelta(days=1)),
line=dict(
width=0,
),
fillcolor="Gray",
opacity=0.3,
layer="below"
)


# Show the plot
if if_show_fig:
fig.show()

# %%
return fig
Loading

0 comments on commit 7afee3f

Please sign in to comment.