Skip to content

Commit

Permalink
Merge pull request irgolic#130 from raphael-francis/autoclose
Browse files Browse the repository at this point in the history
Autoclose
  • Loading branch information
irgolic authored Oct 21, 2023
2 parents 8cbabc1 + 076211d commit 7c2d4dc
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 19 deletions.
9 changes: 7 additions & 2 deletions autopr/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import uuid
from typing import Type

Expand Down Expand Up @@ -32,7 +33,9 @@ class MainService:
def __init__(self):
self.log = get_logger(service="main")

self.config_dir = ".autopr" # TODO: Make this configurable
# TODO make these configurable
self.config_dir = ".autopr"
self.cache_dir = os.path.join(self.config_dir, "cache")

self.settings = self.settings_class.parse_obj({}) # pyright workaround
self.repo_path = self.get_repo_path()
Expand All @@ -56,13 +59,14 @@ def __init__(self):
repo_path=self.repo_path,
branch_name=self.branch_name,
base_branch_name=self.base_branch_name,
cache_dir=self.cache_dir,
)
self.commit_service.ensure_branch_exists()

# Create action service and agent service
action_service = ActionService(
repo=self.repo,
config_dir=self.config_dir,
cache_dir=self.cache_dir,
platform_service=self.platform_service,
commit_service=self.commit_service,
)
Expand All @@ -83,6 +87,7 @@ def __init__(self):
triggers=triggers,
publish_service=self.publish_service,
workflow_service=self.workflow_service,
commit_service=self.commit_service,
)

async def run(self):
Expand Down
6 changes: 3 additions & 3 deletions autopr/services/action_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ class ActionService:
def __init__(
self,
repo: Repo,
config_dir: str,
cache_dir: str,
platform_service: PlatformService,
commit_service: CommitService,
num_reasks: int = 3
):
self.repo = repo
self.config_dir = config_dir
self.cache_dir = cache_dir
self.platform_service = platform_service
self.commit_service = commit_service
self.num_reasks = num_reasks
Expand All @@ -58,7 +58,7 @@ def instantiate_action(
publish_service: PublishService,
) -> Action[Inputs, Outputs]:
cache_service = ShelveCacheService(
config_dir=self.config_dir,
cache_dir=self.cache_dir,
action_id=action_type.id,
)
return action_type(
Expand Down
8 changes: 4 additions & 4 deletions autopr/services/cache_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ def retrieve(self, key: Any, namespace: Optional[str] = None) -> Optional[Any]:
class ShelveCacheService(CacheService):
def __init__(
self,
config_dir: str,
cache_dir: str,
action_id: str
):
self.cache_folder = os.path.join(config_dir, "cache")
os.makedirs(self.cache_folder, exist_ok=True)
self.cache_dir = cache_dir
os.makedirs(self.cache_dir, exist_ok=True)

self.default_namespace = action_id

Expand All @@ -28,7 +28,7 @@ def _prepare_key(self, key: Any) -> str:
def _load_shelf(self, namespace: str):
return shelve.open(
os.path.join(
self.cache_folder,
self.cache_dir,
f"{namespace}.db"
),
writeback=True,
Expand Down
21 changes: 20 additions & 1 deletion autopr/services/commit_service.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import os
from typing import Optional
from typing import Optional, Literal

from git.repo import Repo

from autopr.log_config import get_logger


CHANGES_STATUS = Literal["no_changes", "cache_only", "modified"]


class CommitService:
"""
Service for creating branches, committing changes, and calling `git push` on the repository.
Expand All @@ -19,11 +22,13 @@ def __init__(
repo_path: str,
branch_name: str,
base_branch_name: str,
cache_dir: str,
):
self.repo = repo
self.repo_path = repo_path
self.branch_name = branch_name
self.base_branch_name = base_branch_name
self.cache_dir = cache_dir

self._empty_commit_message = "[placeholder]"

Expand Down Expand Up @@ -101,3 +106,17 @@ def commit(
if push:
self.log.debug(f'Pushing branch {self.branch_name} to remote...')
self.repo.git.execute(["git", "push", "-f", "origin", self.branch_name])

def get_changes_status(self) -> CHANGES_STATUS:
"""
Returns the status of the changes on the branch.
"""
# Get status of changes
status = self.repo.git.execute(["git", "status", "--porcelain"])
status_text = str(status)
if status == "":
return "no_changes"
elif len(status_text.splitlines()) == 1 and self.cache_dir in status_text:
return "cache_only"
else:
return "modified"
20 changes: 20 additions & 0 deletions autopr/services/platform_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,20 @@ async def merge_pr(
"""
raise NotImplementedError

async def close_pr(
self,
pr_number: int,
):
"""
Close the pull request.
Parameters
----------
pr_number: int
The PR number
"""
raise NotImplementedError

async def update_pr_body(self, pr_number: int, body: str):
"""
Update the body of the pull request.
Expand Down Expand Up @@ -450,6 +464,12 @@ async def _patch_pr(self, pr_number: int, data: dict[str, Any]):
response=response,
)

async def close_pr(
self,
pr_number: int,
):
await self._patch_pr(pr_number, {'state': 'closed'})

def _is_draft_error(self, response_text: str):
response_obj = json.loads(response_text)
is_draft_error = 'message' in response_obj and \
Expand Down
15 changes: 12 additions & 3 deletions autopr/services/publish_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,7 @@ async def end_section(

await self.update()

async def merge(
self,
):
async def merge(self):
"""
Merge the pull request.
"""
Expand All @@ -305,6 +303,17 @@ async def merge(
commit_title=self.title,
)

async def close(self):
"""
Close the pull request.
"""
if self.root_publish_service is not None:
return await self.root_publish_service.close()
if self.pr_number is None:
self.log.warning("PR close requested, but does not exist")
return
return await self.platform_service.close_pr(self.pr_number)

def _contains_last_code_block(self, parent: UpdateSection) -> bool:
for section in reversed(parent.updates):
if isinstance(section, CodeBlock):
Expand Down
38 changes: 32 additions & 6 deletions autopr/services/trigger_service.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio
from typing import Coroutine, Any
from typing import Coroutine, Any, Optional, assert_never

from autopr.log_config import get_logger
from autopr.models.config.elements import ActionConfig, WorkflowInvocation, IterableWorkflowInvocation, ContextAction
from autopr.models.config.entrypoints import Trigger
from autopr.models.events import EventUnion
from autopr.models.executable import Executable, ContextDict
from autopr.services.commit_service import CommitService
from autopr.services.platform_service import PlatformService
from autopr.services.publish_service import PublishService
from autopr.services.utils import truncate_strings, format_for_publishing
Expand All @@ -18,10 +19,12 @@ def __init__(
triggers: list[Trigger],
publish_service: PublishService,
workflow_service: WorkflowService,
commit_service: CommitService,
):
self.triggers = triggers
self.publish_service = publish_service
self.workflow_service = workflow_service
self.commit_service = commit_service

print("Loaded triggers:")
for t in self.triggers:
Expand Down Expand Up @@ -97,16 +100,39 @@ async def trigger_event(
self.log.error("Error in trigger", exc_info=r)
exceptions.append(r)

await self.finalize_trigger(
[trigger for trigger, _ in triggers_and_contexts],
exceptions=exceptions,
)

return results

async def finalize_trigger(
self,
triggers: list[Trigger],
exceptions: Optional[list[Exception]] = None,
):
if exceptions:
await self.publish_service.finalize(False, exceptions)
else:
await self.publish_service.finalize(True)
return

await self.publish_service.finalize(True)

changes_status = self.commit_service.get_changes_status()
# If the PR only makes changes to the cache, merge it
if changes_status == "cache_only":
await self.publish_service.merge()
# Else, if there are no changes, close the PR
elif changes_status == "no_changes":
await self.publish_service.close()
# Else, if there are material changes
elif changes_status == "modified":
# TODO split out multiple triggered workflows into separate PRs,
# so that automerge can be evaluated separately for each
if any(trigger.automerge for trigger, _ in triggers_and_contexts):
if any(trigger.automerge for trigger in triggers):
await self.publish_service.merge()

return results
else:
assert_never(changes_status)

async def handle_trigger(
self,
Expand Down

0 comments on commit 7c2d4dc

Please sign in to comment.