Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More callbacks #9

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 243 additions & 9 deletions Cargo.lock

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,20 @@ name = "mavlink-server"
version = "0.1.0"
edition = "2021"

[lib]
name = "mavlink_server"
path = "src/lib/mod.rs"
bench = false

[[bench]]
name = "callbacks_bench"
harness = false

[[bin]]
name = "mavlink-server"
path = "src/main.rs"
bench = false

[dependencies]
anyhow = "1"
async-trait = "0.1.81"
Expand All @@ -12,6 +26,7 @@ clap = { version = "4.5", features = ["derive"] }
ctrlc = "3.4"
futures = "0.3"
lazy_static = "1.5.0"
indexmap = "2.5.0"
# mavlink = { version = "0.13.1", default-features = false, features = ["ardupilotmega", "std"] }
# mavlink = { default-features = false, features = ["ardupilotmega", "std", "tokio-1"], path = "../rust-mavlink/mavlink" }
mavlink = { default-features = false, features = ["ardupilotmega", "std", "tokio-1"], git = "https://github.com/joaoantoniocardoso/rust-mavlink", branch = "add-tokio" }
Expand All @@ -29,6 +44,9 @@ tracing-log = "0.2.0"
# Reference: https://github.com/tokio-rs/tracing/issues/2441
tracing-appender = { git = "https://github.com/joaoantoniocardoso/tracing", branch = "tracing-appender-0.2.2-with-filename-suffix" }

[dev-dependencies]
criterion = "0.5"
tokio = { version = "1", features = ["full"] }

[build-dependencies]
vergen-gix = { version = "1.0.0-beta.2", default-features = false, features = ["build", "cargo"] }
48 changes: 48 additions & 0 deletions benches/callbacks_bench.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use mavlink_server::callbacks::Callbacks;
use tokio::runtime::Runtime;

fn bench_call_all(c: &mut Criterion) {
let mut group = c.benchmark_group("callbacks");

let callback_counts = vec![0, 1, 3, 5, 10, 20, 50, 100];

for number_of_callbacks in &callback_counts {
group.throughput(criterion::Throughput::Elements(*number_of_callbacks));
group.bench_with_input(
BenchmarkId::from_parameter(number_of_callbacks),
number_of_callbacks,
|b, &number_of_callbacks| {
let rt = Runtime::new().unwrap();
let callbacks = Callbacks::<String>::new();

for _ in 0..number_of_callbacks {
callbacks.add_callback({
move |msg: String| async move {
if msg != "test" {
panic!("Wrong message");
}
Ok(())
}
});
}

// Benchmark calling all callbacks
b.iter(|| {
rt.block_on(async {
for future in callbacks.call_all("test".to_string()) {
if let Err(_error) = future.await {
continue;
}
}
});
});
},
);
}

group.finish();
}

criterion_group!(benches, bench_call_all);
criterion_main!(benches);
58 changes: 34 additions & 24 deletions src/drivers/fake.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
use std::sync::Arc;

use anyhow::Result;
use mavlink_server::callbacks::{Callbacks, MessageCallback};
use tokio::sync::broadcast;
use tracing::*;

use crate::{
drivers::{Driver, DriverInfo, OnMessageCallback, OnMessageCallbackExt},
drivers::{Driver, DriverInfo},
protocol::{read_all_messages, Protocol},
};

#[derive(Default)]
pub struct FakeSink {
on_message: OnMessageCallback<Arc<Protocol>>,
on_message: Callbacks<Arc<Protocol>>,
}

impl FakeSink {
pub fn new() -> FakeSinkBuilder {
FakeSinkBuilder(Self { on_message: None })
pub fn builder() -> FakeSinkBuilder {
FakeSinkBuilder(Self {
on_message: Callbacks::new(),
})
}
}

Expand All @@ -27,11 +29,11 @@ impl FakeSinkBuilder {
self.0
}

pub fn on_message<F>(mut self, callback: F) -> Self
pub fn on_message<C>(self, callback: C) -> Self
where
F: OnMessageCallbackExt<Arc<Protocol>> + Send + Sync + 'static,
C: MessageCallback<Arc<Protocol>>,
{
self.0.on_message = Some(Box::new(callback));
self.0.on_message.add_callback(callback.into_boxed());
self
}
}
Expand All @@ -42,11 +44,14 @@ impl Driver for FakeSink {
let mut hub_receiver = hub_sender.subscribe();

while let Ok(message) = hub_receiver.recv().await {
debug!("Message received: {message:?}");

if let Some(callback) = &self.on_message {
callback.call(Arc::clone(&message)).await?;
for future in self.on_message.call_all(Arc::clone(&message)) {
if let Err(error) = future.await {
debug!("Dropping message: on_message callback returned error: {error:?}");
continue;
}
}

trace!("Message received: {message:?}");
}

Ok(())
Expand Down Expand Up @@ -77,17 +82,16 @@ impl DriverInfo for FakeSinkInfo {
}
}

#[derive(Default)]
pub struct FakeSource {
period: std::time::Duration,
on_message: OnMessageCallback<Arc<Protocol>>,
on_message: Callbacks<Arc<Protocol>>,
}

impl FakeSource {
pub fn new(period: std::time::Duration) -> FakeSourceBuilder {
pub fn builder(period: std::time::Duration) -> FakeSourceBuilder {
FakeSourceBuilder(Self {
period,
on_message: None,
on_message: Callbacks::new(),
})
}
}
Expand All @@ -99,11 +103,11 @@ impl FakeSourceBuilder {
self.0
}

pub fn on_message<F>(mut self, callback: F) -> Self
pub fn on_message<C>(self, callback: C) -> Self
where
F: OnMessageCallbackExt<Arc<Protocol>> + Send + Sync + 'static,
C: MessageCallback<Arc<Protocol>>,
{
self.0.on_message = Some(Box::new(callback));
self.0.on_message.add_callback(callback.into_boxed());
self
}
}
Expand Down Expand Up @@ -149,8 +153,13 @@ impl Driver for FakeSource {
async move {
trace!("Fake message created: {message:?}");

if let Some(callback) = &self.on_message {
callback.call(Arc::clone(&message)).await.unwrap();
for future in self.on_message.call_all(Arc::clone(&message)) {
if let Err(error) = future.await {
debug!(
"Dropping message: on_message callback returned error: {error:?}"
);
continue;
}
}

if let Err(error) = hub_sender.send(message) {
Expand Down Expand Up @@ -191,8 +200,9 @@ impl DriverInfo for FakeSourceInfo {

#[cfg(test)]
mod test {
use anyhow::Result;
use std::sync::Arc;

use anyhow::Result;
use tokio::sync::{broadcast, RwLock};

use super::*;
Expand All @@ -210,7 +220,7 @@ mod test {

// FakeSink and task
let sink_messages_clone = sink_messages.clone();
let sink = FakeSink::new()
let sink = FakeSink::builder()
.on_message(move |message: Arc<Protocol>| {
let sink_messages = sink_messages_clone.clone();

Expand All @@ -228,7 +238,7 @@ mod test {

// FakeSource and task
let source_messages_clone = source_messages.clone();
let source = FakeSource::new(message_period)
let source = FakeSource::builder(message_period)
.on_message(move |message: Arc<Protocol>| {
let source_messages = source_messages_clone.clone();

Expand Down
35 changes: 31 additions & 4 deletions src/drivers/file/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{path::PathBuf, sync::Arc};

use anyhow::Result;
use mavlink_server::callbacks::{Callbacks, MessageCallback};
use tokio::{
io::{AsyncWriteExt, BufWriter},
sync::broadcast,
Expand All @@ -14,6 +15,7 @@ use crate::{

pub struct FileClient {
pub path: PathBuf,
on_message: Callbacks<(u64, Arc<Protocol>)>,
}

pub struct FileClientBuilder(FileClient);
Expand All @@ -22,12 +24,23 @@ impl FileClientBuilder {
pub fn build(self) -> FileClient {
self.0
}

pub fn on_message<C>(self, callback: C) -> Self
where
C: MessageCallback<(u64, Arc<Protocol>)>,
{
self.0.on_message.add_callback(callback.into_boxed());
self
}
}

impl FileClient {
#[instrument(level = "debug")]
pub fn new(path: PathBuf) -> FileClientBuilder {
FileClientBuilder(Self { path })
pub fn builder(path: PathBuf) -> FileClientBuilder {
FileClientBuilder(Self {
path,
on_message: Callbacks::new(),
})
}

#[instrument(level = "debug", skip(self, writer, hub_receiver))]
Expand All @@ -41,8 +54,22 @@ impl FileClient {
loop {
match hub_receiver.recv().await {
Ok(message) => {
let raw_bytes = message.raw_bytes();
let timestamp = chrono::Utc::now().timestamp_micros() as u64;
let message = Arc::new(message);

for future in self
.on_message
.call_all((timestamp, (Arc::clone(&message))))
{
if let Err(error) = future.await {
debug!(
"Dropping message: on_message callback returned error: {error:?}"
);
continue;
}
}

let raw_bytes = message.raw_bytes();
writer.write_all(&timestamp.to_be_bytes()).await?;
writer.write_all(raw_bytes).await?;
writer.flush().await?;
Expand Down Expand Up @@ -91,6 +118,6 @@ impl DriverInfo for FileClientInfo {
}

fn create_endpoint_from_url(&self, url: &url::Url) -> Option<Arc<dyn Driver>> {
Some(Arc::new(FileClient::new(url.path().into()).build()))
Some(Arc::new(FileClient::builder(url.path().into()).build()))
}
}
Loading
Loading