-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feat] Add Weighted Median Processor (#82)
* add weighted median processor * [feat] Add Monitoring (#73) * add monitoring * fix error msgs * cleanup imports * add todo and minor refactor * add uuid to response * add default * remove todo * fix * minor fixes * add supported sources * add enabled option for monitoring * minor format and removed copy trait * add workflow (#81) * minor format for htx * change weights to u32 * [chore] Add Startup Docs (#77) * add docs and default api_key value for coingecko * fix * [chore] add sources to api (#79) * add sources to api, fix logging, add ping to coinbase * modify cryptocompare opts * remove trace, add logging filter * regen cargo.lock * skip empty ids for rest sources * add message to delete config file * add example config (#80) Co-authored-by: Ongart Pisansathienwong <[email protected]> --------- Co-authored-by: Warittorn Cheevachaipimol <[email protected]> Co-authored-by: warittornc <[email protected]> * add comments for clarity * add weighted median processor * minor format and removed copy trait * minor format for htx * change weights to u32 * add comments for clarity --------- Co-authored-by: Ongart Pisansathienwong <[email protected]>
- Loading branch information
1 parent
4f9f87b
commit 75b3ad3
Showing
5 changed files
with
191 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
use std::cmp::Ordering; | ||
use std::collections::HashMap; | ||
use std::ops::{Add, Div}; | ||
|
||
use bincode::{Decode, Encode}; | ||
use num_traits::{FromPrimitive, Zero}; | ||
use rust_decimal::Decimal; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
use crate::registry::processor::{Process, ProcessError}; | ||
|
||
/// The `WeightedMedianProcessor` finds the weighted median of a given data set where the dataset | ||
/// contains the source and the value. It also has a `minimum_cumulative_weight` which is the | ||
/// minimum cumulative weight required to calculate the weighted median. If the cumulative weight | ||
/// of the data sources is less than `minimum_cumulative_weight` or the source associated with the | ||
/// data does not have an assigned weight, it returns an error. | ||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Encode, Decode)] | ||
pub struct WeightedMedianProcessor { | ||
pub source_weights: HashMap<String, u32>, | ||
pub minimum_cumulative_weight: u32, | ||
} | ||
|
||
impl WeightedMedianProcessor { | ||
/// Creates a new `WeightedMedianProcessor`. | ||
pub fn new(source_weights: HashMap<String, u32>, minimum_cumulative_weight: u32) -> Self { | ||
WeightedMedianProcessor { | ||
source_weights, | ||
minimum_cumulative_weight, | ||
} | ||
} | ||
} | ||
|
||
impl Process<(String, Decimal), Decimal> for WeightedMedianProcessor { | ||
/// Processes the given data and returns the weighted median. If the cumulative weights of the | ||
/// data sources are less than the minimum cumulative weight or the source associated | ||
/// with the data does not have an assigned weight, it returns an error. | ||
fn process(&self, data: Vec<(String, Decimal)>) -> Result<Decimal, ProcessError> { | ||
let cumulative_weight = data.iter().try_fold(0, |acc, (source, _)| { | ||
self.source_weights | ||
.get(source) | ||
.map(|weight| acc + weight) | ||
.ok_or(ProcessError::new(format!("Unknown source {source}"))) | ||
})?; | ||
|
||
if cumulative_weight < self.minimum_cumulative_weight { | ||
return Err(ProcessError::new( | ||
"Not enough sources to calculate weighted median", | ||
)); | ||
} | ||
|
||
let values = data | ||
.into_iter() | ||
.map(|(source, value)| { | ||
self.source_weights | ||
.get(&source) | ||
.map(|weight| (value, *weight)) | ||
.ok_or(ProcessError::new(format!("Unknown source {source}"))) | ||
}) | ||
.collect::<Result<Vec<(Decimal, u32)>, ProcessError>>()?; | ||
|
||
Ok(compute_weighted_median(values)) | ||
} | ||
} | ||
|
||
// This function requires that values passed is not an empty vector, if an empty vector is passed, | ||
// it will panic | ||
fn compute_weighted_median<T>(mut values: Vec<(T, u32)>) -> T | ||
where | ||
T: Ord + Add<Output = T> + Div<Output = T> + FromPrimitive, | ||
{ | ||
values.sort_by(|(v1, _), (v2, _)| v1.cmp(v2)); | ||
|
||
// We use the sum of the weights as the mid-value to find the median to avoid rounding when | ||
// dividing by two. | ||
let effective_mid = values | ||
.iter() | ||
.fold(u32::zero(), |acc, (_, weight)| acc + weight); | ||
|
||
let mut effective_cumulative_weight = u32::zero(); | ||
let mut iter = values.into_iter(); | ||
while let Some((value, weight)) = iter.next() { | ||
// We multiply the weight by 2 to avoid rounding when dividing by two. | ||
effective_cumulative_weight += weight * 2; | ||
match effective_cumulative_weight.cmp(&effective_mid) { | ||
Ordering::Greater => return value, | ||
Ordering::Equal => { | ||
return if let Some((next_value, _)) = iter.next() { | ||
(value + next_value) / FromPrimitive::from_u32(2).unwrap() | ||
} else { | ||
value | ||
}; | ||
} | ||
Ordering::Less => (), | ||
} | ||
} | ||
|
||
unreachable!() | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
#[test] | ||
fn test_weighted_median() { | ||
let source_weights = HashMap::from([ | ||
("a".to_string(), 15), | ||
("b".to_string(), 10), | ||
("c".to_string(), 20), | ||
("d".to_string(), 30), | ||
("e".to_string(), 25), | ||
]); | ||
|
||
let weighted_median = WeightedMedianProcessor::new(source_weights, 0); | ||
let data = vec![ | ||
("a".to_string(), Decimal::from(1)), | ||
("b".to_string(), Decimal::from(2)), | ||
("c".to_string(), Decimal::from(3)), | ||
("d".to_string(), Decimal::from(4)), | ||
("e".to_string(), Decimal::from(5)), | ||
]; | ||
let res = weighted_median.process(data); | ||
|
||
assert_eq!(res.unwrap(), Decimal::from(4)); | ||
} | ||
|
||
#[test] | ||
fn test_median_with_even_weight() { | ||
let source_weights = HashMap::from([("a".to_string(), 2)]); | ||
|
||
let weighted_median = WeightedMedianProcessor::new(source_weights, 0); | ||
let data = vec![ | ||
("a".to_string(), Decimal::from(1)), | ||
("a".to_string(), Decimal::from(2)), | ||
("a".to_string(), Decimal::from(3)), | ||
("a".to_string(), Decimal::from(4)), | ||
("a".to_string(), Decimal::from(5)), | ||
]; | ||
let res = weighted_median.process(data); | ||
|
||
assert_eq!(res.unwrap(), Decimal::from(3)); | ||
} | ||
|
||
#[test] | ||
fn test_weighted_median_with_intersect() { | ||
let source_weights = HashMap::from([ | ||
("a".to_string(), 49), | ||
("b".to_string(), 1), | ||
("c".to_string(), 25), | ||
("d".to_string(), 25), | ||
]); | ||
|
||
let weighted_median = WeightedMedianProcessor::new(source_weights, 0); | ||
let data = vec![ | ||
("a".to_string(), Decimal::from(1)), | ||
("b".to_string(), Decimal::from(2)), | ||
("c".to_string(), Decimal::from(3)), | ||
("d".to_string(), Decimal::from(4)), | ||
]; | ||
let res = weighted_median.process(data); | ||
|
||
assert_eq!(res.unwrap(), Decimal::from_str_exact("2.5").unwrap()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters