Skip to content

Commit

Permalink
Cleanup + begin adding new embedding gen support
Browse files Browse the repository at this point in the history
  • Loading branch information
robkorn committed Oct 16, 2023
1 parent a69ddb1 commit 6fd505c
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 70 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::base_vector_resources::{BaseVectorResource, VectorResourceBaseType};
use crate::data_tags::{DataTag, DataTagIndex};
use crate::embeddings::Embedding;
use crate::model_type::{EmbeddingModelType, RemoteModel};
use crate::model_type::{EmbeddingModelType, TextEmbeddingsInference};
use crate::resource_errors::VectorResourceError;
use crate::source::VRSource;
use crate::vector_resource::{DataChunk, DataContent, RetrievedDataChunk, TraversalMethod, VRPath, VectorResource};
Expand Down Expand Up @@ -139,7 +139,7 @@ impl DocumentVectorResource {
Embedding::new(&String::new(), vec![]),
Vec::new(),
Vec::new(),
EmbeddingModelType::RemoteModel(RemoteModel::AllMiniLML6v2),
EmbeddingModelType::TextEmbeddingsInference(TextEmbeddingsInference::AllMiniLML6v2),
)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::embeddings::Embedding;
use crate::model_type::{EmbeddingModelType, LocalModel, RemoteModel};
use crate::model_type::{EmbeddingModelType, TextEmbeddingsInference};
use crate::resource_errors::VectorResourceError;
use byteorder::{LittleEndian, ReadBytesExt};
use lazy_static::lazy_static;
Expand Down Expand Up @@ -84,10 +84,7 @@ impl EmbeddingGenerator for RemoteEmbeddingGenerator {
/// Generate an Embedding for an input string by using the external API.
fn generate_embedding(&self, input_string: &str, id: &str) -> Result<Embedding, VectorResourceError> {
// If we're using a Bert model with a Bert-CPP server
if self.model_type == EmbeddingModelType::RemoteModel(RemoteModel::AllMiniLML6v2)
|| self.model_type == EmbeddingModelType::RemoteModel(RemoteModel::AllMiniLML6v2)
|| self.model_type == EmbeddingModelType::RemoteModel(RemoteModel::AllMiniLML6v2)
{
if let EmbeddingModelType::BertCPP(_) = self.model_type {
let vector = self.generate_embedding_bert_cpp(input_string)?;
return Ok(Embedding {
vector,
Expand Down Expand Up @@ -122,7 +119,7 @@ impl RemoteEmbeddingGenerator {
///
/// Expected to have downloaded & be using the AllMiniLML6v2 model.
pub fn new_default() -> RemoteEmbeddingGenerator {
let model_architecture = EmbeddingModelType::RemoteModel(RemoteModel::AllMiniLML6v2);
let model_architecture = EmbeddingModelType::TextEmbeddingsInference(TextEmbeddingsInference::AllMiniLML6v2);
let url = format!("localhost:{}", DEFAULT_LOCAL_EMBEDDINGS_PORT.to_string());
RemoteEmbeddingGenerator {
model_type: model_architecture,
Expand Down
4 changes: 2 additions & 2 deletions shinkai-libs/shinkai-vector-resources/src/map_resource.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::base_vector_resources::{BaseVectorResource, VectorResourceBaseType};
use crate::data_tags::{DataTag, DataTagIndex};
use crate::embeddings::Embedding;
use crate::model_type::{EmbeddingModelType, RemoteModel};
use crate::model_type::{EmbeddingModelType, TextEmbeddingsInference};
use crate::resource_errors::VectorResourceError;
use crate::source::VRSource;
use crate::vector_resource::{DataChunk, DataContent, RetrievedDataChunk, VRPath, VectorResource};
Expand Down Expand Up @@ -136,7 +136,7 @@ impl MapVectorResource {
Embedding::new(&String::new(), vec![]),
HashMap::new(),
HashMap::new(),
EmbeddingModelType::RemoteModel(RemoteModel::AllMiniLML6v2),
EmbeddingModelType::TextEmbeddingsInference(TextEmbeddingsInference::AllMiniLML6v2),
)
}

Expand Down
90 changes: 43 additions & 47 deletions shinkai-libs/shinkai-vector-resources/src/model_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,76 +3,72 @@ use std::fmt;

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum EmbeddingModelType {
// LocalModel(LocalModel),
RemoteModel(RemoteModel),
TextEmbeddingsInference(TextEmbeddingsInference),
BertCPP(BertCPP),
OpenAI(OpenAI),
}

impl fmt::Display for EmbeddingModelType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
// EmbeddingModelType::LocalModel(local_model) => write!(f, "{}", local_model),
EmbeddingModelType::RemoteModel(remote_model) => write!(f, "{}", remote_model),
EmbeddingModelType::TextEmbeddingsInference(model) => write!(f, "{}", model),
EmbeddingModelType::BertCPP(model) => write!(f, "{}", model),
EmbeddingModelType::OpenAI(model) => write!(f, "{}", model),
}
}
}

/// Hugging Face's Text Embeddings Inference
/// (https://github.com/huggingface/text-embeddings-inference)
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum LocalModel {
Bloom,
Gpt2,
GptJ,
GptNeoX,
Llama,
Mpt,
Falcon,
pub enum TextEmbeddingsInference {
AllMiniLML6v2,
AllMiniLML12v2,
MultiQAMiniLML6,
Other(String),
}

/// Bert.CPP (https://github.com/skeskinen/bert.cpp)
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum RemoteModel {
OpenAITextEmbeddingAda002,
pub enum BertCPP {
AllMiniLML6v2,
AllMiniLML12v2,
MultiQAMiniLML6,
Other(String),
}

impl fmt::Display for RemoteModel {
/// OpenAI
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum OpenAI {
OpenAITextEmbeddingAda002,
}

impl fmt::Display for TextEmbeddingsInference {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
RemoteModel::OpenAITextEmbeddingAda002 => write!(f, "text-embedding-ada-002"),
RemoteModel::AllMiniLML6v2 => write!(f, "all-MiniLM-L6-v2"),
RemoteModel::AllMiniLML12v2 => write!(f, "all-MiniLM-L12-v2"),
RemoteModel::MultiQAMiniLML6 => write!(f, "multi-qa-MiniLM-L6-cos-v1"),
RemoteModel::Other(name) => write!(f, "{}", name),
TextEmbeddingsInference::AllMiniLML6v2 => write!(f, "all-MiniLM-L6-v2"),
TextEmbeddingsInference::AllMiniLML12v2 => write!(f, "all-MiniLM-L12-v2"),
TextEmbeddingsInference::MultiQAMiniLML6 => write!(f, "multi-qa-MiniLM-L6-cos-v1"),
TextEmbeddingsInference::Other(name) => write!(f, "{}", name),
}
}
}

// impl LocalModel {
// pub fn from_model_architecture(arch: ModelArchitecture) -> LocalModel {
// match arch {
// ModelArchitecture::Bloom => LocalModel::Bloom,
// ModelArchitecture::Gpt2 => LocalModel::Gpt2,
// ModelArchitecture::GptJ => LocalModel::GptJ,
// ModelArchitecture::GptNeoX => LocalModel::GptNeoX,
// ModelArchitecture::Llama => LocalModel::Llama,
// ModelArchitecture::Mpt => LocalModel::Mpt,
// //ModelArchitecture::Falcon => LocalModel::Falcon, // Falcon not implemented yet in llm crate
// _ => LocalModel::Llama,
// }
// }
// }
impl fmt::Display for BertCPP {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
BertCPP::AllMiniLML6v2 => write!(f, "all-MiniLM-L6-v2"),
BertCPP::AllMiniLML12v2 => write!(f, "all-MiniLM-L12-v2"),
BertCPP::MultiQAMiniLML6 => write!(f, "multi-qa-MiniLM-L6-cos-v1"),
BertCPP::Other(name) => write!(f, "{}", name),
}
}
}

// impl fmt::Display for LocalModel {
// fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// match self {
// LocalModel::Bloom => write!(f, "Bloom"),
// LocalModel::Gpt2 => write!(f, "Gpt2"),
// LocalModel::GptJ => write!(f, "GptJ"),
// LocalModel::GptNeoX => write!(f, "GptNeoX"),
// LocalModel::Llama => write!(f, "Llama"),
// LocalModel::Mpt => write!(f, "Mpt"),
// LocalModel::Falcon => write!(f, "Falcon"),
// }
// }
// }
impl fmt::Display for OpenAI {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
OpenAI::OpenAITextEmbeddingAda002 => write!(f, "text-embedding-ada-002"),
}
}
}
24 changes: 11 additions & 13 deletions tests/encrypted_files_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ fn sandwich_messages_with_files_test() {
perform_locally: false,
external_url: Some("https://api.openai.com".to_string()),
// external_url: Some(server.url()),
api_key: Some("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".to_string()),
api_key: Some("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".to_string()),
// api_key: Some("mockapikey".to_string()),
model: AgentLLMInterface::OpenAI(open_ai),
toolkit_permissions: vec![],
Expand Down Expand Up @@ -397,20 +397,18 @@ fn sandwich_messages_with_files_test() {
eprintln!("node1_last_messages: {:?}", node1_last_messages);

match node1_last_messages[0].get_message_content() {
Ok(message_content) => {
match serde_json::from_str::<JobMessage>(&message_content) {
Ok(job_message) => {
eprintln!("message_content: {}", message_content);
if job_message.content != job_message_content {
assert!(true);
break;
}
}
Err(_) => {
eprintln!("error: message_content: {}", message_content);
Ok(message_content) => match serde_json::from_str::<JobMessage>(&message_content) {
Ok(job_message) => {
eprintln!("message_content: {}", message_content);
if job_message.content != job_message_content {
assert!(true);
break;
}
}
}
Err(_) => {
eprintln!("error: message_content: {}", message_content);
}
},
Err(_) => {
// nothing
}
Expand Down

0 comments on commit 6fd505c

Please sign in to comment.