From 79981a0853c84766b499d2fc65b21e59b9f6b706 Mon Sep 17 00:00:00 2001 From: Valery Kharitonov Date: Tue, 12 Dec 2023 00:36:23 -0500 Subject: [PATCH] Add together.xyz endpoint --- gptcli/assistant.py | 22 ++++++++++++++++++++-- gptcli/config.py | 1 + gptcli/gpt.py | 4 ++++ pyproject.toml | 2 ++ 4 files changed, 27 insertions(+), 2 deletions(-) diff --git a/gptcli/assistant.py b/gptcli/assistant.py index 6adac67..1c16ccd 100644 --- a/gptcli/assistant.py +++ b/gptcli/assistant.py @@ -9,6 +9,7 @@ from gptcli.llama import LLaMACompletionProvider from gptcli.openai import OpenAICompletionProvider from gptcli.anthropic import AnthropicCompletionProvider +from gptcli.together import TogetherCompletionProvider class AssistantConfig(TypedDict, total=False): @@ -16,6 +17,11 @@ class AssistantConfig(TypedDict, total=False): model: str temperature: float top_p: float + system_prefix: str + system_suffix: str + user_prefix: str + user_suffix: str + stop_tokens: List[str] CONFIG_DEFAULTS = { @@ -55,7 +61,9 @@ class AssistantConfig(TypedDict, total=False): } -def get_completion_provider(model: str) -> CompletionProvider: +def get_completion_provider( + model: str, assistant_config: AssistantConfig +) -> CompletionProvider: if model.startswith("gpt"): return OpenAICompletionProvider() elif model.startswith("claude"): @@ -64,6 +72,16 @@ def get_completion_provider(model: str) -> CompletionProvider: return LLaMACompletionProvider() elif model.startswith("chat-bison"): return GoogleCompletionProvider() + elif model.startswith("together"): + return TogetherCompletionProvider( + { + "system_prefix": assistant_config.get("system_prefix", ""), + "system_suffix": assistant_config.get("system_suffix", ""), + "user_prefix": assistant_config.get("user_prefix", ""), + "user_suffix": assistant_config.get("user_suffix", ""), + "stop_tokens": assistant_config.get("stop_tokens", []), + } + ) else: raise ValueError(f"Unknown model: {model}") @@ -103,7 +121,7 @@ def complete_chat( self, messages, override_params: ModelOverrides = {}, stream: bool = True ) -> Iterator[str]: model = self._param("model", override_params) - completion_provider = get_completion_provider(model) + completion_provider = get_completion_provider(model, self.config) return completion_provider.complete( messages, { diff --git a/gptcli/config.py b/gptcli/config.py index 3cb9070..b205114 100644 --- a/gptcli/config.py +++ b/gptcli/config.py @@ -22,6 +22,7 @@ class GptCliConfig: openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") anthropic_api_key: Optional[str] = os.environ.get("ANTHROPIC_API_KEY") google_api_key: Optional[str] = os.environ.get("GOOGLE_API_KEY") + together_api_key: Optional[str] = os.environ.get("TOGETHER_API_KEY") log_file: Optional[str] = None log_level: str = "INFO" assistants: Dict[str, AssistantConfig] = {} diff --git a/gptcli/gpt.py b/gptcli/gpt.py index e9634e9..32a4e31 100755 --- a/gptcli/gpt.py +++ b/gptcli/gpt.py @@ -15,6 +15,7 @@ import datetime import google.generativeai as genai import gptcli.anthropic +import gptcli.together from gptcli.assistant import ( Assistant, DEFAULT_ASSISTANTS, @@ -178,6 +179,9 @@ def main(): ) sys.exit(1) + if config.together_api_key: + gptcli.together.api_key = config.together_api_key + if config.anthropic_api_key: gptcli.anthropic.api_key = config.anthropic_api_key diff --git a/pyproject.toml b/pyproject.toml index 612739d..4585571 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,8 @@ dependencies = [ "pytest==7.3.1", "PyYAML==6.0", "rich==13.7.0", + "requests==2.31.0", + "sseclient-py==1.8.0", "tiktoken==0.5.2", "tokenizers==0.15.0", "typing_extensions==4.5.0",