Skip to content

Commit

Permalink
Ensure namespace packages and submodule entrypoints work
Browse files Browse the repository at this point in the history
  • Loading branch information
lkubb authored and dwoz committed May 22, 2024
1 parent 217fafd commit 33efd9c
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 56 deletions.
66 changes: 40 additions & 26 deletions salt/utils/thin.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,18 +229,19 @@ def _is_shareable(mod):
return os.path.basename(mod) in shareable


def _add_dependency(container, obj):
def _add_dependency(container, obj, namespace=None):
"""
Add a dependency to the top list.
:param obj:
:param is_file:
:param namespace: Optional tuple of parent namespaces for namespace packages
:return:
"""
if os.path.basename(obj.__file__).split(".")[0] == "__init__":
container.append(os.path.dirname(obj.__file__))
container.append((os.path.dirname(obj.__file__), namespace))
else:
container.append(obj.__file__.replace(".pyc", ".py"))
container.append((obj.__file__.replace(".pyc", ".py"), None))


def gte():
Expand Down Expand Up @@ -459,9 +460,9 @@ def get_tops(extra_mods="", so_mods=""):
moddir, modname = os.path.split(locals()[mod].__file__)
base, _ = os.path.splitext(modname)
if base == "__init__":
tops.append(moddir)
tops.append((moddir, None))
else:
tops.append(os.path.join(moddir, base + ".py"))
tops.append((os.path.join(moddir, base + ".py"), None))
except ImportError as err:
log.error(
'Unable to import extra-module "%s": %s', mod, err, exc_info=True
Expand All @@ -470,8 +471,8 @@ def get_tops(extra_mods="", so_mods=""):
for mod in [m for m in so_mods.split(",") if m]:
try:
locals()[mod] = __import__(mod)
tops.append(locals()[mod].__file__)
except ImportError as err:
tops.append((locals()[mod].__file__, None))
except ImportError:
log.error('Unable to import so-module "%s"', mod, exc_info=True)

return tops
Expand Down Expand Up @@ -607,10 +608,32 @@ def _catch_entry_points_exception(entry_point):
)


def _get_package_root_mod(mod):
"""
Given an imported module, find the topmost module
that is not a namespace package.
Returns a tuple of (root_mod, tuple), where the
second value is a tuple of parent namespaces.
Needed for saltext discovery if the entrypoint is not
part of the root module.
"""
parts = mod.__name__.split(".")
level = 0
while level < len(parts):
root_mod_name = ".".join(parts[: level + 1])
root_mod = sys.modules[root_mod_name]
# importlib.machinery.NamespaceLoader requires Python 3.11+
if type(root_mod.__path__) is list:
return root_mod, tuple(parts[:level])
level += 1
raise RuntimeError(f"Unable to determine package root mod for {mod}")


def _discover_saltexts(allowlist=None, blocklist=None):
mods = []
loaded_saltexts = {}
blocklist = blocklist or []

for entry_point in salt.utils.entrypoints.iter_entry_points("salt.loader"):
if allowlist is not None and entry_point.dist.name not in allowlist:
log.debug(
Expand Down Expand Up @@ -660,27 +683,16 @@ def _discover_saltexts(allowlist=None, blocklist=None):
"entrypoints": {},
}

if isinstance(loaded_entry_point, types.FunctionType):
func_mod = inspect.getmodule(loaded_entry_point)
try:
mod = sys.modules[func_mod.__package__]
except KeyError:
mod = func_mod
except AttributeError:
# func_mod was None
log.debug(
"Failed discovering module for function entrypoint '%s' defined by '%s'",
entry_point.name,
entry_point.dist.name,
)
continue
else:
mod = loaded_entry_point
mod = inspect.getmodule(loaded_entry_point)
with _catch_entry_points_exception(entry_point) as ctx:
root_mod, namespace = _get_package_root_mod(mod)
if ctx.exception_caught:
continue

loaded_saltexts[entry_point.dist.name]["entrypoints"][
entry_point.name
] = entry_point.value
_add_dependency(mods, mod)
_add_dependency(mods, root_mod, namespace=namespace)

# We need the mods to be in a deterministic order for the hash digest later
return list(sorted(set(mods))), loaded_saltexts
Expand Down Expand Up @@ -832,7 +844,7 @@ def gen_thin(
# Pack default data
log.debug("Packing default libraries based on current Salt version")
for py_ver, tops in tops_py_version_mapping.items():
for top in tops:
for top, namespace in tops:
if absonly and not os.path.isabs(top):
continue
base = os.path.basename(top)
Expand All @@ -859,7 +871,9 @@ def gen_thin(
for name in files:
if not name.endswith((".pyc", ".pyo")):
digest_collector.add(os.path.join(root, name))
arcname = os.path.join(site_pkg_dir, root, name)
arcname = os.path.join(
site_pkg_dir, *(namespace or ()), root, name
)
if hasattr(tfp, "getinfo"):
try:
# This is a little slow but there's no clear way to detect duplicates
Expand Down
6 changes: 6 additions & 0 deletions tests/integration/files/conf/master
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ discovery: false
#enable_ssh_minions: True
#ignore_host_keys: True

# Ensure pytest-salt-factories is not included
# in the thin tar during integration tests
# (it defines a saltext, which are autodiscovered by default)
thin_saltext_blocklist:
- pytest-salt-factories

sdbetcd:
driver: etcd
etcd.host: 127.0.0.1
Expand Down
39 changes: 29 additions & 10 deletions tests/pytests/integration/ssh/test_saltext.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,41 @@ def salt_extension(tmp_path_factory):


@pytest.fixture(scope="module")
def other_salt_extension(tmp_path_factory):
def namespaced_salt_extension(tmp_path_factory):
with FakeSaltExtension(
tmp_path_factory=tmp_path_factory,
name="salt-ext2-ssh-test",
name="saltext.ssh-test2",
virtualname="barbaz",
) as extension:
yield extension


@pytest.fixture(scope="module")
def venv(tmp_path_factory, salt_extension, other_salt_extension):
def namespaced_salt_extension_2(tmp_path_factory):
with FakeSaltExtension(
tmp_path_factory=tmp_path_factory,
name="saltext.ssh-test3",
virtualname="wut",
) as extension:
yield extension


@pytest.fixture(scope="module")
def venv(
tmp_path_factory,
salt_extension,
namespaced_salt_extension,
namespaced_salt_extension_2,
):
venv_dir = tmp_path_factory.mktemp("saltext-ssh-test-venv")
saltexts = (salt_extension, namespaced_salt_extension, namespaced_salt_extension_2)
try:
with SaltVirtualEnv(venv_dir=venv_dir) as _venv:
_venv.install(str(salt_extension.srcdir))
_venv.install(str(other_salt_extension.srcdir))
for saltext in saltexts:
_venv.install(str(saltext.srcdir))
installed_packages = _venv.get_installed_packages()
assert salt_extension.name in installed_packages
assert other_salt_extension.name in installed_packages
for saltext in saltexts:
assert saltext.name in installed_packages
yield _venv
finally:
shutil.rmtree(venv_dir, ignore_errors=True)
Expand Down Expand Up @@ -67,8 +83,8 @@ def args(venv, salt_master, salt_ssh_roster_file, sshd_config_dir):
"saltext_conf",
(
{},
{"thin_saltext_allowlist": ["salt-ext-ssh-test"]},
{"thin_saltext_blocklist": ["salt-ext2-ssh-test"]},
{"thin_saltext_allowlist": ["salt-ext-ssh-test", "saltext.ssh-test3"]},
{"thin_saltext_blocklist": ["saltext.ssh-test2"]},
),
indirect=True,
)
Expand All @@ -87,14 +103,17 @@ def test_saltexts_are_available_on_target(venv, args, saltext_conf):
else:
assert res.returncode > 0
assert "'barbaz.echo1' is not available" in res.stdout
ext3_args = args + ["wut.echo1", "wat"]
res = venv.run(*ext3_args, check=True)
assert res.stdout == "localhost:\n wat\n"


@pytest.mark.usefixtures("saltext_conf")
@pytest.mark.parametrize(
"saltext_conf", ({"thin_exclude_saltexts": True},), indirect=True
)
def test_saltexts_can_be_excluded(venv, args):
for ext in ("foobar", "barbaz"):
for ext in ("foobar", "barbaz", "wut"):
ext_args = args + [f"{ext}.echo1", "foo"]
res = venv.run(*ext_args, check=False)
assert res.returncode > 0
Expand Down
48 changes: 48 additions & 0 deletions tests/pytests/unit/utils/test_thin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import importlib
import os
import sys

import pytest
import saltfactories.utils.saltext

import salt.exceptions
import salt.utils.stringutils
import salt.utils.thin
from tests.support.mock import MagicMock, patch
from tests.support.pytest.helpers import FakeSaltExtension


def _mock_popen(return_value=None, side_effect=None, returncode=0):
Expand Down Expand Up @@ -77,3 +83,45 @@ def test_get_ext_tops(version):
else:
assert not [x for x in ret["namespace"]["dependencies"] if "distro" in x]
assert [x for x in ret["namespace"]["dependencies"] if "msgpack" in x]


def test_get_package_root_mod():
res = salt.utils.thin._get_package_root_mod(saltfactories.utils.saltext)
assert res[0] is saltfactories
assert res[1] == ()


@pytest.fixture
def namespaced_saltext(tmp_path_factory):
with FakeSaltExtension(
tmp_path_factory=tmp_path_factory,
name="saltext.wut",
) as extension:
try:
sys.path.insert(0, str(extension.srcdir / "src"))
yield extension
finally:
sys.path.pop(0)


def test_get_namespaced_package_root_mod(namespaced_saltext):
saltext = importlib.import_module(namespaced_saltext.name)
res = salt.utils.thin._get_package_root_mod(saltext)
assert res[0].__name__ == namespaced_saltext.name
assert res[1] == ("saltext",)


def test_discover_saltexts():
"""
pytest-salt-factories provides a saltext, which can be discovered here.
"""
mods, dists = salt.utils.thin._discover_saltexts()
assert mods
assert any(mod.endswith(f"{os.sep}saltfactories") and not ns for mod, ns in mods)
assert dists
dist = "pytest-salt-factories"
assert dist in dists
assert "entrypoints" in dists[dist]
assert "name" in dists[dist]
assert dists[dist]["name"].startswith("pytest_salt_factories")
assert dists[dist]["name"].endswith(".dist-info")
27 changes: 15 additions & 12 deletions tests/support/pytest/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,9 +486,9 @@ def _laydown_files(self):
if not setup_cfg.exists():
setup_cfg.write_text(
textwrap.dedent(
"""\
f"""\
[metadata]
name = {0}
name = {self.name}
version = 1.0
description = Salt Extension Test
author = Pedro
Expand All @@ -509,27 +509,30 @@ def _laydown_files(self):
[options]
zip_safe = False
include_package_data = True
packages = find:
package_dir =
=src
packages = find{'_namespace' if '.' in self.pkgname else ''}:
python_requires = >= 3.5
setup_requires =
wheel
setuptools>=50.3.2
[options.packages.find]
where = src
[options.entry_points]
salt.loader=
module_dirs = {1}
runner_dirs = {1}.loader:get_runner_dirs
states_dirs = {1}.loader:get_state_dirs
wheel_dirs = {1}.loader:get_new_style_entry_points
""".format(
self.name, self.pkgname
)
module_dirs = {self.pkgname}
runner_dirs = {self.pkgname}.loader:get_runner_dirs
states_dirs = {self.pkgname}.loader:get_state_dirs
wheel_dirs = {self.pkgname}.loader:get_new_style_entry_points
"""
)
)

extension_package_dir = self.srcdir / self.pkgname
extension_package_dir = self.srcdir.joinpath("src", *self.pkgname.split("."))
if not extension_package_dir.exists():
extension_package_dir.mkdir()
extension_package_dir.mkdir(parents=True)
extension_package_dir.joinpath("__init__.py").write_text("")
extension_package_dir.joinpath("loader.py").write_text(
textwrap.dedent(
Expand Down
Loading

0 comments on commit 33efd9c

Please sign in to comment.