Skip to content

Commit

Permalink
add date as x-axis
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhou committed Dec 20, 2023
1 parent 0c0c550 commit 7f32b6d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
4 changes: 2 additions & 2 deletions code/aind_auto_train/auto_train_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def compute_stats(self):

self.df_manager_stats = df_stats

def plot_all_progress(self, if_show_fig=True):
return plot_manager_all_progress(self, if_show_fig=if_show_fig)
def plot_all_progress(self, **kwargs):
return plot_manager_all_progress(self, **kwargs)

def _get_next_stage_suggested_on_last_session(self, subject_id, session) -> str:
df_this_mouse = self.df_manager.query(f'subject_id == "{subject_id}"')
Expand Down
22 changes: 16 additions & 6 deletions code/aind_auto_train/plot/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


def plot_manager_all_progress(manager: 'AutoTrainManager',
x_axis: ['session', 'date'] = 'session',
if_show_fig=True
):
# %%
Expand All @@ -32,16 +33,24 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
else:
h2o = None

# Set fill color (none if not closed loop)
# Handle open loop sessions
open_loop_ids = df_subject.if_closed_loop == False
color_actual = df_subject['current_stage_actual'].map(
stage_color_mapper)
color_actual[open_loop_ids] = 'lightgrey'
stage_actual = df_subject.current_stage_actual.values
stage_actual[open_loop_ids] = 'unknown (open loop)'

# Select x
if x_axis == 'session':
x = df_subject['session']
elif x_axis == 'date':
x = df_subject['session_date']
else:
raise ValueError(f'x_axis can only be "session" or "date"')

traces.append(go.Scattergl(
x=df_subject['session'],
x=x,
y=[n] * len(df_subject),
mode='markers',
marker=dict(
Expand All @@ -56,7 +65,7 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
),
name=f'Mouse {subject_id}',
hovertemplate=(f"<b>Subject {subject_id} ({h2o})"
"<br>Session %{x}, %{customdata[4]}</b>"
"<br>Session %{customdata[9]}, %{customdata[4]}</b>"
"<br>Curriculum: <b>%{customdata[7]}_v%{customdata[8]}</b>"
"<br>Suggested: <b>%{customdata[0]}</b>"
"<br>Actual: <b>%{customdata[1]}</b>"
Expand All @@ -75,15 +84,16 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
df_subject.task,
df_subject.curriculum_task,
df_subject.curriculum_version,
df_subject.session,
), axis=-1),
showlegend=False
)
)

# Add "x" for open loop sessions
traces.append(go.Scattergl(
x=df_subject['session'][open_loop_ids],
y=[n] * len(df_subject['session'][open_loop_ids]),
x=x[open_loop_ids],
y=[n] * len(x[open_loop_ids]),
mode='markers',
marker=dict(
size=5,
Expand All @@ -99,7 +109,7 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
fig = go.Figure(data=traces)
fig.update_layout(
title=f'Training Progress of All Mice ({manager.manager_name}, curriculum_task = {manager.df_manager.curriculum_task[0]})',
xaxis_title='Session',
xaxis_title=x_axis,
yaxis_title='Mouse',
height=1200,
)
Expand Down

0 comments on commit 7f32b6d

Please sign in to comment.