diff --git a/entropylab/cli/tests/test_main.py b/entropylab/cli/tests/test_main.py index 098f139a..361b35a0 100644 --- a/entropylab/cli/tests/test_main.py +++ b/entropylab/cli/tests/test_main.py @@ -13,6 +13,8 @@ def test_init_with_no_args(): args.directory = "" # act init(args) + # clean up + shutil.rmtree(".entropy") def test_init_with_current_dir(): diff --git a/entropylab/dashboard/assets/dashboard.css b/entropylab/dashboard/assets/dashboard.css index 578b382d..67919855 100644 --- a/entropylab/dashboard/assets/dashboard.css +++ b/entropylab/dashboard/assets/dashboard.css @@ -33,7 +33,7 @@ html,body { #experiments-title, -#plots-title { +#figures-title { margin-top:20px; } @@ -118,7 +118,7 @@ input.current-page::placeholder { .add-button-container { margin-top:20px; - display: flex + display: flex; align-items: center; justify-content: center; } @@ -139,7 +139,7 @@ input.current-page::placeholder { height: 60px; margin: 4px; border-color: #666666; - display: flex + display: flex; align-items: center; justify-content: center; } diff --git a/entropylab/dashboard/pages/results/callbacks.py b/entropylab/dashboard/pages/results/callbacks.py index 98c032d8..1612e910 100644 --- a/entropylab/dashboard/pages/results/callbacks.py +++ b/entropylab/dashboard/pages/results/callbacks.py @@ -13,27 +13,28 @@ from entropylab.dashboard.pages.results.dashboard_data import FAVORITE_TRUE from entropylab.dashboard.theme import ( colors, - dark_plot_layout, + dark_figure_layout, ) from entropylab.logger import logger -from entropylab.pipeline.api.data_reader import PlotRecord, FigureRecord +from entropylab.pipeline.api.data_reader import ( + FigureRecord, + MatplotlibFigureRecord, +) from entropylab.pipeline.api.errors import EntropyError REFRESH_INTERVAL_IN_MILLIS = 3000 EXPERIMENTS_PAGE_SIZE = 6 +IMG_TAB_KEY = "m" +FIGURE_TAB_KEY = "f" def register_callbacks(app, dashboard_data_reader): - """Initialize the results dashboard Dash app to display an Entropy project + """Add callbacks and helper methods to the dashboard Dash app :param app the Dash app to add the callbacks to :param dashboard_data_reader a DashboardDataReader instance to read dashboard data from""" - """ Creating and setting up our Dash app """ - - """ CALLBACKS and their helper functions """ - @app.callback( Output("experiments-table", "data"), Output("empty-project-modal", "is_open"), @@ -68,17 +69,16 @@ def open_failed_plotting_alert_when_its_not_empty(children): return children != "" @app.callback( - Output("plot-tabs", "children"), + Output("figure-tabs", "children"), Output("figures-by-key", "data"), Output("prev-selected-rows", "data"), Output("failed-plotting-alert", "children"), - Output("add-button", "disabled"), Input("experiments-table", "selected_rows"), State("experiments-table", "data"), State("figures-by-key", "data"), State("prev-selected-rows", "data"), ) - def render_plot_tabs_from_selected_experiments_table_rows( + def render_figure_tabs_from_selected_experiments_table_rows( selected_rows, data, figures_by_key, prev_selected_rows ): result = [] @@ -87,160 +87,160 @@ def render_plot_tabs_from_selected_experiments_table_rows( prev_selected_rows = prev_selected_rows or {} alert_text = "" added_row = get_added_row(prev_selected_rows, selected_rows) - add_button_disabled = False if data and selected_rows: for row_num in selected_rows: alert_on_fail = row_num == added_row exp_id = data[row_num]["id"] try: - plots_and_figures = dashboard_data_reader.get_plot_and_figure_data( - exp_id - ) + figure_records = dashboard_data_reader.get_figure_records(exp_id) except EntropyError: logger.exception( - f"Exception when getting plot/figure data for exp_id={exp_id}" + f"Exception when getting figure data for exp_id={exp_id}" ) if alert_on_fail: alert_text = ( - f"⚠ Error when reading plot/figure data for this " + f"⚠ Error when reading figure data for this " f"experiment. (id: {exp_id})" ) - plots_and_figures = None - if plots_and_figures and len(plots_and_figures) > 0: - failed_plot_ids = [] - figures_by_key = build_plot_tabs( + figure_records = None + if figure_records and len(figure_records) > 0: + failed_figure_keys = [] + figures_by_key = build_figure_tabs( alert_on_fail, - failed_plot_ids, + failed_figure_keys, figures_by_key, - plots_and_figures, + figure_records, result, ) - if len(failed_plot_ids) > 0: + if len(failed_figure_keys) > 0: alert_text = ( - f"⚠ Some plots could not be rendered. " - f"(ids: {','.join(failed_plot_ids)})" + f"⚠ Some figures could not be rendered. " + f"(ids: {','.join(failed_figure_keys)})" ) else: if alert_on_fail and alert_text == "": alert_text = ( - f"⚠ Experiment has no plots to render. (id: {exp_id})" + f"⚠ Experiment has no figures to render. (id: {exp_id})" ) if len(result) == 0: - result = [build_plot_tabs_placeholder()] - add_button_disabled = True + result = [build_figure_tabs_placeholder()] return ( result, figures_by_key, selected_rows, alert_text, - add_button_disabled, ) - def build_plot_tabs( - alert_on_fail, failed_plot_ids, figures_by_key, plots_and_figures, result + def build_figure_tabs( + alert_on_fail, failed_figure_keys, figures_by_key, figure_records, result ): - for plot_or_figure in plots_and_figures: + for figure_record in figure_records: try: color = colors[len(result) % len(colors)] - plot_tab, figures_by_key = build_plot_tab_from_plot_or_figure( - figures_by_key, plot_or_figure, color + figure_tab, figures_by_key = build_tab_from_figure( + figures_by_key, figure_record, color ) - result.append(plot_tab) + result.append(figure_tab) except (EntropyError, TypeError): - logger.exception(f"Failed to render plot id [{plot_or_figure.id}]") + logger.exception( + f"Failed to render figure record id [{figure_record.id}]" + ) if alert_on_fail: - plot_key = f"{plot_or_figure.experiment_id}/{plot_or_figure.id}" - failed_plot_ids.append(plot_key) + figure_key = f"{figure_record.experiment_id}/{figure_record.id}" + failed_figure_keys.append(figure_key) return figures_by_key - def build_plot_tabs_placeholder(): + def build_figure_tabs_placeholder(): return dbc.Tab( html.Div( html.Div( - "Select an experiment above to display its plots here", + "Select an experiment above to display its figures here", className="tab-placeholder-text", ), className="tab-placeholder-container", ), - label="Plots", - tab_id="plot-tab-placeholder", + label="Figures", + tab_id="figure-tab-placeholder", ) - def build_plot_tab_from_plot_or_figure( - figures_by_key, plot_or_figure: PlotRecord | FigureRecord, color: str + def build_tab_from_figure( + figures_by_key, + figure_record: FigureRecord | MatplotlibFigureRecord, + color: str, ) -> (dbc.Tab, Dict): - if isinstance(plot_or_figure, PlotRecord): - # For backwards compatibility with soon to be deprecated Plots API: - plot_rec = cast(PlotRecord, plot_or_figure) - key = f"{plot_rec.experiment_id}/{plot_rec.id}/p" - name = f"Plot {key[:-2]}" - figure = go.Figure() - plot_or_figure.generator.plot_plotly( - figure, - plot_or_figure.plot_data, - name=name, - color=color, - showlegend=False, - ) + if isinstance(figure_record, MatplotlibFigureRecord): + record = cast(MatplotlibFigureRecord, figure_record) + key = f"{record.experiment_id}/{record.id}/{IMG_TAB_KEY}" + name = f"Image {key[:-2]}" + return build_img_tab(record.img_src, name, key), figures_by_key else: - figure_rec = cast(FigureRecord, plot_or_figure) - key = f"{figure_rec.experiment_id}/{figure_rec.id}/f" + record = cast(FigureRecord, figure_record) + key = f"{record.experiment_id}/{record.id}/{FIGURE_TAB_KEY}" name = f"Figure {key[:-2]}" - figure = figure_rec.figure - figure.update_layout(dark_plot_layout) - figures_by_key[key] = dict(figure=figure, color=color) - return build_plot_tab(figure, name, key), figures_by_key + figure = record.figure + figure.update_layout(dark_figure_layout) + figures_by_key[key] = dict(figure=figure, color=color) + return build_figure_tab(figure, name, key), figures_by_key - def build_plot_tab( - plot_figure: go.Figure, plot_name: str, plot_key: str + def build_figure_tab( + figure: go.Figure, figure_name: str, figure_key: str ) -> dbc.Tab: return dbc.Tab( - dcc.Graph(figure=plot_figure, responsive=True), - label=plot_name, - id=f"plot-tab-{plot_key}", - tab_id=f"plot-tab-{plot_key}", + dcc.Graph(figure=figure, responsive=True), + label=figure_name, + id=f"figure-tab-{figure_key}", + tab_id=f"figure-tab-{figure_key}", + ) + + def build_img_tab(img_src: str, figure_name: str, figure_key: str) -> dbc.Tab: + return dbc.Tab( + # TODO: Fit img into tab dimensions + html.Img(src=img_src), + label=figure_name, + id=f"figure-tab-{figure_key}", + tab_id=f"figure-tab-{figure_key}", ) @app.callback( - Output("plot-keys-to-combine", "data"), + Output("figure-keys-to-combine", "data"), Input("add-button", "n_clicks"), Input({"type": "remove-button", "index": ALL}, "n_clicks"), - State("plot-tabs", "active_tab"), - State("plot-keys-to-combine", "data"), + State("figure-tabs", "active_tab"), + State("figure-keys-to-combine", "data"), ) - def add_or_remove_plot_keys_based_on_click_events( - _, __, active_tab, plot_keys_to_combine + def add_or_remove_figure_keys_based_on_click_events( + _, __, active_tab, figure_keys_to_combine ): - plot_keys_to_combine = plot_keys_to_combine or [] + figure_keys_to_combine = figure_keys_to_combine or [] prop_id = dash.callback_context.triggered[0]["prop_id"] if prop_id == "add-button.n_clicks" and active_tab: - active_plot_key = active_tab.replace("plot-tab-", "") - if active_plot_key not in plot_keys_to_combine: - plot_keys_to_combine.append(active_plot_key) + active_figure_key = active_tab.replace("figure-tab-", "") + if active_figure_key not in figure_keys_to_combine: + figure_keys_to_combine.append(active_figure_key) elif "remove-button" in prop_id: id_dict = json.loads(prop_id.replace(".n_clicks", "")) - remove_plot_key = id_dict["index"] - if remove_plot_key in plot_keys_to_combine: - plot_keys_to_combine.remove(remove_plot_key) - return plot_keys_to_combine + remove_figure_key = id_dict["index"] + if remove_figure_key in figure_keys_to_combine: + figure_keys_to_combine.remove(remove_figure_key) + return figure_keys_to_combine @app.callback( Output("aggregate-tab", "children"), Output("remove-buttons", "children"), - Input("plot-keys-to-combine", "data"), + Input("figure-keys-to-combine", "data"), State("figures-by-key", "data"), ) - def build_combined_plot_from_plot_keys(plot_keys_to_combine, plot_figures): - if plot_keys_to_combine and len(plot_keys_to_combine) > 0: + def build_combined_figure_from_figure_keys(figure_keys_to_combine, figures_by_key): + if figure_keys_to_combine and len(figure_keys_to_combine) > 0: combined_figure = make_subplots(specs=[[{"secondary_y": True}]]) remove_buttons = [] - for plot_id in plot_keys_to_combine: - figure = plot_figures[plot_id]["figure"] - color = plot_figures[plot_id]["color"] + for figure_key in figure_keys_to_combine: + figure = figures_by_key[figure_key]["figure"] + color = figures_by_key[figure_key]["color"] combined_figure.add_trace(figure["data"][0]) - button = build_remove_button(plot_id, color) + button = build_remove_button(figure_key, color) remove_buttons.append(button) - combined_figure.update_layout(dark_plot_layout) + combined_figure.update_layout(dark_figure_layout) return ( dcc.Graph( id="aggregate-graph", @@ -255,26 +255,26 @@ def build_combined_plot_from_plot_keys(plot_keys_to_combine, plot_figures): [html.Div()], ) - def build_remove_button(plot_id, color): + def build_remove_button(figure_key, color): return dbc.Button( dbc.Row( children=[ dbc.Col( "✖", ), - dbc.Col(f"{plot_id}", className="remove-button-label"), + dbc.Col(f"{figure_key}", className="remove-button-label"), ], ), style={"background-color": color}, class_name="remove-button", - id={"type": "remove-button", "index": plot_id}, + id={"type": "remove-button", "index": figure_key}, ) def build_aggregate_tab_placeholder(): return html.Div( [ html.Div( - "Add a plot on the left to aggregate it here", + "Add a figure (above) to aggregate it here", className="tab-placeholder-text", ), dcc.Graph( @@ -288,15 +288,22 @@ def build_aggregate_tab_placeholder(): ) @app.callback( - Output("plot-tabs", "active_tab"), - Input("plot-tabs", "children"), + Output("figure-tabs", "active_tab"), + Input("figure-tabs", "children"), ) - def activate_last_plot_tab_when_tabs_are_changed(children): + def activate_last_figure_tab_when_tabs_are_changed(children): if len(children) > 0: last_tab = len(children) - 1 return children[last_tab]["props"]["tab_id"] return 0 + @app.callback( + Output("add-button", "disabled"), + Input("figure-tabs", "active_tab"), + ) + def disable_add_button_when_active_tab_is_img_or_placeholder(active_tab): + return active_tab.endswith("/m") or active_tab == "figure-tab-placeholder" + @app.callback( Output("aggregate-clipboard", "content"), Input("aggregate-clipboard", "n_clicks"), diff --git a/entropylab/dashboard/pages/results/dashboard_data.py b/entropylab/dashboard/pages/results/dashboard_data.py index d69b8b42..0e57b19b 100644 --- a/entropylab/dashboard/pages/results/dashboard_data.py +++ b/entropylab/dashboard/pages/results/dashboard_data.py @@ -8,7 +8,10 @@ from entropylab import SqlAlchemyDB from entropylab.dashboard.pages.results.auto_plot import auto_plot from entropylab.logger import logger -from entropylab.pipeline.api.data_reader import PlotRecord, FigureRecord +from entropylab.pipeline.api.data_reader import ( + FigureRecord, + MatplotlibFigureRecord, +) FAVORITE_TRUE = "⭐" FAVORITE_FALSE = "✰" @@ -25,7 +28,9 @@ def get_last_experiments( pass @abc.abstractmethod - def get_plot_and_figure_data(self, exp_id: int) -> List[PlotRecord]: + def get_figure_records( + self, exp_id: int + ) -> List[FigureRecord | MatplotlibFigureRecord]: pass @@ -65,22 +70,26 @@ def get_last_result_of_experiment( ): return self._db.get_last_result_of_experiment(experiment_id) - def get_plot_and_figure_data(self, exp_id: int) -> List[PlotRecord | FigureRecord]: - plots = self._db.get_plots(exp_id) + def get_figure_records( + self, exp_id: int + ) -> List[FigureRecord | MatplotlibFigureRecord]: + # Plotly figures if exp_id not in self._figures_cache: logger.debug(f"Figures cache miss. exp_id=[{exp_id}]") self._figures_cache[exp_id] = self._db.get_figures(exp_id) else: logger.debug(f"Figures cache hit. exp_id=[{exp_id}]") figures = self._figures_cache[exp_id] - if len(plots) > 0 or len(figures) > 0: - return [*plots, *figures] + # Matplotlib figures + # TODO: Cache matplotlib figures? + matplotlib_figures = self._db.get_matplotlib_figures(exp_id) + if len(figures) > 0 or len(matplotlib_figures) > 0: + return [*figures, *matplotlib_figures] else: - # TODO: auto_plot to produce figures, not plots last_result = self._db.get_last_result_of_experiment(exp_id) if last_result is not None and last_result.data is not None: - plot = auto_plot(exp_id, last_result.data) - return [plot] + figure = auto_plot(exp_id, last_result.data) + return [figure] else: return [] diff --git a/entropylab/dashboard/pages/results/layout.py b/entropylab/dashboard/pages/results/layout.py index 12365b79..53c9d0e1 100644 --- a/entropylab/dashboard/pages/results/layout.py +++ b/entropylab/dashboard/pages/results/layout.py @@ -23,7 +23,7 @@ def build_layout(path: str, dashboard_data_reader: DashboardDataReader): className="main", children=[ dcc.Store(id="figures-by-key", storage_type="session"), - dcc.Store(id="plot-keys-to-combine", storage_type="session"), + dcc.Store(id="figure-keys-to-combine", storage_type="session"), dcc.Store(id="prev-selected-rows", storage_type="session"), dcc.Store(id="favorites", storage_type="session"), dcc.Interval( @@ -71,17 +71,19 @@ def build_layout(path: str, dashboard_data_reader: DashboardDataReader): [ dbc.Row( [ - html.H5("Plots and Figures", id="plots-title"), + html.H5( + "Figures and Images", id="figures-title" + ), dcc.Loading( - id="plot-tabs-loading", - children=[dbc.Tabs(id="plot-tabs")], + id="figure-tabs-loading", + children=[dbc.Tabs(id="figure-tabs")], type="default", ), ] ), dbc.Row( dbc.Button( - "➕ Add Plot to Aggregate View", + "➕ Add Figure to Aggregate View", id="add-button", ), className="add-button-container", diff --git a/entropylab/dashboard/theme.py b/entropylab/dashboard/theme.py index 9ba4c012..e53cc374 100644 --- a/entropylab/dashboard/theme.py +++ b/entropylab/dashboard/theme.py @@ -17,7 +17,7 @@ plot_paper_bgcolor = "rgb(48, 48, 48)" plot_plot_bgcolor = "rgb(173, 181, 189)" -dark_plot_layout = dict( +dark_figure_layout = dict( font_color=plot_font_color, legend_font_color=plot_legend_font_color, paper_bgcolor=plot_paper_bgcolor, diff --git a/entropylab/pipeline/api/data_reader.py b/entropylab/pipeline/api/data_reader.py index d615202c..ab084cae 100644 --- a/entropylab/pipeline/api/data_reader.py +++ b/entropylab/pipeline/api/data_reader.py @@ -2,13 +2,10 @@ from dataclasses import dataclass from datetime import datetime from typing import List, Any, Optional, Iterable -from warnings import warn from pandas import DataFrame from plotly import graph_objects as go -from entropylab.pipeline.api.data_writer import PlotGenerator - class ScriptViewer: def __init__(self, stages: List[str]) -> None: @@ -84,29 +81,26 @@ class ResultRecord: @dataclass -class PlotRecord: +class FigureRecord: """ - A single plot information and plotting instructions that was saved during the - experiment + A single plotly figure that was saved during the experiment """ experiment_id: int id: int - plot_data: Any = None - generator: Optional[PlotGenerator] = None - label: Optional[str] = None - story: Optional[str] = None + figure: go.Figure + time: datetime @dataclass -class FigureRecord: +class MatplotlibFigureRecord: """ - A single plotly figure that was saved during the experiment + A single matplotlib figure image that was saved during the experiment """ experiment_id: int id: int - figure: go.Figure + img_src: str time: datetime @@ -227,23 +221,19 @@ def get_debug_record(self, experiment_id: int) -> Optional[DebugRecord]: """ pass - # noinspection PyTypeChecker @abstractmethod - def get_plots(self, experiment_id: int) -> List[PlotRecord]: + def get_figures(self, experiment_id: int) -> List[FigureRecord]: """ - returns a list of all plots saved in the requested experiment + returns a list of all plotly figures saved with the requested experiment """ - warn( - "This method will soon be deprecated. Please use get_figures() instead", - PendingDeprecationWarning, - stacklevel=2, - ) pass @abstractmethod - def get_figures(self, experiment_id: int) -> List[FigureRecord]: + def get_matplotlib_figures( + self, experiment_id: int + ) -> List[MatplotlibFigureRecord]: """ - returns a list of all figures saved in the requested experiment + returns a list of all matplotlib figures saved with the requested experiment """ pass @@ -336,19 +326,14 @@ def get_results(self, label: Optional[str] = None) -> Iterable[ResultRecord]: """ return self._data_reader.get_results(self._experiment_id, label) - def get_plots(self) -> List[PlotRecord]: - """ - returns a list of plot records that were saved for current experiment - """ - warn( - "This method will soon be deprecated. Please use get_plots() instead", - PendingDeprecationWarning, - stacklevel=2, - ) - return self._data_reader.get_plots(self._experiment_id) - def get_figures(self) -> List[FigureRecord]: """ returns a list of plotly figures that were saved for current experiment """ return self._data_reader.get_figures(self._experiment_id) + + def get_matplotlib_figures(self) -> List[MatplotlibFigureRecord]: + """ + returns a list of matplotlib figures that were saved for current experiment + """ + return self._data_reader.get_matplotlib_figures(self._experiment_id) diff --git a/entropylab/pipeline/api/data_writer.py b/entropylab/pipeline/api/data_writer.py index 78bd85b6..3003203f 100644 --- a/entropylab/pipeline/api/data_writer.py +++ b/entropylab/pipeline/api/data_writer.py @@ -1,12 +1,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime -from typing import Any, Optional, Type -from warnings import warn +from typing import Any -from bokeh.models import Renderer -from bokeh.plotting import Figure -from matplotlib.figure import Figure as matplotlibFigure +from matplotlib.figure import Figure from plotly import graph_objects as go @@ -75,59 +72,6 @@ class Debug: extra: str -class PlotGenerator(ABC): - """ - An abstract class for plots. - Implementations of this class will let Entropy to save and view plots. - Every implementation can either implement all plotting functions - (within the different environments), or just part of it. - """ - - def __init__(self) -> None: - super().__init__() - - @abstractmethod - def plot_bokeh(self, figure: Figure, data, **kwargs) -> Renderer: - """ - plot the given data within the Bokeh Figure - :param figure: Bokeh figure to plot in - :param data: plot data - :param kwargs: extra parameters for plotting - """ - pass - - @abstractmethod - def plot_matplotlib(self, figure: matplotlibFigure, data, **kwargs): - """ - plot the given data within the matplotlib Figure - :param figure: matplotlib figure - :param data: plot data - :param kwargs: extra parameters for plotting - """ - pass - - @abstractmethod - def plot_plotly(self, figure: go.Figure, data, **kwargs) -> None: - """ - plot the given data within the plot.ly Figure - :param figure: plot.ly figure - :param data: plot data - :param kwargs: extra parameters for plotting - """ - pass - - -@dataclass(frozen=True, eq=True) -class PlotSpec: - """ - Description and plotting instructions for a plot that will be saved - """ - - generator: Optional[Type[PlotGenerator]] = None - label: Optional[str] = None - story: Optional[str] = "" - - @dataclass class NodeData: """ @@ -184,24 +128,19 @@ def save_debug(self, experiment_id: int, debug: Debug): """ pass - @abstractmethod - def save_plot(self, experiment_id: int, plot: PlotSpec, data: Any): - """ - save a new plot to the db according to the PlotSpec class - :param experiment_id: the experiment id - :param plot: plotting instructions - :param data: the data of the plot - """ - warn( - "This method will soon be deprecated. Please use save_figure() instead", - PendingDeprecationWarning, - stacklevel=2, - ) + def save_figure(self, experiment_id: int, figure: go.Figure) -> None: + """ + saves a new plotly figure to the db and associates it with an experiment + + :param experiment_id: the id of the experiment to associate the figure to + :param figure: the figure to save to the database + """ pass - def save_figure(self, experiment_id: int, figure: go.Figure) -> None: + def save_matplotlib_figure(self, experiment_id: int, figure: Figure) -> None: """ - save a new plotly figure to the db and associates it with an experiment + saves a new matplotlib figure to the db (as a base64-encoded string + representation of a PNG image) and associates it with an experiment :param experiment_id: the id of the experiment to associate the figure to :param figure: the figure to save to the database diff --git a/entropylab/pipeline/api/execution.py b/entropylab/pipeline/api/execution.py index 30407858..3b5fb331 100644 --- a/entropylab/pipeline/api/execution.py +++ b/entropylab/pipeline/api/execution.py @@ -4,13 +4,12 @@ from plotly import graph_objects as go +from entropylab.components.lab_topology import ExperimentResources from entropylab.pipeline.api.data_writer import ( DataWriter, RawResultData, Metadata, - PlotSpec, ) -from entropylab.components.lab_topology import ExperimentResources class EntropyContext: @@ -60,15 +59,6 @@ def add_metadata(self, label: str, metadata: Any): self._exp_id, Metadata(label, self._stage_id, metadata) ) - def add_plot(self, plot: PlotSpec, data: Any): - """ - saves a new plot from this experiment in the database - - :param plot: description and plotting instructions - :param data: the data for plotting - """ - self._data_writer.save_plot(self._exp_id, plot, data) - def add_figure(self, figure: go.Figure) -> None: """ saves a new figure from this experiment in the database diff --git a/entropylab/pipeline/api/memory_reader_writer.py b/entropylab/pipeline/api/memory_reader_writer.py index 8e2b6583..b8199b1f 100644 --- a/entropylab/pipeline/api/memory_reader_writer.py +++ b/entropylab/pipeline/api/memory_reader_writer.py @@ -1,8 +1,9 @@ import random from datetime import datetime from time import time_ns -from typing import List, Optional, Iterable, Any, Dict, Tuple +from typing import List, Optional, Iterable, Dict, Tuple +import matplotlib.figure from pandas import DataFrame from plotly import graph_objects as go @@ -13,10 +14,10 @@ MetadataRecord, ExperimentRecord, ScriptViewer, - PlotRecord, FigureRecord, + MatplotlibFigureRecord, ) -from entropylab.pipeline.api.data_writer import DataWriter, PlotSpec, NodeData +from entropylab.pipeline.api.data_writer import DataWriter, NodeData from entropylab.pipeline.api.data_writer import ( ExperimentInitialData, ExperimentEndData, @@ -24,6 +25,9 @@ Metadata, Debug, ) +from entropylab.pipeline.results_backend.sqlalchemy.db import ( + matplotlib_figure_to_img_src, +) class MemoryOnlyDataReaderWriter(DataWriter, DataReader): @@ -41,8 +45,8 @@ def __init__(self): self._results: List[Tuple[RawResultData, datetime]] = [] self._metadata: List[Tuple[Metadata, datetime]] = [] self._debug: Optional[Debug] = None - self._plot: Dict[PlotSpec, Any] = {} self._figure: Dict[int, List[FigureRecord]] = {} + self._matplotlib_figure: Dict[int, List[MatplotlibFigureRecord]] = {} self._nodes: List[NodeData] = [] def save_experiment_initial_data(self, initial_data: ExperimentInitialData) -> int: @@ -61,9 +65,6 @@ def save_metadata(self, experiment_id: int, metadata: Metadata): def save_debug(self, experiment_id: int, debug: Debug): self._debug = debug - def save_plot(self, experiment_id: int, plot: PlotSpec, data: Any): - self._plot[plot] = data - def save_figure(self, experiment_id: int, figure: go.Figure) -> None: figure_record = FigureRecord( experiment_id=experiment_id, @@ -76,10 +77,26 @@ def save_figure(self, experiment_id: int, figure: go.Figure) -> None: else: self._figure[experiment_id] = [figure_record] + def save_matplotlib_figure( + self, experiment_id: int, figure: matplotlib.figure.Figure + ) -> None: + record = MatplotlibFigureRecord( + experiment_id=experiment_id, + id=random.randint(0, 2**31 - 1), + img_src=matplotlib_figure_to_img_src(figure), + time=datetime.now(), + ) + if experiment_id in self._matplotlib_figure: + self._matplotlib_figure[experiment_id].append(record) + else: + self._matplotlib_figure[experiment_id] = [record] + def save_node(self, experiment_id: int, node_data: NodeData): self._nodes.append(node_data) - def get_experiments_range(self, starting_from_index: int, count: int) -> DataFrame: + def get_experiments_range( + self, starting_from_index: int, count: int, success: bool = None + ) -> DataFrame: raise NotImplementedError() def get_experiments( @@ -163,22 +180,14 @@ def get_debug_record(self, experiment_id: int) -> Optional[DebugRecord]: else: return None - def get_plots(self, experiment_id: int) -> List[PlotRecord]: - return [ - PlotRecord( - experiment_id, - id(plot), - self._plot[plot], - plot.generator(), - plot.label, - plot.story, - ) - for plot in self._plot - ] - def get_figures(self, experiment_id: int) -> List[FigureRecord]: return self._figure[experiment_id] + def get_matplotlib_figures( + self, experiment_id: int + ) -> List[MatplotlibFigureRecord]: + return self._matplotlib_figure[experiment_id] + def get_node_stage_ids_by_label( self, label: str, experiment_id: Optional[int] = None ) -> List[int]: diff --git a/entropylab/pipeline/api/plot.py b/entropylab/pipeline/api/plot.py deleted file mode 100644 index b6533bd2..00000000 --- a/entropylab/pipeline/api/plot.py +++ /dev/null @@ -1,111 +0,0 @@ -from typing import List - -import plotly -import plotly.express as px -import plotly.graph_objects as go -from bokeh.models import Renderer -from bokeh.plotting import Figure -from matplotlib.axes import Axes as Axes -from matplotlib.figure import Figure as matplotlibFigure - -from entropylab.pipeline.api.data_writer import PlotGenerator - - -class LinePlotGenerator(PlotGenerator): - def __init__(self) -> None: - super().__init__() - - def plot_matplotlib(self, figure: matplotlibFigure, data, **kwargs) -> Renderer: - raise NotImplementedError() - - def plot_bokeh(self, figure: Figure, data, **kwargs) -> Renderer: - if isinstance(data, List) and len(data) == 2 and len(data[0]) == len(data[1]): - x = data[0] - y = data[1] - return figure.line( - x, - y, - color=kwargs.get("color", "blue"), - legend_label=kwargs.get("label", ""), - ) - else: - raise TypeError("data type is not supported") - - def plot_plotly(self, figure: plotly.graph_objects.Figure, data, **kwargs): - if isinstance(data, List) and len(data) == 2 and len(data[0]) == len(data[1]): - x = data[0] - y = data[1] - color = kwargs.pop("color", "blue") - figure.add_trace( - go.Scatter( - mode="lines", - x=x, - y=y, - line_color=color, - **kwargs, - ) - ) - return figure - else: - raise TypeError("data type is not supported") - - -class CirclePlotGenerator(PlotGenerator): - def __init__(self) -> None: - super().__init__() - - def plot_matplotlib(self, figure: Axes, data, **kwargs): - raise NotImplementedError() - - def plot_bokeh(self, figure: Figure, data, **kwargs) -> Renderer: - if isinstance(data, List) and len(data) == 2 and len(data[0]) == len(data[1]): - x = data[0] - y = data[1] - color = (kwargs.get("color", "blue"),) - return figure.circle( - x, - y, - size=10, - color=color, - legend_label=kwargs.get("label", ""), - alpha=0.5, - ) - else: - raise TypeError("data type is not supported") - - def plot_plotly(self, figure: plotly.graph_objects.Figure, data, **kwargs): - if isinstance(data, List) and len(data) == 2 and len(data[0]) == len(data[1]): - x = data[0] - y = data[1] - color = kwargs.pop("color", "blue") - # noinspection PyTypeChecker - figure.add_trace( - go.Scatter( - mode="markers", - x=x, - y=y, - marker_color=color, - marker_size=10, - **kwargs, - ) - ) - return figure - else: - raise TypeError("data type is not supported") - - -class ImShowPlotGenerator(PlotGenerator): - def __init__(self) -> None: - super().__init__() - - def plot_plotly(self, figure: go.Figure, data, **kwargs) -> None: - headtmap_fig = px.imshow(data) - figure.add_trace(headtmap_fig.data[0]) - - return figure - - def plot_bokeh(self, figure: Figure, data, **kwargs) -> Renderer: - raise NotImplementedError() - - def plot_matplotlib(self, figure: matplotlibFigure, data, **kwargs): - raise NotImplementedError() diff --git a/entropylab/pipeline/results_backend/sqlalchemy/alembic/versions/2022-03-17-15-57-28_f1ada2484fe2_create_figures_table.py b/entropylab/pipeline/results_backend/sqlalchemy/alembic/versions/2022-03-17-15-57-28_f1ada2484fe2_create_figures_table.py index 96960c32..8b80deb2 100644 --- a/entropylab/pipeline/results_backend/sqlalchemy/alembic/versions/2022-03-17-15-57-28_f1ada2484fe2_create_figures_table.py +++ b/entropylab/pipeline/results_backend/sqlalchemy/alembic/versions/2022-03-17-15-57-28_f1ada2484fe2_create_figures_table.py @@ -27,10 +27,10 @@ def upgrade(): "Figures", sa.Column("id", sa.Integer(), nullable=False), sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.Column("figure", sa.String(), nullable=False), + sa.Column("img_src", sa.String(), nullable=False), sa.Column("time", sa.DATETIME(), nullable=False), sa.ForeignKeyConstraint( - ["experiment_id"], ["Experiments.id"], ondelete="CASCADE" + ("experiment_id",), ("Experiments.id",), ondelete="CASCADE" ), sa.PrimaryKeyConstraint("id"), ) diff --git a/entropylab/pipeline/results_backend/sqlalchemy/alembic/versions/2022-07-03-08-56-23_da8d38e19ff8_matplotlib_figures.py b/entropylab/pipeline/results_backend/sqlalchemy/alembic/versions/2022-07-03-08-56-23_da8d38e19ff8_matplotlib_figures.py new file mode 100644 index 00000000..573a8a4b --- /dev/null +++ b/entropylab/pipeline/results_backend/sqlalchemy/alembic/versions/2022-07-03-08-56-23_da8d38e19ff8_matplotlib_figures.py @@ -0,0 +1,45 @@ +"""Matplotlib figures + +Revision ID: da8d38e19ff8 +Revises: 09f3b5a1689c +Create Date: 2022-07-03 08:56:23.627973+00:00 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +from sqlalchemy.engine import Inspector + +# revision identifiers, used by Alembic. +from entropylab.pipeline.results_backend.sqlalchemy.model import MatplotlibFigureTable + +revision = "da8d38e19ff8" +down_revision = "09f3b5a1689c" +branch_labels = None +depends_on = None + +TABLE_NAME = MatplotlibFigureTable.__tablename__ + + +def upgrade(): + conn = op.get_bind() + inspector = Inspector.from_engine(conn) + tables = inspector.get_table_names() + if TABLE_NAME not in tables: + op.create_table( + TABLE_NAME, + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.Column("img_src", sa.String(), nullable=False), + sa.Column("time", sa.DATETIME(), nullable=False), + sa.ForeignKeyConstraint( + ("experiment_id",), ("Experiments.id",), ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("id"), + ) + + +def downgrade(): + op.drop_table(TABLE_NAME) diff --git a/entropylab/pipeline/results_backend/sqlalchemy/alembic/versions/2022-08-07-11-53-59_997e336572b8_paramstore_json_v0_3.py b/entropylab/pipeline/results_backend/sqlalchemy/alembic/versions/2022-08-07-11-53-59_997e336572b8_paramstore_json_v0_3.py index 37563680..290ea142 100644 --- a/entropylab/pipeline/results_backend/sqlalchemy/alembic/versions/2022-08-07-11-53-59_997e336572b8_paramstore_json_v0_3.py +++ b/entropylab/pipeline/results_backend/sqlalchemy/alembic/versions/2022-08-07-11-53-59_997e336572b8_paramstore_json_v0_3.py @@ -20,7 +20,7 @@ # revision identifiers, used by Alembic. revision = "997e336572b8" -down_revision = "09f3b5a1689c" +down_revision = "da8d38e19ff8" branch_labels = None depends_on = None diff --git a/entropylab/pipeline/results_backend/sqlalchemy/db.py b/entropylab/pipeline/results_backend/sqlalchemy/db.py index 489c0240..7aa4002e 100644 --- a/entropylab/pipeline/results_backend/sqlalchemy/db.py +++ b/entropylab/pipeline/results_backend/sqlalchemy/db.py @@ -1,9 +1,11 @@ +import base64 +import io from datetime import datetime -from typing import List, TypeVar, Optional, ContextManager, Iterable, Union, Any +from typing import List, TypeVar, Optional, ContextManager, Iterable, Union from typing import Set -from warnings import warn import jsonpickle +import matplotlib import pandas as pd from pandas import DataFrame from plotly import graph_objects as go @@ -30,8 +32,8 @@ ResultRecord, MetadataRecord, DebugRecord, - PlotRecord, FigureRecord, + MatplotlibFigureRecord, ) from entropylab.pipeline.api.data_writer import ( DataWriter, @@ -40,19 +42,18 @@ RawResultData, Metadata, Debug, - PlotSpec, NodeData, ) from entropylab.pipeline.api.errors import EntropyError from entropylab.pipeline.results_backend.sqlalchemy.db_initializer import _DbInitializer from entropylab.pipeline.results_backend.sqlalchemy.model import ( ExperimentTable, - PlotTable, ResultTable, DebugTable, MetadataTable, NodeTable, FigureTable, + MatplotlibFigureTable, ) T = TypeVar( @@ -141,19 +142,17 @@ def save_debug(self, experiment_id: int, debug: Debug): transaction = DebugTable.from_model(experiment_id, debug) return self._execute_transaction(transaction) - def save_plot(self, experiment_id: int, plot: PlotSpec, data: Any): - warn( - "This method will soon be deprecated. Please use save_figure() instead", - PendingDeprecationWarning, - stacklevel=2, - ) - transaction = PlotTable.from_model(experiment_id, plot, data) - return self._execute_transaction(transaction) - def save_figure(self, experiment_id: int, figure: go.Figure) -> None: transaction = FigureTable.from_model(experiment_id, figure) return self._execute_transaction(transaction) + def save_matplotlib_figure( + self, experiment_id: int, figure: matplotlib.figure.Figure + ) -> None: + img_src = matplotlib_figure_to_img_src(figure) + transaction = MatplotlibFigureTable.from_model(experiment_id, img_src) + return self._execute_transaction(transaction) + def save_node(self, experiment_id: int, node_data: NodeData): transaction = NodeTable.from_model(experiment_id, node_data) return self._execute_transaction(transaction) @@ -272,27 +271,24 @@ def get_all_results_with_label(self, exp_id, name) -> DataFrame: ) return self._query_pandas(query) - def get_plots(self, experiment_id: int) -> List[PlotRecord]: - warn( - "This method will soon be deprecated. Please use get_figures() instead", - PendingDeprecationWarning, - stacklevel=2, - ) + def get_figures(self, experiment_id: int) -> List[FigureRecord]: with self._session_maker() as sess: query = ( - sess.query(PlotTable) - .filter(PlotTable.experiment_id == int(experiment_id)) + sess.query(FigureTable) + .filter(FigureTable.experiment_id == int(experiment_id)) .all() ) if query: - return [plot.to_record() for plot in query] + return [figure.to_record() for figure in query] return [] - def get_figures(self, experiment_id: int) -> List[FigureRecord]: + def get_matplotlib_figures( + self, experiment_id: int + ) -> List[MatplotlibFigureRecord]: with self._session_maker() as sess: query = ( - sess.query(FigureTable) - .filter(FigureTable.experiment_id == int(experiment_id)) + sess.query(MatplotlibFigureTable) + .filter(MatplotlibFigureTable.experiment_id == int(experiment_id)) .all() ) if query: @@ -522,3 +518,14 @@ def __hdf5_storage_enabled(self) -> bool: else: enabled = self._enable_hdf5_storage return enabled + + +def matplotlib_figure_to_img_src(figure: matplotlib.figure.Figure): + """Converts a matplotlib Figure instance into a string that can be used as the + 'src' attribute of an HTML """ + buf = io.BytesIO() + figure.savefig(buf, format="png") + # figure.close() + data = base64.b64encode(buf.getbuffer()).decode("utf8") + img_src = "data:image/png;base64,{}".format(data) + return img_src diff --git a/entropylab/pipeline/results_backend/sqlalchemy/model.py b/entropylab/pipeline/results_backend/sqlalchemy/model.py index 5e750aa7..f4d9bf37 100644 --- a/entropylab/pipeline/results_backend/sqlalchemy/model.py +++ b/entropylab/pipeline/results_backend/sqlalchemy/model.py @@ -3,7 +3,6 @@ import pickle from datetime import datetime from io import BytesIO -from typing import Any import numpy as np from plotly import graph_objects as go @@ -28,15 +27,14 @@ ResultRecord, MetadataRecord, DebugRecord, - PlotRecord, FigureRecord, + MatplotlibFigureRecord, ) from entropylab.pipeline.api.data_writer import ( ExperimentInitialData, RawResultData, Metadata, Debug, - PlotSpec, NodeData, ) from entropylab.pipeline.api.errors import EntropyError @@ -99,7 +97,6 @@ class ExperimentTable(Base): results = relationship("ResultTable", cascade="all, delete-orphan") experiment_metadata = relationship("MetadataTable", cascade="all, delete-orphan") debug = relationship("DebugTable", cascade="all, delete-orphan") - plots = relationship("PlotTable", cascade="all, delete-orphan") def __repr__(self): return f"<_Experiment(id='{self.id}')>" @@ -240,72 +237,56 @@ def from_model(experiment_id: int, node_data: NodeData): ) -class PlotTable(Base): - __tablename__ = "Plots" - +class FigureTable(Base): + __tablename__ = "Figures" id = Column(Integer, primary_key=True) experiment_id = Column(Integer, ForeignKey("Experiments.id", ondelete="CASCADE")) - plot_data = Column(BLOB) - data_type = Column(Enum(ResultDataType)) - generator_module = Column(String) - generator_class = Column(String) + figure = Column(String) time = Column(DATETIME) - label = Column(String) - story = Column(String) def __repr__(self): - return f"" + return f"" - def to_record(self) -> PlotRecord: - data = _decode_serialized_data(self.plot_data, self.data_type) - generator = _get_class(self.generator_module, self.generator_class) - return PlotRecord( + def to_record(self) -> FigureRecord: + return FigureRecord( experiment_id=self.experiment_id, id=self.id, - label=self.label, - story=self.story, - plot_data=data, - generator=generator(), + figure=from_json(self.figure), + time=self.time, ) @staticmethod - def from_model(experiment_id: int, plot: PlotSpec, data: Any): - data_type, serialized_data = _encode_serialized_data(data) - return PlotTable( + def from_model(experiment_id: int, figure: go.Figure): + return FigureTable( experiment_id=experiment_id, - plot_data=serialized_data, - data_type=data_type, - generator_module=plot.generator.__module__, - generator_class=plot.generator.__qualname__, + figure=to_json(figure), time=datetime.now(), - label=plot.label, - story=plot.story, ) -class FigureTable(Base): - __tablename__ = "Figures" +class MatplotlibFigureTable(Base): + __tablename__ = "MatplotlibFigures" id = Column(Integer, primary_key=True) experiment_id = Column(Integer, ForeignKey("Experiments.id", ondelete="CASCADE")) - figure = Column(String) + img_src = Column(String) time = Column(DATETIME) def __repr__(self): - return f"" + return f"" - def to_record(self) -> FigureRecord: - return FigureRecord( + def to_record(self) -> MatplotlibFigureRecord: + return MatplotlibFigureRecord( experiment_id=self.experiment_id, id=self.id, - figure=from_json(self.figure), + img_src=self.img_src, time=self.time, ) @staticmethod - def from_model(experiment_id: int, figure: go.Figure): - return FigureTable( + def from_model(experiment_id: int, img_src: str): + return MatplotlibFigureTable( experiment_id=experiment_id, - figure=to_json(figure), + img_src=img_src, time=datetime.now(), ) diff --git a/entropylab/pipeline/results_backend/sqlalchemy/tests/db_templates/empty_after_2022-07-03-08-56-23_da8d38e19ff8_matplotlib_figures.db b/entropylab/pipeline/results_backend/sqlalchemy/tests/db_templates/empty_after_2022-07-03-08-56-23_da8d38e19ff8_matplotlib_figures.db new file mode 100644 index 00000000..730a9208 Binary files /dev/null and b/entropylab/pipeline/results_backend/sqlalchemy/tests/db_templates/empty_after_2022-07-03-08-56-23_da8d38e19ff8_matplotlib_figures.db differ diff --git a/entropylab/pipeline/results_backend/sqlalchemy/tests/entropynodes/library/TestNode.py b/entropylab/pipeline/results_backend/sqlalchemy/tests/entropynodes/library/TestNode.py new file mode 100644 index 00000000..e69de29b diff --git a/entropylab/pipeline/results_backend/sqlalchemy/tests/entropynodes/schema/TestNode.json b/entropylab/pipeline/results_backend/sqlalchemy/tests/entropynodes/schema/TestNode.json new file mode 100644 index 00000000..f6b88ba9 --- /dev/null +++ b/entropylab/pipeline/results_backend/sqlalchemy/tests/entropynodes/schema/TestNode.json @@ -0,0 +1 @@ +{"name": "TestNode", "description": "test", "command": "python3", "bin": "_jb_pytest_runner.py", "icon": "bootstrap/person-circle.svg", "inputs": [{"description": {"stream_1": "desc", "state_1": "desc"}, "units": {"stream_1": "test", "state_1": "test"}, "type": {"stream_1": 2, "state_1": 1}}], "outputs": []} \ No newline at end of file diff --git a/entropylab/pipeline/results_backend/sqlalchemy/tests/test_db.py b/entropylab/pipeline/results_backend/sqlalchemy/tests/test_db.py index d98ec3fa..06ad167a 100644 --- a/entropylab/pipeline/results_backend/sqlalchemy/tests/test_db.py +++ b/entropylab/pipeline/results_backend/sqlalchemy/tests/test_db.py @@ -2,6 +2,7 @@ from datetime import datetime import pytest +from matplotlib import pyplot as plt from plotly import express as px from entropylab import SqlAlchemyDB, RawResultData @@ -75,6 +76,24 @@ def test_save_figure_(initialized_project_dir_path): assert actual.figure == figure +def test_save_matplotlib_figure_(initialized_project_dir_path): + # arrange + db = SqlAlchemyDB(initialized_project_dir_path) + x = [1, 2, 3, 4] + y = [10, 40, 20, 30] + plt.scatter(x, y) + figure = plt.gcf() + # act + db.save_matplotlib_figure(0, figure) + # assert + actual = db.get_matplotlib_figures(0)[0] + assert actual.img_src.startswith( + "data:image/png;base64," + "iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAYAAAA10dzkAAAAOXRFWHRTb2Z0d2FyZQBN" + "YXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9F" + ) + + def test_get_experiments_range_reads_all_columns(): # arrange target = SqlAlchemyDB() diff --git a/entropylab/pipeline/results_backend/sqlalchemy/tests/test_db_upgrader.py b/entropylab/pipeline/results_backend/sqlalchemy/tests/test_db_upgrader.py index 3950740b..d27a87d8 100644 --- a/entropylab/pipeline/results_backend/sqlalchemy/tests/test_db_upgrader.py +++ b/entropylab/pipeline/results_backend/sqlalchemy/tests/test_db_upgrader.py @@ -153,6 +153,8 @@ def test__migrate_metadata_to_hdf5(initialized_project_dir_path): "empty_after_2022-05-19-09-12-01_06140c96c8c4_wrapping_param_store_values.db", "empty_after_2022-06-16-09-26-06_7fa75ca1263f_del_results_and_metadata.db", "empty_after_2022-06-23-10-16-39_273a9fae6206_experiments_favorite_col.db", + "empty_after_2022-06-28-12-13-39_09f3b5a1689c_fixing_param_qualified_name.db", + "empty_after_2022-07-03-08-56-23_da8d38e19ff8_matplotlib_figures.db", ], indirect=True, ) diff --git a/entropylab/pipeline/results_backend/sqlalchemy/tests/test_migrations.py b/entropylab/pipeline/results_backend/sqlalchemy/tests/test_migrations.py index 75f5924e..07d994d8 100644 --- a/entropylab/pipeline/results_backend/sqlalchemy/tests/test_migrations.py +++ b/entropylab/pipeline/results_backend/sqlalchemy/tests/test_migrations.py @@ -19,8 +19,7 @@ def test_ctor_creates_up_to_date_schema_when_in_memory(path: str): [ None, # new db "empty.db", # existing but empty - "empty_after_2022-08-07-11-53-59_997e336572b8_paramstore_json_v0_3.db" - # "empty_after_2022-06-28-12-13-39_09f3b5a1689c_fixing_param_qualified_name.db" + "empty_after_2022-08-07-11-53-59_997e336572b8_paramstore_json_v0_3.db", # ⬆ latest version in pipeline/results_backend/sqlalchemy/alembic/versions ], indirect=True, diff --git a/entropylab/pipeline/results_backend/sqlalchemy/tests/test_model.py b/entropylab/pipeline/results_backend/sqlalchemy/tests/test_model.py index 69604e4d..50a23589 100644 --- a/entropylab/pipeline/results_backend/sqlalchemy/tests/test_model.py +++ b/entropylab/pipeline/results_backend/sqlalchemy/tests/test_model.py @@ -1,7 +1,12 @@ +import datetime + from plotly import express as px from plotly.io import to_json -from entropylab.pipeline.results_backend.sqlalchemy.model import FigureTable +from entropylab.pipeline.results_backend.sqlalchemy.model import ( + FigureTable, + MatplotlibFigureTable, +) class TestFigureTable: @@ -24,3 +29,29 @@ def test_from_model(self): assert actual.experiment_id == 1 assert actual.figure == to_json(figure) assert actual.time is not None + + +class TestMatplotlibFigureTable: + def test_to_record(self): + target = MatplotlibFigureTable( + id=42, + experiment_id=1337, + img_src="data:image/png;base64,", + time=datetime.datetime.utcnow(), + ) + + actual = target.to_record() + + assert actual.id == target.id + assert actual.experiment_id == target.experiment_id + assert actual.img_src == target.img_src + assert actual.time == target.time + + def test_from_model(self): + target = MatplotlibFigureTable() + + actual = target.from_model(1, "data:image/png;base64,") + + assert actual.experiment_id == 1 + assert actual.img_src == "data:image/png;base64," + assert actual.time is not None diff --git a/entropylab/pipeline/tests/test_async_graph.py b/entropylab/pipeline/tests/test_async_graph.py index 8d81bf8d..8414829b 100644 --- a/entropylab/pipeline/tests/test_async_graph.py +++ b/entropylab/pipeline/tests/test_async_graph.py @@ -1,7 +1,6 @@ import asyncio import numpy as np -from bokeh.plotting import Figure from entropylab.pipeline.graph_experiment import ( Graph, @@ -136,11 +135,6 @@ def test_async_graph(): results = graph.run().results print(results.get_experiment_info()) - plots = results.get_plots() - for plot in plots: - figure = Figure() - plot.generator.plot_bokeh(figure, plot.plot_data) - # save(figure, f"try{plot.label}.html") def test_async_graph_run_to_node(): diff --git a/entropylab/pipeline/tests/test_executor.py b/entropylab/pipeline/tests/test_executor.py index 339e5116..d2beef49 100644 --- a/entropylab/pipeline/tests/test_executor.py +++ b/entropylab/pipeline/tests/test_executor.py @@ -1,11 +1,10 @@ from datetime import datetime import pytest +from plotly import express as px -from entropylab.pipeline.api.data_writer import PlotSpec -from entropylab.pipeline.api.execution import EntropyContext -from entropylab.pipeline.api.plot import CirclePlotGenerator, LinePlotGenerator from entropylab.components.lab_topology import LabResources, ExperimentResources +from entropylab.pipeline.api.execution import EntropyContext from entropylab.pipeline.results_backend.sqlalchemy.db import SqlAlchemyDB from entropylab.pipeline.script_experiment import Script, script_experiment from entropylab.pipeline.tests.mock_instruments import MockScope @@ -64,14 +63,10 @@ def an_experiment_with_plot(experiment: EntropyContext): experiment.add_result("b_result" + str(i), b1 + i + datetime.now().microsecond) micro = datetime.now().microsecond - experiment.add_plot( - PlotSpec( - label="plot", - story="created this plot in experiment", - generator=CirclePlotGenerator, - ), - data=[ - [ + + experiment.add_figure( + px.scatter( + x=[ 1 * micro, 2 * micro, 3 * micro, @@ -81,20 +76,15 @@ def an_experiment_with_plot(experiment: EntropyContext): 7 * micro, 8 * micro, ], - [0, 1, 2, 3, 4, 5, 6, 7], - ], + y=[0, 1, 2, 3, 4, 5, 6, 7], + ) ) - experiment.add_plot( - PlotSpec( - label="another plot", - story="just showing off now", - generator=LinePlotGenerator, - ), - data=[ - [1, 2, 3, 4, 5, 6, 7, 8], - [4, 5, 6, 7, 0, 1, 2, 3], - ], + experiment.add_figure( + px.line( + x=[1, 2, 3, 4, 5, 6, 7, 8], + y=[4, 5, 6, 7, 0, 1, 2, 3], + ) ) diff --git a/entropylab/pipeline/tests/test_graph.py b/entropylab/pipeline/tests/test_graph.py index 4114c7f1..1e049f2a 100644 --- a/entropylab/pipeline/tests/test_graph.py +++ b/entropylab/pipeline/tests/test_graph.py @@ -1,14 +1,9 @@ import asyncio -import os from time import sleep import pytest -from bokeh.io import save -from bokeh.plotting import Figure -from entropylab.pipeline.api.data_writer import PlotSpec from entropylab.pipeline.api.execution import EntropyContext -from entropylab.pipeline.api.plot import CirclePlotGenerator from entropylab.pipeline.graph_experiment import ( Graph, PyNode, @@ -52,10 +47,6 @@ def d(x, y): async def e(y, z, context: EntropyContext): print(f"Node e resting for {y / z}") print(f"e Result: {y + z}") - context.add_plot( - PlotSpec(CirclePlotGenerator, "the best plot"), - data=[[0, 1, 2, 3, 4, 5], [y + z, 7, 6, 20, 10, 11]], - ) return {"y_z": [0, 1, 2, 3, 4, 5, y + z, 7, 6, 20, 10, 11]} @@ -147,14 +138,6 @@ def test_sync_graph(): results = Graph(None, g, "run_a").run().results print(results.get_experiment_info()) - # TODO: Use figures instead of plots here - plots = results.get_plots() - for plot in plots: - figure = Figure() - plot.generator.plot_bokeh(figure, plot.plot_data) - if not os.path.exists("tests_cache"): - os.mkdir("tests_cache") - save(figure, f"tests_cache/bokeh-exported-{plot.label}.html") def test_sync_graph_run_to_node(): diff --git a/entropylab/pipeline/tests/test_plot.py b/entropylab/pipeline/tests/test_plot.py deleted file mode 100644 index 03c079c0..00000000 --- a/entropylab/pipeline/tests/test_plot.py +++ /dev/null @@ -1,21 +0,0 @@ -import numpy as np -import plotly - -from entropylab.pipeline.api.plot import CirclePlotGenerator, ImShowPlotGenerator -from plotly.graph_objects import Figure - - -def test_circle_plot_plotly(): - target = CirclePlotGenerator() - figure = Figure() - data = [[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]] - target.plot_plotly(figure, data) - i = 0 - - -def test_imshow_plot_plotly(): - target = ImShowPlotGenerator() - figure = Figure() - data = np.random.rand(10, 10) - target.plot_plotly(figure, data) - assert isinstance(figure.data[0], plotly.graph_objs._heatmap.Heatmap)