From 71e0693d642fbf799e9278c977d933f64216df99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Ant=C3=B4nio=20Cardoso?= Date: Sat, 7 Sep 2024 11:38:16 -0300 Subject: [PATCH 1/6] cargo: Add indexmap crate --- Cargo.lock | 33 +++++++++++++++++++++++++-------- Cargo.toml | 1 + 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ef1d5c0e..5ab6fadb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -116,9 +116,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.86" +version = "1.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +checksum = "10f00e1f6e58a40e807377c75c6a7f97bf9044fab57816f2414e6f5f4499d7b8" [[package]] name = "arc-swap" @@ -233,9 +233,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.16" +version = "1.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9d013ecb737093c0e86b151a7b837993cf9ec6c502946cfb44bedc392421e0b" +checksum = "b62ac837cdb5cb22e10a256099b4fc502b1dfe560cb282963a974d7abd80e476" dependencies = [ "shlex", ] @@ -466,6 +466,12 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "errno" version = "0.3.9" @@ -1220,6 +1226,16 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "indexmap" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5" +dependencies = [ + "equivalent", + "hashbrown", +] + [[package]] name = "io-kit-sys" version = "0.4.1" @@ -1360,6 +1376,7 @@ dependencies = [ "clap", "ctrlc", "futures", + "indexmap", "lazy_static", "mavlink", "regex", @@ -1776,18 +1793,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.209" +version = "1.0.210" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09" +checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.209" +version = "1.0.210" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170" +checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 7bd5d6ae..904bfa92 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ clap = { version = "4.5", features = ["derive"] } ctrlc = "3.4" futures = "0.3" lazy_static = "1.5.0" +indexmap = "2.5.0" # mavlink = { version = "0.13.1", default-features = false, features = ["ardupilotmega", "std"] } # mavlink = { default-features = false, features = ["ardupilotmega", "std", "tokio-1"], path = "../rust-mavlink/mavlink" } mavlink = { default-features = false, features = ["ardupilotmega", "std", "tokio-1"], git = "https://github.com/joaoantoniocardoso/rust-mavlink", branch = "add-tokio" } From 9a0ec6082df7ae40230eaeaa89ba4ce666fe4eaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Ant=C3=B4nio=20Cardoso?= Date: Sat, 7 Sep 2024 14:37:51 -0300 Subject: [PATCH 2/6] src: rewrite callbacks into a crate lib --- Cargo.toml | 5 ++ src/drivers/mod.rs | 40 ++++------ src/lib/callbacks.rs | 174 +++++++++++++++++++++++++++++++++++++++++++ src/lib/mod.rs | 1 + 4 files changed, 195 insertions(+), 25 deletions(-) create mode 100644 src/lib/callbacks.rs create mode 100644 src/lib/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 904bfa92..1cd8db50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,11 @@ name = "mavlink-server" version = "0.1.0" edition = "2021" +[lib] +name = "mavlink_server" +path = "src/lib/mod.rs" +bench = false + [dependencies] anyhow = "1" async-trait = "0.1.81" diff --git a/src/drivers/mod.rs b/src/drivers/mod.rs index 0a731958..6aca344d 100644 --- a/src/drivers/mod.rs +++ b/src/drivers/mod.rs @@ -39,24 +39,9 @@ pub trait Driver: Send + Sync { fn info(&self) -> Box; } -type OnMessageCallback = Option + Send + Sync>>; - -pub trait OnMessageCallbackExt: Send + Sync { - fn call(&self, msg: T) -> futures::future::BoxFuture<'static, Result<()>>; -} - -impl OnMessageCallbackExt for F -where - F: Fn(T) -> Fut + Send + Sync + 'static, - Fut: std::future::Future> + Send + 'static, -{ - fn call(&self, msg: T) -> futures::future::BoxFuture<'static, Result<()>> { - Box::pin(self(msg)) - } -} - pub trait DriverInfo: Sync + Send { fn name(&self) -> &str; + fn valid_schemes(&self) -> Vec; fn create_endpoint_from_url(&self, url: &Url) -> Option>; @@ -221,6 +206,7 @@ mod tests { use anyhow::{anyhow, Result}; use mavlink::MAVLinkV2MessageRaw; + use mavlink_server::callbacks::{Callbacks, MessageCallback}; use tokio::sync::RwLock; use tracing::*; @@ -239,14 +225,15 @@ mod tests { } // Example struct implementing Driver - #[derive(Default)] pub struct ExampleDriver { - on_message: OnMessageCallback>, + on_message: Callbacks>, } impl ExampleDriver { pub fn new() -> ExampleDriverBuilder { - ExampleDriverBuilder(Self { on_message: None }) + ExampleDriverBuilder(Self { + on_message: Callbacks::new(), + }) } } @@ -257,11 +244,11 @@ mod tests { self.0 } - pub fn on_message(mut self, callback: F) -> Self + pub fn on_message(self, callback: C) -> Self where - F: OnMessageCallbackExt> + Send + Sync + 'static, + C: MessageCallback>, { - self.0.on_message = Some(Box::new(callback)); + self.0.on_message.add_callback(callback.into_boxed()); self } } @@ -272,11 +259,14 @@ mod tests { let mut hub_receiver = hub_sender.subscribe(); while let Ok(message) = hub_receiver.recv().await { - if let Some(callback) = &self.on_message { - callback.call(Arc::clone(&message)).await?; + for future in self.on_message.call_all(Arc::clone(&message)) { + if let Err(error) = future.await { + debug!("Dropping message: on_message callback returned error: {error:?}"); + continue; + } } - debug!("message:?"); + trace!("Message received: {message:?}"); } Ok(()) diff --git a/src/lib/callbacks.rs b/src/lib/callbacks.rs new file mode 100644 index 00000000..463b77dc --- /dev/null +++ b/src/lib/callbacks.rs @@ -0,0 +1,174 @@ +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, +}; + +use anyhow::Result; +use futures::future::BoxFuture; +use indexmap::IndexMap; + +#[derive(Clone)] +pub struct Callbacks { + callbacks: Arc>>>, +} + +impl Callbacks { + pub fn new() -> Self { + Self { + callbacks: Arc::new(Mutex::new(IndexMap::new())), + } + } + + pub fn add_callback(&self, callback: F) -> usize + where + F: Fn(T) -> Fut + Send + Sync + 'static, + Fut: std::future::Future> + Send + 'static, + { + static COUNTER: AtomicUsize = AtomicUsize::new(0); + let callback_id = COUNTER.fetch_add(1, Ordering::Relaxed); + + let mut callbacks = self.callbacks.lock().unwrap(); + callbacks.insert(callback_id, Callback::new(callback)); + + callback_id + } + + pub fn remove_callback(&self, id: usize) { + let mut callbacks = self.callbacks.lock().unwrap(); + callbacks.shift_remove(&id); + } + + pub fn call_all(&self, msg: T) -> Vec>> + where + T: Clone, + { + let callbacks = self.callbacks.lock().unwrap(); + callbacks + .values() + .map(|callback| callback.call(msg.clone())) + .collect() + } +} + +#[derive(Clone)] +struct Callback { + callback: Arc BoxFuture<'static, Result<()>> + Send + Sync>, +} + +impl Callback { + fn new(callback: F) -> Self + where + F: Fn(T) -> Fut + Send + Sync + 'static, + Fut: std::future::Future> + Send + 'static, + { + Self { + callback: Arc::new(move |msg: T| Box::pin(callback(msg))), + } + } + + fn call(&self, msg: T) -> BoxFuture<'static, Result<()>> { + (self.callback)(msg) + } +} + +pub trait MessageCallback: Send + Sync + 'static { + fn into_boxed(self) -> Box BoxFuture<'static, Result<()>> + Send + Sync>; +} + +impl MessageCallback for F +where + F: Fn(T) -> Fut + Send + Sync + 'static, + Fut: std::future::Future> + Send + 'static, +{ + fn into_boxed(self) -> Box BoxFuture<'static, Result<()>> + Send + Sync> { + Box::new(move |msg: T| Box::pin(self(msg))) + } +} + +#[cfg(test)] +mod tests { + use tokio::sync::oneshot; + + use super::*; + + #[tokio::test] + async fn test_callbacks() { + let callbacks = Callbacks::::new(); + + let (tx1, rx1) = oneshot::channel(); + let tx1 = Arc::new(Mutex::new(Some(tx1))); + let id1 = callbacks.add_callback(move |msg: String| { + let tx1 = tx1.clone(); + async move { + if msg == "test" { + if let Some(sender) = tx1.lock().unwrap().take() { + sender.send(()).unwrap(); + } + } + Ok(()) + } + }); + + let (tx2, rx2) = oneshot::channel(); + let tx2 = Arc::new(Mutex::new(Some(tx2))); + let id2 = callbacks.add_callback(move |msg: String| { + let tx2 = tx2.clone(); + async move { + if msg == "test" { + if let Some(sender) = tx2.lock().unwrap().take() { + sender.send(()).unwrap(); + } + } + Ok(()) + } + }); + + let (tx3, rx3) = oneshot::channel(); + let tx3 = Arc::new(Mutex::new(Some(tx3))); + let id3 = callbacks.add_callback(move |msg: String| { + let tx3 = tx3.clone(); + async move { + if msg == "test" { + if let Some(sender) = tx3.lock().unwrap().take() { + sender.send(()).unwrap(); + } + } + Ok(()) + } + }); + + // Remove the secondly added callback + callbacks.remove_callback(id2); + + // Add a fourth callback + let (tx4, rx4) = oneshot::channel(); + let tx4 = Arc::new(Mutex::new(Some(tx4))); + let id4 = callbacks.add_callback(move |msg: String| { + let tx4 = tx4.clone(); + async move { + if msg == "test" { + if let Some(sender) = tx4.lock().unwrap().take() { + sender.send(()).unwrap(); + } + } + Ok(()) + } + }); + + // Remove the third callback + callbacks.remove_callback(id3); + + // Certify that only the callback functions 1 and 4 are called + let futures = callbacks.call_all("test".to_string()); + futures::future::join_all(futures).await; + + assert!(rx1.await.is_ok(), "Callback 1 should be called"); + assert!(rx2.await.is_err(), "Callback 2 should NOT be called"); + assert!(rx3.await.is_err(), "Callback 3 should NOT be called"); + assert!(rx4.await.is_ok(), "Callback 4 should be called"); + + // Remove remaining callbacks + callbacks.remove_callback(id1); + callbacks.remove_callback(id4); + } +} diff --git a/src/lib/mod.rs b/src/lib/mod.rs new file mode 100644 index 00000000..bd482ca3 --- /dev/null +++ b/src/lib/mod.rs @@ -0,0 +1 @@ +pub mod callbacks; From 5601e8a8c9e89b5fe15a276ba2cdc256a968b548 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Ant=C3=B4nio=20Cardoso?= Date: Sat, 7 Sep 2024 14:51:18 -0300 Subject: [PATCH 3/6] benches: Add callbacks benchmark --- benches/callbacks_bench.rs | 48 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 benches/callbacks_bench.rs diff --git a/benches/callbacks_bench.rs b/benches/callbacks_bench.rs new file mode 100644 index 00000000..73e20863 --- /dev/null +++ b/benches/callbacks_bench.rs @@ -0,0 +1,48 @@ +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use mavlink_server::callbacks::Callbacks; +use tokio::runtime::Runtime; + +fn bench_call_all(c: &mut Criterion) { + let mut group = c.benchmark_group("callbacks"); + + let callback_counts = vec![0, 1, 3, 5, 10, 20, 50, 100]; + + for number_of_callbacks in &callback_counts { + group.throughput(criterion::Throughput::Elements(*number_of_callbacks)); + group.bench_with_input( + BenchmarkId::from_parameter(number_of_callbacks), + number_of_callbacks, + |b, &number_of_callbacks| { + let rt = Runtime::new().unwrap(); + let callbacks = Callbacks::::new(); + + for _ in 0..number_of_callbacks { + callbacks.add_callback({ + move |msg: String| async move { + if msg != "test" { + panic!("Wrong message"); + } + Ok(()) + } + }); + } + + // Benchmark calling all callbacks + b.iter(|| { + rt.block_on(async { + for future in callbacks.call_all("test".to_string()) { + if let Err(_error) = future.await { + continue; + } + } + }); + }); + }, + ); + } + + group.finish(); +} + +criterion_group!(benches, bench_call_all); +criterion_main!(benches); From fad7e1f4e60b70d28af555565ad9c8685e7cd7f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Ant=C3=B4nio=20Cardoso?= Date: Sat, 7 Sep 2024 14:50:49 -0300 Subject: [PATCH 4/6] cargo: Add criterion and enable benchmarks --- Cargo.lock | 219 ++++++++++++++++++++++++++++++++++++++++++++++++++++- Cargo.toml | 12 +++ 2 files changed, 230 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 5ab6fadb..e8475983 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -65,6 +65,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstream" version = "0.6.15" @@ -231,6 +237,12 @@ dependencies = [ "thiserror", ] +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cc" version = "1.1.18" @@ -266,6 +278,33 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "clap" version = "4.5.17" @@ -339,6 +378,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools", +] + [[package]] name = "crossbeam-channel" version = "0.5.13" @@ -348,12 +423,37 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + [[package]] name = "ctrlc" version = "3.4.5" @@ -466,6 +566,12 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" +[[package]] +name = "either" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" + [[package]] name = "equivalent" version = "1.0.1" @@ -1156,6 +1262,16 @@ dependencies = [ "thiserror", ] +[[package]] +name = "half" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +dependencies = [ + "cfg-if", + "crunchy", +] + [[package]] name = "hashbrown" version = "0.14.5" @@ -1178,6 +1294,12 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" +[[package]] +name = "hermit-abi" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" + [[package]] name = "home" version = "0.5.9" @@ -1246,12 +1368,32 @@ dependencies = [ "mach2", ] +[[package]] +name = "is-terminal" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b" +dependencies = [ + "hermit-abi 0.4.0", + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -1374,6 +1516,7 @@ dependencies = [ "byteorder", "chrono", "clap", + "criterion", "ctrlc", "futures", "indexmap", @@ -1452,7 +1595,7 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" dependencies = [ - "hermit-abi", + "hermit-abi 0.3.9", "libc", "wasi", "windows-sys 0.52.0", @@ -1556,6 +1699,12 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "oorandom" +version = "11.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" + [[package]] name = "option-ext" version = "0.2.0" @@ -1609,6 +1758,34 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "plotters" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a15b6eccb8484002195a3e44fe65a4ce8e93a625797a063735536fd59cb01cf3" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "414cec62c6634ae900ea1c56128dfe87cf63e7caece0852ec76aba307cebadb7" + +[[package]] +name = "plotters-svg" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81b30686a7d9c3e010b84284bdd26a29f2138574f52f5eb6f794fc0ad924e705" +dependencies = [ + "plotters-backend", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -1672,6 +1849,26 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.5.3" @@ -2018,6 +2215,16 @@ dependencies = [ "time-core", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.8.0" @@ -2327,6 +2534,16 @@ version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" +[[package]] +name = "web-sys" +version = "0.3.70" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/Cargo.toml b/Cargo.toml index 1cd8db50..d0a9f05f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,15 @@ name = "mavlink_server" path = "src/lib/mod.rs" bench = false +[[bench]] +name = "callbacks_bench" +harness = false + +[[bin]] +name = "mavlink-server" +path = "src/main.rs" +bench = false + [dependencies] anyhow = "1" async-trait = "0.1.81" @@ -35,6 +44,9 @@ tracing-log = "0.2.0" # Reference: https://github.com/tokio-rs/tracing/issues/2441 tracing-appender = { git = "https://github.com/joaoantoniocardoso/tracing", branch = "tracing-appender-0.2.2-with-filename-suffix" } +[dev-dependencies] +criterion = "0.5" +tokio = { version = "1", features = ["full"] } [build-dependencies] vergen-gix = { version = "1.0.0-beta.2", default-features = false, features = ["build", "cargo"] } From f44a6aed7820e24f5a5e862279e2256e31f6007a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Ant=C3=B4nio=20Cardoso?= Date: Sat, 7 Sep 2024 14:48:58 -0300 Subject: [PATCH 5/6] src: drivers: update code to match new callbacks crate lib --- src/drivers/fake.rs | 47 +++++++++++++++++++++++--------------- src/drivers/file/server.rs | 29 ++++++++++++----------- 2 files changed, 44 insertions(+), 32 deletions(-) diff --git a/src/drivers/fake.rs b/src/drivers/fake.rs index 12c5e7fa..b70de966 100644 --- a/src/drivers/fake.rs +++ b/src/drivers/fake.rs @@ -1,22 +1,24 @@ use std::sync::Arc; use anyhow::Result; +use mavlink_server::callbacks::{Callbacks, MessageCallback}; use tokio::sync::broadcast; use tracing::*; use crate::{ - drivers::{Driver, DriverInfo, OnMessageCallback, OnMessageCallbackExt}, + drivers::{Driver, DriverInfo}, protocol::{read_all_messages, Protocol}, }; -#[derive(Default)] pub struct FakeSink { - on_message: OnMessageCallback>, + on_message: Callbacks>, } impl FakeSink { pub fn new() -> FakeSinkBuilder { - FakeSinkBuilder(Self { on_message: None }) + FakeSinkBuilder(Self { + on_message: Callbacks::new(), + }) } } @@ -27,11 +29,11 @@ impl FakeSinkBuilder { self.0 } - pub fn on_message(mut self, callback: F) -> Self + pub fn on_message(self, callback: C) -> Self where - F: OnMessageCallbackExt> + Send + Sync + 'static, + C: MessageCallback>, { - self.0.on_message = Some(Box::new(callback)); + self.0.on_message.add_callback(callback.into_boxed()); self } } @@ -42,11 +44,14 @@ impl Driver for FakeSink { let mut hub_receiver = hub_sender.subscribe(); while let Ok(message) = hub_receiver.recv().await { - debug!("Message received: {message:?}"); - - if let Some(callback) = &self.on_message { - callback.call(Arc::clone(&message)).await?; + for future in self.on_message.call_all(Arc::clone(&message)) { + if let Err(error) = future.await { + debug!("Dropping message: on_message callback returned error: {error:?}"); + continue; + } } + + trace!("Message received: {message:?}"); } Ok(()) @@ -77,17 +82,16 @@ impl DriverInfo for FakeSinkInfo { } } -#[derive(Default)] pub struct FakeSource { period: std::time::Duration, - on_message: OnMessageCallback>, + on_message: Callbacks>, } impl FakeSource { pub fn new(period: std::time::Duration) -> FakeSourceBuilder { FakeSourceBuilder(Self { period, - on_message: None, + on_message: Callbacks::new(), }) } } @@ -99,11 +103,11 @@ impl FakeSourceBuilder { self.0 } - pub fn on_message(mut self, callback: F) -> Self + pub fn on_message(self, callback: C) -> Self where - F: OnMessageCallbackExt> + Send + Sync + 'static, + C: MessageCallback>, { - self.0.on_message = Some(Box::new(callback)); + self.0.on_message.add_callback(callback.into_boxed()); self } } @@ -149,8 +153,13 @@ impl Driver for FakeSource { async move { trace!("Fake message created: {message:?}"); - if let Some(callback) = &self.on_message { - callback.call(Arc::clone(&message)).await.unwrap(); + for future in self.on_message.call_all(Arc::clone(&message)) { + if let Err(error) = future.await { + debug!( + "Dropping message: on_message callback returned error: {error:?}" + ); + continue; + } } if let Err(error) = hub_sender.send(message) { diff --git a/src/drivers/file/server.rs b/src/drivers/file/server.rs index e72fe065..60c5dfaa 100644 --- a/src/drivers/file/server.rs +++ b/src/drivers/file/server.rs @@ -1,20 +1,20 @@ -use std::sync::Arc; +use std::{path::PathBuf, sync::Arc}; use anyhow::Result; use chrono::DateTime; use mavlink::ardupilotmega::MavMessage; -use std::path::PathBuf; +use mavlink_server::callbacks::{Callbacks, MessageCallback}; use tokio::sync::broadcast; use tracing::*; use crate::{ - drivers::{Driver, DriverInfo, OnMessageCallback, OnMessageCallbackExt}, + drivers::{Driver, DriverInfo}, protocol::Protocol, }; pub struct FileServer { pub path: PathBuf, - on_message: OnMessageCallback<(u64, Arc)>, + on_message: Callbacks<(u64, Arc)>, } pub struct FileServerBuilder(FileServer); @@ -24,11 +24,11 @@ impl FileServerBuilder { self.0 } - pub fn on_message(mut self, callback: F) -> Self + pub fn on_message(self, callback: C) -> Self where - F: OnMessageCallbackExt<(u64, Arc)> + Send + Sync + 'static, + C: MessageCallback<(u64, Arc)>, { - self.0.on_message = Some(Box::new(callback)); + self.0.on_message.add_callback(callback.into_boxed()); self } } @@ -38,7 +38,7 @@ impl FileServer { pub fn new(path: PathBuf) -> FileServerBuilder { FileServerBuilder(Self { path, - on_message: None, + on_message: Callbacks::new(), }) } @@ -100,11 +100,14 @@ impl FileServer { let message = Arc::new(message); - if let Some(callback) = &self.on_message { - callback - .call((us_since_epoch, Arc::clone(&message))) - .await - .unwrap(); + for future in self + .on_message + .call_all((us_since_epoch, (Arc::clone(&message)))) + { + if let Err(error) = future.await { + debug!("Dropping message: on_message callback returned error: {error:?}"); + continue; + } } if let Err(error) = hub_sender.send(message) { From 7f3aa7fd38b10efac8b2ac2991f6fa72d64d3a02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Ant=C3=B4nio=20Cardoso?= Date: Sat, 7 Sep 2024 14:49:55 -0300 Subject: [PATCH 6/6] src: drivers: Move to builder and add callbacks --- src/drivers/fake.rs | 11 +++---- src/drivers/file/client.rs | 35 +++++++++++++++++++--- src/drivers/file/server.rs | 13 +++++---- src/drivers/serial/mod.rs | 33 +++++++++++++++++---- src/drivers/tcp/client.rs | 33 +++++++++++++++++---- src/drivers/tcp/mod.rs | 25 ++++++++++++++-- src/drivers/tcp/server.rs | 42 +++++++++++++++++++++------ src/drivers/udp/client.rs | 59 ++++++++++++++++++++++++++++++++------ src/drivers/udp/server.rs | 52 ++++++++++++++++++++++++++------- 9 files changed, 245 insertions(+), 58 deletions(-) diff --git a/src/drivers/fake.rs b/src/drivers/fake.rs index b70de966..e29afca7 100644 --- a/src/drivers/fake.rs +++ b/src/drivers/fake.rs @@ -15,7 +15,7 @@ pub struct FakeSink { } impl FakeSink { - pub fn new() -> FakeSinkBuilder { + pub fn builder() -> FakeSinkBuilder { FakeSinkBuilder(Self { on_message: Callbacks::new(), }) @@ -88,7 +88,7 @@ pub struct FakeSource { } impl FakeSource { - pub fn new(period: std::time::Duration) -> FakeSourceBuilder { + pub fn builder(period: std::time::Duration) -> FakeSourceBuilder { FakeSourceBuilder(Self { period, on_message: Callbacks::new(), @@ -200,8 +200,9 @@ impl DriverInfo for FakeSourceInfo { #[cfg(test)] mod test { - use anyhow::Result; use std::sync::Arc; + + use anyhow::Result; use tokio::sync::{broadcast, RwLock}; use super::*; @@ -219,7 +220,7 @@ mod test { // FakeSink and task let sink_messages_clone = sink_messages.clone(); - let sink = FakeSink::new() + let sink = FakeSink::builder() .on_message(move |message: Arc| { let sink_messages = sink_messages_clone.clone(); @@ -237,7 +238,7 @@ mod test { // FakeSource and task let source_messages_clone = source_messages.clone(); - let source = FakeSource::new(message_period) + let source = FakeSource::builder(message_period) .on_message(move |message: Arc| { let source_messages = source_messages_clone.clone(); diff --git a/src/drivers/file/client.rs b/src/drivers/file/client.rs index 268c420a..95bf41b5 100644 --- a/src/drivers/file/client.rs +++ b/src/drivers/file/client.rs @@ -1,6 +1,7 @@ use std::{path::PathBuf, sync::Arc}; use anyhow::Result; +use mavlink_server::callbacks::{Callbacks, MessageCallback}; use tokio::{ io::{AsyncWriteExt, BufWriter}, sync::broadcast, @@ -14,6 +15,7 @@ use crate::{ pub struct FileClient { pub path: PathBuf, + on_message: Callbacks<(u64, Arc)>, } pub struct FileClientBuilder(FileClient); @@ -22,12 +24,23 @@ impl FileClientBuilder { pub fn build(self) -> FileClient { self.0 } + + pub fn on_message(self, callback: C) -> Self + where + C: MessageCallback<(u64, Arc)>, + { + self.0.on_message.add_callback(callback.into_boxed()); + self + } } impl FileClient { #[instrument(level = "debug")] - pub fn new(path: PathBuf) -> FileClientBuilder { - FileClientBuilder(Self { path }) + pub fn builder(path: PathBuf) -> FileClientBuilder { + FileClientBuilder(Self { + path, + on_message: Callbacks::new(), + }) } #[instrument(level = "debug", skip(self, writer, hub_receiver))] @@ -41,8 +54,22 @@ impl FileClient { loop { match hub_receiver.recv().await { Ok(message) => { - let raw_bytes = message.raw_bytes(); let timestamp = chrono::Utc::now().timestamp_micros() as u64; + let message = Arc::new(message); + + for future in self + .on_message + .call_all((timestamp, (Arc::clone(&message)))) + { + if let Err(error) = future.await { + debug!( + "Dropping message: on_message callback returned error: {error:?}" + ); + continue; + } + } + + let raw_bytes = message.raw_bytes(); writer.write_all(×tamp.to_be_bytes()).await?; writer.write_all(raw_bytes).await?; writer.flush().await?; @@ -91,6 +118,6 @@ impl DriverInfo for FileClientInfo { } fn create_endpoint_from_url(&self, url: &url::Url) -> Option> { - Some(Arc::new(FileClient::new(url.path().into()).build())) + Some(Arc::new(FileClient::builder(url.path().into()).build())) } } diff --git a/src/drivers/file/server.rs b/src/drivers/file/server.rs index 60c5dfaa..6792ea96 100644 --- a/src/drivers/file/server.rs +++ b/src/drivers/file/server.rs @@ -35,7 +35,7 @@ impl FileServerBuilder { impl FileServer { #[instrument(level = "debug")] - pub fn new(path: PathBuf) -> FileServerBuilder { + pub fn builder(path: PathBuf) -> FileServerBuilder { FileServerBuilder(Self { path, on_message: Callbacks::new(), @@ -43,7 +43,7 @@ impl FileServer { } #[instrument(level = "debug", skip(self, reader, hub_sender))] - async fn handle_client( + async fn handle_file( &self, reader: tokio::io::BufReader, hub_sender: broadcast::Sender>, @@ -124,7 +124,7 @@ impl Driver for FileServer { let file = tokio::fs::File::open(self.path.clone()).await?; let reader = tokio::io::BufReader::with_capacity(1024, file); - FileServer::handle_client(self, reader, hub_sender).await + FileServer::handle_file(self, reader, hub_sender).await } #[instrument(level = "debug", skip(self))] @@ -179,7 +179,7 @@ impl DriverInfo for FileServerInfo { } fn create_endpoint_from_url(&self, url: &url::Url) -> Option> { - Some(Arc::new(FileServer::new(url.path().into()).build())) + Some(Arc::new(FileServer::builder(url.path().into()).build())) } } @@ -204,7 +204,7 @@ mod tests { let tlog_file = PathBuf::from_str("tests/files/00025-2024-04-22_18-49-07.tlog").unwrap(); - let driver = FileServer::new(tlog_file.clone()) + let driver = FileServer::builder(tlog_file.clone()) .on_message(move |args: (u64, Arc)| { let messages_received = messages_received_cloned.clone(); @@ -232,7 +232,8 @@ mod tests { let file_v2_messages = 30437; let file_messages = file_v2_messages; let mut total_messages_read = 0; - let res = tokio::time::timeout(tokio::time::Duration::from_secs(1), async { + let timeout_time = tokio::time::Duration::from_secs(1); + let res = tokio::time::timeout(timeout_time, async { loop { let messages_received_per_id = messages_received_per_id.read().await.clone(); total_messages_read = messages_received_per_id diff --git a/src/drivers/serial/mod.rs b/src/drivers/serial/mod.rs index b10faa10..7fa8f1f7 100644 --- a/src/drivers/serial/mod.rs +++ b/src/drivers/serial/mod.rs @@ -1,8 +1,11 @@ use std::sync::Arc; use anyhow::Result; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::sync::{broadcast, Mutex}; +use mavlink_server::callbacks::{Callbacks, MessageCallback}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + sync::{broadcast, Mutex}, +}; use tokio_serial::{self, SerialPortBuilderExt}; use tracing::*; @@ -14,15 +17,33 @@ use crate::{ pub struct Serial { pub port_name: String, pub baud_rate: u32, + on_message: Callbacks>, +} + +pub struct SerialBuilder(Serial); + +impl SerialBuilder { + pub fn build(self) -> Serial { + self.0 + } + + pub fn on_message(self, callback: C) -> Self + where + C: MessageCallback>, + { + self.0.on_message.add_callback(callback.into_boxed()); + self + } } impl Serial { #[instrument(level = "debug")] - pub fn new(port_name: &str, baud_rate: u32) -> Self { - Self { + pub fn builder(port_name: &str, baud_rate: u32) -> SerialBuilder { + SerialBuilder(Self { port_name: port_name.to_string(), baud_rate, - } + on_message: Callbacks::new(), + }) } #[instrument(level = "debug", skip(port))] @@ -148,6 +169,6 @@ impl DriverInfo for SerialInfo { }) .unwrap_or(115200); // Commun baudrate between flight controllers - Some(Arc::new(Serial::new(&port_name, baud_rate))) + Some(Arc::new(Serial::builder(&port_name, baud_rate).build())) } } diff --git a/src/drivers/tcp/client.rs b/src/drivers/tcp/client.rs index d4f87a58..4d2c5aa3 100644 --- a/src/drivers/tcp/client.rs +++ b/src/drivers/tcp/client.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use anyhow::Result; +use mavlink_server::callbacks::{Callbacks, MessageCallback}; use tokio::{net::TcpStream, sync::broadcast}; use tracing::*; @@ -14,14 +15,32 @@ use crate::{ pub struct TcpClient { pub remote_addr: String, + on_message: Callbacks>, +} + +pub struct TcpClientBuilder(TcpClient); + +impl TcpClientBuilder { + pub fn build(self) -> TcpClient { + self.0 + } + + pub fn on_message(self, callback: C) -> Self + where + C: MessageCallback>, + { + self.0.on_message.add_callback(callback.into_boxed()); + self + } } impl TcpClient { #[instrument(level = "debug")] - pub fn new(remote_addr: &str) -> Self { - Self { + pub fn builder(remote_addr: &str) -> TcpClientBuilder { + TcpClientBuilder(Self { remote_addr: remote_addr.to_string(), - } + on_message: Callbacks::new(), + }) } } @@ -48,12 +67,12 @@ impl Driver for TcpClient { let hub_sender_cloned = Arc::clone(&hub_sender); tokio::select! { - result = tcp_receive_task(read, server_addr, hub_sender_cloned) => { + result = tcp_receive_task(read, server_addr, hub_sender_cloned, &self.on_message) => { if let Err(e) = result { error!("Error in TCP receive task: {e:?}"); } } - result = tcp_send_task(write, server_addr, hub_receiver) => { + result = tcp_send_task(write, server_addr, hub_receiver, &self.on_message) => { if let Err(e) = result { error!("Error in TCP send task: {e:?}"); } @@ -83,6 +102,8 @@ impl DriverInfo for TcpClientInfo { fn create_endpoint_from_url(&self, url: &url::Url) -> Option> { let host = url.host_str().unwrap(); let port = url.port().unwrap(); - Some(Arc::new(TcpClient::new(&format!("{host}:{port}")))) + Some(Arc::new( + TcpClient::builder(&format!("{host}:{port}")).build(), + )) } } diff --git a/src/drivers/tcp/mod.rs b/src/drivers/tcp/mod.rs index ee0968f1..d82f1cbc 100644 --- a/src/drivers/tcp/mod.rs +++ b/src/drivers/tcp/mod.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use anyhow::Result; +use mavlink_server::callbacks::Callbacks; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::tcp::{OwnedReadHalf, OwnedWriteHalf}, @@ -14,11 +15,12 @@ pub mod client; pub mod server; /// Receives messages from the TCP Socket and sends them to the HUB Channel -#[instrument(level = "debug", skip(socket, hub_sender))] +#[instrument(level = "debug", skip(socket, hub_sender, on_message))] async fn tcp_receive_task( mut socket: OwnedReadHalf, remote_addr: &str, hub_sender: Arc>>, + on_message: &Callbacks>, ) -> Result<()> { let mut buf = Vec::with_capacity(1024); @@ -32,7 +34,16 @@ async fn tcp_receive_task( trace!("Received TCP packet: {buf:?}"); read_all_messages(remote_addr, &mut buf, |message| async { - if let Err(error) = hub_sender.send(Arc::new(message)) { + let message = Arc::new(message); + + for future in on_message.call_all(Arc::clone(&message)) { + if let Err(error) = future.await { + debug!("Dropping message: on_message callback returned error: {error:?}"); + continue; + } + } + + if let Err(error) = hub_sender.send(message) { error!("Failed to send message to hub: {error:?}"); } }) @@ -44,11 +55,12 @@ async fn tcp_receive_task( } /// Receives messages from the HUB Channel and sends them to the TCP Socket -#[instrument(level = "debug", skip(socket, hub_receiver))] +#[instrument(level = "debug", skip(socket, hub_receiver, on_message))] async fn tcp_send_task( mut socket: OwnedWriteHalf, remote_addr: &str, mut hub_receiver: broadcast::Receiver>, + on_message: &Callbacks>, ) -> Result<()> { loop { let message = match hub_receiver.recv().await { @@ -67,6 +79,13 @@ async fn tcp_send_task( continue; // Don't do loopback } + for future in on_message.call_all(Arc::clone(&message)) { + if let Err(error) = future.await { + debug!("Dropping message: on_message callback returned error: {error:?}"); + continue; + } + } + socket.write_all(message.raw_bytes()).await?; trace!( diff --git a/src/drivers/tcp/server.rs b/src/drivers/tcp/server.rs index 75d90da6..449d12bd 100644 --- a/src/drivers/tcp/server.rs +++ b/src/drivers/tcp/server.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use anyhow::Result; +use mavlink_server::callbacks::{Callbacks, MessageCallback}; use tokio::{ net::{TcpListener, TcpStream}, sync::broadcast, @@ -15,36 +16,56 @@ use crate::{ protocol::Protocol, }; +#[derive(Clone)] pub struct TcpServer { pub local_addr: String, + on_message: Callbacks>, +} + +pub struct TcpServerBuilder(TcpServer); + +impl TcpServerBuilder { + pub fn build(self) -> TcpServer { + self.0 + } + + pub fn on_message(self, callback: C) -> Self + where + C: MessageCallback>, + { + self.0.on_message.add_callback(callback.into_boxed()); + self + } } impl TcpServer { #[instrument(level = "debug")] - pub fn new(local_addr: &str) -> Self { - Self { + pub fn builder(local_addr: &str) -> TcpServerBuilder { + TcpServerBuilder(Self { local_addr: local_addr.to_string(), - } + on_message: Callbacks::new(), + }) } /// Handles communication with a single client - #[instrument(level = "debug", skip(socket, hub_sender))] + #[instrument(level = "debug", skip(socket, hub_sender, on_message))] async fn handle_client( socket: TcpStream, remote_addr: String, hub_sender: Arc>>, + on_message: Callbacks>, ) -> Result<()> { let hub_receiver = hub_sender.subscribe(); let (read, write) = socket.into_split(); tokio::select! { - result = tcp_receive_task(read, &remote_addr, hub_sender) => { + result = tcp_receive_task(read, &remote_addr, hub_sender, &on_message) => { if let Err(e) = result { error!("Error in TCP receive task for {remote_addr}: {e:?}"); } } - result = tcp_send_task(write, &remote_addr, hub_receiver) => { + result = tcp_send_task(write, &remote_addr, hub_receiver, &on_message) => { if let Err(e) = result { error!("Error in TCP send task for {remote_addr}: {e:?}"); } @@ -67,12 +88,13 @@ impl Driver for TcpServer { match listener.accept().await { Ok((socket, remote_addr)) => { let remote_addr = remote_addr.to_string(); - let hub_sender_cloned = Arc::clone(&hub_sender); + let hub_sender = Arc::clone(&hub_sender); tokio::spawn(TcpServer::handle_client( socket, remote_addr, - hub_sender_cloned, + hub_sender, + self.on_message.clone(), )); } Err(error) => { @@ -102,6 +124,8 @@ impl DriverInfo for TcpServerInfo { fn create_endpoint_from_url(&self, url: &url::Url) -> Option> { let host = url.host_str().unwrap(); let port = url.port().unwrap(); - Some(Arc::new(TcpServer::new(&format!("{host}:{port}")))) + Some(Arc::new( + TcpServer::builder(&format!("{host}:{port}")).build(), + )) } } diff --git a/src/drivers/udp/client.rs b/src/drivers/udp/client.rs index 53d81931..24044e04 100644 --- a/src/drivers/udp/client.rs +++ b/src/drivers/udp/client.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use anyhow::Result; +use mavlink_server::callbacks::{Callbacks, MessageCallback}; use tokio::{net::UdpSocket, sync::broadcast}; use tracing::*; @@ -11,18 +12,37 @@ use crate::{ pub struct UdpClient { pub remote_addr: String, + on_message: Callbacks>, +} + +pub struct UdpClientBuilder(UdpClient); + +impl UdpClientBuilder { + pub fn build(self) -> UdpClient { + self.0 + } + + pub fn on_message(self, callback: C) -> Self + where + C: MessageCallback>, + { + self.0.on_message.add_callback(callback.into_boxed()); + self + } } impl UdpClient { #[instrument(level = "debug")] - pub fn new(remote_addr: &str) -> Self { - Self { + pub fn builder(remote_addr: &str) -> UdpClientBuilder { + UdpClientBuilder(Self { remote_addr: remote_addr.to_string(), - } + on_message: Callbacks::new(), + }) } - #[instrument(level = "debug", skip(socket))] + #[instrument(level = "debug", skip(self, socket))] async fn udp_receive_task( + &self, socket: Arc, hub_sender: Arc>>, ) -> Result<()> { @@ -34,7 +54,16 @@ impl UdpClient { let client_addr = &client_addr.to_string(); read_all_messages(client_addr, &mut buf, |message| async { - if let Err(error) = hub_sender.send(Arc::new(message)) { + let message = Arc::new(message); + + for future in self.on_message.call_all(Arc::clone(&message)) { + if let Err(error) = future.await { + debug!("Dropping message: on_message callback returned error: {error:?}"); + continue; + } + } + + if let Err(error) = hub_sender.send(message) { error!("Failed to send message to hub: {error:?}"); } }) @@ -55,8 +84,9 @@ impl UdpClient { Ok(()) } - #[instrument(level = "debug", skip(socket))] + #[instrument(level = "debug", skip(self, socket, hub_receiver))] async fn udp_send_task( + &self, socket: Arc, mut hub_receiver: broadcast::Receiver>, ) -> Result<()> { @@ -67,6 +97,15 @@ impl UdpClient { continue; // Don't do loopback } + for future in self.on_message.call_all(Arc::clone(&message)) { + if let Err(error) = future.await { + debug!( + "Dropping message: on_message callback returned error: {error:?}" + ); + continue; + } + } + match socket.send(message.raw_bytes()).await { Ok(_) => { // Message sent successfully @@ -121,12 +160,12 @@ impl Driver for UdpClient { let hub_receiver = hub_sender.subscribe(); tokio::select! { - result = UdpClient::udp_receive_task(socket.clone(), hub_sender) => { + result = self.udp_receive_task(socket.clone(), hub_sender) => { if let Err(error) = result { error!("Error in receiving UDP messages: {error:?}"); } } - result = UdpClient::udp_send_task(socket, hub_receiver) => { + result = self.udp_send_task(socket, hub_receiver) => { if let Err(error) = result { error!("Error in sending UDP messages: {error:?}"); } @@ -158,6 +197,8 @@ impl DriverInfo for UdpClientInfo { fn create_endpoint_from_url(&self, url: &url::Url) -> Option> { let host = url.host_str().unwrap(); let port = url.port().unwrap(); - Some(Arc::new(UdpClient::new(&format!("{host}:{port}")))) + Some(Arc::new( + UdpClient::builder(&format!("{host}:{port}")).build(), + )) } } diff --git a/src/drivers/udp/server.rs b/src/drivers/udp/server.rs index 2d22c496..7b020994 100644 --- a/src/drivers/udp/server.rs +++ b/src/drivers/udp/server.rs @@ -1,6 +1,7 @@ use std::{collections::HashMap, sync::Arc}; use anyhow::Result; +use mavlink_server::callbacks::{Callbacks, MessageCallback}; use tokio::{ net::UdpSocket, sync::{broadcast, RwLock}, @@ -15,19 +16,38 @@ use crate::{ pub struct UdpServer { pub local_addr: String, clients: Arc>>, + on_message: Callbacks>, +} + +pub struct UdpServerBuilder(UdpServer); + +impl UdpServerBuilder { + pub fn build(self) -> UdpServer { + self.0 + } + + pub fn on_message(self, callback: C) -> Self + where + C: MessageCallback>, + { + self.0.on_message.add_callback(callback.into_boxed()); + self + } } impl UdpServer { #[instrument(level = "debug")] - pub fn new(local_addr: String) -> Self { - Self { + pub fn builder(local_addr: String) -> UdpServerBuilder { + UdpServerBuilder(Self { local_addr, clients: Arc::new(RwLock::new(HashMap::new())), - } + on_message: Callbacks::new(), + }) } - #[instrument(level = "debug", skip(socket, hub_sender, clients))] + #[instrument(level = "debug", skip(self, socket, hub_sender, clients))] async fn udp_receive_task( + &self, socket: Arc, hub_sender: Arc>>, clients: Arc>>, @@ -40,8 +60,16 @@ impl UdpServer { let client_addr = &client_addr.to_string(); read_all_messages(client_addr, &mut buf, |message| async { + let message = Arc::new(message); + + for future in self.on_message.call_all(Arc::clone(&message)) { + if let Err(error) = future.await { + debug!("Dropping message: on_message callback returned error: {error:?}"); + continue; + } + } + // Update clients - let header_buf = message.header(); let sysid = message.system_id(); let compid = message.component_id(); if let Some(old_client_addr) = clients @@ -54,7 +82,8 @@ impl UdpServer { debug!("Client added: ({sysid},{compid}) -> {client_addr:?}"); } - if let Err(error) = hub_sender.send(Arc::new(message)) { + + if let Err(error) = hub_sender.send(message) { error!("Failed to send message to hub: {error:?}"); } }) @@ -75,8 +104,9 @@ impl UdpServer { Ok(()) } - #[instrument(level = "debug", skip(socket, hub_receiver, clients))] + #[instrument(level = "debug", skip(self, socket, hub_receiver, clients))] async fn udp_send_task( + &self, socket: Arc, mut hub_receiver: broadcast::Receiver>, clients: Arc>>, @@ -134,12 +164,12 @@ impl Driver for UdpServer { let hub_receiver = hub_sender.subscribe(); tokio::select! { - result = UdpServer::udp_receive_task(socket.clone(), hub_sender, clients.clone()) => { + result = self.udp_receive_task(socket.clone(), hub_sender, clients.clone()) => { if let Err(error) = result { error!("Error in receiving UDP messages: {error:?}"); } } - result = UdpServer::udp_send_task(socket, hub_receiver, clients.clone()) => { + result = self.udp_send_task(socket, hub_receiver, clients.clone()) => { if let Err(error) = result { error!("Error in sending UDP messages: {error:?}"); } @@ -171,6 +201,8 @@ impl DriverInfo for UdpServerInfo { fn create_endpoint_from_url(&self, url: &url::Url) -> Option> { let host = url.host_str().unwrap(); let port = url.port().unwrap(); - Some(Arc::new(UdpServer::new(format!("{host}:{port}")))) + Some(Arc::new( + UdpServer::builder(format!("{host}:{port}")).build(), + )) } }