diff --git a/src/app/services/openai.py b/src/app/services/openai.py index d05706c..b5fa6ec 100644 --- a/src/app/services/openai.py +++ b/src/app/services/openai.py @@ -108,7 +108,11 @@ def __init__( ) def analyze_multi( - self, code: str, guidelines: List[Guideline], timeout: int = 10, mode: ExecutionMode = ExecutionMode.SINGLE + self, + code: str, + guidelines: List[Guideline], + mode: ExecutionMode = ExecutionMode.SINGLE, + **kwargs: Any, ) -> List[ComplianceResult]: # Check args before sending a request if len(code) == 0 or len(guidelines) == 0 or any(len(guideline.details) == 0 for guideline in guidelines): @@ -121,7 +125,7 @@ def analyze_multi( MULTI_PROMPT, {"code": code, "guidelines": [guideline.details for guideline in guidelines]}, MULTI_SCHEMA, - timeout, + **kwargs, )["result"] if len(parsed_response) != len(guidelines): raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Invalid model response") @@ -129,7 +133,11 @@ def analyze_multi( with ThreadPoolExecutor() as executor: tasks = [ executor.submit( - self._analyze, MONO_PROMPT, {"code": code, "guideline": guideline.details}, MONO_SCHEMA, timeout + self._analyze, + MONO_PROMPT, + {"code": code, "guideline": guideline.details}, + MONO_SCHEMA, + **kwargs, ) for guideline in guidelines ] @@ -145,17 +153,17 @@ def analyze_multi( for guideline, res in zip(guidelines, parsed_response) ] - def analyze_mono(self, code: str, guideline: Guideline, timeout: int = 10) -> ComplianceResult: + def analyze_mono(self, code: str, guideline: Guideline, **kwargs: Any) -> ComplianceResult: # Check args before sending a request if len(code) == 0 or len(guideline.details) == 0: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="No code or guideline provided for analysis." ) - res = self._analyze(MONO_PROMPT, {"code": code, "guideline": guideline.details}, MONO_SCHEMA, timeout) + res = self._analyze(MONO_PROMPT, {"code": code, "guideline": guideline.details}, MONO_SCHEMA, **kwargs) # Return with pydantic validation return ComplianceResult(guideline_id=guideline.id, **res) - def _analyze(self, prompt: str, payload: Dict[str, Any], schema: ObjectSchema, timeout: int = 10) -> Dict[str, Any]: + def _analyze(self, prompt: str, payload: Dict[str, Any], schema: ObjectSchema, timeout: int = 20) -> Dict[str, Any]: # Prepare the request _payload = ChatCompletion( model=self.model,