From da7b2a710284a8603b6e0b538866186d99baf393 Mon Sep 17 00:00:00 2001 From: eugene Date: Thu, 29 Aug 2024 21:23:39 +0800 Subject: [PATCH] marimo.app_meta.theme and custom themers to auto switch plotting themes based on the display theme (#2126) * some changes for the Chinese README * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make the gutter bigger when the window is super wide * revert changes for fixing #2035 * Update README_Chinese.md highlights * first attemptation * remove "system" option from display theme * Add marimo.app_meta * Add themers to auto switch plotting themes based on the display theme * fix type errors * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert fe changes * revert fe changes * apply_theme for formatters * update app_meta * some small errors * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * set default value for theme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update api.yaml --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Myles Scolnick --- marimo/__init__.py | 9 +- marimo/_config/config.py | 3 +- .../_output/formatters/altair_formatters.py | 6 + marimo/_output/formatters/bokeh_formatters.py | 6 + .../_output/formatters/formatter_factory.py | 13 ++ marimo/_output/formatters/formatters.py | 6 +- .../formatters/holoviews_formatters.py | 11 ++ .../formatters/matplotlib_formatters.py | 8 + .../_output/formatters/plotly_formatters.py | 6 + marimo/_runtime/app_meta.py | 23 +++ marimo/_runtime/runtime.py | 20 ++- marimo/_server/sessions.py | 4 +- marimo/_smoke_tests/theming/apply_theme.py | 150 ++++++++++++++++++ 13 files changed, 260 insertions(+), 5 deletions(-) create mode 100644 marimo/_runtime/app_meta.py create mode 100644 marimo/_smoke_tests/theming/apply_theme.py diff --git a/marimo/__init__.py b/marimo/__init__.py index fda238e2912..4bf5b68b35b 100644 --- a/marimo/__init__.py +++ b/marimo/__init__.py @@ -24,6 +24,7 @@ "MarimoIslandGenerator", "accordion", "carousel", + "app_meta", "as_html", "audio", "callout", @@ -109,7 +110,13 @@ ) from marimo._runtime.context.utils import running_in_notebook from marimo._runtime.control_flow import MarimoStopError, stop -from marimo._runtime.runtime import cli_args, defs, query_params, refs +from marimo._runtime.runtime import ( + app_meta, + cli_args, + defs, + query_params, + refs, +) from marimo._runtime.state import state from marimo._server.asgi import create_asgi_app from marimo._sql.sql import sql diff --git a/marimo/_config/config.py b/marimo/_config/config.py index c9d92028098..b907c117d22 100644 --- a/marimo/_config/config.py +++ b/marimo/_config/config.py @@ -101,6 +101,7 @@ class RuntimeConfig(TypedDict): # TODO(akshayka): remove normal, migrate to compact # normal == compact WidthType = Literal["normal", "compact", "medium", "full"] +Theme = Literal["light", "dark", "system"] @mddoc @@ -116,7 +117,7 @@ class DisplayConfig(TypedDict): - `dataframes`: `"rich"` or `"plain"` """ - theme: Literal["light", "dark", "system"] + theme: Theme code_editor_font_size: int cell_output: Literal["above", "below"] default_width: WidthType diff --git a/marimo/_output/formatters/altair_formatters.py b/marimo/_output/formatters/altair_formatters.py index 24363b352ef..22aa0b9ba20 100644 --- a/marimo/_output/formatters/altair_formatters.py +++ b/marimo/_output/formatters/altair_formatters.py @@ -3,6 +3,7 @@ import html +from marimo._config.config import Theme from marimo._messaging.mimetypes import KnownMimeType from marimo._output.builder import h from marimo._output.formatters.formatter_factory import FormatterFactory @@ -48,3 +49,8 @@ def _show_chart(chart: altair.Chart) -> tuple[KnownMimeType, str]: ) ), ) + + def apply_theme(self, theme: Theme) -> None: + import altair as alt # type: ignore + + alt.themes.enable("dark" if theme == "dark" else "default") # type: ignore diff --git a/marimo/_output/formatters/bokeh_formatters.py b/marimo/_output/formatters/bokeh_formatters.py index 3b925b17664..51122153b86 100644 --- a/marimo/_output/formatters/bokeh_formatters.py +++ b/marimo/_output/formatters/bokeh_formatters.py @@ -3,6 +3,7 @@ from typing import Optional +from marimo._config.config import Theme from marimo._messaging.mimetypes import KnownMimeType from marimo._output.builder import h from marimo._output.formatters.formatter_factory import FormatterFactory @@ -67,3 +68,8 @@ def _show_plot( ) ), ) + + def apply_theme(self, theme: Theme) -> None: + from bokeh.io import curdoc # type: ignore + + curdoc().theme = "dark_minimal" if theme == "dark" else None # type: ignore diff --git a/marimo/_output/formatters/formatter_factory.py b/marimo/_output/formatters/formatter_factory.py index 9845ab286a6..c8933e5c2e3 100644 --- a/marimo/_output/formatters/formatter_factory.py +++ b/marimo/_output/formatters/formatter_factory.py @@ -4,6 +4,8 @@ import abc from typing import Callable, Optional +from marimo._config.config import Theme + # Abstract base class for formatters that are installed at runtime. class FormatterFactory(abc.ABC): @@ -29,3 +31,14 @@ def register(self) -> Callable[[], None] | None: patches. """ raise NotImplementedError + + def apply_theme(self, theme: Theme) -> None: + """ + Apply the theme (light/dark) to third party libraries. + If the theme is set to "system", then we fallback to "light". + + Args: + theme: The theme to apply. + """ + del theme + return diff --git a/marimo/_output/formatters/formatters.py b/marimo/_output/formatters/formatters.py index 4891677efd8..db0a638a74b 100644 --- a/marimo/_output/formatters/formatters.py +++ b/marimo/_output/formatters/formatters.py @@ -4,6 +4,7 @@ import sys from typing import Any, Callable, Sequence +from marimo._config.config import Theme from marimo._output.formatters.altair_formatters import AltairFormatter from marimo._output.formatters.anywidget_formatters import AnyWidgetFormatter from marimo._output.formatters.bokeh_formatters import BokehFormatter @@ -55,7 +56,7 @@ ] -def register_formatters() -> None: +def register_formatters(theme: Theme = "light") -> None: """Register formatters with marimo. marimo comes packaged with rich formatters for a number of third-party @@ -81,6 +82,7 @@ def register_formatters() -> None: for package, factory in THIRD_PARTY_FACTORIES.items(): if package in sys.modules: factory.register() + factory.apply_theme(theme) pre_registered.add(package) third_party_factories = { @@ -147,6 +149,7 @@ def exec_module( ) -> Any: loader_return_value = original_exec_module(module) factory.register() + factory.apply_theme(theme) return loader_return_value spec.loader.exec_module = exec_module @@ -161,3 +164,4 @@ def exec_module( # package import. So we can register them at program start-up. for factory in NATIVE_FACTORIES: factory.register() + factory.apply_theme(theme) diff --git a/marimo/_output/formatters/holoviews_formatters.py b/marimo/_output/formatters/holoviews_formatters.py index 3c520b5cfa6..89962032453 100644 --- a/marimo/_output/formatters/holoviews_formatters.py +++ b/marimo/_output/formatters/holoviews_formatters.py @@ -1,6 +1,7 @@ # Copyright 2024 Marimo. All rights reserved. from __future__ import annotations +from marimo._config.config import Theme from marimo._dependencies.dependencies import DependencyManager from marimo._messaging.mimetypes import KnownMimeType from marimo._output.formatters.formatter_factory import FormatterFactory @@ -53,3 +54,13 @@ def _show_chart( html = as_html(backend_output) return ("text/html", html.text) + + def apply_theme(self, theme: Theme) -> None: + import holoviews as hv # type: ignore + + hv.renderer("bokeh").theme = ( + "dark_minimal" if theme == "dark" else None + ) + hv.renderer("plotly").theme = ( + "plotly_dark" if theme == "dark" else "plotly" + ) diff --git a/marimo/_output/formatters/matplotlib_formatters.py b/marimo/_output/formatters/matplotlib_formatters.py index f6d58f45c3f..e1ff62fae88 100644 --- a/marimo/_output/formatters/matplotlib_formatters.py +++ b/marimo/_output/formatters/matplotlib_formatters.py @@ -1,6 +1,7 @@ # Copyright 2024 Marimo. All rights reserved. from __future__ import annotations +from marimo._config.config import Theme from marimo._messaging.mimetypes import KnownMimeType from marimo._output.formatters.formatter_factory import FormatterFactory @@ -55,3 +56,10 @@ def _show_bar_container(bc: BarContainer) -> tuple[KnownMimeType, str]: return mime_data_artist(bc.patches[0].figure) # type: ignore else: return ("text/plain", str(bc)) + + def apply_theme(self, theme: Theme) -> None: + import matplotlib # type: ignore + + matplotlib.style.use( + "dark_background" if theme == "dark" else "default" + ) diff --git a/marimo/_output/formatters/plotly_formatters.py b/marimo/_output/formatters/plotly_formatters.py index 7e97c4fc53d..18315cafa2c 100644 --- a/marimo/_output/formatters/plotly_formatters.py +++ b/marimo/_output/formatters/plotly_formatters.py @@ -4,6 +4,7 @@ import json from typing import Any +from marimo._config.config import Theme from marimo._messaging.mimetypes import KnownMimeType from marimo._output.formatters.formatter_factory import FormatterFactory from marimo._output.hypertext import Html @@ -51,3 +52,8 @@ def render_plotly_dict(json: dict[Any, Any]) -> Html: args={"figure": json, "config": resolved_config}, ) ) + + def apply_theme(self, theme: Theme) -> None: + import plotly.io as pio # type: ignore + + pio.templates.default = "plotly_dark" if theme == "dark" else "plotly" diff --git a/marimo/_runtime/app_meta.py b/marimo/_runtime/app_meta.py new file mode 100644 index 00000000000..3a522517c81 --- /dev/null +++ b/marimo/_runtime/app_meta.py @@ -0,0 +1,23 @@ +# Copyright 2024 Marimo. All rights reserved. +from marimo._config.utils import load_config + + +class AppMeta: + """ + Metadata about the app. + + This is used to store metadata about the app + that is not part of the app's code or state. + """ + + def __init__(self) -> None: + self.user_config = load_config() + + @property + def theme(self) -> str: + """The display theme of the app.""" + theme = self.user_config["display"]["theme"] or "light" + if theme == "system": + # TODO(mscolnick): have frontend tell the backend the system theme + return "light" + return theme diff --git a/marimo/_runtime/runtime.py b/marimo/_runtime/runtime.py index 78fce625314..51ddc4afc7f 100644 --- a/marimo/_runtime/runtime.py +++ b/marimo/_runtime/runtime.py @@ -67,6 +67,7 @@ from marimo._plugins.core.web_component import JSONType from marimo._plugins.ui._core.ui_element import MarimoConvertValueException from marimo._runtime import dataflow, handlers, marimo_pdb, patches +from marimo._runtime.app_meta import AppMeta from marimo._runtime.complete import complete, completion_worker from marimo._runtime.context import ( ContextNotInitializedError, @@ -232,6 +233,23 @@ def query_params() -> QueryParams: return get_context().query_params +@mddoc +def app_meta() -> AppMeta: + """Get the metadata of a marimo app. + + **Examples**: + + ```python3 + theme = mo.app_meta().theme + ``` + + **Returns**: + + - An `AppMeta` object containing the app's metadata. + """ + return AppMeta() + + @mddoc def cli_args() -> CLIArgs: """Get the command line arguments of a marimo notebook. @@ -1891,7 +1909,7 @@ def launch_kernel( # kernels are processes in edit mode, and each process needs to # install the formatter import hooks - register_formatters() + register_formatters(theme=user_config["display"]["theme"]) signal.signal( signal.SIGINT, handlers.construct_interrupt_handler(kernel) diff --git a/marimo/_server/sessions.py b/marimo/_server/sessions.py index 600b12a579e..5eb2d7bb44e 100644 --- a/marimo/_server/sessions.py +++ b/marimo/_server/sessions.py @@ -215,7 +215,9 @@ def launch_kernel_with_cleanup(*args: Any) -> None: # install formatter import hooks, which will be shared by all # threads (in edit mode, the single kernel process installs # formatters ...) - register_formatters() + register_formatters( + theme=self.user_config_manager.config["display"]["theme"] + ) # Make threads daemons so killing the server immediately brings # down all client sessions diff --git a/marimo/_smoke_tests/theming/apply_theme.py b/marimo/_smoke_tests/theming/apply_theme.py new file mode 100644 index 00000000000..55dcfce5c1b --- /dev/null +++ b/marimo/_smoke_tests/theming/apply_theme.py @@ -0,0 +1,150 @@ +import marimo + +__generated_with = "0.8.3" +app = marimo.App(width="medium") + + +@app.cell +def __(mo): + mo.app_meta().theme + return + + +@app.cell +def __(mo): + mo.md(r"""# Seaborn""") + return + + +@app.cell +def __(df, plt): + import seaborn as sns + + plt.figure(figsize=(10, 6)) + sns.lineplot(x="x", y="y", data=df) + plt.title("Seaborn: Sine Wave") + return (sns,) + + +@app.cell +def __(mo): + mo.md(r"""# Matplotlib""") + return + + +@app.cell +def __(x, y): + import matplotlib.pyplot as plt + + plt.figure(figsize=(10, 6)) + plt.plot(x, y) + plt.title("Matplotlib: Sine Wave") + return (plt,) + + +@app.cell(hide_code=True) +def __(mo): + mo.md(r"""# Holoviews""") + return + + +@app.cell +def __(df): + import holoviews as hv + + hv.extension("bokeh") + curve = hv.Curve(df, "x", "y") + hv.render(curve.opts(title="Holoviews: Sine Wave", width=800, height=400)) + return curve, hv + + +@app.cell(hide_code=True) +def __(mo): + mo.md(r"""# Bokeh""") + return + + +@app.cell +def __(x, y): + # Bokeh + from bokeh.plotting import figure, show + + p = figure( + title="Bokeh: Sine Wave", + x_axis_label="x", + y_axis_label="y", + width=800, + height=400, + ) + p.line(x, y, line_width=2) + p + return figure, p, show + + +@app.cell(hide_code=True) +def __(mo): + mo.md(r"""# Altair""") + return + + +@app.cell(hide_code=True) +def __(mo): + import altair as alt + from vega_datasets import data + + chart = ( + alt.Chart(data.cars()) + .mark_point() + .encode( + x="Horsepower", + y="Miles_per_Gallon", + color="Origin", + ) + ) + + chart = mo.ui.altair_chart(chart) + chart + return alt, chart, data + + +@app.cell(hide_code=True) +def __(mo): + mo.md(r"""# Plotly""") + return + + +@app.cell(hide_code=True) +def __(df): + # Plotly + import plotly.express as px + + px.line(df, x="x", y="y", title="Plotly: Sine Wave") + return (px,) + + +@app.cell +def __(np, pd): + # Sample data + x = np.linspace(0, 10, 100) + y = np.sin(x) + df = pd.DataFrame({"x": x, "y": y}) + return df, x, y + + +@app.cell +def __(): + import marimo as mo + + return (mo,) + + +@app.cell +def __(): + import numpy as np + import pandas as pd + + return np, pd + + +if __name__ == "__main__": + app.run()