Skip to content

Commit

Permalink
[feat] Add Weighted Median Processor (#82)
Browse files Browse the repository at this point in the history
* add weighted median processor

* [feat] Add Monitoring (#73)

* add monitoring

* fix error msgs

* cleanup imports

* add todo and minor refactor

* add uuid to response

* add default

* remove todo

* fix

* minor fixes

* add supported sources

* add enabled option for monitoring

* minor format and removed copy trait

* add workflow (#81)

* minor format for htx

* change weights to u32

* [chore] Add Startup Docs (#77)

* add docs and default api_key value for coingecko

* fix

* [chore] add sources to api (#79)

* add sources to api, fix logging, add ping to coinbase

* modify cryptocompare opts

* remove trace, add logging filter

* regen cargo.lock

* skip empty ids for rest sources

* add message to delete config file

* add example config (#80)

Co-authored-by: Ongart Pisansathienwong <[email protected]>

---------

Co-authored-by: Warittorn Cheevachaipimol <[email protected]>
Co-authored-by: warittornc <[email protected]>

* add comments for clarity

* add weighted median processor

* minor format and removed copy trait

* minor format for htx

* change weights to u32

* add comments for clarity

---------

Co-authored-by: Ongart Pisansathienwong <[email protected]>
  • Loading branch information
warittornc and colmazia authored Oct 17, 2024
1 parent 4f9f87b commit 75b3ad3
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 13 deletions.
6 changes: 3 additions & 3 deletions bothan-core/src/manager/crypto_asset_info/price/tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ async fn compute_source_result<'a>(
workers: &WorkerMap<'a>,
cache: &PriceCache<String>,
stale_cutoff: i64,
) -> Result<Vec<Decimal>, Error> {
// Check if all prerequisites are available, if not add the missing ones to the queue
) -> Result<Vec<(String, Decimal)>, Error> {
// Check if all prerequisites are available, if not, add the missing ones to the queue
// and continue to the next signal
let mut source_results = Vec::with_capacity(signal.source_queries.len());
let mut missing_pids = HashSet::new();
Expand All @@ -127,7 +127,7 @@ async fn compute_source_result<'a>(
Ok(AssetState::Available(a)) => {
if a.timestamp.ge(&stale_cutoff) {
match compute_source_routes(&source_query.routes, a.price, cache) {
Ok(Some(price)) => source_results.push(price),
Ok(Some(price)) => source_results.push((sid.clone(), price)),
Ok(None) => {} // If unable to calculate the price, ignore the source
Err(Error::PrerequisiteRequired(ids)) => missing_pids.extend(ids),
Err(e) => return Err(e),
Expand Down
23 changes: 20 additions & 3 deletions bothan-core/src/registry/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize};
use thiserror::Error;

pub mod median;
pub mod weighted_median;

#[derive(Debug, Error, PartialEq, Clone)]
#[error("{msg}")]
Expand All @@ -18,21 +19,37 @@ impl ProcessError {
}

/// The Processor trait defines the methods that a processor must implement.
pub trait Process<T> {
fn process(&self, data: Vec<T>) -> Result<T, ProcessError>;
pub trait Process<T, U> {
fn process(&self, data: Vec<T>) -> Result<U, ProcessError>;
}

/// The Process enum represents the different types of processors that can be used.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Encode, Decode)]
#[serde(rename_all = "snake_case", tag = "function", content = "params")]
pub enum Processor {
Median(median::MedianProcessor),
WeightedMedian(weighted_median::WeightedMedianProcessor),
}

impl Process<Decimal> for Processor {
impl Process<Decimal, Decimal> for Processor {
fn process(&self, data: Vec<Decimal>) -> Result<Decimal, ProcessError> {
match self {
Processor::Median(median) => median.process(data),
Processor::WeightedMedian(_) => Err(ProcessError::new(
"Weighted median not implemented for T: Decimal",
)),
}
}
}

impl Process<(String, Decimal), Decimal> for Processor {
fn process(&self, data: Vec<(String, Decimal)>) -> Result<Decimal, ProcessError> {
match self {
Processor::Median(median) => {
let data = data.into_iter().map(|(_, value)| value).collect();
median.process(data)
}
Processor::WeightedMedian(weighted_median) => weighted_median.process(data),
}
}
}
2 changes: 1 addition & 1 deletion bothan-core/src/registry/processor/median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl MedianProcessor {
}
}

impl Process<Decimal> for MedianProcessor {
impl Process<Decimal, Decimal> for MedianProcessor {
/// Processes the given data and returns the median. If there are not enough sources, it
/// returns an error.
fn process(&self, data: Vec<Decimal>) -> Result<Decimal, ProcessError> {
Expand Down
164 changes: 164 additions & 0 deletions bothan-core/src/registry/processor/weighted_median.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
use std::cmp::Ordering;
use std::collections::HashMap;
use std::ops::{Add, Div};

use bincode::{Decode, Encode};
use num_traits::{FromPrimitive, Zero};
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};

use crate::registry::processor::{Process, ProcessError};

/// The `WeightedMedianProcessor` finds the weighted median of a given data set where the dataset
/// contains the source and the value. It also has a `minimum_cumulative_weight` which is the
/// minimum cumulative weight required to calculate the weighted median. If the cumulative weight
/// of the data sources is less than `minimum_cumulative_weight` or the source associated with the
/// data does not have an assigned weight, it returns an error.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Encode, Decode)]
pub struct WeightedMedianProcessor {
pub source_weights: HashMap<String, u32>,
pub minimum_cumulative_weight: u32,
}

impl WeightedMedianProcessor {
/// Creates a new `WeightedMedianProcessor`.
pub fn new(source_weights: HashMap<String, u32>, minimum_cumulative_weight: u32) -> Self {
WeightedMedianProcessor {
source_weights,
minimum_cumulative_weight,
}
}
}

impl Process<(String, Decimal), Decimal> for WeightedMedianProcessor {
/// Processes the given data and returns the weighted median. If the cumulative weights of the
/// data sources are less than the minimum cumulative weight or the source associated
/// with the data does not have an assigned weight, it returns an error.
fn process(&self, data: Vec<(String, Decimal)>) -> Result<Decimal, ProcessError> {
let cumulative_weight = data.iter().try_fold(0, |acc, (source, _)| {
self.source_weights
.get(source)
.map(|weight| acc + weight)
.ok_or(ProcessError::new(format!("Unknown source {source}")))
})?;

if cumulative_weight < self.minimum_cumulative_weight {
return Err(ProcessError::new(
"Not enough sources to calculate weighted median",
));
}

let values = data
.into_iter()
.map(|(source, value)| {
self.source_weights
.get(&source)
.map(|weight| (value, *weight))
.ok_or(ProcessError::new(format!("Unknown source {source}")))
})
.collect::<Result<Vec<(Decimal, u32)>, ProcessError>>()?;

Ok(compute_weighted_median(values))
}
}

// This function requires that values passed is not an empty vector, if an empty vector is passed,
// it will panic
fn compute_weighted_median<T>(mut values: Vec<(T, u32)>) -> T
where
T: Ord + Add<Output = T> + Div<Output = T> + FromPrimitive,
{
values.sort_by(|(v1, _), (v2, _)| v1.cmp(v2));

// We use the sum of the weights as the mid-value to find the median to avoid rounding when
// dividing by two.
let effective_mid = values
.iter()
.fold(u32::zero(), |acc, (_, weight)| acc + weight);

let mut effective_cumulative_weight = u32::zero();
let mut iter = values.into_iter();
while let Some((value, weight)) = iter.next() {
// We multiply the weight by 2 to avoid rounding when dividing by two.
effective_cumulative_weight += weight * 2;
match effective_cumulative_weight.cmp(&effective_mid) {
Ordering::Greater => return value,
Ordering::Equal => {
return if let Some((next_value, _)) = iter.next() {
(value + next_value) / FromPrimitive::from_u32(2).unwrap()
} else {
value
};
}
Ordering::Less => (),
}
}

unreachable!()
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_weighted_median() {
let source_weights = HashMap::from([
("a".to_string(), 15),
("b".to_string(), 10),
("c".to_string(), 20),
("d".to_string(), 30),
("e".to_string(), 25),
]);

let weighted_median = WeightedMedianProcessor::new(source_weights, 0);
let data = vec![
("a".to_string(), Decimal::from(1)),
("b".to_string(), Decimal::from(2)),
("c".to_string(), Decimal::from(3)),
("d".to_string(), Decimal::from(4)),
("e".to_string(), Decimal::from(5)),
];
let res = weighted_median.process(data);

assert_eq!(res.unwrap(), Decimal::from(4));
}

#[test]
fn test_median_with_even_weight() {
let source_weights = HashMap::from([("a".to_string(), 2)]);

let weighted_median = WeightedMedianProcessor::new(source_weights, 0);
let data = vec![
("a".to_string(), Decimal::from(1)),
("a".to_string(), Decimal::from(2)),
("a".to_string(), Decimal::from(3)),
("a".to_string(), Decimal::from(4)),
("a".to_string(), Decimal::from(5)),
];
let res = weighted_median.process(data);

assert_eq!(res.unwrap(), Decimal::from(3));
}

#[test]
fn test_weighted_median_with_intersect() {
let source_weights = HashMap::from([
("a".to_string(), 49),
("b".to_string(), 1),
("c".to_string(), 25),
("d".to_string(), 25),
]);

let weighted_median = WeightedMedianProcessor::new(source_weights, 0);
let data = vec![
("a".to_string(), Decimal::from(1)),
("b".to_string(), Decimal::from(2)),
("c".to_string(), Decimal::from(3)),
("d".to_string(), Decimal::from(4)),
];
let res = weighted_median.process(data);

assert_eq!(res.unwrap(), Decimal::from_str_exact("2.5").unwrap());
}
}
9 changes: 3 additions & 6 deletions bothan-htx/src/api/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,9 @@ impl HtxWebSocketConnection {
}
Ok(Message::Text(msg)) => serde_json::from_str::<HtxResponse>(&msg)
.map_err(|_| MessageError::UnsupportedMessage),
Err(err) => match err {
TungsteniteError::Protocol(..) | TungsteniteError::ConnectionClosed => {
Err(MessageError::ChannelClosed)
}
_ => Err(MessageError::UnsupportedMessage),
},
Err(TungsteniteError::Protocol(_)) | Err(TungsteniteError::ConnectionClosed) => {
Err(MessageError::ChannelClosed)
}
_ => Err(MessageError::UnsupportedMessage),
};
}
Expand Down

0 comments on commit 75b3ad3

Please sign in to comment.