Skip to content

Commit

Permalink
Further merge updates
Browse files Browse the repository at this point in the history
  • Loading branch information
robkorn committed Jan 26, 2024
1 parent a16f2a1 commit 84eeae6
Show file tree
Hide file tree
Showing 14 changed files with 38 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,25 @@ use std::fmt;
/// and VecFS entries (source/vector resource stored in the DB, accessible to all jobs)
pub struct JobScope {
pub local: Vec<LocalScopeEntry>,
pub vec_fs: Vec<VectorFSScopeEntry>,
pub vector_fs: Vec<VectorFSScopeEntry>,
}

impl JobScope {}
impl JobScope {
pub fn new(local: Vec<LocalScopeEntry>, vec_fs: Vec<VectorFSScopeEntry>) -> Self {
Self { local, vec_fs }
pub fn new(local: Vec<LocalScopeEntry>, vector_fs: Vec<VectorFSScopeEntry>) -> Self {
Self { local, vector_fs }
}

pub fn new_default() -> Self {
Self {
local: Vec::new(),
vec_fs: Vec::new(),
vector_fs: Vec::new(),
}
}

/// Checks if the Job Scope is empty (has no entries pointing to VRs)
pub fn is_empty(&self) -> bool {
self.local.is_empty() && self.database.is_empty()
self.local.is_empty() && self.vector_fs.is_empty()
}

pub fn to_bytes(&self) -> serde_json::Result<Vec<u8>> {
Expand Down Expand Up @@ -65,15 +65,15 @@ impl fmt::Debug for JobScope {
})
.collect();

let vec_fs_ids: Vec<String> = self
.vec_fs
let vector_fs_ids: Vec<String> = self
.vector_fs
.iter()
.map(|entry| entry.resource_header.reference_string())
.collect();

f.debug_struct("JobScope")
.field("local", &format_args!("{:?}", local_ids))
.field("vector_fs", &format_args!("{:?}", vec_fs_ids))
.field("vector_fs", &format_args!("{:?}", vector_fs_ids))
.finish()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@ impl PyJobScope {
}

pub fn is_empty(&self) -> bool {
self.inner.local.is_empty() && self.inner.vec_fs.is_empty()
self.inner.local.is_empty() && self.inner.vector_fs.is_empty()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_job_creation(self):
result_json = json.loads(result)

# Add assertions to check the fields of the result
self.assertEqual(result_json["body"]["unencrypted"]["message_data"]["unencrypted"]["message_raw_content"], "{\"scope\":{\"local\":[],\"database\":[]}}")
self.assertEqual(result_json["body"]["unencrypted"]["message_data"]["unencrypted"]["message_raw_content"], "{\"scope\":{\"local\":[],\"vector_fs\":[]}}")
self.assertEqual(result_json["body"]["unencrypted"]["message_data"]["unencrypted"]["message_content_schema"], "JobCreationSchema")
self.assertEqual(result_json["body"]["unencrypted"]["internal_metadata"]["sender_subidentity"], "main")
self.assertEqual(result_json["body"]["unencrypted"]["internal_metadata"]["recipient_subidentity"], "main/agent/agent_1")
Expand Down
12 changes: 6 additions & 6 deletions src/agent/execution/job_execution_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ impl JobManager {
agent_found: Option<SerializedAgent>,
full_job: &mut Job,
profile: ShinkaiName,
save_to_vec_fs_folder: Option<VRPath>,
save_to_vector_fs_folder: Option<VRPath>,
generator: RemoteEmbeddingGenerator,
unstructured_api: UnstructuredAPI,
) -> Result<(), AgentError> {
Expand All @@ -396,7 +396,7 @@ impl JobManager {
agent_found,
job_message.files_inbox.clone(),
profile,
save_to_vec_fs_folder,
save_to_vector_fs_folder,
generator,
unstructured_api,
)
Expand All @@ -418,8 +418,8 @@ impl JobManager {
}
}
ScopeEntry::VectorFS(fs_entry) => {
if !full_job.scope.vec_fs.contains(&fs_entry) {
full_job.scope.vec_fs.push(fs_entry);
if !full_job.scope.vector_fs.contains(&fs_entry) {
full_job.scope.vector_fs.push(fs_entry);
} else {
shinkai_log(
ShinkaiLogOption::JobExecution,
Expand Down Expand Up @@ -455,7 +455,7 @@ impl JobManager {
agent: Option<SerializedAgent>,
files_inbox: String,
profile: ShinkaiName,
save_to_vec_fs_folder: Option<VRPath>,
save_to_vector_fs_folder: Option<VRPath>,
generator: RemoteEmbeddingGenerator,
unstructured_api: UnstructuredAPI,
) -> Result<HashMap<String, ScopeEntry>, AgentError> {
Expand Down Expand Up @@ -499,7 +499,7 @@ impl JobManager {

// Now create Local/VectorFSScopeEntry depending on setting
let text_chunking_strategy = TextChunkingStrategy::V1;
if let Some(folder_path) = &save_to_vec_fs_folder {
if let Some(folder_path) = &save_to_vector_fs_folder {
// TODO: Save to VectorFS
let resource_header = resource.as_trait_object().generate_resource_header();
let fs_scope_entry = VectorFSScopeEntry {
Expand Down
2 changes: 1 addition & 1 deletion src/agent/execution/job_vector_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl JobManager {

// Fetch DB resources and add them to the list
// let db = db.lock().await;
// for fs_entry in &job_scope.vec_fs {
// for fs_entry in &job_scope.vector_fs {
// resources.push(resource);
// }
// std::mem::drop(db);
Expand Down
1 change: 0 additions & 1 deletion src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ pub mod db_identity;
pub mod db_identity_registration;
pub mod db_inbox;
pub mod db_inbox_get_messages;
pub mod db_inbox_get_messages;
pub mod db_job_queue;
pub mod db_jobs;
pub mod db_profile_bound;
Expand Down
4 changes: 2 additions & 2 deletions src/vector_fs/db/fs_db.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::super::{vector_fs_error::VectorFSError, vector_fs_internals::VectorFSInternals};
use crate::db::ShinkaiDB;
use crate::db::db_profile_bound::ProfileBoundWriteBatch;
use crate::db::ShinkaiDB;
use rand::Rng;
use rand::{distributions::Alphanumeric, thread_rng};
use rocksdb::{
Expand Down Expand Up @@ -80,7 +80,7 @@ impl VectorFSDB {
.take(8)
.map(char::from)
.collect();
let db_path = format!("db_tests/empty_vec_fs_db_{}", random_string);
let db_path = format!("db_tests/empty_vector_fs_db_{}", random_string);
Self {
db: DB::open_default(&db_path).unwrap(),
path: db_path,
Expand Down
2 changes: 1 addition & 1 deletion src/vector_fs/vector_fs_internals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ impl VectorFSInternals {

/// A hard-coded DB key for the profile-wide VectorFSInternals.
pub fn profile_fs_internals_shinkai_db_key() -> String {
"profile_vec_fs_internals".to_string()
"profile_vector_fs_internals".to_string()
}

pub fn to_json(&self) -> serde_json::Result<String> {
Expand Down
6 changes: 4 additions & 2 deletions tests/it/agent_integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::{JobMe
use shinkai_message_primitives::shinkai_utils::encryption::{
clone_static_secret_key, unsafe_deterministic_encryption_keypair, EncryptionMethod,
};
use shinkai_message_primitives::shinkai_utils::shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption, init_default_tracing};
use shinkai_message_primitives::shinkai_utils::shinkai_logging::{
init_default_tracing, shinkai_log, ShinkaiLogLevel, ShinkaiLogOption,
};
use shinkai_message_primitives::shinkai_utils::shinkai_message_builder::ShinkaiMessageBuilder;
use shinkai_message_primitives::shinkai_utils::signatures::{
clone_signature_secret_key, unsafe_deterministic_signature_keypair,
Expand Down Expand Up @@ -64,7 +66,7 @@ fn node_agent_registration() {
let (node1_device_encryption_sk, node1_device_encryption_pk) = unsafe_deterministic_encryption_keypair(200);

let node1_db_path = format!("db_tests/{}", hash_string(node1_identity_name.clone()));
let node1_fs_db_path = format!("db_tests/vec_fs{}", hash_string(node1_identity_name.clone()));
let node1_fs_db_path = format!("db_tests/vector_fs{}", hash_string(node1_identity_name.clone()));

// Agent pre-creation

Expand Down
4 changes: 2 additions & 2 deletions tests/it/db_restore_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ mod tests {
let (node1_device_encryption_sk, _) = unsafe_deterministic_encryption_keypair(200);

let node1_db_path = "tests/db_for_testing/test".to_string();
let node1_vec_fs_path = "tests/vec_fs_db_for_testing/test".to_string();
let node1_vector_fs_path = "tests/vector_fs_db_for_testing/test".to_string();

// Create node1
let addr1 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
Expand All @@ -56,7 +56,7 @@ mod tests {
true,
vec![],
None,
node1_vec_fs_path,
node1_vector_fs_path,
None,
None,
);
Expand Down
8 changes: 4 additions & 4 deletions tests/it/node_integration_tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use async_channel::{bounded, Receiver, Sender};
use async_std::println;
use shinkai_message_primitives::shinkai_utils::shinkai_logging::init_default_tracing;
use core::panic;
use ed25519_dalek::{SigningKey, VerifyingKey};
use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName;
Expand All @@ -12,6 +11,7 @@ use shinkai_message_primitives::shinkai_utils::encryption::{
encryption_public_key_to_string, encryption_secret_key_to_string, unsafe_deterministic_encryption_keypair,
EncryptionMethod,
};
use shinkai_message_primitives::shinkai_utils::shinkai_logging::init_default_tracing;
use shinkai_message_primitives::shinkai_utils::shinkai_message_builder::ShinkaiMessageBuilder;
use shinkai_message_primitives::shinkai_utils::signatures::{
clone_signature_secret_key, signature_public_key_to_string, signature_secret_key_to_string,
Expand Down Expand Up @@ -42,7 +42,7 @@ fn setup() {

#[test]
fn subidentity_registration() {
init_default_tracing();
init_default_tracing();
setup();
let rt = Runtime::new().unwrap();

Expand Down Expand Up @@ -86,9 +86,9 @@ fn subidentity_registration() {
bounded(100);

let node1_db_path = format!("db_tests/{}", hash_string(node1_identity_name.clone()));
let node1_fs_db_path = format!("db_tests/vec_fs{}", hash_string(node1_identity_name.clone()));
let node1_fs_db_path = format!("db_tests/vector_fs{}", hash_string(node1_identity_name.clone()));
let node2_db_path = format!("db_tests/{}", hash_string(node2_identity_name.clone()));
let node2_fs_db_path = format!("db_tests/vec_fs{}", hash_string(node2_identity_name.clone()));
let node2_fs_db_path = format!("db_tests/vector_fs{}", hash_string(node2_identity_name.clone()));

// Create node1 and node2
let addr1 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
Expand Down
6 changes: 3 additions & 3 deletions tests/it/node_retrying_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use super::utils::node_test_local::local_registration_profile_node;

#[test]
fn node_retrying_test() {
init_default_tracing();
init_default_tracing();
utils::db_handlers::setup();
let rt = Runtime::new().unwrap();

Expand Down Expand Up @@ -80,9 +80,9 @@ fn node_retrying_test() {
bounded(100);

let node1_db_path = format!("db_tests/{}", hash_string(node1_identity_name.clone()));
let node1_fs_db_path = format!("db_tests/vec_fs{}", hash_string(node1_identity_name.clone()));
let node1_fs_db_path = format!("db_tests/vector_fs{}", hash_string(node1_identity_name.clone()));
let node2_db_path = format!("db_tests/{}", hash_string(node2_identity_name.clone()));
let node2_fs_db_path = format!("db_tests/vec_fs{}", hash_string(node2_identity_name.clone()));
let node2_fs_db_path = format!("db_tests/vector_fs{}", hash_string(node2_identity_name.clone()));

// Create node1 and node2
let addr1 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
Expand Down
2 changes: 1 addition & 1 deletion tests/it/utils/test_boilerplate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ where
let (node1_device_encryption_sk, node1_device_encryption_pk) = unsafe_deterministic_encryption_keypair(200);

let node1_db_path = format!("db_tests/{}", hash_string(node1_identity_name.clone()));
let node1_fs_db_path = format!("db_tests/vec_fs{}", hash_string(node1_identity_name.clone()));
let node1_fs_db_path = format!("db_tests/vector_fs{}", hash_string(node1_identity_name.clone()));

// Create node1 and node2
let addr1 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
Expand Down
9 changes: 4 additions & 5 deletions tests/it/vec_fs_tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use serde_json::Value as JsonValue;
use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName;
use shinkai_message_primitives::shinkai_utils::shinkai_logging::init_tracing;
use shinkai_node::agent::file_parsing::ParsingHelper;
use shinkai_node::db::ShinkaiDB;
use shinkai_node::vector_fs::vector_fs_internals::VectorFSInternals;
Expand Down Expand Up @@ -37,9 +36,9 @@ fn node_name() -> ShinkaiName {
ShinkaiName::new("@@localhost.shinkai".to_string()).unwrap()
}

fn setup_default_vec_fs() -> VectorFS {
fn setup_default_vector_fs() -> VectorFS {
let generator = RemoteEmbeddingGenerator::new_default();
let fs_db_path = format!("db_tests/{}", "vec_fs");
let fs_db_path = format!("db_tests/{}", "vector_fs");
let profile_list = vec![default_test_profile()];
let supported_embedding_models = vec![EmbeddingModelType::TextEmbeddingsInference(
TextEmbeddingsInference::AllMiniLML6v2,
Expand Down Expand Up @@ -98,7 +97,7 @@ pub fn get_shinkai_intro_doc(generator: &RemoteEmbeddingGenerator, data_tags: &V
async fn test_vector_fs_initializes_new_profile_automatically() {
setup();
let generator = RemoteEmbeddingGenerator::new_default();
let mut vector_fs = setup_default_vec_fs();
let mut vector_fs = setup_default_vector_fs();

let fs_internals = vector_fs._get_profile_fs_internals(&default_test_profile());
assert!(fs_internals.is_ok())
Expand All @@ -108,7 +107,7 @@ async fn test_vector_fs_initializes_new_profile_automatically() {
async fn test_vector_fs_saving_reading() {
setup();
let generator = RemoteEmbeddingGenerator::new_default();
let mut vector_fs = setup_default_vec_fs();
let mut vector_fs = setup_default_vector_fs();

let path = VRPath::new();
let writer = vector_fs
Expand Down

0 comments on commit 84eeae6

Please sign in to comment.