Skip to content

Commit

Permalink
src: drivers: Move to builder and add callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
joaoantoniocardoso authored and patrickelectric committed Sep 7, 2024
1 parent a0d44e5 commit 6111fc7
Show file tree
Hide file tree
Showing 9 changed files with 245 additions and 58 deletions.
11 changes: 6 additions & 5 deletions src/drivers/fake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub struct FakeSink {
}

impl FakeSink {
pub fn new() -> FakeSinkBuilder {
pub fn builder() -> FakeSinkBuilder {
FakeSinkBuilder(Self {
on_message: Callbacks::new(),
})
Expand Down Expand Up @@ -88,7 +88,7 @@ pub struct FakeSource {
}

impl FakeSource {
pub fn new(period: std::time::Duration) -> FakeSourceBuilder {
pub fn builder(period: std::time::Duration) -> FakeSourceBuilder {
FakeSourceBuilder(Self {
period,
on_message: Callbacks::new(),
Expand Down Expand Up @@ -200,8 +200,9 @@ impl DriverInfo for FakeSourceInfo {

#[cfg(test)]
mod test {
use anyhow::Result;
use std::sync::Arc;

use anyhow::Result;
use tokio::sync::{broadcast, RwLock};

use super::*;
Expand All @@ -219,7 +220,7 @@ mod test {

// FakeSink and task
let sink_messages_clone = sink_messages.clone();
let sink = FakeSink::new()
let sink = FakeSink::builder()
.on_message(move |message: Arc<Protocol>| {
let sink_messages = sink_messages_clone.clone();

Expand All @@ -237,7 +238,7 @@ mod test {

// FakeSource and task
let source_messages_clone = source_messages.clone();
let source = FakeSource::new(message_period)
let source = FakeSource::builder(message_period)
.on_message(move |message: Arc<Protocol>| {
let source_messages = source_messages_clone.clone();

Expand Down
35 changes: 31 additions & 4 deletions src/drivers/file/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{path::PathBuf, sync::Arc};

use anyhow::Result;
use mavlink_server::callbacks::{Callbacks, MessageCallback};
use tokio::{
io::{AsyncWriteExt, BufWriter},
sync::broadcast,
Expand All @@ -14,6 +15,7 @@ use crate::{

pub struct FileClient {
pub path: PathBuf,
on_message: Callbacks<(u64, Arc<Protocol>)>,
}

pub struct FileClientBuilder(FileClient);
Expand All @@ -22,12 +24,23 @@ impl FileClientBuilder {
pub fn build(self) -> FileClient {
self.0
}

pub fn on_message<C>(self, callback: C) -> Self
where
C: MessageCallback<(u64, Arc<Protocol>)>,
{
self.0.on_message.add_callback(callback.into_boxed());
self
}
}

impl FileClient {
#[instrument(level = "debug")]
pub fn new(path: PathBuf) -> FileClientBuilder {
FileClientBuilder(Self { path })
pub fn builder(path: PathBuf) -> FileClientBuilder {
FileClientBuilder(Self {
path,
on_message: Callbacks::new(),
})
}

#[instrument(level = "debug", skip(self, writer, hub_receiver))]
Expand All @@ -41,8 +54,22 @@ impl FileClient {
loop {
match hub_receiver.recv().await {
Ok(message) => {
let raw_bytes = message.raw_bytes();
let timestamp = chrono::Utc::now().timestamp_micros() as u64;
let message = Arc::new(message);

for future in self
.on_message
.call_all((timestamp, (Arc::clone(&message))))
{
if let Err(error) = future.await {
debug!(
"Dropping message: on_message callback returned error: {error:?}"
);
continue;
}
}

let raw_bytes = message.raw_bytes();
writer.write_all(&timestamp.to_be_bytes()).await?;
writer.write_all(raw_bytes).await?;
writer.flush().await?;
Expand Down Expand Up @@ -91,6 +118,6 @@ impl DriverInfo for FileClientInfo {
}

fn create_endpoint_from_url(&self, url: &url::Url) -> Option<Arc<dyn Driver>> {
Some(Arc::new(FileClient::new(url.path().into()).build()))
Some(Arc::new(FileClient::builder(url.path().into()).build()))
}
}
13 changes: 7 additions & 6 deletions src/drivers/file/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ impl FileServerBuilder {

impl FileServer {
#[instrument(level = "debug")]
pub fn new(path: PathBuf) -> FileServerBuilder {
pub fn builder(path: PathBuf) -> FileServerBuilder {
FileServerBuilder(Self {
path,
on_message: Callbacks::new(),
})
}

#[instrument(level = "debug", skip(self, reader, hub_sender))]
async fn handle_client(
async fn handle_file(
&self,
reader: tokio::io::BufReader<tokio::fs::File>,
hub_sender: broadcast::Sender<Arc<Protocol>>,
Expand Down Expand Up @@ -124,7 +124,7 @@ impl Driver for FileServer {
let file = tokio::fs::File::open(self.path.clone()).await?;
let reader = tokio::io::BufReader::with_capacity(1024, file);

FileServer::handle_client(self, reader, hub_sender).await
FileServer::handle_file(self, reader, hub_sender).await
}

#[instrument(level = "debug", skip(self))]
Expand Down Expand Up @@ -179,7 +179,7 @@ impl DriverInfo for FileServerInfo {
}

fn create_endpoint_from_url(&self, url: &url::Url) -> Option<Arc<dyn Driver>> {
Some(Arc::new(FileServer::new(url.path().into()).build()))
Some(Arc::new(FileServer::builder(url.path().into()).build()))
}
}

Expand All @@ -204,7 +204,7 @@ mod tests {

let tlog_file = PathBuf::from_str("tests/files/00025-2024-04-22_18-49-07.tlog").unwrap();

let driver = FileServer::new(tlog_file.clone())
let driver = FileServer::builder(tlog_file.clone())
.on_message(move |args: (u64, Arc<Protocol>)| {
let messages_received = messages_received_cloned.clone();

Expand Down Expand Up @@ -232,7 +232,8 @@ mod tests {
let file_v2_messages = 30437;
let file_messages = file_v2_messages;
let mut total_messages_read = 0;
let res = tokio::time::timeout(tokio::time::Duration::from_secs(1), async {
let timeout_time = tokio::time::Duration::from_secs(1);
let res = tokio::time::timeout(timeout_time, async {
loop {
let messages_received_per_id = messages_received_per_id.read().await.clone();
total_messages_read = messages_received_per_id
Expand Down
33 changes: 27 additions & 6 deletions src/drivers/serial/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use std::sync::Arc;

use anyhow::Result;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::{broadcast, Mutex};
use mavlink_server::callbacks::{Callbacks, MessageCallback};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
sync::{broadcast, Mutex},
};
use tokio_serial::{self, SerialPortBuilderExt};
use tracing::*;

Expand All @@ -14,15 +17,33 @@ use crate::{
pub struct Serial {
pub port_name: String,
pub baud_rate: u32,
on_message: Callbacks<Arc<Protocol>>,
}

pub struct SerialBuilder(Serial);

impl SerialBuilder {
pub fn build(self) -> Serial {
self.0
}

pub fn on_message<C>(self, callback: C) -> Self
where
C: MessageCallback<Arc<Protocol>>,
{
self.0.on_message.add_callback(callback.into_boxed());
self
}
}

impl Serial {
#[instrument(level = "debug")]
pub fn new(port_name: &str, baud_rate: u32) -> Self {
Self {
pub fn builder(port_name: &str, baud_rate: u32) -> SerialBuilder {
SerialBuilder(Self {
port_name: port_name.to_string(),
baud_rate,
}
on_message: Callbacks::new(),
})
}

#[instrument(level = "debug", skip(port))]
Expand Down Expand Up @@ -148,6 +169,6 @@ impl DriverInfo for SerialInfo {
})
.unwrap_or(115200); // Commun baudrate between flight controllers

Some(Arc::new(Serial::new(&port_name, baud_rate)))
Some(Arc::new(Serial::builder(&port_name, baud_rate).build()))
}
}
33 changes: 27 additions & 6 deletions src/drivers/tcp/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::sync::Arc;

use anyhow::Result;
use mavlink_server::callbacks::{Callbacks, MessageCallback};
use tokio::{net::TcpStream, sync::broadcast};
use tracing::*;

Expand All @@ -14,14 +15,32 @@ use crate::{

pub struct TcpClient {
pub remote_addr: String,
on_message: Callbacks<Arc<Protocol>>,
}

pub struct TcpClientBuilder(TcpClient);

impl TcpClientBuilder {
pub fn build(self) -> TcpClient {
self.0
}

pub fn on_message<C>(self, callback: C) -> Self
where
C: MessageCallback<Arc<Protocol>>,
{
self.0.on_message.add_callback(callback.into_boxed());
self
}
}

impl TcpClient {
#[instrument(level = "debug")]
pub fn new(remote_addr: &str) -> Self {
Self {
pub fn builder(remote_addr: &str) -> TcpClientBuilder {
TcpClientBuilder(Self {
remote_addr: remote_addr.to_string(),
}
on_message: Callbacks::new(),
})
}
}

Expand All @@ -48,12 +67,12 @@ impl Driver for TcpClient {
let hub_sender_cloned = Arc::clone(&hub_sender);

tokio::select! {
result = tcp_receive_task(read, server_addr, hub_sender_cloned) => {
result = tcp_receive_task(read, server_addr, hub_sender_cloned, &self.on_message) => {
if let Err(e) = result {
error!("Error in TCP receive task: {e:?}");
}
}
result = tcp_send_task(write, server_addr, hub_receiver) => {
result = tcp_send_task(write, server_addr, hub_receiver, &self.on_message) => {
if let Err(e) = result {
error!("Error in TCP send task: {e:?}");
}
Expand Down Expand Up @@ -83,6 +102,8 @@ impl DriverInfo for TcpClientInfo {
fn create_endpoint_from_url(&self, url: &url::Url) -> Option<Arc<dyn Driver>> {
let host = url.host_str().unwrap();
let port = url.port().unwrap();
Some(Arc::new(TcpClient::new(&format!("{host}:{port}"))))
Some(Arc::new(
TcpClient::builder(&format!("{host}:{port}")).build(),
))
}
}
25 changes: 22 additions & 3 deletions src/drivers/tcp/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::sync::Arc;

use anyhow::Result;
use mavlink_server::callbacks::Callbacks;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::tcp::{OwnedReadHalf, OwnedWriteHalf},
Expand All @@ -14,11 +15,12 @@ pub mod client;
pub mod server;

/// Receives messages from the TCP Socket and sends them to the HUB Channel
#[instrument(level = "debug", skip(socket, hub_sender))]
#[instrument(level = "debug", skip(socket, hub_sender, on_message))]
async fn tcp_receive_task(
mut socket: OwnedReadHalf,
remote_addr: &str,
hub_sender: Arc<broadcast::Sender<Arc<Protocol>>>,
on_message: &Callbacks<Arc<Protocol>>,
) -> Result<()> {
let mut buf = Vec::with_capacity(1024);

Expand All @@ -32,7 +34,16 @@ async fn tcp_receive_task(
trace!("Received TCP packet: {buf:?}");

read_all_messages(remote_addr, &mut buf, |message| async {
if let Err(error) = hub_sender.send(Arc::new(message)) {
let message = Arc::new(message);

for future in on_message.call_all(Arc::clone(&message)) {
if let Err(error) = future.await {
debug!("Dropping message: on_message callback returned error: {error:?}");
continue;
}
}

if let Err(error) = hub_sender.send(message) {
error!("Failed to send message to hub: {error:?}");
}
})
Expand All @@ -44,11 +55,12 @@ async fn tcp_receive_task(
}

/// Receives messages from the HUB Channel and sends them to the TCP Socket
#[instrument(level = "debug", skip(socket, hub_receiver))]
#[instrument(level = "debug", skip(socket, hub_receiver, on_message))]
async fn tcp_send_task(
mut socket: OwnedWriteHalf,
remote_addr: &str,
mut hub_receiver: broadcast::Receiver<Arc<Protocol>>,
on_message: &Callbacks<Arc<Protocol>>,
) -> Result<()> {
loop {
let message = match hub_receiver.recv().await {
Expand All @@ -67,6 +79,13 @@ async fn tcp_send_task(
continue; // Don't do loopback
}

for future in on_message.call_all(Arc::clone(&message)) {
if let Err(error) = future.await {
debug!("Dropping message: on_message callback returned error: {error:?}");
continue;
}
}

socket.write_all(message.raw_bytes()).await?;

trace!(
Expand Down
Loading

0 comments on commit 6111fc7

Please sign in to comment.