Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve handling of large cascade plots #71

Merged
merged 5 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions p3analysis/plot/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
<https://github.com/intel/p3-analysis-library/issues/new/choose>`_.
"""

import itertools
import string

__all__ = [
"Plot",
"CascadePlot",
Expand Down Expand Up @@ -51,3 +54,24 @@ class NavChart(Plot):

def __init__(self, backend):
super().__init__(backend)


def _get_platform_labels(platforms: list[str]) -> dict[str, str]:
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing argument in docstring

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 6471f60.

Returns
-------
dict[str, str]:
A mapping from platform names to unique labels.
"""
if len(platforms) <= len(string.ascii_uppercase):
labels = string.ascii_uppercase
elif len(platforms) <= len(string.ascii_uppercase) ** 2:
labels = []
for x, y in itertools.product(string.ascii_uppercase, repeat=2):
labels.append(f"{x}{y}")
else:
raise RuntimeError(
"The number of platforms supported by cascade plots is "
+ f"currently limited to {len(string.ascii_uppercase)**2}.",
)
return dict(zip(platforms, labels))
26 changes: 13 additions & 13 deletions p3analysis/plot/backend/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
:py:mod:`matplotlib` backend.
"""

import string

import matplotlib
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
Expand All @@ -17,7 +15,7 @@
import p3analysis.metrics
from p3analysis._utils import _require_numeric
from p3analysis.plot._common import ApplicationStyle, Legend, PlatformStyle
from p3analysis.plot.backend import CascadePlot, NavChart
from p3analysis.plot.backend import CascadePlot, NavChart, _get_platform_labels


def _get_colors(applications, kwarg):
Expand Down Expand Up @@ -57,6 +55,10 @@ def create_artists(
):
artist = []

# Make adjustments for large numbers of platforms.
if len(self.labels) > 26:
width *= 1.5

# Draw a box using the platform's assigned color
name = orig_handle.get_label()
color = self.colors[name]
Expand Down Expand Up @@ -127,13 +129,16 @@ def __init__(
"filled_markers",
)

# If the size is unset, default to 6 x 5
if not size:
size = (6, 5)

platforms = df["platform"].unique()
applications = df["application"].unique()

# If the size is unset, try to pick a sensible default.
if not size:
if len(platforms) <= 26:
size = (6, 5)
else:
size = (12, 10)

# Create a 2x2 grid of subplots sharing axes
fig = plt.figure(figsize=size)
ratios = [6, len(applications) * 0.5]
Expand Down Expand Up @@ -168,12 +173,7 @@ def __init__(
plat_colors = _get_colors(platforms, plat_style.colors)

# Choose labels for each platform
if len(platforms) > len(string.ascii_uppercase):
raise RuntimeError(
"The number of platforms supported by cascade plots is "
+ f"currently limited to {len(string.ascii_uppercase)}.",
)
plat_labels = dict(zip(platforms, string.ascii_uppercase))
plat_labels = _get_platform_labels(platforms)

# Plot the efficiency cascade in the top-left (0, 0)
app_handles = self.__efficiency_cascade(
Expand Down
9 changes: 2 additions & 7 deletions p3analysis/plot/backend/pgfplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
:py:mod:`pgfplots` backend.
"""

import string

import jinja2
import matplotlib
import matplotlib.pyplot as plt
Expand All @@ -16,7 +14,7 @@
import p3analysis.metrics
from p3analysis._utils import _require_numeric
from p3analysis.plot._common import ApplicationStyle, Legend, PlatformStyle
from p3analysis.plot.backend import CascadePlot, NavChart
from p3analysis.plot.backend import CascadePlot, NavChart, _get_platform_labels

# Define 19 default markers for LaTeX plots
_pgfplots_markers = [
Expand Down Expand Up @@ -139,7 +137,7 @@ def __init__(self, df, eff_column, size=None, stream=None, **kwargs):
).strip("()")

# Build a dictionary of platforms to labels
plat_labels = dict(zip(platforms, string.ascii_uppercase))
plat_labels = _get_platform_labels(platforms)

# Choose colors for each platform and then convert the dictionary to
# RGB colors using the platform labels
Expand All @@ -162,9 +160,6 @@ def __init__(self, df, eff_column, size=None, stream=None, **kwargs):
app
] = f"{app_to_tex_name[app]}, thick, solid, mark={mark}"

# Choose labels for each platform
plat_labels = dict(zip(platforms, string.ascii_uppercase))

# Set the number of rows in the platform key (if set)
# NOTE: This is different to matplotlib because PGF plots uses columns,
# and then transposes (so columns become rows)
Expand Down
Loading