Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving matplotlib figures as images assigned to experiments #311

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions entropylab/cli/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def test_init_with_no_args():
args.directory = ""
# act
init(args)
# clean up
shutil.rmtree(".entropy")


def test_init_with_current_dir():
Expand Down
6 changes: 3 additions & 3 deletions entropylab/dashboard/assets/dashboard.css
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ html,body {


#experiments-title,
#plots-title {
#figures-title {
margin-top:20px;
}

Expand Down Expand Up @@ -118,7 +118,7 @@ input.current-page::placeholder {

.add-button-container {
margin-top:20px;
display: flex
display: flex;
align-items: center;
justify-content: center;
}
Expand All @@ -139,7 +139,7 @@ input.current-page::placeholder {
height: 60px;
margin: 4px;
border-color: #666666;
display: flex
display: flex;
align-items: center;
justify-content: center;
}
Expand Down
201 changes: 104 additions & 97 deletions entropylab/dashboard/pages/results/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,28 @@
from entropylab.dashboard.pages.results.dashboard_data import FAVORITE_TRUE
from entropylab.dashboard.theme import (
colors,
dark_plot_layout,
dark_figure_layout,
)
from entropylab.logger import logger
from entropylab.pipeline.api.data_reader import PlotRecord, FigureRecord
from entropylab.pipeline.api.data_reader import (
FigureRecord,
MatplotlibFigureRecord,
)
from entropylab.pipeline.api.errors import EntropyError

REFRESH_INTERVAL_IN_MILLIS = 3000
EXPERIMENTS_PAGE_SIZE = 6
IMG_TAB_KEY = "m"
FIGURE_TAB_KEY = "f"


def register_callbacks(app, dashboard_data_reader):
"""Initialize the results dashboard Dash app to display an Entropy project
"""Add callbacks and helper methods to the dashboard Dash app

:param app the Dash app to add the callbacks to
:param dashboard_data_reader a DashboardDataReader instance to read dashboard
data from"""

""" Creating and setting up our Dash app """

""" CALLBACKS and their helper functions """

@app.callback(
Output("experiments-table", "data"),
Output("empty-project-modal", "is_open"),
Expand Down Expand Up @@ -68,17 +69,16 @@ def open_failed_plotting_alert_when_its_not_empty(children):
return children != ""

@app.callback(
Output("plot-tabs", "children"),
Output("figure-tabs", "children"),
Output("figures-by-key", "data"),
Output("prev-selected-rows", "data"),
Output("failed-plotting-alert", "children"),
Output("add-button", "disabled"),
Input("experiments-table", "selected_rows"),
State("experiments-table", "data"),
State("figures-by-key", "data"),
State("prev-selected-rows", "data"),
)
def render_plot_tabs_from_selected_experiments_table_rows(
def render_figure_tabs_from_selected_experiments_table_rows(
selected_rows, data, figures_by_key, prev_selected_rows
):
result = []
Expand All @@ -87,160 +87,160 @@ def render_plot_tabs_from_selected_experiments_table_rows(
prev_selected_rows = prev_selected_rows or {}
alert_text = ""
added_row = get_added_row(prev_selected_rows, selected_rows)
add_button_disabled = False
if data and selected_rows:
for row_num in selected_rows:
alert_on_fail = row_num == added_row
exp_id = data[row_num]["id"]
try:
plots_and_figures = dashboard_data_reader.get_plot_and_figure_data(
exp_id
)
figure_records = dashboard_data_reader.get_figure_records(exp_id)
except EntropyError:
logger.exception(
f"Exception when getting plot/figure data for exp_id={exp_id}"
f"Exception when getting figure data for exp_id={exp_id}"
)
if alert_on_fail:
alert_text = (
f"⚠ Error when reading plot/figure data for this "
f"⚠ Error when reading figure data for this "
f"experiment. (id: {exp_id})"
)
plots_and_figures = None
if plots_and_figures and len(plots_and_figures) > 0:
failed_plot_ids = []
figures_by_key = build_plot_tabs(
figure_records = None
if figure_records and len(figure_records) > 0:
failed_figure_keys = []
figures_by_key = build_figure_tabs(
alert_on_fail,
failed_plot_ids,
failed_figure_keys,
figures_by_key,
plots_and_figures,
figure_records,
result,
)
if len(failed_plot_ids) > 0:
if len(failed_figure_keys) > 0:
alert_text = (
f"⚠ Some plots could not be rendered. "
f"(ids: {','.join(failed_plot_ids)})"
f"⚠ Some figures could not be rendered. "
f"(ids: {','.join(failed_figure_keys)})"
)
else:
if alert_on_fail and alert_text == "":
alert_text = (
f"⚠ Experiment has no plots to render. (id: {exp_id})"
f"⚠ Experiment has no figures to render. (id: {exp_id})"
)
if len(result) == 0:
result = [build_plot_tabs_placeholder()]
add_button_disabled = True
result = [build_figure_tabs_placeholder()]
return (
result,
figures_by_key,
selected_rows,
alert_text,
add_button_disabled,
)

def build_plot_tabs(
alert_on_fail, failed_plot_ids, figures_by_key, plots_and_figures, result
def build_figure_tabs(
alert_on_fail, failed_figure_keys, figures_by_key, figure_records, result
):
for plot_or_figure in plots_and_figures:
for figure_record in figure_records:
try:
color = colors[len(result) % len(colors)]
plot_tab, figures_by_key = build_plot_tab_from_plot_or_figure(
figures_by_key, plot_or_figure, color
figure_tab, figures_by_key = build_tab_from_figure(
figures_by_key, figure_record, color
)
result.append(plot_tab)
result.append(figure_tab)
except (EntropyError, TypeError):
logger.exception(f"Failed to render plot id [{plot_or_figure.id}]")
logger.exception(
f"Failed to render figure record id [{figure_record.id}]"
)
if alert_on_fail:
plot_key = f"{plot_or_figure.experiment_id}/{plot_or_figure.id}"
failed_plot_ids.append(plot_key)
figure_key = f"{figure_record.experiment_id}/{figure_record.id}"
failed_figure_keys.append(figure_key)
return figures_by_key

def build_plot_tabs_placeholder():
def build_figure_tabs_placeholder():
return dbc.Tab(
html.Div(
html.Div(
"Select an experiment above to display its plots here",
"Select an experiment above to display its figures here",
className="tab-placeholder-text",
),
className="tab-placeholder-container",
),
label="Plots",
tab_id="plot-tab-placeholder",
label="Figures",
tab_id="figure-tab-placeholder",
)

def build_plot_tab_from_plot_or_figure(
figures_by_key, plot_or_figure: PlotRecord | FigureRecord, color: str
def build_tab_from_figure(
figures_by_key,
figure_record: FigureRecord | MatplotlibFigureRecord,
color: str,
) -> (dbc.Tab, Dict):
if isinstance(plot_or_figure, PlotRecord):
# For backwards compatibility with soon to be deprecated Plots API:
plot_rec = cast(PlotRecord, plot_or_figure)
key = f"{plot_rec.experiment_id}/{plot_rec.id}/p"
name = f"Plot {key[:-2]}"
figure = go.Figure()
plot_or_figure.generator.plot_plotly(
figure,
plot_or_figure.plot_data,
name=name,
color=color,
showlegend=False,
)
if isinstance(figure_record, MatplotlibFigureRecord):
record = cast(MatplotlibFigureRecord, figure_record)
key = f"{record.experiment_id}/{record.id}/{IMG_TAB_KEY}"
name = f"Image {key[:-2]}"
return build_img_tab(record.img_src, name, key), figures_by_key
else:
figure_rec = cast(FigureRecord, plot_or_figure)
key = f"{figure_rec.experiment_id}/{figure_rec.id}/f"
record = cast(FigureRecord, figure_record)
key = f"{record.experiment_id}/{record.id}/{FIGURE_TAB_KEY}"
name = f"Figure {key[:-2]}"
figure = figure_rec.figure
figure.update_layout(dark_plot_layout)
figures_by_key[key] = dict(figure=figure, color=color)
return build_plot_tab(figure, name, key), figures_by_key
figure = record.figure
figure.update_layout(dark_figure_layout)
figures_by_key[key] = dict(figure=figure, color=color)
return build_figure_tab(figure, name, key), figures_by_key

def build_plot_tab(
plot_figure: go.Figure, plot_name: str, plot_key: str
def build_figure_tab(
figure: go.Figure, figure_name: str, figure_key: str
) -> dbc.Tab:
return dbc.Tab(
dcc.Graph(figure=plot_figure, responsive=True),
label=plot_name,
id=f"plot-tab-{plot_key}",
tab_id=f"plot-tab-{plot_key}",
dcc.Graph(figure=figure, responsive=True),
label=figure_name,
id=f"figure-tab-{figure_key}",
tab_id=f"figure-tab-{figure_key}",
)

def build_img_tab(img_src: str, figure_name: str, figure_key: str) -> dbc.Tab:
return dbc.Tab(
# TODO: Fit img into tab dimensions
html.Img(src=img_src),
label=figure_name,
id=f"figure-tab-{figure_key}",
tab_id=f"figure-tab-{figure_key}",
)

@app.callback(
Output("plot-keys-to-combine", "data"),
Output("figure-keys-to-combine", "data"),
Input("add-button", "n_clicks"),
Input({"type": "remove-button", "index": ALL}, "n_clicks"),
State("plot-tabs", "active_tab"),
State("plot-keys-to-combine", "data"),
State("figure-tabs", "active_tab"),
State("figure-keys-to-combine", "data"),
)
def add_or_remove_plot_keys_based_on_click_events(
_, __, active_tab, plot_keys_to_combine
def add_or_remove_figure_keys_based_on_click_events(
_, __, active_tab, figure_keys_to_combine
):
plot_keys_to_combine = plot_keys_to_combine or []
figure_keys_to_combine = figure_keys_to_combine or []
prop_id = dash.callback_context.triggered[0]["prop_id"]
if prop_id == "add-button.n_clicks" and active_tab:
active_plot_key = active_tab.replace("plot-tab-", "")
if active_plot_key not in plot_keys_to_combine:
plot_keys_to_combine.append(active_plot_key)
active_figure_key = active_tab.replace("figure-tab-", "")
if active_figure_key not in figure_keys_to_combine:
figure_keys_to_combine.append(active_figure_key)
elif "remove-button" in prop_id:
id_dict = json.loads(prop_id.replace(".n_clicks", ""))
remove_plot_key = id_dict["index"]
if remove_plot_key in plot_keys_to_combine:
plot_keys_to_combine.remove(remove_plot_key)
return plot_keys_to_combine
remove_figure_key = id_dict["index"]
if remove_figure_key in figure_keys_to_combine:
figure_keys_to_combine.remove(remove_figure_key)
return figure_keys_to_combine

@app.callback(
Output("aggregate-tab", "children"),
Output("remove-buttons", "children"),
Input("plot-keys-to-combine", "data"),
Input("figure-keys-to-combine", "data"),
State("figures-by-key", "data"),
)
def build_combined_plot_from_plot_keys(plot_keys_to_combine, plot_figures):
if plot_keys_to_combine and len(plot_keys_to_combine) > 0:
def build_combined_figure_from_figure_keys(figure_keys_to_combine, figures_by_key):
if figure_keys_to_combine and len(figure_keys_to_combine) > 0:
combined_figure = make_subplots(specs=[[{"secondary_y": True}]])
remove_buttons = []
for plot_id in plot_keys_to_combine:
figure = plot_figures[plot_id]["figure"]
color = plot_figures[plot_id]["color"]
for figure_key in figure_keys_to_combine:
figure = figures_by_key[figure_key]["figure"]
color = figures_by_key[figure_key]["color"]
combined_figure.add_trace(figure["data"][0])
button = build_remove_button(plot_id, color)
button = build_remove_button(figure_key, color)
remove_buttons.append(button)
combined_figure.update_layout(dark_plot_layout)
combined_figure.update_layout(dark_figure_layout)
return (
dcc.Graph(
id="aggregate-graph",
Expand All @@ -255,26 +255,26 @@ def build_combined_plot_from_plot_keys(plot_keys_to_combine, plot_figures):
[html.Div()],
)

def build_remove_button(plot_id, color):
def build_remove_button(figure_key, color):
return dbc.Button(
dbc.Row(
children=[
dbc.Col(
"✖",
),
dbc.Col(f"{plot_id}", className="remove-button-label"),
dbc.Col(f"{figure_key}", className="remove-button-label"),
],
),
style={"background-color": color},
class_name="remove-button",
id={"type": "remove-button", "index": plot_id},
id={"type": "remove-button", "index": figure_key},
)

def build_aggregate_tab_placeholder():
return html.Div(
[
html.Div(
"Add a plot on the left to aggregate it here",
"Add a figure (above) to aggregate it here",
className="tab-placeholder-text",
),
dcc.Graph(
Expand All @@ -288,15 +288,22 @@ def build_aggregate_tab_placeholder():
)

@app.callback(
Output("plot-tabs", "active_tab"),
Input("plot-tabs", "children"),
Output("figure-tabs", "active_tab"),
Input("figure-tabs", "children"),
)
def activate_last_plot_tab_when_tabs_are_changed(children):
def activate_last_figure_tab_when_tabs_are_changed(children):
if len(children) > 0:
last_tab = len(children) - 1
return children[last_tab]["props"]["tab_id"]
return 0

@app.callback(
Output("add-button", "disabled"),
Input("figure-tabs", "active_tab"),
)
def disable_add_button_when_active_tab_is_img_or_placeholder(active_tab):
return active_tab.endswith("/m") or active_tab == "figure-tab-placeholder"

@app.callback(
Output("aggregate-clipboard", "content"),
Input("aggregate-clipboard", "n_clicks"),
Expand Down
Loading