This repository has been archived by the owner on Oct 11, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Adds docker & client for Ollama (preview) (#42)
* feat: Improves example prompt * feat: Adds client for Ollama * feat: Adds config vars * feat: Adds ollama commented sections * feat: Adds docker with Ollama support * style: Fixes typing
- Loading branch information
Showing
5 changed files
with
257 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
version: '3.7' | ||
|
||
services: | ||
backend: | ||
build: . | ||
command: uvicorn app.main:app --reload --host 0.0.0.0 --port 8050 --proxy-headers | ||
volumes: | ||
- ./src/:/app/ | ||
ports: | ||
- "8050:8050" | ||
environment: | ||
- SUPERADMIN_GH_PAT=${SUPERADMIN_GH_PAT} | ||
- SUPERADMIN_PWD=${SUPERADMIN_PWD} | ||
- GH_OAUTH_ID=${GH_OAUTH_ID} | ||
- GH_OAUTH_SECRET=${GH_OAUTH_SECRET} | ||
- POSTGRES_URL=postgresql+asyncpg://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db/${POSTGRES_DB} | ||
- OPENAI_API_KEY=${OPENAI_API_KEY} | ||
- SECRET_KEY=${SECRET_KEY} | ||
- SENTRY_DSN=${SENTRY_DSN} | ||
- SERVER_NAME=${SERVER_NAME} | ||
- POSTHOG_KEY=${POSTHOG_KEY} | ||
- SLACK_API_TOKEN=${SLACK_API_TOKEN} | ||
- SLACK_CHANNEL=${SLACK_CHANNEL} | ||
- OLLAMA_ENDPOINT=http://ollama:11434 | ||
- OLLAMA_MODEL=${OLLAMA_MODEL} | ||
- SUPPORT_EMAIL=${SUPPORT_EMAIL} | ||
- DEBUG=true | ||
depends_on: | ||
db: | ||
condition: service_healthy | ||
ollama: | ||
condition: service_healthy | ||
|
||
ollama: | ||
image: ollama/ollama:latest | ||
command: serve && ollama pull ${OLLAMA_MODEL} | ||
volumes: | ||
- "$HOME/.ollama:/root/.ollama" | ||
expose: | ||
- 11434 | ||
healthcheck: | ||
test: ["CMD-SHELL", "ollama list | grep -q '${OLLAMA_MODEL}'"] | ||
interval: 10s | ||
timeout: 3s | ||
retries: 3 | ||
deploy: | ||
resources: | ||
reservations: | ||
devices: | ||
- driver: nvidia | ||
count: 1 | ||
capabilities: [gpu] | ||
|
||
# ollama-webui: | ||
# image: ghcr.io/ollama-webui/ollama-webui:main | ||
# container_name: ollama-webui | ||
# depends_on: | ||
# - ollama | ||
# # Uncomment below for WIP: Auth support | ||
# # - ollama-webui-db | ||
# ports: | ||
# - 3000:8080 | ||
# environment: | ||
# - "OLLAMA_API_BASE_URL=http://ollama:11434/api" | ||
# # Uncomment below for WIP: Auth support | ||
# # - "WEBUI_AUTH=TRUE" | ||
# # - "WEBUI_DB_URL=mongodb://root:example@ollama-webui-db:27017/" | ||
# # - "WEBUI_JWT_SECRET_KEY=SECRET_KEY" | ||
# # extra_hosts: | ||
# # - host.docker.internal:host-gateway | ||
# # restart: unless-stopped | ||
|
||
db: | ||
image: postgres:15-alpine | ||
volumes: | ||
- postgres_data:/var/lib/postgresql/data/ | ||
expose: | ||
- 5432 | ||
environment: | ||
- POSTGRES_USER=${POSTGRES_USER} | ||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD} | ||
- POSTGRES_DB=${POSTGRES_DB} | ||
healthcheck: | ||
test: ["CMD-SHELL", "sh -c 'pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}'"] | ||
interval: 10s | ||
timeout: 3s | ||
retries: 3 | ||
|
||
volumes: | ||
postgres_data: | ||
ollama: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# Copyright (C) 2023, Quack AI. | ||
|
||
# This program is licensed under the Apache License 2.0. | ||
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details. | ||
|
||
import json | ||
import logging | ||
import re | ||
from typing import Callable, Dict, List, TypeVar | ||
|
||
import requests | ||
from fastapi import HTTPException, status | ||
|
||
from app.schemas.guidelines import GuidelineContent, GuidelineExample | ||
|
||
logger = logging.getLogger("uvicorn.error") | ||
|
||
ValidationOut = TypeVar("ValidationOut") | ||
# __all__ = ["ollama_client"] | ||
|
||
|
||
EXAMPLE_PROMPT = ( | ||
"You are responsible for producing concise illustrations of the company coding guidelines. " | ||
"This will be used to teach new developers our way of engineering software. " | ||
"Make sure your code is in the specified programming language and functional, don't add extra comments or explanations.\n" | ||
# Format | ||
"You should output two code blocks: " | ||
"a minimal code snippet where the instruction was correctly followed, " | ||
"and the same snippet with minimal modifications that invalidates the instruction." | ||
) | ||
# Strangely, this doesn't work when compiled | ||
EXAMPLE_PATTERN = r"```[a-zA-Z]*\n(?P<positive>.*?)```\n.*```[a-zA-Z]*\n(?P<negative>.*?)```" | ||
|
||
|
||
def validate_example_response(response: str) -> Dict[str, str]: | ||
matches = re.search(EXAMPLE_PATTERN, response.strip(), re.DOTALL) | ||
if matches is None: | ||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed output schema validation") | ||
|
||
return matches.groupdict() | ||
|
||
|
||
PARSING_PROMPT = ( | ||
"You are responsible for summarizing the list of distinct coding guidelines for the company, by going through documentation. " | ||
"This list will be used by developers to avoid hesitations in code reviews and to onboard new members. " | ||
"Consider only guidelines that can be verified for a specific snippet of code (nothing about git, commits or community interactions) " | ||
"by a human developer without running additional commands or tools, it should only relate to the code within each file. " | ||
"Only include guidelines for which you could generate positive and negative code snippets, " | ||
"don't invent anything that isn't present in the input text.\n" | ||
# Format | ||
"You should answer with a list of JSON dictionaries, one dictionary per guideline, where each dictionary has two keys with string values:\n" | ||
"- title: a short summary title of the guideline\n" | ||
"- details: a descriptive, comprehensive and inambiguous explanation of the guideline." | ||
) | ||
PARSING_PATTERN = r"\{\s*\"title\":\s+\"(?P<title>.*?)\",\s+\"details\":\s+\"(?P<details>.*?)\"\s*\}" | ||
|
||
|
||
def validate_parsing_response(response: str) -> List[Dict[str, str]]: | ||
guideline_list = json.loads(response.strip()) | ||
if not isinstance(guideline_list, list) or any( | ||
not isinstance(val, str) for guideline in guideline_list for val in guideline.values() | ||
): | ||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed output schema validation") | ||
|
||
return json.loads(response.strip()) | ||
|
||
|
||
class OllamaClient: | ||
def __init__(self, endpoint: str, model_name: str, temperature: float = 0.0) -> None: | ||
self.endpoint = endpoint | ||
# Check endpoint | ||
response = requests.get(f"{self.endpoint}/api/tags", timeout=2) | ||
if response.status_code != 200: | ||
raise HTTPException(status_code=status.HTTP_404, detail="Unavailable endpoint") | ||
# Pull model | ||
logger.info("Loading Ollama model...") | ||
response = requests.post(f"{self.endpoint}/api/pull", json={"name": model_name, "stream": False}, timeout=10) | ||
if response.status_code != 200 or response.json()["status"] != "success": | ||
raise HTTPException(status_code=status.HTTP_404, detail="Unavailable model") | ||
self.temperature = temperature | ||
self.model_name = model_name | ||
logger.info(f"Using Ollama model: {self.model_name}") | ||
|
||
def _request( | ||
self, | ||
system_prompt: str, | ||
message: str, | ||
validate_fn: Callable[[str], ValidationOut], | ||
timeout: int = 20, | ||
) -> ValidationOut: | ||
# Send the request | ||
response = requests.post( | ||
f"{self.endpoint}/api/generate", | ||
json={ | ||
"model": self.model_name, | ||
"stream": False, | ||
"options": {"temperature": self.temperature}, | ||
"system": system_prompt, | ||
"prompt": message, | ||
}, | ||
timeout=timeout, | ||
) | ||
|
||
# Check status | ||
if response.status_code != 200: | ||
raise HTTPException(status_code=response.status_code, detail=response.json()["error"]) | ||
|
||
# Regex to locate JSON string | ||
logger.info(response.json()["response"].strip()) | ||
return validate_fn(response.json()["response"]) | ||
|
||
def parse_guidelines_from_text( | ||
self, | ||
corpus: str, | ||
timeout: int = 20, | ||
) -> List[GuidelineContent]: | ||
if not isinstance(corpus, str): | ||
raise HTTPException( | ||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, | ||
detail="The input corpus needs to be a string.", | ||
) | ||
if len(corpus) == 0: | ||
return [] | ||
|
||
response = self._request( | ||
PARSING_PROMPT, | ||
json.dumps(corpus), | ||
validate_parsing_response, | ||
timeout, | ||
) | ||
|
||
return [GuidelineContent(**elt) for elt in response] | ||
|
||
def generate_examples_for_instruction( | ||
self, | ||
instruction: str, | ||
language: str, | ||
timeout: int = 20, | ||
) -> GuidelineExample: | ||
if ( | ||
not isinstance(instruction, str) | ||
or len(instruction) == 0 | ||
or not isinstance(language, str) | ||
or len(language) == 0 | ||
): | ||
raise HTTPException( | ||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, | ||
detail="The instruction and language need to be non-empty strings.", | ||
) | ||
|
||
return GuidelineExample( | ||
**self._request( | ||
EXAMPLE_PROMPT, | ||
json.dumps({"guideline": instruction, "language": language}), | ||
validate_example_response, | ||
timeout, | ||
), | ||
) | ||
|
||
|
||
# ollama_client = OllamaClient(settings.OLLAMA_ENDPOINT, settings.OLLAMA_MODEL) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters