Skip to content

Commit

Permalink
Merge pull request #63 from dcSpark/feature/simplify-resource-db-types
Browse files Browse the repository at this point in the history
Simplifying Resource DB Types/Interface
  • Loading branch information
robkorn authored Sep 5, 2023
2 parents a12da71 + c1e132c commit aa59f75
Show file tree
Hide file tree
Showing 14 changed files with 447 additions and 317 deletions.
143 changes: 67 additions & 76 deletions src/db/db_resources.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use crate::db::{ShinkaiDB, Topic};
use crate::resources::base_vector_resources::BaseVectorResource;
use crate::resources::base_vector_resources::VectorResourceBaseType;
use crate::resources::document_resource::DocumentVectorResource;
use crate::resources::embeddings::Embedding;
use crate::resources::resource_errors::VectorResourceError;
use crate::resources::router::{VectorResourcePointer, VectorResourceRouter};
use crate::resources::vector_resource::RetrievedDataChunk;
use crate::resources::vector_resource::{VectorResource, VectorResourceType};
use serde_json::{from_str, to_string};
use crate::resources::vector_resource::VectorResource;
use serde_json::from_str;
use shinkai_message_wasm::schemas::shinkai_name::ShinkaiName;

use super::db::ProfileBoundWriteBatch;
Expand All @@ -18,10 +20,15 @@ impl ShinkaiDB {
router: &VectorResourceRouter,
profile: &ShinkaiName,
) -> Result<(), ShinkaiDBError> {
let (bytes, cf) = self._prepare_profile_resource_router(router, profile)?;
let (bytes, cf) = self._prepare_profile_resource_router(router)?;

// Insert into the "VectorResources" column family
self.put_cf_pb(cf, &VectorResourceRouter::profile_router_db_key(), bytes, profile)?;
self.put_cf_pb(
cf,
&VectorResourceRouter::profile_router_shinkai_db_key(),
bytes,
profile,
)?;

Ok(())
}
Expand All @@ -30,7 +37,6 @@ impl ShinkaiDB {
fn _prepare_profile_resource_router(
&self,
router: &VectorResourceRouter,
profile: &ShinkaiName,
) -> Result<(Vec<u8>, &rocksdb::ColumnFamily), ShinkaiDBError> {
// Convert JSON to bytes for storage
let json = router.to_json()?;
Expand All @@ -50,25 +56,24 @@ impl ShinkaiDB {
/// resource being saved and is implemented in `.save_resources`.
fn _save_resource_pointerless(
&self,
resource: &Box<dyn VectorResource>,
resource: &BaseVectorResource,
profile: &ShinkaiName,
) -> Result<(), ShinkaiDBError> {
let (bytes, cf) = self._prepare_resource_pointerless(resource, profile)?;
let (bytes, cf) = self._prepare_resource_pointerless(resource)?;

// Insert into the "VectorResources" column family
self.put_cf_pb(cf, &resource.db_key(), &bytes, profile)?;
self.put_cf_pb(cf, &resource.as_trait_object().shinkai_db_key(), &bytes, profile)?;

Ok(())
}

/// Prepares the `VectorResource` for saving into the ShinkaiDB in the resources topic as a JSON
/// Prepares the `BaseVectorResource` for saving into the ShinkaiDB in the resources topic as a JSON
/// string. Note this is only to be used internally.
fn _prepare_resource_pointerless(
&self,
resource: &Box<dyn VectorResource>,
profile: &ShinkaiName,
resource: &BaseVectorResource,
) -> Result<(Vec<u8>, &rocksdb::ColumnFamily), ShinkaiDBError> {
// Convert VectorResource JSON to bytes for storage
// Convert BaseVectorResource JSON to bytes for storage
let json = resource.to_json()?;
let bytes = json.as_bytes().to_vec();

Expand All @@ -78,16 +83,12 @@ impl ShinkaiDB {
Ok((bytes, cf))
}

/// Saves the `VectorResource` into the ShinkaiDB. This updates the
/// Saves the `BaseVectorResource` into the ShinkaiDB. This updates the
/// Global VectorResourceRouter with the resource pointers as well.
///
/// Of note, if an existing resource exists in the DB with the same name and
/// resource_id, this will overwrite the old resource completely.
pub fn save_resource(
&self,
resource: Box<dyn VectorResource>,
profile: &ShinkaiName,
) -> Result<(), ShinkaiDBError> {
pub fn save_resource(&self, resource: BaseVectorResource, profile: &ShinkaiName) -> Result<(), ShinkaiDBError> {
self.save_resources(vec![resource], profile)
}

Expand All @@ -98,7 +99,7 @@ impl ShinkaiDB {
/// resource_id, this will overwrite the old resource completely.
pub fn save_resources(
&self,
resources: Vec<Box<dyn VectorResource>>,
resources: Vec<BaseVectorResource>,
profile: &ShinkaiName,
) -> Result<(), ShinkaiDBError> {
// Get the resource router
Expand All @@ -107,71 +108,46 @@ impl ShinkaiDB {
let mut pb_batch = ProfileBoundWriteBatch::new(profile)?;
for resource in resources {
// Adds the JSON of the resource to the batch
let (bytes, cf) = self._prepare_resource_pointerless(&resource, profile)?;
pb_batch.put_cf_pb(cf, &resource.db_key(), &bytes);
let (bytes, cf) = self._prepare_resource_pointerless(&resource)?;
pb_batch.put_cf_pb(cf, &resource.as_trait_object().shinkai_db_key(), &bytes);

// Add the pointer to the router, then putting the router
// into the batch
let pointer = resource.get_resource_pointer();
let pointer = resource.as_trait_object().get_resource_pointer();
router.add_resource_pointer(&pointer)?;
let (bytes, cf) = self._prepare_profile_resource_router(&router, profile)?;
pb_batch.put_cf_pb(cf, &VectorResourceRouter::profile_router_db_key(), &bytes);
let (bytes, cf) = self._prepare_profile_resource_router(&router)?;
pb_batch.put_cf_pb(cf, &VectorResourceRouter::profile_router_shinkai_db_key(), &bytes);
}

self.write_pb(pb_batch)?;

Ok(())
}

/// Fetches the VectorResource from the DB using a VectorResourcePointer
/// Fetches the BaseVectorResource from the DB using a VectorResourcePointer
pub fn get_resource_by_pointer(
&self,
resource_pointer: &VectorResourcePointer,
profile: &ShinkaiName,
) -> Result<Box<dyn VectorResource>, ShinkaiDBError> {
self.get_resource(
&resource_pointer.db_key.clone(),
&resource_pointer.resource_type,
profile,
)
}

/// Fetches the VectorResource from the DB
pub fn get_resource(
&self,
key: &str,
resource_type: &VectorResourceType,
profile: &ShinkaiName,
) -> Result<Box<dyn VectorResource>, ShinkaiDBError> {
// Fetch and convert the bytes to a valid UTF-8 string
let bytes = self.get_cf_pb(Topic::VectorResources, key, profile)?;
let json_str = std::str::from_utf8(&bytes)?;

// Parse the JSON string into a VectorResource implementing struct
if resource_type == &VectorResourceType::Document {
let document_resource: DocumentVectorResource = from_str(json_str)?;
Ok(Box::new(document_resource))
} else {
Err(ShinkaiDBError::from(VectorResourceError::InvalidVectorResourceType))
}
) -> Result<BaseVectorResource, ShinkaiDBError> {
self.get_resource(&resource_pointer.shinkai_db_key.clone(), profile)
}

/// Fetches a DocumentVectorResource from the DB
pub fn get_document(&self, key: &str, profile: &ShinkaiName) -> Result<DocumentVectorResource, ShinkaiDBError> {
/// Fetches the BaseVectorResource from the DB
pub fn get_resource(&self, key: &str, profile: &ShinkaiName) -> Result<BaseVectorResource, ShinkaiDBError> {
// Fetch and convert the bytes to a valid UTF-8 string
let bytes = self.get_cf_pb(Topic::VectorResources, key, profile)?;
let json_str = std::str::from_utf8(&bytes)?;

// Parse the JSON string into a VectorResource implementing struct
Ok(from_str(json_str)?)
Ok(BaseVectorResource::from_json(json_str)?)
}

/// Fetches the Global VectorResource Router from the DB
pub fn get_profile_resource_router(&self, profile: &ShinkaiName) -> Result<VectorResourceRouter, ShinkaiDBError> {
// Fetch and convert the bytes to a valid UTF-8 string
let bytes = self.get_cf_pb(
Topic::VectorResources,
&VectorResourceRouter::profile_router_db_key(),
&VectorResourceRouter::profile_router_shinkai_db_key(),
profile,
)?;
let json_str = std::str::from_utf8(&bytes)?;
Expand All @@ -186,7 +162,7 @@ impl ShinkaiDB {
/// Only resources with matching data tags will be considered at all,
/// and likewise only data chunks with matching data tags inside of said
/// resources will be scored and potentially returned.
pub fn syntactic_vector_search_data(
pub fn syntactic_vector_search(
&self,
query: Embedding,
num_of_resources: u64,
Expand All @@ -199,8 +175,11 @@ impl ShinkaiDB {

let mut retrieved_chunks = Vec::new();
for resource in resources {
println!("VectorResource: {}", resource.name());
let results = resource.syntactic_vector_search(query.clone(), num_of_results, data_tag_names);
println!("VectorResource: {}", resource.as_trait_object().name());
let results =
resource
.as_trait_object()
.syntactic_vector_search(query.clone(), num_of_results, data_tag_names);
retrieved_chunks.extend(results);
}

Expand All @@ -212,7 +191,7 @@ impl ShinkaiDB {
/// From there a vector search is performed on each resource with the query embedding,
/// and the results from all resources are then collected, sorted, and the top num_of_results
/// RetriedDataChunks based on similarity score are returned.
pub fn vector_search_data(
pub fn vector_search(
&self,
query: Embedding,
num_of_resources: u64,
Expand All @@ -223,7 +202,7 @@ impl ShinkaiDB {

let mut retrieved_chunks = Vec::new();
for resource in resources {
let results = resource.vector_search(query.clone(), num_of_results);
let results = resource.as_trait_object().vector_search(query.clone(), num_of_results);
retrieved_chunks.extend(results);
}

Expand All @@ -236,14 +215,14 @@ impl ShinkaiDB {
/// * `tolerance_range` - A float between 0 and 1, inclusive, that
/// determines the range of acceptable similarity scores as a percentage
/// of the highest score.
pub fn vector_search_data_tolerance_ranged(
pub fn vector_search_tolerance_ranged(
&self,
query: Embedding,
num_of_resources: u64,
tolerance_range: f32,
profile: &ShinkaiName,
) -> Result<Vec<RetrievedDataChunk>, ShinkaiDBError> {
let retrieved_chunks = self.vector_search_data(query.clone(), num_of_resources, 1, profile)?;
let retrieved_chunks = self.vector_search(query.clone(), num_of_resources, 1, profile)?;
let top_chunk = &retrieved_chunks.get(0).ok_or(ShinkaiDBError::VectorResourceError(
VectorResourceError::VectorResourceEmpty,
))?;
Expand All @@ -252,8 +231,11 @@ impl ShinkaiDB {
let resources = self.vector_search_resources(query.clone(), num_of_resources, profile)?;
let mut final_chunks = Vec::new();
for resource in resources {
let results =
resource.vector_search_tolerance_ranged_score(query.clone(), tolerance_range, top_chunk.score);
let results = resource.as_trait_object().vector_search_tolerance_ranged_score(
query.clone(),
tolerance_range,
top_chunk.score,
);
final_chunks.extend(results);
}

Expand All @@ -265,14 +247,19 @@ impl ShinkaiDB {
///
/// Note: This only searches DocumentVectorResources in Topic::VectorResources, not all resources. This is
/// because the proximity logic is not generic (potentially later we can have a Proximity trait).
pub fn vector_search_data_doc_proximity(
pub fn vector_search_proximity(
&self,
query: Embedding,
num_of_docs: u64,
proximity_window: u64,
profile: &ShinkaiName,
) -> Result<Vec<RetrievedDataChunk>, ShinkaiDBError> {
let docs = self.vector_search_docs(query.clone(), num_of_docs, profile)?;
let mut docs: Vec<DocumentVectorResource> = Vec::new();
for doc in self.vector_search_docs(query.clone(), num_of_docs, profile)? {
if let Ok(document_resource) = doc.as_document_resource() {
docs.push(document_resource.clone());
}
}

let mut retrieved_chunks = Vec::new();
for doc in &docs {
Expand All @@ -286,7 +273,7 @@ impl ShinkaiDB {
))?;

for doc in &docs {
if doc.db_key() == top_chunk.resource_pointer.db_key {
if doc.shinkai_db_key() == top_chunk.resource_pointer.shinkai_db_key {
return Ok(doc.vector_search_proximity(query, proximity_window)?);
}
}
Expand All @@ -304,32 +291,32 @@ impl ShinkaiDB {
num_of_resources: u64,
data_tag_names: &Vec<String>,
profile: &ShinkaiName,
) -> Result<Vec<Box<dyn VectorResource>>, ShinkaiDBError> {
) -> Result<Vec<BaseVectorResource>, ShinkaiDBError> {
let router = self.get_profile_resource_router(profile)?;
let resource_pointers = router.syntactic_vector_search(query, num_of_resources, data_tag_names);

let mut resources = vec![];
for res_pointer in resource_pointers {
resources.push(self.get_resource(&res_pointer.db_key, &(res_pointer.resource_type), profile)?);
resources.push(self.get_resource(&res_pointer.shinkai_db_key, profile)?);
}

Ok(resources)
}

/// Performs a vector search using a query embedding and returns the
/// num_of_resources amount of most similar VectorResources.
/// num_of_resources amount of most similar BaseVectorResources.
pub fn vector_search_resources(
&self,
query: Embedding,
num_of_resources: u64,
profile: &ShinkaiName,
) -> Result<Vec<Box<dyn VectorResource>>, ShinkaiDBError> {
) -> Result<Vec<BaseVectorResource>, ShinkaiDBError> {
let router = self.get_profile_resource_router(profile)?;
let resource_pointers = router.vector_search(query, num_of_resources);

let mut resources = vec![];
for res_pointer in resource_pointers {
resources.push(self.get_resource(&res_pointer.db_key, &(res_pointer.resource_type), profile)?);
resources.push(self.get_resource(&res_pointer.shinkai_db_key, profile)?);
}

Ok(resources)
Expand All @@ -342,13 +329,17 @@ impl ShinkaiDB {
query: Embedding,
num_of_docs: u64,
profile: &ShinkaiName,
) -> Result<Vec<DocumentVectorResource>, ShinkaiDBError> {
) -> Result<Vec<BaseVectorResource>, ShinkaiDBError> {
let router = self.get_profile_resource_router(profile)?;
let resource_pointers = router.vector_search(query, num_of_docs);
let resource_pointers = router.vector_search(query, num_of_docs * 2);

let mut resources = vec![];
for res_pointer in resource_pointers {
resources.push(self.get_document(&res_pointer.db_key, profile)?);
if res_pointer.resource_base_type == VectorResourceBaseType::Document {
if (resources.len() as u64) < num_of_docs {
resources.push(self.get_resource(&res_pointer.shinkai_db_key, profile)?);
}
}
}

Ok(resources)
Expand Down
Loading

0 comments on commit aa59f75

Please sign in to comment.