Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 576334842
  • Loading branch information
blois authored and colaboratory-team committed Oct 25, 2023
1 parent 7b84345 commit ccf9365
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 39 deletions.
47 changes: 37 additions & 10 deletions google/colab/_quickchart.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,36 @@ def _ensure_dataframe_registry():

_ensure_dataframe_registry()

chart_sections = determine_charts(
df, _DATAFRAME_REGISTRY, max_chart_instances
)
if not chart_sections:
print('No charts were generated by quickchart')
return chart_sections


def find_charts_json(df_name: str, max_chart_instances=None):
"""Equivalent to find_charts, but emits to JSON for use from browser."""

class FixedDataframeRegistry:

def get_or_register_varname(self, _) -> str:
"""Returns the name of the fixed dataframe name."""
return df_name

dataframe = IPython.get_ipython().user_ns[df_name]

chart_sections = determine_charts(
dataframe, FixedDataframeRegistry(), max_chart_instances
)
return IPython.display.JSON([s.to_json() for s in chart_sections])


def determine_charts(df, dataframe_registry, max_chart_instances=None):
"""Finds charts compatible with dtypes of the given data frame."""
# Lazy import to avoid loading matplotlib and transitive deps on kernel init.
from google.colab import _quickchart_helpers # pylint: disable=g-import-not-at-top

dtype_groups = _classify_dtypes(df)
numeric_cols = dtype_groups['numeric']
categorical_cols = dtype_groups['categorical']
Expand All @@ -66,15 +96,15 @@ def _ensure_dataframe_registry():
if numeric_cols:
chart_sections.append(
_quickchart_helpers.histograms_section(
df, numeric_cols[:max_chart_instances], _DATAFRAME_REGISTRY
df, numeric_cols[:max_chart_instances], dataframe_registry
)
)

if categorical_cols:
selected_categorical_cols = categorical_cols[:max_chart_instances]
chart_sections += [
_quickchart_helpers.categorical_histograms_section(
df, selected_categorical_cols, _DATAFRAME_REGISTRY
df, selected_categorical_cols, dataframe_registry
),
]

Expand All @@ -83,7 +113,7 @@ def _ensure_dataframe_registry():
_quickchart_helpers.scatter_section(
df,
_select_first_k_pairs(numeric_cols, k=max_chart_instances),
_DATAFRAME_REGISTRY,
dataframe_registry,
),
]

Expand All @@ -97,14 +127,14 @@ def _ensure_dataframe_registry():
categorical_cols=categorical_cols,
k=max_chart_instances,
),
_DATAFRAME_REGISTRY,
dataframe_registry,
),
)

if numeric_cols:
chart_sections.append(
_quickchart_helpers.value_plots_section(
df, numeric_cols[:max_chart_instances], _DATAFRAME_REGISTRY
df, numeric_cols[:max_chart_instances], dataframe_registry
)
)

Expand All @@ -113,7 +143,7 @@ def _ensure_dataframe_registry():
_quickchart_helpers.heatmaps_section(
df,
_select_first_k_pairs(categorical_cols, k=max_chart_instances),
_DATAFRAME_REGISTRY,
dataframe_registry,
),
]

Expand All @@ -124,12 +154,9 @@ def _ensure_dataframe_registry():
_select_faceted_numeric_cols(
numeric_cols, categorical_cols, k=max_chart_instances
),
_DATAFRAME_REGISTRY,
dataframe_registry,
),
]

if not chart_sections:
print('No charts were generated by quickchart')
return chart_sections


Expand Down
24 changes: 21 additions & 3 deletions google/colab/_quickchart_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ def display(self):
for d in self._displayables:
d.display()

def to_json(self):
charts = [chart.to_json() for chart in self._charts]
return {
'section_type': self._section_type,
'charts': charts,
}


class SectionTitle:
"""Section title used for delineating chart sections."""
Expand Down Expand Up @@ -169,18 +176,29 @@ def display(self):
"""Displays the chart within a notebook context."""
IPython.display.display(self)

def get_code(self):
"""Gets the code and associated dependencies + context for a given chart."""
def to_json(self):
data = self.get_code_and_title()
return {
'code': data.code,
'title': data.title,
}

def get_code_and_title(self):
if self._df_varname is None:
self._df_varname = self._df_registry.get_or_register_varname(self._df)

return self._plot_func(self._df_varname, *self._args, **self._kwargs)

def get_code(self):
"""Gets the code and associated dependencies + context for a given chart."""

return self.get_code_and_title().code

def _repr_html_(self):
"""Gets the HTML representation of the chart."""
if self._chart is None:
with mpl.rc_context(dict(_MPL_STYLE_OPTIONS)):
exec_code = self._plot_func('df', *self._args, **self._kwargs)
exec_code, _ = self._plot_func('df', *self._args, **self._kwargs)
exec(exec_code, {'df': self._df}) # pylint: disable=exec-used
self._chart = _quickchart_lib.autoviz.MplChart.from_current_mpl_state()

Expand Down
70 changes: 44 additions & 26 deletions google/colab/_quickchart_lib.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Library of charts for use by quickchart."""

import base64
import dataclasses
import io

import IPython.display
Expand All @@ -10,6 +11,12 @@
# pylint:disable=g-import-not-at-top


@dataclasses.dataclass
class ChartData:
title: str
code: str


class MplChart:
"""Matplotlib chart wrapper that displays charts to PNG <image> elements."""

Expand Down Expand Up @@ -57,7 +64,7 @@ class autoviz: # pylint:disable=invalid-name
MplChart = MplChart # pylint:disable=invalid-name


def histogram(df_name: str, colname: str, num_bins=20):
def histogram(df_name: str, colname: str, num_bins=20) -> ChartData:
"""Generates a histogram for the given data column.
Args:
Expand All @@ -68,13 +75,13 @@ def histogram(df_name: str, colname: str, num_bins=20):
Returns:
Code to generate the plot.
"""
return f"""from matplotlib import pyplot as plt
code = f"""from matplotlib import pyplot as plt
{df_name}[{colname!r}].plot(kind='hist', bins={num_bins}, title={colname!r})
plt.gca().spines[['top', 'right',]].set_visible(False)
plt.tight_layout()"""
plt.gca().spines[['top', 'right',]].set_visible(False)"""
return ChartData(title=colname, code=code)


def categorical_histogram(df_name, colname):
def categorical_histogram(df_name, colname) -> ChartData:
"""Generates a single categorical histogram.
Args:
Expand All @@ -84,13 +91,15 @@ def categorical_histogram(df_name, colname):
Returns:
Code to generate the plot.
"""
return f"""from matplotlib import pyplot as plt
code = f"""from matplotlib import pyplot as plt
import seaborn as sns
{df_name}.groupby({colname!r}).size().plot(kind='barh', color=sns.palettes.mpl_palette('Dark2'))
plt.gca().spines[['top', 'right',]].set_visible(False)"""

return ChartData(title=colname, code=code)

def heatmap(df_name: str, x_colname: str, y_colname: str):

def heatmap(df_name: str, x_colname: str, y_colname: str) -> ChartData:
"""Generates a single heatmap.
Args:
Expand All @@ -101,7 +110,7 @@ def heatmap(df_name: str, x_colname: str, y_colname: str):
Returns:
Code to generate the plot.
"""
return f"""from matplotlib import pyplot as plt
code = f"""from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
plt.subplots(figsize=(8, 8))
Expand All @@ -111,12 +120,14 @@ def heatmap(df_name: str, x_colname: str, y_colname: str):
}})
sns.heatmap(df_2dhist, cmap='viridis')
plt.xlabel({x_colname!r})
plt.ylabel({y_colname!r})"""
_ = plt.ylabel({y_colname!r})"""

return ChartData(f'{x_colname} vs {y_colname}', code)


def swarm_plot(
df_name: str, value_colname: str, facet_colname: str, jitter_domain_width=8
):
) -> ChartData:
"""Generates a single swarm plot.
Incorporated from altair example gallery:
Expand All @@ -132,7 +143,7 @@ def swarm_plot(
Code to generate the plot.
"""

return f"""from matplotlib import pyplot as plt
code = f"""from matplotlib import pyplot as plt
import numpy as np
import seaborn as sns
palette = sns.palettes.mpl_palette('Dark2')
Expand All @@ -148,12 +159,14 @@ def swarm_plot(
ax.scatter(xtick_locs[i] + jitter, values, s=1.5, alpha=.8, color=color)
ax.xaxis.set_ticks(xtick_locs, facet_values, rotation='vertical')
plt.title({facet_colname!r})
plt.ylabel({value_colname!r})"""
_ = plt.ylabel({value_colname!r})"""

return ChartData(f'{facet_colname} vs {value_colname}', code)


def violin_plot(
df_name: str, value_colname: str, facet_colname: str, inner: str
):
) -> ChartData:
"""Generates a single violin plot.
Args:
Expand All @@ -165,15 +178,17 @@ def violin_plot(
Returns:
Code to generate the plot.
"""
return f"""from matplotlib import pyplot as plt
code = f"""from matplotlib import pyplot as plt
import seaborn as sns
figsize = (12, 1.2 * len({df_name}[{facet_colname!r}].unique()))
plt.figure(figsize=figsize)
sns.violinplot({df_name}, x={value_colname!r}, y={facet_colname!r}, inner={inner!r}, palette='Dark2')
sns.despine(top=True, right=True, bottom=True, left=True)"""

return ChartData(f'{facet_colname} vs {value_colname}', code)

def value_plot(df_name: str, y: str):

def value_plot(df_name: str, y: str) -> ChartData:
"""Generates a single value plot.
Args:
Expand All @@ -184,13 +199,14 @@ def value_plot(df_name: str, y: str):
Code to generate the plot.
"""

return f"""from matplotlib import pyplot as plt
code = f"""from matplotlib import pyplot as plt
{df_name}[{y!r}].plot(kind='line', figsize=(8, 4), title={y!r})
plt.gca().spines[['top', 'right']].set_visible(False)
plt.tight_layout()"""
plt.gca().spines[['top', 'right']].set_visible(False)"""

return ChartData(y, code)


def scatter_plot(df_name: str, x_colname: str, y_colname: str):
def scatter_plot(df_name: str, x_colname: str, y_colname: str) -> ChartData:
"""Generates a single scatter plot.
Args:
Expand All @@ -202,16 +218,16 @@ def scatter_plot(df_name: str, x_colname: str, y_colname: str):
Code to generate the plot.
"""

return f"""from matplotlib import pyplot as plt
plt.figure(figsize=(6, 6))
code = f"""from matplotlib import pyplot as plt
{df_name}.plot(kind='scatter', x={x_colname!r}, y={y_colname!r}, s=32, alpha=.8)
plt.gca().spines[['top', 'right',]].set_visible(False)
plt.tight_layout()"""
plt.gca().spines[['top', 'right',]].set_visible(False)"""

return ChartData(f'{x_colname} vs {y_colname}', code)


def time_series_multiline(
df_name: str, timelike_colname: str, value_colname: str, series_colname: str
):
) -> ChartData:
"""Generates a single time series plot.
Args:
Expand Down Expand Up @@ -241,7 +257,7 @@ def time_series_multiline(
_plot_series(series, series_name, i)
fig.legend(title={series_colname!r}, bbox_to_anchor=(1, 1), loc='upper left')"""

return f"""from matplotlib import pyplot as plt
code = f"""from matplotlib import pyplot as plt
import seaborn as sns
def _plot_series(series, series_name, series_index=0):
from matplotlib import pyplot as plt
Expand All @@ -255,4 +271,6 @@ def _plot_series(series, series_name, series_index=0):
{series_impl}
sns.despine(fig=fig, ax=ax)
plt.xlabel({timelike_colname!r})
plt.ylabel({value_colname!r})"""
_ = plt.ylabel({value_colname!r})"""

return ChartData(f'{timelike_colname} vs {value_colname}', code)

0 comments on commit ccf9365

Please sign in to comment.