From 01d878310a1e22791bc6be65566382cd5632ff10 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Mon, 21 Oct 2024 14:47:48 -0700 Subject: [PATCH] Fix merge bot to use ref for merged PR (#6417) ghstack-source-id: 420c6810c526668f6fde2d640a4a3d62caf9cde3 Pull Request resolved: https://github.com/pytorch/executorch/pull/6416 Co-authored-by: Hansong Zhang --- .github/scripts/propose_ghstack_orig_pr.py | 18 ++++++++++++------ .github/workflows/ghstack_land.yml | 4 +--- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/.github/scripts/propose_ghstack_orig_pr.py b/.github/scripts/propose_ghstack_orig_pr.py index e363a30da2..b6706f6c9e 100644 --- a/.github/scripts/propose_ghstack_orig_pr.py +++ b/.github/scripts/propose_ghstack_orig_pr.py @@ -26,9 +26,9 @@ def parse_args(): required=True, ) parser.add_argument( - "--pr", - type=int, - help="Number of the PR in the stack to check and create corresponding PR", + "--ref", + type=str, + help="Ref fo PR in the stack to check and create corresponding PR", required=True, ) return parser.parse_args() @@ -68,12 +68,18 @@ def extract_stack_from_body(pr_body: str) -> List[int]: return list(reversed(prs)) -def get_pr_stack_from_number(pr_number: int, repo: Repository) -> List[int]: +def get_pr_stack_from_number(ref: str, repo: Repository) -> List[int]: + if ref.isnumeric(): + pr_number = int(ref) + else: + branch_name = ref.replace("refs/heads/", "") + pr_number = repo.get_branch(branch_name).commit.get_pulls()[0].number + pr_stack = extract_stack_from_body(repo.get_pull(pr_number).body) if not pr_stack: raise Exception( - f"Could not find PR stack in body of #{pr_number}. " + f"Could not find PR stack in body of ref. " + "Please make sure that the PR was created with ghstack." ) @@ -129,7 +135,7 @@ def main(): with Github(auth=Auth.Token(os.environ["GITHUB_TOKEN"])) as gh: repo = gh.get_repo(args.repo) - create_prs_for_orig_branch(get_pr_stack_from_number(args.pr, repo), repo) + create_prs_for_orig_branch(get_pr_stack_from_number(args.ref, repo), repo) if __name__ == "__main__": diff --git a/.github/workflows/ghstack_land.yml b/.github/workflows/ghstack_land.yml index 2c91a1aa40..8a9f8e89a7 100644 --- a/.github/workflows/ghstack_land.yml +++ b/.github/workflows/ghstack_land.yml @@ -32,9 +32,7 @@ jobs: run: | pip install pygithub - PR_NUMBER=$(echo "$GITHUB_REF" | grep -oE '[0-9]+') - - python .github/scripts/propose_ghstack_orig_pr.py --pr $PR_NUMBER --repo pytorch/executorch + python .github/scripts/propose_ghstack_orig_pr.py --ref $GITHUB_REF --repo pytorch/executorch env: GITHUB_TOKEN: ${{ secrets.GH_PYTORCHBOT_CHERRY_PICK_TOKEN }} GITHUB_REF: ${{ github.ref }}