Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 575926840
  • Loading branch information
blois authored and colaboratory-team committed Oct 23, 2023
1 parent 4bac8ed commit 4e56a62
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 135 deletions.
35 changes: 7 additions & 28 deletions google/colab/_quickchart_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Supporting code for quickchart functionality."""

import inspect
import textwrap
import uuid as _uuid

Expand Down Expand Up @@ -160,12 +159,7 @@ def __init__(self, df, plot_func, args, kwargs, df_registry):
self._kwargs = kwargs

self._chart_id = f'chart-{str(_uuid.uuid4())}'
with mpl.rc_context(dict(_MPL_STYLE_OPTIONS)):
# We want the charts to be small when there are many of them, but larger
# when a user inserts a single chart. We set `figscale` here so that it
# isn't remembered if a user clicks on a chart.
figscale = 0.5
self._chart = plot_func(df, figscale=figscale, *args, **kwargs)
self._chart = None

@property
def chart_id(self):
Expand All @@ -180,30 +174,15 @@ def get_code(self):
if self._df_varname is None:
self._df_varname = self._df_registry.get_or_register_varname(self._df)

plot_func_src = inspect.getsource(self._plot_func)
plot_invocation = textwrap.dedent(
"""\
chart = {plot_func}({df_varname}, *{args}, **{kwargs})
chart""".format(
plot_func=self._plot_func.__name__,
args=str(self._args),
kwargs=str(self._kwargs),
df_varname=self._df_varname,
)
)

chart_src = textwrap.dedent("""\
import numpy as np
from google.colab import autoviz
""")
chart_src += '\n'
chart_src += plot_func_src
chart_src += '\n'
chart_src += plot_invocation
return chart_src
return self._plot_func(self._df_varname, *self._args, **self._kwargs)

def _repr_html_(self):
"""Gets the HTML representation of the chart."""
if self._chart is None:
with mpl.rc_context(dict(_MPL_STYLE_OPTIONS)):
exec_code = self._plot_func('df', *self._args, **self._kwargs)
exec(exec_code, {'df': self._df}) # pylint: disable=exec-used
self._chart = _quickchart_lib.autoviz.MplChart.from_current_mpl_state()

chart_html = self._chart._repr_mimebundle_()['text/html'] # pylint:disable = protected-access
script_start = chart_html.find('<script')
Expand Down
292 changes: 185 additions & 107 deletions google/colab/_quickchart_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import io

import IPython.display
import numpy as np


# Note: lazy imports throughout due to minimizing kernel init imports.
Expand Down Expand Up @@ -34,9 +33,8 @@ def from_current_mpl_state(cls):
plt.close()
f.seek(0)
png_data = f.read()
return cls(f"""<img src="data:image/png;base64,{
base64.encodebytes(png_data).decode("ascii")}">
<script></script>""")
return cls(f"""<img style="width: 180px;" src="data:image/png;base64,{
base64.encodebytes(png_data).decode("ascii")}">""")

def _repr_html_(self):
return self.chart_html
Expand All @@ -59,122 +57,202 @@ class autoviz: # pylint:disable=invalid-name
MplChart = MplChart # pylint:disable=invalid-name


# Note: pyformat disabled for this module since these are user-facing code
# snippets that are meant to be more notebook-style per b/283014273#comment8.
# pyformat: disable
# pylint:disable=missing-function-docstring
def histogram(df_name: str, colname: str, num_bins=20):
"""Generates a histogram for the given data column.
Args:
df_name: Variable name of a dataframe.
colname: Name of the column to plot.
num_bins: The number of value bins.
# NOTE: All functions below must have a `figscale` keyword arg.
Returns:
Code to generate the plot.
"""
return f"""from matplotlib import pyplot as plt
{df_name}[{colname!r}].plot(kind='hist', bins={num_bins}, title={colname!r})
plt.gca().spines[['top', 'right',]].set_visible(False)
plt.tight_layout()"""


def histogram(df, colname, num_bins=20, figscale=1):
from matplotlib import pyplot as plt
df[colname].plot(kind='hist', bins=num_bins, title=colname, figsize=(8*figscale, 4*figscale))
plt.gca().spines[['top', 'right',]].set_visible(False)
plt.tight_layout()
return autoviz.MplChart.from_current_mpl_state()
def categorical_histogram(df_name, colname):
"""Generates a single categorical histogram.
Args:
df_name: Variable name of a dataframe.
colname: The column name to plot.
def categorical_histogram(df, colname, figscale=1, mpl_palette_name='Dark2'):
from matplotlib import pyplot as plt
import seaborn as sns
df.groupby(colname).size().plot(kind='barh', color=sns.palettes.mpl_palette(mpl_palette_name), figsize=(8*figscale, 4.8*figscale))
plt.gca().spines[['top', 'right',]].set_visible(False)
return autoviz.MplChart.from_current_mpl_state()
Returns:
Code to generate the plot.
"""
return f"""from matplotlib import pyplot as plt
import seaborn as sns
{df_name}.groupby({colname!r}).size().plot(kind='barh', color=sns.palettes.mpl_palette('Dark2'))
plt.gca().spines[['top', 'right',]].set_visible(False)"""


def heatmap(df, x_colname, y_colname, figscale=1, mpl_palette_name='viridis'):
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
plt.subplots(figsize=(8 * figscale, 8 * figscale))
df_2dhist = pd.DataFrame({
x_label: grp[y_colname].value_counts()
for x_label, grp in df.groupby(x_colname)
})
sns.heatmap(df_2dhist, cmap=mpl_palette_name)
plt.xlabel(x_colname)
plt.ylabel(y_colname)
return autoviz.MplChart.from_current_mpl_state()


def swarm_plot(df, value_colname, facet_colname, figscale=1, mpl_palette_name='Dark2', jitter_domain_width=8):
from matplotlib import pyplot as plt
import seaborn as sns
palette = sns.palettes.mpl_palette(mpl_palette_name)
facet_values = list(sorted(df[facet_colname].unique()))
figsize = (1.2 * figscale * len(facet_values), 8 * figscale)
_, ax = plt.subplots(figsize=figsize)
ax.spines[['top', 'right']].set_visible(False)
xtick_locs = [jitter_domain_width*i for i in range(len(facet_values))]
for i, facet_value in enumerate(facet_values):
color = palette[i % len(palette)]
values = df[df[facet_colname] == facet_value][value_colname]
r1, r2 = np.random.random(len(values)), np.random.random(len(values))
jitter = np.sqrt(-2*np.log(r1))*np.cos(2*np.pi*r2) # Box-Muller.
ax.scatter(xtick_locs[i] + jitter, values, s=1.5, alpha=.8, color=color)
ax.xaxis.set_ticks(xtick_locs, facet_values, rotation='vertical')
plt.title(facet_colname)
plt.ylabel(value_colname)
return autoviz.MplChart.from_current_mpl_state()


def violin_plot(df, value_colname, facet_colname, figscale=1, mpl_palette_name='Dark2', **kwargs):
from matplotlib import pyplot as plt
import seaborn as sns
figsize = (12 * figscale, 1.2 * figscale * len(df[facet_colname].unique()))
plt.figure(figsize=figsize)
sns.violinplot(df, x=value_colname, y=facet_colname, palette=mpl_palette_name, **kwargs)
sns.despine(top=True, right=True, bottom=True, left=True)
return autoviz.MplChart.from_current_mpl_state()
def heatmap(df_name: str, x_colname: str, y_colname: str):
"""Generates a single heatmap.
Args:
df_name: Variable name of a dataframe.
x_colname: The x-axis column name.
y_colname: The y-axis column name.
def value_plot(df, y, figscale=1):
from matplotlib import pyplot as plt
df[y].plot(kind='line', figsize=(8 * figscale, 4 * figscale), title=y)
plt.gca().spines[['top', 'right']].set_visible(False)
plt.tight_layout()
return autoviz.MplChart.from_current_mpl_state()
Returns:
Code to generate the plot.
"""
return f"""from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
plt.subplots(figsize=(8, 8))
df_2dhist = pd.DataFrame({{
x_label: grp[{y_colname!r}].value_counts()
for x_label, grp in {df_name}.groupby({x_colname!r})
}})
sns.heatmap(df_2dhist, cmap='viridis')
plt.xlabel({x_colname!r})
plt.ylabel({y_colname!r})"""


def swarm_plot(
df_name: str, value_colname: str, facet_colname: str, jitter_domain_width=8
):
"""Generates a single swarm plot.
Incorporated from altair example gallery:
https://altair-viz.github.io/gallery/stripplot.html
Args:
df_name: Variable name of a dataframe.
value_colname: The value distribution column name.
facet_colname: The faceting column name.
jitter_domain_width: Jitter width.
Returns:
Code to generate the plot.
"""

return f"""from matplotlib import pyplot as plt
import numpy as np
import seaborn as sns
palette = sns.palettes.mpl_palette('Dark2')
facet_values = list(sorted({df_name}[{facet_colname!r}].unique()))
_, ax = plt.subplots(figsize=(1.2 * len(facet_values), 8))
ax.spines[['top', 'right']].set_visible(False)
xtick_locs = [{jitter_domain_width}*i for i in range(len(facet_values))]
for i, facet_value in enumerate(facet_values):
color = palette[i % len(palette)]
values = {df_name}[{df_name}[{facet_colname!r}] == facet_value][{value_colname!r}]
r1, r2 = np.random.random(len(values)), np.random.random(len(values))
jitter = np.sqrt(-2*np.log(r1))*np.cos(2*np.pi*r2) # Box-Muller.
ax.scatter(xtick_locs[i] + jitter, values, s=1.5, alpha=.8, color=color)
ax.xaxis.set_ticks(xtick_locs, facet_values, rotation='vertical')
plt.title({facet_colname!r})
plt.ylabel({value_colname!r})"""


def violin_plot(
df_name: str, value_colname: str, facet_colname: str, inner: str
):
"""Generates a single violin plot.
Args:
df_name: Variable name of a dataframe.
value_colname: The value distribution column name.
facet_colname: The faceting column name.
inner: Representation of the data in the violin interior.
Returns:
Code to generate the plot.
"""
return f"""from matplotlib import pyplot as plt
import seaborn as sns
figsize = (12, 1.2 * len({df_name}[{facet_colname!r}].unique()))
plt.figure(figsize=figsize)
sns.violinplot({df_name}, x={value_colname!r}, y={facet_colname!r}, inner={inner!r}, palette='Dark2')
sns.despine(top=True, right=True, bottom=True, left=True)"""

def scatter_plot(df, x_colname, y_colname, figscale=1, alpha=.8):
from matplotlib import pyplot as plt
plt.figure(figsize=(6 * figscale, 6 * figscale))
df.plot(kind='scatter', x=x_colname, y=y_colname, s=(32 * figscale), alpha=alpha)
plt.gca().spines[['top', 'right',]].set_visible(False)
plt.tight_layout()
return autoviz.MplChart.from_current_mpl_state()

def value_plot(df_name: str, y: str):
"""Generates a single value plot.
def time_series_multiline(df, timelike_colname, value_colname, series_colname, figscale=1, mpl_palette_name='Dark2'):
Args:
df_name: Variable name of a dataframe.
y: The series name to plot.
Returns:
Code to generate the plot.
"""

return f"""from matplotlib import pyplot as plt
{df_name}[{y!r}].plot(kind='line', figsize=(8, 4), title={y!r})
plt.gca().spines[['top', 'right']].set_visible(False)
plt.tight_layout()"""


def scatter_plot(df_name: str, x_colname: str, y_colname: str):
"""Generates a single scatter plot.
Args:
df_name: Variable name of a dataframe.
x_colname: Column name for the X axis.
y_colname: Column name for the Y axis.
Returns:
Code to generate the plot.
"""

return f"""from matplotlib import pyplot as plt
plt.figure(figsize=(6, 6))
{df_name}.plot(kind='scatter', x={x_colname!r}, y={y_colname!r}, s=32, alpha=.8)
plt.gca().spines[['top', 'right',]].set_visible(False)
plt.tight_layout()"""


def time_series_multiline(
df_name: str, timelike_colname: str, value_colname: str, series_colname: str
):
"""Generates a single time series plot.
Args:
df_name: Variable name of a dataframe.
timelike_colname: Column name for the time based column.
value_colname: Column name for the value column.
series_colname: Column name for the series column.
Returns:
Code to generate the plot.
"""
plot_series_impl = f"""xs = series[{timelike_colname!r}]
ys = series[{value_colname!r}]
"""
if value_colname == 'count()':
plot_series_impl = f"""counted = (series[{timelike_colname!r}]
.value_counts()
.reset_index(name='counts')
.rename({{'index': {timelike_colname!r}}}, axis=1)
.sort_values({timelike_colname!r}, ascending=True))
xs = counted[{timelike_colname!r}]
ys = counted['counts']"""

series_impl = """_plot_series(df_sorted, '')"""
if series_colname:
series_impl = f"""for i, (series_name, series) in enumerate(df_sorted.groupby({series_colname!r})):
_plot_series(series, series_name, i)
fig.legend(title={series_colname!r}, bbox_to_anchor=(1, 1), loc='upper left')"""

return f"""from matplotlib import pyplot as plt
import seaborn as sns
def _plot_series(series, series_name, series_index=0):
from matplotlib import pyplot as plt
import seaborn as sns
figsize = (10 * figscale, 5.2 * figscale)
palette = list(sns.palettes.mpl_palette(mpl_palette_name))
def _plot_series(series, series_name, series_index=0):
if value_colname == 'count()':
counted = (series[timelike_colname]
.value_counts()
.reset_index(name='counts')
.rename({'index': timelike_colname}, axis=1)
.sort_values(timelike_colname, ascending=True))
xs = counted[timelike_colname]
ys = counted['counts']
else:
xs = series[timelike_colname]
ys = series[value_colname]
plt.plot(xs, ys, label=series_name, color=palette[series_index % len(palette)])

fig, ax = plt.subplots(figsize=figsize, layout='constrained')
df = df.sort_values(timelike_colname, ascending=True)
if series_colname:
for i, (series_name, series) in enumerate(df.groupby(series_colname)):
_plot_series(series, series_name, i)
fig.legend(title=series_colname, bbox_to_anchor=(1, 1), loc='upper left')
else:
_plot_series(df, '')
sns.despine(fig=fig, ax=ax)
plt.xlabel(timelike_colname)
plt.ylabel(value_colname)
return autoviz.MplChart.from_current_mpl_state()
palette = list(sns.palettes.mpl_palette('Dark2'))
{plot_series_impl}
plt.plot(xs, ys, label=series_name, color=palette[series_index % len(palette)])
fig, ax = plt.subplots(figsize=(10, 5.2), layout='constrained')
df_sorted = {df_name}.sort_values({timelike_colname!r}, ascending=True)
{series_impl}
sns.despine(fig=fig, ax=ax)
plt.xlabel({timelike_colname!r})
plt.ylabel({value_colname!r})"""

0 comments on commit 4e56a62

Please sign in to comment.