diff --git a/entropylab/cli/tests/test_main.py b/entropylab/cli/tests/test_main.py
index 098f139a..361b35a0 100644
--- a/entropylab/cli/tests/test_main.py
+++ b/entropylab/cli/tests/test_main.py
@@ -13,6 +13,8 @@ def test_init_with_no_args():
args.directory = ""
# act
init(args)
+ # clean up
+ shutil.rmtree(".entropy")
def test_init_with_current_dir():
diff --git a/entropylab/dashboard/assets/dashboard.css b/entropylab/dashboard/assets/dashboard.css
index 578b382d..67919855 100644
--- a/entropylab/dashboard/assets/dashboard.css
+++ b/entropylab/dashboard/assets/dashboard.css
@@ -33,7 +33,7 @@ html,body {
#experiments-title,
-#plots-title {
+#figures-title {
margin-top:20px;
}
@@ -118,7 +118,7 @@ input.current-page::placeholder {
.add-button-container {
margin-top:20px;
- display: flex
+ display: flex;
align-items: center;
justify-content: center;
}
@@ -139,7 +139,7 @@ input.current-page::placeholder {
height: 60px;
margin: 4px;
border-color: #666666;
- display: flex
+ display: flex;
align-items: center;
justify-content: center;
}
diff --git a/entropylab/dashboard/pages/results/callbacks.py b/entropylab/dashboard/pages/results/callbacks.py
index 98c032d8..1612e910 100644
--- a/entropylab/dashboard/pages/results/callbacks.py
+++ b/entropylab/dashboard/pages/results/callbacks.py
@@ -13,27 +13,28 @@
from entropylab.dashboard.pages.results.dashboard_data import FAVORITE_TRUE
from entropylab.dashboard.theme import (
colors,
- dark_plot_layout,
+ dark_figure_layout,
)
from entropylab.logger import logger
-from entropylab.pipeline.api.data_reader import PlotRecord, FigureRecord
+from entropylab.pipeline.api.data_reader import (
+ FigureRecord,
+ MatplotlibFigureRecord,
+)
from entropylab.pipeline.api.errors import EntropyError
REFRESH_INTERVAL_IN_MILLIS = 3000
EXPERIMENTS_PAGE_SIZE = 6
+IMG_TAB_KEY = "m"
+FIGURE_TAB_KEY = "f"
def register_callbacks(app, dashboard_data_reader):
- """Initialize the results dashboard Dash app to display an Entropy project
+ """Add callbacks and helper methods to the dashboard Dash app
:param app the Dash app to add the callbacks to
:param dashboard_data_reader a DashboardDataReader instance to read dashboard
data from"""
- """ Creating and setting up our Dash app """
-
- """ CALLBACKS and their helper functions """
-
@app.callback(
Output("experiments-table", "data"),
Output("empty-project-modal", "is_open"),
@@ -68,17 +69,16 @@ def open_failed_plotting_alert_when_its_not_empty(children):
return children != ""
@app.callback(
- Output("plot-tabs", "children"),
+ Output("figure-tabs", "children"),
Output("figures-by-key", "data"),
Output("prev-selected-rows", "data"),
Output("failed-plotting-alert", "children"),
- Output("add-button", "disabled"),
Input("experiments-table", "selected_rows"),
State("experiments-table", "data"),
State("figures-by-key", "data"),
State("prev-selected-rows", "data"),
)
- def render_plot_tabs_from_selected_experiments_table_rows(
+ def render_figure_tabs_from_selected_experiments_table_rows(
selected_rows, data, figures_by_key, prev_selected_rows
):
result = []
@@ -87,160 +87,160 @@ def render_plot_tabs_from_selected_experiments_table_rows(
prev_selected_rows = prev_selected_rows or {}
alert_text = ""
added_row = get_added_row(prev_selected_rows, selected_rows)
- add_button_disabled = False
if data and selected_rows:
for row_num in selected_rows:
alert_on_fail = row_num == added_row
exp_id = data[row_num]["id"]
try:
- plots_and_figures = dashboard_data_reader.get_plot_and_figure_data(
- exp_id
- )
+ figure_records = dashboard_data_reader.get_figure_records(exp_id)
except EntropyError:
logger.exception(
- f"Exception when getting plot/figure data for exp_id={exp_id}"
+ f"Exception when getting figure data for exp_id={exp_id}"
)
if alert_on_fail:
alert_text = (
- f"⚠ Error when reading plot/figure data for this "
+ f"⚠ Error when reading figure data for this "
f"experiment. (id: {exp_id})"
)
- plots_and_figures = None
- if plots_and_figures and len(plots_and_figures) > 0:
- failed_plot_ids = []
- figures_by_key = build_plot_tabs(
+ figure_records = None
+ if figure_records and len(figure_records) > 0:
+ failed_figure_keys = []
+ figures_by_key = build_figure_tabs(
alert_on_fail,
- failed_plot_ids,
+ failed_figure_keys,
figures_by_key,
- plots_and_figures,
+ figure_records,
result,
)
- if len(failed_plot_ids) > 0:
+ if len(failed_figure_keys) > 0:
alert_text = (
- f"⚠ Some plots could not be rendered. "
- f"(ids: {','.join(failed_plot_ids)})"
+ f"⚠ Some figures could not be rendered. "
+ f"(ids: {','.join(failed_figure_keys)})"
)
else:
if alert_on_fail and alert_text == "":
alert_text = (
- f"⚠ Experiment has no plots to render. (id: {exp_id})"
+ f"⚠ Experiment has no figures to render. (id: {exp_id})"
)
if len(result) == 0:
- result = [build_plot_tabs_placeholder()]
- add_button_disabled = True
+ result = [build_figure_tabs_placeholder()]
return (
result,
figures_by_key,
selected_rows,
alert_text,
- add_button_disabled,
)
- def build_plot_tabs(
- alert_on_fail, failed_plot_ids, figures_by_key, plots_and_figures, result
+ def build_figure_tabs(
+ alert_on_fail, failed_figure_keys, figures_by_key, figure_records, result
):
- for plot_or_figure in plots_and_figures:
+ for figure_record in figure_records:
try:
color = colors[len(result) % len(colors)]
- plot_tab, figures_by_key = build_plot_tab_from_plot_or_figure(
- figures_by_key, plot_or_figure, color
+ figure_tab, figures_by_key = build_tab_from_figure(
+ figures_by_key, figure_record, color
)
- result.append(plot_tab)
+ result.append(figure_tab)
except (EntropyError, TypeError):
- logger.exception(f"Failed to render plot id [{plot_or_figure.id}]")
+ logger.exception(
+ f"Failed to render figure record id [{figure_record.id}]"
+ )
if alert_on_fail:
- plot_key = f"{plot_or_figure.experiment_id}/{plot_or_figure.id}"
- failed_plot_ids.append(plot_key)
+ figure_key = f"{figure_record.experiment_id}/{figure_record.id}"
+ failed_figure_keys.append(figure_key)
return figures_by_key
- def build_plot_tabs_placeholder():
+ def build_figure_tabs_placeholder():
return dbc.Tab(
html.Div(
html.Div(
- "Select an experiment above to display its plots here",
+ "Select an experiment above to display its figures here",
className="tab-placeholder-text",
),
className="tab-placeholder-container",
),
- label="Plots",
- tab_id="plot-tab-placeholder",
+ label="Figures",
+ tab_id="figure-tab-placeholder",
)
- def build_plot_tab_from_plot_or_figure(
- figures_by_key, plot_or_figure: PlotRecord | FigureRecord, color: str
+ def build_tab_from_figure(
+ figures_by_key,
+ figure_record: FigureRecord | MatplotlibFigureRecord,
+ color: str,
) -> (dbc.Tab, Dict):
- if isinstance(plot_or_figure, PlotRecord):
- # For backwards compatibility with soon to be deprecated Plots API:
- plot_rec = cast(PlotRecord, plot_or_figure)
- key = f"{plot_rec.experiment_id}/{plot_rec.id}/p"
- name = f"Plot {key[:-2]}"
- figure = go.Figure()
- plot_or_figure.generator.plot_plotly(
- figure,
- plot_or_figure.plot_data,
- name=name,
- color=color,
- showlegend=False,
- )
+ if isinstance(figure_record, MatplotlibFigureRecord):
+ record = cast(MatplotlibFigureRecord, figure_record)
+ key = f"{record.experiment_id}/{record.id}/{IMG_TAB_KEY}"
+ name = f"Image {key[:-2]}"
+ return build_img_tab(record.img_src, name, key), figures_by_key
else:
- figure_rec = cast(FigureRecord, plot_or_figure)
- key = f"{figure_rec.experiment_id}/{figure_rec.id}/f"
+ record = cast(FigureRecord, figure_record)
+ key = f"{record.experiment_id}/{record.id}/{FIGURE_TAB_KEY}"
name = f"Figure {key[:-2]}"
- figure = figure_rec.figure
- figure.update_layout(dark_plot_layout)
- figures_by_key[key] = dict(figure=figure, color=color)
- return build_plot_tab(figure, name, key), figures_by_key
+ figure = record.figure
+ figure.update_layout(dark_figure_layout)
+ figures_by_key[key] = dict(figure=figure, color=color)
+ return build_figure_tab(figure, name, key), figures_by_key
- def build_plot_tab(
- plot_figure: go.Figure, plot_name: str, plot_key: str
+ def build_figure_tab(
+ figure: go.Figure, figure_name: str, figure_key: str
) -> dbc.Tab:
return dbc.Tab(
- dcc.Graph(figure=plot_figure, responsive=True),
- label=plot_name,
- id=f"plot-tab-{plot_key}",
- tab_id=f"plot-tab-{plot_key}",
+ dcc.Graph(figure=figure, responsive=True),
+ label=figure_name,
+ id=f"figure-tab-{figure_key}",
+ tab_id=f"figure-tab-{figure_key}",
+ )
+
+ def build_img_tab(img_src: str, figure_name: str, figure_key: str) -> dbc.Tab:
+ return dbc.Tab(
+ # TODO: Fit img into tab dimensions
+ html.Img(src=img_src),
+ label=figure_name,
+ id=f"figure-tab-{figure_key}",
+ tab_id=f"figure-tab-{figure_key}",
)
@app.callback(
- Output("plot-keys-to-combine", "data"),
+ Output("figure-keys-to-combine", "data"),
Input("add-button", "n_clicks"),
Input({"type": "remove-button", "index": ALL}, "n_clicks"),
- State("plot-tabs", "active_tab"),
- State("plot-keys-to-combine", "data"),
+ State("figure-tabs", "active_tab"),
+ State("figure-keys-to-combine", "data"),
)
- def add_or_remove_plot_keys_based_on_click_events(
- _, __, active_tab, plot_keys_to_combine
+ def add_or_remove_figure_keys_based_on_click_events(
+ _, __, active_tab, figure_keys_to_combine
):
- plot_keys_to_combine = plot_keys_to_combine or []
+ figure_keys_to_combine = figure_keys_to_combine or []
prop_id = dash.callback_context.triggered[0]["prop_id"]
if prop_id == "add-button.n_clicks" and active_tab:
- active_plot_key = active_tab.replace("plot-tab-", "")
- if active_plot_key not in plot_keys_to_combine:
- plot_keys_to_combine.append(active_plot_key)
+ active_figure_key = active_tab.replace("figure-tab-", "")
+ if active_figure_key not in figure_keys_to_combine:
+ figure_keys_to_combine.append(active_figure_key)
elif "remove-button" in prop_id:
id_dict = json.loads(prop_id.replace(".n_clicks", ""))
- remove_plot_key = id_dict["index"]
- if remove_plot_key in plot_keys_to_combine:
- plot_keys_to_combine.remove(remove_plot_key)
- return plot_keys_to_combine
+ remove_figure_key = id_dict["index"]
+ if remove_figure_key in figure_keys_to_combine:
+ figure_keys_to_combine.remove(remove_figure_key)
+ return figure_keys_to_combine
@app.callback(
Output("aggregate-tab", "children"),
Output("remove-buttons", "children"),
- Input("plot-keys-to-combine", "data"),
+ Input("figure-keys-to-combine", "data"),
State("figures-by-key", "data"),
)
- def build_combined_plot_from_plot_keys(plot_keys_to_combine, plot_figures):
- if plot_keys_to_combine and len(plot_keys_to_combine) > 0:
+ def build_combined_figure_from_figure_keys(figure_keys_to_combine, figures_by_key):
+ if figure_keys_to_combine and len(figure_keys_to_combine) > 0:
combined_figure = make_subplots(specs=[[{"secondary_y": True}]])
remove_buttons = []
- for plot_id in plot_keys_to_combine:
- figure = plot_figures[plot_id]["figure"]
- color = plot_figures[plot_id]["color"]
+ for figure_key in figure_keys_to_combine:
+ figure = figures_by_key[figure_key]["figure"]
+ color = figures_by_key[figure_key]["color"]
combined_figure.add_trace(figure["data"][0])
- button = build_remove_button(plot_id, color)
+ button = build_remove_button(figure_key, color)
remove_buttons.append(button)
- combined_figure.update_layout(dark_plot_layout)
+ combined_figure.update_layout(dark_figure_layout)
return (
dcc.Graph(
id="aggregate-graph",
@@ -255,26 +255,26 @@ def build_combined_plot_from_plot_keys(plot_keys_to_combine, plot_figures):
[html.Div()],
)
- def build_remove_button(plot_id, color):
+ def build_remove_button(figure_key, color):
return dbc.Button(
dbc.Row(
children=[
dbc.Col(
"✖",
),
- dbc.Col(f"{plot_id}", className="remove-button-label"),
+ dbc.Col(f"{figure_key}", className="remove-button-label"),
],
),
style={"background-color": color},
class_name="remove-button",
- id={"type": "remove-button", "index": plot_id},
+ id={"type": "remove-button", "index": figure_key},
)
def build_aggregate_tab_placeholder():
return html.Div(
[
html.Div(
- "Add a plot on the left to aggregate it here",
+ "Add a figure (above) to aggregate it here",
className="tab-placeholder-text",
),
dcc.Graph(
@@ -288,15 +288,22 @@ def build_aggregate_tab_placeholder():
)
@app.callback(
- Output("plot-tabs", "active_tab"),
- Input("plot-tabs", "children"),
+ Output("figure-tabs", "active_tab"),
+ Input("figure-tabs", "children"),
)
- def activate_last_plot_tab_when_tabs_are_changed(children):
+ def activate_last_figure_tab_when_tabs_are_changed(children):
if len(children) > 0:
last_tab = len(children) - 1
return children[last_tab]["props"]["tab_id"]
return 0
+ @app.callback(
+ Output("add-button", "disabled"),
+ Input("figure-tabs", "active_tab"),
+ )
+ def disable_add_button_when_active_tab_is_img_or_placeholder(active_tab):
+ return active_tab.endswith("/m") or active_tab == "figure-tab-placeholder"
+
@app.callback(
Output("aggregate-clipboard", "content"),
Input("aggregate-clipboard", "n_clicks"),
diff --git a/entropylab/dashboard/pages/results/dashboard_data.py b/entropylab/dashboard/pages/results/dashboard_data.py
index d69b8b42..0e57b19b 100644
--- a/entropylab/dashboard/pages/results/dashboard_data.py
+++ b/entropylab/dashboard/pages/results/dashboard_data.py
@@ -8,7 +8,10 @@
from entropylab import SqlAlchemyDB
from entropylab.dashboard.pages.results.auto_plot import auto_plot
from entropylab.logger import logger
-from entropylab.pipeline.api.data_reader import PlotRecord, FigureRecord
+from entropylab.pipeline.api.data_reader import (
+ FigureRecord,
+ MatplotlibFigureRecord,
+)
FAVORITE_TRUE = "⭐"
FAVORITE_FALSE = "✰"
@@ -25,7 +28,9 @@ def get_last_experiments(
pass
@abc.abstractmethod
- def get_plot_and_figure_data(self, exp_id: int) -> List[PlotRecord]:
+ def get_figure_records(
+ self, exp_id: int
+ ) -> List[FigureRecord | MatplotlibFigureRecord]:
pass
@@ -65,22 +70,26 @@ def get_last_result_of_experiment(
):
return self._db.get_last_result_of_experiment(experiment_id)
- def get_plot_and_figure_data(self, exp_id: int) -> List[PlotRecord | FigureRecord]:
- plots = self._db.get_plots(exp_id)
+ def get_figure_records(
+ self, exp_id: int
+ ) -> List[FigureRecord | MatplotlibFigureRecord]:
+ # Plotly figures
if exp_id not in self._figures_cache:
logger.debug(f"Figures cache miss. exp_id=[{exp_id}]")
self._figures_cache[exp_id] = self._db.get_figures(exp_id)
else:
logger.debug(f"Figures cache hit. exp_id=[{exp_id}]")
figures = self._figures_cache[exp_id]
- if len(plots) > 0 or len(figures) > 0:
- return [*plots, *figures]
+ # Matplotlib figures
+ # TODO: Cache matplotlib figures?
+ matplotlib_figures = self._db.get_matplotlib_figures(exp_id)
+ if len(figures) > 0 or len(matplotlib_figures) > 0:
+ return [*figures, *matplotlib_figures]
else:
- # TODO: auto_plot to produce figures, not plots
last_result = self._db.get_last_result_of_experiment(exp_id)
if last_result is not None and last_result.data is not None:
- plot = auto_plot(exp_id, last_result.data)
- return [plot]
+ figure = auto_plot(exp_id, last_result.data)
+ return [figure]
else:
return []
diff --git a/entropylab/dashboard/pages/results/layout.py b/entropylab/dashboard/pages/results/layout.py
index 12365b79..53c9d0e1 100644
--- a/entropylab/dashboard/pages/results/layout.py
+++ b/entropylab/dashboard/pages/results/layout.py
@@ -23,7 +23,7 @@ def build_layout(path: str, dashboard_data_reader: DashboardDataReader):
className="main",
children=[
dcc.Store(id="figures-by-key", storage_type="session"),
- dcc.Store(id="plot-keys-to-combine", storage_type="session"),
+ dcc.Store(id="figure-keys-to-combine", storage_type="session"),
dcc.Store(id="prev-selected-rows", storage_type="session"),
dcc.Store(id="favorites", storage_type="session"),
dcc.Interval(
@@ -71,17 +71,19 @@ def build_layout(path: str, dashboard_data_reader: DashboardDataReader):
[
dbc.Row(
[
- html.H5("Plots and Figures", id="plots-title"),
+ html.H5(
+ "Figures and Images", id="figures-title"
+ ),
dcc.Loading(
- id="plot-tabs-loading",
- children=[dbc.Tabs(id="plot-tabs")],
+ id="figure-tabs-loading",
+ children=[dbc.Tabs(id="figure-tabs")],
type="default",
),
]
),
dbc.Row(
dbc.Button(
- "➕ Add Plot to Aggregate View",
+ "➕ Add Figure to Aggregate View",
id="add-button",
),
className="add-button-container",
diff --git a/entropylab/dashboard/theme.py b/entropylab/dashboard/theme.py
index 9ba4c012..e53cc374 100644
--- a/entropylab/dashboard/theme.py
+++ b/entropylab/dashboard/theme.py
@@ -17,7 +17,7 @@
plot_paper_bgcolor = "rgb(48, 48, 48)"
plot_plot_bgcolor = "rgb(173, 181, 189)"
-dark_plot_layout = dict(
+dark_figure_layout = dict(
font_color=plot_font_color,
legend_font_color=plot_legend_font_color,
paper_bgcolor=plot_paper_bgcolor,
diff --git a/entropylab/pipeline/api/data_reader.py b/entropylab/pipeline/api/data_reader.py
index d615202c..ab084cae 100644
--- a/entropylab/pipeline/api/data_reader.py
+++ b/entropylab/pipeline/api/data_reader.py
@@ -2,13 +2,10 @@
from dataclasses import dataclass
from datetime import datetime
from typing import List, Any, Optional, Iterable
-from warnings import warn
from pandas import DataFrame
from plotly import graph_objects as go
-from entropylab.pipeline.api.data_writer import PlotGenerator
-
class ScriptViewer:
def __init__(self, stages: List[str]) -> None:
@@ -84,29 +81,26 @@ class ResultRecord:
@dataclass
-class PlotRecord:
+class FigureRecord:
"""
- A single plot information and plotting instructions that was saved during the
- experiment
+ A single plotly figure that was saved during the experiment
"""
experiment_id: int
id: int
- plot_data: Any = None
- generator: Optional[PlotGenerator] = None
- label: Optional[str] = None
- story: Optional[str] = None
+ figure: go.Figure
+ time: datetime
@dataclass
-class FigureRecord:
+class MatplotlibFigureRecord:
"""
- A single plotly figure that was saved during the experiment
+ A single matplotlib figure image that was saved during the experiment
"""
experiment_id: int
id: int
- figure: go.Figure
+ img_src: str
time: datetime
@@ -227,23 +221,19 @@ def get_debug_record(self, experiment_id: int) -> Optional[DebugRecord]:
"""
pass
- # noinspection PyTypeChecker
@abstractmethod
- def get_plots(self, experiment_id: int) -> List[PlotRecord]:
+ def get_figures(self, experiment_id: int) -> List[FigureRecord]:
"""
- returns a list of all plots saved in the requested experiment
+ returns a list of all plotly figures saved with the requested experiment
"""
- warn(
- "This method will soon be deprecated. Please use get_figures() instead",
- PendingDeprecationWarning,
- stacklevel=2,
- )
pass
@abstractmethod
- def get_figures(self, experiment_id: int) -> List[FigureRecord]:
+ def get_matplotlib_figures(
+ self, experiment_id: int
+ ) -> List[MatplotlibFigureRecord]:
"""
- returns a list of all figures saved in the requested experiment
+ returns a list of all matplotlib figures saved with the requested experiment
"""
pass
@@ -336,19 +326,14 @@ def get_results(self, label: Optional[str] = None) -> Iterable[ResultRecord]:
"""
return self._data_reader.get_results(self._experiment_id, label)
- def get_plots(self) -> List[PlotRecord]:
- """
- returns a list of plot records that were saved for current experiment
- """
- warn(
- "This method will soon be deprecated. Please use get_plots() instead",
- PendingDeprecationWarning,
- stacklevel=2,
- )
- return self._data_reader.get_plots(self._experiment_id)
-
def get_figures(self) -> List[FigureRecord]:
"""
returns a list of plotly figures that were saved for current experiment
"""
return self._data_reader.get_figures(self._experiment_id)
+
+ def get_matplotlib_figures(self) -> List[MatplotlibFigureRecord]:
+ """
+ returns a list of matplotlib figures that were saved for current experiment
+ """
+ return self._data_reader.get_matplotlib_figures(self._experiment_id)
diff --git a/entropylab/pipeline/api/data_writer.py b/entropylab/pipeline/api/data_writer.py
index 78bd85b6..3003203f 100644
--- a/entropylab/pipeline/api/data_writer.py
+++ b/entropylab/pipeline/api/data_writer.py
@@ -1,12 +1,9 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
-from typing import Any, Optional, Type
-from warnings import warn
+from typing import Any
-from bokeh.models import Renderer
-from bokeh.plotting import Figure
-from matplotlib.figure import Figure as matplotlibFigure
+from matplotlib.figure import Figure
from plotly import graph_objects as go
@@ -75,59 +72,6 @@ class Debug:
extra: str
-class PlotGenerator(ABC):
- """
- An abstract class for plots.
- Implementations of this class will let Entropy to save and view plots.
- Every implementation can either implement all plotting functions
- (within the different environments), or just part of it.
- """
-
- def __init__(self) -> None:
- super().__init__()
-
- @abstractmethod
- def plot_bokeh(self, figure: Figure, data, **kwargs) -> Renderer:
- """
- plot the given data within the Bokeh Figure
- :param figure: Bokeh figure to plot in
- :param data: plot data
- :param kwargs: extra parameters for plotting
- """
- pass
-
- @abstractmethod
- def plot_matplotlib(self, figure: matplotlibFigure, data, **kwargs):
- """
- plot the given data within the matplotlib Figure
- :param figure: matplotlib figure
- :param data: plot data
- :param kwargs: extra parameters for plotting
- """
- pass
-
- @abstractmethod
- def plot_plotly(self, figure: go.Figure, data, **kwargs) -> None:
- """
- plot the given data within the plot.ly Figure
- :param figure: plot.ly figure
- :param data: plot data
- :param kwargs: extra parameters for plotting
- """
- pass
-
-
-@dataclass(frozen=True, eq=True)
-class PlotSpec:
- """
- Description and plotting instructions for a plot that will be saved
- """
-
- generator: Optional[Type[PlotGenerator]] = None
- label: Optional[str] = None
- story: Optional[str] = ""
-
-
@dataclass
class NodeData:
"""
@@ -184,24 +128,19 @@ def save_debug(self, experiment_id: int, debug: Debug):
"""
pass
- @abstractmethod
- def save_plot(self, experiment_id: int, plot: PlotSpec, data: Any):
- """
- save a new plot to the db according to the PlotSpec class
- :param experiment_id: the experiment id
- :param plot: plotting instructions
- :param data: the data of the plot
- """
- warn(
- "This method will soon be deprecated. Please use save_figure() instead",
- PendingDeprecationWarning,
- stacklevel=2,
- )
+ def save_figure(self, experiment_id: int, figure: go.Figure) -> None:
+ """
+ saves a new plotly figure to the db and associates it with an experiment
+
+ :param experiment_id: the id of the experiment to associate the figure to
+ :param figure: the figure to save to the database
+ """
pass
- def save_figure(self, experiment_id: int, figure: go.Figure) -> None:
+ def save_matplotlib_figure(self, experiment_id: int, figure: Figure) -> None:
"""
- save a new plotly figure to the db and associates it with an experiment
+ saves a new matplotlib figure to the db (as a base64-encoded string
+ representation of a PNG image) and associates it with an experiment
:param experiment_id: the id of the experiment to associate the figure to
:param figure: the figure to save to the database
diff --git a/entropylab/pipeline/api/execution.py b/entropylab/pipeline/api/execution.py
index 30407858..3b5fb331 100644
--- a/entropylab/pipeline/api/execution.py
+++ b/entropylab/pipeline/api/execution.py
@@ -4,13 +4,12 @@
from plotly import graph_objects as go
+from entropylab.components.lab_topology import ExperimentResources
from entropylab.pipeline.api.data_writer import (
DataWriter,
RawResultData,
Metadata,
- PlotSpec,
)
-from entropylab.components.lab_topology import ExperimentResources
class EntropyContext:
@@ -60,15 +59,6 @@ def add_metadata(self, label: str, metadata: Any):
self._exp_id, Metadata(label, self._stage_id, metadata)
)
- def add_plot(self, plot: PlotSpec, data: Any):
- """
- saves a new plot from this experiment in the database
-
- :param plot: description and plotting instructions
- :param data: the data for plotting
- """
- self._data_writer.save_plot(self._exp_id, plot, data)
-
def add_figure(self, figure: go.Figure) -> None:
"""
saves a new figure from this experiment in the database
diff --git a/entropylab/pipeline/api/memory_reader_writer.py b/entropylab/pipeline/api/memory_reader_writer.py
index 8e2b6583..b8199b1f 100644
--- a/entropylab/pipeline/api/memory_reader_writer.py
+++ b/entropylab/pipeline/api/memory_reader_writer.py
@@ -1,8 +1,9 @@
import random
from datetime import datetime
from time import time_ns
-from typing import List, Optional, Iterable, Any, Dict, Tuple
+from typing import List, Optional, Iterable, Dict, Tuple
+import matplotlib.figure
from pandas import DataFrame
from plotly import graph_objects as go
@@ -13,10 +14,10 @@
MetadataRecord,
ExperimentRecord,
ScriptViewer,
- PlotRecord,
FigureRecord,
+ MatplotlibFigureRecord,
)
-from entropylab.pipeline.api.data_writer import DataWriter, PlotSpec, NodeData
+from entropylab.pipeline.api.data_writer import DataWriter, NodeData
from entropylab.pipeline.api.data_writer import (
ExperimentInitialData,
ExperimentEndData,
@@ -24,6 +25,9 @@
Metadata,
Debug,
)
+from entropylab.pipeline.results_backend.sqlalchemy.db import (
+ matplotlib_figure_to_img_src,
+)
class MemoryOnlyDataReaderWriter(DataWriter, DataReader):
@@ -41,8 +45,8 @@ def __init__(self):
self._results: List[Tuple[RawResultData, datetime]] = []
self._metadata: List[Tuple[Metadata, datetime]] = []
self._debug: Optional[Debug] = None
- self._plot: Dict[PlotSpec, Any] = {}
self._figure: Dict[int, List[FigureRecord]] = {}
+ self._matplotlib_figure: Dict[int, List[MatplotlibFigureRecord]] = {}
self._nodes: List[NodeData] = []
def save_experiment_initial_data(self, initial_data: ExperimentInitialData) -> int:
@@ -61,9 +65,6 @@ def save_metadata(self, experiment_id: int, metadata: Metadata):
def save_debug(self, experiment_id: int, debug: Debug):
self._debug = debug
- def save_plot(self, experiment_id: int, plot: PlotSpec, data: Any):
- self._plot[plot] = data
-
def save_figure(self, experiment_id: int, figure: go.Figure) -> None:
figure_record = FigureRecord(
experiment_id=experiment_id,
@@ -76,10 +77,26 @@ def save_figure(self, experiment_id: int, figure: go.Figure) -> None:
else:
self._figure[experiment_id] = [figure_record]
+ def save_matplotlib_figure(
+ self, experiment_id: int, figure: matplotlib.figure.Figure
+ ) -> None:
+ record = MatplotlibFigureRecord(
+ experiment_id=experiment_id,
+ id=random.randint(0, 2**31 - 1),
+ img_src=matplotlib_figure_to_img_src(figure),
+ time=datetime.now(),
+ )
+ if experiment_id in self._matplotlib_figure:
+ self._matplotlib_figure[experiment_id].append(record)
+ else:
+ self._matplotlib_figure[experiment_id] = [record]
+
def save_node(self, experiment_id: int, node_data: NodeData):
self._nodes.append(node_data)
- def get_experiments_range(self, starting_from_index: int, count: int) -> DataFrame:
+ def get_experiments_range(
+ self, starting_from_index: int, count: int, success: bool = None
+ ) -> DataFrame:
raise NotImplementedError()
def get_experiments(
@@ -163,22 +180,14 @@ def get_debug_record(self, experiment_id: int) -> Optional[DebugRecord]:
else:
return None
- def get_plots(self, experiment_id: int) -> List[PlotRecord]:
- return [
- PlotRecord(
- experiment_id,
- id(plot),
- self._plot[plot],
- plot.generator(),
- plot.label,
- plot.story,
- )
- for plot in self._plot
- ]
-
def get_figures(self, experiment_id: int) -> List[FigureRecord]:
return self._figure[experiment_id]
+ def get_matplotlib_figures(
+ self, experiment_id: int
+ ) -> List[MatplotlibFigureRecord]:
+ return self._matplotlib_figure[experiment_id]
+
def get_node_stage_ids_by_label(
self, label: str, experiment_id: Optional[int] = None
) -> List[int]:
diff --git a/entropylab/pipeline/api/plot.py b/entropylab/pipeline/api/plot.py
deleted file mode 100644
index b6533bd2..00000000
--- a/entropylab/pipeline/api/plot.py
+++ /dev/null
@@ -1,111 +0,0 @@
-from typing import List
-
-import plotly
-import plotly.express as px
-import plotly.graph_objects as go
-from bokeh.models import Renderer
-from bokeh.plotting import Figure
-from matplotlib.axes import Axes as Axes
-from matplotlib.figure import Figure as matplotlibFigure
-
-from entropylab.pipeline.api.data_writer import PlotGenerator
-
-
-class LinePlotGenerator(PlotGenerator):
- def __init__(self) -> None:
- super().__init__()
-
- def plot_matplotlib(self, figure: matplotlibFigure, data, **kwargs) -> Renderer:
- raise NotImplementedError()
-
- def plot_bokeh(self, figure: Figure, data, **kwargs) -> Renderer:
- if isinstance(data, List) and len(data) == 2 and len(data[0]) == len(data[1]):
- x = data[0]
- y = data[1]
- return figure.line(
- x,
- y,
- color=kwargs.get("color", "blue"),
- legend_label=kwargs.get("label", ""),
- )
- else:
- raise TypeError("data type is not supported")
-
- def plot_plotly(self, figure: plotly.graph_objects.Figure, data, **kwargs):
- if isinstance(data, List) and len(data) == 2 and len(data[0]) == len(data[1]):
- x = data[0]
- y = data[1]
- color = kwargs.pop("color", "blue")
- figure.add_trace(
- go.Scatter(
- mode="lines",
- x=x,
- y=y,
- line_color=color,
- **kwargs,
- )
- )
- return figure
- else:
- raise TypeError("data type is not supported")
-
-
-class CirclePlotGenerator(PlotGenerator):
- def __init__(self) -> None:
- super().__init__()
-
- def plot_matplotlib(self, figure: Axes, data, **kwargs):
- raise NotImplementedError()
-
- def plot_bokeh(self, figure: Figure, data, **kwargs) -> Renderer:
- if isinstance(data, List) and len(data) == 2 and len(data[0]) == len(data[1]):
- x = data[0]
- y = data[1]
- color = (kwargs.get("color", "blue"),)
- return figure.circle(
- x,
- y,
- size=10,
- color=color,
- legend_label=kwargs.get("label", ""),
- alpha=0.5,
- )
- else:
- raise TypeError("data type is not supported")
-
- def plot_plotly(self, figure: plotly.graph_objects.Figure, data, **kwargs):
- if isinstance(data, List) and len(data) == 2 and len(data[0]) == len(data[1]):
- x = data[0]
- y = data[1]
- color = kwargs.pop("color", "blue")
- # noinspection PyTypeChecker
- figure.add_trace(
- go.Scatter(
- mode="markers",
- x=x,
- y=y,
- marker_color=color,
- marker_size=10,
- **kwargs,
- )
- )
- return figure
- else:
- raise TypeError("data type is not supported")
-
-
-class ImShowPlotGenerator(PlotGenerator):
- def __init__(self) -> None:
- super().__init__()
-
- def plot_plotly(self, figure: go.Figure, data, **kwargs) -> None:
- headtmap_fig = px.imshow(data)
- figure.add_trace(headtmap_fig.data[0])
-
- return figure
-
- def plot_bokeh(self, figure: Figure, data, **kwargs) -> Renderer:
- raise NotImplementedError()
-
- def plot_matplotlib(self, figure: matplotlibFigure, data, **kwargs):
- raise NotImplementedError()
diff --git a/entropylab/pipeline/results_backend/sqlalchemy/alembic/versions/2022-03-17-15-57-28_f1ada2484fe2_create_figures_table.py b/entropylab/pipeline/results_backend/sqlalchemy/alembic/versions/2022-03-17-15-57-28_f1ada2484fe2_create_figures_table.py
index 96960c32..8b80deb2 100644
--- a/entropylab/pipeline/results_backend/sqlalchemy/alembic/versions/2022-03-17-15-57-28_f1ada2484fe2_create_figures_table.py
+++ b/entropylab/pipeline/results_backend/sqlalchemy/alembic/versions/2022-03-17-15-57-28_f1ada2484fe2_create_figures_table.py
@@ -27,10 +27,10 @@ def upgrade():
"Figures",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("experiment_id", sa.Integer(), nullable=False),
- sa.Column("figure", sa.String(), nullable=False),
+ sa.Column("img_src", sa.String(), nullable=False),
sa.Column("time", sa.DATETIME(), nullable=False),
sa.ForeignKeyConstraint(
- ["experiment_id"], ["Experiments.id"], ondelete="CASCADE"
+ ("experiment_id",), ("Experiments.id",), ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
)
diff --git a/entropylab/pipeline/results_backend/sqlalchemy/alembic/versions/2022-07-03-08-56-23_da8d38e19ff8_matplotlib_figures.py b/entropylab/pipeline/results_backend/sqlalchemy/alembic/versions/2022-07-03-08-56-23_da8d38e19ff8_matplotlib_figures.py
new file mode 100644
index 00000000..573a8a4b
--- /dev/null
+++ b/entropylab/pipeline/results_backend/sqlalchemy/alembic/versions/2022-07-03-08-56-23_da8d38e19ff8_matplotlib_figures.py
@@ -0,0 +1,45 @@
+"""Matplotlib figures
+
+Revision ID: da8d38e19ff8
+Revises: 09f3b5a1689c
+Create Date: 2022-07-03 08:56:23.627973+00:00
+
+"""
+
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+from sqlalchemy.engine import Inspector
+
+# revision identifiers, used by Alembic.
+from entropylab.pipeline.results_backend.sqlalchemy.model import MatplotlibFigureTable
+
+revision = "da8d38e19ff8"
+down_revision = "09f3b5a1689c"
+branch_labels = None
+depends_on = None
+
+TABLE_NAME = MatplotlibFigureTable.__tablename__
+
+
+def upgrade():
+ conn = op.get_bind()
+ inspector = Inspector.from_engine(conn)
+ tables = inspector.get_table_names()
+ if TABLE_NAME not in tables:
+ op.create_table(
+ TABLE_NAME,
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("experiment_id", sa.Integer(), nullable=False),
+ sa.Column("img_src", sa.String(), nullable=False),
+ sa.Column("time", sa.DATETIME(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ("experiment_id",), ("Experiments.id",), ondelete="CASCADE"
+ ),
+ sa.PrimaryKeyConstraint("id"),
+ )
+
+
+def downgrade():
+ op.drop_table(TABLE_NAME)
diff --git a/entropylab/pipeline/results_backend/sqlalchemy/alembic/versions/2022-08-07-11-53-59_997e336572b8_paramstore_json_v0_3.py b/entropylab/pipeline/results_backend/sqlalchemy/alembic/versions/2022-08-07-11-53-59_997e336572b8_paramstore_json_v0_3.py
index 37563680..290ea142 100644
--- a/entropylab/pipeline/results_backend/sqlalchemy/alembic/versions/2022-08-07-11-53-59_997e336572b8_paramstore_json_v0_3.py
+++ b/entropylab/pipeline/results_backend/sqlalchemy/alembic/versions/2022-08-07-11-53-59_997e336572b8_paramstore_json_v0_3.py
@@ -20,7 +20,7 @@
# revision identifiers, used by Alembic.
revision = "997e336572b8"
-down_revision = "09f3b5a1689c"
+down_revision = "da8d38e19ff8"
branch_labels = None
depends_on = None
diff --git a/entropylab/pipeline/results_backend/sqlalchemy/db.py b/entropylab/pipeline/results_backend/sqlalchemy/db.py
index 489c0240..7aa4002e 100644
--- a/entropylab/pipeline/results_backend/sqlalchemy/db.py
+++ b/entropylab/pipeline/results_backend/sqlalchemy/db.py
@@ -1,9 +1,11 @@
+import base64
+import io
from datetime import datetime
-from typing import List, TypeVar, Optional, ContextManager, Iterable, Union, Any
+from typing import List, TypeVar, Optional, ContextManager, Iterable, Union
from typing import Set
-from warnings import warn
import jsonpickle
+import matplotlib
import pandas as pd
from pandas import DataFrame
from plotly import graph_objects as go
@@ -30,8 +32,8 @@
ResultRecord,
MetadataRecord,
DebugRecord,
- PlotRecord,
FigureRecord,
+ MatplotlibFigureRecord,
)
from entropylab.pipeline.api.data_writer import (
DataWriter,
@@ -40,19 +42,18 @@
RawResultData,
Metadata,
Debug,
- PlotSpec,
NodeData,
)
from entropylab.pipeline.api.errors import EntropyError
from entropylab.pipeline.results_backend.sqlalchemy.db_initializer import _DbInitializer
from entropylab.pipeline.results_backend.sqlalchemy.model import (
ExperimentTable,
- PlotTable,
ResultTable,
DebugTable,
MetadataTable,
NodeTable,
FigureTable,
+ MatplotlibFigureTable,
)
T = TypeVar(
@@ -141,19 +142,17 @@ def save_debug(self, experiment_id: int, debug: Debug):
transaction = DebugTable.from_model(experiment_id, debug)
return self._execute_transaction(transaction)
- def save_plot(self, experiment_id: int, plot: PlotSpec, data: Any):
- warn(
- "This method will soon be deprecated. Please use save_figure() instead",
- PendingDeprecationWarning,
- stacklevel=2,
- )
- transaction = PlotTable.from_model(experiment_id, plot, data)
- return self._execute_transaction(transaction)
-
def save_figure(self, experiment_id: int, figure: go.Figure) -> None:
transaction = FigureTable.from_model(experiment_id, figure)
return self._execute_transaction(transaction)
+ def save_matplotlib_figure(
+ self, experiment_id: int, figure: matplotlib.figure.Figure
+ ) -> None:
+ img_src = matplotlib_figure_to_img_src(figure)
+ transaction = MatplotlibFigureTable.from_model(experiment_id, img_src)
+ return self._execute_transaction(transaction)
+
def save_node(self, experiment_id: int, node_data: NodeData):
transaction = NodeTable.from_model(experiment_id, node_data)
return self._execute_transaction(transaction)
@@ -272,27 +271,24 @@ def get_all_results_with_label(self, exp_id, name) -> DataFrame:
)
return self._query_pandas(query)
- def get_plots(self, experiment_id: int) -> List[PlotRecord]:
- warn(
- "This method will soon be deprecated. Please use get_figures() instead",
- PendingDeprecationWarning,
- stacklevel=2,
- )
+ def get_figures(self, experiment_id: int) -> List[FigureRecord]:
with self._session_maker() as sess:
query = (
- sess.query(PlotTable)
- .filter(PlotTable.experiment_id == int(experiment_id))
+ sess.query(FigureTable)
+ .filter(FigureTable.experiment_id == int(experiment_id))
.all()
)
if query:
- return [plot.to_record() for plot in query]
+ return [figure.to_record() for figure in query]
return []
- def get_figures(self, experiment_id: int) -> List[FigureRecord]:
+ def get_matplotlib_figures(
+ self, experiment_id: int
+ ) -> List[MatplotlibFigureRecord]:
with self._session_maker() as sess:
query = (
- sess.query(FigureTable)
- .filter(FigureTable.experiment_id == int(experiment_id))
+ sess.query(MatplotlibFigureTable)
+ .filter(MatplotlibFigureTable.experiment_id == int(experiment_id))
.all()
)
if query:
@@ -522,3 +518,14 @@ def __hdf5_storage_enabled(self) -> bool:
else:
enabled = self._enable_hdf5_storage
return enabled
+
+
+def matplotlib_figure_to_img_src(figure: matplotlib.figure.Figure):
+ """Converts a matplotlib Figure instance into a string that can be used as the
+ 'src' attribute of an HTML """
+ buf = io.BytesIO()
+ figure.savefig(buf, format="png")
+ # figure.close()
+ data = base64.b64encode(buf.getbuffer()).decode("utf8")
+ img_src = "data:image/png;base64,{}".format(data)
+ return img_src
diff --git a/entropylab/pipeline/results_backend/sqlalchemy/model.py b/entropylab/pipeline/results_backend/sqlalchemy/model.py
index 5e750aa7..f4d9bf37 100644
--- a/entropylab/pipeline/results_backend/sqlalchemy/model.py
+++ b/entropylab/pipeline/results_backend/sqlalchemy/model.py
@@ -3,7 +3,6 @@
import pickle
from datetime import datetime
from io import BytesIO
-from typing import Any
import numpy as np
from plotly import graph_objects as go
@@ -28,15 +27,14 @@
ResultRecord,
MetadataRecord,
DebugRecord,
- PlotRecord,
FigureRecord,
+ MatplotlibFigureRecord,
)
from entropylab.pipeline.api.data_writer import (
ExperimentInitialData,
RawResultData,
Metadata,
Debug,
- PlotSpec,
NodeData,
)
from entropylab.pipeline.api.errors import EntropyError
@@ -99,7 +97,6 @@ class ExperimentTable(Base):
results = relationship("ResultTable", cascade="all, delete-orphan")
experiment_metadata = relationship("MetadataTable", cascade="all, delete-orphan")
debug = relationship("DebugTable", cascade="all, delete-orphan")
- plots = relationship("PlotTable", cascade="all, delete-orphan")
def __repr__(self):
return f"<_Experiment(id='{self.id}')>"
@@ -240,72 +237,56 @@ def from_model(experiment_id: int, node_data: NodeData):
)
-class PlotTable(Base):
- __tablename__ = "Plots"
-
+class FigureTable(Base):
+ __tablename__ = "Figures"
id = Column(Integer, primary_key=True)
experiment_id = Column(Integer, ForeignKey("Experiments.id", ondelete="CASCADE"))
- plot_data = Column(BLOB)
- data_type = Column(Enum(ResultDataType))
- generator_module = Column(String)
- generator_class = Column(String)
+ figure = Column(String)
time = Column(DATETIME)
- label = Column(String)
- story = Column(String)
def __repr__(self):
- return f""
+ return f""
- def to_record(self) -> PlotRecord:
- data = _decode_serialized_data(self.plot_data, self.data_type)
- generator = _get_class(self.generator_module, self.generator_class)
- return PlotRecord(
+ def to_record(self) -> FigureRecord:
+ return FigureRecord(
experiment_id=self.experiment_id,
id=self.id,
- label=self.label,
- story=self.story,
- plot_data=data,
- generator=generator(),
+ figure=from_json(self.figure),
+ time=self.time,
)
@staticmethod
- def from_model(experiment_id: int, plot: PlotSpec, data: Any):
- data_type, serialized_data = _encode_serialized_data(data)
- return PlotTable(
+ def from_model(experiment_id: int, figure: go.Figure):
+ return FigureTable(
experiment_id=experiment_id,
- plot_data=serialized_data,
- data_type=data_type,
- generator_module=plot.generator.__module__,
- generator_class=plot.generator.__qualname__,
+ figure=to_json(figure),
time=datetime.now(),
- label=plot.label,
- story=plot.story,
)
-class FigureTable(Base):
- __tablename__ = "Figures"
+class MatplotlibFigureTable(Base):
+ __tablename__ = "MatplotlibFigures"
id = Column(Integer, primary_key=True)
experiment_id = Column(Integer, ForeignKey("Experiments.id", ondelete="CASCADE"))
- figure = Column(String)
+ img_src = Column(String)
time = Column(DATETIME)
def __repr__(self):
- return f""
+ return f""
- def to_record(self) -> FigureRecord:
- return FigureRecord(
+ def to_record(self) -> MatplotlibFigureRecord:
+ return MatplotlibFigureRecord(
experiment_id=self.experiment_id,
id=self.id,
- figure=from_json(self.figure),
+ img_src=self.img_src,
time=self.time,
)
@staticmethod
- def from_model(experiment_id: int, figure: go.Figure):
- return FigureTable(
+ def from_model(experiment_id: int, img_src: str):
+ return MatplotlibFigureTable(
experiment_id=experiment_id,
- figure=to_json(figure),
+ img_src=img_src,
time=datetime.now(),
)
diff --git a/entropylab/pipeline/results_backend/sqlalchemy/tests/db_templates/empty_after_2022-07-03-08-56-23_da8d38e19ff8_matplotlib_figures.db b/entropylab/pipeline/results_backend/sqlalchemy/tests/db_templates/empty_after_2022-07-03-08-56-23_da8d38e19ff8_matplotlib_figures.db
new file mode 100644
index 00000000..730a9208
Binary files /dev/null and b/entropylab/pipeline/results_backend/sqlalchemy/tests/db_templates/empty_after_2022-07-03-08-56-23_da8d38e19ff8_matplotlib_figures.db differ
diff --git a/entropylab/pipeline/results_backend/sqlalchemy/tests/entropynodes/library/TestNode.py b/entropylab/pipeline/results_backend/sqlalchemy/tests/entropynodes/library/TestNode.py
new file mode 100644
index 00000000..e69de29b
diff --git a/entropylab/pipeline/results_backend/sqlalchemy/tests/entropynodes/schema/TestNode.json b/entropylab/pipeline/results_backend/sqlalchemy/tests/entropynodes/schema/TestNode.json
new file mode 100644
index 00000000..f6b88ba9
--- /dev/null
+++ b/entropylab/pipeline/results_backend/sqlalchemy/tests/entropynodes/schema/TestNode.json
@@ -0,0 +1 @@
+{"name": "TestNode", "description": "test", "command": "python3", "bin": "_jb_pytest_runner.py", "icon": "bootstrap/person-circle.svg", "inputs": [{"description": {"stream_1": "desc", "state_1": "desc"}, "units": {"stream_1": "test", "state_1": "test"}, "type": {"stream_1": 2, "state_1": 1}}], "outputs": []}
\ No newline at end of file
diff --git a/entropylab/pipeline/results_backend/sqlalchemy/tests/test_db.py b/entropylab/pipeline/results_backend/sqlalchemy/tests/test_db.py
index d98ec3fa..06ad167a 100644
--- a/entropylab/pipeline/results_backend/sqlalchemy/tests/test_db.py
+++ b/entropylab/pipeline/results_backend/sqlalchemy/tests/test_db.py
@@ -2,6 +2,7 @@
from datetime import datetime
import pytest
+from matplotlib import pyplot as plt
from plotly import express as px
from entropylab import SqlAlchemyDB, RawResultData
@@ -75,6 +76,24 @@ def test_save_figure_(initialized_project_dir_path):
assert actual.figure == figure
+def test_save_matplotlib_figure_(initialized_project_dir_path):
+ # arrange
+ db = SqlAlchemyDB(initialized_project_dir_path)
+ x = [1, 2, 3, 4]
+ y = [10, 40, 20, 30]
+ plt.scatter(x, y)
+ figure = plt.gcf()
+ # act
+ db.save_matplotlib_figure(0, figure)
+ # assert
+ actual = db.get_matplotlib_figures(0)[0]
+ assert actual.img_src.startswith(
+ "data:image/png;base64,"
+ "iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAYAAAA10dzkAAAAOXRFWHRTb2Z0d2FyZQBN"
+ "YXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9F"
+ )
+
+
def test_get_experiments_range_reads_all_columns():
# arrange
target = SqlAlchemyDB()
diff --git a/entropylab/pipeline/results_backend/sqlalchemy/tests/test_db_upgrader.py b/entropylab/pipeline/results_backend/sqlalchemy/tests/test_db_upgrader.py
index 3950740b..d27a87d8 100644
--- a/entropylab/pipeline/results_backend/sqlalchemy/tests/test_db_upgrader.py
+++ b/entropylab/pipeline/results_backend/sqlalchemy/tests/test_db_upgrader.py
@@ -153,6 +153,8 @@ def test__migrate_metadata_to_hdf5(initialized_project_dir_path):
"empty_after_2022-05-19-09-12-01_06140c96c8c4_wrapping_param_store_values.db",
"empty_after_2022-06-16-09-26-06_7fa75ca1263f_del_results_and_metadata.db",
"empty_after_2022-06-23-10-16-39_273a9fae6206_experiments_favorite_col.db",
+ "empty_after_2022-06-28-12-13-39_09f3b5a1689c_fixing_param_qualified_name.db",
+ "empty_after_2022-07-03-08-56-23_da8d38e19ff8_matplotlib_figures.db",
],
indirect=True,
)
diff --git a/entropylab/pipeline/results_backend/sqlalchemy/tests/test_migrations.py b/entropylab/pipeline/results_backend/sqlalchemy/tests/test_migrations.py
index 75f5924e..07d994d8 100644
--- a/entropylab/pipeline/results_backend/sqlalchemy/tests/test_migrations.py
+++ b/entropylab/pipeline/results_backend/sqlalchemy/tests/test_migrations.py
@@ -19,8 +19,7 @@ def test_ctor_creates_up_to_date_schema_when_in_memory(path: str):
[
None, # new db
"empty.db", # existing but empty
- "empty_after_2022-08-07-11-53-59_997e336572b8_paramstore_json_v0_3.db"
- # "empty_after_2022-06-28-12-13-39_09f3b5a1689c_fixing_param_qualified_name.db"
+ "empty_after_2022-08-07-11-53-59_997e336572b8_paramstore_json_v0_3.db",
# ⬆ latest version in pipeline/results_backend/sqlalchemy/alembic/versions
],
indirect=True,
diff --git a/entropylab/pipeline/results_backend/sqlalchemy/tests/test_model.py b/entropylab/pipeline/results_backend/sqlalchemy/tests/test_model.py
index 69604e4d..50a23589 100644
--- a/entropylab/pipeline/results_backend/sqlalchemy/tests/test_model.py
+++ b/entropylab/pipeline/results_backend/sqlalchemy/tests/test_model.py
@@ -1,7 +1,12 @@
+import datetime
+
from plotly import express as px
from plotly.io import to_json
-from entropylab.pipeline.results_backend.sqlalchemy.model import FigureTable
+from entropylab.pipeline.results_backend.sqlalchemy.model import (
+ FigureTable,
+ MatplotlibFigureTable,
+)
class TestFigureTable:
@@ -24,3 +29,29 @@ def test_from_model(self):
assert actual.experiment_id == 1
assert actual.figure == to_json(figure)
assert actual.time is not None
+
+
+class TestMatplotlibFigureTable:
+ def test_to_record(self):
+ target = MatplotlibFigureTable(
+ id=42,
+ experiment_id=1337,
+ img_src="data:image/png;base64,",
+ time=datetime.datetime.utcnow(),
+ )
+
+ actual = target.to_record()
+
+ assert actual.id == target.id
+ assert actual.experiment_id == target.experiment_id
+ assert actual.img_src == target.img_src
+ assert actual.time == target.time
+
+ def test_from_model(self):
+ target = MatplotlibFigureTable()
+
+ actual = target.from_model(1, "data:image/png;base64,")
+
+ assert actual.experiment_id == 1
+ assert actual.img_src == "data:image/png;base64,"
+ assert actual.time is not None
diff --git a/entropylab/pipeline/tests/test_async_graph.py b/entropylab/pipeline/tests/test_async_graph.py
index 8d81bf8d..8414829b 100644
--- a/entropylab/pipeline/tests/test_async_graph.py
+++ b/entropylab/pipeline/tests/test_async_graph.py
@@ -1,7 +1,6 @@
import asyncio
import numpy as np
-from bokeh.plotting import Figure
from entropylab.pipeline.graph_experiment import (
Graph,
@@ -136,11 +135,6 @@ def test_async_graph():
results = graph.run().results
print(results.get_experiment_info())
- plots = results.get_plots()
- for plot in plots:
- figure = Figure()
- plot.generator.plot_bokeh(figure, plot.plot_data)
- # save(figure, f"try{plot.label}.html")
def test_async_graph_run_to_node():
diff --git a/entropylab/pipeline/tests/test_executor.py b/entropylab/pipeline/tests/test_executor.py
index 339e5116..d2beef49 100644
--- a/entropylab/pipeline/tests/test_executor.py
+++ b/entropylab/pipeline/tests/test_executor.py
@@ -1,11 +1,10 @@
from datetime import datetime
import pytest
+from plotly import express as px
-from entropylab.pipeline.api.data_writer import PlotSpec
-from entropylab.pipeline.api.execution import EntropyContext
-from entropylab.pipeline.api.plot import CirclePlotGenerator, LinePlotGenerator
from entropylab.components.lab_topology import LabResources, ExperimentResources
+from entropylab.pipeline.api.execution import EntropyContext
from entropylab.pipeline.results_backend.sqlalchemy.db import SqlAlchemyDB
from entropylab.pipeline.script_experiment import Script, script_experiment
from entropylab.pipeline.tests.mock_instruments import MockScope
@@ -64,14 +63,10 @@ def an_experiment_with_plot(experiment: EntropyContext):
experiment.add_result("b_result" + str(i), b1 + i + datetime.now().microsecond)
micro = datetime.now().microsecond
- experiment.add_plot(
- PlotSpec(
- label="plot",
- story="created this plot in experiment",
- generator=CirclePlotGenerator,
- ),
- data=[
- [
+
+ experiment.add_figure(
+ px.scatter(
+ x=[
1 * micro,
2 * micro,
3 * micro,
@@ -81,20 +76,15 @@ def an_experiment_with_plot(experiment: EntropyContext):
7 * micro,
8 * micro,
],
- [0, 1, 2, 3, 4, 5, 6, 7],
- ],
+ y=[0, 1, 2, 3, 4, 5, 6, 7],
+ )
)
- experiment.add_plot(
- PlotSpec(
- label="another plot",
- story="just showing off now",
- generator=LinePlotGenerator,
- ),
- data=[
- [1, 2, 3, 4, 5, 6, 7, 8],
- [4, 5, 6, 7, 0, 1, 2, 3],
- ],
+ experiment.add_figure(
+ px.line(
+ x=[1, 2, 3, 4, 5, 6, 7, 8],
+ y=[4, 5, 6, 7, 0, 1, 2, 3],
+ )
)
diff --git a/entropylab/pipeline/tests/test_graph.py b/entropylab/pipeline/tests/test_graph.py
index 4114c7f1..1e049f2a 100644
--- a/entropylab/pipeline/tests/test_graph.py
+++ b/entropylab/pipeline/tests/test_graph.py
@@ -1,14 +1,9 @@
import asyncio
-import os
from time import sleep
import pytest
-from bokeh.io import save
-from bokeh.plotting import Figure
-from entropylab.pipeline.api.data_writer import PlotSpec
from entropylab.pipeline.api.execution import EntropyContext
-from entropylab.pipeline.api.plot import CirclePlotGenerator
from entropylab.pipeline.graph_experiment import (
Graph,
PyNode,
@@ -52,10 +47,6 @@ def d(x, y):
async def e(y, z, context: EntropyContext):
print(f"Node e resting for {y / z}")
print(f"e Result: {y + z}")
- context.add_plot(
- PlotSpec(CirclePlotGenerator, "the best plot"),
- data=[[0, 1, 2, 3, 4, 5], [y + z, 7, 6, 20, 10, 11]],
- )
return {"y_z": [0, 1, 2, 3, 4, 5, y + z, 7, 6, 20, 10, 11]}
@@ -147,14 +138,6 @@ def test_sync_graph():
results = Graph(None, g, "run_a").run().results
print(results.get_experiment_info())
- # TODO: Use figures instead of plots here
- plots = results.get_plots()
- for plot in plots:
- figure = Figure()
- plot.generator.plot_bokeh(figure, plot.plot_data)
- if not os.path.exists("tests_cache"):
- os.mkdir("tests_cache")
- save(figure, f"tests_cache/bokeh-exported-{plot.label}.html")
def test_sync_graph_run_to_node():
diff --git a/entropylab/pipeline/tests/test_plot.py b/entropylab/pipeline/tests/test_plot.py
deleted file mode 100644
index 03c079c0..00000000
--- a/entropylab/pipeline/tests/test_plot.py
+++ /dev/null
@@ -1,21 +0,0 @@
-import numpy as np
-import plotly
-
-from entropylab.pipeline.api.plot import CirclePlotGenerator, ImShowPlotGenerator
-from plotly.graph_objects import Figure
-
-
-def test_circle_plot_plotly():
- target = CirclePlotGenerator()
- figure = Figure()
- data = [[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]]
- target.plot_plotly(figure, data)
- i = 0
-
-
-def test_imshow_plot_plotly():
- target = ImShowPlotGenerator()
- figure = Figure()
- data = np.random.rand(10, 10)
- target.plot_plotly(figure, data)
- assert isinstance(figure.data[0], plotly.graph_objs._heatmap.Heatmap)