From 99a99b3d9cdc89310b17329189a9c7ac88bfbdf4 Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Mon, 11 Dec 2023 15:14:42 +0100 Subject: [PATCH] feat: Adds docker & client for Ollama (preview) (#42) * feat: Improves example prompt * feat: Adds client for Ollama * feat: Adds config vars * feat: Adds ollama commented sections * feat: Adds docker with Ollama support * style: Fixes typing --- docker-compose.ollama.yml | 91 ++++++++++++ src/app/api/api_v1/endpoints/guidelines.py | 2 + src/app/core/config.py | 2 + src/app/services/ollama.py | 161 +++++++++++++++++++++ src/app/services/openai.py | 2 +- 5 files changed, 257 insertions(+), 1 deletion(-) create mode 100644 docker-compose.ollama.yml create mode 100644 src/app/services/ollama.py diff --git a/docker-compose.ollama.yml b/docker-compose.ollama.yml new file mode 100644 index 0000000..82ffec3 --- /dev/null +++ b/docker-compose.ollama.yml @@ -0,0 +1,91 @@ +version: '3.7' + +services: + backend: + build: . + command: uvicorn app.main:app --reload --host 0.0.0.0 --port 8050 --proxy-headers + volumes: + - ./src/:/app/ + ports: + - "8050:8050" + environment: + - SUPERADMIN_GH_PAT=${SUPERADMIN_GH_PAT} + - SUPERADMIN_PWD=${SUPERADMIN_PWD} + - GH_OAUTH_ID=${GH_OAUTH_ID} + - GH_OAUTH_SECRET=${GH_OAUTH_SECRET} + - POSTGRES_URL=postgresql+asyncpg://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db/${POSTGRES_DB} + - OPENAI_API_KEY=${OPENAI_API_KEY} + - SECRET_KEY=${SECRET_KEY} + - SENTRY_DSN=${SENTRY_DSN} + - SERVER_NAME=${SERVER_NAME} + - POSTHOG_KEY=${POSTHOG_KEY} + - SLACK_API_TOKEN=${SLACK_API_TOKEN} + - SLACK_CHANNEL=${SLACK_CHANNEL} + - OLLAMA_ENDPOINT=http://ollama:11434 + - OLLAMA_MODEL=${OLLAMA_MODEL} + - SUPPORT_EMAIL=${SUPPORT_EMAIL} + - DEBUG=true + depends_on: + db: + condition: service_healthy + ollama: + condition: service_healthy + + ollama: + image: ollama/ollama:latest + command: serve && ollama pull ${OLLAMA_MODEL} + volumes: + - "$HOME/.ollama:/root/.ollama" + expose: + - 11434 + healthcheck: + test: ["CMD-SHELL", "ollama list | grep -q '${OLLAMA_MODEL}'"] + interval: 10s + timeout: 3s + retries: 3 + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + + # ollama-webui: + # image: ghcr.io/ollama-webui/ollama-webui:main + # container_name: ollama-webui + # depends_on: + # - ollama + # # Uncomment below for WIP: Auth support + # # - ollama-webui-db + # ports: + # - 3000:8080 + # environment: + # - "OLLAMA_API_BASE_URL=http://ollama:11434/api" + # # Uncomment below for WIP: Auth support + # # - "WEBUI_AUTH=TRUE" + # # - "WEBUI_DB_URL=mongodb://root:example@ollama-webui-db:27017/" + # # - "WEBUI_JWT_SECRET_KEY=SECRET_KEY" + # # extra_hosts: + # # - host.docker.internal:host-gateway + # # restart: unless-stopped + + db: + image: postgres:15-alpine + volumes: + - postgres_data:/var/lib/postgresql/data/ + expose: + - 5432 + environment: + - POSTGRES_USER=${POSTGRES_USER} + - POSTGRES_PASSWORD=${POSTGRES_PASSWORD} + - POSTGRES_DB=${POSTGRES_DB} + healthcheck: + test: ["CMD-SHELL", "sh -c 'pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}'"] + interval: 10s + timeout: 3s + retries: 3 + +volumes: + postgres_data: + ollama: diff --git a/src/app/api/api_v1/endpoints/guidelines.py b/src/app/api/api_v1/endpoints/guidelines.py index 7e0a5d1..bd5a211 100644 --- a/src/app/api/api_v1/endpoints/guidelines.py +++ b/src/app/api/api_v1/endpoints/guidelines.py @@ -123,6 +123,7 @@ async def parse_guidelines_from_text( telemetry_client.capture(user.id, event="guideline-parse") # Analyze with LLM return openai_client.parse_guidelines_from_text(payload.content, user_id=str(user.id)) + # return ollama_client.parse_guidelines_from_text(payload.content) @router.post("/examples", status_code=status.HTTP_200_OK, summary="Request examples for a guideline") @@ -133,3 +134,4 @@ async def generate_examples_for_text( telemetry_client.capture(user.id, event="guideline-examples") # Analyze with LLM return openai_client.generate_examples_for_instruction(payload.content, payload.language, user_id=str(user.id)) + # return ollama_client.generate_examples_for_instruction(payload.content, payload.language) diff --git a/src/app/core/config.py b/src/app/core/config.py index 96358a8..1cb56fb 100644 --- a/src/app/core/config.py +++ b/src/app/core/config.py @@ -48,6 +48,8 @@ def sqlachmey_uri(cls, v: str) -> str: # Compute OPENAI_API_KEY: str = os.environ["OPENAI_API_KEY"] OPENAI_MODEL: OpenAIModel = OpenAIModel.GPT3_5_TURBO + # OLLAMA_ENDPOINT: str = os.environ["OLLAMA_ENDPOINT"] + # OLLAMA_MODEL: str = os.environ.get("OLLAMA_MODEL", "starling-lm:7b-alpha") # Error monitoring SENTRY_DSN: Union[str, None] = os.environ.get("SENTRY_DSN") diff --git a/src/app/services/ollama.py b/src/app/services/ollama.py new file mode 100644 index 0000000..85bcf05 --- /dev/null +++ b/src/app/services/ollama.py @@ -0,0 +1,161 @@ +# Copyright (C) 2023, Quack AI. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import json +import logging +import re +from typing import Callable, Dict, List, TypeVar + +import requests +from fastapi import HTTPException, status + +from app.schemas.guidelines import GuidelineContent, GuidelineExample + +logger = logging.getLogger("uvicorn.error") + +ValidationOut = TypeVar("ValidationOut") +# __all__ = ["ollama_client"] + + +EXAMPLE_PROMPT = ( + "You are responsible for producing concise illustrations of the company coding guidelines. " + "This will be used to teach new developers our way of engineering software. " + "Make sure your code is in the specified programming language and functional, don't add extra comments or explanations.\n" + # Format + "You should output two code blocks: " + "a minimal code snippet where the instruction was correctly followed, " + "and the same snippet with minimal modifications that invalidates the instruction." +) +# Strangely, this doesn't work when compiled +EXAMPLE_PATTERN = r"```[a-zA-Z]*\n(?P.*?)```\n.*```[a-zA-Z]*\n(?P.*?)```" + + +def validate_example_response(response: str) -> Dict[str, str]: + matches = re.search(EXAMPLE_PATTERN, response.strip(), re.DOTALL) + if matches is None: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed output schema validation") + + return matches.groupdict() + + +PARSING_PROMPT = ( + "You are responsible for summarizing the list of distinct coding guidelines for the company, by going through documentation. " + "This list will be used by developers to avoid hesitations in code reviews and to onboard new members. " + "Consider only guidelines that can be verified for a specific snippet of code (nothing about git, commits or community interactions) " + "by a human developer without running additional commands or tools, it should only relate to the code within each file. " + "Only include guidelines for which you could generate positive and negative code snippets, " + "don't invent anything that isn't present in the input text.\n" + # Format + "You should answer with a list of JSON dictionaries, one dictionary per guideline, where each dictionary has two keys with string values:\n" + "- title: a short summary title of the guideline\n" + "- details: a descriptive, comprehensive and inambiguous explanation of the guideline." +) +PARSING_PATTERN = r"\{\s*\"title\":\s+\"(?P.*?)\",\s+\"details\":\s+\"(?P<details>.*?)\"\s*\}" + + +def validate_parsing_response(response: str) -> List[Dict[str, str]]: + guideline_list = json.loads(response.strip()) + if not isinstance(guideline_list, list) or any( + not isinstance(val, str) for guideline in guideline_list for val in guideline.values() + ): + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed output schema validation") + + return json.loads(response.strip()) + + +class OllamaClient: + def __init__(self, endpoint: str, model_name: str, temperature: float = 0.0) -> None: + self.endpoint = endpoint + # Check endpoint + response = requests.get(f"{self.endpoint}/api/tags", timeout=2) + if response.status_code != 200: + raise HTTPException(status_code=status.HTTP_404, detail="Unavailable endpoint") + # Pull model + logger.info("Loading Ollama model...") + response = requests.post(f"{self.endpoint}/api/pull", json={"name": model_name, "stream": False}, timeout=10) + if response.status_code != 200 or response.json()["status"] != "success": + raise HTTPException(status_code=status.HTTP_404, detail="Unavailable model") + self.temperature = temperature + self.model_name = model_name + logger.info(f"Using Ollama model: {self.model_name}") + + def _request( + self, + system_prompt: str, + message: str, + validate_fn: Callable[[str], ValidationOut], + timeout: int = 20, + ) -> ValidationOut: + # Send the request + response = requests.post( + f"{self.endpoint}/api/generate", + json={ + "model": self.model_name, + "stream": False, + "options": {"temperature": self.temperature}, + "system": system_prompt, + "prompt": message, + }, + timeout=timeout, + ) + + # Check status + if response.status_code != 200: + raise HTTPException(status_code=response.status_code, detail=response.json()["error"]) + + # Regex to locate JSON string + logger.info(response.json()["response"].strip()) + return validate_fn(response.json()["response"]) + + def parse_guidelines_from_text( + self, + corpus: str, + timeout: int = 20, + ) -> List[GuidelineContent]: + if not isinstance(corpus, str): + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="The input corpus needs to be a string.", + ) + if len(corpus) == 0: + return [] + + response = self._request( + PARSING_PROMPT, + json.dumps(corpus), + validate_parsing_response, + timeout, + ) + + return [GuidelineContent(**elt) for elt in response] + + def generate_examples_for_instruction( + self, + instruction: str, + language: str, + timeout: int = 20, + ) -> GuidelineExample: + if ( + not isinstance(instruction, str) + or len(instruction) == 0 + or not isinstance(language, str) + or len(language) == 0 + ): + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="The instruction and language need to be non-empty strings.", + ) + + return GuidelineExample( + **self._request( + EXAMPLE_PROMPT, + json.dumps({"guideline": instruction, "language": language}), + validate_example_response, + timeout, + ), + ) + + +# ollama_client = OllamaClient(settings.OLLAMA_ENDPOINT, settings.OLLAMA_MODEL) diff --git a/src/app/services/openai.py b/src/app/services/openai.py index 6e4fc34..44cc79a 100644 --- a/src/app/services/openai.py +++ b/src/app/services/openai.py @@ -178,7 +178,7 @@ class ExecutionMode(str, Enum): "This will be used to teach new developers our way of engineering software. " "You should answer in JSON format with only two short code snippets in the specified programming language: one that follows the rule correctly, " "and a similar version with minimal modifications that violates the rule. " - "Make sure your code is functional, don't extra comments or explanation." + "Make sure your code is functional, don't extra comments or explanation, or someone will die." ) ModelInp = TypeVar("ModelInp")