Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

feat(llm): add integration for Claude #171

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ GROQ_API_KEY=
GROQ_MODEL='llama3-8b-8192'
OPENAI_API_KEY=
OPENAI_MODEL='gpt-4o-2024-05-13'
CLAUDE_API_KEY=
CLAUDE_MODEL='claude-3-opus-20240229'
LLM_TEMPERATURE=0
JWT_SECRET=
SENTRY_DSN=
Expand Down
4 changes: 3 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ If you are wondering how to do something with Companion API, or a more general q
- [Poetry](https://python-poetry.org/docs/)
- [Make](https://www.gnu.org/software/make/) (optional)

_*If you don't have a GPU, you can use alternative LLM providers (currently supported: Groq, OpenAI)_
_*If you don't have a GPU, you can use alternative LLM providers (currently supported: Groq, OpenAI, Claude)_


### Configure your fork
Expand Down Expand Up @@ -138,6 +138,8 @@ This file contains all the information to run the project.
- `GROQ_MODEL`: the model tag in [Groq supported models](https://console.groq.com/docs/models) that will be used for the API.
- `OPENAI_API_KEY`: your [OpenAI API KEY](https://platform.openai.com/api-keys), required if you select `openai` as `LLM_PROVIDER`.
- `OPENAI_MODEL`: the model tag in [OpenAI supported models](https://platform.openai.com/docs/models) that will be used for the API.
- `CLAUDE_API_KEY`: your [CLAUDE_API_KEY](https://console.anthropic.com/settings/keys), required if you select `claude` as `LLM_PROVIDER`
- `CLAUDE_MODEL`: the model tag in [Claude supported models](https://docs.anthropic.com/en/docs/models-overview) that will be used for the API.
- `SENTRY_DSN`: the DSN for your [Sentry](https://sentry.io/) project, which monitors back-end errors and report them back.
- `SERVER_NAME`: the server tag that will be used to report events to Sentry.
- `POSTHOG_HOST`: the host for PostHog [PostHog](https://eu.posthog.com/settings/project-details).
Expand Down
2 changes: 2 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ services:
- GROQ_MODEL=${GROQ_MODEL}
- OPENAI_API_KEY=${OPENAI_API_KEY}
- OPENAI_MODEL=${OPENAI_MODEL}
- CLAUDE_API_KEY=${CLAUDE_API_KEY}
- CLAUDE_MODEL=${CLAUDE_MODEL}
- OLLAMA_TIMEOUT=${OLLAMA_TIMEOUT:-60}
- SUPPORT_EMAIL=${SUPPORT_EMAIL}
- DEBUG=true
Expand Down
2 changes: 1 addition & 1 deletion docs/developers/self-hosting.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Whatever your installation method, you'll need at least the following to be inst
1. [Docker](https://docs.docker.com/engine/install/) (and [Docker compose](https://docs.docker.com/compose/) if you're using an old version)
2. [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) and a GPU

_We recommend min 5Gb of VRAM on your GPU for good performance/latency balance. Please note that by default, this will run your LLM locally (available offline) but if you don't have a GPU, you can use online LLM providers (currently supported: Groq, OpenAI)_
_We recommend min 5Gb of VRAM on your GPU for good performance/latency balance. Please note that by default, this will run your LLM locally (available offline) but if you don't have a GPU, you can use online LLM providers (currently supported: Groq, OpenAI, Claude)_

### 60 seconds setup ⏱️

Expand Down
216 changes: 214 additions & 2 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ ollama = "^0.2.0"
openai = "^1.29.0"
uvloop = "^0.19.0"
httptools = "^0.6.1"
anthropic = "^0.26.1"

[tool.poetry.group.quality]
optional = true
Expand Down
71 changes: 71 additions & 0 deletions src/app/services/llm/claude.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (C) 2023-2024, Quack AI.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

import logging
from enum import Enum
from typing import Dict, Generator, List, Union

from anthropic import BaseModel, Client

from .utils import CHAT_PROMPT

logger = logging.getLogger("uvicorn.error")


class ClaudeModel(str, Enum):
OPUS: str = "claude-3-opus-20240229"
SONNET: str = "claude-3-sonnet-20240229"
HAIKU: str = "claude-3-haiku-20240307"


class ClaudeClient:
def __init__(
self,
api_key: str,
model: ClaudeModel,
temperature: float = 0.0,
) -> None:
self._client = Client(api_key=api_key)
self.model = model

self._validate_model()

self.temperature = temperature
# model_card = BaseModel.retrieve(model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need to verify whether the model selection is available

logger.info(
f"Using Claude w/ {self.model} (created at "
# {datetime.fromtimestamp(model_card.created).isoformat()})",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the information is available, it would be worth displaying it :) I was only saying earlier that I'm not sure the attribute is named the same way than groq or openai

)

def _validate_model(self) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep it as a static method (with the corresponding arg, like in groq/openai integration)

input_dict = {"model_type": self.model}
validation_result = BaseModel.model_validate(input_dict)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a confusion here: we want to validate the LLM name reference in the API (=model name), not the Pydantic schema model

if not validation_result:
raise ValueError(f"Invalid model: {self.model}")

def chat(
self,
messages: List[Dict[str, str]],
system: Union[str, None] = None,
) -> Generator[str, None, None]:
# Prepare the request
_system = CHAT_PROMPT if not system else f"{CHAT_PROMPT} {system}"

stream = self._client.messages.create(
messages=[{"role": "user", "content": _system}, *messages],
model=self.model,
temperature=self.temperature,
max_tokens=2048,
stream=True,
top_p=1.0,
)

for chunk in stream:
if len(chunk.choices) > 0 and isinstance(chunk.choices[0].delta.content, str):
yield chunk.choices[0].delta.content
if chunk.usage:
logger.info(
f"Claude ({self.model}): {chunk.usage.prompt_tokens} prompt tokens | {chunk.usage.completion_tokens} completion tokens",
)
24 changes: 24 additions & 0 deletions src/tests/services/test_llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import types

import pytest
from anthropic import AuthenticationError as CAuthError
from anthropic import NotFoundError as CNotFoundError
from groq import AuthenticationError as GAuthError
from groq import NotFoundError as GNotFoundError
from httpx import ConnectError
Expand All @@ -9,6 +11,7 @@
from openai import NotFoundError as OAINotFounderError

from app.core.config import settings
from app.services.llm.claude import ClaudeClient
from app.services.llm.groq import GroqClient
from app.services.llm.ollama import OllamaClient
from app.services.llm.openai import OpenAIClient
Expand Down Expand Up @@ -71,3 +74,24 @@ def test_openaiclient_chat():
assert isinstance(stream, types.GeneratorType)
for chunk in stream:
assert isinstance(chunk, str)


def test_claudeclient_constructor():
with pytest.raises(CAuthError):
ClaudeClient("api_key", settings.CLAUDE_MODEL)
if isinstance(settings.CLAUDE_API_KEY, str):
with pytest.raises(CNotFoundError):
ClaudeClient(settings.CLAUDE_API_KEY, "quack")
ClaudeClient(settings.CLAUDE_API_KEY, settings.CLAUDE_MODEL)


@pytest.mark.skipif("settings.CLAUDE_API_KEY is None")
def test_claudeclient_chat():
llm_client = ClaudeClient(settings.CLAUDE_API_KEY, settings.CLAUDE_MODEL)
messages = [
{"role": "assistant", "content": "Hello, how are you?"},
]
stream = llm_client.chat(messages=messages)
assert isinstance(stream, types.GeneratorType)
for chunk in stream:
assert isinstance(chunk, str)
Loading