diff --git a/rust/src/detect/uint.rs b/rust/src/detect/uint.rs index 6cf31b2dbcf1..3d6a5baab0ca 100644 --- a/rust/src/detect/uint.rs +++ b/rust/src/detect/uint.rs @@ -16,7 +16,7 @@ */ use nom7::branch::alt; -use nom7::bytes::complete::{is_a, tag, take_while}; +use nom7::bytes::complete::{is_a, tag, tag_no_case, take_while}; use nom7::character::complete::digit1; use nom7::combinator::{all_consuming, map_opt, opt, value, verify}; use nom7::error::{make_error, ErrorKind}; @@ -46,20 +46,54 @@ pub struct DetectUintData { } pub trait DetectIntType: - std::str::FromStr + std::cmp::PartialOrd + num::PrimInt + num::Bounded + std::str::FromStr + + std::cmp::PartialOrd + + num::PrimInt + + num::Bounded + + num::ToPrimitive + + num::FromPrimitive { } impl DetectIntType for T where - T: std::str::FromStr + std::cmp::PartialOrd + num::PrimInt + num::Bounded + T: std::str::FromStr + + std::cmp::PartialOrd + + num::PrimInt + + num::Bounded + + num::ToPrimitive + + num::FromPrimitive { } +pub fn detect_parse_uint_unit(i: &str) -> IResult<&str, u64> { + let (i, unit) = alt(( + value(1024, tag_no_case("kb")), + value(1024 * 1024, tag_no_case("mb")), + value(1024 * 1024 * 1024, tag_no_case("gb")), + ))(i)?; + return Ok((i, unit)); +} + +pub fn detect_parse_uint_with_unit(i: &str) -> IResult<&str, T> { + let (i, arg1) = map_opt(digit1, |s: &str| s.parse::().ok())(i)?; + let (i, unit) = opt(detect_parse_uint_unit)(i)?; + if arg1 >= T::one() { + if let Some(u) = unit { + if T::max_value().to_u64().unwrap() / u < arg1.to_u64().unwrap() { + return Err(Err::Error(make_error(i, ErrorKind::Verify))); + } + let ru64 = arg1 * T::from_u64(u).unwrap(); + return Ok((i, ru64)); + } + } + Ok((i, arg1)) +} + pub fn detect_parse_uint_start_equal( i: &str, ) -> IResult<&str, DetectUintData> { let (i, _) = opt(tag("="))(i)?; let (i, _) = opt(is_a(" "))(i)?; - let (i, arg1) = map_opt(digit1, |s: &str| s.parse::().ok())(i)?; + let (i, arg1) = detect_parse_uint_with_unit(i)?; Ok(( i, DetectUintData { @@ -368,3 +402,34 @@ pub unsafe extern "C" fn rs_detect_u16_free(ctx: &mut DetectUintData) { // Just unbox... std::mem::drop(Box::from_raw(ctx)); } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_uint_unit() { + match detect_parse_uint::(" 2kb") { + Ok((_, val)) => { + assert_eq!(val.arg1, 2048); + } + Err(_) => { + assert!(false); + } + } + match detect_parse_uint::("2kb") { + Ok((_, _val)) => { + assert!(false); + } + Err(_) => {} + } + match detect_parse_uint::("3MB") { + Ok((_, val)) => { + assert_eq!(val.arg1, 3 * 1024 * 1024); + } + Err(_) => { + assert!(false); + } + } + } +}