Skip to content

Commit

Permalink
src: lib: drivers: rest: Refactor send and recv tasks using same patt…
Browse files Browse the repository at this point in the history
…erns as other drivers
  • Loading branch information
joaoantoniocardoso authored and patrickelectric committed Nov 13, 2024
1 parent 0eac2b0 commit 39b405f
Showing 1 changed file with 83 additions and 75 deletions.
158 changes: 83 additions & 75 deletions src/lib/drivers/rest/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@ pub mod data;
use std::sync::Arc;

use anyhow::Result;
use mavlink::ardupilotmega::MavMessage;
use mavlink_codec::Packet;
use tokio::sync::{broadcast, RwLock};
use tracing::*;

use crate::{
callbacks::{Callbacks, MessageCallback},
drivers::{Driver, DriverInfo},
mavlink_json::{MAVLinkJSON, MAVLinkJSONHeader},
drivers::{generic_tasks::SendReceiveContext, Driver, DriverInfo},
mavlink_json::MAVLinkJSON,
protocol::Protocol,
stats::{
accumulated::driver::{AccumulatedDriverStats, AccumulatedDriverStatsProvider},
Expand Down Expand Up @@ -66,89 +64,94 @@ impl Rest {
})
}

#[instrument(level = "debug", skip(on_message_input))]
#[instrument(level = "debug", skip_all)]
async fn receive_task(
hub_sender: broadcast::Sender<Arc<Protocol>>,
on_message_input: &Callbacks<Arc<Protocol>>,
context: &SendReceiveContext,
ws_receiver: &mut broadcast::Receiver<String>,
stats: &Arc<RwLock<AccumulatedDriverStats>>,
) -> Result<()> {
match ws_receiver.recv().await {
Ok(message) => {
if let Ok(content) =
json5::from_str::<MAVLinkJSON<mavlink::ardupilotmega::MavMessage>>(&message)
{
let mut message_raw = mavlink::MAVLinkV2MessageRaw::new();
message_raw.serialize_message(content.header.inner, &content.message);
let bus_message = Arc::new(Protocol::new("Ws", Packet::from(message_raw)));
stats.write().await.stats.update_input(&bus_message);

for future in on_message_input.call_all(bus_message.clone()) {
if let Err(error) = future.await {
debug!("Dropping message: on_message_input callback returned error: {error:?}");
continue;
}
}

if let Err(error) = hub_sender.send(bus_message) {
error!("Failed to send message to hub: {error:?}");
}
return Ok(());
while let Ok(message) = ws_receiver.recv().await {
let Ok(content) =
json5::from_str::<MAVLinkJSON<mavlink::ardupilotmega::MavMessage>>(&message)
else {
debug!("Failed to parse message, not a valid MAVLinkMessage: {message:?}");
continue;
};

let bus_message = Arc::new(Protocol::from_mavlink_raw(
content.header.inner,
&content.message,
"Ws",
));

trace!("Received message: {bus_message:?}");

context.stats.write().await.stats.update_input(&bus_message);

for future in context.on_message_input.call_all(bus_message.clone()) {
if let Err(error) = future.await {
debug!("Dropping message: on_message_input callback returned error: {error:?}");
continue;
}
return Err(anyhow::anyhow!(
"Failed to parse message, not a valid MAVLinkMessage: {message:?}"
));
}
// We got problems
Err(error) => {
return Err(anyhow::anyhow!(
"Failed to receive message from ws: {error:?}"
));

if let Err(error) = context.hub_sender.send(bus_message) {
error!("Failed to send message to hub: {error:?}");
continue;
}

trace!("Message sent to hub");
}

debug!("Driver receiver task stopped!");

Ok(())
}

#[instrument(level = "debug", skip(on_message_output))]
async fn send_task(
mut hub_receiver: broadcast::Receiver<Arc<Protocol>>,
on_message_output: &Callbacks<Arc<Protocol>>,
stats: &Arc<RwLock<AccumulatedDriverStats>>,
) -> Result<()> {
loop {
match hub_receiver.recv().await {
Ok(message) => {
stats.write().await.stats.update_output(&message);
for future in on_message_output.call_all(message.clone()) {
if let Err(error) = future.await {
debug!("Dropping message: on_message_output callback returned error: {error:?}");
continue;
}
}
#[instrument(level = "debug", skip_all)]
async fn send_task(context: &SendReceiveContext) -> Result<()> {
let mut hub_receiver = context.hub_sender.subscribe();

let mut bytes =
mavlink::async_peek_reader::AsyncPeekReader::new(message.as_slice());
let Ok((header, message)) =
mavlink::read_v2_msg_async::<MavMessage, _>(&mut bytes).await
else {
continue;
};
loop {
let message = match hub_receiver.recv().await {
Ok(message) => message,
Err(broadcast::error::RecvError::Closed) => {
error!("Hub channel closed!");
break;
}
Err(broadcast::error::RecvError::Lagged(count)) => {
warn!("Channel lagged by {count} messages.");
continue;
}
};

let header = MAVLinkJSONHeader {
inner: header,
message_id: Some(mavlink::Message::message_id(&message)),
};
if message.origin.eq("Ws") {
continue; // Don't do loopback
}

let mavlink_message = MAVLinkJSON { header, message };
context.stats.write().await.stats.update_output(&message);

let json_string = parse_query(&mavlink_message);
data::update((header, mavlink_message.message));
crate::web::send_message(json_string).await;
}
Err(error) => {
error!("Failed to receive message from hub: {error:?}");
for future in context.on_message_output.call_all(message.clone()) {
if let Err(error) = future.await {
debug!(
"Dropping message: on_message_output callback returned error: {error:?}"
);
continue;
}
}

let Ok(mavlink_json) = message.to_mavlink_json().await else {
continue;
};

let json_string = parse_query(&mavlink_json);
data::update((mavlink_json.header, mavlink_json.message));

crate::web::send_message(json_string).await;
}

debug!("Driver sender task stopped!");

Ok(())
}
}

Expand All @@ -162,6 +165,13 @@ pub fn parse_query<T: serde::ser::Serialize>(message: &T) -> String {
impl Driver for Rest {
#[instrument(level = "debug", skip(self, hub_sender))]
async fn run(&self, hub_sender: broadcast::Sender<Arc<Protocol>>) -> Result<()> {
let context = SendReceiveContext {
hub_sender,
on_message_output: self.on_message_output.clone(),
on_message_input: self.on_message_input.clone(),
stats: self.stats.clone(),
};

let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(1));
let mut first = true;
loop {
Expand All @@ -171,17 +181,15 @@ impl Driver for Rest {
interval.tick().await;
}

let hub_sender = hub_sender.clone();
let hub_receiver = hub_sender.subscribe();
let mut ws_receiver = crate::web::create_message_receiver();

tokio::select! {
result = Rest::send_task(hub_receiver, &self.on_message_output, &self.stats) => {
result = Rest::send_task(&context) => {
if let Err(e) = result {
error!("Error in rest sender task: {e:?}");
}
}
result = Rest::receive_task(hub_sender, &self.on_message_input, &mut ws_receiver, &self.stats) => {
result = Rest::receive_task(&context, &mut ws_receiver) => {
if let Err(e) = result {
error!("Error in rest receive task: {e:?}");
}
Expand Down

0 comments on commit 39b405f

Please sign in to comment.