From 81a9a98d644135838f19e1920122916784d7aaf6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jo=C3=A3o=20Ant=C3=B4nio=20Cardoso?= <joao.maker@gmail.com>
Date: Sun, 15 Jan 2023 00:50:55 -0300
Subject: [PATCH] Deserialization seems to be working. Need some refactoring

---
 build/binder.rs      |   4 +-
 build/main.rs        |   1 +
 build/parser.rs      | 132 ++++++++++++++++++++++++++++++++++++++++++-
 src/lib.rs           |  65 ++++++++++++++++++---
 src/serialize.rs     |  13 +++--
 tests/deserialize.rs |  20 +++++++
 tests/serialize.rs   |   2 +-
 7 files changed, 219 insertions(+), 18 deletions(-)
 create mode 100644 tests/deserialize.rs

diff --git a/build/binder.rs b/build/binder.rs
index a71d02b1d..82a85b669 100644
--- a/build/binder.rs
+++ b/build/binder.rs
@@ -6,7 +6,7 @@ pub fn generate<W: Write>(modules: Vec<String>, out: &mut W) {
     dbg!(&modules);
     let modules_tokens = modules.into_iter().map(|module| {
         let file_name = module.clone() + ".rs";
-        let module_ident = quote::format_ident!("{}", module.clone());
+        let module_ident = quote::format_ident!("{module}");
 
         quote! {
             pub mod #module_ident {
@@ -19,5 +19,5 @@ pub fn generate<W: Write>(modules: Vec<String>, out: &mut W) {
         #(#modules_tokens)*
     };
 
-    writeln!(out, "{}", tokens).unwrap();
+    writeln!(out, "{tokens}").unwrap();
 }
diff --git a/build/main.rs b/build/main.rs
index 1fa411153..7d31bebc6 100644
--- a/build/main.rs
+++ b/build/main.rs
@@ -80,6 +80,7 @@ pub fn main() {
 
 fn format_code(cwd: impl AsRef<Path>, path: impl AsRef<OsStr>) {
     if let Err(error) = Command::new("rustfmt")
+        .args([""])
         .arg("--edition")
         .arg("2021")
         .arg(path)
diff --git a/build/parser.rs b/build/parser.rs
index 5c5b12c74..5d44e01e2 100644
--- a/build/parser.rs
+++ b/build/parser.rs
@@ -182,6 +182,7 @@ impl MessageDefinition {
             })
             .collect();
 
+        // Serialization part
         let mut sum_quote = None;
         let mut variables_serialized: Vec<TokenStream> = vec![];
         let mut sum = 0;
@@ -236,6 +237,78 @@ impl MessageDefinition {
 
         let sum_quote = sum_quote.or_else(|| Some(quote! { #sum }));
 
+        // Descerialization part
+        let mut b: usize = 0; // current byte
+        let variables_deserialized: Vec<TokenStream> = self
+            .payload
+            .iter()
+            .map(|field| {
+                let name = ident!(field.name);
+                match &field.typ {
+                    PayloadType::I8 | PayloadType::U8 | PayloadType::CHAR => {
+                        let value = quote! {
+                            #name: payload[#b].into(),
+                        };
+                        b += field.typ.to_size();
+                        value
+                    }
+                    PayloadType::U16 | PayloadType::I16 | PayloadType::U32 | PayloadType::I32 | PayloadType::F32 => {
+                        let data_type = field.typ.to_rust();
+                        let data_size = field.typ.to_size();
+                        let field_token = quote! {
+                            #name: #data_type::from_le_bytes(payload[#b..#b + #data_size].try_into().expect("Wrong slice length")),
+                        };
+                        b += data_size;
+                        field_token
+                    }
+                    PayloadType::VECTOR(vector) => {
+                        let data_type = vector.data_type.to_rust();
+                        let data_size = vector.data_type.to_size();
+                        if let Some(size_type) = &vector.size_type {
+                            let length_name = quote::format_ident!("{}_length", field.name);
+                            let length_type = size_type.to_rust();
+                            let length = self.payload.len();
+                            let field_token = {
+                                let value = match vector.data_type {
+                                    PayloadType::CHAR |
+                                    PayloadType::U8 |
+                                    PayloadType::I8 => quote! {
+                                        payload[#b..#b + payload.len()].to_vec()
+                                    },
+                                    PayloadType::U16 |
+                                    PayloadType::U32 |
+                                    PayloadType::I16 |
+                                    PayloadType::I32 |
+                                    PayloadType::F32 => quote! {
+                                        payload[#b..#b + payload.len()]
+                                            .chunks_exact(#data_size)
+                                            .into_iter()
+                                            .map(|a| u16::from_le_bytes((*a).try_into().expect("Wrong slice length")))
+                                            .collect::<Vec<#data_type>>()
+                                    },
+                                    PayloadType::VECTOR(_) => unimplemented!("Vector of vectors are not supported"),
+                                };
+
+                                quote! {
+                                    #length_name: payload.len() as #length_type,
+                                    #name: #value,
+                                }
+                            };
+                            b += length;
+                            field_token
+                        } else {
+                            let length = self.payload.len();
+                            let field_token = quote! {
+                                #name: String::from_utf8(payload[#b..#b + payload.len()].to_vec()).unwrap(),
+                            };
+                            b += length;
+                            field_token
+                        }
+                    }
+                }
+            })
+            .collect();
+
         quote! {
             #[derive(Debug, Clone, PartialEq, Default)]
             #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -245,11 +318,19 @@ impl MessageDefinition {
             }
 
             impl Serialize for #struct_name {
-                fn serialize(self, buffer: &mut [u8]) -> usize {
+                fn serialize(&self, buffer: &mut [u8]) -> usize {
                     #(#variables_serialized)*
                     #sum_quote
                 }
             }
+
+            impl Deserialize for #struct_name {
+                fn deserialize(payload: &[u8]) -> Result<Self, &'static str> {
+                    Ok(Self {
+                        #(#variables_deserialized)*
+                    })
+                }
+            }
         }
     }
 }
@@ -301,6 +382,19 @@ fn emit_ping_message(messages: HashMap<&String, &MessageDefinition>) -> TokenStr
         })
         .collect::<Vec<TokenStream>>();
 
+    let message_enums_deserialize = messages
+        .iter()
+        .map(|(name, message)| {
+            let pascal_message_name = ident!(name.to_case(Case::Pascal));
+            let struct_name = quote::format_ident!("{}Struct", pascal_message_name);
+            let id = message.id;
+
+            quote! {
+                #id => Messages::#pascal_message_name(#struct_name::deserialize(payload)?),
+            }
+        })
+        .collect::<Vec<TokenStream>>();
+
     quote! {
         impl PingMessage for Messages {
             fn message_name(&self) -> &'static str {
@@ -321,13 +415,45 @@ fn emit_ping_message(messages: HashMap<&String, &MessageDefinition>) -> TokenStr
                     _ => Err("Invalid message name."),
                 }
             }
+        }
 
-            fn serialize(self, buffer: &mut [u8]) -> usize {
+        impl Serialize for Messages {
+            fn serialize(&self, buffer: &mut [u8]) -> usize {
                 match self {
                     #(#message_enums_serialize)*
                 }
             }
         }
+
+        impl Deserialize for Messages {
+            fn deserialize(buffer: &[u8]) -> Result<Self, &'static str> {
+                // Parse start1 and start2
+                if !((buffer[0] == b'B') && (buffer[1] == b'R')) {
+                    return Err("Message should start with \"BR\" ASCII sequence");
+                }
+
+                // Get the package data
+                let payload_length = u16::from_le_bytes([buffer[2], buffer[3]]);
+                let message_id = u16::from_le_bytes([buffer[4], buffer[5]]);
+                let _src_device_id = buffer[6];
+                let _dst_device_id = buffer[7];
+                let payload = &buffer[8..(8 + payload_length) as usize];
+                let _checksum = u16::from_le_bytes([
+                    buffer[(payload_length + 1) as usize],
+                    buffer[(payload_length + 2) as usize],
+                ]);
+
+                // Parse the payload
+                Ok(match message_id {
+                    #(#message_enums_deserialize)*
+                    _ => {
+                        return Err(&"Unknown message id");
+                    }
+                })
+            }
+        }
+
+
     }
 }
 
@@ -368,6 +494,8 @@ pub fn generate<R: Read, W: Write>(input: &mut R, output_rust: &mut W) {
     let code = quote! {
         use crate::serialize::PingMessage;
         use crate::serialize::Serialize;
+        use crate::serialize::Deserialize;
+        use std::convert::TryInto;
 
         #[cfg(feature = "serde")]
         use serde::{Deserialize, Serialize};
diff --git a/src/lib.rs b/src/lib.rs
index 4e91a5034..148180bee 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,11 +1,11 @@
 include!(concat!(env!("OUT_DIR"), "/mod.rs"));
 
-use crate::serialize::PingMessage;
+use crate::serialize::{Deserialize, PingMessage};
 
 const PAYLOAD_SIZE: usize = 255;
 
 use std::fmt;
-use std::io::Write;
+use std::{convert::TryFrom, io::Write};
 
 pub mod serialize;
 
@@ -21,6 +21,59 @@ impl Default for PingMessagePack {
     }
 }
 
+impl<T: PingMessage> From<&T> for PingMessagePack {
+    fn from(message: &T) -> Self {
+        let mut new: Self = Default::default();
+        new.set_message(message);
+        new
+    }
+}
+
+pub enum Messages {
+    Bluebps(bluebps::Messages),
+    Common(common::Messages),
+    Ping1d(ping1d::Messages),
+    Ping360(ping360::Messages),
+}
+
+impl TryFrom<&Vec<u8>> for Messages {
+    type Error = &'static str; // TODO: define error types for each kind of failure
+
+    fn try_from(buffer: &Vec<u8>) -> Result<Self, Self::Error> {
+        // Parse start1 and start2
+        if !((buffer[0] == b'B') && (buffer[1] == b'R')) {
+            return Err("Message should start with \"BR\" ASCII sequence");
+        }
+
+        // Get the package data
+        let payload_length = u16::from_le_bytes([buffer[2], buffer[3]]);
+        let _message_id = u16::from_le_bytes([buffer[4], buffer[5]]);
+        let _src_device_id = buffer[6];
+        let _dst_device_id = buffer[7];
+        let payload = &buffer[8..(8 + payload_length) as usize];
+        let _checksum = u16::from_le_bytes([
+            buffer[(payload_length + 1) as usize],
+            buffer[(payload_length + 2) as usize],
+        ]);
+
+        // Try to parse with each module
+        if let Ok(message) = bluebps::Messages::deserialize(buffer) {
+            return Ok(Messages::Bluebps(message));
+        }
+        if let Ok(message) = common::Messages::deserialize(buffer) {
+            return Ok(Messages::Common(message));
+        }
+        if let Ok(message) = ping1d::Messages::deserialize(buffer) {
+            return Ok(Messages::Ping1d(message));
+        }
+        if let Ok(message) = ping360::Messages::deserialize(buffer) {
+            return Ok(Messages::Ping360(message));
+        }
+
+        Err("Unknown message")
+    }
+}
+
 impl PingMessagePack {
     /**
      * Message Format
@@ -45,13 +98,7 @@ impl PingMessagePack {
         Default::default()
     }
 
-    pub fn from(message: impl PingMessage) -> Self {
-        let mut new: Self = Default::default();
-        new.set_message(message);
-        new
-    }
-
-    pub fn set_message(&mut self, message: impl PingMessage) {
+    pub fn set_message(&mut self, message: &impl PingMessage) {
         let message_id = message.message_id();
         let (left, right) = self.0.split_at_mut(Self::HEADER_SIZE);
         let length = message.serialize(right) as u16;
diff --git a/src/serialize.rs b/src/serialize.rs
index d3b311ebd..7c2e8af93 100644
--- a/src/serialize.rs
+++ b/src/serialize.rs
@@ -1,15 +1,20 @@
 pub trait PingMessage
 where
-    Self: Sized,
+    Self: Sized + Serialize + Deserialize,
 {
     fn message_id(&self) -> u16;
     fn message_name(&self) -> &'static str;
 
     fn message_id_from_name(name: &str) -> Result<u16, &'static str>;
-
-    fn serialize(self, buffer: &mut [u8]) -> usize;
 }
 
 pub trait Serialize {
-    fn serialize(self, buffer: &mut [u8]) -> usize;
+    fn serialize(&self, buffer: &mut [u8]) -> usize;
+}
+
+pub trait Deserialize
+where
+    Self: Sized,
+{
+    fn deserialize(buffer: &[u8]) -> Result<Self, &'static str>;
 }
diff --git a/tests/deserialize.rs b/tests/deserialize.rs
new file mode 100644
index 000000000..00dd30d0f
--- /dev/null
+++ b/tests/deserialize.rs
@@ -0,0 +1,20 @@
+use std::convert::TryFrom;
+
+use ping_rs::common::Messages as common_messages;
+use ping_rs::{common, Messages};
+
+#[test]
+fn test_simple_deserialization() {
+    let general_request =
+        common_messages::GeneralRequest(common::GeneralRequestStruct { requested_id: 5 });
+
+    let buffer: Vec<u8> = vec![
+        0x42, 0x52, 0x02, 0x00, 0x06, 0x00, 0x00, 0x00, 0x05, 0x00, 0xa1, 0x00,
+    ];
+    let Messages::Common(parsed) = Messages::try_from(&buffer).unwrap() else {
+        panic!("");
+    };
+
+    // From official ping protocol documentation
+    assert_eq!(general_request, parsed);
+}
diff --git a/tests/serialize.rs b/tests/serialize.rs
index 2701563f2..ad30f2274 100644
--- a/tests/serialize.rs
+++ b/tests/serialize.rs
@@ -5,7 +5,7 @@ use ping_rs::{common, PingMessagePack};
 fn test_simple_serialization() {
     let general_request =
         common_messages::GeneralRequest(common::GeneralRequestStruct { requested_id: 5 });
-    let message = PingMessagePack::from(general_request);
+    let message = PingMessagePack::from(&general_request);
 
     // From official ping protocol documentation
     assert_eq!(