Skip to content

Commit

Permalink
Deserialization seems to be working. Need some refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
joaoantoniocardoso committed Jan 15, 2023
1 parent ef16c34 commit e40fe6d
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 18 deletions.
4 changes: 2 additions & 2 deletions build/binder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub fn generate<W: Write>(modules: Vec<String>, out: &mut W) {
dbg!(&modules);
let modules_tokens = modules.into_iter().map(|module| {
let file_name = module.clone() + ".rs";
let module_ident = quote::format_ident!("{}", module.clone());
let module_ident = quote::format_ident!("{module}");

quote! {
pub mod #module_ident {
Expand All @@ -19,5 +19,5 @@ pub fn generate<W: Write>(modules: Vec<String>, out: &mut W) {
#(#modules_tokens)*
};

writeln!(out, "{}", tokens).unwrap();
writeln!(out, "{tokens}").unwrap();
}
1 change: 1 addition & 0 deletions build/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ pub fn main() {

fn format_code(cwd: impl AsRef<Path>, path: impl AsRef<OsStr>) {
if let Err(error) = Command::new("rustfmt")
.args([""])
.arg("--edition")
.arg("2021")
.arg(path)
Expand Down
132 changes: 130 additions & 2 deletions build/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ impl MessageDefinition {
})
.collect();

// Serialization part
let mut sum_quote = None;
let mut variables_serialized: Vec<TokenStream> = vec![];
let mut sum = 0;
Expand Down Expand Up @@ -236,6 +237,78 @@ impl MessageDefinition {

let sum_quote = sum_quote.or_else(|| Some(quote! { #sum }));

// Descerialization part
let mut b: usize = 0; // current byte
let variables_deserialized: Vec<TokenStream> = self
.payload
.iter()
.map(|field| {
let name = ident!(field.name);
match &field.typ {
PayloadType::I8 | PayloadType::U8 | PayloadType::CHAR => {
let value = quote! {
#name: payload[#b].into(),
};
b += field.typ.to_size();
value
}
PayloadType::U16 | PayloadType::I16 | PayloadType::U32 | PayloadType::I32 | PayloadType::F32 => {
let data_type = field.typ.to_rust();
let data_size = field.typ.to_size();
let field_token = quote! {
#name: #data_type::from_le_bytes(payload[#b..#b + #data_size].try_into().expect("Wrong slice length")),
};
b += data_size;
field_token
}
PayloadType::VECTOR(vector) => {
let data_type = vector.data_type.to_rust();
let data_size = vector.data_type.to_size();
if let Some(size_type) = &vector.size_type {
let length_name = quote::format_ident!("{}_length", field.name);
let length_type = size_type.to_rust();
let length = self.payload.len();
let field_token = {
let value = match vector.data_type {
PayloadType::CHAR |
PayloadType::U8 |
PayloadType::I8 => quote! {
payload[#b..#b + payload.len()].to_vec()
},
PayloadType::U16 |
PayloadType::U32 |
PayloadType::I16 |
PayloadType::I32 |
PayloadType::F32 => quote! {
payload[#b..#b + payload.len()]
.chunks_exact(#data_size)
.into_iter()
.map(|a| u16::from_le_bytes((*a).try_into().expect("Wrong slice length")))
.collect::<Vec<#data_type>>()
},
PayloadType::VECTOR(_) => unimplemented!("Vector of vectors are not supported"),
};

quote! {
#length_name: payload.len() as #length_type,
#name: #value,
}
};
b += length;
field_token
} else {
let length = self.payload.len();
let field_token = quote! {
#name: String::from_utf8(payload[#b..#b + payload.len()].to_vec()).unwrap(),
};
b += length;
field_token
}
}
}
})
.collect();

quote! {
#[derive(Debug, Clone, PartialEq, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
Expand All @@ -245,11 +318,19 @@ impl MessageDefinition {
}

impl Serialize for #struct_name {
fn serialize(self, buffer: &mut [u8]) -> usize {
fn serialize(&self, buffer: &mut [u8]) -> usize {
#(#variables_serialized)*
#sum_quote
}
}

impl Deserialize for #struct_name {
fn deserialize(payload: &[u8]) -> Result<Self, &'static str> {
Ok(Self {
#(#variables_deserialized)*
})
}
}
}
}
}
Expand Down Expand Up @@ -301,6 +382,19 @@ fn emit_ping_message(messages: HashMap<&String, &MessageDefinition>) -> TokenStr
})
.collect::<Vec<TokenStream>>();

let message_enums_deserialize = messages
.iter()
.map(|(name, message)| {
let pascal_message_name = ident!(name.to_case(Case::Pascal));
let struct_name = quote::format_ident!("{}Struct", pascal_message_name);
let id = message.id;

quote! {
#id => Messages::#pascal_message_name(#struct_name::deserialize(payload)?),
}
})
.collect::<Vec<TokenStream>>();

quote! {
impl PingMessage for Messages {
fn message_name(&self) -> &'static str {
Expand All @@ -321,13 +415,45 @@ fn emit_ping_message(messages: HashMap<&String, &MessageDefinition>) -> TokenStr
_ => Err("Invalid message name."),
}
}
}

fn serialize(self, buffer: &mut [u8]) -> usize {
impl Serialize for Messages {
fn serialize(&self, buffer: &mut [u8]) -> usize {
match self {
#(#message_enums_serialize)*
}
}
}

impl Deserialize for Messages {
fn deserialize(buffer: &[u8]) -> Result<Self, &'static str> {
// Parse start1 and start2
if !((buffer[0] == b'B') && (buffer[1] == b'R')) {
return Err("Message should start with \"BR\" ASCII sequence");
}

// Get the package data
let payload_length = u16::from_le_bytes([buffer[2], buffer[3]]);
let message_id = u16::from_le_bytes([buffer[4], buffer[5]]);
let _src_device_id = buffer[6];
let _dst_device_id = buffer[7];
let payload = &buffer[8..(8 + payload_length) as usize];
let _checksum = u16::from_le_bytes([
buffer[(payload_length + 1) as usize],
buffer[(payload_length + 2) as usize],
]);

// Parse the payload
Ok(match message_id {
#(#message_enums_deserialize)*
_ => {
return Err(&"Unknown message id");
}
})
}
}


}
}

Expand Down Expand Up @@ -368,6 +494,8 @@ pub fn generate<R: Read, W: Write>(input: &mut R, output_rust: &mut W) {
let code = quote! {
use crate::serialize::PingMessage;
use crate::serialize::Serialize;
use crate::serialize::Deserialize;
use std::convert::TryInto;

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand Down
65 changes: 56 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
include!(concat!(env!("OUT_DIR"), "/mod.rs"));

use crate::serialize::PingMessage;
use crate::serialize::{Deserialize, PingMessage};

const PAYLOAD_SIZE: usize = 255;

use std::io::Write;
use std::{convert::TryFrom, io::Write};

pub mod serialize;

Expand All @@ -20,6 +20,59 @@ impl Default for PingMessagePack {
}
}

impl<T: PingMessage> From<&T> for PingMessagePack {
fn from(message: &T) -> Self {
let mut new: Self = Default::default();
new.set_message(message);
new
}
}

pub enum Messages {
Bluebps(bluebps::Messages),
Common(common::Messages),
Ping1d(ping1d::Messages),
Ping360(ping360::Messages),
}

impl TryFrom<&Vec<u8>> for Messages {
type Error = &'static str; // TODO: define error types for each kind of failure

fn try_from(buffer: &Vec<u8>) -> Result<Self, Self::Error> {
// Parse start1 and start2
if !((buffer[0] == b'B') && (buffer[1] == b'R')) {
return Err("Message should start with \"BR\" ASCII sequence");
}

// Get the package data
let payload_length = u16::from_le_bytes([buffer[2], buffer[3]]);
let _message_id = u16::from_le_bytes([buffer[4], buffer[5]]);
let _src_device_id = buffer[6];
let _dst_device_id = buffer[7];
let payload = &buffer[8..(8 + payload_length) as usize];
let _checksum = u16::from_le_bytes([
buffer[(payload_length + 1) as usize],
buffer[(payload_length + 2) as usize],
]);

// Try to parse with each module
if let Ok(message) = bluebps::Messages::deserialize(buffer) {
return Ok(Messages::Bluebps(message));
}
if let Ok(message) = common::Messages::deserialize(buffer) {
return Ok(Messages::Common(message));
}
if let Ok(message) = ping1d::Messages::deserialize(buffer) {
return Ok(Messages::Ping1d(message));
}
if let Ok(message) = ping360::Messages::deserialize(buffer) {
return Ok(Messages::Ping360(message));
}

Err("Unknown message")
}
}

impl PingMessagePack {
/**
* Message Format
Expand All @@ -44,13 +97,7 @@ impl PingMessagePack {
Self([0; 1 + Self::HEADER_SIZE + PAYLOAD_SIZE + 2])
}

pub fn from(message: impl PingMessage) -> Self {
let mut new: Self = Default::default();
new.set_message(message);
new
}

pub fn set_message(&mut self, message: impl PingMessage) {
pub fn set_message(&mut self, message: &impl PingMessage) {
let message_id = message.message_id();
let (left, right) = self.0.split_at_mut(Self::HEADER_SIZE);
let length = message.serialize(right) as u16;
Expand Down
13 changes: 9 additions & 4 deletions src/serialize.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
pub trait PingMessage
where
Self: Sized,
Self: Sized + Serialize + Deserialize,
{
fn message_id(&self) -> u16;
fn message_name(&self) -> &'static str;

fn message_id_from_name(name: &str) -> Result<u16, &'static str>;

fn serialize(self, buffer: &mut [u8]) -> usize;
}

pub trait Serialize {
fn serialize(self, buffer: &mut [u8]) -> usize;
fn serialize(&self, buffer: &mut [u8]) -> usize;
}

pub trait Deserialize
where
Self: Sized,
{
fn deserialize(buffer: &[u8]) -> Result<Self, &'static str>;
}
20 changes: 20 additions & 0 deletions tests/deserialize.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use std::convert::TryFrom;

use ping_rs::common::Messages as common_messages;
use ping_rs::{common, Messages};

#[test]
fn test_simple_deserialization() {
let general_request =
common_messages::GeneralRequest(common::GeneralRequestStruct { requested_id: 5 });

let buffer: Vec<u8> = vec![
0x42, 0x52, 0x02, 0x00, 0x06, 0x00, 0x00, 0x00, 0x05, 0x00, 0xa1, 0x00,
];
let Messages::Common(parsed) = Messages::try_from(&buffer).unwrap() else {
panic!("");
};

// From official ping protocol documentation
assert_eq!(general_request, parsed);
}
2 changes: 1 addition & 1 deletion tests/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use ping_rs::{common, PingMessagePack};
fn test_simple_serialization() {
let general_request =
common_messages::GeneralRequest(common::GeneralRequestStruct { requested_id: 5 });
let message = PingMessagePack::from(general_request);
let message = PingMessagePack::from(&general_request);

// From official ping protocol documentation
assert_eq!(
Expand Down

0 comments on commit e40fe6d

Please sign in to comment.