diff --git a/google/colab/_quickchart_helpers.py b/google/colab/_quickchart_helpers.py index 1d9289af..dd189534 100644 --- a/google/colab/_quickchart_helpers.py +++ b/google/colab/_quickchart_helpers.py @@ -1,6 +1,5 @@ """Supporting code for quickchart functionality.""" -import inspect import textwrap import uuid as _uuid @@ -160,12 +159,7 @@ def __init__(self, df, plot_func, args, kwargs, df_registry): self._kwargs = kwargs self._chart_id = f'chart-{str(_uuid.uuid4())}' - with mpl.rc_context(dict(_MPL_STYLE_OPTIONS)): - # We want the charts to be small when there are many of them, but larger - # when a user inserts a single chart. We set `figscale` here so that it - # isn't remembered if a user clicks on a chart. - figscale = 0.5 - self._chart = plot_func(df, figscale=figscale, *args, **kwargs) + self._chart = None @property def chart_id(self): @@ -180,30 +174,15 @@ def get_code(self): if self._df_varname is None: self._df_varname = self._df_registry.get_or_register_varname(self._df) - plot_func_src = inspect.getsource(self._plot_func) - plot_invocation = textwrap.dedent( - """\ - chart = {plot_func}({df_varname}, *{args}, **{kwargs}) - chart""".format( - plot_func=self._plot_func.__name__, - args=str(self._args), - kwargs=str(self._kwargs), - df_varname=self._df_varname, - ) - ) - - chart_src = textwrap.dedent("""\ - import numpy as np - from google.colab import autoviz - """) - chart_src += '\n' - chart_src += plot_func_src - chart_src += '\n' - chart_src += plot_invocation - return chart_src + return self._plot_func(self._df_varname, *self._args, **self._kwargs) def _repr_html_(self): """Gets the HTML representation of the chart.""" + if self._chart is None: + with mpl.rc_context(dict(_MPL_STYLE_OPTIONS)): + exec_code = self._plot_func('df', *self._args, **self._kwargs) + exec(exec_code, {'df': self._df}) # pylint: disable=exec-used + self._chart = _quickchart_lib.autoviz.MplChart.from_current_mpl_state() chart_html = self._chart._repr_mimebundle_()['text/html'] # pylint:disable = protected-access script_start = chart_html.find(' - """) + return cls(f"""""") def _repr_html_(self): return self.chart_html @@ -59,122 +57,202 @@ class autoviz: # pylint:disable=invalid-name MplChart = MplChart # pylint:disable=invalid-name -# Note: pyformat disabled for this module since these are user-facing code -# snippets that are meant to be more notebook-style per b/283014273#comment8. -# pyformat: disable -# pylint:disable=missing-function-docstring +def histogram(df_name: str, colname: str, num_bins=20): + """Generates a histogram for the given data column. + Args: + df_name: Variable name of a dataframe. + colname: Name of the column to plot. + num_bins: The number of value bins. -# NOTE: All functions below must have a `figscale` keyword arg. + Returns: + Code to generate the plot. + """ + return f"""from matplotlib import pyplot as plt +{df_name}[{colname!r}].plot(kind='hist', bins={num_bins}, title={colname!r}) +plt.gca().spines[['top', 'right',]].set_visible(False) +plt.tight_layout()""" -def histogram(df, colname, num_bins=20, figscale=1): - from matplotlib import pyplot as plt - df[colname].plot(kind='hist', bins=num_bins, title=colname, figsize=(8*figscale, 4*figscale)) - plt.gca().spines[['top', 'right',]].set_visible(False) - plt.tight_layout() - return autoviz.MplChart.from_current_mpl_state() +def categorical_histogram(df_name, colname): + """Generates a single categorical histogram. + Args: + df_name: Variable name of a dataframe. + colname: The column name to plot. -def categorical_histogram(df, colname, figscale=1, mpl_palette_name='Dark2'): - from matplotlib import pyplot as plt - import seaborn as sns - df.groupby(colname).size().plot(kind='barh', color=sns.palettes.mpl_palette(mpl_palette_name), figsize=(8*figscale, 4.8*figscale)) - plt.gca().spines[['top', 'right',]].set_visible(False) - return autoviz.MplChart.from_current_mpl_state() + Returns: + Code to generate the plot. + """ + return f"""from matplotlib import pyplot as plt +import seaborn as sns +{df_name}.groupby({colname!r}).size().plot(kind='barh', color=sns.palettes.mpl_palette('Dark2')) +plt.gca().spines[['top', 'right',]].set_visible(False)""" -def heatmap(df, x_colname, y_colname, figscale=1, mpl_palette_name='viridis'): - from matplotlib import pyplot as plt - import seaborn as sns - import pandas as pd - plt.subplots(figsize=(8 * figscale, 8 * figscale)) - df_2dhist = pd.DataFrame({ - x_label: grp[y_colname].value_counts() - for x_label, grp in df.groupby(x_colname) - }) - sns.heatmap(df_2dhist, cmap=mpl_palette_name) - plt.xlabel(x_colname) - plt.ylabel(y_colname) - return autoviz.MplChart.from_current_mpl_state() - - -def swarm_plot(df, value_colname, facet_colname, figscale=1, mpl_palette_name='Dark2', jitter_domain_width=8): - from matplotlib import pyplot as plt - import seaborn as sns - palette = sns.palettes.mpl_palette(mpl_palette_name) - facet_values = list(sorted(df[facet_colname].unique())) - figsize = (1.2 * figscale * len(facet_values), 8 * figscale) - _, ax = plt.subplots(figsize=figsize) - ax.spines[['top', 'right']].set_visible(False) - xtick_locs = [jitter_domain_width*i for i in range(len(facet_values))] - for i, facet_value in enumerate(facet_values): - color = palette[i % len(palette)] - values = df[df[facet_colname] == facet_value][value_colname] - r1, r2 = np.random.random(len(values)), np.random.random(len(values)) - jitter = np.sqrt(-2*np.log(r1))*np.cos(2*np.pi*r2) # Box-Muller. - ax.scatter(xtick_locs[i] + jitter, values, s=1.5, alpha=.8, color=color) - ax.xaxis.set_ticks(xtick_locs, facet_values, rotation='vertical') - plt.title(facet_colname) - plt.ylabel(value_colname) - return autoviz.MplChart.from_current_mpl_state() - - -def violin_plot(df, value_colname, facet_colname, figscale=1, mpl_palette_name='Dark2', **kwargs): - from matplotlib import pyplot as plt - import seaborn as sns - figsize = (12 * figscale, 1.2 * figscale * len(df[facet_colname].unique())) - plt.figure(figsize=figsize) - sns.violinplot(df, x=value_colname, y=facet_colname, palette=mpl_palette_name, **kwargs) - sns.despine(top=True, right=True, bottom=True, left=True) - return autoviz.MplChart.from_current_mpl_state() +def heatmap(df_name: str, x_colname: str, y_colname: str): + """Generates a single heatmap. + Args: + df_name: Variable name of a dataframe. + x_colname: The x-axis column name. + y_colname: The y-axis column name. -def value_plot(df, y, figscale=1): - from matplotlib import pyplot as plt - df[y].plot(kind='line', figsize=(8 * figscale, 4 * figscale), title=y) - plt.gca().spines[['top', 'right']].set_visible(False) - plt.tight_layout() - return autoviz.MplChart.from_current_mpl_state() + Returns: + Code to generate the plot. + """ + return f"""from matplotlib import pyplot as plt +import seaborn as sns +import pandas as pd +plt.subplots(figsize=(8, 8)) +df_2dhist = pd.DataFrame({{ + x_label: grp[{y_colname!r}].value_counts() + for x_label, grp in {df_name}.groupby({x_colname!r}) +}}) +sns.heatmap(df_2dhist, cmap='viridis') +plt.xlabel({x_colname!r}) +plt.ylabel({y_colname!r})""" + + +def swarm_plot( + df_name: str, value_colname: str, facet_colname: str, jitter_domain_width=8 +): + """Generates a single swarm plot. + + Incorporated from altair example gallery: + https://altair-viz.github.io/gallery/stripplot.html + + Args: + df_name: Variable name of a dataframe. + value_colname: The value distribution column name. + facet_colname: The faceting column name. + jitter_domain_width: Jitter width. + + Returns: + Code to generate the plot. + """ + return f"""from matplotlib import pyplot as plt +import numpy as np +import seaborn as sns +palette = sns.palettes.mpl_palette('Dark2') +facet_values = list(sorted({df_name}[{facet_colname!r}].unique())) +_, ax = plt.subplots(figsize=(1.2 * len(facet_values), 8)) +ax.spines[['top', 'right']].set_visible(False) +xtick_locs = [{jitter_domain_width}*i for i in range(len(facet_values))] +for i, facet_value in enumerate(facet_values): + color = palette[i % len(palette)] + values = {df_name}[{df_name}[{facet_colname!r}] == facet_value][{value_colname!r}] + r1, r2 = np.random.random(len(values)), np.random.random(len(values)) + jitter = np.sqrt(-2*np.log(r1))*np.cos(2*np.pi*r2) # Box-Muller. + ax.scatter(xtick_locs[i] + jitter, values, s=1.5, alpha=.8, color=color) +ax.xaxis.set_ticks(xtick_locs, facet_values, rotation='vertical') +plt.title({facet_colname!r}) +plt.ylabel({value_colname!r})""" + + +def violin_plot( + df_name: str, value_colname: str, facet_colname: str, inner: str +): + """Generates a single violin plot. + + Args: + df_name: Variable name of a dataframe. + value_colname: The value distribution column name. + facet_colname: The faceting column name. + inner: Representation of the data in the violin interior. + + Returns: + Code to generate the plot. + """ + return f"""from matplotlib import pyplot as plt +import seaborn as sns +figsize = (12, 1.2 * len({df_name}[{facet_colname!r}].unique())) +plt.figure(figsize=figsize) +sns.violinplot({df_name}, x={value_colname!r}, y={facet_colname!r}, inner={inner!r}, palette='Dark2') +sns.despine(top=True, right=True, bottom=True, left=True)""" -def scatter_plot(df, x_colname, y_colname, figscale=1, alpha=.8): - from matplotlib import pyplot as plt - plt.figure(figsize=(6 * figscale, 6 * figscale)) - df.plot(kind='scatter', x=x_colname, y=y_colname, s=(32 * figscale), alpha=alpha) - plt.gca().spines[['top', 'right',]].set_visible(False) - plt.tight_layout() - return autoviz.MplChart.from_current_mpl_state() +def value_plot(df_name: str, y: str): + """Generates a single value plot. -def time_series_multiline(df, timelike_colname, value_colname, series_colname, figscale=1, mpl_palette_name='Dark2'): + Args: + df_name: Variable name of a dataframe. + y: The series name to plot. + + Returns: + Code to generate the plot. + """ + + return f"""from matplotlib import pyplot as plt +{df_name}[{y!r}].plot(kind='line', figsize=(8, 4), title={y!r}) +plt.gca().spines[['top', 'right']].set_visible(False) +plt.tight_layout()""" + + +def scatter_plot(df_name: str, x_colname: str, y_colname: str): + """Generates a single scatter plot. + + Args: + df_name: Variable name of a dataframe. + x_colname: Column name for the X axis. + y_colname: Column name for the Y axis. + + Returns: + Code to generate the plot. + """ + + return f"""from matplotlib import pyplot as plt +plt.figure(figsize=(6, 6)) +{df_name}.plot(kind='scatter', x={x_colname!r}, y={y_colname!r}, s=32, alpha=.8) +plt.gca().spines[['top', 'right',]].set_visible(False) +plt.tight_layout()""" + + +def time_series_multiline( + df_name: str, timelike_colname: str, value_colname: str, series_colname: str +): + """Generates a single time series plot. + + Args: + df_name: Variable name of a dataframe. + timelike_colname: Column name for the time based column. + value_colname: Column name for the value column. + series_colname: Column name for the series column. + + Returns: + Code to generate the plot. + """ + plot_series_impl = f"""xs = series[{timelike_colname!r}] + ys = series[{value_colname!r}] + """ + if value_colname == 'count()': + plot_series_impl = f"""counted = (series[{timelike_colname!r}] + .value_counts() + .reset_index(name='counts') + .rename({{'index': {timelike_colname!r}}}, axis=1) + .sort_values({timelike_colname!r}, ascending=True)) + xs = counted[{timelike_colname!r}] + ys = counted['counts']""" + + series_impl = """_plot_series(df_sorted, '')""" + if series_colname: + series_impl = f"""for i, (series_name, series) in enumerate(df_sorted.groupby({series_colname!r})): + _plot_series(series, series_name, i) + fig.legend(title={series_colname!r}, bbox_to_anchor=(1, 1), loc='upper left')""" + + return f"""from matplotlib import pyplot as plt +import seaborn as sns +def _plot_series(series, series_name, series_index=0): from matplotlib import pyplot as plt import seaborn as sns - figsize = (10 * figscale, 5.2 * figscale) - palette = list(sns.palettes.mpl_palette(mpl_palette_name)) - def _plot_series(series, series_name, series_index=0): - if value_colname == 'count()': - counted = (series[timelike_colname] - .value_counts() - .reset_index(name='counts') - .rename({'index': timelike_colname}, axis=1) - .sort_values(timelike_colname, ascending=True)) - xs = counted[timelike_colname] - ys = counted['counts'] - else: - xs = series[timelike_colname] - ys = series[value_colname] - plt.plot(xs, ys, label=series_name, color=palette[series_index % len(palette)]) - - fig, ax = plt.subplots(figsize=figsize, layout='constrained') - df = df.sort_values(timelike_colname, ascending=True) - if series_colname: - for i, (series_name, series) in enumerate(df.groupby(series_colname)): - _plot_series(series, series_name, i) - fig.legend(title=series_colname, bbox_to_anchor=(1, 1), loc='upper left') - else: - _plot_series(df, '') - sns.despine(fig=fig, ax=ax) - plt.xlabel(timelike_colname) - plt.ylabel(value_colname) - return autoviz.MplChart.from_current_mpl_state() + palette = list(sns.palettes.mpl_palette('Dark2')) + {plot_series_impl} + plt.plot(xs, ys, label=series_name, color=palette[series_index % len(palette)]) + +fig, ax = plt.subplots(figsize=(10, 5.2), layout='constrained') +df_sorted = {df_name}.sort_values({timelike_colname!r}, ascending=True) +{series_impl} +sns.despine(fig=fig, ax=ax) +plt.xlabel({timelike_colname!r}) +plt.ylabel({value_colname!r})"""