From cbfec5f380d7db1bf4b091a6fcbb72eb07993b3a Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Tue, 31 Oct 2023 20:06:13 +0100 Subject: [PATCH] style: Fixes typing --- src/app/api/api_v1/endpoints/compute.py | 6 +++--- src/app/services/openai.py | 4 +--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/app/api/api_v1/endpoints/compute.py b/src/app/api/api_v1/endpoints/compute.py index 656e0ab..b2cd595 100644 --- a/src/app/api/api_v1/endpoints/compute.py +++ b/src/app/api/api_v1/endpoints/compute.py @@ -3,13 +3,13 @@ # All rights reserved. # Copying and/or distributing is strictly prohibited without the express permission of its copyright owner. -from typing import List +from typing import List, cast from fastapi import APIRouter, Depends, Path, Security, status from app.api.dependencies import get_current_user, get_guideline_crud, get_repo_crud from app.crud import GuidelineCRUD, RepositoryCRUD -from app.models import UserScope +from app.models import Guideline, UserScope from app.schemas.compute import ComplianceResult, Snippet from app.services.openai import ExecutionMode, openai_client from app.services.telemetry import telemetry_client @@ -42,7 +42,7 @@ async def check_code_against_guideline( user=Security(get_current_user, scopes=[UserScope.ADMIN, UserScope.USER]), ) -> ComplianceResult: # Check repo - guideline = await guidelines.get(guideline_id, strict=True) + guideline = cast(Guideline, await guidelines.get(guideline_id, strict=True)) telemetry_client.capture( user.id, event="compute-check", properties={"repo_id": guideline.repo_id, "guideline_id": guideline_id} ) diff --git a/src/app/services/openai.py b/src/app/services/openai.py index d0d43e9..d05706c 100644 --- a/src/app/services/openai.py +++ b/src/app/services/openai.py @@ -155,9 +155,7 @@ def analyze_mono(self, code: str, guideline: Guideline, timeout: int = 10) -> Co # Return with pydantic validation return ComplianceResult(guideline_id=guideline.id, **res) - def _analyze( - self, prompt: str, payload: Dict[str, Any], schema: ObjectSchema, timeout: int = 10 - ) -> ComplianceResult: + def _analyze(self, prompt: str, payload: Dict[str, Any], schema: ObjectSchema, timeout: int = 10) -> Dict[str, Any]: # Prepare the request _payload = ChatCompletion( model=self.model,