Skip to content

Commit

Permalink
Impl, test compact_xaxis_units
Browse files Browse the repository at this point in the history
  • Loading branch information
mmore500 committed Nov 5, 2023
1 parent a3ef533 commit fe58788
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 0 deletions.
73 changes: 73 additions & 0 deletions conduitpylib/test/test_viz/test_compact_xaxis_units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from matplotlib import pyplot as plt

from conduitpylib.viz import compact_xaxis_units


def test_prefix_removal():
# setup
ax = plt.gca()
ax.ticklabel_format(style="sci", scilimits=(-2, 2)),
ax.plot([1000, 2000, 3000], [1, 2, 3])
ax.set_xlabel("Time (ms)")

# apply
compact_xaxis_units(ax, base_unit="s")

# verify
expected_label = "Time (s)"
assert ax.get_xlabel() == expected_label

# clean up
plt.clf()


def test_prefix_addition():
# setup
ax = plt.gca()
ax.ticklabel_format(style="sci", scilimits=(-2, 2)),
ax.plot([1000, 2000, 3000], [1, 2, 3])
ax.set_xlabel("Weight (g)")
# apply
compact_xaxis_units(ax, base_unit="g")

# verify
expected_label = "Weight (kg)"
assert ax.get_xlabel() == expected_label

# clean up
plt.clf()


def test_prefix_nop():
# setup
ax = plt.gca()
ax.ticklabel_format(style="sci", scilimits=(-2, 2)),
ax.plot([1, 2, 3], [1, 2, 3])
ax.set_xlabel("Weight (g)")
# apply
compact_xaxis_units(ax, base_unit="g")

# verify
expected_label = "Weight (g)"
assert ax.get_xlabel() == expected_label

# clean up
plt.clf()


def test_nanoseconds_to_milliseconds():
# setup
ax = plt.gca()
ax.ticklabel_format(style="sci", scilimits=(-2, 2)),
ax.plot([1e6, 2e6, 3e6], [1, 2, 3])
ax.set_xlabel("Time (ns)")

# apply
compact_xaxis_units(ax, base_unit="s")

# verify
expected_label = "Time (ms)"
assert ax.get_xlabel() == expected_label

# clean up
plt.clf()
2 changes: 2 additions & 0 deletions conduitpylib/viz/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Visualization tools."""

from ._compact_xaxis_units import compact_xaxis_units
from ._frame_scatter_subsets import frame_scatter_subsets
from ._performance_semantics_plot import performance_semantics_plot
from ._set_kde_lims import set_kde_lims


# adapted from https://stackoverflow.com/a/31079085
__all__ = [
"compact_xaxis_units",
"frame_scatter_subsets",
"performance_semantics_plot",
"set_kde_lims",
Expand Down
105 changes: 105 additions & 0 deletions conduitpylib/viz/_compact_xaxis_units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import re

from bidict import frozenbidict
from matplotlib import axis
import numpy as np

from ..utils import round_to_multiple, splice

_si_prefixes = frozenbidict(
{
"Y": 24, # yotta
"Z": 21, # zetta
"E": 18, # exa
"P": 15, # peta
"T": 12, # tera
"G": 9, # giga
"M": 6, # mega
"k": 3, # kilo
"": 0, # base
"m": -3, # milli
"u": -6, # micro
"n": -9, # nano
"p": -12, # pico
"f": -15, # femto
"a": -18, # atto
"z": -21, # zepto
"y": -24, # yocto
},
)


def compact_xaxis_units(
ax: axis.Axis,
base_unit: str = "s",
regex_pattern_template: str = r"\((.*?){base_unit}\)",
) -> None:
r"""Adjusts the units on the x-axis of a matplotlib axis object to a more
compact form.
This function finds the multiplier (such as 10^3, 10^6, etc.) used by
matplotlib for the axis, identifies the current unit prefix (like k for
kilo, M for mega), and then consolidates these into a new, more compact unit
prefix. This is particularly useful for graphs where the axis values are
very large or very small, and a more readable unit is desired.
Parameters
----------
ax : axis.Axis
Axis object whose x-axis units are to be compacted inplace.
base_unit : str, default "s"
The base unit (without prefix) for the axis.
regex_pattern_template : str, default r"\((.*?){base_unit}\)"
A regular expression template used to identify and extract the unit
prefix from the axis label.
The template must contain `{base_unit}` as a
placeholder for the actual base unit. Default matches parenthesized expressions like "(ks)" for kiloseconds.
"""
regex_pattern = regex_pattern_template.format(base_unit=base_unit)
regex = re.compile(regex_pattern)

# force population of
fig = ax.get_figure()
fig.canvas.draw()

# get unit multiplier chosen by matplotlib
offset_text = ax.xaxis.get_offset_text()
if offset_text is None: # handle unit (i.e., non) multiplier
return
offset_string = offset_text.get_text()
if offset_string == "": # handle unit (i.e., non) multiplier
return
offset_amount = float(offset_string)

assert str(offset_amount).count("1") == 1 # power of 10

# get unit prefix and multiplier in axis label
old_label_string = ax.get_xlabel()
(old_prefix_match,) = regex.finditer(old_label_string)
old_prefix = old_prefix_match.group(1)
old_multiplier_pow = _si_prefixes[old_prefix]
old_multiplier = 10.0**old_multiplier_pow

# calculate new prefix from net multiplier
new_multiplier = offset_amount * old_multiplier
approx_pow = np.log10(new_multiplier)
new_multiplier_pow = round_to_multiple(approx_pow, multiple=3)
assert np.isclose(approx_pow, new_multiplier_pow)
new_prefix = _si_prefixes.inv[new_multiplier_pow]

# apply new multiplier, consolidating offset axis scale and label unit
def replace_group(match):
# Extracting the captured group
captured_group = match.group(1)
# Defining the replacement value for the captured group
return replacement_value

ax.xaxis.get_offset_text().set(visible=False) # remove offset axis scale
old_prefix_span = old_prefix_match.span(1)
new_label_string = splice(
old_label_string,
old_prefix_span,
new_prefix,
)
ax.set_xlabel(new_label_string)

0 comments on commit fe58788

Please sign in to comment.