diff --git a/src/tlsh.rs b/src/tlsh.rs index 1cefdcf..7c9d88b 100644 --- a/src/tlsh.rs +++ b/src/tlsh.rs @@ -1,3 +1,5 @@ +use core::str::FromStr; + use crate::pearson::{b_mapping, fast_b_mapping}; use crate::quartile::get_quartiles; use crate::util::{l_capturing, swap_byte}; @@ -343,6 +345,61 @@ impl Option { + if s.len() != TLSH_STRING_LEN_REQ || s[0] != b'T' || s[1] != b'1' { + return None; + } + + let mut i = 2; + + let mut checksum = [0; TLSH_CHECKSUM_LEN]; + for k in &mut checksum { + *k = swap_byte(from_hex(s, &mut i)?); + } + + let lvalue = swap_byte(from_hex(s, &mut i)?); + let qb = from_hex(s, &mut i)?; + let q1_ratio = qb >> 4; + let q2_ratio = qb & 0x0F; + + let mut code = [0; CODE_SIZE]; + for c in code.iter_mut().rev() { + *c = from_hex(s, &mut i)?; + } + + Some(Self { + lvalue, + q1_ratio, + q2_ratio, + checksum, + code, + }) + } +} + +/// Error returned when failing to convert a hash string to a `Tlsh` object. +#[derive(Debug, PartialEq, Eq)] +pub struct ParseError; + +/// Parse a hash string and build the corresponding `Tlsh` object. +impl + FromStr for Tlsh +{ + type Err = ParseError; + + fn from_str(s: &str) -> Result { + Self::from_hash(s.as_bytes()).ok_or(ParseError) + } +} + +fn from_hex(s: &[u8], i: &mut usize) -> Option { + let a = char::from(s[*i]).to_digit(16)?; + *i += 1; + let b = char::from(s[*i]).to_digit(16)?; + *i += 1; + + Some(((a as u8) << 4) | (b as u8)) } fn to_hex(s: &mut [u8], s_idx: &mut usize, b: u8) { diff --git a/tests/it/hash.rs b/tests/it/hash.rs index 4d3458e..e6022bf 100644 --- a/tests/it/hash.rs +++ b/tests/it/hash.rs @@ -21,25 +21,55 @@ where } macro_rules! do_hash_test { - ($testname:ident, $name:expr, $type:ty) => { + ($testname:ident, $name:expr, $builder:ty, $tlsh:ty) => { #[test] fn $testname() { test_hash( &format!("tests/assets/tlsh/exp/example_data.{}.len.out_EXP", $name), |contents| { - let mut tlsh = <$type>::new(); - tlsh.update(contents); - tlsh.build() - .map(|v| String::from_utf8(v.hash().to_vec()).unwrap()) - .unwrap_or_default() + let mut builder = <$builder>::new(); + builder.update(contents); + let tlsh = builder.build().unwrap(); + let hash = String::from_utf8(tlsh.hash().to_vec()).unwrap(); + + // Test the FromStr implementation + let tlsh2 = hash.parse::<$tlsh>().unwrap(); + assert_eq!(tlsh.hash(), tlsh2.hash()); + + hash }, ) } }; } -do_hash_test!(test_hash_48_1, "48.1", tlsh2::TlshBuilder48_1); -do_hash_test!(test_hash_128_1, "128.1", tlsh2::TlshBuilder128_1); -do_hash_test!(test_hash_128_3, "128.3", tlsh2::TlshBuilder128_3); -do_hash_test!(test_hash_256_1, "256.1", tlsh2::TlshBuilder256_1); -do_hash_test!(test_hash_256_3, "256.3", tlsh2::TlshBuilder256_3); +do_hash_test!( + test_hash_48_1, + "48.1", + tlsh2::TlshBuilder48_1, + tlsh2::Tlsh48_1 +); +do_hash_test!( + test_hash_128_1, + "128.1", + tlsh2::TlshBuilder128_1, + tlsh2::Tlsh128_1 +); +do_hash_test!( + test_hash_128_3, + "128.3", + tlsh2::TlshBuilder128_3, + tlsh2::Tlsh128_3 +); +do_hash_test!( + test_hash_256_1, + "256.1", + tlsh2::TlshBuilder256_1, + tlsh2::Tlsh256_1 +); +do_hash_test!( + test_hash_256_3, + "256.3", + tlsh2::TlshBuilder256_3, + tlsh2::Tlsh256_3 +);