diff --git a/bothan-core/src/manager/crypto_asset_info/price/tasks.rs b/bothan-core/src/manager/crypto_asset_info/price/tasks.rs index 6edc5bc5..2791bc7b 100644 --- a/bothan-core/src/manager/crypto_asset_info/price/tasks.rs +++ b/bothan-core/src/manager/crypto_asset_info/price/tasks.rs @@ -114,8 +114,8 @@ async fn compute_source_result<'a>( workers: &WorkerMap<'a>, cache: &PriceCache, stale_cutoff: i64, -) -> Result, Error> { - // Check if all prerequisites are available, if not add the missing ones to the queue +) -> Result, Error> { + // Check if all prerequisites are available, if not, add the missing ones to the queue // and continue to the next signal let mut source_results = Vec::with_capacity(signal.source_queries.len()); let mut missing_pids = HashSet::new(); @@ -127,7 +127,7 @@ async fn compute_source_result<'a>( Ok(AssetState::Available(a)) => { if a.timestamp.ge(&stale_cutoff) { match compute_source_routes(&source_query.routes, a.price, cache) { - Ok(Some(price)) => source_results.push(price), + Ok(Some(price)) => source_results.push((sid.clone(), price)), Ok(None) => {} // If unable to calculate the price, ignore the source Err(Error::PrerequisiteRequired(ids)) => missing_pids.extend(ids), Err(e) => return Err(e), diff --git a/bothan-core/src/registry/processor.rs b/bothan-core/src/registry/processor.rs index 89777634..acc404be 100644 --- a/bothan-core/src/registry/processor.rs +++ b/bothan-core/src/registry/processor.rs @@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize}; use thiserror::Error; pub mod median; +pub mod weighted_median; #[derive(Debug, Error, PartialEq, Clone)] #[error("{msg}")] @@ -18,8 +19,8 @@ impl ProcessError { } /// The Processor trait defines the methods that a processor must implement. -pub trait Process { - fn process(&self, data: Vec) -> Result; +pub trait Process { + fn process(&self, data: Vec) -> Result; } /// The Process enum represents the different types of processors that can be used. @@ -27,12 +28,28 @@ pub trait Process { #[serde(rename_all = "snake_case", tag = "function", content = "params")] pub enum Processor { Median(median::MedianProcessor), + WeightedMedian(weighted_median::WeightedMedianProcessor), } -impl Process for Processor { +impl Process for Processor { fn process(&self, data: Vec) -> Result { match self { Processor::Median(median) => median.process(data), + Processor::WeightedMedian(_) => Err(ProcessError::new( + "Weighted median not implemented for T: Decimal", + )), + } + } +} + +impl Process<(String, Decimal), Decimal> for Processor { + fn process(&self, data: Vec<(String, Decimal)>) -> Result { + match self { + Processor::Median(median) => { + let data = data.into_iter().map(|(_, value)| value).collect(); + median.process(data) + } + Processor::WeightedMedian(weighted_median) => weighted_median.process(data), } } } diff --git a/bothan-core/src/registry/processor/median.rs b/bothan-core/src/registry/processor/median.rs index 85be151b..1eae7b14 100644 --- a/bothan-core/src/registry/processor/median.rs +++ b/bothan-core/src/registry/processor/median.rs @@ -23,7 +23,7 @@ impl MedianProcessor { } } -impl Process for MedianProcessor { +impl Process for MedianProcessor { /// Processes the given data and returns the median. If there are not enough sources, it /// returns an error. fn process(&self, data: Vec) -> Result { diff --git a/bothan-core/src/registry/processor/weighted_median.rs b/bothan-core/src/registry/processor/weighted_median.rs new file mode 100644 index 00000000..0ef3a609 --- /dev/null +++ b/bothan-core/src/registry/processor/weighted_median.rs @@ -0,0 +1,164 @@ +use std::cmp::Ordering; +use std::collections::HashMap; +use std::ops::{Add, Div}; + +use bincode::{Decode, Encode}; +use num_traits::{FromPrimitive, Zero}; +use rust_decimal::Decimal; +use serde::{Deserialize, Serialize}; + +use crate::registry::processor::{Process, ProcessError}; + +/// The `WeightedMedianProcessor` finds the weighted median of a given data set where the dataset +/// contains the source and the value. It also has a `minimum_cumulative_weight` which is the +/// minimum cumulative weight required to calculate the weighted median. If the cumulative weight +/// of the data sources is less than `minimum_cumulative_weight` or the source associated with the +/// data does not have an assigned weight, it returns an error. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Encode, Decode)] +pub struct WeightedMedianProcessor { + pub source_weights: HashMap, + pub minimum_cumulative_weight: u32, +} + +impl WeightedMedianProcessor { + /// Creates a new `WeightedMedianProcessor`. + pub fn new(source_weights: HashMap, minimum_cumulative_weight: u32) -> Self { + WeightedMedianProcessor { + source_weights, + minimum_cumulative_weight, + } + } +} + +impl Process<(String, Decimal), Decimal> for WeightedMedianProcessor { + /// Processes the given data and returns the weighted median. If the cumulative weights of the + /// data sources are less than the minimum cumulative weight or the source associated + /// with the data does not have an assigned weight, it returns an error. + fn process(&self, data: Vec<(String, Decimal)>) -> Result { + let cumulative_weight = data.iter().try_fold(0, |acc, (source, _)| { + self.source_weights + .get(source) + .map(|weight| acc + weight) + .ok_or(ProcessError::new(format!("Unknown source {source}"))) + })?; + + if cumulative_weight < self.minimum_cumulative_weight { + return Err(ProcessError::new( + "Not enough sources to calculate weighted median", + )); + } + + let values = data + .into_iter() + .map(|(source, value)| { + self.source_weights + .get(&source) + .map(|weight| (value, *weight)) + .ok_or(ProcessError::new(format!("Unknown source {source}"))) + }) + .collect::, ProcessError>>()?; + + Ok(compute_weighted_median(values)) + } +} + +// This function requires that values passed is not an empty vector, if an empty vector is passed, +// it will panic +fn compute_weighted_median(mut values: Vec<(T, u32)>) -> T +where + T: Ord + Add + Div + FromPrimitive, +{ + values.sort_by(|(v1, _), (v2, _)| v1.cmp(v2)); + + // We use the sum of the weights as the mid-value to find the median to avoid rounding when + // dividing by two. + let effective_mid = values + .iter() + .fold(u32::zero(), |acc, (_, weight)| acc + weight); + + let mut effective_cumulative_weight = u32::zero(); + let mut iter = values.into_iter(); + while let Some((value, weight)) = iter.next() { + // We multiply the weight by 2 to avoid rounding when dividing by two. + effective_cumulative_weight += weight * 2; + match effective_cumulative_weight.cmp(&effective_mid) { + Ordering::Greater => return value, + Ordering::Equal => { + return if let Some((next_value, _)) = iter.next() { + (value + next_value) / FromPrimitive::from_u32(2).unwrap() + } else { + value + }; + } + Ordering::Less => (), + } + } + + unreachable!() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_weighted_median() { + let source_weights = HashMap::from([ + ("a".to_string(), 15), + ("b".to_string(), 10), + ("c".to_string(), 20), + ("d".to_string(), 30), + ("e".to_string(), 25), + ]); + + let weighted_median = WeightedMedianProcessor::new(source_weights, 0); + let data = vec![ + ("a".to_string(), Decimal::from(1)), + ("b".to_string(), Decimal::from(2)), + ("c".to_string(), Decimal::from(3)), + ("d".to_string(), Decimal::from(4)), + ("e".to_string(), Decimal::from(5)), + ]; + let res = weighted_median.process(data); + + assert_eq!(res.unwrap(), Decimal::from(4)); + } + + #[test] + fn test_median_with_even_weight() { + let source_weights = HashMap::from([("a".to_string(), 2)]); + + let weighted_median = WeightedMedianProcessor::new(source_weights, 0); + let data = vec![ + ("a".to_string(), Decimal::from(1)), + ("a".to_string(), Decimal::from(2)), + ("a".to_string(), Decimal::from(3)), + ("a".to_string(), Decimal::from(4)), + ("a".to_string(), Decimal::from(5)), + ]; + let res = weighted_median.process(data); + + assert_eq!(res.unwrap(), Decimal::from(3)); + } + + #[test] + fn test_weighted_median_with_intersect() { + let source_weights = HashMap::from([ + ("a".to_string(), 49), + ("b".to_string(), 1), + ("c".to_string(), 25), + ("d".to_string(), 25), + ]); + + let weighted_median = WeightedMedianProcessor::new(source_weights, 0); + let data = vec![ + ("a".to_string(), Decimal::from(1)), + ("b".to_string(), Decimal::from(2)), + ("c".to_string(), Decimal::from(3)), + ("d".to_string(), Decimal::from(4)), + ]; + let res = weighted_median.process(data); + + assert_eq!(res.unwrap(), Decimal::from_str_exact("2.5").unwrap()); + } +} diff --git a/bothan-htx/src/api/websocket.rs b/bothan-htx/src/api/websocket.rs index 65b8155c..7a702705 100644 --- a/bothan-htx/src/api/websocket.rs +++ b/bothan-htx/src/api/websocket.rs @@ -108,12 +108,9 @@ impl HtxWebSocketConnection { } Ok(Message::Text(msg)) => serde_json::from_str::(&msg) .map_err(|_| MessageError::UnsupportedMessage), - Err(err) => match err { - TungsteniteError::Protocol(..) | TungsteniteError::ConnectionClosed => { - Err(MessageError::ChannelClosed) - } - _ => Err(MessageError::UnsupportedMessage), - }, + Err(TungsteniteError::Protocol(_)) | Err(TungsteniteError::ConnectionClosed) => { + Err(MessageError::ChannelClosed) + } _ => Err(MessageError::UnsupportedMessage), }; }