-
Notifications
You must be signed in to change notification settings - Fork 12
feat(llm): add integration for Claude #171
base: main
Are you sure you want to change the base?
Changes from all commits
fdd53de
e19973d
37b5bb3
004fa2f
5bf2234
bdb71d4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright (C) 2023-2024, Quack AI. | ||
|
||
# This program is licensed under the Apache License 2.0. | ||
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details. | ||
|
||
import logging | ||
from enum import Enum | ||
from typing import Dict, Generator, List, Union | ||
|
||
from anthropic import BaseModel, Client | ||
|
||
from .utils import CHAT_PROMPT | ||
|
||
logger = logging.getLogger("uvicorn.error") | ||
|
||
|
||
class ClaudeModel(str, Enum): | ||
OPUS: str = "claude-3-opus-20240229" | ||
SONNET: str = "claude-3-sonnet-20240229" | ||
HAIKU: str = "claude-3-haiku-20240307" | ||
|
||
|
||
class ClaudeClient: | ||
def __init__( | ||
self, | ||
api_key: str, | ||
model: ClaudeModel, | ||
temperature: float = 0.0, | ||
) -> None: | ||
self._client = Client(api_key=api_key) | ||
self.model = model | ||
|
||
self._validate_model() | ||
|
||
self.temperature = temperature | ||
# model_card = BaseModel.retrieve(model) | ||
logger.info( | ||
f"Using Claude w/ {self.model} (created at " | ||
# {datetime.fromtimestamp(model_card.created).isoformat()})", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if the information is available, it would be worth displaying it :) I was only saying earlier that I'm not sure the attribute is named the same way than groq or openai |
||
) | ||
|
||
def _validate_model(self) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's keep it as a static method (with the corresponding arg, like in groq/openai integration) |
||
input_dict = {"model_type": self.model} | ||
validation_result = BaseModel.model_validate(input_dict) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there is a confusion here: we want to validate the LLM name reference in the API (=model name), not the Pydantic schema model |
||
if not validation_result: | ||
raise ValueError(f"Invalid model: {self.model}") | ||
|
||
def chat( | ||
self, | ||
messages: List[Dict[str, str]], | ||
system: Union[str, None] = None, | ||
) -> Generator[str, None, None]: | ||
# Prepare the request | ||
_system = CHAT_PROMPT if not system else f"{CHAT_PROMPT} {system}" | ||
|
||
stream = self._client.messages.create( | ||
messages=[{"role": "user", "content": _system}, *messages], | ||
model=self.model, | ||
temperature=self.temperature, | ||
max_tokens=2048, | ||
stream=True, | ||
top_p=1.0, | ||
) | ||
|
||
for chunk in stream: | ||
if len(chunk.choices) > 0 and isinstance(chunk.choices[0].delta.content, str): | ||
yield chunk.choices[0].delta.content | ||
if chunk.usage: | ||
logger.info( | ||
f"Claude ({self.model}): {chunk.usage.prompt_tokens} prompt tokens | {chunk.usage.completion_tokens} completion tokens", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still need to verify whether the model selection is available