diff --git a/crates/http-api-bindings/src/embedding/llama.rs b/crates/http-api-bindings/src/embedding/llama.rs index a478cbf0fc34..31d444faca7f 100644 --- a/crates/http-api-bindings/src/embedding/llama.rs +++ b/crates/http-api-bindings/src/embedding/llama.rs @@ -1,8 +1,9 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use tabby_inference::Embedding; +use tracing::Instrument; -use crate::create_reqwest_client; +use crate::{create_reqwest_client, embedding_info_span}; pub struct LlamaCppEngine { client: reqwest::Client, @@ -44,7 +45,10 @@ impl Embedding for LlamaCppEngine { request = request.bearer_auth(api_key); } - let response = request.send().await?; + let response = request + .send() + .instrument(embedding_info_span!("llamacpp")) + .await?; if response.status().is_server_error() { let error = response.text().await?; return Err(anyhow::anyhow!( diff --git a/crates/http-api-bindings/src/embedding/mod.rs b/crates/http-api-bindings/src/embedding/mod.rs index f6e3c9695b60..c68781e3eb6e 100644 --- a/crates/http-api-bindings/src/embedding/mod.rs +++ b/crates/http-api-bindings/src/embedding/mod.rs @@ -1,15 +1,14 @@ mod llama; mod openai; -mod voyage; use core::panic; use std::sync::Arc; use llama::LlamaCppEngine; +use openai::OpenAIEmbeddingEngine; use tabby_common::config::HttpModelConfig; use tabby_inference::Embedding; -use self::{openai::OpenAIEmbeddingEngine, voyage::VoyageEmbeddingEngine}; use super::rate_limit; pub async fn create(config: &HttpModelConfig) -> Arc { @@ -30,16 +29,16 @@ pub async fn create(config: &HttpModelConfig) -> Arc { config.api_key.as_deref(), ), "ollama/embedding" => ollama_api_bindings::create_embedding(config).await, - "voyage/embedding" => VoyageEmbeddingEngine::create( - config.api_endpoint.as_deref(), + "voyage/embedding" => OpenAIEmbeddingEngine::create( + config + .api_endpoint + .as_deref() + .unwrap_or("https://api.voyageai.com/v1"), config .model_name .as_deref() .expect("model_name must be set for voyage/embedding"), - config - .api_key - .clone() - .expect("api_key must be set for voyage/embedding"), + config.api_key.as_deref(), ), unsupported_kind => panic!( "Unsupported kind for http embedding model: {}", @@ -52,3 +51,10 @@ pub async fn create(config: &HttpModelConfig) -> Arc { config.rate_limit.request_per_minute, )) } + +#[macro_export] +macro_rules! embedding_info_span { + ($kind:expr) => { + tracing::info_span!("embedding", kind = $kind) + }; +} diff --git a/crates/http-api-bindings/src/embedding/openai.rs b/crates/http-api-bindings/src/embedding/openai.rs index 0df9327e504a..ad8f77a9548f 100644 --- a/crates/http-api-bindings/src/embedding/openai.rs +++ b/crates/http-api-bindings/src/embedding/openai.rs @@ -1,14 +1,16 @@ use anyhow::Context; -use async_openai::{ - config::OpenAIConfig, - types::{CreateEmbeddingRequest, EmbeddingInput}, -}; use async_trait::async_trait; +use reqwest::Client; +use serde::{Deserialize, Serialize}; use tabby_inference::Embedding; -use tracing::{info_span, Instrument}; +use tracing::Instrument; + +use crate::embedding_info_span; pub struct OpenAIEmbeddingEngine { - client: async_openai::Client, + client: Client, + api_endpoint: String, + api_key: String, model_name: String, } @@ -18,41 +20,69 @@ impl OpenAIEmbeddingEngine { model_name: &str, api_key: Option<&str>, ) -> Box { - let config = OpenAIConfig::default() - .with_api_base(api_endpoint) - .with_api_key(api_key.unwrap_or_default()); - - let client = async_openai::Client::with_config(config); - + let client = Client::new(); Box::new(Self { client, + api_endpoint: format!("{}/embeddings", api_endpoint), + api_key: api_key.unwrap_or_default().to_owned(), model_name: model_name.to_owned(), }) } } +#[derive(Debug, Serialize)] +struct EmbeddingRequest { + input: Vec, + model: String, +} + +#[derive(Debug, Deserialize)] +struct EmbeddingResponse { + data: Vec, +} + +#[derive(Debug, Deserialize)] +struct EmbeddingData { + embedding: Vec, +} + #[async_trait] impl Embedding for OpenAIEmbeddingEngine { async fn embed(&self, prompt: &str) -> anyhow::Result> { - let request = CreateEmbeddingRequest { + let request = EmbeddingRequest { + input: vec![prompt.to_owned()], model: self.model_name.clone(), - input: EmbeddingInput::String(prompt.to_owned()), - encoding_format: None, - user: None, - dimensions: None, }; - let resp = self + + let request_builder = self .client - .embeddings() - .create(request) - .instrument(info_span!("embedding", kind = "openai")) + .post(&self.api_endpoint) + .json(&request) + .header("content-type", "application/json") + .bearer_auth(&self.api_key); + + let response = request_builder + .send() + .instrument(embedding_info_span!("openai")) .await?; - let data = resp + + if !response.status().is_success() { + let status = response.status(); + let error = response.text().await?; + return Err(anyhow::anyhow!("Error {}: {}", status.as_u16(), error)); + } + + let response_body = response + .json::() + .await + .context("Failed to parse response body")?; + + response_body .data .into_iter() .next() - .context("Failed to get embedding")?; - Ok(data.embedding) + .map(|data| data.embedding) + .ok_or_else(|| anyhow::anyhow!("No embedding data found")) } } @@ -73,4 +103,17 @@ mod tests { let embedding = engine.embed("Hello, world!").await.unwrap(); assert_eq!(embedding.len(), 768); } + + #[tokio::test] + #[ignore] + async fn test_voyage_embedding() { + let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY must be set"); + let engine = OpenAIEmbeddingEngine::create( + "https://api.voyageai.com/v1", + "voyage-code-2", + Some(&api_key), + ); + let embedding = engine.embed("Hello, world!").await.unwrap(); + assert_eq!(embedding.len(), 1536); + } } diff --git a/crates/http-api-bindings/src/embedding/voyage.rs b/crates/http-api-bindings/src/embedding/voyage.rs deleted file mode 100644 index 0ca9d2d8e050..000000000000 --- a/crates/http-api-bindings/src/embedding/voyage.rs +++ /dev/null @@ -1,98 +0,0 @@ -use anyhow::Context; -use async_trait::async_trait; -use reqwest::Client; -use serde::{Deserialize, Serialize}; -use tabby_inference::Embedding; - -const DEFAULT_VOYAGE_API_ENDPOINT: &str = "https://api.voyageai.com"; - -pub struct VoyageEmbeddingEngine { - client: Client, - api_endpoint: String, - api_key: String, - model_name: String, -} - -impl VoyageEmbeddingEngine { - pub fn create( - api_endpoint: Option<&str>, - model_name: &str, - api_key: String, - ) -> Box { - let api_endpoint = api_endpoint.unwrap_or(DEFAULT_VOYAGE_API_ENDPOINT); - let client = Client::new(); - Box::new(Self { - client, - api_endpoint: format!("{}/v1/embeddings", api_endpoint), - api_key, - model_name: model_name.to_owned(), - }) - } -} - -#[derive(Debug, Serialize)] -struct EmbeddingRequest { - input: Vec, - model: String, -} - -#[derive(Debug, Deserialize)] -struct EmbeddingResponse { - data: Vec, -} - -#[derive(Debug, Deserialize)] -struct EmbeddingData { - embedding: Vec, -} - -#[async_trait] -impl Embedding for VoyageEmbeddingEngine { - async fn embed(&self, prompt: &str) -> anyhow::Result> { - let request = EmbeddingRequest { - input: vec![prompt.to_owned()], - model: self.model_name.clone(), - }; - - let request_builder = self - .client - .post(&self.api_endpoint) - .json(&request) - .header("content-type", "application/json") - .bearer_auth(&self.api_key); - - let response = request_builder.send().await?; - if !response.status().is_success() { - let status = response.status(); - let error = response.text().await?; - return Err(anyhow::anyhow!("Error {}: {}", status.as_u16(), error)); - } - - let response_body = response - .json::() - .await - .context("Failed to parse response body")?; - - response_body - .data - .into_iter() - .next() - .map(|data| data.embedding) - .ok_or_else(|| anyhow::anyhow!("No embedding data found")) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - /// VOYAGE_API_KEY=xxx cargo test test_voyage_embedding -- --ignored - #[tokio::test] - #[ignore] - async fn test_voyage_embedding() { - let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY must be set"); - let engine = VoyageEmbeddingEngine::create(None, "voyage-code-2", api_key); - let embedding = engine.embed("Hello, world!").await.unwrap(); - assert_eq!(embedding.len(), 1536); - } -}