diff --git a/ci/scripts/towncrier_automation.py b/ci/scripts/towncrier_automation.py index e367ebbac..b49ba340e 100755 --- a/ci/scripts/towncrier_automation.py +++ b/ci/scripts/towncrier_automation.py @@ -4,6 +4,7 @@ import argparse import re import subprocess +from functools import cache from typing import TYPE_CHECKING from packaging.version import Version @@ -12,17 +13,34 @@ from collections.abc import Sequence -class Args(argparse.Namespace): - version: str - dry_run: bool +class BumpVersion(Version): + def __init__(self, version: str) -> None: + super().__init__(version) + + if len(self.release) != 3: + msg = f"{version} must contain major, minor, and patch version." + raise argparse.ArgumentTypeError(msg) + base_branch = get_base_branch() + patch_branch_pattern = re.compile(r"\d+\.\d+\.x") + if self.micro != 0 and not patch_branch_pattern.fullmatch(base_branch): + msg = ( + f"{version} is a patch release, but " + f"you are trying to release from a non-patch release branch: {base_branch}." + ) + raise argparse.ArgumentTypeError(msg) -class NoPatchReleaseOnMainError(Exception): - pass + if self.micro == 0 and base_branch != "main": + msg = ( + f"{version} is a minor or major release, " + f"but you are trying to release not from main: {base_branch}." + ) + raise argparse.ArgumentTypeError(msg) -class NoMinorMajorReleaseOffMainError(Exception): - pass +class Args(argparse.Namespace): + version: BumpVersion + dry_run: bool def parse_args(argv: Sequence[str] | None = None) -> Args: @@ -37,7 +55,7 @@ def parse_args(argv: Sequence[str] | None = None) -> Args: ) parser.add_argument( "version", - type=str, + type=BumpVersion, help=( "The new version for the release must have at least three parts, like `major.minor.patch` and no `major.minor`. " "It can have a suffix like `major.minor.patch.dev0` or `major.minor.0rc1`." @@ -49,10 +67,6 @@ def parse_args(argv: Sequence[str] | None = None) -> Args: action="store_true", ) args = parser.parse_args(argv, Args()) - # validate the version - if len(Version(args.version).release) != 3: - msg = f"Version argument {args.version} must contain major, minor, and patch version." - raise ValueError(msg) return args @@ -65,23 +79,7 @@ def main(argv: Sequence[str] | None = None) -> None: ) # Check if we are on the main branch to know if we need to backport - base_branch = subprocess.run( - ["git", "rev-parse", "--abbrev-ref", "HEAD"], - capture_output=True, - text=True, - check=True, - ).stdout.strip() - patch_branch_pattern = r"\d+\.\d+\.x" - if Version(args.version).micro != 0 and not re.fullmatch( - patch_branch_pattern, base_branch - ): - msg = f"Version {args.version} is a patch release, but " - "you are trying to release from a non-patch release branch: {base_branch}." - raise NoPatchReleaseOnMainError(msg) - if Version(args.version).micro == 0 and base_branch != "main": - msg = f"Version {args.version} is a minor or major release, " - "but you are trying to release not from main: {base_branch}." - raise NoMinorMajorReleaseOffMainError(msg) + base_branch = get_base_branch() pr_description = "" if base_branch == "main" else "@meeseeksdev backport to main" branch_name = f"release_notes_{args.version}" @@ -124,5 +122,15 @@ def main(argv: Sequence[str] | None = None) -> None: print("Dry run, not merging") +@cache +def get_base_branch(): + return subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + capture_output=True, + text=True, + check=True, + ).stdout.strip() + + if __name__ == "__main__": main()