Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adds hmget command #487

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions src/protocol/resp/src/request/hmget.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
use super::*;
use std::io::{Error, ErrorKind};
use std::sync::Arc;

type ArcByteSlice = Arc<Box<[u8]>>;
#[derive(Debug, PartialEq, Eq)]
pub struct HmGetRequest {
key: ArcByteSlice,
fields: Arc<Box<[ArcByteSlice]>>,
}

impl HmGetRequest {
pub fn key(&self) -> &[u8] {
&self.key
}

pub fn fields(&self) -> Box<[&[u8]]> {
self.fields
.iter()
.map(|f| &***f)
.collect::<Vec<&[u8]>>()
.into_boxed_slice()
}
}

impl TryFrom<Message> for HmGetRequest {
type Error = Error;

fn try_from(other: Message) -> Result<Self, Error> {
if let Message::Array(array) = other {
if array.inner.is_none() {
return Err(Error::new(ErrorKind::Other, "malformed command"));
}

let mut array = array.inner.unwrap();

if array.len() <= 2 {
return Err(Error::new(ErrorKind::Other, "malformed command"));
}

let key = take_bulk_string(&mut array)?;
if key.is_empty() {
return Err(Error::new(ErrorKind::Other, "malformed command"));
}

let mut fields = Vec::with_capacity(array.len());
while array.len() >= 2 {
let field = take_bulk_string(&mut array)?;
if field.is_empty() {
return Err(Error::new(ErrorKind::Other, "malformed command"));
}

fields.push(field);
}

let f = Arc::new(Box::<[ArcByteSlice]>::from(fields));
Ok(Self { key, fields: f })
} else {
Err(Error::new(ErrorKind::Other, "malformed command"))
}
}
}

impl From<&HmGetRequest> for Message {
fn from(other: &HmGetRequest) -> Message {
let mut v = vec![
Message::bulk_string(b"HMGET"),
Message::BulkString(BulkString::from(other.key.clone())),
];
for kv in (*other.fields).iter() {
v.push(Message::BulkString(BulkString::from(kv.clone())));
}

Message::Array(Array { inner: Some(v) })
}
}

impl Compose for HmGetRequest {
fn compose(&self, buf: &mut dyn BufMut) -> usize {
let message = Message::from(self);
message.compose(buf)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn parser() {
let parser = RequestParser::new();

//1 field
if let Request::HmGet(request) = parser.parse(b"hmget key field1\r\n").unwrap().into_inner()
{
assert_eq!(request.key(), b"key");
assert_eq!(request.fields().len(), 1);
assert_eq!(request.fields()[0], b"field1");
} else {
panic!("invalid parse result");
}

//2 fields
if let Request::HmGet(request) = parser
.parse(b"hmget key field1 field2\r\n")
.unwrap()
.into_inner()
{
assert_eq!(request.key(), b"key");
assert_eq!(request.fields().len(), 2);
assert_eq!(request.fields()[0], b"field1");
assert_eq!(request.fields()[1], b"field2");
} else {
panic!("invalid parse result");
}

//3 fields
if let Request::HmGet(request) = parser
.parse(b"hmget key field1 field2 42\r\n")
.unwrap()
.into_inner()
{
assert_eq!(request.key(), b"key");
assert_eq!(request.fields().len(), 3);
assert_eq!(request.fields()[0], b"field1");
assert_eq!(request.fields()[1], b"field2");
assert_eq!(request.fields()[2], b"42");
} else {
panic!("invalid parse result");
}

//insufficient whitespace delimited strings
parser
.parse(b"hmget key\r\n")
.expect_err("malformed command");
}
}
15 changes: 15 additions & 0 deletions src/protocol/resp/src/request/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ use std::sync::Arc;

mod badd;
mod get;
mod hmget;
mod set;

pub use badd::BAddRequest;
pub use get::GetRequest;
pub use hmget::HmGetRequest;
pub use set::SetRequest;

#[derive(Default)]
Expand Down Expand Up @@ -95,6 +97,9 @@ impl Parse<Request> for RequestParser {
Some(b"get") | Some(b"GET") => {
GetRequest::try_from(message).map(Request::from)
}
Some(b"hmget") | Some(b"HMGET") => {
HmGetRequest::try_from(message).map(Request::from)
}
Some(b"set") | Some(b"SET") => {
SetRequest::try_from(message).map(Request::from)
}
Expand All @@ -120,6 +125,7 @@ impl Compose for Request {
match self {
Self::BAdd(r) => r.compose(buf),
Self::Get(r) => r.compose(buf),
Self::HmGet(r) => r.compose(buf),
Self::Set(r) => r.compose(buf),
}
}
Expand All @@ -129,6 +135,7 @@ impl Compose for Request {
pub enum Request {
BAdd(BAddRequest),
Get(GetRequest),
HmGet(HmGetRequest),
Set(SetRequest),
}

Expand All @@ -144,6 +151,12 @@ impl From<GetRequest> for Request {
}
}

impl From<HmGetRequest> for Request {
fn from(other: HmGetRequest) -> Self {
Self::HmGet(other)
}
}

impl From<SetRequest> for Request {
fn from(other: SetRequest) -> Self {
Self::Set(other)
Expand All @@ -154,6 +167,7 @@ impl From<SetRequest> for Request {
pub enum Command {
BAdd,
Get,
HmGet,
Set,
}

Expand All @@ -164,6 +178,7 @@ impl TryFrom<&[u8]> for Command {
match other {
b"badd" | b"BADD" => Ok(Command::BAdd),
b"get" | b"GET" => Ok(Command::Get),
b"hmget" | b"HMGET" => Ok(Command::HmGet),
b"set" | b"SET" => Ok(Command::Set),
_ => Err(()),
}
Expand Down