Skip to content

Commit

Permalink
Switch from entrypoints to importlib.metadata
Browse files Browse the repository at this point in the history
Since the minimum Python version supported now directly supports
querying entry points using the standard library, write a wrapper around
to support both upstream APIs, and make use of it, rather than the
external entrypoints package.
  • Loading branch information
s-t-e-v-e-n-k committed Nov 1, 2024
1 parent 5384731 commit 30394a4
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 11 deletions.
3 changes: 3 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ include *.toml

include .bumpversion.cfg

include papermill/tests/fixtures/foo-0.0.1.dist-info/METADATA
include papermill/tests/fixtures/foo-0.0.1.dist-info/entry_points.txt

# Documentation
prune docs

Expand Down
5 changes: 2 additions & 3 deletions papermill/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
from functools import wraps

import dateutil
import entrypoints

from .clientwrap import PapermillNotebookClient
from .exceptions import PapermillException
from .iorw import write_ipynb
from .log import logger
from .utils import merge_kwargs, nb_kernel_name, nb_language, remove_args
from .utils import get_entrypoints_group, merge_kwargs, nb_kernel_name, nb_language, remove_args


class PapermillEngines:
Expand All @@ -33,7 +32,7 @@ def register_entry_points(self):
Load handlers provided by other packages
"""
for entrypoint in entrypoints.get_group_all("papermill.engine"):
for entrypoint in get_entrypoints_group("papermill.engine"):
self.register(entrypoint.name, entrypoint.load())

def get_engine(self, name=None):
Expand Down
5 changes: 2 additions & 3 deletions papermill/iorw.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import warnings
from contextlib import contextmanager

import entrypoints
import nbformat
import requests
import yaml
Expand All @@ -18,7 +17,7 @@
missing_environment_variable_generator,
)
from .log import logger
from .utils import chdir
from .utils import chdir, get_entrypoints_group
from .version import version as __version__

try:
Expand Down Expand Up @@ -116,7 +115,7 @@ def register(self, scheme, handler):

def register_entry_points(self):
# Load handlers provided by other packages
for entrypoint in entrypoints.get_group_all("papermill.io"):
for entrypoint in get_entrypoints_group("papermill.io"):
self.register(entrypoint.name, entrypoint.load())

def get_handler(self, path, extensions=None):
Expand Down
3 changes: 3 additions & 0 deletions papermill/tests/fixtures/foo-0.0.1.dist-info/METADATA
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Metadata-Version: 2.3
Name: foo
Version: 0.0.1
2 changes: 2 additions & 0 deletions papermill/tests/fixtures/foo-0.0.1.dist-info/entry_points.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[papermill.tests.fake]
foo = bar
5 changes: 3 additions & 2 deletions papermill/tests/test_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,8 @@ def test_registering_entry_points(self):
fake_entrypoint = Mock(load=Mock())
fake_entrypoint.name = "fake-engine"

with patch("entrypoints.get_group_all", return_value=[fake_entrypoint]) as mock_get_group_all:
entry_points = {"papermill.engine": [fake_entrypoint]}
with patch("papermill.utils.entry_points", return_value=entry_points) as mock_entry_points:
self.papermill_engines.register_entry_points()
mock_get_group_all.assert_called_once_with("papermill.engine")
mock_entry_points.assert_called_once()
self.assertEqual(self.papermill_engines.get_engine("fake-engine"), fake_entrypoint.load.return_value)
5 changes: 3 additions & 2 deletions papermill/tests/test_iorw.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,10 @@ def test_entrypoint_register(self):
fake_entrypoint = Mock(load=Mock())
fake_entrypoint.name = "fake-from-entry-point://"

with patch("entrypoints.get_group_all", return_value=[fake_entrypoint]) as mock_get_group_all:
entry_points = {"papermill.io": [fake_entrypoint]}
with patch("papermill.utils.entry_points", return_value=entry_points) as mock_entry_points:
self.papermill_io.register_entry_points()
mock_get_group_all.assert_called_once_with("papermill.io")
mock_entry_points.assert_called_once()
fake_ = self.papermill_io.get_handler("fake-from-entry-point://")
assert fake_ == fake_entrypoint.load.return_value

Expand Down
13 changes: 13 additions & 0 deletions papermill/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import warnings
from pathlib import Path
from tempfile import TemporaryDirectory
Expand All @@ -10,6 +11,7 @@
from ..utils import (
any_tagged_cell,
chdir,
get_entrypoints_group,
merge_kwargs,
remove_args,
retry,
Expand Down Expand Up @@ -58,3 +60,14 @@ def test_chdir():
assert Path.cwd() == Path(temp_dir)

assert Path.cwd() == old_cwd


def test_get_entrypoints_group():
# We don't need to mock anything here, there is just enough metadata
# present to give us one entry point.
sys.path.insert(0, Path(__file__).parent / "fixtures")
# We need to cast to a list here, 3.8/3.9 and 3.10+ return different
# types.
eps = list(get_entrypoints_group("papermill.tests.fake"))
sys.path.pop()
assert eps[0].name == "foo"
18 changes: 18 additions & 0 deletions papermill/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
from contextlib import contextmanager
from functools import wraps
from importlib.metadata import entry_points

from .exceptions import PapermillParameterOverwriteWarning

Expand Down Expand Up @@ -190,3 +191,20 @@ def chdir(path):
yield
finally:
os.chdir(old_dir)


def get_entrypoints_group(group):
"""Return a given group of entrypoints.
Since the importlib.metadata entry points API is very simple in 3.8 and
more complete in 3.10+, we need to support both. This function can be
removed when 3.10 is the minimum supported version, and replaced
with ``entry_points(group=group)``.
"""
eps = entry_points()
if hasattr(eps, "select"):
# New and shiny Python 3.10+ API
return eps.select(group=group)
else:
# Python 3.8 and 3.9
return eps.get(group, [])
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ nbformat >= 5.2.0
nbclient >= 0.2.0
tqdm >= 4.32.2
requests
entrypoints
tenacity >= 5.0.2
aiohttp >=3.9.0; python_version=="3.12"
ansicolors

0 comments on commit 30394a4

Please sign in to comment.