Skip to content

Commit

Permalink
FIX: match signature of axes.legend with add_sorted_driver_legend (#621)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: theOehrly <[email protected]>
  • Loading branch information
formulatimer and theOehrly authored Jul 25, 2024
1 parent 93bfdde commit 25698cb
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
31 changes: 28 additions & 3 deletions fastf1/plotting/_interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import warnings
from collections.abc import Sequence
from typing import (
Any,
Expand All @@ -8,6 +9,7 @@
)

import matplotlib.axes
import matplotlib.legend

from fastf1.core import Session
from fastf1.internals.fuzzy import fuzzy_matcher
Expand Down Expand Up @@ -715,7 +717,9 @@ def list_compounds(session: Session) -> list[str]:
return list(_Constants[year].CompoundColors.keys())


def add_sorted_driver_legend(ax: matplotlib.axes.Axes, session: Session):
def add_sorted_driver_legend(
ax: matplotlib.axes.Axes, session: Session, *args, **kwargs
):
"""
Adds a legend to the axis where drivers are grouped by team and within each
team they are shown in the same order that is used for selecting plot
Expand All @@ -725,12 +729,18 @@ def add_sorted_driver_legend(ax: matplotlib.axes.Axes, session: Session):
``ax.legend()`` method. It can only be used when driver names or driver
abbreviations are used as labels for the legend.
This function supports the same ``*args`` and ``**kwargs`` as
Matplotlib's ``ax.legend()``, including the ``handles`` and ``labels``
arguments. Check the Matplotlib documentation for more information.
There is no particular need to use this function except to make the
legend more visually pleasing.
Args:
ax: An instance of a Matplotlib ``Axes`` object
session: the session for which the data should be obtained
*args: Matplotlib legend args
**kwargs: Matplotlib legend kwargs
Returns:
``matplotlib.legend.Legend``
Expand All @@ -740,7 +750,22 @@ def add_sorted_driver_legend(ax: matplotlib.axes.Axes, session: Session):
"""
dtm = _get_driver_team_mapping(session)
handles, labels = ax.get_legend_handles_labels()

try:
ret = matplotlib.legend._parse_legend_args([ax], *args, **kwargs)
if len(ret) == 3:
handles, labels, kwargs = ret
extra_args = []
else:
handles, labels, extra_args, kwargs = ret

except AttributeError:
warnings.warn("Failed to parse optional legend arguments correctly.",
UserWarning)
extra_args = []
kwargs.pop('handles', None)
kwargs.pop('labels', None)
handles, labels = ax.get_legend_handles_labels()

teams_list = list(dtm.teams_by_normalized.values())
driver_list = list(dtm.drivers_by_normalized.values())
Expand Down Expand Up @@ -769,7 +794,7 @@ def add_sorted_driver_legend(ax: matplotlib.axes.Axes, session: Session):
handles_new.append(elem[2])
labels_new.append(elem[3])

return ax.legend(handles_new, labels_new)
return ax.legend(handles_new, labels_new, *extra_args, **kwargs)


def set_default_colormap(colormap: str):
Expand Down
9 changes: 9 additions & 0 deletions fastf1/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,3 +452,12 @@ def test_override_team_constants():

if fastf1.plotting.get_team_name('Haas', session, short=True) != 'Haas':
raise RuntimeError("Test cleanup failed!")


def test_import_internal_mpl_lgend_arg_kwarg_parser():
# Import the module and just try to access the internal function to see
# if it is available. This is not a test of the function itself but rather
# serves as an early warning if the function is removed or renamed by
# causing a test failure.
import matplotlib.legend
_ = matplotlib.legend._parse_legend_args

0 comments on commit 25698cb

Please sign in to comment.