Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
fix: Fixes the route input schemas for repos & guidelines (#25)
Browse files Browse the repository at this point in the history
* fix: Fixes schemas

* fix: Fix router doc

* refactor: Refactors args
  • Loading branch information
frgfm authored Nov 5, 2023
1 parent 9852d78 commit a8840fa
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/app/api/api_v1/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

api_router = APIRouter()
api_router.include_router(login.router, prefix="/login", tags=["login"])
api_router.include_router(users.router, prefix="/users", tags=["access"])
api_router.include_router(users.router, prefix="/users", tags=["users"])
api_router.include_router(repos.router, prefix="/repos", tags=["repos"])
api_router.include_router(guidelines.router, prefix="/guidelines", tags=["guidelines"])
api_router.include_router(compute.router, prefix="/compute", tags=["compute"])
5 changes: 1 addition & 4 deletions src/app/schemas/guidelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

from pydantic import BaseModel, Field

from app.schemas.base import _CreatedAt

__all__ = ["GuidelineCreate", "GuidelineEdit", "ContentUpdate", "OrderUpdate"]


Expand All @@ -17,9 +15,8 @@ class GuidelineEdit(BaseModel):
details: str = Field(..., min_length=6, max_length=1000)


class GuidelineCreate(_CreatedAt, GuidelineEdit):
class GuidelineCreate(GuidelineEdit):
repo_id: int = Field(..., gt=0)
updated_at: datetime = Field(default_factory=datetime.utcnow, nullable=False)
order: int = Field(..., ge=0, nullable=False)


Expand Down
1 change: 0 additions & 1 deletion src/app/schemas/repos.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
class RepoCreate(_Id):
owner_id: int = Field(..., gt=0)
full_name: str = Field(..., example="frgfm/torch-cam")
installed_at: datetime = Field(default_factory=datetime.utcnow, nullable=False)


class RepoCreation(RepoCreate):
Expand Down
20 changes: 14 additions & 6 deletions src/app/services/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,11 @@ def __init__(
)

def analyze_multi(
self, code: str, guidelines: List[Guideline], timeout: int = 10, mode: ExecutionMode = ExecutionMode.SINGLE
self,
code: str,
guidelines: List[Guideline],
mode: ExecutionMode = ExecutionMode.SINGLE,
**kwargs: Any,
) -> List[ComplianceResult]:
# Check args before sending a request
if len(code) == 0 or len(guidelines) == 0 or any(len(guideline.details) == 0 for guideline in guidelines):
Expand All @@ -121,15 +125,19 @@ def analyze_multi(
MULTI_PROMPT,
{"code": code, "guidelines": [guideline.details for guideline in guidelines]},
MULTI_SCHEMA,
timeout,
**kwargs,
)["result"]
if len(parsed_response) != len(guidelines):
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Invalid model response")
elif mode == ExecutionMode.MULTI:
with ThreadPoolExecutor() as executor:
tasks = [
executor.submit(
self._analyze, MONO_PROMPT, {"code": code, "guideline": guideline.details}, MONO_SCHEMA, timeout
self._analyze,
MONO_PROMPT,
{"code": code, "guideline": guideline.details},
MONO_SCHEMA,
**kwargs,
)
for guideline in guidelines
]
Expand All @@ -145,17 +153,17 @@ def analyze_multi(
for guideline, res in zip(guidelines, parsed_response)
]

def analyze_mono(self, code: str, guideline: Guideline, timeout: int = 10) -> ComplianceResult:
def analyze_mono(self, code: str, guideline: Guideline, **kwargs: Any) -> ComplianceResult:
# Check args before sending a request
if len(code) == 0 or len(guideline.details) == 0:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="No code or guideline provided for analysis."
)
res = self._analyze(MONO_PROMPT, {"code": code, "guideline": guideline.details}, MONO_SCHEMA, timeout)
res = self._analyze(MONO_PROMPT, {"code": code, "guideline": guideline.details}, MONO_SCHEMA, **kwargs)
# Return with pydantic validation
return ComplianceResult(guideline_id=guideline.id, **res)

def _analyze(self, prompt: str, payload: Dict[str, Any], schema: ObjectSchema, timeout: int = 10) -> Dict[str, Any]:
def _analyze(self, prompt: str, payload: Dict[str, Any], schema: ObjectSchema, timeout: int = 20) -> Dict[str, Any]:
# Prepare the request
_payload = ChatCompletion(
model=self.model,
Expand Down

0 comments on commit a8840fa

Please sign in to comment.