From 9ba44da0c7dfe1249795efe416c070623b543d54 Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Mon, 30 Oct 2023 13:49:37 +0100 Subject: [PATCH] feat: Adds compute route using OpenAI API (#21) * feat: Adds compute route * fix: Fixes typo in schema * fix: Fixes typo in frequency penalty * style: Fixes typing * feat: Improved guideline output * feat: Improves docker env * feat: Adds env var to test docker * ci: Updates labeler * feat: Adds safeguard * ci: Updates CI jobs --- .github/labeler.yml | 3 + .github/workflows/builds.yml | 1 + .github/workflows/scripts.yml | 1 + .github/workflows/tests.yml | 1 + docker-compose.test.yml | 1 + docker-compose.yml | 2 + src/app/api/api_v1/endpoints/compute.py | 33 ++++++ src/app/api/api_v1/router.py | 3 +- src/app/core/config.py | 5 + src/app/schemas/compute.py | 20 ++++ src/app/schemas/services.py | 61 +++++++++++ src/app/services/openai.py | 129 ++++++++++++++++++++++++ 12 files changed, 259 insertions(+), 1 deletion(-) create mode 100644 src/app/api/api_v1/endpoints/compute.py create mode 100644 src/app/schemas/compute.py create mode 100644 src/app/schemas/services.py create mode 100644 src/app/services/openai.py diff --git a/.github/labeler.yml b/.github/labeler.yml index 11d6506..01b709e 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -38,6 +38,9 @@ 'endpoint: guidelines': - src/app/api/*/endpoints/guidelines.py +'endpoint: compute': +- src/app/api/*/endpoints/compute.py + 'topic: build': - pyproject.toml - poetry.lock diff --git a/.github/workflows/builds.yml b/.github/workflows/builds.yml index debc483..4245eff 100644 --- a/.github/workflows/builds.yml +++ b/.github/workflows/builds.yml @@ -32,6 +32,7 @@ jobs: SUPERUSER_PWD: ${{ secrets.SUPERUSER_PWD }} GH_OAUTH_ID: ${{ secrets.GH_OAUTH_ID }} GH_OAUTH_SECRET: ${{ secrets.GH_OAUTH_SECRET }} + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: docker-compose up -d --build - name: Docker sanity check run: sleep 20 && nc -vz api.localhost 8050 diff --git a/.github/workflows/scripts.yml b/.github/workflows/scripts.yml index df87e8d..b262654 100644 --- a/.github/workflows/scripts.yml +++ b/.github/workflows/scripts.yml @@ -32,6 +32,7 @@ jobs: SUPERUSER_PWD: ${{ secrets.SUPERUSER_PWD }} GH_OAUTH_ID: ${{ secrets.GH_OAUTH_ID }} GH_OAUTH_SECRET: ${{ secrets.GH_OAUTH_SECRET }} + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: docker compose -f docker-compose.test.yml up -d --build - name: Docker sanity check run: sleep 20 && nc -vz localhost 8050 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index cdd3888..66741a9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -26,6 +26,7 @@ jobs: SUPERUSER_PWD: ${{ secrets.SUPERUSER_PWD }} GH_OAUTH_ID: ${{ secrets.GH_OAUTH_ID }} GH_OAUTH_SECRET: ${{ secrets.GH_OAUTH_SECRET }} + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: docker compose -f docker-compose.test.yml up -d --build - name: Run docker test run: | diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 281946b..3f96dd6 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -15,6 +15,7 @@ services: - SUPERUSER_PWD=${SUPERUSER_PWD} - GH_OAUTH_ID=${GH_OAUTH_ID} - GH_OAUTH_SECRET=${GH_OAUTH_SECRET} + - OPENAI_API_KEY=${OPENAI_API_KEY} - SERVER_NAME=dummy_server - DEBUG=true depends_on: diff --git a/docker-compose.yml b/docker-compose.yml index 7715b89..4611a5c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -15,7 +15,9 @@ services: - SUPERUSER_PWD=${SUPERUSER_PWD} - GH_OAUTH_ID=${GH_OAUTH_ID} - GH_OAUTH_SECRET=${GH_OAUTH_SECRET} + - OPENAI_API_KEY=${OPENAI_API_KEY} - SERVER_NAME=${SERVER_NAME} + - SECRET_KEY=${SECRET_KEY} - DEBUG=true depends_on: db: diff --git a/src/app/api/api_v1/endpoints/compute.py b/src/app/api/api_v1/endpoints/compute.py new file mode 100644 index 0000000..b3e6096 --- /dev/null +++ b/src/app/api/api_v1/endpoints/compute.py @@ -0,0 +1,33 @@ +# Copyright (C) 2023, Quack AI. + +# All rights reserved. +# Copying and/or distributing is strictly prohibited without the express permission of its copyright owner. + +from typing import List + +from fastapi import APIRouter, Depends, Security, status + +from app.api.dependencies import get_current_user, get_guideline_crud, get_repo_crud +from app.crud import GuidelineCRUD, RepositoryCRUD +from app.models import UserScope +from app.schemas.compute import ComplianceResult, Snippet +from app.services.openai import openai_client +from app.services.telemetry import telemetry_client + +router = APIRouter() + + +@router.post("/analyze", status_code=status.HTTP_200_OK) +async def analyze_snippet( + payload: Snippet, + repos: RepositoryCRUD = Depends(get_repo_crud), + guidelines: GuidelineCRUD = Depends(get_guideline_crud), + user=Security(get_current_user, scopes=[UserScope.ADMIN, UserScope.USER]), +) -> List[ComplianceResult]: + telemetry_client.capture(user.id, event="snippet-analysis", properties={"repo_id": payload.repo_id}) + # Check repo + await repos.get(payload.repo_id, strict=True) + # Fetch guidelines + guideline_list = [elt for elt in await guidelines.fetch_all(("repo_id", payload.repo_id))] + # Run analysis + return openai_client.analyze_snippet(payload.code, guideline_list) diff --git a/src/app/api/api_v1/router.py b/src/app/api/api_v1/router.py index 0836ffd..6ee93d6 100644 --- a/src/app/api/api_v1/router.py +++ b/src/app/api/api_v1/router.py @@ -5,10 +5,11 @@ from fastapi import APIRouter -from app.api.api_v1.endpoints import guidelines, login, repos, users +from app.api.api_v1.endpoints import compute, guidelines, login, repos, users api_router = APIRouter() api_router.include_router(login.router, prefix="/login", tags=["login"]) api_router.include_router(users.router, prefix="/users", tags=["access"]) api_router.include_router(repos.router, prefix="/repos", tags=["repos"]) api_router.include_router(guidelines.router, prefix="/guidelines", tags=["guidelines"]) +api_router.include_router(compute.router, prefix="/compute", tags=["compute"]) diff --git a/src/app/core/config.py b/src/app/core/config.py index 847a672..9fde3e6 100644 --- a/src/app/core/config.py +++ b/src/app/core/config.py @@ -9,6 +9,8 @@ from pydantic import BaseSettings, validator +from app.schemas.services import OpenAIModel + __all__ = ["settings"] @@ -34,6 +36,9 @@ class Settings(BaseSettings): SUPERUSER_LOGIN: str = os.environ["SUPERUSER_LOGIN"] SUPERUSER_ID: int = int(os.environ["SUPERUSER_ID"]) SUPERUSER_PWD: str = os.environ["SUPERUSER_PWD"] + # Compute + OPENAI_API_KEY: str = os.environ["OPENAI_API_KEY"] + OPENAI_MODEL: OpenAIModel = OpenAIModel.GPT3_5 @validator("POSTGRES_URL", pre=True) @classmethod diff --git a/src/app/schemas/compute.py b/src/app/schemas/compute.py new file mode 100644 index 0000000..34a5b9e --- /dev/null +++ b/src/app/schemas/compute.py @@ -0,0 +1,20 @@ +# Copyright (C) 2023, Quack AI. + +# All rights reserved. +# Copying and/or distributing is strictly prohibited without the express permission of its copyright owner. + + +from pydantic import BaseModel, Field + +__all__ = ["Snippet", "ComplianceResult"] + + +class Snippet(BaseModel): + code: str = Field(..., min_length=1) + repo_id: int = Field(..., gt=0) + + +class ComplianceResult(BaseModel): + guideline_id: int = Field(..., gt=0) + is_compliant: bool + comment: str diff --git a/src/app/schemas/services.py b/src/app/schemas/services.py new file mode 100644 index 0000000..79c3938 --- /dev/null +++ b/src/app/schemas/services.py @@ -0,0 +1,61 @@ +# Copyright (C) 2023, Quack AI. + +# All rights reserved. +# Copying and/or distributing is strictly prohibited without the express permission of its copyright owner. + +from enum import Enum +from typing import Any, Dict, List + +from pydantic import BaseModel + +__all__ = ["ChatCompletion"] + + +class OpenAIModel(str, Enum): + # https://platform.openai.com/docs/models/overview + GPT3_5: str = "gpt-3.5-turbo-0613" + GPT3_5_LONG: str = "gpt-3.5-turbo-16k-0613" + GPT4: str = "gpt-4-0613" + GPT4_LONG: str = "gpt-4-32k-0613" + + +class OpenAIChatRole(str, Enum): + SYSTEM: str = "system" + USER: str = "user" + ASSISTANT: str = "assistant" + + +class FieldSchema(BaseModel): + type: str + description: str + + +class ObjectSchema(BaseModel): + type: str = "object" + properties: Dict[str, Any] + required: List[str] + + +class ArraySchema(BaseModel): + type: str = "array" + items: ObjectSchema + + +class OpenAIFunction(BaseModel): + name: str + description: str + parameters: ObjectSchema + + +class OpenAIMessage(BaseModel): + role: OpenAIChatRole + content: str + + +class ChatCompletion(BaseModel): + model: OpenAIModel = OpenAIModel.GPT3_5 + messages: List[OpenAIMessage] + functions: List[OpenAIFunction] + function_call: Dict[str, str] + temperature: float + frequency_penalty: float diff --git a/src/app/services/openai.py b/src/app/services/openai.py new file mode 100644 index 0000000..6c5c107 --- /dev/null +++ b/src/app/services/openai.py @@ -0,0 +1,129 @@ +# Copyright (C) 2023, Quack AI. + +# All rights reserved. +# Copying and/or distributing is strictly prohibited without the express permission of its copyright owner. + +import json +import logging +from datetime import datetime +from typing import List + +import requests +from fastapi import HTTPException, status + +from app.core.config import settings +from app.models import Guideline +from app.schemas.compute import ComplianceResult +from app.schemas.services import ( + ArraySchema, + ChatCompletion, + FieldSchema, + ObjectSchema, + OpenAIChatRole, + OpenAIFunction, + OpenAIMessage, + OpenAIModel, +) + +logger = logging.getLogger("uvicorn.error") + +__all__ = ["openai_client"] + + +RESPONSE_SCHEMA = ObjectSchema( + type="object", + properties={ + "result": ArraySchema( + type="array", + items=ObjectSchema( + type="object", + properties={ + "is_compliant": FieldSchema( + type="boolean", description="whether the guideline has been followed in the code snippet" + ), + "comment": FieldSchema( + type="string", + description="instruction to make the snippet compliant, addressed to the developer who wrote the code. Should be empty if the snippet is compliant", + ), + # "suggestion": FieldSchema(type="string", description="the modified code snippet that meets the guideline, with minimal modifications. Should be empty if the snippet is compliant"), + }, + required=["is_compliant", "comment"], + ), + ), + }, + required=["result"], +) + + +class OpenAIClient: + ENDPOINT: str = "https://api.openai.com/v1/chat/completions" + + def __init__( + self, api_key: str, model: OpenAIModel, temperature: float = 0.0, frequency_penalty: float = 1.0 + ) -> None: + self.headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} + # Validate model + model_card = requests.get(f"https://api.openai.com/v1/models/{model}", headers=self.headers, timeout=2) + if model_card.status_code != 200: + raise HTTPException(status_code=model_card.status_code, detail=model_card.json()["error"]["message"]) + self.model = model + self.temperature = temperature + self.frequency_penalty = frequency_penalty + logger.info( + f"Using OpenAI model: {self.model} (created at {datetime.fromtimestamp(model_card.json()['created']).isoformat()})" + ) + + def analyze_snippet(self, code: str, guidelines: List[Guideline], timeout: int = 10) -> List[ComplianceResult]: + # Check args before sending a request + if len(code) == 0 or len(guidelines) == 0 or any(len(guideline.details) == 0 for guideline in guidelines): + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="No code or guideline provided for analysis." + ) + # Ideas: check which programming language & whether it's correct code + # Prepare the request + payload = ChatCompletion( + model=self.model, + messages=[ + OpenAIMessage( + role=OpenAIChatRole.SYSTEM, + content=( + "As a code compliance agent, you are going to receive requests from user with two elements: a code snippet, and a list of guidelines. " + "You should answer with a list of compliance results, one for each guideline (in the same order). " + "For a given compliance results, the comment should be an empty string if the code is compliant with the corresponding guideline." + ), + ), + OpenAIMessage( + role=OpenAIChatRole.USER, + content=json.dumps({"code": code, "guidelines": [guideline.details for guideline in guidelines]}), + ), + ], + functions=[ + OpenAIFunction( + name="analyze_code", + description="Check code against a set of guidelines", + parameters=RESPONSE_SCHEMA, + ) + ], + function_call={"name": "analyze_code"}, + temperature=self.temperature, + frequency_penalty=self.frequency_penalty, + ) + # Send the request + response = requests.post(self.ENDPOINT, json=payload.dict(), headers=self.headers, timeout=10) + + # Check status + if response.status_code != 200: + raise HTTPException(status_code=response.status_code, detail=response.json()["error"]["message"]) + # Ideas: check the returned code can run + parsed_response = json.loads(response.json()["choices"][0]["message"]["function_call"]["arguments"])["result"] + if len(parsed_response) != len(guidelines): + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Invalid model response") + # Return with pydantic validation + return [ + # ComplianceResult(is_compliant=res["is_compliant"], comment="" if res["is_compliant"] else res["comment"]) + ComplianceResult(guideline_id=guideline.id, **res) + for guideline, res in zip(guidelines, parsed_response) + ] + + +openai_client = OpenAIClient(settings.OPENAI_API_KEY, settings.OPENAI_MODEL)