diff --git a/google/colab/_quickchart_helpers.py b/google/colab/_quickchart_helpers.py
index 1d9289af..dd189534 100644
--- a/google/colab/_quickchart_helpers.py
+++ b/google/colab/_quickchart_helpers.py
@@ -1,6 +1,5 @@
"""Supporting code for quickchart functionality."""
-import inspect
import textwrap
import uuid as _uuid
@@ -160,12 +159,7 @@ def __init__(self, df, plot_func, args, kwargs, df_registry):
self._kwargs = kwargs
self._chart_id = f'chart-{str(_uuid.uuid4())}'
- with mpl.rc_context(dict(_MPL_STYLE_OPTIONS)):
- # We want the charts to be small when there are many of them, but larger
- # when a user inserts a single chart. We set `figscale` here so that it
- # isn't remembered if a user clicks on a chart.
- figscale = 0.5
- self._chart = plot_func(df, figscale=figscale, *args, **kwargs)
+ self._chart = None
@property
def chart_id(self):
@@ -180,30 +174,15 @@ def get_code(self):
if self._df_varname is None:
self._df_varname = self._df_registry.get_or_register_varname(self._df)
- plot_func_src = inspect.getsource(self._plot_func)
- plot_invocation = textwrap.dedent(
- """\
- chart = {plot_func}({df_varname}, *{args}, **{kwargs})
- chart""".format(
- plot_func=self._plot_func.__name__,
- args=str(self._args),
- kwargs=str(self._kwargs),
- df_varname=self._df_varname,
- )
- )
-
- chart_src = textwrap.dedent("""\
- import numpy as np
- from google.colab import autoviz
- """)
- chart_src += '\n'
- chart_src += plot_func_src
- chart_src += '\n'
- chart_src += plot_invocation
- return chart_src
+ return self._plot_func(self._df_varname, *self._args, **self._kwargs)
def _repr_html_(self):
"""Gets the HTML representation of the chart."""
+ if self._chart is None:
+ with mpl.rc_context(dict(_MPL_STYLE_OPTIONS)):
+ exec_code = self._plot_func('df', *self._args, **self._kwargs)
+ exec(exec_code, {'df': self._df}) # pylint: disable=exec-used
+ self._chart = _quickchart_lib.autoviz.MplChart.from_current_mpl_state()
chart_html = self._chart._repr_mimebundle_()['text/html'] # pylint:disable = protected-access
script_start = chart_html.find('""")
+ return cls(f"""""")
def _repr_html_(self):
return self.chart_html
@@ -59,122 +57,202 @@ class autoviz: # pylint:disable=invalid-name
MplChart = MplChart # pylint:disable=invalid-name
-# Note: pyformat disabled for this module since these are user-facing code
-# snippets that are meant to be more notebook-style per b/283014273#comment8.
-# pyformat: disable
-# pylint:disable=missing-function-docstring
+def histogram(df_name: str, colname: str, num_bins=20):
+ """Generates a histogram for the given data column.
+ Args:
+ df_name: Variable name of a dataframe.
+ colname: Name of the column to plot.
+ num_bins: The number of value bins.
-# NOTE: All functions below must have a `figscale` keyword arg.
+ Returns:
+ Code to generate the plot.
+ """
+ return f"""from matplotlib import pyplot as plt
+{df_name}[{colname!r}].plot(kind='hist', bins={num_bins}, title={colname!r})
+plt.gca().spines[['top', 'right',]].set_visible(False)
+plt.tight_layout()"""
-def histogram(df, colname, num_bins=20, figscale=1):
- from matplotlib import pyplot as plt
- df[colname].plot(kind='hist', bins=num_bins, title=colname, figsize=(8*figscale, 4*figscale))
- plt.gca().spines[['top', 'right',]].set_visible(False)
- plt.tight_layout()
- return autoviz.MplChart.from_current_mpl_state()
+def categorical_histogram(df_name, colname):
+ """Generates a single categorical histogram.
+ Args:
+ df_name: Variable name of a dataframe.
+ colname: The column name to plot.
-def categorical_histogram(df, colname, figscale=1, mpl_palette_name='Dark2'):
- from matplotlib import pyplot as plt
- import seaborn as sns
- df.groupby(colname).size().plot(kind='barh', color=sns.palettes.mpl_palette(mpl_palette_name), figsize=(8*figscale, 4.8*figscale))
- plt.gca().spines[['top', 'right',]].set_visible(False)
- return autoviz.MplChart.from_current_mpl_state()
+ Returns:
+ Code to generate the plot.
+ """
+ return f"""from matplotlib import pyplot as plt
+import seaborn as sns
+{df_name}.groupby({colname!r}).size().plot(kind='barh', color=sns.palettes.mpl_palette('Dark2'))
+plt.gca().spines[['top', 'right',]].set_visible(False)"""
-def heatmap(df, x_colname, y_colname, figscale=1, mpl_palette_name='viridis'):
- from matplotlib import pyplot as plt
- import seaborn as sns
- import pandas as pd
- plt.subplots(figsize=(8 * figscale, 8 * figscale))
- df_2dhist = pd.DataFrame({
- x_label: grp[y_colname].value_counts()
- for x_label, grp in df.groupby(x_colname)
- })
- sns.heatmap(df_2dhist, cmap=mpl_palette_name)
- plt.xlabel(x_colname)
- plt.ylabel(y_colname)
- return autoviz.MplChart.from_current_mpl_state()
-
-
-def swarm_plot(df, value_colname, facet_colname, figscale=1, mpl_palette_name='Dark2', jitter_domain_width=8):
- from matplotlib import pyplot as plt
- import seaborn as sns
- palette = sns.palettes.mpl_palette(mpl_palette_name)
- facet_values = list(sorted(df[facet_colname].unique()))
- figsize = (1.2 * figscale * len(facet_values), 8 * figscale)
- _, ax = plt.subplots(figsize=figsize)
- ax.spines[['top', 'right']].set_visible(False)
- xtick_locs = [jitter_domain_width*i for i in range(len(facet_values))]
- for i, facet_value in enumerate(facet_values):
- color = palette[i % len(palette)]
- values = df[df[facet_colname] == facet_value][value_colname]
- r1, r2 = np.random.random(len(values)), np.random.random(len(values))
- jitter = np.sqrt(-2*np.log(r1))*np.cos(2*np.pi*r2) # Box-Muller.
- ax.scatter(xtick_locs[i] + jitter, values, s=1.5, alpha=.8, color=color)
- ax.xaxis.set_ticks(xtick_locs, facet_values, rotation='vertical')
- plt.title(facet_colname)
- plt.ylabel(value_colname)
- return autoviz.MplChart.from_current_mpl_state()
-
-
-def violin_plot(df, value_colname, facet_colname, figscale=1, mpl_palette_name='Dark2', **kwargs):
- from matplotlib import pyplot as plt
- import seaborn as sns
- figsize = (12 * figscale, 1.2 * figscale * len(df[facet_colname].unique()))
- plt.figure(figsize=figsize)
- sns.violinplot(df, x=value_colname, y=facet_colname, palette=mpl_palette_name, **kwargs)
- sns.despine(top=True, right=True, bottom=True, left=True)
- return autoviz.MplChart.from_current_mpl_state()
+def heatmap(df_name: str, x_colname: str, y_colname: str):
+ """Generates a single heatmap.
+ Args:
+ df_name: Variable name of a dataframe.
+ x_colname: The x-axis column name.
+ y_colname: The y-axis column name.
-def value_plot(df, y, figscale=1):
- from matplotlib import pyplot as plt
- df[y].plot(kind='line', figsize=(8 * figscale, 4 * figscale), title=y)
- plt.gca().spines[['top', 'right']].set_visible(False)
- plt.tight_layout()
- return autoviz.MplChart.from_current_mpl_state()
+ Returns:
+ Code to generate the plot.
+ """
+ return f"""from matplotlib import pyplot as plt
+import seaborn as sns
+import pandas as pd
+plt.subplots(figsize=(8, 8))
+df_2dhist = pd.DataFrame({{
+ x_label: grp[{y_colname!r}].value_counts()
+ for x_label, grp in {df_name}.groupby({x_colname!r})
+}})
+sns.heatmap(df_2dhist, cmap='viridis')
+plt.xlabel({x_colname!r})
+plt.ylabel({y_colname!r})"""
+
+
+def swarm_plot(
+ df_name: str, value_colname: str, facet_colname: str, jitter_domain_width=8
+):
+ """Generates a single swarm plot.
+
+ Incorporated from altair example gallery:
+ https://altair-viz.github.io/gallery/stripplot.html
+
+ Args:
+ df_name: Variable name of a dataframe.
+ value_colname: The value distribution column name.
+ facet_colname: The faceting column name.
+ jitter_domain_width: Jitter width.
+
+ Returns:
+ Code to generate the plot.
+ """
+ return f"""from matplotlib import pyplot as plt
+import numpy as np
+import seaborn as sns
+palette = sns.palettes.mpl_palette('Dark2')
+facet_values = list(sorted({df_name}[{facet_colname!r}].unique()))
+_, ax = plt.subplots(figsize=(1.2 * len(facet_values), 8))
+ax.spines[['top', 'right']].set_visible(False)
+xtick_locs = [{jitter_domain_width}*i for i in range(len(facet_values))]
+for i, facet_value in enumerate(facet_values):
+ color = palette[i % len(palette)]
+ values = {df_name}[{df_name}[{facet_colname!r}] == facet_value][{value_colname!r}]
+ r1, r2 = np.random.random(len(values)), np.random.random(len(values))
+ jitter = np.sqrt(-2*np.log(r1))*np.cos(2*np.pi*r2) # Box-Muller.
+ ax.scatter(xtick_locs[i] + jitter, values, s=1.5, alpha=.8, color=color)
+ax.xaxis.set_ticks(xtick_locs, facet_values, rotation='vertical')
+plt.title({facet_colname!r})
+plt.ylabel({value_colname!r})"""
+
+
+def violin_plot(
+ df_name: str, value_colname: str, facet_colname: str, inner: str
+):
+ """Generates a single violin plot.
+
+ Args:
+ df_name: Variable name of a dataframe.
+ value_colname: The value distribution column name.
+ facet_colname: The faceting column name.
+ inner: Representation of the data in the violin interior.
+
+ Returns:
+ Code to generate the plot.
+ """
+ return f"""from matplotlib import pyplot as plt
+import seaborn as sns
+figsize = (12, 1.2 * len({df_name}[{facet_colname!r}].unique()))
+plt.figure(figsize=figsize)
+sns.violinplot({df_name}, x={value_colname!r}, y={facet_colname!r}, inner={inner!r}, palette='Dark2')
+sns.despine(top=True, right=True, bottom=True, left=True)"""
-def scatter_plot(df, x_colname, y_colname, figscale=1, alpha=.8):
- from matplotlib import pyplot as plt
- plt.figure(figsize=(6 * figscale, 6 * figscale))
- df.plot(kind='scatter', x=x_colname, y=y_colname, s=(32 * figscale), alpha=alpha)
- plt.gca().spines[['top', 'right',]].set_visible(False)
- plt.tight_layout()
- return autoviz.MplChart.from_current_mpl_state()
+def value_plot(df_name: str, y: str):
+ """Generates a single value plot.
-def time_series_multiline(df, timelike_colname, value_colname, series_colname, figscale=1, mpl_palette_name='Dark2'):
+ Args:
+ df_name: Variable name of a dataframe.
+ y: The series name to plot.
+
+ Returns:
+ Code to generate the plot.
+ """
+
+ return f"""from matplotlib import pyplot as plt
+{df_name}[{y!r}].plot(kind='line', figsize=(8, 4), title={y!r})
+plt.gca().spines[['top', 'right']].set_visible(False)
+plt.tight_layout()"""
+
+
+def scatter_plot(df_name: str, x_colname: str, y_colname: str):
+ """Generates a single scatter plot.
+
+ Args:
+ df_name: Variable name of a dataframe.
+ x_colname: Column name for the X axis.
+ y_colname: Column name for the Y axis.
+
+ Returns:
+ Code to generate the plot.
+ """
+
+ return f"""from matplotlib import pyplot as plt
+plt.figure(figsize=(6, 6))
+{df_name}.plot(kind='scatter', x={x_colname!r}, y={y_colname!r}, s=32, alpha=.8)
+plt.gca().spines[['top', 'right',]].set_visible(False)
+plt.tight_layout()"""
+
+
+def time_series_multiline(
+ df_name: str, timelike_colname: str, value_colname: str, series_colname: str
+):
+ """Generates a single time series plot.
+
+ Args:
+ df_name: Variable name of a dataframe.
+ timelike_colname: Column name for the time based column.
+ value_colname: Column name for the value column.
+ series_colname: Column name for the series column.
+
+ Returns:
+ Code to generate the plot.
+ """
+ plot_series_impl = f"""xs = series[{timelike_colname!r}]
+ ys = series[{value_colname!r}]
+ """
+ if value_colname == 'count()':
+ plot_series_impl = f"""counted = (series[{timelike_colname!r}]
+ .value_counts()
+ .reset_index(name='counts')
+ .rename({{'index': {timelike_colname!r}}}, axis=1)
+ .sort_values({timelike_colname!r}, ascending=True))
+ xs = counted[{timelike_colname!r}]
+ ys = counted['counts']"""
+
+ series_impl = """_plot_series(df_sorted, '')"""
+ if series_colname:
+ series_impl = f"""for i, (series_name, series) in enumerate(df_sorted.groupby({series_colname!r})):
+ _plot_series(series, series_name, i)
+ fig.legend(title={series_colname!r}, bbox_to_anchor=(1, 1), loc='upper left')"""
+
+ return f"""from matplotlib import pyplot as plt
+import seaborn as sns
+def _plot_series(series, series_name, series_index=0):
from matplotlib import pyplot as plt
import seaborn as sns
- figsize = (10 * figscale, 5.2 * figscale)
- palette = list(sns.palettes.mpl_palette(mpl_palette_name))
- def _plot_series(series, series_name, series_index=0):
- if value_colname == 'count()':
- counted = (series[timelike_colname]
- .value_counts()
- .reset_index(name='counts')
- .rename({'index': timelike_colname}, axis=1)
- .sort_values(timelike_colname, ascending=True))
- xs = counted[timelike_colname]
- ys = counted['counts']
- else:
- xs = series[timelike_colname]
- ys = series[value_colname]
- plt.plot(xs, ys, label=series_name, color=palette[series_index % len(palette)])
-
- fig, ax = plt.subplots(figsize=figsize, layout='constrained')
- df = df.sort_values(timelike_colname, ascending=True)
- if series_colname:
- for i, (series_name, series) in enumerate(df.groupby(series_colname)):
- _plot_series(series, series_name, i)
- fig.legend(title=series_colname, bbox_to_anchor=(1, 1), loc='upper left')
- else:
- _plot_series(df, '')
- sns.despine(fig=fig, ax=ax)
- plt.xlabel(timelike_colname)
- plt.ylabel(value_colname)
- return autoviz.MplChart.from_current_mpl_state()
+ palette = list(sns.palettes.mpl_palette('Dark2'))
+ {plot_series_impl}
+ plt.plot(xs, ys, label=series_name, color=palette[series_index % len(palette)])
+
+fig, ax = plt.subplots(figsize=(10, 5.2), layout='constrained')
+df_sorted = {df_name}.sort_values({timelike_colname!r}, ascending=True)
+{series_impl}
+sns.despine(fig=fig, ax=ax)
+plt.xlabel({timelike_colname!r})
+plt.ylabel({value_colname!r})"""