diff --git a/conduitpylib/test/test_viz/test_compact_xaxis_units.py b/conduitpylib/test/test_viz/test_compact_xaxis_units.py new file mode 100644 index 000000000..9aa38614c --- /dev/null +++ b/conduitpylib/test/test_viz/test_compact_xaxis_units.py @@ -0,0 +1,73 @@ +from matplotlib import pyplot as plt + +from conduitpylib.viz import compact_xaxis_units + + +def test_prefix_removal(): + # setup + ax = plt.gca() + ax.ticklabel_format(style="sci", scilimits=(-2, 2)), + ax.plot([1000, 2000, 3000], [1, 2, 3]) + ax.set_xlabel("Time (ms)") + + # apply + compact_xaxis_units(ax, base_unit="s") + + # verify + expected_label = "Time (s)" + assert ax.get_xlabel() == expected_label + + # clean up + plt.clf() + + +def test_prefix_addition(): + # setup + ax = plt.gca() + ax.ticklabel_format(style="sci", scilimits=(-2, 2)), + ax.plot([1000, 2000, 3000], [1, 2, 3]) + ax.set_xlabel("Weight (g)") + # apply + compact_xaxis_units(ax, base_unit="g") + + # verify + expected_label = "Weight (kg)" + assert ax.get_xlabel() == expected_label + + # clean up + plt.clf() + + +def test_prefix_nop(): + # setup + ax = plt.gca() + ax.ticklabel_format(style="sci", scilimits=(-2, 2)), + ax.plot([1, 2, 3], [1, 2, 3]) + ax.set_xlabel("Weight (g)") + # apply + compact_xaxis_units(ax, base_unit="g") + + # verify + expected_label = "Weight (g)" + assert ax.get_xlabel() == expected_label + + # clean up + plt.clf() + + +def test_nanoseconds_to_milliseconds(): + # setup + ax = plt.gca() + ax.ticklabel_format(style="sci", scilimits=(-2, 2)), + ax.plot([1e6, 2e6, 3e6], [1, 2, 3]) + ax.set_xlabel("Time (ns)") + + # apply + compact_xaxis_units(ax, base_unit="s") + + # verify + expected_label = "Time (ms)" + assert ax.get_xlabel() == expected_label + + # clean up + plt.clf() diff --git a/conduitpylib/viz/__init__.py b/conduitpylib/viz/__init__.py index 2144e2cf1..57ab01fdc 100644 --- a/conduitpylib/viz/__init__.py +++ b/conduitpylib/viz/__init__.py @@ -1,5 +1,6 @@ """Visualization tools.""" +from ._compact_xaxis_units import compact_xaxis_units from ._frame_scatter_subsets import frame_scatter_subsets from ._performance_semantics_plot import performance_semantics_plot from ._set_kde_lims import set_kde_lims @@ -7,6 +8,7 @@ # adapted from https://stackoverflow.com/a/31079085 __all__ = [ + "compact_xaxis_units", "frame_scatter_subsets", "performance_semantics_plot", "set_kde_lims", diff --git a/conduitpylib/viz/_compact_xaxis_units.py b/conduitpylib/viz/_compact_xaxis_units.py new file mode 100644 index 000000000..78031b4a3 --- /dev/null +++ b/conduitpylib/viz/_compact_xaxis_units.py @@ -0,0 +1,105 @@ +import re + +from bidict import frozenbidict +from matplotlib import axis +import numpy as np + +from ..utils import round_to_multiple, splice + +_si_prefixes = frozenbidict( + { + "Y": 24, # yotta + "Z": 21, # zetta + "E": 18, # exa + "P": 15, # peta + "T": 12, # tera + "G": 9, # giga + "M": 6, # mega + "k": 3, # kilo + "": 0, # base + "m": -3, # milli + "u": -6, # micro + "n": -9, # nano + "p": -12, # pico + "f": -15, # femto + "a": -18, # atto + "z": -21, # zepto + "y": -24, # yocto + }, +) + + +def compact_xaxis_units( + ax: axis.Axis, + base_unit: str = "s", + regex_pattern_template: str = r"\((.*?){base_unit}\)", +) -> None: + r"""Adjusts the units on the x-axis of a matplotlib axis object to a more + compact form. + + This function finds the multiplier (such as 10^3, 10^6, etc.) used by + matplotlib for the axis, identifies the current unit prefix (like k for + kilo, M for mega), and then consolidates these into a new, more compact unit + prefix. This is particularly useful for graphs where the axis values are + very large or very small, and a more readable unit is desired. + + Parameters + ---------- + ax : axis.Axis + Axis object whose x-axis units are to be compacted inplace. + base_unit : str, default "s" + The base unit (without prefix) for the axis. + regex_pattern_template : str, default r"\((.*?){base_unit}\)" + A regular expression template used to identify and extract the unit + prefix from the axis label. + + The template must contain `{base_unit}` as a + placeholder for the actual base unit. Default matches parenthesized expressions like "(ks)" for kiloseconds. + """ + regex_pattern = regex_pattern_template.format(base_unit=base_unit) + regex = re.compile(regex_pattern) + + # force population of + fig = ax.get_figure() + fig.canvas.draw() + + # get unit multiplier chosen by matplotlib + offset_text = ax.xaxis.get_offset_text() + if offset_text is None: # handle unit (i.e., non) multiplier + return + offset_string = offset_text.get_text() + if offset_string == "": # handle unit (i.e., non) multiplier + return + offset_amount = float(offset_string) + + assert str(offset_amount).count("1") == 1 # power of 10 + + # get unit prefix and multiplier in axis label + old_label_string = ax.get_xlabel() + (old_prefix_match,) = regex.finditer(old_label_string) + old_prefix = old_prefix_match.group(1) + old_multiplier_pow = _si_prefixes[old_prefix] + old_multiplier = 10.0**old_multiplier_pow + + # calculate new prefix from net multiplier + new_multiplier = offset_amount * old_multiplier + approx_pow = np.log10(new_multiplier) + new_multiplier_pow = round_to_multiple(approx_pow, multiple=3) + assert np.isclose(approx_pow, new_multiplier_pow) + new_prefix = _si_prefixes.inv[new_multiplier_pow] + + # apply new multiplier, consolidating offset axis scale and label unit + def replace_group(match): + # Extracting the captured group + captured_group = match.group(1) + # Defining the replacement value for the captured group + return replacement_value + + ax.xaxis.get_offset_text().set(visible=False) # remove offset axis scale + old_prefix_span = old_prefix_match.span(1) + new_label_string = splice( + old_label_string, + old_prefix_span, + new_prefix, + ) + ax.set_xlabel(new_label_string)