diff --git a/build/binder.rs b/build/binder.rs index a71d02b1d..82a85b669 100644 --- a/build/binder.rs +++ b/build/binder.rs @@ -6,7 +6,7 @@ pub fn generate(modules: Vec, out: &mut W) { dbg!(&modules); let modules_tokens = modules.into_iter().map(|module| { let file_name = module.clone() + ".rs"; - let module_ident = quote::format_ident!("{}", module.clone()); + let module_ident = quote::format_ident!("{module}"); quote! { pub mod #module_ident { @@ -19,5 +19,5 @@ pub fn generate(modules: Vec, out: &mut W) { #(#modules_tokens)* }; - writeln!(out, "{}", tokens).unwrap(); + writeln!(out, "{tokens}").unwrap(); } diff --git a/build/main.rs b/build/main.rs index 1fa411153..7d31bebc6 100644 --- a/build/main.rs +++ b/build/main.rs @@ -80,6 +80,7 @@ pub fn main() { fn format_code(cwd: impl AsRef, path: impl AsRef) { if let Err(error) = Command::new("rustfmt") + .args([""]) .arg("--edition") .arg("2021") .arg(path) diff --git a/build/parser.rs b/build/parser.rs index 5c5b12c74..5d44e01e2 100644 --- a/build/parser.rs +++ b/build/parser.rs @@ -182,6 +182,7 @@ impl MessageDefinition { }) .collect(); + // Serialization part let mut sum_quote = None; let mut variables_serialized: Vec = vec![]; let mut sum = 0; @@ -236,6 +237,78 @@ impl MessageDefinition { let sum_quote = sum_quote.or_else(|| Some(quote! { #sum })); + // Descerialization part + let mut b: usize = 0; // current byte + let variables_deserialized: Vec = self + .payload + .iter() + .map(|field| { + let name = ident!(field.name); + match &field.typ { + PayloadType::I8 | PayloadType::U8 | PayloadType::CHAR => { + let value = quote! { + #name: payload[#b].into(), + }; + b += field.typ.to_size(); + value + } + PayloadType::U16 | PayloadType::I16 | PayloadType::U32 | PayloadType::I32 | PayloadType::F32 => { + let data_type = field.typ.to_rust(); + let data_size = field.typ.to_size(); + let field_token = quote! { + #name: #data_type::from_le_bytes(payload[#b..#b + #data_size].try_into().expect("Wrong slice length")), + }; + b += data_size; + field_token + } + PayloadType::VECTOR(vector) => { + let data_type = vector.data_type.to_rust(); + let data_size = vector.data_type.to_size(); + if let Some(size_type) = &vector.size_type { + let length_name = quote::format_ident!("{}_length", field.name); + let length_type = size_type.to_rust(); + let length = self.payload.len(); + let field_token = { + let value = match vector.data_type { + PayloadType::CHAR | + PayloadType::U8 | + PayloadType::I8 => quote! { + payload[#b..#b + payload.len()].to_vec() + }, + PayloadType::U16 | + PayloadType::U32 | + PayloadType::I16 | + PayloadType::I32 | + PayloadType::F32 => quote! { + payload[#b..#b + payload.len()] + .chunks_exact(#data_size) + .into_iter() + .map(|a| u16::from_le_bytes((*a).try_into().expect("Wrong slice length"))) + .collect::>() + }, + PayloadType::VECTOR(_) => unimplemented!("Vector of vectors are not supported"), + }; + + quote! { + #length_name: payload.len() as #length_type, + #name: #value, + } + }; + b += length; + field_token + } else { + let length = self.payload.len(); + let field_token = quote! { + #name: String::from_utf8(payload[#b..#b + payload.len()].to_vec()).unwrap(), + }; + b += length; + field_token + } + } + } + }) + .collect(); + quote! { #[derive(Debug, Clone, PartialEq, Default)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -245,11 +318,19 @@ impl MessageDefinition { } impl Serialize for #struct_name { - fn serialize(self, buffer: &mut [u8]) -> usize { + fn serialize(&self, buffer: &mut [u8]) -> usize { #(#variables_serialized)* #sum_quote } } + + impl Deserialize for #struct_name { + fn deserialize(payload: &[u8]) -> Result { + Ok(Self { + #(#variables_deserialized)* + }) + } + } } } } @@ -301,6 +382,19 @@ fn emit_ping_message(messages: HashMap<&String, &MessageDefinition>) -> TokenStr }) .collect::>(); + let message_enums_deserialize = messages + .iter() + .map(|(name, message)| { + let pascal_message_name = ident!(name.to_case(Case::Pascal)); + let struct_name = quote::format_ident!("{}Struct", pascal_message_name); + let id = message.id; + + quote! { + #id => Messages::#pascal_message_name(#struct_name::deserialize(payload)?), + } + }) + .collect::>(); + quote! { impl PingMessage for Messages { fn message_name(&self) -> &'static str { @@ -321,13 +415,45 @@ fn emit_ping_message(messages: HashMap<&String, &MessageDefinition>) -> TokenStr _ => Err("Invalid message name."), } } + } - fn serialize(self, buffer: &mut [u8]) -> usize { + impl Serialize for Messages { + fn serialize(&self, buffer: &mut [u8]) -> usize { match self { #(#message_enums_serialize)* } } } + + impl Deserialize for Messages { + fn deserialize(buffer: &[u8]) -> Result { + // Parse start1 and start2 + if !((buffer[0] == b'B') && (buffer[1] == b'R')) { + return Err("Message should start with \"BR\" ASCII sequence"); + } + + // Get the package data + let payload_length = u16::from_le_bytes([buffer[2], buffer[3]]); + let message_id = u16::from_le_bytes([buffer[4], buffer[5]]); + let _src_device_id = buffer[6]; + let _dst_device_id = buffer[7]; + let payload = &buffer[8..(8 + payload_length) as usize]; + let _checksum = u16::from_le_bytes([ + buffer[(payload_length + 1) as usize], + buffer[(payload_length + 2) as usize], + ]); + + // Parse the payload + Ok(match message_id { + #(#message_enums_deserialize)* + _ => { + return Err(&"Unknown message id"); + } + }) + } + } + + } } @@ -368,6 +494,8 @@ pub fn generate(input: &mut R, output_rust: &mut W) { let code = quote! { use crate::serialize::PingMessage; use crate::serialize::Serialize; + use crate::serialize::Deserialize; + use std::convert::TryInto; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; diff --git a/src/lib.rs b/src/lib.rs index 4e91a5034..148180bee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,11 @@ include!(concat!(env!("OUT_DIR"), "/mod.rs")); -use crate::serialize::PingMessage; +use crate::serialize::{Deserialize, PingMessage}; const PAYLOAD_SIZE: usize = 255; use std::fmt; -use std::io::Write; +use std::{convert::TryFrom, io::Write}; pub mod serialize; @@ -21,6 +21,59 @@ impl Default for PingMessagePack { } } +impl From<&T> for PingMessagePack { + fn from(message: &T) -> Self { + let mut new: Self = Default::default(); + new.set_message(message); + new + } +} + +pub enum Messages { + Bluebps(bluebps::Messages), + Common(common::Messages), + Ping1d(ping1d::Messages), + Ping360(ping360::Messages), +} + +impl TryFrom<&Vec> for Messages { + type Error = &'static str; // TODO: define error types for each kind of failure + + fn try_from(buffer: &Vec) -> Result { + // Parse start1 and start2 + if !((buffer[0] == b'B') && (buffer[1] == b'R')) { + return Err("Message should start with \"BR\" ASCII sequence"); + } + + // Get the package data + let payload_length = u16::from_le_bytes([buffer[2], buffer[3]]); + let _message_id = u16::from_le_bytes([buffer[4], buffer[5]]); + let _src_device_id = buffer[6]; + let _dst_device_id = buffer[7]; + let payload = &buffer[8..(8 + payload_length) as usize]; + let _checksum = u16::from_le_bytes([ + buffer[(payload_length + 1) as usize], + buffer[(payload_length + 2) as usize], + ]); + + // Try to parse with each module + if let Ok(message) = bluebps::Messages::deserialize(buffer) { + return Ok(Messages::Bluebps(message)); + } + if let Ok(message) = common::Messages::deserialize(buffer) { + return Ok(Messages::Common(message)); + } + if let Ok(message) = ping1d::Messages::deserialize(buffer) { + return Ok(Messages::Ping1d(message)); + } + if let Ok(message) = ping360::Messages::deserialize(buffer) { + return Ok(Messages::Ping360(message)); + } + + Err("Unknown message") + } +} + impl PingMessagePack { /** * Message Format @@ -45,13 +98,7 @@ impl PingMessagePack { Default::default() } - pub fn from(message: impl PingMessage) -> Self { - let mut new: Self = Default::default(); - new.set_message(message); - new - } - - pub fn set_message(&mut self, message: impl PingMessage) { + pub fn set_message(&mut self, message: &impl PingMessage) { let message_id = message.message_id(); let (left, right) = self.0.split_at_mut(Self::HEADER_SIZE); let length = message.serialize(right) as u16; diff --git a/src/serialize.rs b/src/serialize.rs index d3b311ebd..7c2e8af93 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -1,15 +1,20 @@ pub trait PingMessage where - Self: Sized, + Self: Sized + Serialize + Deserialize, { fn message_id(&self) -> u16; fn message_name(&self) -> &'static str; fn message_id_from_name(name: &str) -> Result; - - fn serialize(self, buffer: &mut [u8]) -> usize; } pub trait Serialize { - fn serialize(self, buffer: &mut [u8]) -> usize; + fn serialize(&self, buffer: &mut [u8]) -> usize; +} + +pub trait Deserialize +where + Self: Sized, +{ + fn deserialize(buffer: &[u8]) -> Result; } diff --git a/tests/deserialize.rs b/tests/deserialize.rs new file mode 100644 index 000000000..00dd30d0f --- /dev/null +++ b/tests/deserialize.rs @@ -0,0 +1,20 @@ +use std::convert::TryFrom; + +use ping_rs::common::Messages as common_messages; +use ping_rs::{common, Messages}; + +#[test] +fn test_simple_deserialization() { + let general_request = + common_messages::GeneralRequest(common::GeneralRequestStruct { requested_id: 5 }); + + let buffer: Vec = vec![ + 0x42, 0x52, 0x02, 0x00, 0x06, 0x00, 0x00, 0x00, 0x05, 0x00, 0xa1, 0x00, + ]; + let Messages::Common(parsed) = Messages::try_from(&buffer).unwrap() else { + panic!(""); + }; + + // From official ping protocol documentation + assert_eq!(general_request, parsed); +} diff --git a/tests/serialize.rs b/tests/serialize.rs index 2701563f2..ad30f2274 100644 --- a/tests/serialize.rs +++ b/tests/serialize.rs @@ -5,7 +5,7 @@ use ping_rs::{common, PingMessagePack}; fn test_simple_serialization() { let general_request = common_messages::GeneralRequest(common::GeneralRequestStruct { requested_id: 5 }); - let message = PingMessagePack::from(general_request); + let message = PingMessagePack::from(&general_request); // From official ping protocol documentation assert_eq!(