diff --git a/src/app/api/api_v1/endpoints/compute.py b/src/app/api/api_v1/endpoints/compute.py index 831d758..0b988ab 100644 --- a/src/app/api/api_v1/endpoints/compute.py +++ b/src/app/api/api_v1/endpoints/compute.py @@ -31,7 +31,9 @@ async def check_code_against_repo_guidelines( # Fetch guidelines guideline_list = [elt for elt in await guidelines.fetch_all(("repo_id", repo_id))] # Run analysis - return openai_client.analyze_multi(payload.code, guideline_list, mode=ExecutionMode.MULTI, user_id=str(user.id)) + return openai_client.check_code_against_guidelines( + payload.code, guideline_list, mode=ExecutionMode.MULTI, user_id=str(user.id) + ) @router.post("/check/{guideline_id}", status_code=status.HTTP_200_OK) @@ -47,4 +49,4 @@ async def check_code_against_guideline( user.id, event="compute-check", properties={"repo_id": guideline.repo_id, "guideline_id": guideline_id} ) # Run analysis - return openai_client.analyze_mono(payload.code, guideline, user_id=str(user.id)) + return openai_client.check_code(payload.code, guideline, user_id=str(user.id)) diff --git a/src/app/api/api_v1/endpoints/guidelines.py b/src/app/api/api_v1/endpoints/guidelines.py index 40feb7e..de19eb5 100644 --- a/src/app/api/api_v1/endpoints/guidelines.py +++ b/src/app/api/api_v1/endpoints/guidelines.py @@ -12,8 +12,19 @@ from app.crud import GuidelineCRUD, RepositoryCRUD from app.models import Guideline, Repository, UserScope from app.schemas.base import OptionalGHToken -from app.schemas.guidelines import ContentUpdate, GuidelineCreate, GuidelineCreation, GuidelineEdit, OrderUpdate +from app.schemas.guidelines import ( + ContentUpdate, + ExampleRequest, + GuidelineContent, + GuidelineCreate, + GuidelineCreation, + GuidelineEdit, + GuidelineExample, + OrderUpdate, + TextContent, +) from app.services.github import gh_client +from app.services.openai import openai_client from app.services.telemetry import telemetry_client router = APIRouter() @@ -100,3 +111,23 @@ async def delete_guideline( repo = cast(Repository, await repos.get(guideline.repo_id, strict=True)) gh_client.check_user_permission(user, repo.full_name, repo.owner_id, payload.github_token, repo.installed_by) await guidelines.delete(guideline_id) + + +@router.post("/parse", status_code=status.HTTP_200_OK) +async def parse_guidelines_from_text( + payload: TextContent, + user=Security(get_current_user, scopes=[UserScope.ADMIN, UserScope.USER]), +) -> List[GuidelineContent]: + telemetry_client.capture(user.id, event="guideline-parse") + # Analyze with LLM + return openai_client.parse_guidelines_from_text(payload.content, user_id=str(user.id)) + + +@router.post("/examples", status_code=status.HTTP_200_OK) +async def generate_examples_for_text( + payload: ExampleRequest, + user=Security(get_current_user, scopes=[UserScope.ADMIN, UserScope.USER]), +) -> GuidelineExample: + telemetry_client.capture(user.id, event="guideline-examples") + # Analyze with LLM + return openai_client.generate_examples_for_instruction(payload.content, payload.language, user_id=str(user.id)) diff --git a/src/app/api/api_v1/endpoints/repos.py b/src/app/api/api_v1/endpoints/repos.py index 49e3064..db2ae3a 100644 --- a/src/app/api/api_v1/endpoints/repos.py +++ b/src/app/api/api_v1/endpoints/repos.py @@ -3,6 +3,8 @@ # All rights reserved. # Copying and/or distributing is strictly prohibited without the express permission of its copyright owner. +import logging +from base64 import b64decode from datetime import datetime from typing import List, cast @@ -12,12 +14,14 @@ from app.crud import GuidelineCRUD, RepositoryCRUD from app.models import Guideline, Repository, User, UserScope from app.schemas.base import OptionalGHToken -from app.schemas.guidelines import OrderUpdate +from app.schemas.guidelines import OrderUpdate, ParsedGuideline from app.schemas.repos import GuidelineOrder, RepoCreate, RepoCreation, RepoUpdate from app.services.github import gh_client +from app.services.openai import openai_client from app.services.slack import slack_client from app.services.telemetry import telemetry_client +logger = logging.getLogger("uvicorn.error") router = APIRouter() @@ -151,6 +155,36 @@ async def fetch_guidelines_from_repo( return [elt for elt in await guidelines.fetch_all(("repo_id", repo_id))] +@router.post("/{repo_id}/parse", status_code=status.HTTP_200_OK) +async def parse_guidelines_from_github( + payload: OptionalGHToken, + repo_id: int = Path(..., gt=0), + repos: RepositoryCRUD = Depends(get_repo_crud), + user=Security(get_current_user, scopes=[UserScope.ADMIN, UserScope.USER]), +) -> List[ParsedGuideline]: + telemetry_client.capture(user.id, event="repo-parse-guidelines", properties={"repo_id": repo_id}) + # Sanity check + repo = cast(Repository, await repos.get(repo_id, strict=True)) + # STATIC CONTENT + # Parse CONTRIBUTING (README if CONTRIBUTING doesn't exist) + contributing = gh_client.get_file(repo.full_name, "CONTRIBUTING.md", payload.github_token) + # readme = gh_client.get_readme(payload.github_token) + # diff_hunk, body, path + # comments = gh_client.list_review_comments(payload.github_token) + # Ideas: filter on pulls with highest amount of comments recently, add the review output rejection/etc + # If not enough information, raise error + if contributing is None: + raise HTTPException(status.HTTP_404_NOT_FOUND, detail="No useful information is accessible in the repository") + # Analyze with LLM + contributing_guidelines = openai_client.parse_guidelines_from_text( + b64decode(contributing["content"]).decode(), user_id=str(user.id) + ) + return [ + ParsedGuideline(**guideline.dict(), repo_id=repo_id, origin_path=contributing["path"]) + for guideline in contributing_guidelines + ] + + @router.post("/{repo_id}/waitlist", status_code=status.HTTP_200_OK) async def add_repo_to_waitlist( repo_id: int = Path(..., gt=0), diff --git a/src/app/schemas/guidelines.py b/src/app/schemas/guidelines.py index b83a1af..ae1fd28 100644 --- a/src/app/schemas/guidelines.py +++ b/src/app/schemas/guidelines.py @@ -12,11 +12,29 @@ __all__ = ["GuidelineCreate", "GuidelineEdit", "ContentUpdate", "OrderUpdate"] +class TextContent(BaseModel): + content: str = Field(..., min_length=10) + + +class ExampleRequest(TextContent): + language: str = Field("python", min_length=1, max_length=20) + + +class GuidelineExample(BaseModel): + positive: str = Field(..., min_length=3) + negative: str = Field(..., min_length=3) + + class GuidelineContent(BaseModel): title: str = Field(..., min_length=6, max_length=100) details: str = Field(..., min_length=6, max_length=1000) +class ParsedGuideline(GuidelineContent): + repo_id: int = Field(..., gt=0) + origin_path: str + + class GuidelineLocation(BaseModel): repo_id: int = Field(..., gt=0) order: int = Field(0, ge=0, nullable=False) diff --git a/src/app/services/github.py b/src/app/services/github.py index 25d1f6b..a66c075 100644 --- a/src/app/services/github.py +++ b/src/app/services/github.py @@ -3,7 +3,8 @@ # All rights reserved. # Copying and/or distributing is strictly prohibited without the express permission of its copyright owner. -from typing import Any, Dict, Union +import logging +from typing import Any, Dict, List, Union import requests from fastapi import HTTPException, status @@ -13,6 +14,8 @@ from app.models import User, UserScope from app.schemas.services import GHToken, GHTokenRequest +logger = logging.getLogger("uvicorn.error") + __all__ = ["gh_client"] @@ -23,35 +26,46 @@ class GitHubClient: def __init__(self, token: Union[str, None] = None) -> None: self.headers = self._get_headers(token) - def _get_headers(self, token: Union[str, None] = None) -> Dict[str, str]: + @staticmethod + def _get_headers(token: Union[str, None] = None) -> Dict[str, str]: if isinstance(token, str): return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} return {"Content-Type": "application/json"} - def _get(self, route: str, token: Union[str, None] = None, timeout: int = 5) -> Dict[str, Any]: + def _get( + self, + route: str, + token: Union[str, None] = None, + timeout: int = 5, + status_code_tolerance: Union[int, None] = None, + **kwargs: Any, + ) -> requests.Response: response = requests.get( f"{self.ENDPOINT}/{route}", headers=self._get_headers(token) if isinstance(token, str) else self.headers, timeout=timeout, + params=kwargs, ) - json_response = response.json() - if response.status_code != status.HTTP_200_OK: + if response.status_code != status.HTTP_200_OK and ( + status_code_tolerance is None or response.status_code != status_code_tolerance + ): + json_response = response.json() raise HTTPException( status_code=response.status_code, detail=json_response.get("error", json_response["message"]) ) - return json_response + return response def get_repo(self, repo_id: int, **kwargs: Any) -> Dict[str, Any]: - return self._get(f"repositories/{repo_id}", **kwargs) + return self._get(f"repositories/{repo_id}", **kwargs).json() def get_user(self, user_id: int, **kwargs: Any) -> Dict[str, Any]: - return self._get(f"user/{user_id}", **kwargs) + return self._get(f"user/{user_id}", **kwargs).json() def get_my_user(self, token: str) -> Dict[str, Any]: - return self._get("user", token) + return self._get("user", token).json() - def get_permission(self, repo_name: str, user_name: str, github_token: str) -> str: - return self._get(f"repos/{repo_name}/collaborators/{user_name}/permission", github_token)["role_name"] + def get_permission(self, repo_name: str, user_name: str, token: str) -> str: + return self._get(f"repos/{repo_name}/collaborators/{user_name}/permission", token).json()["role_name"] def check_user_permission( self, @@ -94,5 +108,37 @@ def get_token_from_code(self, code: str, redirect_uri: HttpUrl, timeout: int = 5 raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=response.json()["error_description"]) return GHToken(**response.json()) + def get_readme(self, repo_name: str, token: Union[str, None] = None) -> Union[Dict[str, Any], None]: + # https://docs.github.com/en/rest/repos/contents#get-a-repository-readme + response = self._get(f"repos/{repo_name}/readme", token, status_code_tolerance=status.HTTP_404_NOT_FOUND) + return response.json() if response.status_code != status.HTTP_404_NOT_FOUND else None + + def get_file(self, repo_name: str, file_path: str, token: Union[str, None] = None) -> Union[Dict[str, Any], None]: + # https://docs.github.com/en/rest/repos/contents#get-repository-content + response = self._get( + f"repos/{repo_name}/contents/{file_path}", token, status_code_tolerance=status.HTTP_404_NOT_FOUND + ) + return response.json() if response.status_code != status.HTTP_404_NOT_FOUND else None + + def list_pulls(self, repo_name: str, token: Union[str, None] = None, per_page: int = 30) -> List[Dict[str, Any]]: + # https://docs.github.com/en/rest/pulls/pulls#list-pull-requests + return self._get( + f"repos/{repo_name}/pulls", + token, + state="closed", + sort="popularity", + direction="desc", + base=self._get(f"repos/{repo_name}", token).json()["default_branch"], + per_page=per_page, + ).json() + + def list_review_comments(self, repo_name: str, token: Union[str, None] = None): + # https://docs.github.com/en/rest/pulls/comments#list-review-comments-in-a-repository + comments = self._get( + f"repos/{repo_name}/pulls/comments", token, sort="created_at", direction="desc", per_page=100 + ).json() + # Get comments (filter account type == user, & user != author) --> take diff_hunk, body, path + return [comment for comment in comments if comment["user"]["type"] == "User"] + gh_client = GitHubClient(settings.GH_TOKEN) diff --git a/src/app/services/openai.py b/src/app/services/openai.py index 5fbb386..18fb391 100644 --- a/src/app/services/openai.py +++ b/src/app/services/openai.py @@ -16,6 +16,7 @@ from app.core.config import settings from app.models import Guideline from app.schemas.compute import ComplianceResult +from app.schemas.guidelines import GuidelineContent, GuidelineExample from app.schemas.services import ( ArraySchema, ChatCompletion, @@ -88,6 +89,58 @@ class ExecutionMode(str, Enum): "For a given compliance results, the comment should be an empty string if the code is compliant with the corresponding guideline." ) +PARSING_SCHEMA = ObjectSchema( + type="object", + properties={ + "result": ArraySchema( + type="array", + items=ObjectSchema( + type="object", + properties={ + "title": FieldSchema(type="string", description="a short summary title of the guideline"), + "details": FieldSchema( + type="string", + description="a descriptive, comprehensive and inambiguous explanation of the guideline.", + ), + }, + required=["title", "details"], + ), + ), + }, + required=["result"], +) + +PARSING_PROMPT = ( + "You are responsible for summarizing the list of distinct coding guidelines for the company, by going through documentation. " + "This list will be used by developers to avoid hesitations in code reviews and to onboard new members. " + "Consider only guidelines that can be verified for a specific snippet of code (nothing about git, commits or community interactions) " + "by a human developer without running additional commands or tools, it should only relate to the code within each file. " + "Only include guidelines for which you could generate positive and negative code snippets, " + "don't invent anything that isn't present in the input text. " + "You should answer in JSON format with only the list of string guidelines." +) + +EXAMPLE_SCHEMA = ObjectSchema( + type="object", + properties={ + "positive": FieldSchema( + type="string", description="a short code snippet where the instruction was correctly followed." + ), + "negative": FieldSchema( + type="string", description="the same snippet with minimal modification that invalidate the instruction." + ), + }, + required=["positive", "negative"], +) + +EXAMPLE_PROMPT = ( + "You are responsible for producing concise code snippets to illustrate the company coding guidelines. " + "This will be used to teach new developers our way of engineering software. " + "You should answer in JSON format with only two short code snippets in the specified programming language: one that follows the rule correctly, " + "and a similar version with minimal modifications that violates the rule. " + "Make sure your code is functional, don't extra comments or explanation." +) + class OpenAIClient: ENDPOINT: str = "https://api.openai.com/v1/chat/completions" @@ -111,7 +164,7 @@ def __init__( def _get_headers(api_key: str) -> Dict[str, str]: return {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} - def analyze_multi( + def check_code_against_guidelines( self, code: str, guidelines: List[Guideline], @@ -157,7 +210,7 @@ def analyze_multi( for guideline, res in zip(guidelines, parsed_response) ] - def analyze_mono(self, code: str, guideline: Guideline, **kwargs: Any) -> ComplianceResult: + def check_code(self, code: str, guideline: Guideline, **kwargs: Any) -> ComplianceResult: # Check args before sending a request if len(code) == 0 or len(guideline.details) == 0: raise HTTPException( @@ -167,11 +220,11 @@ def analyze_mono(self, code: str, guideline: Guideline, **kwargs: Any) -> Compli # Return with pydantic validation return ComplianceResult(guideline_id=guideline.id, **res) - def _analyze( + def _request( self, - prompt: str, - payload: Dict[str, Any], - schema: ObjectSchema, + system_prompt: str, + openai_fn: OpenAIFunction, + message: str, timeout: int = 20, user_id: Union[str, None] = None, ) -> Dict[str, Any]: @@ -181,21 +234,15 @@ def _analyze( messages=[ OpenAIMessage( role=OpenAIChatRole.SYSTEM, - content=prompt, + content=system_prompt, ), OpenAIMessage( role=OpenAIChatRole.USER, - content=json.dumps(payload), + content=message, ), ], - functions=[ - OpenAIFunction( - name="analyze_code", - description="Analyze code", - parameters=schema, - ) - ], - function_call={"name": "analyze_code"}, + functions=[openai_fn], + function_call={"name": openai_fn.name}, temperature=self.temperature, frequency_penalty=self.frequency_penalty, user=user_id, @@ -209,5 +256,77 @@ def _analyze( return json.loads(response.json()["choices"][0]["message"]["function_call"]["arguments"]) + def _analyze( + self, + prompt: str, + payload: Dict[str, Any], + schema: ObjectSchema, + timeout: int = 20, + user_id: Union[str, None] = None, + ) -> Dict[str, Any]: + return self._request( + prompt, + OpenAIFunction( + name="analyze_code", + description="Analyze code", + parameters=schema, + ), + json.dumps(payload), + timeout, + user_id, + ) + + def parse_guidelines_from_text( + self, corpus: str, timeout: int = 20, user_id: Union[str, None] = None + ) -> List[GuidelineContent]: + if not isinstance(corpus, str): + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="The input corpus needs to be a string." + ) + if len(corpus) == 0: + return [] + + response = self._request( + PARSING_PROMPT, + OpenAIFunction( + name="parse_guidelines_from_text", + description="Parse guidelines from corpus", + parameters=PARSING_SCHEMA, + ), + json.dumps(corpus), + timeout, + user_id, + ) + + return [GuidelineContent(**elt) for elt in response["result"]] + + def generate_examples_for_instruction( + self, instruction: str, language: str, timeout: int = 20, user_id: Union[str, None] = None + ) -> GuidelineExample: + if ( + not isinstance(instruction, str) + or len(instruction) == 0 + or not isinstance(language, str) + or len(language) == 0 + ): + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="The instruction and language need to be non-empty strings.", + ) + + return GuidelineExample( + **self._request( + EXAMPLE_PROMPT, + OpenAIFunction( + name="generate_examples_from_instruction", + description="Generate code examples for a coding instruction", + parameters=EXAMPLE_SCHEMA, + ), + json.dumps({"instruction": instruction, "language": language}), + timeout, + user_id, + ) + ) + openai_client = OpenAIClient(settings.OPENAI_API_KEY, settings.OPENAI_MODEL) diff --git a/src/tests/test_services.py b/src/tests/test_services.py new file mode 100644 index 0000000..10ae0c9 --- /dev/null +++ b/src/tests/test_services.py @@ -0,0 +1,124 @@ +import pytest +from fastapi import HTTPException + +from app.services.github import GitHubClient + + +@pytest.mark.parametrize( + ("repo_id", "status_code", "status_detail", "expected_name"), + [ + (100, 404, "Not Found", None), + (249513553, 200, None, "frgfm/torch-cam"), + ], +) +@pytest.mark.asyncio() +async def test_githubclient_get_repo(repo_id, status_code, status_detail, expected_name): + github_client = GitHubClient() + if isinstance(expected_name, str): + response = github_client.get_repo(repo_id) + assert response["full_name"] == expected_name + else: + response = github_client._get(f"repositories/{repo_id}", status_code_tolerance=status_code) + assert response.status_code == status_code + if isinstance(status_detail, str): + assert response.json()["message"] == status_detail + + +@pytest.mark.parametrize( + ("user_id", "status_code", "status_detail", "expected_name"), + [ + (1000000000, 404, "Not Found", None), + (26927750, 200, None, "frgfm"), + ], +) +@pytest.mark.asyncio() +async def test_githubclient_get_user(user_id, status_code, status_detail, expected_name): + github_client = GitHubClient() + if isinstance(expected_name, str): + response = github_client.get_user(user_id) + assert response["login"] == expected_name + else: + response = github_client._get(f"user/{user_id}", status_code_tolerance=status_code) + assert response.status_code == status_code + if isinstance(status_detail, str): + assert response.json()["message"] == status_detail + + +@pytest.mark.parametrize( + ("repo_name", "status_code", "status_detail", "expected_path"), + [ + ("frgfm/hola", 404, "Not Found", None), + ("frgfm/torch-cam", 200, None, "README.md"), + ], +) +@pytest.mark.asyncio() +async def test_githubclient_get_readme(repo_name, status_code, status_detail, expected_path): + github_client = GitHubClient() + if isinstance(expected_path, str): + response = github_client.get_readme(repo_name) + assert response["path"] == expected_path + else: + response = github_client._get(f"repos/{repo_name}/readme", status_code_tolerance=status_code) + assert response.status_code == status_code + if isinstance(status_detail, str): + assert response.json()["message"] == status_detail + + +@pytest.mark.parametrize( + ("repo_name", "file_path", "status_code", "status_detail"), + [ + ("frgfm/hola", "CONTRIBUTING.md", 404, "Not Found"), + ("frgfm/torch-cam", "Hola.md", 404, "Not Found"), + ("frgfm/torch-cam", "CONTRIBUTING.md", 200, None), + ], +) +@pytest.mark.asyncio() +async def test_githubclient_get_file(repo_name, file_path, status_code, status_detail): + github_client = GitHubClient() + if status_code // 100 == 2: + response = github_client.get_file(repo_name, file_path) + assert isinstance(response, dict) + assert response["path"] == file_path + else: + response = github_client._get(f"repos/{repo_name}/contents/{file_path}", status_code_tolerance=status_code) + assert response.status_code == status_code + if isinstance(status_detail, str): + assert response.json()["message"] == status_detail + + +@pytest.mark.parametrize( + ("repo_name", "status_code", "status_detail"), + [ + ("frgfm/hola", 404, "Not Found"), + ("frgfm/torch-cam", 200, None), + ], +) +@pytest.mark.asyncio() +async def test_githubclient_list_pulls(repo_name, status_code, status_detail): + github_client = GitHubClient() + if status_code // 100 == 2: + response = github_client.list_pulls(repo_name) + assert isinstance(response, list) + assert all(isinstance(elt, dict) for elt in response) + else: + with pytest.raises(HTTPException): + github_client.list_pulls(repo_name) + + +@pytest.mark.parametrize( + ("repo_name", "status_code", "status_detail"), + [ + ("frgfm/hola", 404, "Not Found"), + ("frgfm/torch-cam", 200, None), + ], +) +@pytest.mark.asyncio() +async def test_githubclient_list_review_comments(repo_name, status_code, status_detail): + github_client = GitHubClient() + if status_code // 100 == 2: + response = github_client.list_review_comments(repo_name) + assert isinstance(response, list) + assert all(isinstance(elt, dict) for elt in response) + else: + with pytest.raises(HTTPException): + github_client.list_review_comments(repo_name)