Skip to content

Commit

Permalink
Add support for Gemini
Browse files Browse the repository at this point in the history
  • Loading branch information
kharvd committed May 29, 2024
1 parent bbb2447 commit 61aa610
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 4 deletions.
31 changes: 28 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ Command-line interface for chat LLMs.

- OpenAI
- Anthropic
- Google Gemini
- Cohere
- Other APIs compatible with OpenAI
- Other APIs compatible with OpenAI (e.g. Together, OpenRouter, local models with LM Studio)

![screenshot](https://github.com/kharvd/gpt-cli/assets/466920/ecbcccc4-7cfa-4c04-83c3-a822b6596f01)

Expand Down Expand Up @@ -214,8 +215,32 @@ or a config line in `~/.config/gpt-cli/gpt.yml`:
anthropic_api_key: <your_key_here>
```
Now you should be able to run `gpt` with `--model claude-v1` or `--model claude-instant-v1`:
Now you should be able to run `gpt` with `--model claude-3-(opus|sonnet|haiku)-<date>`.

```bash
gpt --model claude-v1
gpt --model claude-3-opus-20240229
```

### Google Gemini

```bash
export GOOGLE_API_KEY=<your_key_here>
```

or

```yaml
google_api_key: <your_key_here>
```

### Cohere

```bash
export COHERE_API_KEY=<your_key_here>
```

or

```yaml
cohere_api_key: <your_key_here>
```
3 changes: 3 additions & 0 deletions gptcli/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ModelOverrides,
Message,
)
from gptcli.providers.google import GoogleCompletionProvider
from gptcli.providers.llama import LLaMACompletionProvider
from gptcli.providers.openai import OpenAICompletionProvider
from gptcli.providers.anthropic import AnthropicCompletionProvider
Expand Down Expand Up @@ -78,6 +79,8 @@ def get_completion_provider(model: str) -> CompletionProvider:
return LLaMACompletionProvider()
elif model.startswith("command") or model.startswith("c4ai"):
return CohereCompletionProvider()
elif model.startswith("gemini"):
return GoogleCompletionProvider()
else:
raise ValueError(f"Unknown model: {model}")

Expand Down
2 changes: 1 addition & 1 deletion gptcli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class GptCliConfig:
openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
openai_base_url: Optional[str] = os.environ.get("OPENAI_BASE_URL")
anthropic_api_key: Optional[str] = os.environ.get("ANTHROPIC_API_KEY")
google_api_key: Optional[str] = os.environ.get("GOOGLE_API_KEY") # deprecated
google_api_key: Optional[str] = os.environ.get("GOOGLE_API_KEY")
cohere_api_key: Optional[str] = os.environ.get("COHERE_API_KEY")
log_file: Optional[str] = None
log_level: str = "INFO"
Expand Down
4 changes: 4 additions & 0 deletions gptcli/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
from typing import cast
import openai
import google.generativeai as genai
import argparse
import sys
import logging
Expand Down Expand Up @@ -189,6 +190,9 @@ def main():
if config.cohere_api_key:
gptcli.providers.cohere.api_key = config.cohere_api_key

if config.google_api_key:
genai.configure(api_key=config.google_api_key)

if config.llama_models is not None:
init_llama_models(config.llama_models)

Expand Down
110 changes: 110 additions & 0 deletions gptcli/providers/google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import google.generativeai as genai
from google.generativeai.types.content_types import ContentDict
from google.generativeai.types.generation_types import GenerationConfig
from google.generativeai.types.safety_types import (
HarmBlockThreshold,
HarmCategory,
)
from typing import Iterator, List, Optional

from gptcli.completion import (
CompletionEvent,
CompletionProvider,
Message,
MessageDeltaEvent,
Pricing,
UsageEvent,
)

ROLE_MAP = {
"user": "user",
"assistant": "model",
}


def map_message(message: Message) -> ContentDict:
return {"role": ROLE_MAP[message["role"]], "parts": [message["content"]]}


SAFETY_SETTINGS = [
{"category": category, "threshold": HarmBlockThreshold.BLOCK_NONE}
for category in [
HarmCategory.HARM_CATEGORY_HARASSMENT,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
HarmCategory.HARM_CATEGORY_HATE_SPEECH,
]
]


class GoogleCompletionProvider(CompletionProvider):
def complete(
self, messages: List[Message], args: dict, stream: bool = False
) -> Iterator[CompletionEvent]:
generation_config = GenerationConfig(
temperature=args.get("temperature"),
top_p=args.get("top_p"),
)

model_name = args["model"]

if messages[0]["role"] == "system":
system_instruction = messages[0]["content"]
messages = messages[1:]
else:
system_instruction = None

chat_history = [map_message(m) for m in messages]

model = genai.GenerativeModel(model_name, system_instruction=system_instruction)

if stream:
response = model.generate_content(
chat_history,
generation_config=generation_config,
safety_settings=SAFETY_SETTINGS,
stream=True,
)

for chunk in response:
yield MessageDeltaEvent(chunk.text)

else:
response = model.generate_content(
chat_history,
generation_config=generation_config,
safety_settings=SAFETY_SETTINGS,
)
yield MessageDeltaEvent(response.text)

prompt_tokens = response.usage_metadata.prompt_token_count
completion_tokens = response.usage_metadata.candidates_token_count
total_tokens = prompt_tokens + completion_tokens
pricing = get_gemini_pricing(model_name, prompt_tokens)
if pricing:
yield UsageEvent.with_pricing(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
pricing=pricing,
)


def get_gemini_pricing(model: str, prompt_tokens: int) -> Optional[Pricing]:
if model.startswith("gemini-1.5-flash"):
return {
"prompt": (0.35 if prompt_tokens < 128000 else 0.7) / 1_000_000,
"response": (1.05 if prompt_tokens < 128000 else 2.10) / 1_000_000,
}
elif model.startswith("gemini-1.5-pro"):
return {
"prompt": (3.50 if prompt_tokens < 128000 else 7.00) / 1_000_000,
"response": (10.5 if prompt_tokens < 128000 else 21.0) / 1_000_000,
}
elif model.startswith("gemini-pro"):
return {
"prompt": 0.50 / 1_000_000,
"response": 1.50 / 1_000_000,
}
else:
return None
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"attrs==23.2.0",
"black==24.4.2",
"cohere==5.5.3",
"google-generativeai==0.5.4",
"mistralai==0.1.8",
"openai==1.30.1",
"prompt-toolkit==3.0.43",
Expand Down

0 comments on commit 61aa610

Please sign in to comment.