Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: port pytest_test macro #401

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions examples/pytest/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
load("@aspect_rules_py//py:defs.bzl", "py_pytest_main", "py_test")
load("@aspect_rules_py//pytest:defs.bzl", "py_pytest_test")

py_pytest_main(
name = "__test__",
deps = ["@pypi_pytest//:pkg"],
)

py_test(
py_pytest_test(
name = "pytest_test",
srcs = [
"foo_test.py",
":__test__",
],
imports = ["../.."],
main = ":__test__.py",
package_collisions = "warning",
pip_repo = "pypi",
deps = [
":__test__",
"@pypi_ftfy//:pkg",
"@pypi_neptune//:pkg",
"@pypi_pytest//:pkg",
],
)
20 changes: 12 additions & 8 deletions py/defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,26 @@ py_unpacked_wheel = _py_unpacked_wheel
resolutions = _resolutions

def _py_binary_or_test(name, rule, srcs, main, deps = [], resolutions = {}, **kwargs):
if main and type(main) not in ["string", "Label"]:
fail("main must be a Label or a string, not {}".format(type(main)))

# Compatibility with rules_python, see docs in py_executable.bzl
main_target = "_{}.find_main".format(name)
determine_main(
name = main_target,
target_name = name,
main = main,
srcs = srcs,
**propagate_common_rule_attributes(kwargs)
)
if type(main) != "Label":
determine_main(
name = main_target,
target_name = name,
main = main,
srcs = srcs,
**propagate_common_rule_attributes(kwargs)
)

package_collisions = kwargs.pop("package_collisions", None)

rule(
name = name,
srcs = srcs,
main = main_target,
main = main if type(main) == "Label" else main_target,
deps = deps,
resolutions = resolutions,
package_collisions = package_collisions,
Expand Down
4 changes: 4 additions & 0 deletions pytest/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
exports_files(
["pytest_shim.py"],
visibility = ["//visibility:public"],
)
65 changes: 65 additions & 0 deletions pytest/defs.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Use pytest to run tests, using a wrapper script to interface with Bazel.

Example:

```starlark
load("@aspect_rules_py//pytest:defs.bzl", "py_pytest_test")

py_pytest_test(
name = "test_w_pytest",
size = "small",
srcs = ["test.py"],
)
```

By default, `@pip//pytest` is added to `deps`.
If sharding is used (when `shard_count > 1`) then `@pip//pytest_shard` is also added.
To instead provide explicit deps for the pytest library, set `pytest_deps`:

```starlark
py_pytest_test(
name = "test_w_my_pytest",
shard_count = 2,
srcs = ["test.py"],
pytest_deps = [requirement("pytest"), requirement("pytest-shard"), ...],
)
```
"""

load("//py:defs.bzl", "py_test")

def py_pytest_test(name, srcs, deps = [], args = [], pytest_deps = None, pip_repo = "pip", **kwargs):
"""
Wrapper macro for `py_test` which supports pytest.

Args:
name: A unique name for this target.
srcs: Python source files.
deps: Dependencies, typically `py_library`.
args: Additional command-line arguments to pytest.
See https://docs.pytest.org/en/latest/how-to/usage.html
pytest_deps: Labels of the pytest tool and other packages it may import.
pip_repo: Name of the external repository where Python packages are installed.
It's typically created by `pip.parse`.
This attribute is used only when `pytest_deps` is unset.
**kwargs: Additional named parameters to py_test.
"""
shim_label = Label("//pytest:pytest_shim.py")

if pytest_deps == None:
pytest_deps = ["@{}//pytest".format(pip_repo)]
if kwargs.get("shard_count", 1) > 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 1 seems unnecessary? Check for None, which would be falsey?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The user can explicitly say shard_count = 1 which I hoped to account for

pytest_deps.append("@{}//pytest_shard".format(pip_repo))

py_test(
name = name,
srcs = [
shim_label,
] + srcs,
main = shim_label,
args = [
"--capture=no",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason to have this defined here, rather than in the shim?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know, I didn't write this code. Maybe Casey is willing to be the PR author.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This absolutely can be in the shim. In the very beginning, the shim was a trivial wrapper around pytest.main and adding default arguments here was less code than in the shim. As more pytest features were hooked up to native Bazel features, the shim gathered more argument parsing logic and I simply missed calling this out in the pull request that first began adding default arguments in the shim. I see no reason this shouldn't be part of pytest_args = ["--ignore=external"] in the shim.

] + args + ["$(location :%s)" % x for x in srcs],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these automatically spill? The arg limit length could be easily hit here.

deps = deps + pytest_deps,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to swap the ordering, so that this levels pytest dependencies will "win" over any defined transitively (perhaps there is a test helper as a py_library somewhere in the graph that consumes a different pytest version)

**kwargs
)
71 changes: 71 additions & 0 deletions pytest/pytest_shim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""A shim for executing pytest that supports test filtering, sharding, and more.

Copied from https://github.com/caseyduquettesc/rules_python_pytest/blob/331e0e511130cf4859b7589a479db6c553974abf/python_pytest/pytest_shim.py
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI @caseyduquettesc - thank you for writing this!

"""

import sys
import os

import pytest


if __name__ == "__main__":
pytest_args = ["--ignore=external"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other template had different args around disabling caching, check if they are needed here.


args = sys.argv[1:]
# pytest < 8.0 runs tests twice if __init__.py is passed explicitly as an argument.
# Remove any __init__.py file to avoid that.
# pytest.version_tuple is available since pytest 7.0
# https://github.com/pytest-dev/pytest/issues/9313
if not hasattr(pytest, "version_tuple") or pytest.version_tuple < (8, 0):
args = [arg for arg in args if arg.startswith("-") or os.path.basename(arg) != "__init__.py"]

if os.environ.get("XML_OUTPUT_FILE"):
pytest_args.append("--junitxml={xml_output_file}".format(xml_output_file=os.environ.get("XML_OUTPUT_FILE")))

# Handle test sharding - requires pytest-shard plugin.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the presence of the plugin be detected, and an error printed if it's missing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I file and leave as a follow-up? Right now I'm looking for parity with v1.1.1 https://registry.bazel.build/modules/caseyduquettesc_rules_python_pytest

if os.environ.get("TEST_SHARD_INDEX") and os.environ.get("TEST_TOTAL_SHARDS"):
pytest_args.append("--shard-id={shard_id}".format(shard_id=os.environ.get("TEST_SHARD_INDEX")))
pytest_args.append("--num-shards={num_shards}".format(num_shards=os.environ.get("TEST_TOTAL_SHARDS")))
if os.environ.get("TEST_SHARD_STATUS_FILE"):
open(os.environ["TEST_SHARD_STATUS_FILE"], "a").close()

# Handle plugins that generate reports - if they are provided with relative paths (via args),
# re-write it under bazel's test undeclared outputs dir.
if os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR"):
undeclared_output_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR")

# Flags that take file paths as value.
path_flags = [
"--report-log", # pytest-reportlog
"--json-report-file", # pytest-json-report
"--html", # pytest-html
]
for i, arg in enumerate(args):
for flag in path_flags:
if arg.startswith(f"{flag}="):
arg_split = arg.split("=", 1)
if len(arg_split) == 2 and not os.path.isabs(arg_split[1]):
args[i] = f"{flag}={undeclared_output_dir}/{arg_split[1]}"

if os.environ.get("TESTBRIDGE_TEST_ONLY"):
test_filter = os.environ["TESTBRIDGE_TEST_ONLY"]

# If the test filter does not start with a class-like name, then use test filtering instead
if not test_filter[0].isupper():
# --test_filter=test_module.test_fn or --test_filter=test_module/test_file.py
pytest_args.extend(args)
pytest_args.append("-k={filter}".format(filter=test_filter))
else:
# --test_filter=TestClass.test_fn
for arg in args:
if not arg.startswith("--"):
# arg is a src file. Add test class/method selection to it.
# test.py::TestClass::test_fn
arg = "{arg}::{module_fn}".format(arg=arg, module_fn=test_filter.replace(".", "::"))
pytest_args.append(arg)
else:
pytest_args.extend(args)

print(pytest_args, file=sys.stderr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a little spammy, print it if there's some flag set?

raise SystemExit(pytest.main(pytest_args))