Skip to content

Commit

Permalink
use flexible stage color
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhou committed Jan 7, 2024
1 parent 051b55c commit 0389afb
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 13 deletions.
24 changes: 21 additions & 3 deletions code/aind_auto_train/plot/curriculum.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
import re

from graphviz import Digraph
from aind_auto_train.plot.manager import stage_color_mapper

from matplotlib import pyplot as plt
import matplotlib

def get_stage_color_mapper(stage_list):
# Mapping stages to colors from red to green, return rgb values
# Interpolate between red and green using the number of stages
cmap = plt.cm.get_cmap('RdYlGn', 100)
stage_color_mapper = {
stage: matplotlib.colors.rgb2hex(
cmap(i / (len(stage_list) - 1)))
for i, stage in enumerate(stage_list)
}
return stage_color_mapper

def _format_lambda_full(string: str):
return '\n'.join([line.lstrip() # Remove the first line
Expand All @@ -25,8 +36,10 @@ def draw_diagram_rules(curriculum):
curriculum (Curriculum): _description_
"""

# Script data extracted from the user's script
stages = curriculum.parameters.keys()
stage_color_mapper = get_stage_color_mapper(
[s.name for s in list(stages)] + ['GRADUATED']
)

# Create Digraph object
dot = Digraph(comment='Curriculum for Dynamic Foraging - Coupled Baiting')
Expand Down Expand Up @@ -104,6 +117,11 @@ def draw_diagram_paras(curriculum,
with change of parameters highlighted in green
"""

stages = curriculum.parameters.keys()
stage_color_mapper = get_stage_color_mapper(
[s.name for s in list(stages)] + ['GRADUATED']
)

dot = Digraph('G')

# Graph attributes to control the overall appearance
Expand Down
16 changes: 6 additions & 10 deletions code/aind_auto_train/plot/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,7 @@
import pandas as pd

from aind_auto_train.schema.curriculum import TrainingStage

# Define color scale - mapping stages to colors from red to green
# TODO: make this flexible
stage_color_mapper = {
TrainingStage.STAGE_1.name: 'red',
TrainingStage.STAGE_2.name: 'orange',
TrainingStage.STAGE_3.name: 'yellow',
TrainingStage.STAGE_FINAL.name: 'lightgreen',
TrainingStage.GRADUATED.name: 'green'
}
from aind_auto_train.plot.curriculum import get_stage_color_mapper


def plot_manager_all_progress(manager: 'AutoTrainManager',
Expand All @@ -26,6 +17,7 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
'descending'] = 'descending',
if_show_fig=True
):

# %%
# Set default order
df_manager = manager.df_manager.sort_values(by=['subject_id', 'session'],
Expand Down Expand Up @@ -63,6 +55,10 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
traces = []
for n, subject_id in enumerate(subject_ids):
df_subject = df_manager[df_manager['subject_id'] == subject_id]

# Get stage_color_mapper
stage_color_mapper = get_stage_color_mapper(stage_list=list(TrainingStage.__members__))

# Get h2o if available
if 'h2o' in manager.df_behavior:
h2o = manager.df_behavior[
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
's3fs',
'graphviz',
'nbformat==5.1.2',
'matplotlib',
]

[tool.setuptools.packages.find]
Expand Down

0 comments on commit 0389afb

Please sign in to comment.