diff --git a/google/colab/_quickchart.py b/google/colab/_quickchart.py index b6a4232e..e101fb81 100644 --- a/google/colab/_quickchart.py +++ b/google/colab/_quickchart.py @@ -57,6 +57,36 @@ def _ensure_dataframe_registry(): _ensure_dataframe_registry() + chart_sections = determine_charts( + df, _DATAFRAME_REGISTRY, max_chart_instances + ) + if not chart_sections: + print('No charts were generated by quickchart') + return chart_sections + + +def find_charts_json(df_name: str, max_chart_instances=None): + """Equivalent to find_charts, but emits to JSON for use from browser.""" + + class FixedDataframeRegistry: + + def get_or_register_varname(self, _) -> str: + """Returns the name of the fixed dataframe name.""" + return df_name + + dataframe = IPython.get_ipython().user_ns[df_name] + + chart_sections = determine_charts( + dataframe, FixedDataframeRegistry(), max_chart_instances + ) + return IPython.display.JSON([s.to_json() for s in chart_sections]) + + +def determine_charts(df, dataframe_registry, max_chart_instances=None): + """Finds charts compatible with dtypes of the given data frame.""" + # Lazy import to avoid loading matplotlib and transitive deps on kernel init. + from google.colab import _quickchart_helpers # pylint: disable=g-import-not-at-top + dtype_groups = _classify_dtypes(df) numeric_cols = dtype_groups['numeric'] categorical_cols = dtype_groups['categorical'] @@ -66,7 +96,7 @@ def _ensure_dataframe_registry(): if numeric_cols: chart_sections.append( _quickchart_helpers.histograms_section( - df, numeric_cols[:max_chart_instances], _DATAFRAME_REGISTRY + df, numeric_cols[:max_chart_instances], dataframe_registry ) ) @@ -74,7 +104,7 @@ def _ensure_dataframe_registry(): selected_categorical_cols = categorical_cols[:max_chart_instances] chart_sections += [ _quickchart_helpers.categorical_histograms_section( - df, selected_categorical_cols, _DATAFRAME_REGISTRY + df, selected_categorical_cols, dataframe_registry ), ] @@ -83,7 +113,7 @@ def _ensure_dataframe_registry(): _quickchart_helpers.scatter_section( df, _select_first_k_pairs(numeric_cols, k=max_chart_instances), - _DATAFRAME_REGISTRY, + dataframe_registry, ), ] @@ -97,14 +127,14 @@ def _ensure_dataframe_registry(): categorical_cols=categorical_cols, k=max_chart_instances, ), - _DATAFRAME_REGISTRY, + dataframe_registry, ), ) if numeric_cols: chart_sections.append( _quickchart_helpers.value_plots_section( - df, numeric_cols[:max_chart_instances], _DATAFRAME_REGISTRY + df, numeric_cols[:max_chart_instances], dataframe_registry ) ) @@ -113,7 +143,7 @@ def _ensure_dataframe_registry(): _quickchart_helpers.heatmaps_section( df, _select_first_k_pairs(categorical_cols, k=max_chart_instances), - _DATAFRAME_REGISTRY, + dataframe_registry, ), ] @@ -124,12 +154,9 @@ def _ensure_dataframe_registry(): _select_faceted_numeric_cols( numeric_cols, categorical_cols, k=max_chart_instances ), - _DATAFRAME_REGISTRY, + dataframe_registry, ), ] - - if not chart_sections: - print('No charts were generated by quickchart') return chart_sections diff --git a/google/colab/_quickchart_helpers.py b/google/colab/_quickchart_helpers.py index 643a281c..6414c8a0 100644 --- a/google/colab/_quickchart_helpers.py +++ b/google/colab/_quickchart_helpers.py @@ -99,6 +99,13 @@ def display(self): for d in self._displayables: d.display() + def to_json(self): + charts = [chart.to_json() for chart in self._charts] + return { + 'section_type': self._section_type, + 'charts': charts, + } + class SectionTitle: """Section title used for delineating chart sections.""" @@ -169,18 +176,29 @@ def display(self): """Displays the chart within a notebook context.""" IPython.display.display(self) - def get_code(self): - """Gets the code and associated dependencies + context for a given chart.""" + def to_json(self): + data = self.get_code_and_title() + return { + 'code': data.code, + 'title': data.title, + } + + def get_code_and_title(self): if self._df_varname is None: self._df_varname = self._df_registry.get_or_register_varname(self._df) return self._plot_func(self._df_varname, *self._args, **self._kwargs) + def get_code(self): + """Gets the code and associated dependencies + context for a given chart.""" + + return self.get_code_and_title().code + def _repr_html_(self): """Gets the HTML representation of the chart.""" if self._chart is None: with mpl.rc_context(dict(_MPL_STYLE_OPTIONS)): - exec_code = self._plot_func('df', *self._args, **self._kwargs) + exec_code, _ = self._plot_func('df', *self._args, **self._kwargs) exec(exec_code, {'df': self._df}) # pylint: disable=exec-used self._chart = _quickchart_lib.autoviz.MplChart.from_current_mpl_state() diff --git a/google/colab/_quickchart_lib.py b/google/colab/_quickchart_lib.py index 95ec4a6c..90b781da 100644 --- a/google/colab/_quickchart_lib.py +++ b/google/colab/_quickchart_lib.py @@ -1,6 +1,7 @@ """Library of charts for use by quickchart.""" import base64 +import dataclasses import io import IPython.display @@ -10,6 +11,12 @@ # pylint:disable=g-import-not-at-top +@dataclasses.dataclass +class ChartData: + title: str + code: str + + class MplChart: """Matplotlib chart wrapper that displays charts to PNG elements.""" @@ -57,7 +64,7 @@ class autoviz: # pylint:disable=invalid-name MplChart = MplChart # pylint:disable=invalid-name -def histogram(df_name: str, colname: str, num_bins=20): +def histogram(df_name: str, colname: str, num_bins=20) -> ChartData: """Generates a histogram for the given data column. Args: @@ -68,13 +75,13 @@ def histogram(df_name: str, colname: str, num_bins=20): Returns: Code to generate the plot. """ - return f"""from matplotlib import pyplot as plt + code = f"""from matplotlib import pyplot as plt {df_name}[{colname!r}].plot(kind='hist', bins={num_bins}, title={colname!r}) -plt.gca().spines[['top', 'right',]].set_visible(False) -plt.tight_layout()""" +plt.gca().spines[['top', 'right',]].set_visible(False)""" + return ChartData(title=colname, code=code) -def categorical_histogram(df_name, colname): +def categorical_histogram(df_name, colname) -> ChartData: """Generates a single categorical histogram. Args: @@ -84,13 +91,15 @@ def categorical_histogram(df_name, colname): Returns: Code to generate the plot. """ - return f"""from matplotlib import pyplot as plt + code = f"""from matplotlib import pyplot as plt import seaborn as sns {df_name}.groupby({colname!r}).size().plot(kind='barh', color=sns.palettes.mpl_palette('Dark2')) plt.gca().spines[['top', 'right',]].set_visible(False)""" + return ChartData(title=colname, code=code) -def heatmap(df_name: str, x_colname: str, y_colname: str): + +def heatmap(df_name: str, x_colname: str, y_colname: str) -> ChartData: """Generates a single heatmap. Args: @@ -101,7 +110,7 @@ def heatmap(df_name: str, x_colname: str, y_colname: str): Returns: Code to generate the plot. """ - return f"""from matplotlib import pyplot as plt + code = f"""from matplotlib import pyplot as plt import seaborn as sns import pandas as pd plt.subplots(figsize=(8, 8)) @@ -111,12 +120,14 @@ def heatmap(df_name: str, x_colname: str, y_colname: str): }}) sns.heatmap(df_2dhist, cmap='viridis') plt.xlabel({x_colname!r}) -plt.ylabel({y_colname!r})""" +_ = plt.ylabel({y_colname!r})""" + + return ChartData(f'{x_colname} vs {y_colname}', code) def swarm_plot( df_name: str, value_colname: str, facet_colname: str, jitter_domain_width=8 -): +) -> ChartData: """Generates a single swarm plot. Incorporated from altair example gallery: @@ -132,7 +143,7 @@ def swarm_plot( Code to generate the plot. """ - return f"""from matplotlib import pyplot as plt + code = f"""from matplotlib import pyplot as plt import numpy as np import seaborn as sns palette = sns.palettes.mpl_palette('Dark2') @@ -148,12 +159,14 @@ def swarm_plot( ax.scatter(xtick_locs[i] + jitter, values, s=1.5, alpha=.8, color=color) ax.xaxis.set_ticks(xtick_locs, facet_values, rotation='vertical') plt.title({facet_colname!r}) -plt.ylabel({value_colname!r})""" +_ = plt.ylabel({value_colname!r})""" + + return ChartData(f'{facet_colname} vs {value_colname}', code) def violin_plot( df_name: str, value_colname: str, facet_colname: str, inner: str -): +) -> ChartData: """Generates a single violin plot. Args: @@ -165,15 +178,17 @@ def violin_plot( Returns: Code to generate the plot. """ - return f"""from matplotlib import pyplot as plt + code = f"""from matplotlib import pyplot as plt import seaborn as sns figsize = (12, 1.2 * len({df_name}[{facet_colname!r}].unique())) plt.figure(figsize=figsize) sns.violinplot({df_name}, x={value_colname!r}, y={facet_colname!r}, inner={inner!r}, palette='Dark2') sns.despine(top=True, right=True, bottom=True, left=True)""" + return ChartData(f'{facet_colname} vs {value_colname}', code) -def value_plot(df_name: str, y: str): + +def value_plot(df_name: str, y: str) -> ChartData: """Generates a single value plot. Args: @@ -184,13 +199,14 @@ def value_plot(df_name: str, y: str): Code to generate the plot. """ - return f"""from matplotlib import pyplot as plt + code = f"""from matplotlib import pyplot as plt {df_name}[{y!r}].plot(kind='line', figsize=(8, 4), title={y!r}) -plt.gca().spines[['top', 'right']].set_visible(False) -plt.tight_layout()""" +plt.gca().spines[['top', 'right']].set_visible(False)""" + + return ChartData(y, code) -def scatter_plot(df_name: str, x_colname: str, y_colname: str): +def scatter_plot(df_name: str, x_colname: str, y_colname: str) -> ChartData: """Generates a single scatter plot. Args: @@ -202,16 +218,16 @@ def scatter_plot(df_name: str, x_colname: str, y_colname: str): Code to generate the plot. """ - return f"""from matplotlib import pyplot as plt -plt.figure(figsize=(6, 6)) + code = f"""from matplotlib import pyplot as plt {df_name}.plot(kind='scatter', x={x_colname!r}, y={y_colname!r}, s=32, alpha=.8) -plt.gca().spines[['top', 'right',]].set_visible(False) -plt.tight_layout()""" +plt.gca().spines[['top', 'right',]].set_visible(False)""" + + return ChartData(f'{x_colname} vs {y_colname}', code) def time_series_multiline( df_name: str, timelike_colname: str, value_colname: str, series_colname: str -): +) -> ChartData: """Generates a single time series plot. Args: @@ -241,7 +257,7 @@ def time_series_multiline( _plot_series(series, series_name, i) fig.legend(title={series_colname!r}, bbox_to_anchor=(1, 1), loc='upper left')""" - return f"""from matplotlib import pyplot as plt + code = f"""from matplotlib import pyplot as plt import seaborn as sns def _plot_series(series, series_name, series_index=0): from matplotlib import pyplot as plt @@ -255,4 +271,6 @@ def _plot_series(series, series_name, series_index=0): {series_impl} sns.despine(fig=fig, ax=ax) plt.xlabel({timelike_colname!r}) -plt.ylabel({value_colname!r})""" +_ = plt.ylabel({value_colname!r})""" + + return ChartData(f'{timelike_colname} vs {value_colname}', code)