diff --git a/README.md b/README.md index db40671..55d1316 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,9 @@ Command-line interface for chat LLMs. - OpenAI - Anthropic +- Google Gemini - Cohere -- Other APIs compatible with OpenAI +- Other APIs compatible with OpenAI (e.g. Together, OpenRouter, local models with LM Studio) ![screenshot](https://github.com/kharvd/gpt-cli/assets/466920/ecbcccc4-7cfa-4c04-83c3-a822b6596f01) @@ -214,8 +215,32 @@ or a config line in `~/.config/gpt-cli/gpt.yml`: anthropic_api_key: ``` -Now you should be able to run `gpt` with `--model claude-v1` or `--model claude-instant-v1`: +Now you should be able to run `gpt` with `--model claude-3-(opus|sonnet|haiku)-`. ```bash -gpt --model claude-v1 +gpt --model claude-3-opus-20240229 +``` + +### Google Gemini + +```bash +export GOOGLE_API_KEY= +``` + +or + +```yaml +google_api_key: +``` + +### Cohere + +```bash +export COHERE_API_KEY= +``` + +or + +```yaml +cohere_api_key: ``` diff --git a/gptcli/assistant.py b/gptcli/assistant.py index 47fdc02..83a05a9 100644 --- a/gptcli/assistant.py +++ b/gptcli/assistant.py @@ -10,6 +10,7 @@ ModelOverrides, Message, ) +from gptcli.providers.google import GoogleCompletionProvider from gptcli.providers.llama import LLaMACompletionProvider from gptcli.providers.openai import OpenAICompletionProvider from gptcli.providers.anthropic import AnthropicCompletionProvider @@ -78,6 +79,8 @@ def get_completion_provider(model: str) -> CompletionProvider: return LLaMACompletionProvider() elif model.startswith("command") or model.startswith("c4ai"): return CohereCompletionProvider() + elif model.startswith("gemini"): + return GoogleCompletionProvider() else: raise ValueError(f"Unknown model: {model}") diff --git a/gptcli/config.py b/gptcli/config.py index b6c83cd..a23753b 100644 --- a/gptcli/config.py +++ b/gptcli/config.py @@ -22,7 +22,7 @@ class GptCliConfig: openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") openai_base_url: Optional[str] = os.environ.get("OPENAI_BASE_URL") anthropic_api_key: Optional[str] = os.environ.get("ANTHROPIC_API_KEY") - google_api_key: Optional[str] = os.environ.get("GOOGLE_API_KEY") # deprecated + google_api_key: Optional[str] = os.environ.get("GOOGLE_API_KEY") cohere_api_key: Optional[str] = os.environ.get("COHERE_API_KEY") log_file: Optional[str] = None log_level: str = "INFO" diff --git a/gptcli/gpt.py b/gptcli/gpt.py index 232d9d5..897694e 100755 --- a/gptcli/gpt.py +++ b/gptcli/gpt.py @@ -9,6 +9,7 @@ import os from typing import cast import openai +import google.generativeai as genai import argparse import sys import logging @@ -189,6 +190,9 @@ def main(): if config.cohere_api_key: gptcli.providers.cohere.api_key = config.cohere_api_key + if config.google_api_key: + genai.configure(api_key=config.google_api_key) + if config.llama_models is not None: init_llama_models(config.llama_models) diff --git a/gptcli/providers/google.py b/gptcli/providers/google.py new file mode 100644 index 0000000..e73df24 --- /dev/null +++ b/gptcli/providers/google.py @@ -0,0 +1,110 @@ +import google.generativeai as genai +from google.generativeai.types.content_types import ContentDict +from google.generativeai.types.generation_types import GenerationConfig +from google.generativeai.types.safety_types import ( + HarmBlockThreshold, + HarmCategory, +) +from typing import Iterator, List, Optional + +from gptcli.completion import ( + CompletionEvent, + CompletionProvider, + Message, + MessageDeltaEvent, + Pricing, + UsageEvent, +) + +ROLE_MAP = { + "user": "user", + "assistant": "model", +} + + +def map_message(message: Message) -> ContentDict: + return {"role": ROLE_MAP[message["role"]], "parts": [message["content"]]} + + +SAFETY_SETTINGS = [ + {"category": category, "threshold": HarmBlockThreshold.BLOCK_NONE} + for category in [ + HarmCategory.HARM_CATEGORY_HARASSMENT, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + HarmCategory.HARM_CATEGORY_HATE_SPEECH, + ] +] + + +class GoogleCompletionProvider(CompletionProvider): + def complete( + self, messages: List[Message], args: dict, stream: bool = False + ) -> Iterator[CompletionEvent]: + generation_config = GenerationConfig( + temperature=args.get("temperature"), + top_p=args.get("top_p"), + ) + + model_name = args["model"] + + if messages[0]["role"] == "system": + system_instruction = messages[0]["content"] + messages = messages[1:] + else: + system_instruction = None + + chat_history = [map_message(m) for m in messages] + + model = genai.GenerativeModel(model_name, system_instruction=system_instruction) + + if stream: + response = model.generate_content( + chat_history, + generation_config=generation_config, + safety_settings=SAFETY_SETTINGS, + stream=True, + ) + + for chunk in response: + yield MessageDeltaEvent(chunk.text) + + else: + response = model.generate_content( + chat_history, + generation_config=generation_config, + safety_settings=SAFETY_SETTINGS, + ) + yield MessageDeltaEvent(response.text) + + prompt_tokens = response.usage_metadata.prompt_token_count + completion_tokens = response.usage_metadata.candidates_token_count + total_tokens = prompt_tokens + completion_tokens + pricing = get_gemini_pricing(model_name, prompt_tokens) + if pricing: + yield UsageEvent.with_pricing( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + pricing=pricing, + ) + + +def get_gemini_pricing(model: str, prompt_tokens: int) -> Optional[Pricing]: + if model.startswith("gemini-1.5-flash"): + return { + "prompt": (0.35 if prompt_tokens < 128000 else 0.7) / 1_000_000, + "response": (1.05 if prompt_tokens < 128000 else 2.10) / 1_000_000, + } + elif model.startswith("gemini-1.5-pro"): + return { + "prompt": (3.50 if prompt_tokens < 128000 else 7.00) / 1_000_000, + "response": (10.5 if prompt_tokens < 128000 else 21.0) / 1_000_000, + } + elif model.startswith("gemini-pro"): + return { + "prompt": 0.50 / 1_000_000, + "response": 1.50 / 1_000_000, + } + else: + return None diff --git a/pyproject.toml b/pyproject.toml index d2cb1e7..e3024c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "attrs==23.2.0", "black==24.4.2", "cohere==5.5.3", + "google-generativeai==0.5.4", "mistralai==0.1.8", "openai==1.30.1", "prompt-toolkit==3.0.43",