Skip to content

Commit

Permalink
Mistral backend option
Browse files Browse the repository at this point in the history
  • Loading branch information
radbrt committed Mar 7, 2024
1 parent 9f7eadd commit c65e241
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 20 deletions.
34 changes: 24 additions & 10 deletions dbtai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,36 +52,42 @@ def setup():
),
inquirer.List('backend',
message ="LLM Backend",
choices = ["OpenAI", "Azure OpenAI"],
choices = ["OpenAI", "Azure OpenAI", "Mistral"],
default = "OpenAI"
),
inquirer.List("auth_type",
message = "Authentication Type",
choices = ["API Key", "Native Authentication (DefaultAzureCredential)"],
default = "API Key",
ignore = lambda answers: answers['backend'] == "OpenAI"
ignore = lambda answers: answers['backend'] in ["OpenAI", "Mistral"]
),
inquirer.Text('api_key',
message='OpenAI API Key',
message='API Key',
ignore = lambda answers: answers['auth_type'] == "Native Authentication (DefaultAzureCredential)"
),
inquirer.List("openai_model_name",
message = "Model Name",
choices = ["gpt-3.5-turbo", "gpt-4-turbo-preview"],
default = "gpt-4-turbo-preview",
ignore = lambda answers: answers['backend'] == "Azure OpenAI"
ignore = lambda answers: answers['backend'] != "OpenAI"
),
inquirer.List("mistral_model_name",
message = "Model Name",
choices = ["mistral-large-latest"],
default = "mistral-large-latest",
ignore = lambda answers: answers['backend'] != "Mistral"
),
inquirer.Text("azure_endpoint",
message = "Azure OpenAI Endpoint",
ignore = lambda answers: answers['backend'] == "OpenAI"
ignore = lambda answers: answers['backend'] != "Azure OpenAI"
),
inquirer.Text("azure_openai_model",
message = "Azure OpenAI Model",
ignore = lambda answers: answers['backend'] == "OpenAI"
ignore = lambda answers: answers['backend'] != "Azure OpenAI"
),
inquirer.Text("azure_openai_deployment",
message = "Azure OpenAI Deployment",
ignore = lambda answers: answers['backend'] == "OpenAI"
ignore = lambda answers: answers['backend'] != "Azure OpenAI"
),
]
answer = inquirer.prompt(question)
Expand Down Expand Up @@ -147,8 +153,6 @@ def gen(model_name, description, input):
@click.option("--diff", "-d", is_flag=True, help="Show the diff between existing and suggested code", default=False)
def fix(model_name, description, diff):
manifest = Manifest()
click.echo(model_name)
click.echo(description)

model = manifest.fix(model_name, description)

Expand Down Expand Up @@ -211,4 +215,14 @@ def hello():
\____ | |___ /__| \____|__ /___|
\/ \/ \/
"""
click.echo(greeting)
click.echo(greeting)


@dbtai.command(help="Generate a dbt test")
@click.argument("model", required=True)
@click.argument("description", required=True)
def test(model, description):
raise NotImplementedError("Not yet implemented")
manifest = Manifest()
test = manifest.generate_test(model, description)
click.echo(test)
35 changes: 26 additions & 9 deletions dbtai/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import appdirs
import yaml
from openai import OpenAI
from mistralai.client import MistralClient
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap
import io
Expand Down Expand Up @@ -40,21 +41,26 @@ def __init__(
self.manifest = json.load(file)

self.config = self._load_config()
self.client = self._make_openai_client()

if self.config['backend'] == "Mistral":
self.client = self._make_mistral_client()
elif self.config['backend'] == "Azure OpenAI":
raise NotImplementedError("Azure OpenAI not yet implemented")
else:
self.client = self._make_openai_client()
def _make_openai_client(self):
"""Make the OpenAI client with auth."""

if self.config['backend'] == "OpenAI":
api_key = self.config['api_key'] or os.getenv("OPENAI_API_KEY")
return OpenAI(api_key=api_key)
else:
raise NotImplementedError("Azure OpenAI not yet implemented")

# return OpenAI(
# endpoint=self.config['azure_endpoint'],
# model=self.config['azure_openai_model'],
# deployment=self.config['azure_openai_deployment']
# )
def _make_mistral_client(self):
"""Make the Mistral client with auth."""
client = MistralClient(api_key=self.config['api_key'])
return client


def chat_completion(self, messages, response_format_type="json_object"):
"""Convenience method to call the chat completion endpoint.
Expand All @@ -67,13 +73,24 @@ def chat_completion(self, messages, response_format_type="json_object"):
openai.ChatCompletion: The response from the chat API
"""
if self.config["backend"] == "OpenAI":
if not self.config.get("openai_model_name"):
raise ValueError("OpenAI model name not set in config")

return self.client.chat.completions.create(
model=self.config['openai_model_name'],
model=self.config.get('openai_model_name', 'gpt-4-turbo-preview'),
messages=messages,
response_format={"type": response_format_type}
)
elif self.config["backend"] == "Mistral":

return self.client.chat(
model=self.config.get("mistral_model_name", "mistral-large-latest"),
messages=messages,
response_format={"type": response_format_type},
)

else:
raise NotImplementedError("Azure OpenAI not yet implemented")
raise NotImplementedError("Your backend is set to Azure OpenAI not yet implemented")

def _load_config(self):
"""Convenience function to load the user config from the config file."""
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "dbtai"
version = "0.1.0"
version = "0.2.0"
description = "`dbtai` is a utility CLI command to generate dbt model documentation for a given model using OpenAI."
authors = ["Henning Holgersen"]
keywords = [
Expand All @@ -13,6 +13,7 @@ license = "Apache 2.0"
[tool.poetry.dependencies]
python = "<3.14,>=3.8.0"
openai = ">1.1.0"
mistralai = ">=0.1.3,<2"
click = "^8.1.3"
"ruamel.yaml" = "^0.18.6"
inquirer = "^3.2.4"
Expand Down

0 comments on commit c65e241

Please sign in to comment.