From 6a1649544b83b4f004f6f9bbf56262457062613d Mon Sep 17 00:00:00 2001 From: kindywu Date: Mon, 6 May 2024 14:35:14 +0800 Subject: [PATCH] bugfix: remove NullBulkString / NullArray --- src/cmd/mod.rs | 30 +++++++---- src/resp/array.rs | 111 +++++++++++++++++++++++++++------------- src/resp/bulk_string.rs | 67 +++++++++++++++--------- src/resp/mod.rs | 9 ++-- src/resp/resp_frame.rs | 18 +++---- 5 files changed, 151 insertions(+), 84 deletions(-) diff --git a/src/cmd/mod.rs b/src/cmd/mod.rs index d4b6e4b..195cbae 100644 --- a/src/cmd/mod.rs +++ b/src/cmd/mod.rs @@ -57,7 +57,10 @@ impl TryFrom for Command { type Error = CommandError; fn try_from(v: RespFrame) -> Result { match v { - RespFrame::Array(array) => array.try_into(), + RespFrame::Array(array) => match array { + Some(array) => array.try_into(), + _ => Err(CommandError::InvalidCommand("Command is null".to_string())), + }, _ => Err(CommandError::InvalidCommand( "Command must be an Array".to_string(), )), @@ -69,9 +72,12 @@ impl TryFrom for Command { type Error = CommandError; fn try_from(v: RespArray) -> Result { match v.first() { - Some(RespFrame::BulkString(ref cmd)) => match cmd.as_ref() { - b"echo" => Ok(Echo::try_from(v)?.into()), - _ => Ok(Unrecognized.into()), + Some(RespFrame::BulkString(ref cmd)) => match cmd { + Some(cmd) => match cmd.as_ref() { + b"echo" => Ok(Echo::try_from(v)?.into()), + _ => Ok(Unrecognized.into()), + }, + _ => Err(CommandError::InvalidCommand("Command is null".to_string())), }, _ => Err(CommandError::InvalidCommand( "Command must have a BulkString as the first argument".to_string(), @@ -109,12 +115,16 @@ fn validate_command( for (i, name) in names.iter().enumerate() { match value[i] { RespFrame::BulkString(ref cmd) => { - if cmd.as_ref().to_ascii_lowercase() != name.as_bytes() { - return Err(CommandError::InvalidCommand(format!( - "Invalid command: expected {}, got {}", - name, - String::from_utf8_lossy(cmd.as_ref()) - ))); + if let Some(cmd) = cmd { + if cmd.as_ref().to_ascii_lowercase() != name.as_bytes() { + return Err(CommandError::InvalidCommand(format!( + "Invalid command: expected {}, got {}", + name, + String::from_utf8_lossy(cmd.as_ref()) + ))); + } + } else { + return Err(CommandError::InvalidCommand("Command is null".to_string())); } } _ => { diff --git a/src/resp/array.rs b/src/resp/array.rs index 2914369..690de0a 100644 --- a/src/resp/array.rs +++ b/src/resp/array.rs @@ -13,18 +13,18 @@ pub struct RespArray(pub(crate) Vec); const NULL_ARRAY_STRING: &[u8; 5] = b"*-1\r\n"; // - array: "*\r\n..." -impl RespEncode for RespArray { +impl RespEncode for Option { fn encode(self) -> Vec { - if self.0.is_empty() { - NULL_ARRAY_STRING.to_vec() - } else { + if let Some(s) = self { // let mut buf = Vec::with_capacity(BUF_CAP); let mut buf = Vec::new(); - buf.extend_from_slice(&format!("*{}\r\n", self.0.len()).into_bytes()); - for frame in self.0 { + buf.extend_from_slice(&format!("*{}\r\n", s.0.len()).into_bytes()); + for frame in s.0 { buf.extend_from_slice(&frame.encode()); } buf + } else { + NULL_ARRAY_STRING.to_vec() } } } @@ -32,31 +32,42 @@ impl RespEncode for RespArray { // - array: "*\r\n..." // - "*2\r\n$3\r\nget\r\n$5\r\nhello\r\n" // FIXME: need to handle incomplete -impl RespDecode for RespArray { +impl RespDecode for Option { const PREFIX: &'static str = "*"; fn decode(buf: &mut BytesMut) -> Result { let (end, len) = parse_length(buf, Self::PREFIX)?; - let total_len = calc_total_length(buf, end, len, Self::PREFIX)?; - // println!("len={},total_len={}", len, total_len); + if len == -1 { + Ok(None) + } else { + let len = len as usize; + let total_len = calc_total_length(buf, end, len, Self::PREFIX)?; - if buf.len() < total_len { - return Err(RespError::NotComplete); - } + // println!("len={},total_len={}", len, total_len); - buf.advance(end + CRLF_LEN); + if buf.len() < total_len { + return Err(RespError::NotComplete); + } - let mut frames = Vec::with_capacity(len); - for _ in 0..len { - frames.push(RespFrame::decode(buf)?); - } + buf.advance(end + CRLF_LEN); - Ok(RespArray::new(frames)) + let mut frames = Vec::with_capacity(len); + for _ in 0..len { + frames.push(RespFrame::decode(buf)?); + } + + Ok(Some(RespArray::new(frames))) + } } fn expect_length(buf: &[u8]) -> Result { let (end, len) = parse_length(buf, Self::PREFIX)?; - calc_total_length(buf, end, len, Self::PREFIX) + if len == -1 { + Ok(5) + } else { + let len = len as usize; + calc_total_length(buf, end, len, Self::PREFIX) + } } } @@ -104,11 +115,11 @@ mod tests { #[test] fn test_array_encode() { - let frame: RespFrame = RespArray::new(vec![ - BulkString::new("set".to_string()).into(), - BulkString::new("hello".to_string()).into(), - BulkString::new("world".to_string()).into(), - ]) + let frame: RespFrame = Some(RespArray::new(vec![ + Some(BulkString::new("set".to_string())).into(), + Some(BulkString::new("hello".to_string())).into(), + Some(BulkString::new("world".to_string())).into(), + ])) .into(); assert_eq!( &frame.encode(), @@ -121,27 +132,36 @@ mod tests { let mut buf = BytesMut::new(); buf.extend_from_slice(b"*2\r\n$4\r\necho\r\n$5\r\nhello\r\n"); - assert_eq!(RespArray::expect_length(&buf), Ok(25)); + assert_eq!(Option::::expect_length(&buf), Ok(25)); - let frame = RespArray::decode(&mut buf)?; - assert_eq!(frame, RespArray::new([b"echo".into(), b"hello".into()])); + let frame = Option::::decode(&mut buf)?; + assert_eq!( + frame, + Some(RespArray::new([b"echo".into(), b"hello".into()])) + ); buf.extend_from_slice(b"*2\r\n$4\r\necho\r\n"); - assert_eq!(RespArray::expect_length(&buf), Err(RespError::NotComplete)); - let ret = RespArray::decode(&mut buf); + assert_eq!( + Option::::expect_length(&buf), + Err(RespError::NotComplete) + ); + let ret = Option::::decode(&mut buf); assert_eq!(ret.unwrap_err(), RespError::NotComplete); buf.extend_from_slice(b"$5\r\nhello\r\n"); - assert_eq!(RespArray::expect_length(&buf), Ok(25)); - let frame = RespArray::decode(&mut buf)?; - assert_eq!(frame, RespArray::new([b"echo".into(), b"hello".into()])); + assert_eq!(Option::::expect_length(&buf), Ok(25)); + let frame = Option::::decode(&mut buf)?; + assert_eq!( + frame, + Some(RespArray::new([b"echo".into(), b"hello".into()])) + ); Ok(()) } #[test] fn test_null_array_encode() { - let frame: RespFrame = RespArray::new(Vec::new()).into(); + let frame: RespFrame = RespFrame::Array(None); assert_eq!(frame.encode(), b"*-1\r\n"); } @@ -150,10 +170,29 @@ mod tests { let mut buf = BytesMut::new(); buf.extend_from_slice(b"*-1\r\n"); - assert_eq!(RespArray::expect_length(&buf), Ok(5)); + assert_eq!(Option::::expect_length(&buf), Ok(5)); + + let frame = Option::::decode(&mut buf)?; + assert_eq!(frame, None); + + Ok(()) + } + + #[test] + fn test_zero_array_encode() { + let frame: RespFrame = Some(RespArray::new(Vec::new())).into(); + assert_eq!(frame.encode(), b"*0\r\n"); + } + + #[test] + fn test_zero_array_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"*0\r\n"); + + assert_eq!(Option::::expect_length(&buf), Ok(4)); - let frame = RespArray::decode(&mut buf)?; - assert_eq!(frame, RespArray::new(Vec::new())); + let frame = Option::::decode(&mut buf)?; + assert_eq!(frame, Some(RespArray::new(Vec::new()))); Ok(()) } diff --git a/src/resp/bulk_string.rs b/src/resp/bulk_string.rs index 1d43f32..d80c0c0 100644 --- a/src/resp/bulk_string.rs +++ b/src/resp/bulk_string.rs @@ -35,27 +35,28 @@ pub struct BulkString(pub(crate) Vec); const NULL_BULK_STRING: &[u8; 5] = b"$-1\r\n"; // - bulk string: "$\r\n\r\n" -impl RespEncode for BulkString { +impl RespEncode for Option { fn encode(self) -> Vec { - if self.len() == 0 { - NULL_BULK_STRING.to_vec() - } else { - let mut buf = Vec::with_capacity(self.len() + 16); - buf.extend_from_slice(&format!("${}\r\n", self.len()).into_bytes()); - buf.extend_from_slice(&self); + if let Some(s) = self { + let mut buf = Vec::with_capacity(s.len() + 16); + buf.extend_from_slice(&format!("${}\r\n", s.len()).into_bytes()); + buf.extend_from_slice(&s); buf.extend_from_slice(b"\r\n"); buf + } else { + NULL_BULK_STRING.to_vec() } } } -impl RespDecode for BulkString { +impl RespDecode for Option { const PREFIX: &'static str = "$"; fn decode(buf: &mut BytesMut) -> Result { let (end, len) = parse_length(buf, Self::PREFIX)?; - if len == 0 { - Ok(BulkString::new(Vec::new())) + if len == -1 { + Ok(None) } else { + let len = len as usize; let remained = &buf[end + CRLF_LEN..]; if remained.len() < len + CRLF_LEN { return Err(RespError::NotComplete); @@ -64,15 +65,16 @@ impl RespDecode for BulkString { buf.advance(end + CRLF_LEN); let data = buf.split_to(len + CRLF_LEN); - Ok(BulkString::new(data[..len].to_vec())) + Ok(Some(BulkString::new(data[..len].to_vec()))) } } fn expect_length(buf: &[u8]) -> Result { let (end, len) = parse_length(buf, Self::PREFIX)?; - if len == 0 { + if len == -1 { Ok(5) } else { + let len = len as usize; Ok(end + CRLF_LEN + len + CRLF_LEN) } } @@ -131,7 +133,7 @@ mod tests { #[test] fn test_bulk_string_encode() { - let frame: RespFrame = BulkString::new(b"hello".to_vec()).into(); + let frame: RespFrame = Some(BulkString::new(b"hello".to_vec())).into(); assert_eq!(frame.encode(), b"$5\r\nhello\r\n"); } @@ -140,25 +142,25 @@ mod tests { let mut buf = BytesMut::new(); buf.extend_from_slice(b"$5\r\nhello\r\n"); - assert_eq!(BulkString::expect_length(&buf), Ok(11)); + assert_eq!(Option::::expect_length(&buf), Ok(11)); - let frame = BulkString::decode(&mut buf)?; - assert_eq!(frame, BulkString::new(b"hello")); + let frame = Option::::decode(&mut buf)?; + assert_eq!(frame, Some(BulkString::new(b"hello"))); buf.extend_from_slice(b"$5\r\nhello"); - let ret = BulkString::decode(&mut buf); + let ret = Option::::decode(&mut buf); assert_eq!(ret.unwrap_err(), RespError::NotComplete); buf.extend_from_slice(b"\r\n"); - let frame = BulkString::decode(&mut buf)?; - assert_eq!(frame, BulkString::new(b"hello")); + let frame = Option::::decode(&mut buf)?; + assert_eq!(frame, Some(BulkString::new(b"hello"))); Ok(()) } #[test] fn test_null_bulk_string_encode() { - let frame: RespFrame = BulkString::new(Vec::new()).into(); + let frame: RespFrame = RespFrame::BulkString(None); assert_eq!(frame.encode(), b"$-1\r\n"); } @@ -167,10 +169,29 @@ mod tests { let mut buf = BytesMut::new(); buf.extend_from_slice(b"$-1\r\n"); - assert_eq!(BulkString::expect_length(&buf), Ok(5)); + assert_eq!(Option::::expect_length(&buf), Ok(5)); + + let frame = Option::::decode(&mut buf)?; + assert_eq!(frame, None); + + Ok(()) + } + + #[test] + fn test_zero_bulk_string_encode() { + let frame: RespFrame = Some(BulkString::new(Vec::new())).into(); + assert_eq!(frame.encode(), b"$0\r\n\r\n"); + } + + #[test] + fn test_zero_bulk_string_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"$0\r\n\r\n"); + + assert_eq!(Option::::expect_length(&buf), Ok(6)); - let frame = BulkString::decode(&mut buf)?; - assert_eq!(frame, BulkString::new(Vec::new())); + let frame = Option::::decode(&mut buf)?; + assert_eq!(frame, Some(BulkString::new(Vec::new()))); Ok(()) } diff --git a/src/resp/mod.rs b/src/resp/mod.rs index 33edd61..b4a7f6d 100644 --- a/src/resp/mod.rs +++ b/src/resp/mod.rs @@ -69,14 +69,11 @@ fn find_crlf(buf: &[u8], nth: usize) -> Option { None } -fn parse_length(buf: &[u8], prefix: &str) -> Result<(usize, usize), RespError> { +fn parse_length(buf: &[u8], prefix: &str) -> Result<(usize, isize), RespError> { let end = extract_simple_frame_data(buf, prefix)?; let s = String::from_utf8_lossy(&buf[prefix.len()..end]); - if s == "-1" { - Ok((end, 0)) - } else { - Ok((end, s.parse()?)) - } + + Ok((end, s.parse()?)) } fn calc_total_length(buf: &[u8], end: usize, len: usize, prefix: &str) -> Result { diff --git a/src/resp/resp_frame.rs b/src/resp/resp_frame.rs index cf26c04..fca7751 100644 --- a/src/resp/resp_frame.rs +++ b/src/resp/resp_frame.rs @@ -8,9 +8,9 @@ use crate::{BulkString, RespArray, SimpleError, SimpleString}; #[derive(Debug, Clone, PartialEq, PartialOrd)] pub enum RespFrame { SimpleString(SimpleString), - SimpleError(SimpleError), - BulkString(BulkString), - Array(RespArray), + Error(SimpleError), + BulkString(Option), + Array(Option), } // impl RespEncode for RespFrame { @@ -28,13 +28,13 @@ impl From<&str> for RespFrame { impl From<&[u8]> for RespFrame { fn from(s: &[u8]) -> Self { - BulkString(s.to_vec()).into() + Some(BulkString(s.to_vec())).into() } } impl From<&[u8; N]> for RespFrame { fn from(s: &[u8; N]) -> Self { - BulkString(s.to_vec()).into() + Some(BulkString(s.to_vec())).into() } } @@ -61,14 +61,14 @@ impl RespDecode for RespFrame { } Some(b'*') => { // try null array first - match RespArray::decode(buf) { + match Option::::decode(buf) { Ok(frame) => Ok(frame.into()), Err(e) => Err(e), } } Some(b'$') => { // try null bulk string first - match BulkString::decode(buf) { + match Option::::decode(buf) { Ok(frame) => Ok(frame.into()), Err(e) => Err(e), } @@ -84,8 +84,8 @@ impl RespDecode for RespFrame { fn expect_length(buf: &[u8]) -> Result { let mut iter = buf.iter().peekable(); match iter.peek() { - Some(b'*') => RespArray::expect_length(buf), - Some(b'$') => BulkString::expect_length(buf), + Some(b'*') => Option::::expect_length(buf), + Some(b'$') => Option::::expect_length(buf), Some(b'+') => SimpleString::expect_length(buf), _ => Err(RespError::NotComplete), }