diff --git a/Cargo.lock b/Cargo.lock index a92e48a1..c201d85f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1238,6 +1238,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "byteorder", "chrono", "clap", "ctrlc", diff --git a/Cargo.toml b/Cargo.toml index ee18fc78..95918299 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ shellexpand = "3.1" chrono = "0.4" url = { version = "2.5.2", features = ["serde"] } ctrlc = "3.4" +byteorder = "1.5.0" tracing = { version = "0.1.40", features = ["log", "async-await"] } tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } diff --git a/src/cli.rs b/src/cli.rs index 404cf69c..5393064d 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -55,14 +55,17 @@ fn endpoints_parser(endpoint: &str) -> Result { let endpoint = endpoint.to_lowercase(); let mut split = endpoint.split(':'); - if split.clone().count() != 3 { + if split.clone().count() != 2 && split.clone().count() != 3 { return Err("Wrong endpoint format".to_string()); } let kind = split.next().expect( - "Endpoint should start with one of the kinds: udps, udpc, udpb, tcps, tcpc, or serial", + "Endpoint should start with one of the kinds: file, udps, udpc, udpb, tcps, tcpc, or serial", ); - if !matches!(kind, "udps" | "udpc" | "udpb" | "tcps" | "tcpc" | "serial") { + if !matches!( + kind, + "file" | "udps" | "udpc" | "udpb" | "tcps" | "tcpc" | "serial" + ) { return Err(format!("Unknown kind: {kind:?} for endpoint")); } @@ -126,6 +129,11 @@ pub fn tcp_client_endpoints() -> Vec { get_endpoint_with_kind("tcpc") } +#[instrument(level = "debug")] +pub fn file_server_endpoints() -> Vec { + get_endpoint_with_kind("file") +} + #[instrument(level = "debug")] pub fn tcp_server_endpoints() -> Vec { get_endpoint_with_kind("tcps") diff --git a/src/drivers/file/mod.rs b/src/drivers/file/mod.rs new file mode 100644 index 00000000..74f47ad3 --- /dev/null +++ b/src/drivers/file/mod.rs @@ -0,0 +1 @@ +pub mod server; diff --git a/src/drivers/file/server.rs b/src/drivers/file/server.rs new file mode 100644 index 00000000..fc785bc8 --- /dev/null +++ b/src/drivers/file/server.rs @@ -0,0 +1,103 @@ +use std::sync::Arc; + +use anyhow::Result; +use chrono::DateTime; +use mavlink::ardupilotmega::MavMessage; +use std::path::PathBuf; +use tokio::io::AsyncReadExt; +use tokio::sync::broadcast; +use tracing::*; + +use crate::drivers::{Driver, DriverInfo}; +use crate::protocol::Protocol; + +#[derive(Clone, Debug)] +pub struct FileServer { + pub path: PathBuf, +} + +impl FileServer { + #[instrument(level = "debug")] + pub fn try_new(file_path: &str) -> Result { + let path = PathBuf::from(file_path); + Ok(Self { path }) + } + + #[instrument(level = "debug", skip(reader, hub_sender))] + async fn handle_client( + server: FileServer, + mut reader: tokio::io::BufReader, + hub_sender: Arc>, + ) -> Result<()> { + let source_name = server.path.as_path().display().to_string(); + loop { + // Tlog files follow the byte format of + let Ok(us_since_epoch) = reader.read_u64().await else { + info!("End of file reached"); + break; + }; + + let Some(_date_time) = DateTime::from_timestamp_micros(us_since_epoch as i64) else { + warn!("Failed to convert unix time"); + continue; + }; + + // Ensure that we have at least a single byte before checking for a valid mavlink message + if (reader.buffer().is_empty()) { + info!("End of file reached"); + break; + } + + // Since the source is a tlog file that includes timestamps + raw mavlink messages. + // We first need to be sure that the next byte is the start of a mavlink message, + // otherwise the `read_v2_raw_message_async` will process valid timestamps as garbage. + if (reader.buffer()[0] != mavlink::MAV_STX_V2) { + warn!("Invalid MAVLink start byte, skipping"); + continue; + } + + let message = + match mavlink::read_v2_raw_message_async::(&mut reader).await { + Ok(message) => message, + Err(error) => { + error!("Failed to parse MAVLink message: {error:?}"); + continue; // Skip this iteration on error + } + }; + + let message = Protocol::new(&source_name, message); + + trace!("Received File message: {message:?}"); + if let Err(error) = hub_sender.send(message) { + error!("Failed to send message to hub: {error:?}"); + } + } + + debug!("File Receive task for {source_name} finished"); + Ok(()) + } +} + +#[async_trait::async_trait] +impl Driver for FileServer { + #[instrument(level = "debug", skip(self, hub_sender))] + async fn run(&self, hub_sender: broadcast::Sender) -> Result<()> { + let file = tokio::fs::File::open(self.path.clone()).await.unwrap(); + let mut reader = tokio::io::BufReader::new(file); + + tokio::spawn(FileServer::handle_client( + self.clone(), + reader, + Arc::new(hub_sender), + )); + + Ok(()) + } + + #[instrument(level = "debug", skip(self))] + fn info(&self) -> DriverInfo { + DriverInfo { + name: "FileServer".to_string(), + } + } +} diff --git a/src/drivers/mod.rs b/src/drivers/mod.rs index 3de27f4b..71ebe962 100644 --- a/src/drivers/mod.rs +++ b/src/drivers/mod.rs @@ -1,4 +1,5 @@ pub mod fake; +pub mod file; pub mod tcp; pub mod udp; diff --git a/src/main.rs b/src/main.rs index e196e486..22b52ae1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,7 +18,7 @@ async fn main() -> Result<()> { logger::init(); let hub = hub::Hub::new( - 100, + 10000, Arc::new(RwLock::new( mavlink::ardupilotmega::MavComponent::MAV_COMP_ID_ONBOARD_COMPUTER as u8, )), @@ -29,6 +29,13 @@ async fn main() -> Result<()> { // Endpoints creation { + for endpoint in cli::file_server_endpoints() { + debug!("Creating File Server to {endpoint:?}"); + hub.add_driver(Arc::new( + drivers::file::server::FileServer::try_new(&endpoint).unwrap(), + )) + .await?; + } for endpoint in cli::tcp_client_endpoints() { debug!("Creating TCP Client to {endpoint:?}"); hub.add_driver(Arc::new(drivers::tcp::client::TcpClient::new(&endpoint)))