Skip to content

Commit

Permalink
✨ integrate conflict detection
Browse files Browse the repository at this point in the history
  • Loading branch information
tackyunicorn committed Oct 19, 2023
1 parent 85e24aa commit e9f3398
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 48 deletions.
49 changes: 26 additions & 23 deletions code/get_file_at_version.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
from typing import Dict
import requests
import json
import requests
import subprocess
from urllib.parse import quote
from get_github_owner_repo import get_github_owner_repo


def get_file_at_version(path: str, sha: str, token: str) -> Dict[str, str]:
instance = 'https://api.github.com'
def get_file_at_version(path: str, sha: str) -> str:
try:
url = f'{instance}/repos/jwgerlach00/ml_protein_degradation_multitask/contents/{path}?ref={sha}'
headers = {'Authorization': f'Bearer {token}'}

response = requests.get(url, headers=headers)

if response.status_code == 404:
return ''

contents = response.text
return json.loads(content)

except Exception as error:
print(error)

# token = 'YOUR-TOKEN-HERE'
# sha = '470ae25'
# path = 'package/src/linkerology_multitask/dataset_creation/LinkerologySampler.py'
# get_file_at_version(path, sha, token)
owner, repo = get_github_owner_repo()
gh_cli_command = [
"gh", "api", f"repos/{owner}/{repo}/contents/{quote(path)}?ref={sha}",
]
output = subprocess.run(
gh_cli_command,
capture_output=True,
text=True,
check=True
)
output = output.stdout.strip()
output = dict(json.loads(output))
download_url = output["download_url"]

content = ""
response = requests.get(download_url);
if response.status_code == 200:
content = response.text
return content
except subprocess.CalledProcessError:
return ""
17 changes: 17 additions & 0 deletions code/get_github_owner_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import subprocess

def get_github_owner_repo():
try:
output = subprocess.run(
["git", "config", "--get", "remote.origin.url"],
capture_output=True,
text=True,
check=True
)
origin_url = output.stdout.strip().split('/')
owner = origin_url[-2]
repo = origin_url[-1][0:-4]
return owner, repo
except subprocess.CalledProcessError:
return ""

56 changes: 42 additions & 14 deletions code/gits_check_conflicts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import subprocess
import json
from parse_diff import normalize_line_endings, get_common_modified_lines
from get_file_at_version import get_file_at_version
from get_github_owner_repo import get_github_owner_repo

def check_conflicts(args):
target_branch = ""
Expand Down Expand Up @@ -47,8 +50,10 @@ def check_conflicts(args):
"number": pull_request["number"],
"title": pull_request["title"],
"author": dict(pull_request["author"])["login"],
"baseRef": dict(dict(pull_request["baseRef"])["target"])["oid"],
"headRef": dict(dict(pull_request["headRef"])["target"])["oid"]
"baseRef": pull_request["baseRefOid"],
"headRef": dict(dict(pull_request["headRef"])["target"])["oid"],
"hasConflict": False,
"skipped": False
}

for path in pull_request.get("commonFilePaths", []):
Expand All @@ -57,6 +62,23 @@ def check_conflicts(args):
else:
file_to_pr[path] = [pr]

for file in file_to_pr:
for pr in file_to_pr[file]:
local_base = normalize_line_endings(git_file_at_version(file, merge_base))
pr_base = normalize_line_endings(get_file_at_version(file, pr["baseRef"]))

if(local_base != pr_base):
pr["skipped"] = True
continue

local_head = normalize_line_endings(get_local_file(file))
pr_head = normalize_line_endings(get_file_at_version(file, pr["headRef"]))

common_modified_lines = get_common_modified_lines(pr_base, local_head, pr_head)
if (len(common_modified_lines) != 0):
pr["hasConflict"] = True
pr["commonModifiedLines"] = common_modified_lines

print(json.dumps(file_to_pr, indent=4))

def git_merge_base(target_branch):
Expand Down Expand Up @@ -84,10 +106,10 @@ def git_modified_files(merge_base_sha):
except subprocess.CalledProcessError:
return ""

def git_origin_url():
def git_file_at_version(path, sha):
try:
output = subprocess.run(
["git", "config", "--get", "remote.origin.url"],
["git", "--no-pager", "show", f"{sha}:{path}"],
capture_output=True,
text=True,
check=True
Expand All @@ -96,23 +118,29 @@ def git_origin_url():
except subprocess.CalledProcessError:
return ""

def get_recent_prs(target_branch):
origin_url = git_origin_url()
def get_local_file(file):
try:
with open(file, 'r') as file:
contents = file.read()
return contents
except FileNotFoundError:
return ""
except Exception:
return ""

origin_url = origin_url.split('/')
github_project_owner = origin_url[-2]
github_project_name = origin_url[-1][0:-4]
def get_recent_prs(target_branch):
github_repo_owner, github_repo_name = get_github_owner_repo()

recent_prs_query = f"""
query($github_project_owner: String!, $github_project_name: String!, $target_branch: String!) {{
repository (owner: $github_project_owner, name: $github_project_name) {{
query($github_repo_owner: String!, $github_repo_name: String!, $target_branch: String!) {{
repository (owner: $github_repo_owner, name: $github_repo_name) {{
pullRequests (first: 100, states: OPEN, baseRefName: $target_branch, orderBy: {{ field: CREATED_AT, direction: DESC }}) {{
nodes {{
number
title
author {{ login }}
files (first: 100) {{ nodes {{ path }} }}
baseRef {{ target {{ oid }} }}
baseRefOid
headRef {{ target {{ oid }} }}
}}
}}
Expand All @@ -122,8 +150,8 @@ def get_recent_prs(target_branch):

gh_cli_command = [
"gh", "api", "graphql",
"-F", f"github_project_owner={github_project_owner}",
"-F", f"github_project_name={github_project_name}",
"-F", f"github_repo_owner={github_repo_owner}",
"-F", f"github_repo_name={github_repo_name}",
"-F", f"target_branch={target_branch}",
"-f", f"query={recent_prs_query}"
]
Expand Down
22 changes: 11 additions & 11 deletions code/parse_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,22 @@
import re


def get_common_modified_lines(gmr_old: List[str], gmr1_new: List[str], gmr2_new: List[str]) -> List[int]:
def get_common_modified_lines(pr_old, pr1_new, pr2_new) -> List[int]:
# Normalize for different OS. Split lines into list, expected by difflib.united_diff
gmr_old = normalize_line_endings(gmr_old).split('\n')
gmr1_new = normalize_line_endings(gmr1_new).split('\n')
gmr2_new = normalize_line_endings(gmr2_new).split('\n')
pr_old = pr_old.split('\n')
pr1_new = pr1_new.split('\n')
pr2_new = pr2_new.split('\n')

# Generate unified diff patches for gmr_old to gmr1_new and gmr_old to gmr2_new
gmr1_patch = difflib.unified_diff(gmr_old, gmr1_new, lineterm='', fromfile='a', tofile='b')
gmr2_patch = difflib.unified_diff(gmr_old, gmr2_new, lineterm='', fromfile='a', tofile='b')
# Generate unified diff patches for pr_old to pr1_new and pr_old to pr2_new
pr1_patch = difflib.unified_diff(pr_old, pr1_new, lineterm='', fromfile='a', tofile='b')
pr2_patch = difflib.unified_diff(pr_old, pr2_new, lineterm='', fromfile='a', tofile='b')

# Get lists of modified lines from the generated patches
gmr1_modified_lines = get_modified_lines(gmr1_patch)
gmr2_modified_lines = get_modified_lines(gmr2_patch)
pr1_modified_lines = get_modified_lines(pr1_patch)
pr2_modified_lines = get_modified_lines(pr2_patch)

# Find the common modified lines between gmr1 and gmr2
return sorted(list(set(gmr1_modified_lines) & set(gmr2_modified_lines)))
# Find the common modified lines between pr1 and pr2
return sorted(list(set(pr1_modified_lines) & set(pr2_modified_lines)))

def normalize_line_endings(text: str) -> str:
# Normalize line endings to '\n'
Expand Down

0 comments on commit e9f3398

Please sign in to comment.