Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
feat: Adds compute route using OpenAI API (#21)
Browse files Browse the repository at this point in the history
* feat: Adds compute route

* fix: Fixes typo in schema

* fix: Fixes typo in frequency penalty

* style: Fixes typing

* feat: Improved guideline output

* feat: Improves docker env

* feat: Adds env var to test docker

* ci: Updates labeler

* feat: Adds safeguard

* ci: Updates CI jobs
  • Loading branch information
frgfm authored Oct 30, 2023
1 parent 8f9b4ec commit 9ba44da
Show file tree
Hide file tree
Showing 12 changed files with 259 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
'endpoint: guidelines':
- src/app/api/*/endpoints/guidelines.py

'endpoint: compute':
- src/app/api/*/endpoints/compute.py

'topic: build':
- pyproject.toml
- poetry.lock
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/builds.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs:
SUPERUSER_PWD: ${{ secrets.SUPERUSER_PWD }}
GH_OAUTH_ID: ${{ secrets.GH_OAUTH_ID }}
GH_OAUTH_SECRET: ${{ secrets.GH_OAUTH_SECRET }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: docker-compose up -d --build
- name: Docker sanity check
run: sleep 20 && nc -vz api.localhost 8050
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/scripts.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs:
SUPERUSER_PWD: ${{ secrets.SUPERUSER_PWD }}
GH_OAUTH_ID: ${{ secrets.GH_OAUTH_ID }}
GH_OAUTH_SECRET: ${{ secrets.GH_OAUTH_SECRET }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: docker compose -f docker-compose.test.yml up -d --build
- name: Docker sanity check
run: sleep 20 && nc -vz localhost 8050
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
SUPERUSER_PWD: ${{ secrets.SUPERUSER_PWD }}
GH_OAUTH_ID: ${{ secrets.GH_OAUTH_ID }}
GH_OAUTH_SECRET: ${{ secrets.GH_OAUTH_SECRET }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: docker compose -f docker-compose.test.yml up -d --build
- name: Run docker test
run: |
Expand Down
1 change: 1 addition & 0 deletions docker-compose.test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ services:
- SUPERUSER_PWD=${SUPERUSER_PWD}
- GH_OAUTH_ID=${GH_OAUTH_ID}
- GH_OAUTH_SECRET=${GH_OAUTH_SECRET}
- OPENAI_API_KEY=${OPENAI_API_KEY}
- SERVER_NAME=dummy_server
- DEBUG=true
depends_on:
Expand Down
2 changes: 2 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ services:
- SUPERUSER_PWD=${SUPERUSER_PWD}
- GH_OAUTH_ID=${GH_OAUTH_ID}
- GH_OAUTH_SECRET=${GH_OAUTH_SECRET}
- OPENAI_API_KEY=${OPENAI_API_KEY}
- SERVER_NAME=${SERVER_NAME}
- SECRET_KEY=${SECRET_KEY}
- DEBUG=true
depends_on:
db:
Expand Down
33 changes: 33 additions & 0 deletions src/app/api/api_v1/endpoints/compute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (C) 2023, Quack AI.

# All rights reserved.
# Copying and/or distributing is strictly prohibited without the express permission of its copyright owner.

from typing import List

from fastapi import APIRouter, Depends, Security, status

from app.api.dependencies import get_current_user, get_guideline_crud, get_repo_crud
from app.crud import GuidelineCRUD, RepositoryCRUD
from app.models import UserScope
from app.schemas.compute import ComplianceResult, Snippet
from app.services.openai import openai_client
from app.services.telemetry import telemetry_client

router = APIRouter()


@router.post("/analyze", status_code=status.HTTP_200_OK)
async def analyze_snippet(
payload: Snippet,
repos: RepositoryCRUD = Depends(get_repo_crud),
guidelines: GuidelineCRUD = Depends(get_guideline_crud),
user=Security(get_current_user, scopes=[UserScope.ADMIN, UserScope.USER]),
) -> List[ComplianceResult]:
telemetry_client.capture(user.id, event="snippet-analysis", properties={"repo_id": payload.repo_id})
# Check repo
await repos.get(payload.repo_id, strict=True)
# Fetch guidelines
guideline_list = [elt for elt in await guidelines.fetch_all(("repo_id", payload.repo_id))]
# Run analysis
return openai_client.analyze_snippet(payload.code, guideline_list)
3 changes: 2 additions & 1 deletion src/app/api/api_v1/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

from fastapi import APIRouter

from app.api.api_v1.endpoints import guidelines, login, repos, users
from app.api.api_v1.endpoints import compute, guidelines, login, repos, users

api_router = APIRouter()
api_router.include_router(login.router, prefix="/login", tags=["login"])
api_router.include_router(users.router, prefix="/users", tags=["access"])
api_router.include_router(repos.router, prefix="/repos", tags=["repos"])
api_router.include_router(guidelines.router, prefix="/guidelines", tags=["guidelines"])
api_router.include_router(compute.router, prefix="/compute", tags=["compute"])
5 changes: 5 additions & 0 deletions src/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from pydantic import BaseSettings, validator

from app.schemas.services import OpenAIModel

__all__ = ["settings"]


Expand All @@ -34,6 +36,9 @@ class Settings(BaseSettings):
SUPERUSER_LOGIN: str = os.environ["SUPERUSER_LOGIN"]
SUPERUSER_ID: int = int(os.environ["SUPERUSER_ID"])
SUPERUSER_PWD: str = os.environ["SUPERUSER_PWD"]
# Compute
OPENAI_API_KEY: str = os.environ["OPENAI_API_KEY"]
OPENAI_MODEL: OpenAIModel = OpenAIModel.GPT3_5

@validator("POSTGRES_URL", pre=True)
@classmethod
Expand Down
20 changes: 20 additions & 0 deletions src/app/schemas/compute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (C) 2023, Quack AI.

# All rights reserved.
# Copying and/or distributing is strictly prohibited without the express permission of its copyright owner.


from pydantic import BaseModel, Field

__all__ = ["Snippet", "ComplianceResult"]


class Snippet(BaseModel):
code: str = Field(..., min_length=1)
repo_id: int = Field(..., gt=0)


class ComplianceResult(BaseModel):
guideline_id: int = Field(..., gt=0)
is_compliant: bool
comment: str
61 changes: 61 additions & 0 deletions src/app/schemas/services.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (C) 2023, Quack AI.

# All rights reserved.
# Copying and/or distributing is strictly prohibited without the express permission of its copyright owner.

from enum import Enum
from typing import Any, Dict, List

from pydantic import BaseModel

__all__ = ["ChatCompletion"]


class OpenAIModel(str, Enum):
# https://platform.openai.com/docs/models/overview
GPT3_5: str = "gpt-3.5-turbo-0613"
GPT3_5_LONG: str = "gpt-3.5-turbo-16k-0613"
GPT4: str = "gpt-4-0613"
GPT4_LONG: str = "gpt-4-32k-0613"


class OpenAIChatRole(str, Enum):
SYSTEM: str = "system"
USER: str = "user"
ASSISTANT: str = "assistant"


class FieldSchema(BaseModel):
type: str
description: str


class ObjectSchema(BaseModel):
type: str = "object"
properties: Dict[str, Any]
required: List[str]


class ArraySchema(BaseModel):
type: str = "array"
items: ObjectSchema


class OpenAIFunction(BaseModel):
name: str
description: str
parameters: ObjectSchema


class OpenAIMessage(BaseModel):
role: OpenAIChatRole
content: str


class ChatCompletion(BaseModel):
model: OpenAIModel = OpenAIModel.GPT3_5
messages: List[OpenAIMessage]
functions: List[OpenAIFunction]
function_call: Dict[str, str]
temperature: float
frequency_penalty: float
129 changes: 129 additions & 0 deletions src/app/services/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright (C) 2023, Quack AI.

# All rights reserved.
# Copying and/or distributing is strictly prohibited without the express permission of its copyright owner.

import json
import logging
from datetime import datetime
from typing import List

import requests
from fastapi import HTTPException, status

from app.core.config import settings
from app.models import Guideline
from app.schemas.compute import ComplianceResult
from app.schemas.services import (
ArraySchema,
ChatCompletion,
FieldSchema,
ObjectSchema,
OpenAIChatRole,
OpenAIFunction,
OpenAIMessage,
OpenAIModel,
)

logger = logging.getLogger("uvicorn.error")

__all__ = ["openai_client"]


RESPONSE_SCHEMA = ObjectSchema(
type="object",
properties={
"result": ArraySchema(
type="array",
items=ObjectSchema(
type="object",
properties={
"is_compliant": FieldSchema(
type="boolean", description="whether the guideline has been followed in the code snippet"
),
"comment": FieldSchema(
type="string",
description="instruction to make the snippet compliant, addressed to the developer who wrote the code. Should be empty if the snippet is compliant",
),
# "suggestion": FieldSchema(type="string", description="the modified code snippet that meets the guideline, with minimal modifications. Should be empty if the snippet is compliant"),
},
required=["is_compliant", "comment"],
),
),
},
required=["result"],
)


class OpenAIClient:
ENDPOINT: str = "https://api.openai.com/v1/chat/completions"

def __init__(
self, api_key: str, model: OpenAIModel, temperature: float = 0.0, frequency_penalty: float = 1.0
) -> None:
self.headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
# Validate model
model_card = requests.get(f"https://api.openai.com/v1/models/{model}", headers=self.headers, timeout=2)
if model_card.status_code != 200:
raise HTTPException(status_code=model_card.status_code, detail=model_card.json()["error"]["message"])
self.model = model
self.temperature = temperature
self.frequency_penalty = frequency_penalty
logger.info(
f"Using OpenAI model: {self.model} (created at {datetime.fromtimestamp(model_card.json()['created']).isoformat()})"
)

def analyze_snippet(self, code: str, guidelines: List[Guideline], timeout: int = 10) -> List[ComplianceResult]:
# Check args before sending a request
if len(code) == 0 or len(guidelines) == 0 or any(len(guideline.details) == 0 for guideline in guidelines):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="No code or guideline provided for analysis."
)
# Ideas: check which programming language & whether it's correct code
# Prepare the request
payload = ChatCompletion(
model=self.model,
messages=[
OpenAIMessage(
role=OpenAIChatRole.SYSTEM,
content=(
"As a code compliance agent, you are going to receive requests from user with two elements: a code snippet, and a list of guidelines. "
"You should answer with a list of compliance results, one for each guideline (in the same order). "
"For a given compliance results, the comment should be an empty string if the code is compliant with the corresponding guideline."
),
),
OpenAIMessage(
role=OpenAIChatRole.USER,
content=json.dumps({"code": code, "guidelines": [guideline.details for guideline in guidelines]}),
),
],
functions=[
OpenAIFunction(
name="analyze_code",
description="Check code against a set of guidelines",
parameters=RESPONSE_SCHEMA,
)
],
function_call={"name": "analyze_code"},
temperature=self.temperature,
frequency_penalty=self.frequency_penalty,
)
# Send the request
response = requests.post(self.ENDPOINT, json=payload.dict(), headers=self.headers, timeout=10)

# Check status
if response.status_code != 200:
raise HTTPException(status_code=response.status_code, detail=response.json()["error"]["message"])
# Ideas: check the returned code can run
parsed_response = json.loads(response.json()["choices"][0]["message"]["function_call"]["arguments"])["result"]
if len(parsed_response) != len(guidelines):
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Invalid model response")
# Return with pydantic validation
return [
# ComplianceResult(is_compliant=res["is_compliant"], comment="" if res["is_compliant"] else res["comment"])
ComplianceResult(guideline_id=guideline.id, **res)
for guideline, res in zip(guidelines, parsed_response)
]


openai_client = OpenAIClient(settings.OPENAI_API_KEY, settings.OPENAI_MODEL)

0 comments on commit 9ba44da

Please sign in to comment.