diff --git a/mypy.ini b/mypy.ini index 146a0e1929..231330d270 100644 --- a/mypy.ini +++ b/mypy.ini @@ -9,7 +9,7 @@ explicit_package_bases = True exclude = (?x)( ^build/ | ^.tox/ - | ^.egg/ + | ^.eggs/ | ^pkg_resources/tests/data/my-test-package-source/setup.py$ # Duplicate module name | ^.+?/(_vendor|extern)/ # Vendored | ^setuptools/_distutils/ # Vendored diff --git a/newsfragments/4246.feature.rst b/newsfragments/4246.feature.rst new file mode 100644 index 0000000000..d5dd2ead98 --- /dev/null +++ b/newsfragments/4246.feature.rst @@ -0,0 +1,4 @@ +Improve error message when ``pkg_resources.ZipProvider`` tries to extract resources with a missing Egg -- by :user:`Avasam` + +Added variables and parameter type annotations to ``pkg_resources`` to be nearly on par with typeshed.\* -- by :user:`Avasam` +\* Excluding ``TypeVar`` and ``overload``. Return types are currently inferred. diff --git a/pkg_resources/__init__.py b/pkg_resources/__init__.py index faee7dec79..c86d9f095c 100644 --- a/pkg_resources/__init__.py +++ b/pkg_resources/__init__.py @@ -1,3 +1,6 @@ +# TODO: Add Generic type annotations to initialized collections. +# For now we'd simply use implicit Any/Unknown which would add redundant annotations +# mypy: disable-error-code="var-annotated" """ Package resource API -------------------- @@ -28,6 +31,16 @@ import re import types from typing import ( + Any, + Mapping, + MutableSequence, + NamedTuple, + NoReturn, + Sequence, + Set, + Tuple, + Type, + Union, TYPE_CHECKING, List, Protocol, @@ -55,6 +68,7 @@ import ntpath import posixpath import importlib +import importlib.abc import importlib.machinery from pkgutil import get_importer @@ -62,6 +76,8 @@ # capture these to bypass sandboxing from os import utime +from os import open as os_open +from os.path import isdir, split try: from os import mkdir, rename, unlink @@ -71,9 +87,6 @@ # no write support, probably under GAE WRITE_SUPPORT = False -from os import open as os_open -from os.path import isdir, split - from pkg_resources.extern.jaraco.text import ( yield_lines, drop_comment, @@ -85,6 +98,8 @@ from pkg_resources.extern.packaging import version as _packaging_version from pkg_resources.extern.platformdirs import user_cache_dir as _user_cache_dir +if TYPE_CHECKING: + from _typeshed import StrPath warnings.warn( "pkg_resources is deprecated as an API. " @@ -93,7 +108,22 @@ stacklevel=2, ) + T = TypeVar("T") +# Type aliases +_NestedStr = Union[str, Iterable[Union[str, Iterable["_NestedStr"]]]] +_InstallerType = Callable[["Requirement"], Optional["Distribution"]] +_PkgReqType = Union[str, "Requirement"] +_EPDistType = Union["Distribution", _PkgReqType] +_MetadataType = Optional["IResourceProvider"] +# Any object works, but let's indicate we expect something like a module (optionally has __loader__ or __file__) +_ModuleLike = Union[object, types.ModuleType] +_AdapterType = Callable[..., Any] # Incomplete + + +# Use _typeshed.importlib.LoaderProtocol once available https://github.com/python/typeshed/pull/11890 +class _LoaderProtocol(Protocol): + def load_module(self, fullname: str, /) -> types.ModuleType: ... _PEP440_FALLBACK = re.compile(r"^v?(?P(?:[0-9]+!)?[0-9]+(?:\.[0-9]+)*)", re.I) @@ -290,7 +320,7 @@ def req(self): def report(self): return self._template.format(**locals()) - def with_context(self, required_by): + def with_context(self, required_by: Set[Union["Distribution", str]]): """ If required_by is non-empty, return a version of self that is a ContextualVersionConflict. @@ -347,7 +377,7 @@ class UnknownExtra(ResolutionError): """Distribution doesn't have an "extra feature" of the given name""" -_provider_factories = {} +_provider_factories: Dict[Type[_ModuleLike], _AdapterType] = {} PY_MAJOR = '{}.{}'.format(*sys.version_info) EGG_DIST = 3 @@ -357,7 +387,9 @@ class UnknownExtra(ResolutionError): DEVELOP_DIST = -1 -def register_loader_type(loader_type, provider_factory): +def register_loader_type( + loader_type: Type[_ModuleLike], provider_factory: _AdapterType +): """Register `provider_factory` to make providers for `loader_type` `loader_type` is the type or class of a PEP 302 ``module.__loader__``, @@ -367,7 +399,7 @@ def register_loader_type(loader_type, provider_factory): _provider_factories[loader_type] = provider_factory -def get_provider(moduleOrReq): +def get_provider(moduleOrReq: Union[str, "Requirement"]): """Return an IResourceProvider for the named module or requirement""" if isinstance(moduleOrReq, Requirement): return working_set.find(moduleOrReq) or require(str(moduleOrReq))[0] @@ -429,7 +461,7 @@ def get_build_platform(): get_platform = get_build_platform -def compatible_platforms(provided, required): +def compatible_platforms(provided: Optional[str], required: Optional[str]): """Can code for the `provided` platform run on the `required` platform? Returns true if either platform is ``None``, or the platforms are equal. @@ -478,7 +510,7 @@ def compatible_platforms(provided, required): return False -def get_distribution(dist): +def get_distribution(dist: _EPDistType): """Return a current distribution object for a Requirement or string""" if isinstance(dist, str): dist = Requirement.parse(dist) @@ -489,78 +521,80 @@ def get_distribution(dist): return dist -def load_entry_point(dist, group, name): +def load_entry_point(dist: _EPDistType, group: str, name: str): """Return `name` entry point of `group` for `dist` or raise ImportError""" return get_distribution(dist).load_entry_point(group, name) -def get_entry_map(dist, group=None): +def get_entry_map(dist: _EPDistType, group: Optional[str] = None): """Return the entry point map for `group`, or the full entry map""" return get_distribution(dist).get_entry_map(group) -def get_entry_info(dist, group, name): +def get_entry_info(dist: _EPDistType, group: str, name: str): """Return the EntryPoint object for `group`+`name`, or ``None``""" return get_distribution(dist).get_entry_info(group, name) class IMetadataProvider(Protocol): - def has_metadata(self, name) -> bool: + def has_metadata(self, name: str) -> bool: """Does the package's distribution contain the named metadata?""" - def get_metadata(self, name): + def get_metadata(self, name: str): """The named metadata resource as a string""" - def get_metadata_lines(self, name): + def get_metadata_lines(self, name: str): """Yield named metadata resource as list of non-blank non-comment lines Leading and trailing whitespace is stripped from each line, and lines with ``#`` as the first non-blank character are omitted.""" - def metadata_isdir(self, name) -> bool: + def metadata_isdir(self, name: str) -> bool: """Is the named metadata a directory? (like ``os.path.isdir()``)""" - def metadata_listdir(self, name): + def metadata_listdir(self, name: str): """List of metadata names in the directory (like ``os.listdir()``)""" - def run_script(self, script_name, namespace): + def run_script(self, script_name: str, namespace: Dict[str, Any]): """Execute the named script in the supplied namespace dictionary""" class IResourceProvider(IMetadataProvider, Protocol): """An object that provides access to package resources""" - def get_resource_filename(self, manager, resource_name): + def get_resource_filename(self, manager: "ResourceManager", resource_name: str): """Return a true filesystem path for `resource_name` - `manager` must be an ``IResourceManager``""" + `manager` must be a ``ResourceManager``""" - def get_resource_stream(self, manager, resource_name): + def get_resource_stream(self, manager: "ResourceManager", resource_name: str): """Return a readable file-like object for `resource_name` - `manager` must be an ``IResourceManager``""" + `manager` must be a ``ResourceManager``""" - def get_resource_string(self, manager, resource_name) -> bytes: + def get_resource_string( + self, manager: "ResourceManager", resource_name: str + ) -> bytes: """Return the contents of `resource_name` as :obj:`bytes` - `manager` must be an ``IResourceManager``""" + `manager` must be a ``ResourceManager``""" - def has_resource(self, resource_name): + def has_resource(self, resource_name: str): """Does the package contain the named resource?""" - def resource_isdir(self, resource_name): + def resource_isdir(self, resource_name: str): """Is the named resource a directory? (like ``os.path.isdir()``)""" - def resource_listdir(self, resource_name): + def resource_listdir(self, resource_name: str): """List of resource names in the directory (like ``os.listdir()``)""" class WorkingSet: """A collection of active distributions on sys.path (or a similar list)""" - def __init__(self, entries=None): + def __init__(self, entries: Optional[Iterable[str]] = None): """Create working set from list of path entries (default=sys.path)""" - self.entries = [] + self.entries: List[str] = [] self.entry_keys = {} self.by_key = {} self.normalized_to_canonical_keys = {} @@ -614,7 +648,7 @@ def _build_from_requirements(cls, req_spec): sys.path[:] = ws.entries return ws - def add_entry(self, entry): + def add_entry(self, entry: str): """Add a path item to ``.entries``, finding any distributions on it ``find_distributions(entry, True)`` is used to find distributions @@ -629,11 +663,11 @@ def add_entry(self, entry): for dist in find_distributions(entry, True): self.add(dist, entry, False) - def __contains__(self, dist): + def __contains__(self, dist: "Distribution"): """True if `dist` is the active distribution for its project""" return self.by_key.get(dist.key) == dist - def find(self, req): + def find(self, req: "Requirement"): """Find a distribution matching requirement `req` If there is an active distribution for the requested project, this @@ -657,7 +691,7 @@ def find(self, req): raise VersionConflict(dist, req) return dist - def iter_entry_points(self, group, name=None): + def iter_entry_points(self, group: str, name: Optional[str] = None): """Yield entry point objects from `group` matching `name` If `name` is None, yields all entry points in `group` from all @@ -671,7 +705,7 @@ def iter_entry_points(self, group, name=None): if name is None or name == entry.name ) - def run_script(self, requires, script_name): + def run_script(self, requires: str, script_name: str): """Locate distribution for `requires` and run `script_name` script""" ns = sys._getframe(1).f_globals name = ns['__name__'] @@ -696,7 +730,13 @@ def __iter__(self): seen[key] = 1 yield self.by_key[key] - def add(self, dist, entry=None, insert=True, replace=False): + def add( + self, + dist: "Distribution", + entry: Optional[str] = None, + insert: bool = True, + replace: bool = False, + ): """Add `dist` to working set, associated with `entry` If `entry` is unspecified, it defaults to the ``.location`` of `dist`. @@ -730,11 +770,11 @@ def add(self, dist, entry=None, insert=True, replace=False): def resolve( self, - requirements, - env=None, - installer=None, - replace_conflicting=False, - extras=None, + requirements: Iterable["Requirement"], + env: Optional["Environment"] = None, + installer: Optional[_InstallerType] = None, + replace_conflicting: bool = False, + extras: Optional[Tuple[str, ...]] = None, ): """List all distributions needed to (recursively) meet `requirements` @@ -804,7 +844,7 @@ def resolve( def _resolve_dist( self, req, best, replace_conflicting, env, installer, required_by, to_activate - ): + ) -> "Distribution": dist = best.get(req.key) if dist is None: # Find the best distribution and add it to the map @@ -833,7 +873,13 @@ def _resolve_dist( raise VersionConflict(dist, req).with_context(dependent_req) return dist - def find_plugins(self, plugin_env, full_env=None, installer=None, fallback=True): + def find_plugins( + self, + plugin_env: "Environment", + full_env: Optional["Environment"] = None, + installer: Optional[_InstallerType] = None, + fallback: bool = True, + ): """Find all activatable distributions in `plugin_env` Example usage:: @@ -914,7 +960,7 @@ def find_plugins(self, plugin_env, full_env=None, installer=None, fallback=True) return sorted_distributions, error_info - def require(self, *requirements): + def require(self, *requirements: _NestedStr): """Ensure that distributions matching `requirements` are activated `requirements` must be a string or a (possibly-nested) sequence @@ -930,7 +976,9 @@ def require(self, *requirements): return needed - def subscribe(self, callback, existing=True): + def subscribe( + self, callback: Callable[["Distribution"], object], existing: bool = True + ): """Invoke `callback` for all distributions If `existing=True` (default), @@ -966,12 +1014,14 @@ def __setstate__(self, e_k_b_n_c): self.callbacks = callbacks[:] -class _ReqExtras(dict): +class _ReqExtras(Dict["Requirement", Tuple[str, ...]]): """ Map each requirement to the extras that demanded it. """ - def markers_pass(self, req, extras=None): + def markers_pass( + self, req: "Requirement", extras: Optional[Tuple[str, ...]] = None + ): """ Evaluate markers for req against each extra that demanded it. @@ -990,7 +1040,10 @@ class Environment: """Searchable snapshot of distributions on a search path""" def __init__( - self, search_path=None, platform=get_supported_platform(), python=PY_MAJOR + self, + search_path: Optional[Sequence[str]] = None, + platform: Optional[str] = get_supported_platform(), + python: Optional[str] = PY_MAJOR, ): """Snapshot distributions available on a search path @@ -1013,7 +1066,7 @@ def __init__( self.python = python self.scan(search_path) - def can_add(self, dist): + def can_add(self, dist: "Distribution"): """Is distribution `dist` acceptable for this environment? The distribution must match the platform and python version @@ -1027,11 +1080,11 @@ def can_add(self, dist): ) return py_compat and compatible_platforms(dist.platform, self.platform) - def remove(self, dist): + def remove(self, dist: "Distribution"): """Remove `dist` from the environment""" self._distmap[dist.key].remove(dist) - def scan(self, search_path=None): + def scan(self, search_path: Optional[Sequence[str]] = None): """Scan `search_path` for distributions usable in this environment Any distributions found are added to the environment. @@ -1046,7 +1099,7 @@ def scan(self, search_path=None): for dist in find_distributions(item): self.add(dist) - def __getitem__(self, project_name): + def __getitem__(self, project_name: str): """Return a newest-to-oldest list of distributions for `project_name` Uses case-insensitive `project_name` comparison, assuming all the @@ -1057,7 +1110,7 @@ def __getitem__(self, project_name): distribution_key = project_name.lower() return self._distmap.get(distribution_key, []) - def add(self, dist): + def add(self, dist: "Distribution"): """Add `dist` if we ``can_add()`` it and it has not already been added""" if self.can_add(dist) and dist.has_version(): dists = self._distmap.setdefault(dist.key, []) @@ -1065,7 +1118,13 @@ def add(self, dist): dists.append(dist) dists.sort(key=operator.attrgetter('hashcmp'), reverse=True) - def best_match(self, req, working_set, installer=None, replace_conflicting=False): + def best_match( + self, + req: "Requirement", + working_set: WorkingSet, + installer: Optional[Callable[["Requirement"], Any]] = None, + replace_conflicting: bool = False, + ): """Find distribution best matching `req` and usable on `working_set` This calls the ``find(req)`` method of the `working_set` to see if a @@ -1092,7 +1151,11 @@ def best_match(self, req, working_set, installer=None, replace_conflicting=False # try to download/install return self.obtain(req, installer) - def obtain(self, requirement, installer=None): + def obtain( + self, + requirement: "Requirement", + installer: Optional[Callable[["Requirement"], Any]] = None, + ): """Obtain a distribution matching `requirement` (e.g. via download) Obtain a distro that matches requirement (e.g. via download). In the @@ -1109,7 +1172,7 @@ def __iter__(self): if self[key]: yield key - def __iadd__(self, other): + def __iadd__(self, other: Union["Distribution", "Environment"]): """In-place addition of a distribution or environment""" if isinstance(other, Distribution): self.add(other) @@ -1121,7 +1184,7 @@ def __iadd__(self, other): raise TypeError("Can't add %r to environment" % (other,)) return self - def __add__(self, other): + def __add__(self, other: Union["Distribution", "Environment"]): """Add an environment or distribution to an environment""" new = self.__class__([], platform=None, python=None) for env in self, other: @@ -1148,46 +1211,54 @@ class ExtractionError(RuntimeError): The exception instance that caused extraction to fail """ + manager: "ResourceManager" + cache_path: str + original_error: Optional[BaseException] + class ResourceManager: """Manage resource extraction and packages""" - extraction_path = None + extraction_path: Optional[str] = None def __init__(self): self.cached_files = {} - def resource_exists(self, package_or_requirement, resource_name): + def resource_exists(self, package_or_requirement: _PkgReqType, resource_name: str): """Does the named resource exist?""" return get_provider(package_or_requirement).has_resource(resource_name) - def resource_isdir(self, package_or_requirement, resource_name): + def resource_isdir(self, package_or_requirement: _PkgReqType, resource_name: str): """Is the named resource an existing directory?""" return get_provider(package_or_requirement).resource_isdir(resource_name) - def resource_filename(self, package_or_requirement, resource_name): + def resource_filename( + self, package_or_requirement: _PkgReqType, resource_name: str + ): """Return a true filesystem path for specified resource""" return get_provider(package_or_requirement).get_resource_filename( self, resource_name ) - def resource_stream(self, package_or_requirement, resource_name): + def resource_stream(self, package_or_requirement: _PkgReqType, resource_name: str): """Return a readable file-like object for specified resource""" return get_provider(package_or_requirement).get_resource_stream( self, resource_name ) - def resource_string(self, package_or_requirement, resource_name) -> bytes: + def resource_string( + self, package_or_requirement: _PkgReqType, resource_name: str + ) -> bytes: """Return specified resource as :obj:`bytes`""" return get_provider(package_or_requirement).get_resource_string( self, resource_name ) - def resource_listdir(self, package_or_requirement, resource_name): + def resource_listdir(self, package_or_requirement: _PkgReqType, resource_name: str): """List the contents of the named resource directory""" return get_provider(package_or_requirement).resource_listdir(resource_name) - def extraction_error(self): + def extraction_error(self) -> NoReturn: """Give an error message for problems extracting file(s)""" old_exc = sys.exc_info()[1] @@ -1217,7 +1288,7 @@ def extraction_error(self): err.original_error = old_exc raise err - def get_cache_path(self, archive_name, names=()): + def get_cache_path(self, archive_name: str, names: Iterable[str] = ()): """Return absolute location in cache for `archive_name` and `names` The parent directory of the resulting path will be created if it does @@ -1269,7 +1340,7 @@ def _warn_unsafe_extraction_path(path): ).format(**locals()) warnings.warn(msg, UserWarning) - def postprocess(self, tempname, filename): + def postprocess(self, tempname: str, filename: str): """Perform any platform-specific postprocessing of `tempname` This is where Mac header rewrites should be done; other platforms don't @@ -1289,7 +1360,7 @@ def postprocess(self, tempname, filename): mode = ((os.stat(tempname).st_mode) | 0o555) & 0o7777 os.chmod(tempname, mode) - def set_extraction_path(self, path): + def set_extraction_path(self, path: str): """Set the base path where resources will be extracted to, if needed. If you do not call this routine before any extractions take place, the @@ -1313,7 +1384,7 @@ def set_extraction_path(self, path): self.extraction_path = path - def cleanup_resources(self, force=False) -> List[str]: + def cleanup_resources(self, force: bool = False) -> List[str]: """ Delete all extracted resource files and directories, returning a list of the file and directory names that could not be successfully removed. @@ -1337,7 +1408,7 @@ def get_default_cache(): return os.environ.get('PYTHON_EGG_CACHE') or _user_cache_dir(appname='Python-Eggs') -def safe_name(name): +def safe_name(name: str): """Convert an arbitrary string to a standard distribution name Any runs of non-alphanumeric/. characters are replaced with a single '-'. @@ -1345,7 +1416,7 @@ def safe_name(name): return re.sub('[^A-Za-z0-9.]+', '-', name) -def safe_version(version): +def safe_version(version: str): """ Convert an arbitrary string to a standard version string """ @@ -1389,7 +1460,7 @@ def _safe_segment(segment): return re.sub(r'\.[^A-Za-z0-9]+', '.', segment).strip(".-") -def safe_extra(extra): +def safe_extra(extra: str): """Convert an arbitrary string to a standard 'extra' name Any runs of non-alphanumeric characters are replaced with a single '_', @@ -1398,7 +1469,7 @@ def safe_extra(extra): return re.sub('[^A-Za-z0-9.-]+', '_', extra).lower() -def to_filename(name): +def to_filename(name: str): """Convert a project or version name to its filename-escaped form Any '-' characters are currently replaced with '_'. @@ -1406,7 +1477,7 @@ def to_filename(name): return name.replace('-', '_') -def invalid_marker(text): +def invalid_marker(text: str): """ Validate text as a PEP 508 environment marker; return an exception if invalid or False otherwise. @@ -1420,7 +1491,7 @@ def invalid_marker(text): return False -def evaluate_marker(text, extra=None): +def evaluate_marker(text: str, extra: Optional[str] = None): """ Evaluate a PEP 508 environment marker. Return a boolean indicating the marker result in this environment. @@ -1438,37 +1509,40 @@ def evaluate_marker(text, extra=None): class NullProvider: """Try to implement resources and metadata for arbitrary PEP 302 loaders""" - egg_name = None - egg_info = None - loader = None + egg_name: Optional[str] = None + egg_info: Optional[str] = None + loader: Optional[_LoaderProtocol] = None + module_path: Optional[str] # Some subclasses can have a None module_path - def __init__(self, module): + def __init__(self, module: _ModuleLike): self.loader = getattr(module, '__loader__', None) self.module_path = os.path.dirname(getattr(module, '__file__', '')) - def get_resource_filename(self, manager, resource_name): + def get_resource_filename(self, manager: ResourceManager, resource_name: str): return self._fn(self.module_path, resource_name) - def get_resource_stream(self, manager, resource_name): + def get_resource_stream(self, manager: ResourceManager, resource_name: str): return io.BytesIO(self.get_resource_string(manager, resource_name)) - def get_resource_string(self, manager, resource_name) -> bytes: + def get_resource_string( + self, manager: ResourceManager, resource_name: str + ) -> bytes: return self._get(self._fn(self.module_path, resource_name)) - def has_resource(self, resource_name): + def has_resource(self, resource_name: str): return self._has(self._fn(self.module_path, resource_name)) def _get_metadata_path(self, name): return self._fn(self.egg_info, name) - def has_metadata(self, name) -> bool: + def has_metadata(self, name: str) -> bool: if not self.egg_info: return False path = self._get_metadata_path(name) return self._has(path) - def get_metadata(self, name): + def get_metadata(self, name: str): if not self.egg_info: return "" path = self._get_metadata_path(name) @@ -1481,24 +1555,24 @@ def get_metadata(self, name): exc.reason += ' in {} file at path: {}'.format(name, path) raise - def get_metadata_lines(self, name): + def get_metadata_lines(self, name: str): return yield_lines(self.get_metadata(name)) - def resource_isdir(self, resource_name): + def resource_isdir(self, resource_name: str): return self._isdir(self._fn(self.module_path, resource_name)) - def metadata_isdir(self, name) -> bool: + def metadata_isdir(self, name: str) -> bool: return bool(self.egg_info and self._isdir(self._fn(self.egg_info, name))) - def resource_listdir(self, resource_name): + def resource_listdir(self, resource_name: str): return self._listdir(self._fn(self.module_path, resource_name)) - def metadata_listdir(self, name): + def metadata_listdir(self, name: str): if self.egg_info: return self._listdir(self._fn(self.egg_info, name)) return [] - def run_script(self, script_name, namespace): + def run_script(self, script_name: str, namespace: Dict[str, Any]): script = 'scripts/' + script_name if not self.has_metadata(script): raise ResolutionError( @@ -1541,7 +1615,7 @@ def _listdir(self, path): "Can't perform this operation for unregistered loader type" ) - def _fn(self, base, resource_name): + def _fn(self, base, resource_name: str): self._validate_resource_path(resource_name) if resource_name: return os.path.join(base, *resource_name.split('/')) @@ -1624,7 +1698,8 @@ def _validate_resource_path(path): def _get(self, path) -> bytes: if hasattr(self.loader, 'get_data') and self.loader: - return self.loader.get_data(path) + # Already checked get_data exists + return self.loader.get_data(path) # type: ignore[attr-defined] raise NotImplementedError( "Can't perform this operation for loaders without 'get_data()'" ) @@ -1647,7 +1722,7 @@ def _parents(path): class EggProvider(NullProvider): """Provider based on a virtual filesystem""" - def __init__(self, module): + def __init__(self, module: _ModuleLike): super().__init__(module) self._setup_prefix() @@ -1658,7 +1733,7 @@ def _setup_prefix(self): egg = next(eggs, None) egg and self._set_egg(egg) - def _set_egg(self, path): + def _set_egg(self, path: str): self.egg_name = os.path.basename(path) self.egg_info = os.path.join(path, 'EGG-INFO') self.egg_root = path @@ -1676,7 +1751,7 @@ def _isdir(self, path) -> bool: def _listdir(self, path): return os.listdir(path) - def get_resource_stream(self, manager, resource_name): + def get_resource_stream(self, manager: object, resource_name: str): return open(self._fn(self.module_path, resource_name), 'rb') def _get(self, path) -> bytes: @@ -1717,13 +1792,14 @@ def __init__(self): empty_provider = EmptyProvider() -class ZipManifests(dict): +class ZipManifests(Dict[str, "MemoizedZipManifests.manifest_mod"]): """ zip manifest builder """ + # `path` could be `Union["StrPath", IO[bytes]]` but that violates the LSP for `MemoizedZipManifests.load` @classmethod - def build(cls, path): + def build(cls, path: str): """ Build a dictionary similar to the zipimport directory caches, except instead of tuples, store ZipInfo objects. @@ -1749,9 +1825,11 @@ class MemoizedZipManifests(ZipManifests): Memoized zipfile manifests. """ - manifest_mod = collections.namedtuple('manifest_mod', 'manifest mtime') + class manifest_mod(NamedTuple): + manifest: Dict[str, zipfile.ZipInfo] + mtime: float - def load(self, path): + def load(self, path: str): # type: ignore[override] # ZipManifests.load is a classmethod """ Load a manifest at path or return a suitable manifest already loaded. """ @@ -1768,10 +1846,12 @@ def load(self, path): class ZipProvider(EggProvider): """Resource support for zips and eggs""" - eagers = None + eagers: Optional[List[str]] = None _zip_manifests = MemoizedZipManifests() + # ZipProvider's loader should always be a zipimporter or equivalent + loader: zipimport.zipimporter - def __init__(self, module): + def __init__(self, module: _ModuleLike): super().__init__(module) self.zip_pre = self.loader.archive + os.sep @@ -1797,7 +1877,7 @@ def _parts(self, zip_path): def zipinfo(self): return self._zip_manifests.load(self.loader.archive) - def get_resource_filename(self, manager, resource_name): + def get_resource_filename(self, manager: ResourceManager, resource_name: str): if not self.egg_name: raise NotImplementedError( "resource_filename() only supported for .egg, not .zip" @@ -1820,7 +1900,7 @@ def _get_date_and_size(zip_stat): return timestamp, size # FIXME: 'ZipProvider._extract_resource' is too complex (12) - def _extract_resource(self, manager, zip_path): # noqa: C901 + def _extract_resource(self, manager: ResourceManager, zip_path): # noqa: C901 if zip_path in self._index(): for name in self._index()[zip_path]: last = self._extract_resource(manager, os.path.join(zip_path, name)) @@ -1834,6 +1914,10 @@ def _extract_resource(self, manager, zip_path): # noqa: C901 '"os.rename" and "os.unlink" are not supported ' 'on this platform' ) try: + if not self.egg_name: + raise OSError( + '"egg_name" is empty. This likely means no egg could be found from the "module_path".' + ) real_path = manager.get_cache_path(self.egg_name, self._parts(zip_path)) if self._is_current(real_path, zip_path): @@ -1922,10 +2006,10 @@ def _isdir(self, fspath) -> bool: def _listdir(self, fspath): return list(self._index().get(self._zipinfo_name(fspath), ())) - def _eager_to_zip(self, resource_name): + def _eager_to_zip(self, resource_name: str): return self._zipinfo_name(self._fn(self.egg_root, resource_name)) - def _resource_to_zip(self, resource_name): + def _resource_to_zip(self, resource_name: str): return self._zipinfo_name(self._fn(self.module_path, resource_name)) @@ -1944,13 +2028,13 @@ class FileMetadata(EmptyProvider): the provided location. """ - def __init__(self, path): + def __init__(self, path: "StrPath"): self.path = path def _get_metadata_path(self, name): return self.path - def has_metadata(self, name) -> bool: + def has_metadata(self, name: str) -> bool: return name == 'PKG-INFO' and os.path.isfile(self.path) def get_metadata(self, name): @@ -1993,7 +2077,7 @@ class PathMetadata(DefaultProvider): dist = Distribution.from_filename(egg_path, metadata=metadata) """ - def __init__(self, path, egg_info): + def __init__(self, path: str, egg_info: str): self.module_path = path self.egg_info = egg_info @@ -2001,7 +2085,7 @@ def __init__(self, path, egg_info): class EggMetadata(ZipProvider): """Metadata provider for .egg files""" - def __init__(self, importer): + def __init__(self, importer: zipimport.zipimporter): """Create a metadata provider from a zipimporter""" self.zip_pre = importer.archive + os.sep @@ -2018,7 +2102,7 @@ def __init__(self, importer): ] = _declare_state('dict', '_distribution_finders', {}) -def register_finder(importer_type, distribution_finder): +def register_finder(importer_type: type, distribution_finder: _AdapterType): """Register `distribution_finder` to find distributions in sys.path items `importer_type` is the type or class of a PEP 302 "Importer" (sys.path item @@ -2028,14 +2112,16 @@ def register_finder(importer_type, distribution_finder): _distribution_finders[importer_type] = distribution_finder -def find_distributions(path_item, only=False): +def find_distributions(path_item: str, only: bool = False): """Yield distributions accessible via `path_item`""" importer = get_importer(path_item) finder = _find_adapter(_distribution_finders, importer) return finder(importer, path_item, only) -def find_eggs_in_zip(importer, path_item, only=False): +def find_eggs_in_zip( + importer: zipimport.zipimporter, path_item: str, only: bool = False +): """ Find eggs in zip files; possibly multiple nested eggs. """ @@ -2064,14 +2150,16 @@ def find_eggs_in_zip(importer, path_item, only=False): register_finder(zipimport.zipimporter, find_eggs_in_zip) -def find_nothing(importer, path_item, only=False): +def find_nothing( + importer: Optional[object], path_item: Optional[str], only: Optional[bool] = False +): return () register_finder(object, find_nothing) -def find_on_path(importer, path_item, only=False): +def find_on_path(importer: Optional[object], path_item, only=False): """Yield distributions accessible on a sys.path directory""" path_item = _normalize_cached(path_item) @@ -2196,7 +2284,7 @@ def resolve_egg_link(path): ) -def register_namespace_handler(importer_type, namespace_handler): +def register_namespace_handler(importer_type: type, namespace_handler: _AdapterType): """Register `namespace_handler` to declare namespace packages `importer_type` is the type or class of a PEP 302 "Importer" (sys.path item @@ -2251,7 +2339,7 @@ def _handle_ns(packageName, path_item): return subpath -def _rebuild_mod_path(orig_path, package_name, module): +def _rebuild_mod_path(orig_path, package_name, module: types.ModuleType): """ Rebuild module.__path__ ensuring that all entries are ordered corresponding to their sys.path order @@ -2285,7 +2373,7 @@ def position_in_sys_path(path): module.__path__ = new_path -def declare_namespace(packageName): +def declare_namespace(packageName: str): """Declare that package 'packageName' is a namespace package""" msg = ( @@ -2302,7 +2390,7 @@ def declare_namespace(packageName): if packageName in _namespace_packages: return - path = sys.path + path: MutableSequence[str] = sys.path parent, _, _ = packageName.rpartition('.') if parent: @@ -2328,7 +2416,7 @@ def declare_namespace(packageName): _imp.release_lock() -def fixup_namespace_packages(path_item, parent=None): +def fixup_namespace_packages(path_item: str, parent: Optional[str] = None): """Ensure that previously-declared namespace packages include path_item""" _imp.acquire_lock() try: @@ -2340,7 +2428,12 @@ def fixup_namespace_packages(path_item, parent=None): _imp.release_lock() -def file_ns_handler(importer, path_item, packageName, module): +def file_ns_handler( + importer: Optional[importlib.abc.PathEntryFinder], + path_item, + packageName, + module: types.ModuleType, +): """Compute an ns-package subpath for a filesystem or zipfile importer""" subpath = os.path.join(path_item, packageName.split('.')[-1]) @@ -2360,19 +2453,24 @@ def file_ns_handler(importer, path_item, packageName, module): register_namespace_handler(importlib.machinery.FileFinder, file_ns_handler) -def null_ns_handler(importer, path_item, packageName, module): +def null_ns_handler( + importer: Optional[importlib.abc.PathEntryFinder], + path_item: Optional[str], + packageName: Optional[str], + module: Optional[_ModuleLike], +): return None register_namespace_handler(object, null_ns_handler) -def normalize_path(filename): +def normalize_path(filename: "StrPath"): """Normalize a file/dir name for comparison purposes""" return os.path.normcase(os.path.realpath(os.path.normpath(_cygwin_patch(filename)))) -def _cygwin_patch(filename): # pragma: nocover +def _cygwin_patch(filename: "StrPath"): # pragma: nocover """ Contrary to POSIX 2008, on Cygwin, getcwd (3) contains symlink components. Using @@ -2438,7 +2536,14 @@ def _set_parent_ns(packageName): class EntryPoint: """Object representing an advertised importable object""" - def __init__(self, name, module_name, attrs=(), extras=(), dist=None): + def __init__( + self, + name: str, + module_name: str, + attrs: Iterable[str] = (), + extras: Iterable[str] = (), + dist: Optional["Distribution"] = None, + ): if not MODULE(module_name): raise ValueError("Invalid module name", module_name) self.name = name @@ -2458,7 +2563,12 @@ def __str__(self): def __repr__(self): return "EntryPoint.parse(%r)" % str(self) - def load(self, require=True, *args, **kwargs): + def load( + self, + require: bool = True, + *args: Optional[Union[Environment, _InstallerType]], + **kwargs: Optional[Union[Environment, _InstallerType]], + ): """ Require packages for this EntryPoint, then resolve it. """ @@ -2470,7 +2580,9 @@ def load(self, require=True, *args, **kwargs): stacklevel=2, ) if require: - self.require(*args, **kwargs) + # We could pass `env` and `installer` directly, + # but keeping `*args` and `**kwargs` for backwards compatibility + self.require(*args, **kwargs) # type: ignore return self.resolve() def resolve(self): @@ -2483,7 +2595,11 @@ def resolve(self): except AttributeError as exc: raise ImportError(str(exc)) from exc - def require(self, env=None, installer=None): + def require( + self, + env: Optional[Environment] = None, + installer: Optional[_InstallerType] = None, + ): if not self.dist: error_cls = UnknownExtra if self.extras else AttributeError raise error_cls("Can't require() without a distribution", self) @@ -2507,7 +2623,7 @@ def require(self, env=None, installer=None): ) @classmethod - def parse(cls, src, dist=None): + def parse(cls, src: str, dist: Optional["Distribution"] = None): """Parse a single entry point from string `src` Entry point syntax follows the form:: @@ -2536,7 +2652,12 @@ def _parse_extras(cls, extras_spec): return req.extras @classmethod - def parse_group(cls, group, lines, dist=None): + def parse_group( + cls, + group: str, + lines: _NestedStr, + dist: Optional["Distribution"] = None, + ): """Parse an entry point group""" if not MODULE(group): raise ValueError("Invalid group name", group) @@ -2549,13 +2670,17 @@ def parse_group(cls, group, lines, dist=None): return this @classmethod - def parse_map(cls, data, dist=None): + def parse_map( + cls, + data: Union[str, Iterable[str], Dict[str, Union[str, Iterable[str]]]], + dist: Optional["Distribution"] = None, + ): """Parse a map of entry point groups""" if isinstance(data, dict): _data = data.items() else: _data = split_sections(data) - maps = {} + maps: Dict[str, Dict[str, "EntryPoint"]] = {} for group, lines in _data: if group is None: if not lines: @@ -2590,13 +2715,13 @@ class Distribution: def __init__( self, - location=None, - metadata=None, - project_name=None, - version=None, - py_version=PY_MAJOR, - platform=None, - precedence=EGG_DIST, + location: Optional[str] = None, + metadata: _MetadataType = None, + project_name: Optional[str] = None, + version: Optional[str] = None, + py_version: Optional[str] = PY_MAJOR, + platform: Optional[str] = None, + precedence: int = EGG_DIST, ): self.project_name = safe_name(project_name or 'Unknown') if version is not None: @@ -2608,7 +2733,13 @@ def __init__( self._provider = metadata or empty_provider @classmethod - def from_location(cls, location, basename, metadata=None, **kw): + def from_location( + cls, + location: str, + basename: str, + metadata: _MetadataType = None, + **kw: int, # We could set `precedence` explicitly, but keeping this as `**kw` for full backwards and subclassing compatibility + ): project_name, version, py_version, platform = [None] * 4 basename, ext = os.path.splitext(basename) if ext.lower() in _distributionImpl: @@ -2646,25 +2777,25 @@ def hashcmp(self): def __hash__(self): return hash(self.hashcmp) - def __lt__(self, other): + def __lt__(self, other: "Distribution"): return self.hashcmp < other.hashcmp - def __le__(self, other): + def __le__(self, other: "Distribution"): return self.hashcmp <= other.hashcmp - def __gt__(self, other): + def __gt__(self, other: "Distribution"): return self.hashcmp > other.hashcmp - def __ge__(self, other): + def __ge__(self, other: "Distribution"): return self.hashcmp >= other.hashcmp - def __eq__(self, other): + def __eq__(self, other: object): if not isinstance(other, self.__class__): # It's not a Distribution, so they are not equal return False return self.hashcmp == other.hashcmp - def __ne__(self, other): + def __ne__(self, other: object): return not self == other # These properties have to be lazy so that we don't have to load any @@ -2774,7 +2905,7 @@ def _build_dep_map(self): dm.setdefault(extra, []).extend(parse_requirements(reqs)) return dm - def requires(self, extras=()): + def requires(self, extras: Iterable[str] = ()): """List of Requirements needed for this distro if `extras` are used""" dm = self._dep_map deps = [] @@ -2813,7 +2944,7 @@ def _get_version(self): lines = self._get_metadata(self.PKG_INFO) return _version_from_file(lines) - def activate(self, path=None, replace=False): + def activate(self, path: Optional[List[str]] = None, replace: bool = False): """Ensure distribution is importable on `path` (default=sys.path)""" if path is None: path = sys.path @@ -2863,7 +2994,12 @@ def __dir__(self): ) @classmethod - def from_filename(cls, filename, metadata=None, **kw): + def from_filename( + cls, + filename: str, + metadata: _MetadataType = None, + **kw: int, # We could set `precedence` explicitly, but keeping this as `**kw` for full backwards and subclassing compatibility + ): return cls.from_location( _normalize_cached(filename), os.path.basename(filename), metadata, **kw ) @@ -2877,14 +3013,14 @@ def as_requirement(self): return Requirement.parse(spec) - def load_entry_point(self, group, name): + def load_entry_point(self, group: str, name: str): """Return the `name` entry point of `group` or raise ImportError""" ep = self.get_entry_info(group, name) if ep is None: raise ImportError("Entry point %r not found" % ((group, name),)) return ep.load() - def get_entry_map(self, group=None): + def get_entry_map(self, group: Optional[str] = None): """Return the entry point map for `group`, or the full entry map""" if not hasattr(self, "_ep_map"): self._ep_map = EntryPoint.parse_map( @@ -2894,12 +3030,17 @@ def get_entry_map(self, group=None): return self._ep_map.get(group, {}) return self._ep_map - def get_entry_info(self, group, name): + def get_entry_info(self, group: str, name: str): """Return the EntryPoint object for `group`+`name`, or ``None``""" return self.get_entry_map(group).get(name) # FIXME: 'Distribution.insert_on' is too complex (13) - def insert_on(self, path, loc=None, replace=False): # noqa: C901 + def insert_on( # noqa: C901 + self, + path: List[str], + loc=None, + replace: bool = False, + ): """Ensure self.location is on path If replace=False (default): @@ -3004,13 +3145,14 @@ def has_version(self): return False return True - def clone(self, **kw): + def clone(self, **kw: Optional[Union[str, int, IResourceProvider]]): """Copy this distribution, substituting in any changed keyword args""" names = 'project_name version py_version platform location precedence' for attr in names.split(): kw.setdefault(attr, getattr(self, attr, None)) kw.setdefault('metadata', self._provider) - return self.__class__(**kw) + # Unsafely unpacking. But keeping **kw for backwards and subclassing compatibility + return self.__class__(**kw) # type:ignore[arg-type] @property def extras(self): @@ -3107,7 +3249,7 @@ def issue_warning(*args, **kw): warnings.warn(stacklevel=level + 1, *args, **kw) -def parse_requirements(strs): +def parse_requirements(strs: _NestedStr): """ Yield ``Requirement`` objects for each specification in `strs`. @@ -3121,14 +3263,15 @@ class RequirementParseError(_packaging_requirements.InvalidRequirement): class Requirement(_packaging_requirements.Requirement): - def __init__(self, requirement_string): + def __init__(self, requirement_string: str): """DO NOT CALL THIS UNDOCUMENTED METHOD; use Requirement.parse()!""" super().__init__(requirement_string) self.unsafe_name = self.name project_name = safe_name(self.name) self.project_name, self.key = project_name, project_name.lower() self.specs = [(spec.operator, spec.version) for spec in self.specifier] - self.extras = tuple(map(safe_extra, self.extras)) + # packaging.requirements.Requirement uses a set for its extras. We use a variable-length tuple + self.extras: Tuple[str] = tuple(map(safe_extra, self.extras)) self.hashCmp = ( self.key, self.url, @@ -3138,13 +3281,13 @@ def __init__(self, requirement_string): ) self.__hash = hash(self.hashCmp) - def __eq__(self, other): + def __eq__(self, other: object): return isinstance(other, Requirement) and self.hashCmp == other.hashCmp def __ne__(self, other): return not self == other - def __contains__(self, item): + def __contains__(self, item: Union[Distribution, str, Tuple[str, ...]]): if isinstance(item, Distribution): if item.key != self.key: return False @@ -3163,7 +3306,7 @@ def __repr__(self): return "Requirement.parse(%r)" % str(self) @staticmethod - def parse(s): + def parse(s: Union[str, Iterable[str]]): (req,) = parse_requirements(s) return req @@ -3178,7 +3321,7 @@ def _always_object(classes): return classes -def _find_adapter(registry, ob): +def _find_adapter(registry: Mapping[type, _AdapterType], ob: object): """Return an adapter factory for `ob` from `registry`""" types = _always_object(inspect.getmro(getattr(ob, '__class__', type(ob)))) for t in types: @@ -3189,7 +3332,7 @@ def _find_adapter(registry, ob): raise TypeError(f"Could not find adapter for {registry} and {ob}") -def ensure_directory(path): +def ensure_directory(path: str): """Ensure that the parent directory of `path` exists""" dirname = os.path.dirname(path) os.makedirs(dirname, exist_ok=True) @@ -3208,7 +3351,7 @@ def _bypass_ensure_directory(path): pass -def split_sections(s): +def split_sections(s: _NestedStr): """Split a string or iterable thereof into (section, content) pairs Each ``section`` is a stripped version of the section header ("[section]") diff --git a/pkg_resources/extern/__init__.py b/pkg_resources/extern/__init__.py index bfb9eb8bdf..a1b7490dfb 100644 --- a/pkg_resources/extern/__init__.py +++ b/pkg_resources/extern/__init__.py @@ -1,5 +1,8 @@ +from importlib.machinery import ModuleSpec import importlib.util import sys +from types import ModuleType +from typing import Iterable, Optional, Sequence class VendorImporter: @@ -8,7 +11,12 @@ class VendorImporter: or otherwise naturally-installed packages from root_name. """ - def __init__(self, root_name, vendored_names=(), vendor_pkg=None): + def __init__( + self, + root_name: str, + vendored_names: Iterable[str] = (), + vendor_pkg: Optional[str] = None, + ): self.root_name = root_name self.vendored_names = set(vendored_names) self.vendor_pkg = vendor_pkg or root_name.replace('extern', '_vendor') @@ -26,7 +34,7 @@ def _module_matches_namespace(self, fullname): root, base, target = fullname.partition(self.root_name + '.') return not root and any(map(target.startswith, self.vendored_names)) - def load_module(self, fullname): + def load_module(self, fullname: str): """ Iterate over the search path to locate and load fullname. """ @@ -48,16 +56,22 @@ def load_module(self, fullname): "distribution.".format(**locals()) ) - def create_module(self, spec): + def create_module(self, spec: ModuleSpec): return self.load_module(spec.name) - def exec_module(self, module): + def exec_module(self, module: ModuleType): pass - def find_spec(self, fullname, path=None, target=None): + def find_spec( + self, + fullname: str, + path: Optional[Sequence[str]] = None, + target: Optional[ModuleType] = None, + ): """Return a module spec for vendored names.""" return ( - importlib.util.spec_from_loader(fullname, self) + # This should fix itself next mypy release https://github.com/python/typeshed/pull/11890 + importlib.util.spec_from_loader(fullname, self) # type: ignore[arg-type] if self._module_matches_namespace(fullname) else None )