From 4db0e575af6ba26c160196cd16967b24b5993f14 Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Mon, 11 Mar 2024 17:21:25 +0100 Subject: [PATCH 1/2] feat(code): add guidelines to the code chat route --- src/app/api/api_v1/endpoints/code.py | 10 +++++++--- src/app/services/ollama.py | 10 ++++++++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/app/api/api_v1/endpoints/code.py b/src/app/api/api_v1/endpoints/code.py index c72c644..ae0a22a 100644 --- a/src/app/api/api_v1/endpoints/code.py +++ b/src/app/api/api_v1/endpoints/code.py @@ -4,10 +4,11 @@ # See LICENSE or go to for full license details. -from fastapi import APIRouter, Security, status +from fastapi import APIRouter, Depends, Security, status from fastapi.responses import StreamingResponse -from app.api.dependencies import get_token_payload +from app.api.dependencies import get_guideline_crud, get_token_payload +from app.crud.crud_guideline import GuidelineCRUD from app.models import UserScope from app.schemas.code import ChatHistory from app.schemas.login import TokenPayload @@ -20,11 +21,14 @@ @router.post("/chat", status_code=status.HTTP_200_OK, summary="Chat with our code model") async def chat( payload: ChatHistory, + guidelines: GuidelineCRUD = Depends(get_guideline_crud), token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN, UserScope.USER]), ) -> StreamingResponse: telemetry_client.capture(token_payload.user_id, event="compute-chat") + # Retrieve the guidelines of this user + user_guidelines = [g.content for g in await guidelines.fetch_all(filter_pair=("creator_id", token_payload.user_id))] # Run analysis return StreamingResponse( - ollama_client.chat(payload.model_dump()["messages"]).iter_content(chunk_size=8192), + ollama_client.chat(payload.model_dump()["messages"], user_guidelines).iter_content(chunk_size=8192), media_type="text/event-stream", ) diff --git a/src/app/services/ollama.py b/src/app/services/ollama.py index c98036e..e160a0e 100644 --- a/src/app/services/ollama.py +++ b/src/app/services/ollama.py @@ -61,6 +61,8 @@ def validate_example_response(response: str) -> Dict[str, str]: "(refuse to answer for the rest)." ) +GUIDELINE_PROMPT = "When answering user requests, you should at all times keep in mind the following software development guidelines:\n" + def validate_parsing_response(response: str) -> List[Dict[str, str]]: guideline_list = json.loads(response.strip()) @@ -122,15 +124,18 @@ def _chat( self, system_prompt: str, messages: List[Dict[str, str]], + guidelines: List[str], timeout: int = 20, ) -> requests.Response: + _guideline_str = "\n-".join(guidelines) + _system = system_prompt if len(guidelines) == 0 else f"{system_prompt} {GUIDELINE_PROMPT}-{_guideline_str}" return requests.post( f"{self.endpoint}/api/chat", json={ "model": self.model_name, "stream": True, "options": {"temperature": self.temperature}, - "messages": [{"role": "system", "content": system_prompt}, *messages], + "messages": [{"role": "system", "content": _system}, *messages], "keep_alive": "30s", }, stream=True, @@ -140,9 +145,10 @@ def _chat( def chat( self, messages: List[Dict[str, str]], + guidelines: List[str], **kwargs, ) -> requests.Response: - return self._chat(CHAT_PROMPT, messages, **kwargs) + return self._chat(CHAT_PROMPT, messages, guidelines, **kwargs) def parse_guidelines_from_text( self, From 62bc287a32083f76fea6d0e0016c8c928daf5cad Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Mon, 11 Mar 2024 17:23:51 +0100 Subject: [PATCH 2/2] test(code): update fixture --- src/tests/endpoints/test_code.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tests/endpoints/test_code.py b/src/tests/endpoints/test_code.py index 9da724c..8474936 100644 --- a/src/tests/endpoints/test_code.py +++ b/src/tests/endpoints/test_code.py @@ -53,7 +53,7 @@ @pytest.mark.asyncio() async def test_chat( async_client: AsyncClient, - user_session: AsyncSession, + guideline_session: AsyncSession, user_idx: Union[int, None], payload: Dict[str, Any], status_code: int,