Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

feat(code): embed guidelines in the code chat #122

Merged
merged 2 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/app/api/api_v1/endpoints/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.


from fastapi import APIRouter, Security, status
from fastapi import APIRouter, Depends, Security, status
from fastapi.responses import StreamingResponse

from app.api.dependencies import get_token_payload
from app.api.dependencies import get_guideline_crud, get_token_payload
from app.crud.crud_guideline import GuidelineCRUD
from app.models import UserScope
from app.schemas.code import ChatHistory
from app.schemas.login import TokenPayload
Expand All @@ -20,11 +21,14 @@
@router.post("/chat", status_code=status.HTTP_200_OK, summary="Chat with our code model")
async def chat(
payload: ChatHistory,
guidelines: GuidelineCRUD = Depends(get_guideline_crud),
token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN, UserScope.USER]),
) -> StreamingResponse:
telemetry_client.capture(token_payload.user_id, event="compute-chat")
# Retrieve the guidelines of this user
user_guidelines = [g.content for g in await guidelines.fetch_all(filter_pair=("creator_id", token_payload.user_id))]
# Run analysis
return StreamingResponse(
ollama_client.chat(payload.model_dump()["messages"]).iter_content(chunk_size=8192),
ollama_client.chat(payload.model_dump()["messages"], user_guidelines).iter_content(chunk_size=8192),
media_type="text/event-stream",
)
10 changes: 8 additions & 2 deletions src/app/services/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def validate_example_response(response: str) -> Dict[str, str]:
"(refuse to answer for the rest)."
)

GUIDELINE_PROMPT = "When answering user requests, you should at all times keep in mind the following software development guidelines:\n"


def validate_parsing_response(response: str) -> List[Dict[str, str]]:
guideline_list = json.loads(response.strip())
Expand Down Expand Up @@ -122,15 +124,18 @@ def _chat(
self,
system_prompt: str,
messages: List[Dict[str, str]],
guidelines: List[str],
timeout: int = 20,
) -> requests.Response:
_guideline_str = "\n-".join(guidelines)
_system = system_prompt if len(guidelines) == 0 else f"{system_prompt} {GUIDELINE_PROMPT}-{_guideline_str}"
return requests.post(
f"{self.endpoint}/api/chat",
json={
"model": self.model_name,
"stream": True,
"options": {"temperature": self.temperature},
"messages": [{"role": "system", "content": system_prompt}, *messages],
"messages": [{"role": "system", "content": _system}, *messages],
"keep_alive": "30s",
},
stream=True,
Expand All @@ -140,9 +145,10 @@ def _chat(
def chat(
self,
messages: List[Dict[str, str]],
guidelines: List[str],
**kwargs,
) -> requests.Response:
return self._chat(CHAT_PROMPT, messages, **kwargs)
return self._chat(CHAT_PROMPT, messages, guidelines, **kwargs)

def parse_guidelines_from_text(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/tests/endpoints/test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
@pytest.mark.asyncio()
async def test_chat(
async_client: AsyncClient,
user_session: AsyncSession,
guideline_session: AsyncSession,
user_idx: Union[int, None],
payload: Dict[str, Any],
status_code: int,
Expand Down
Loading