Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
feat: Adds docker & client for Ollama (preview) (#42)
Browse files Browse the repository at this point in the history
* feat: Improves example prompt

* feat: Adds client for Ollama

* feat: Adds config vars

* feat: Adds ollama commented sections

* feat: Adds docker with Ollama support

* style: Fixes typing
  • Loading branch information
frgfm authored Dec 11, 2023
1 parent dc080b1 commit 99a99b3
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 1 deletion.
91 changes: 91 additions & 0 deletions docker-compose.ollama.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
version: '3.7'

services:
backend:
build: .
command: uvicorn app.main:app --reload --host 0.0.0.0 --port 8050 --proxy-headers
volumes:
- ./src/:/app/
ports:
- "8050:8050"
environment:
- SUPERADMIN_GH_PAT=${SUPERADMIN_GH_PAT}
- SUPERADMIN_PWD=${SUPERADMIN_PWD}
- GH_OAUTH_ID=${GH_OAUTH_ID}
- GH_OAUTH_SECRET=${GH_OAUTH_SECRET}
- POSTGRES_URL=postgresql+asyncpg://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db/${POSTGRES_DB}
- OPENAI_API_KEY=${OPENAI_API_KEY}
- SECRET_KEY=${SECRET_KEY}
- SENTRY_DSN=${SENTRY_DSN}
- SERVER_NAME=${SERVER_NAME}
- POSTHOG_KEY=${POSTHOG_KEY}
- SLACK_API_TOKEN=${SLACK_API_TOKEN}
- SLACK_CHANNEL=${SLACK_CHANNEL}
- OLLAMA_ENDPOINT=http://ollama:11434
- OLLAMA_MODEL=${OLLAMA_MODEL}
- SUPPORT_EMAIL=${SUPPORT_EMAIL}
- DEBUG=true
depends_on:
db:
condition: service_healthy
ollama:
condition: service_healthy

ollama:
image: ollama/ollama:latest
command: serve && ollama pull ${OLLAMA_MODEL}
volumes:
- "$HOME/.ollama:/root/.ollama"
expose:
- 11434
healthcheck:
test: ["CMD-SHELL", "ollama list | grep -q '${OLLAMA_MODEL}'"]
interval: 10s
timeout: 3s
retries: 3
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

# ollama-webui:
# image: ghcr.io/ollama-webui/ollama-webui:main
# container_name: ollama-webui
# depends_on:
# - ollama
# # Uncomment below for WIP: Auth support
# # - ollama-webui-db
# ports:
# - 3000:8080
# environment:
# - "OLLAMA_API_BASE_URL=http://ollama:11434/api"
# # Uncomment below for WIP: Auth support
# # - "WEBUI_AUTH=TRUE"
# # - "WEBUI_DB_URL=mongodb://root:example@ollama-webui-db:27017/"
# # - "WEBUI_JWT_SECRET_KEY=SECRET_KEY"
# # extra_hosts:
# # - host.docker.internal:host-gateway
# # restart: unless-stopped

db:
image: postgres:15-alpine
volumes:
- postgres_data:/var/lib/postgresql/data/
expose:
- 5432
environment:
- POSTGRES_USER=${POSTGRES_USER}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
- POSTGRES_DB=${POSTGRES_DB}
healthcheck:
test: ["CMD-SHELL", "sh -c 'pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}'"]
interval: 10s
timeout: 3s
retries: 3

volumes:
postgres_data:
ollama:
2 changes: 2 additions & 0 deletions src/app/api/api_v1/endpoints/guidelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ async def parse_guidelines_from_text(
telemetry_client.capture(user.id, event="guideline-parse")
# Analyze with LLM
return openai_client.parse_guidelines_from_text(payload.content, user_id=str(user.id))
# return ollama_client.parse_guidelines_from_text(payload.content)


@router.post("/examples", status_code=status.HTTP_200_OK, summary="Request examples for a guideline")
Expand All @@ -133,3 +134,4 @@ async def generate_examples_for_text(
telemetry_client.capture(user.id, event="guideline-examples")
# Analyze with LLM
return openai_client.generate_examples_for_instruction(payload.content, payload.language, user_id=str(user.id))
# return ollama_client.generate_examples_for_instruction(payload.content, payload.language)
2 changes: 2 additions & 0 deletions src/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def sqlachmey_uri(cls, v: str) -> str:
# Compute
OPENAI_API_KEY: str = os.environ["OPENAI_API_KEY"]
OPENAI_MODEL: OpenAIModel = OpenAIModel.GPT3_5_TURBO
# OLLAMA_ENDPOINT: str = os.environ["OLLAMA_ENDPOINT"]
# OLLAMA_MODEL: str = os.environ.get("OLLAMA_MODEL", "starling-lm:7b-alpha")

# Error monitoring
SENTRY_DSN: Union[str, None] = os.environ.get("SENTRY_DSN")
Expand Down
161 changes: 161 additions & 0 deletions src/app/services/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (C) 2023, Quack AI.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

import json
import logging
import re
from typing import Callable, Dict, List, TypeVar

import requests
from fastapi import HTTPException, status

from app.schemas.guidelines import GuidelineContent, GuidelineExample

logger = logging.getLogger("uvicorn.error")

ValidationOut = TypeVar("ValidationOut")
# __all__ = ["ollama_client"]


EXAMPLE_PROMPT = (
"You are responsible for producing concise illustrations of the company coding guidelines. "
"This will be used to teach new developers our way of engineering software. "
"Make sure your code is in the specified programming language and functional, don't add extra comments or explanations.\n"
# Format
"You should output two code blocks: "
"a minimal code snippet where the instruction was correctly followed, "
"and the same snippet with minimal modifications that invalidates the instruction."
)
# Strangely, this doesn't work when compiled
EXAMPLE_PATTERN = r"```[a-zA-Z]*\n(?P<positive>.*?)```\n.*```[a-zA-Z]*\n(?P<negative>.*?)```"


def validate_example_response(response: str) -> Dict[str, str]:
matches = re.search(EXAMPLE_PATTERN, response.strip(), re.DOTALL)
if matches is None:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed output schema validation")

return matches.groupdict()


PARSING_PROMPT = (
"You are responsible for summarizing the list of distinct coding guidelines for the company, by going through documentation. "
"This list will be used by developers to avoid hesitations in code reviews and to onboard new members. "
"Consider only guidelines that can be verified for a specific snippet of code (nothing about git, commits or community interactions) "
"by a human developer without running additional commands or tools, it should only relate to the code within each file. "
"Only include guidelines for which you could generate positive and negative code snippets, "
"don't invent anything that isn't present in the input text.\n"
# Format
"You should answer with a list of JSON dictionaries, one dictionary per guideline, where each dictionary has two keys with string values:\n"
"- title: a short summary title of the guideline\n"
"- details: a descriptive, comprehensive and inambiguous explanation of the guideline."
)
PARSING_PATTERN = r"\{\s*\"title\":\s+\"(?P<title>.*?)\",\s+\"details\":\s+\"(?P<details>.*?)\"\s*\}"


def validate_parsing_response(response: str) -> List[Dict[str, str]]:
guideline_list = json.loads(response.strip())
if not isinstance(guideline_list, list) or any(
not isinstance(val, str) for guideline in guideline_list for val in guideline.values()
):
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed output schema validation")

return json.loads(response.strip())


class OllamaClient:
def __init__(self, endpoint: str, model_name: str, temperature: float = 0.0) -> None:
self.endpoint = endpoint
# Check endpoint
response = requests.get(f"{self.endpoint}/api/tags", timeout=2)
if response.status_code != 200:
raise HTTPException(status_code=status.HTTP_404, detail="Unavailable endpoint")
# Pull model
logger.info("Loading Ollama model...")
response = requests.post(f"{self.endpoint}/api/pull", json={"name": model_name, "stream": False}, timeout=10)
if response.status_code != 200 or response.json()["status"] != "success":
raise HTTPException(status_code=status.HTTP_404, detail="Unavailable model")
self.temperature = temperature
self.model_name = model_name
logger.info(f"Using Ollama model: {self.model_name}")

def _request(
self,
system_prompt: str,
message: str,
validate_fn: Callable[[str], ValidationOut],
timeout: int = 20,
) -> ValidationOut:
# Send the request
response = requests.post(
f"{self.endpoint}/api/generate",
json={
"model": self.model_name,
"stream": False,
"options": {"temperature": self.temperature},
"system": system_prompt,
"prompt": message,
},
timeout=timeout,
)

# Check status
if response.status_code != 200:
raise HTTPException(status_code=response.status_code, detail=response.json()["error"])

# Regex to locate JSON string
logger.info(response.json()["response"].strip())
return validate_fn(response.json()["response"])

def parse_guidelines_from_text(
self,
corpus: str,
timeout: int = 20,
) -> List[GuidelineContent]:
if not isinstance(corpus, str):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="The input corpus needs to be a string.",
)
if len(corpus) == 0:
return []

response = self._request(
PARSING_PROMPT,
json.dumps(corpus),
validate_parsing_response,
timeout,
)

return [GuidelineContent(**elt) for elt in response]

def generate_examples_for_instruction(
self,
instruction: str,
language: str,
timeout: int = 20,
) -> GuidelineExample:
if (
not isinstance(instruction, str)
or len(instruction) == 0
or not isinstance(language, str)
or len(language) == 0
):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="The instruction and language need to be non-empty strings.",
)

return GuidelineExample(
**self._request(
EXAMPLE_PROMPT,
json.dumps({"guideline": instruction, "language": language}),
validate_example_response,
timeout,
),
)


# ollama_client = OllamaClient(settings.OLLAMA_ENDPOINT, settings.OLLAMA_MODEL)
2 changes: 1 addition & 1 deletion src/app/services/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class ExecutionMode(str, Enum):
"This will be used to teach new developers our way of engineering software. "
"You should answer in JSON format with only two short code snippets in the specified programming language: one that follows the rule correctly, "
"and a similar version with minimal modifications that violates the rule. "
"Make sure your code is functional, don't extra comments or explanation."
"Make sure your code is functional, don't extra comments or explanation, or someone will die."
)

ModelInp = TypeVar("ModelInp")
Expand Down

0 comments on commit 99a99b3

Please sign in to comment.