Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General improvements before adding Hub Stats #23

Merged
120 changes: 64 additions & 56 deletions src/drivers/fake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@ use tracing::*;
use crate::{
drivers::{Driver, DriverInfo},
protocol::{read_all_messages, Protocol},
stats::driver::{DriverStats, DriverStatsInfo},
stats::driver::{AccumulatedDriverStats, AccumulatedDriverStatsProvider},
};

pub struct FakeSink {
on_message_input: Callbacks<Arc<Protocol>>,
print: bool,
stats: Arc<RwLock<DriverStatsInfo>>,
stats: Arc<RwLock<AccumulatedDriverStats>>,
}

impl FakeSink {
pub fn builder() -> FakeSinkBuilder {
FakeSinkBuilder(Self {
on_message_input: Callbacks::new(),
print: false,
stats: Arc::new(RwLock::new(DriverStatsInfo::default())),
stats: Arc::new(RwLock::new(AccumulatedDriverStats::default())),
})
}
}
Expand Down Expand Up @@ -93,13 +93,13 @@ impl Driver for FakeSink {
}

#[async_trait::async_trait]
impl DriverStats for FakeSink {
async fn stats(&self) -> DriverStatsInfo {
impl AccumulatedDriverStatsProvider for FakeSink {
async fn stats(&self) -> AccumulatedDriverStats {
self.stats.read().await.clone()
}

async fn reset_stats(&self) {
*self.stats.write().await = DriverStatsInfo {
*self.stats.write().await = AccumulatedDriverStats {
input: None,
output: None,
}
Expand Down Expand Up @@ -146,15 +146,15 @@ impl DriverInfo for FakeSinkInfo {
pub struct FakeSource {
period: std::time::Duration,
on_message_output: Callbacks<Arc<Protocol>>,
stats: Arc<RwLock<DriverStatsInfo>>,
stats: Arc<RwLock<AccumulatedDriverStats>>,
}

impl FakeSource {
pub fn builder(period: std::time::Duration) -> FakeSourceBuilder {
FakeSourceBuilder(Self {
period,
on_message_output: Callbacks::new(),
stats: Arc::new(RwLock::new(DriverStatsInfo::default())),
stats: Arc::new(RwLock::new(AccumulatedDriverStats::default())),
})
}
}
Expand Down Expand Up @@ -207,34 +207,35 @@ impl Driver for FakeSource {
buf.clear();
mavlink::write_v2_msg(&mut buf, header, &data).expect("Failed to write message");

let hub_sender_cloned = hub_sender.clone();
read_all_messages("FakeSource", &mut buf, move |message| {
let message = Arc::new(message);
let hub_sender = hub_sender_cloned.clone();

async move {
trace!("Fake message created: {message:?}");

self.stats
.write()
.await
.update_output(Arc::clone(&message))
.await;

for future in self.on_message_output.call_all(Arc::clone(&message)) {
if let Err(error) = future.await {
debug!(
"Dropping message: on_message_input callback returned error: {error:?}"
);
continue;
read_all_messages("FakeSource", &mut buf, {
let hub_sender = hub_sender.clone();
move |message| {
let message = Arc::new(message);
let hub_sender = hub_sender.clone();

async move {
trace!("Fake message created: {message:?}");

self.stats
.write()
.await
.update_output(Arc::clone(&message))
.await;

for future in self.on_message_output.call_all(Arc::clone(&message)) {
if let Err(error) = future.await {
debug!(
"Dropping message: on_message_input callback returned error: {error:?}"
);
continue;
}
}
}

if let Err(error) = hub_sender.send(message) {
error!("Failed to send message to hub: {error:?}");
if let Err(error) = hub_sender.send(message) {
error!("Failed to send message to hub: {error:?}");
}
}
}
})
}})
.await;

tokio::time::sleep(self.period).await;
Expand All @@ -247,13 +248,13 @@ impl Driver for FakeSource {
}

#[async_trait::async_trait]
impl DriverStats for FakeSource {
async fn stats(&self) -> DriverStatsInfo {
impl AccumulatedDriverStatsProvider for FakeSource {
async fn stats(&self) -> AccumulatedDriverStats {
self.stats.read().await.clone()
}

async fn reset_stats(&self) {
*self.stats.write().await = DriverStatsInfo {
*self.stats.write().await = AccumulatedDriverStats {
input: None,
output: None,
}
Expand Down Expand Up @@ -319,14 +320,17 @@ mod test {
let sink_messages = Arc::new(RwLock::new(Vec::<Arc<Protocol>>::with_capacity(1000)));

// FakeSink and task
let sink_messages_clone = sink_messages.clone();
let sink = FakeSink::builder()
.on_message_input(move |message: Arc<Protocol>| {
let sink_messages = sink_messages_clone.clone();
.on_message_input({
let sink_messages = sink_messages.clone();

move |message: Arc<Protocol>| {
let sink_messages = sink_messages.clone();

async move {
sink_messages.write().await.push(message);
Ok(())
async move {
sink_messages.write().await.push(message);
Ok(())
}
}
})
.build();
Expand All @@ -337,14 +341,16 @@ mod test {
});

// FakeSource and task
let source_messages_clone = source_messages.clone();
let source = FakeSource::builder(message_period)
.on_message_output(move |message: Arc<Protocol>| {
let source_messages = source_messages_clone.clone();

async move {
source_messages.write().await.push(message);
Ok(())
.on_message_output({
let source_messages = source_messages.clone();
move |message: Arc<Protocol>| {
let source_messages = source_messages.clone();

async move {
source_messages.write().await.push(message);
Ok(())
}
}
})
.build();
Expand All @@ -355,14 +361,16 @@ mod test {
});

// Monitoring task to wait the
let sink_messages_clone = sink_messages.clone();
let sink_monitor_task = tokio::spawn(async move {
loop {
if sink_messages_clone.read().await.len() >= number_of_messages {
break;
}
let sink_monitor_task = tokio::spawn({
let sink_messages = sink_messages.clone();
async move {
loop {
if sink_messages.read().await.len() >= number_of_messages {
break;
}

tokio::time::sleep(std::time::Duration::from_millis(1)).await;
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
}
}
});
let _ = tokio::time::timeout(timeout_time, sink_monitor_task)
Expand Down
30 changes: 16 additions & 14 deletions src/drivers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use tokio::sync::broadcast;
use tracing::*;
use url::Url;

use crate::{protocol::Protocol, stats::driver::DriverStats};
use crate::{protocol::Protocol, stats::driver::AccumulatedDriverStatsProvider};

#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum Type {
Expand All @@ -36,7 +36,7 @@ pub struct DriverDescriptionLegacy {
}

#[async_trait::async_trait]
pub trait Driver: Send + Sync + DriverStats {
pub trait Driver: Send + Sync + AccumulatedDriverStatsProvider {
async fn run(&self, hub_sender: broadcast::Sender<Arc<Protocol>>) -> Result<()>;
fn info(&self) -> Box<dyn DriverInfo>;
}
Expand Down Expand Up @@ -224,7 +224,7 @@ mod tests {
use tokio::sync::RwLock;
use tracing::*;

use crate::stats::driver::DriverStatsInfo;
use crate::stats::driver::AccumulatedDriverStats;

use super::*;

Expand All @@ -244,15 +244,15 @@ mod tests {
pub struct ExampleDriver {
on_message_input: Callbacks<Arc<Protocol>>,
on_message_output: Callbacks<Arc<Protocol>>,
stats: Arc<RwLock<DriverStatsInfo>>,
stats: Arc<RwLock<AccumulatedDriverStats>>,
}

impl ExampleDriver {
pub fn new() -> ExampleDriverBuilder {
ExampleDriverBuilder(Self {
on_message_input: Callbacks::new(),
on_message_output: Callbacks::new(),
stats: Arc::new(RwLock::new(DriverStatsInfo::default())),
stats: Arc::new(RwLock::new(AccumulatedDriverStats::default())),
})
}
}
Expand Down Expand Up @@ -306,13 +306,13 @@ mod tests {
}

#[async_trait::async_trait]
impl DriverStats for ExampleDriver {
async fn stats(&self) -> DriverStatsInfo {
impl AccumulatedDriverStatsProvider for ExampleDriver {
async fn stats(&self) -> AccumulatedDriverStats {
self.stats.read().await.clone()
}

async fn reset_stats(&self) {
*self.stats.write().await = DriverStatsInfo {
*self.stats.write().await = AccumulatedDriverStats {
input: None,
output: None,
}
Expand Down Expand Up @@ -347,15 +347,17 @@ mod tests {
let (sender, _receiver) = tokio::sync::broadcast::channel(1);

let called = Arc::new(RwLock::new(false));
let called_cloned = called.clone();
let driver = ExampleDriver::new()
.on_message_input(move |_msg| {
let called = called_cloned.clone();
.on_message_input({
let called = called.clone();
move |_msg| {
let called = called.clone();

async move {
*called.write().await = true;
async move {
*called.write().await = true;

Err(anyhow!("Finished from callback"))
Err(anyhow!("Finished from callback"))
}
}
})
.build();
Expand Down
12 changes: 6 additions & 6 deletions src/drivers/serial/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ use tracing::*;
use crate::{
drivers::{Driver, DriverInfo},
protocol::{read_all_messages, Protocol},
stats::driver::{DriverStats, DriverStatsInfo},
stats::driver::{AccumulatedDriverStats, AccumulatedDriverStatsProvider},
};

pub struct Serial {
pub port_name: String,
pub baud_rate: u32,
on_message_input: Callbacks<Arc<Protocol>>,
on_message_output: Callbacks<Arc<Protocol>>,
stats: Arc<RwLock<DriverStatsInfo>>,
stats: Arc<RwLock<AccumulatedDriverStats>>,
}

pub struct SerialBuilder(Serial);
Expand Down Expand Up @@ -55,7 +55,7 @@ impl Serial {
baud_rate,
on_message_input: Callbacks::new(),
on_message_output: Callbacks::new(),
stats: Arc::new(RwLock::new(DriverStatsInfo::default())),
stats: Arc::new(RwLock::new(AccumulatedDriverStats::default())),
})
}

Expand Down Expand Up @@ -179,13 +179,13 @@ impl Driver for Serial {
}

#[async_trait::async_trait]
impl DriverStats for Serial {
async fn stats(&self) -> DriverStatsInfo {
impl AccumulatedDriverStatsProvider for Serial {
async fn stats(&self) -> AccumulatedDriverStats {
self.stats.read().await.clone()
}

async fn reset_stats(&self) {
*self.stats.write().await = DriverStatsInfo {
*self.stats.write().await = AccumulatedDriverStats {
input: None,
output: None,
}
Expand Down
15 changes: 7 additions & 8 deletions src/drivers/tcp/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ use crate::{
Driver, DriverInfo,
},
protocol::Protocol,
stats::driver::{DriverStats, DriverStatsInfo},
stats::driver::{AccumulatedDriverStats, AccumulatedDriverStatsProvider},
};

pub struct TcpClient {
pub remote_addr: String,
on_message_input: Callbacks<Arc<Protocol>>,
on_message_output: Callbacks<Arc<Protocol>>,
stats: Arc<RwLock<DriverStatsInfo>>,
stats: Arc<RwLock<AccumulatedDriverStats>>,
}

pub struct TcpClientBuilder(TcpClient);
Expand Down Expand Up @@ -55,7 +55,7 @@ impl TcpClient {
remote_addr: remote_addr.to_string(),
on_message_input: Callbacks::new(),
on_message_output: Callbacks::new(),
stats: Arc::new(RwLock::new(DriverStatsInfo::default())),
stats: Arc::new(RwLock::new(AccumulatedDriverStats::default())),
})
}
}
Expand All @@ -80,10 +80,9 @@ impl Driver for TcpClient {
debug!("TcpClient successfully connected to {server_addr:?}");

let hub_receiver = hub_sender.subscribe();
let hub_sender_cloned = Arc::clone(&hub_sender);

tokio::select! {
result = tcp_receive_task(read, server_addr, hub_sender_cloned, &self.on_message_input, &self.stats) => {
result = tcp_receive_task(read, server_addr, Arc::clone(&hub_sender), &self.on_message_input, &self.stats) => {
if let Err(e) = result {
error!("Error in TCP receive task: {e:?}");
}
Expand All @@ -106,13 +105,13 @@ impl Driver for TcpClient {
}

#[async_trait::async_trait]
impl DriverStats for TcpClient {
async fn stats(&self) -> DriverStatsInfo {
impl AccumulatedDriverStatsProvider for TcpClient {
async fn stats(&self) -> AccumulatedDriverStats {
self.stats.read().await.clone()
}

async fn reset_stats(&self) {
*self.stats.write().await = DriverStatsInfo {
*self.stats.write().await = AccumulatedDriverStats {
input: None,
output: None,
}
Expand Down
Loading
Loading