Skip to content

Commit

Permalink
feat(2658): introduce Adapter (tailcallhq#2659)
Browse files Browse the repository at this point in the history
Co-authored-by: Tushar Mathur <[email protected]>
  • Loading branch information
ssddOnTop and tusharmath authored Aug 12, 2024
1 parent 5d3f1ad commit e1d7f69
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 8 deletions.
7 changes: 6 additions & 1 deletion src/cli/generator/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,12 @@ impl Generator {
let mut config = config_gen.generate(true)?;

if infer_type_names {
let mut llm_gen = InferTypeName::default();
let key = self
.runtime
.env
.get("TAILCALL_SECRET")
.map(|s| s.into_owned());
let mut llm_gen = InferTypeName::new(key);
let suggested_names = llm_gen.generate(config.config()).await?;
let cfg = RenameTypes::new(suggested_names.iter())
.transform(config.config().to_owned())
Expand Down
14 changes: 10 additions & 4 deletions src/cli/llm/infer_type_name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ use std::collections::HashMap;
use genai::chat::{ChatMessage, ChatRequest, ChatResponse};
use serde::{Deserialize, Serialize};

use super::model::groq;
use super::{Error, Result, Wizard};
use crate::core::config::Config;

const MODEL: &str = "llama3-8b-8192";

#[derive(Default)]
pub struct InferTypeName {}
pub struct InferTypeName {
secret: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct Answer {
Expand Down Expand Up @@ -73,8 +74,13 @@ impl TryInto<ChatRequest> for Question {
}

impl InferTypeName {
pub fn new(secret: Option<String>) -> InferTypeName {
Self { secret }
}
pub async fn generate(&mut self, config: &Config) -> Result<HashMap<String, String>> {
let wizard: Wizard<Question, Answer> = Wizard::new(MODEL.to_string());
let secret = self.secret.as_ref().map(|s| s.to_owned());

let wizard: Wizard<Question, Answer> = Wizard::new(groq::LLAMA38192, secret);

let mut new_name_mappings: HashMap<String, String> = HashMap::new();

Expand Down
2 changes: 2 additions & 0 deletions src/cli/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,7 @@ pub mod infer_type_name;
pub use error::Error;
use error::Result;
pub use infer_type_name::InferTypeName;
mod model;
mod wizard;

pub use wizard::Wizard;
73 changes: 73 additions & 0 deletions src/cli/llm/model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#![allow(unused)]

use std::borrow::Cow;
use std::fmt::{Display, Formatter};
use std::marker::PhantomData;

use derive_setters::Setters;
use genai::adapter::AdapterKind;

#[derive(Clone)]
pub struct Model(&'static str);

pub mod open_ai {
use super::*;
pub const GPT3_5_TURBO: Model = Model("gp-3.5-turbo");
pub const GPT4: Model = Model("gpt-4");
pub const GPT4_TURBO: Model = Model("gpt-4-turbo");
pub const GPT4O_MINI: Model = Model("gpt-4o-mini");
pub const GPT4O: Model = Model("gpt-4o");
}

pub mod ollama {
use super::*;
pub const GEMMA2B: Model = Model("gemma:2b");
}

pub mod anthropic {
use super::*;
pub const CLAUDE3_HAIKU_20240307: Model = Model("claude-3-haiku-20240307");
pub const CLAUDE3_SONNET_20240229: Model = Model("claude-3-sonnet-20240229");
pub const CLAUDE3_OPUS_20240229: Model = Model("claude-3-opus-20240229");
pub const CLAUDE35_SONNET_20240620: Model = Model("claude-3-5-sonnet-20240620");
}

pub mod cohere {
use super::*;
pub const COMMAND_LIGHT_NIGHTLY: Model = Model("command-light-nightly");
pub const COMMAND_LIGHT: Model = Model("command-light");
pub const COMMAND_NIGHTLY: Model = Model("command-nightly");
pub const COMMAND: Model = Model("command");
pub const COMMAND_R: Model = Model("command-r");
pub const COMMAND_R_PLUS: Model = Model("command-r-plus");
}

pub mod gemini {
use super::*;
pub const GEMINI15_FLASH_LATEST: Model = Model("gemini-1.5-flash-latest");
pub const GEMINI10_PRO: Model = Model("gemini-1.0-pro");
pub const GEMINI15_FLASH: Model = Model("gemini-1.5-flash");
pub const GEMINI15_PRO: Model = Model("gemini-1.5-pro");
}

pub mod groq {
use super::*;
pub const LLAMA708192: Model = Model("llama3-70b-8192");
pub const LLAMA38192: Model = Model("llama3-8b-8192");
pub const LLAMA_GROQ8B8192_TOOL_USE_PREVIEW: Model =
Model("llama3-groq-8b-8192-tool-use-preview");
pub const LLAMA_GROQ70B8192_TOOL_USE_PREVIEW: Model =
Model("llama3-groq-70b-8192-tool-use-preview");
pub const GEMMA29B_IT: Model = Model("gemma2-9b-it");
pub const GEMMA7B_IT: Model = Model("gemma-7b-it");
pub const MIXTRAL_8X7B32768: Model = Model("mixtral-8x7b-32768");
pub const LLAMA8B_INSTANT: Model = Model("llama-3.1-8b-instant");
pub const LLAMA70B_VERSATILE: Model = Model("llama-3.1-70b-versatile");
pub const LLAMA405B_REASONING: Model = Model("llama-3.1-405b-reasoning");
}

impl Model {
pub fn as_str(&self) -> &'static str {
self.0
}
}
15 changes: 12 additions & 3 deletions src/cli/llm/wizard.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,36 @@
use derive_setters::Setters;
use genai::adapter::AdapterKind;
use genai::chat::{ChatOptions, ChatRequest, ChatResponse};
use genai::Client;

use super::Result;
use crate::cli::llm::model::Model;

#[derive(Setters, Clone)]
pub struct Wizard<Q, A> {
client: Client,
// TODO: change model to enum
model: String,
model: Model,
_q: std::marker::PhantomData<Q>,
_a: std::marker::PhantomData<A>,
}

impl<Q, A> Wizard<Q, A> {
pub fn new(model: String) -> Self {
pub fn new(model: Model, secret: Option<String>) -> Self {
let mut config = genai::adapter::AdapterConfig::default();
if let Some(key) = secret {
config = config.with_auth_env_name(key);
}

let adapter = AdapterKind::from_model(model.as_str()).unwrap_or(AdapterKind::Ollama);

Self {
client: Client::builder()
.with_chat_options(
ChatOptions::default()
.with_json_mode(true)
.with_temperature(0.0),
)
.insert_adapter_config(adapter, config)
.build(),
model,
_q: Default::default(),
Expand Down

0 comments on commit e1d7f69

Please sign in to comment.