diff --git a/server/Cargo.lock b/server/Cargo.lock index 6e85b2cc12..1a61a1c36f 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -722,6 +722,15 @@ dependencies = [ "redis 0.25.4", ] +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -1158,6 +1167,19 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + [[package]] name = "crossbeam-channel" version = "0.5.13" @@ -1186,6 +1208,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.20" @@ -1746,9 +1777,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.30" +version = "1.0.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" +checksum = "7f211bbe8e69bbd0cfdea405084f128ae8b4aaa6b0b522fc8f2b009084797920" dependencies = [ "crc32fast", "miniz_oxide", @@ -4089,28 +4120,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "rmp" -version = "0.8.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4" -dependencies = [ - "byteorder", - "num-traits", - "paste", -] - -[[package]] -name = "rmp-serde" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db" -dependencies = [ - "byteorder", - "rmp", - "serde", -] - [[package]] name = "rsa" version = "0.9.6" @@ -5744,12 +5753,14 @@ dependencies = [ "async-stripe", "base64 0.22.1", "bb8-redis", + "bincode", "bktree", "blake3", "cfg-if", "chm", "chrono", "clickhouse 0.12.0", + "crossbeam", "crossbeam-channel", "dateparser", "derive_more", @@ -5757,6 +5768,7 @@ dependencies = [ "diesel-async", "diesel_migrations", "dotenvy", + "flate2", "futures", "futures-util", "glob", @@ -5776,11 +5788,11 @@ dependencies = [ "prometheus", "qdrant-client", "rand 0.8.5", + "rayon", "redis 0.25.4", "regex", "regex-split", "reqwest 0.12.5", - "rmp-serde", "rust-argon2", "rust-s3", "scraper", diff --git a/server/Cargo.toml b/server/Cargo.toml index 67595fe579..39c6512294 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -159,7 +159,11 @@ tantivy = "0.22.0" strsim = "0.11.1" levenshtein_automata = "0.2.1" bktree = "1.0.1" -rmp-serde = "1.3.0" +flate2 = "1.0.31" +bincode = "1.3" +rayon = "1.10.0" +crossbeam = "0.8.4" + [build-dependencies] dotenvy = "0.15.7" diff --git a/server/src/bin/bktree-worker.rs b/server/src/bin/bktree-worker.rs index f39fe02a1f..122fb6def7 100644 --- a/server/src/bin/bktree-worker.rs +++ b/server/src/bin/bktree-worker.rs @@ -4,6 +4,7 @@ use std::sync::{ }; use chm::tools::migrations::SetupArgs; +use rand::Rng; use sentry::{Hub, SentryFutureExt}; use signal_hook::consts::SIGTERM; use tracing_subscriber::{prelude::*, EnvFilter, Layer}; @@ -12,8 +13,9 @@ use trieve_server::{ errors::ServiceError, get_env, operators::{ + chunk_operator::get_last_processed_from_clickhouse, dataset_operator::{scroll_words_from_dataset, update_dataset_last_processed_query}, - words_operator::{get_bktree_from_redis_query, BkTree, CreateBkTreeMessage}, + words_operator::{BkTree, CreateBkTreeMessage}, }, }; @@ -116,6 +118,7 @@ fn main() { ); } +#[allow(clippy::print_stdout)] async fn bktree_worker( should_terminate: Arc, redis_pool: actix_web::web::Data, @@ -207,7 +210,7 @@ async fn bktree_worker( log::info!("Processing dataset {}", create_tree_msg.dataset_id); let mut bk_tree = if let Ok(Some(bktree)) = - get_bktree_from_redis_query(create_tree_msg.dataset_id, redis_pool.clone()).await + BkTree::from_redis(create_tree_msg.dataset_id, redis_pool.clone()).await { bktree } else { @@ -216,10 +219,27 @@ async fn bktree_worker( let mut failed = false; + let last_processed = + get_last_processed_from_clickhouse(&clickhouse_client, create_tree_msg.dataset_id) + .await; + + let last_processed = match last_processed { + Ok(last_processed) => last_processed.map(|lp| lp.last_processed), + Err(err) => { + let _ = readd_error_to_queue(create_tree_msg.clone(), &err, redis_pool.clone()) + .await + .map_err(|e| { + eprintln!("Failed to readd error to queue: {:?}", e); + }); + continue; + } + }; + while let Ok(Some(word_and_counts)) = scroll_words_from_dataset( create_tree_msg.dataset_id, id_offset, - 10000, + last_processed, + 5000, &clickhouse_client, ) .await @@ -236,7 +256,7 @@ async fn bktree_worker( }); failed = true; }) { - println!("Processing offset: {:?}", id_offset); + dbg!(id_offset); if let Some(last_word) = word_and_counts.last() { id_offset = last_word.id; } @@ -253,43 +273,22 @@ async fn bktree_worker( continue; } - match rmp_serde::to_vec(&bk_tree) { - Ok(serialized_tree) => { - match redis::cmd("SET") - .arg(format!("bk_tree_{}", create_tree_msg.dataset_id)) - .arg(serialized_tree) - .query_async::( - &mut *redis_connection, - ) - .await - { - Ok(_) => { - let _ = redis::cmd("LREM") - .arg("bktree_processing") - .arg(1) - .arg(serialized_message.clone()) - .query_async::( - &mut *redis_connection, - ) - .await; - - log::info!( - "Succesfully created bk-tree for {}", - create_tree_msg.dataset_id - ); - } - Err(err) => { - let _ = readd_error_to_queue( - create_tree_msg.clone(), - &ServiceError::InternalServerError(format!( - "Failed to serialize tree: {:?}", - err - )), - redis_pool.clone(), - ) - .await; - } - } + match bk_tree + .save(create_tree_msg.dataset_id, redis_pool.clone()) + .await + { + Ok(()) => { + let _ = redis::cmd("LREM") + .arg("bktree_processing") + .arg(1) + .arg(serialized_message.clone()) + .query_async::(&mut *redis_connection) + .await; + + log::info!( + "Succesfully created bk-tree for {}", + create_tree_msg.dataset_id + ); } Err(err) => { let _ = readd_error_to_queue( @@ -303,6 +302,7 @@ async fn bktree_worker( .await; } } + match update_dataset_last_processed_query(create_tree_msg.dataset_id, &clickhouse_client) .await { @@ -311,7 +311,8 @@ async fn bktree_worker( log::error!("Failed to update last processed {:?}", err); } } - tokio::time::sleep(std::time::Duration::from_secs(10)).await; + let sleep_duration = rand::thread_rng().gen_range(1..=10); + tokio::time::sleep(std::time::Duration::from_secs(sleep_duration)).await; } } @@ -331,7 +332,7 @@ pub async fn readd_error_to_queue( .await .map_err(|err| ServiceError::BadRequest(err.to_string()))?; - let _ = redis::cmd("LREM") + let _ = redis::cmd("SREM") .arg("bktree_processing") .arg(1) .arg(old_payload_message.clone()) @@ -347,7 +348,7 @@ pub async fn readd_error_to_queue( .await .map_err(|err| ServiceError::BadRequest(err.to_string()))?; - redis::cmd("lpush") + redis::cmd("SADD") .arg("bktree_dead_letters") .arg(old_payload_message) .query_async(&mut *redis_conn) @@ -374,7 +375,7 @@ pub async fn readd_error_to_queue( message.attempt_number ); - redis::cmd("lpush") + redis::cmd("SADD") .arg("bktree_creation") .arg(&new_payload_message) .query_async(&mut *redis_conn) diff --git a/server/src/bin/word-worker.rs b/server/src/bin/word-worker.rs index b1f8d8b0bd..157ad0dc5a 100644 --- a/server/src/bin/word-worker.rs +++ b/server/src/bin/word-worker.rs @@ -379,7 +379,7 @@ pub async fn readd_error_to_queue( .await .map_err(|err| ServiceError::BadRequest(err.to_string()))?; - let _ = redis::cmd("SPOP") + let _ = redis::cmd("lrem") .arg("process_dictionary") .arg(1) .arg(old_payload_message.clone()) @@ -406,7 +406,7 @@ pub async fn readd_error_to_queue( ServiceError::InternalServerError("Failed to reserialize input for retry".to_string()) })?; - redis::cmd("SADD") + redis::cmd("lpush") .arg("create_dictionary") .arg(&new_payload_message) .query_async(&mut *redis_conn) diff --git a/server/src/operators/chunk_operator.rs b/server/src/operators/chunk_operator.rs index 9f37438382..0df70c7bc9 100644 --- a/server/src/operators/chunk_operator.rs +++ b/server/src/operators/chunk_operator.rs @@ -2448,7 +2448,7 @@ pub async fn get_last_processed_from_clickhouse( dataset_id: uuid::Uuid, ) -> Result, ServiceError> { let query = format!( - "SELECT ?fields FROM dataset_words_last_processed WHERE dataset_id = '{}' LIMIT 1", + "SELECT dataset_id, min(last_processed) as last_processed FROM dataset_words_last_processed WHERE dataset_id = '{}' GROUP BY dataset_id LIMIT 1", dataset_id ); diff --git a/server/src/operators/dataset_operator.rs b/server/src/operators/dataset_operator.rs index 9aab7e34c3..2f044c65a3 100644 --- a/server/src/operators/dataset_operator.rs +++ b/server/src/operators/dataset_operator.rs @@ -20,7 +20,7 @@ use diesel::result::{DatabaseErrorKind, Error as DBError}; use diesel_async::RunQueryDsl; use itertools::Itertools; use serde::{Deserialize, Serialize}; -use time::OffsetDateTime; +use time::{format_description, OffsetDateTime}; use super::clickhouse_operator::EventQueue; @@ -760,38 +760,55 @@ pub async fn add_words_to_dataset( pub struct WordDatasetCount { #[serde(with = "clickhouse::serde::uuid")] pub id: uuid::Uuid, - #[serde(with = "clickhouse::serde::uuid")] - pub dataset_id: uuid::Uuid, pub word: String, pub count: i32, - #[serde(with = "clickhouse::serde::time::datetime")] - pub created_at: OffsetDateTime, } #[tracing::instrument(skip(clickhouse_client))] pub async fn scroll_words_from_dataset( dataset_id: uuid::Uuid, offset: uuid::Uuid, + last_processed: Option, limit: i64, clickhouse_client: &clickhouse::Client, ) -> Result>, ServiceError> { - let query = format!( + let mut query = format!( " SELECT id, - dataset_id, word, count, - created_at FROM words_datasets - LEFT JOIN dataset_words_last_processed ON dataset_words_last_processed.dataset_id = words_datasets.dataset_id WHERE dataset_id = '{}' AND id > '{}' - AND (created_at > last_processed OR last_processed IS NULL) - ORDER BY id ASC LIMIT {} ", - dataset_id, offset, limit + dataset_id, offset, ); + if let Some(last_processed) = last_processed { + query = format!( + "{} AND created_at >= '{}'", + query, + last_processed + .format( + &format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]",) + .unwrap() + ) + .map_err(|e| { + log::error!("Error formatting last processed time: {:?}", e); + sentry::capture_message( + "Error formatting last processed time", + sentry::Level::Error, + ); + ServiceError::InternalServerError(format!( + "Error formatting last processed time: {:?}", + e + )) + })? + ); + } + + query = format!("{} ORDER BY id LIMIT {}", query, limit); + let words = clickhouse_client .query(&query) .fetch_all::() diff --git a/server/src/operators/words_operator.rs b/server/src/operators/words_operator.rs index a60a65ddbb..1340392955 100644 --- a/server/src/operators/words_operator.rs +++ b/server/src/operators/words_operator.rs @@ -1,5 +1,7 @@ use std::{ collections::{HashMap, HashSet}, + io::Write, + sync::{Arc, Mutex}, time::{Duration, Instant}, }; @@ -8,12 +10,18 @@ use crate::{ errors::ServiceError, }; use actix_web::web; +use flate2::{ + write::{GzDecoder, GzEncoder}, + Compression, +}; use itertools::Itertools; use lazy_static::lazy_static; -use serde::{Deserialize, Serialize}; +use rayon::prelude::*; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::collections::VecDeque; use tokio::sync::RwLock; -#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug, Eq, PartialEq)] struct Node { word: String, count: i32, @@ -22,17 +30,142 @@ struct Node { /// A BK-tree datastructure /// -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Clone, Debug)] pub struct BkTree { root: Option>, } +#[derive(Serialize, Deserialize)] +struct FlatNode { + parent_index: Option, + distance: Option, + word: String, + count: i32, +} + +impl Serialize for BkTree { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut queue = VecDeque::new(); + let mut flat_tree = Vec::new(); + + if let Some(root) = &self.root { + queue.push_back((None, None, root.as_ref())); + } + + while let Some((parent_index, distance, node)) = queue.pop_front() { + let current_index = flat_tree.len(); + flat_tree.push(FlatNode { + parent_index, + distance, + word: node.word.clone(), + count: node.count, + }); + + for (child_distance, child) in &node.children { + queue.push_back((Some(current_index), Some(*child_distance), child)); + } + } + + let binary_data = bincode::serialize(&flat_tree).map_err(serde::ser::Error::custom)?; + serializer.serialize_bytes(&binary_data) + } +} + +impl<'de> Deserialize<'de> for BkTree { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let binary_data: Vec = Vec::deserialize(deserializer)?; + let flat_tree: Vec = + bincode::deserialize(&binary_data).map_err(serde::de::Error::custom)?; + + if flat_tree.is_empty() { + return Ok(BkTree { root: None }); + } + + let mut nodes: Vec = flat_tree + .iter() + .map(|flat_node| Node { + word: flat_node.word.clone(), + count: flat_node.count, + children: Vec::new(), + }) + .collect(); + + // Reconstruct the tree structure + for i in (1..nodes.len()).rev() { + let parent_index = flat_tree[i].parent_index.unwrap(); + let distance = flat_tree[i].distance.unwrap(); + let child = nodes.remove(i); + nodes[parent_index].children.push((distance, child)); + } + + Ok(BkTree { + root: Some(Box::new(nodes.remove(0))), + }) + } +} + impl Default for BkTree { fn default() -> Self { Self::new() } } +pub fn levenshtein_distance>(a: &S, b: &S) -> isize { + let a = a.as_ref().to_lowercase(); + let b = b.as_ref().to_lowercase(); + + if a == b { + return 0; + } + + let a_len = a.chars().count(); + let b_len = b.chars().count(); + + if a_len == 0 { + return b_len as isize; + } + + if b_len == 0 { + return a_len as isize; + } + + let mut res = 0; + let mut cache: Vec = (1..).take(a_len).collect(); + let mut a_dist; + let mut b_dist; + + for (ib, cb) in b.chars().enumerate() { + res = ib; + a_dist = ib; + for (ia, ca) in a.chars().enumerate() { + b_dist = if ca == cb { a_dist } else { a_dist + 1 }; + a_dist = cache[ia]; + + res = if a_dist > res { + if b_dist > res { + res + 1 + } else { + b_dist + } + } else if b_dist > a_dist { + a_dist + 1 + } else { + b_dist + }; + + cache[ia] = res; + } + } + + res as isize +} + impl BkTree { /// Create a new BK-tree pub fn new() -> Self { @@ -59,12 +192,16 @@ impl BkTree { Some(ref mut root_node) => { let mut u = &mut **root_node; loop { - let k = bktree::levenshtein_distance(&u.word, &val.0); + let k = levenshtein_distance(&u.word, &val.0); if k == 0 { u.count = val.1; return; } + if val.1 == 1 { + return; + } + let v = u.children.iter().position(|(dist, _)| *dist == k); match v { None => { @@ -95,26 +232,53 @@ impl BkTree { match self.root { None => Vec::new(), Some(ref root) => { - let mut found = Vec::new(); - - let mut candidates: std::collections::VecDeque<&Node> = - std::collections::VecDeque::new(); - candidates.push_back(root); - - while let Some(n) = candidates.pop_front() { - let distance = bktree::levenshtein_distance(&n.word, &val); - if distance <= max_dist { - found.push(((&n.word, &n.count), distance)); - } - - candidates.extend( - n.children + let found = Arc::new(Mutex::new(Vec::new())); + let mut candidates: Vec<&Node> = vec![root]; + + while !candidates.is_empty() { + let next_candidates: Vec<&Node> = if candidates.len() > 1000 { + candidates + .par_iter() + .flat_map(|&n| { + let distance = levenshtein_distance(&n.word, &val); + let mut local_candidates = Vec::new(); + + if distance <= max_dist { + found.lock().unwrap().push(((&n.word, &n.count), distance)); + } + + for (arc, node) in &n.children { + if (*arc - distance).abs() <= max_dist { + local_candidates.push(node); + } + } + + local_candidates + }) + .collect() + } else { + candidates .iter() - .filter(|(arc, _)| (*arc - distance).abs() <= max_dist) - .map(|(_, node)| node), - ); + .flat_map(|&n| { + let distance = levenshtein_distance(&n.word, &val); + if distance <= max_dist { + found.lock().unwrap().push(((&n.word, &n.count), distance)); + } + n.children + .iter() + .filter(|(arc, _)| (*arc - distance).abs() <= max_dist) + .map(|(_, node)| node) + .collect::>() + }) + .collect() + }; + + candidates = next_candidates; } - found + + let mut result = Arc::try_unwrap(found).unwrap().into_inner().unwrap(); + result.sort_by_key(|&(_, dist)| dist); + result } } } @@ -127,6 +291,79 @@ impl BkTree { } Iter { queue } } + + pub async fn from_redis( + dataset_id: uuid::Uuid, + redis_pool: web::Data, + ) -> Result, ServiceError> { + let mut redis_conn = redis_pool.get().await.map_err(|_| { + ServiceError::InternalServerError("Failed to get redis connection".to_string()) + })?; + + let compressed_bk_tree: Option> = redis::cmd("GET") + .arg(format!("bk_tree_{}", dataset_id)) + .query_async(&mut *redis_conn) + .await + .map_err(|err| ServiceError::BadRequest(err.to_string()))?; + + if let Some(compressed_bk_tree) = compressed_bk_tree { + let buf = Vec::new(); + let mut decoder = GzDecoder::new(buf); + decoder.write_all(&compressed_bk_tree).map_err(|err| { + ServiceError::InternalServerError(format!("Failed to decompress bk tree {}", err)) + })?; + + let serialized_bk_tree = decoder.finish().map_err(|err| { + ServiceError::InternalServerError(format!( + "Failed to finish decompressing bk tree {}", + err + )) + })?; + + let tree = bincode::deserialize(&serialized_bk_tree).map_err(|err| { + ServiceError::InternalServerError(format!("Failed to deserialize bk tree {}", err)) + })?; + + Ok(Some(tree)) + } else { + Ok(None) + } + } + + pub async fn save( + &self, + dataset_id: uuid::Uuid, + redis_pool: web::Data, + ) -> Result<(), ServiceError> { + if self.root.is_none() { + return Ok(()); + } + let mut redis_conn = redis_pool.get().await.map_err(|_| { + ServiceError::InternalServerError("Failed to get redis connection".to_string()) + })?; + + let uncompressed_bk_tree = bincode::serialize(self).map_err(|_| { + ServiceError::InternalServerError("Failed to serialize bk tree".to_string()) + })?; + + let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(&uncompressed_bk_tree).map_err(|_| { + ServiceError::InternalServerError("Failed to compress bk tree".to_string()) + })?; + + let serialized_bk_tree = encoder.finish().map_err(|_| { + ServiceError::InternalServerError("Failed to finish compressing bk tree".to_string()) + })?; + + redis::cmd("SET") + .arg(format!("bk_tree_{}", dataset_id)) + .arg(serialized_bk_tree) + .query_async(&mut *redis_conn) + .await + .map_err(|err| ServiceError::BadRequest(err.to_string()))?; + + Ok(()) + } } /// Iterator over BK-tree elements @@ -171,29 +408,6 @@ pub struct CreateBkTreeMessage { pub attempt_number: usize, } -pub async fn get_bktree_from_redis_query( - dataset_id: uuid::Uuid, - redis_pool: web::Data, -) -> Result, ServiceError> { - let mut redis_conn = redis_pool.get().await.map_err(|_| { - ServiceError::InternalServerError("Failed to get redis connection".to_string()) - })?; - - let serialized_bk_tree: Option> = redis::cmd("GET") - .arg(format!("bk_tree_{}", dataset_id)) - .query_async(&mut *redis_conn) - .await - .map_err(|err| ServiceError::BadRequest(err.to_string()))?; - - if let Some(serialized_bk_tree) = serialized_bk_tree { - let tree: BkTree = rmp_serde::from_slice(&serialized_bk_tree) - .map_err(|_| ServiceError::BadRequest("Failed to deserialize bk tree".to_string()))?; - Ok(Some(tree)) - } else { - Ok(None) - } -} - struct BKTreeCacheEntry { bktree: BkTree, expiration: Instant, @@ -340,15 +554,24 @@ pub async fn correct_query( None => { let dataset_id = dataset_id; let redis_pool = redis_pool.clone(); + dbg!("Pulling new BK tree from Redis"); tokio::spawn(async move { - if let Ok(Some(bktree)) = get_bktree_from_redis_query(dataset_id, redis_pool).await - { + match BkTree::from_redis(dataset_id, redis_pool).await { // TTL of 1 day - BKTREE_CACHE.insert_with_ttl( - dataset_id, - bktree, - Duration::from_secs(60 * 60 * 24), - ); + Ok(Some(bktree)) => { + BKTREE_CACHE.insert_with_ttl( + dataset_id, + bktree, + Duration::from_secs(60 * 60 * 24), + ); + dbg!("Inserted new BK tree into cache for dataset_id: {:?}", dataset_id); + } + Ok(None) => { + dbg!("No BK tree found in Redis for dataset_id: {:?}", dataset_id); + } + Err(e) => { + dbg!("Failed to insert new BK tree into cache {:?} for dataset_id: {:?}", e, dataset_id); + } }; }); Ok(query)