From 5c3ea61cf4ba2859049f67704bf910884435ba8d Mon Sep 17 00:00:00 2001 From: Adam Cattermole Date: Thu, 14 Nov 2024 14:01:05 +0000 Subject: [PATCH 1/2] Replace all unwraps Signed-off-by: Adam Cattermole --- src/configuration.rs | 10 +++++----- src/data/attribute.rs | 5 ++++- src/data/cel.rs | 38 +++++++++++++++++++++++++++++++------- src/filter.rs | 3 ++- src/service.rs | 18 ++++++++++++------ src/service/auth.rs | 6 +++--- src/service/rate_limit.rs | 2 +- 7 files changed, 58 insertions(+), 24 deletions(-) diff --git a/src/configuration.rs b/src/configuration.rs index 9d49cc69..de84f6dc 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -106,10 +106,7 @@ impl TryFrom for FilterConfig { .expect("Predicates must not be compiled yet!"); for datum in &action.data { - let result = datum.item.compile(); - if result.is_err() { - return Err(result.err().unwrap()); - } + datum.item.compile()?; } } @@ -204,7 +201,10 @@ impl<'de> Visitor<'de> for TimeoutVisitor { E: Error, { match duration(Arc::new(string)) { - Ok(Value::Duration(duration)) => Ok(Timeout(duration.to_std().unwrap())), + Ok(Value::Duration(duration)) => duration + .to_std() + .map(Timeout) + .map_err(|e| E::custom(e.to_string())), Err(e) => Err(E::custom(e)), _ => Err(E::custom("Unsupported Duration Value")), } diff --git a/src/data/attribute.rs b/src/data/attribute.rs index 9e65caa9..a3a1d230 100644 --- a/src/data/attribute.rs +++ b/src/data/attribute.rs @@ -166,7 +166,10 @@ fn process_metadata(s: &Struct, prefix: String) -> Vec<(String, String)> { let nested_struct = value.get_struct_value(); result.extend(process_metadata(nested_struct, current_prefix)); } else if let Some(v) = json { - result.push((current_prefix, serde_json::to_string(&v).unwrap())); + match serde_json::to_string(&v) { + Ok(ser) => result.push((current_prefix, ser)), + Err(e) => error!("failed to serialize json Value: {e:?}"), + } } } result diff --git a/src/data/cel.rs b/src/data/cel.rs index 844c726b..e83e9666 100644 --- a/src/data/cel.rs +++ b/src/data/cel.rs @@ -7,8 +7,10 @@ use cel_parser::{parse, Expression as CelExpression, Member, ParseError}; use chrono::{DateTime, FixedOffset}; #[cfg(feature = "debug-host-behaviour")] use log::debug; +use log::warn; use proxy_wasm::types::{Bytes, Status}; use serde_json::Value as JsonValue; +use std::borrow::Cow; use std::collections::hash_map::Entry; use std::collections::HashMap; use std::fmt::{Debug, Formatter}; @@ -87,7 +89,7 @@ impl Expression { /// Decodes the query string and returns a Map where the key is the parameter's name and /// the value is either a [`Value::String`] or a [`Value::List`] if the parameter's name is repeated -/// and the second arg is set not set to `false`. +/// and the second arg is not set to `false`. /// see [`tests::decodes_query_string`] fn decode_query_string(This(s): This>, Arguments(args): Arguments) -> ResolveResult { let allow_repeats = if args.len() == 2 { @@ -102,8 +104,22 @@ fn decode_query_string(This(s): This>, Arguments(args): Arguments) - for part in s.split('&') { let mut kv = part.split('='); if let (Some(key), Some(value)) = (kv.next(), kv.next().or(Some(""))) { - let new_v: Value = decode(value).unwrap().into_owned().into(); - match map.entry(decode(key).unwrap().into_owned().into()) { + let new_v: Value = decode(value) + .unwrap_or_else(|e| { + warn!("failed to decode query value, using default: {e:?}"); + Cow::from(value) + }) + .into_owned() + .into(); + match map.entry( + decode(key) + .unwrap_or_else(|e| { + warn!("failed to decode query key, using default: {e:?}"); + Cow::from(key) + }) + .into_owned() + .into(), + ) { Entry::Occupied(mut e) => { if allow_repeats { if let Value::List(ref mut list) = e.get_mut() { @@ -118,7 +134,15 @@ fn decode_query_string(This(s): This>, Arguments(args): Arguments) - } } Entry::Vacant(e) => { - e.insert(decode(value).unwrap().into_owned().into()); + e.insert( + decode(value) + .unwrap_or_else(|e| { + warn!("failed to decode query value, using default: {e:?}"); + Cow::from(value) + }) + .into_owned() + .into(), + ); } } } @@ -296,11 +320,11 @@ fn json_to_cel(json: &str) -> Value { JsonValue::Bool(b) => b.into(), JsonValue::Number(n) => { if n.is_u64() { - n.as_u64().unwrap().into() + n.as_u64().expect("Unreachable: number must be u64").into() } else if n.is_i64() { - n.as_i64().unwrap().into() + n.as_i64().expect("Unreachable: number must be i64").into() } else { - n.as_f64().unwrap().into() + n.as_f64().expect("Unreachable: number must be f64").into() } } JsonValue::String(str) => str.into(), diff --git a/src/filter.rs b/src/filter.rs index ab359182..ed9c2418 100644 --- a/src/filter.rs +++ b/src/filter.rs @@ -26,7 +26,8 @@ extern "C" fn start() { proxy_wasm::set_log_level(LogLevel::Trace); std::panic::set_hook(Box::new(|panic_info| { - proxy_wasm::hostcalls::log(LogLevel::Critical, &panic_info.to_string()).unwrap(); + proxy_wasm::hostcalls::log(LogLevel::Critical, &panic_info.to_string()) + .expect("failed to log panic_info"); })); proxy_wasm::set_root_context(|context_id| -> Box { info!("#{} set_root_context", context_id); diff --git a/src/service.rs b/src/service.rs index 404671fa..a981add3 100644 --- a/src/service.rs +++ b/src/service.rs @@ -13,6 +13,7 @@ use crate::service::TracingHeader::{Baggage, Traceparent, Tracestate}; use log::warn; use protobuf::Message; use proxy_wasm::hostcalls; +use proxy_wasm::types::Status::SerializationFailure; use proxy_wasm::types::{BufferType, Bytes, MapType, Status}; use std::cell::OnceCell; use std::rc::Rc; @@ -56,8 +57,8 @@ impl GrpcService { resp_size: usize, ) -> Result { let failure_mode = operation.get_failure_mode(); - if let Some(res_body_bytes) = - hostcalls::get_buffer(BufferType::GrpcReceiveBuffer, 0, resp_size).unwrap() + if let Ok(Some(res_body_bytes)) = + hostcalls::get_buffer(BufferType::GrpcReceiveBuffer, 0, resp_size) { match GrpcMessageResponse::new(operation.get_service_type(), &res_body_bytes) { Ok(res) => match operation.get_service_type() { @@ -75,7 +76,7 @@ impl GrpcService { } } } else { - warn!("grpc response body is empty!"); + warn!("failed to get grpc buffer or return data is null!"); GrpcService::handle_error_on_grpc_response(failure_mode); Err(StatusCode::InternalServerError) } @@ -85,9 +86,11 @@ impl GrpcService { match failure_mode { FailureMode::Deny => { hostcalls::send_http_response(500, vec![], Some(b"Internal Server Error.\n")) - .unwrap(); + .expect("failed to send_http_response 500"); + } + FailureMode::Allow => { + hostcalls::resume_http_request().expect("failed to resume_http_request") } - FailureMode::Allow => hostcalls::resume_http_request().unwrap(), } } } @@ -140,7 +143,10 @@ impl GrpcServiceHandler { message: GrpcMessageRequest, timeout: Duration, ) -> Result { - let msg = Message::write_to_bytes(&message).unwrap(); + let msg = Message::write_to_bytes(&message).map_err(|e| { + warn!("Failed to write protobuf message to bytes: {e:?}"); + SerializationFailure + })?; let metadata = self .header_resolver .get(get_map_values_bytes_fn) diff --git a/src/service/auth.rs b/src/service/auth.rs index 33af0689..925c4762 100644 --- a/src/service/auth.rs +++ b/src/service/auth.rs @@ -64,7 +64,7 @@ impl AuthService { let mut request = AttributeContext_Request::default(); let mut http = AttributeContext_HttpRequest::default(); let headers: HashMap = hostcalls::get_map(MapType::HttpRequestHeaders) - .unwrap() + .expect("failed to retrieve HttpRequestHeaders from host") .into_iter() .collect(); @@ -151,7 +151,7 @@ impl AuthService { header.get_header().get_key(), header.get_header().get_value(), ) - .unwrap() + .expect("failed to add_map_value to HttpRequestHeaders") }); Ok(GrpcResult::default()) } @@ -170,7 +170,7 @@ impl AuthService { response_headers, Some(denied_response.get_body().as_ref()), ) - .unwrap(); + .expect("failed to send_http_response"); Err(status_code) } None => { diff --git a/src/service/rate_limit.rs b/src/service/rate_limit.rs index 4d8f2424..d2a692f3 100644 --- a/src/service/rate_limit.rs +++ b/src/service/rate_limit.rs @@ -57,7 +57,7 @@ impl RateLimitService { response_headers.push((header.get_key(), header.get_value())); } hostcalls::send_http_response(429, response_headers, Some(b"Too Many Requests\n")) - .unwrap(); + .expect("failed to send_http_response 429 while OVER_LIMIT"); Err(StatusCode::TooManyRequests) } GrpcMessageResponse::RateLimit(RateLimitResponse { From b3d43079e475f276c367ac58e85b6a71ae9ddf04 Mon Sep 17 00:00:00 2001 From: Adam Cattermole Date: Thu, 14 Nov 2024 14:42:59 +0000 Subject: [PATCH 2/2] Expect in tests Signed-off-by: Adam Cattermole --- src/configuration.rs | 13 +++-- src/configuration/action_set_index.rs | 10 ++-- src/data/cel.rs | 48 ++++++++++-------- src/lib.rs | 4 +- src/operation_dispatcher.rs | 72 +++++++++++++++++++++++---- src/service/rate_limit.rs | 18 ++++++- 6 files changed, 120 insertions(+), 45 deletions(-) diff --git a/src/configuration.rs b/src/configuration.rs index de84f6dc..7f9771f1 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -277,7 +277,7 @@ mod test { } assert!(res.is_ok()); - let filter_config = res.unwrap(); + let filter_config = res.expect("result is ok"); assert_eq!(filter_config.action_sets.len(), 1); let services = &filter_config.services; @@ -364,7 +364,7 @@ mod test { } assert!(res.is_ok()); - let filter_config = res.unwrap(); + let filter_config = res.expect("result is ok"); assert_eq!(filter_config.action_sets.len(), 0); } @@ -410,12 +410,15 @@ mod test { } assert!(res.is_ok()); - let filter_config = res.unwrap(); + let filter_config = res.expect("result is ok"); assert_eq!(filter_config.action_sets.len(), 1); let services = &filter_config.services; assert_eq!( - services.get("limitador").unwrap().timeout, + services + .get("limitador") + .expect("limitador service to be set") + .timeout, Timeout(Duration::from_millis(20)) ); @@ -510,7 +513,7 @@ mod test { } assert!(res.is_ok()); - let result = FilterConfig::try_from(res.unwrap()); + let result = FilterConfig::try_from(res.expect("result is ok")); let filter_config = result.expect("That didn't work"); let rlp_option = filter_config .index diff --git a/src/configuration/action_set_index.rs b/src/configuration/action_set_index.rs index e2153cec..068fd96a 100644 --- a/src/configuration/action_set_index.rs +++ b/src/configuration/action_set_index.rs @@ -68,7 +68,7 @@ mod tests { let val = index.get_longest_match_action_sets("example.com"); assert!(val.is_some()); - assert_eq!(val.unwrap()[0].name, "rlp1"); + assert_eq!(val.expect("value must be some")[0].name, "rlp1"); } #[test] @@ -90,7 +90,7 @@ mod tests { let val = index.get_longest_match_action_sets("test.example.com"); assert!(val.is_some()); - assert_eq!(val.unwrap()[0].name, "rlp1"); + assert_eq!(val.expect("value must be some")[0].name, "rlp1"); } #[test] @@ -103,11 +103,11 @@ mod tests { let val = index.get_longest_match_action_sets("test.example.com"); assert!(val.is_some()); - assert_eq!(val.unwrap()[0].name, "rlp2"); + assert_eq!(val.expect("value must be some")[0].name, "rlp2"); let val = index.get_longest_match_action_sets("example.com"); assert!(val.is_some()); - assert_eq!(val.unwrap()[0].name, "rlp1"); + assert_eq!(val.expect("value must be some")[0].name, "rlp1"); } #[test] @@ -118,6 +118,6 @@ mod tests { let val = index.get_longest_match_action_sets("test.example.com"); assert!(val.is_some()); - assert_eq!(val.unwrap()[0].name, "rlp1"); + assert_eq!(val.expect("value must be some")[0].name, "rlp1"); } } diff --git a/src/data/cel.rs b/src/data/cel.rs index e83e9666..6ef63020 100644 --- a/src/data/cel.rs +++ b/src/data/cel.rs @@ -557,10 +557,14 @@ pub mod data { fn it_works() { let map = AttributeMap::new( [ - known_attribute_for(&"request.method".into()).unwrap(), - known_attribute_for(&"request.referer".into()).unwrap(), - known_attribute_for(&"source.address".into()).unwrap(), - known_attribute_for(&"destination.port".into()).unwrap(), + known_attribute_for(&"request.method".into()) + .expect("request.method known attribute exists"), + known_attribute_for(&"request.referer".into()) + .expect("request.referer known attribute exists"), + known_attribute_for(&"source.address".into()) + .expect("source.address known attribute exists"), + known_attribute_for(&"destination.port".into()) + .expect("destination.port known attribute exists"), ] .into(), ); @@ -572,10 +576,10 @@ pub mod data { assert!(map.data.contains_key("destination")); assert!(map.data.contains_key("request")); - match map.data.get("source").unwrap() { + match map.data.get("source").expect("source is some") { Token::Node(map) => { assert_eq!(map.len(), 1); - match map.get("address").unwrap() { + match map.get("address").expect("address is some") { Token::Node(_) => panic!("Not supposed to get here!"), Token::Value(v) => assert_eq!(v.path, "source.address".into()), } @@ -583,10 +587,10 @@ pub mod data { Token::Value(_) => panic!("Not supposed to get here!"), } - match map.data.get("destination").unwrap() { + match map.data.get("destination").expect("destination is some") { Token::Node(map) => { assert_eq!(map.len(), 1); - match map.get("port").unwrap() { + match map.get("port").expect("port is some") { Token::Node(_) => panic!("Not supposed to get here!"), Token::Value(v) => assert_eq!(v.path, "destination.port".into()), } @@ -594,16 +598,16 @@ pub mod data { Token::Value(_) => panic!("Not supposed to get here!"), } - match map.data.get("request").unwrap() { + match map.data.get("request").expect("request is some") { Token::Node(map) => { assert_eq!(map.len(), 2); assert!(map.get("method").is_some()); - match map.get("method").unwrap() { + match map.get("method").expect("method is some") { Token::Node(_) => panic!("Not supposed to get here!"), Token::Value(v) => assert_eq!(v.path, "request.method".into()), } assert!(map.get("referer").is_some()); - match map.get("referer").unwrap() { + match map.get("referer").expect("referer is some") { Token::Node(_) => panic!("Not supposed to get here!"), Token::Value(v) => assert_eq!(v.path, "request.referer".into()), } @@ -635,7 +639,7 @@ mod tests { let value = Expression::new( "auth.identity.anonymous && auth.identity != null && auth.identity.foo > 3", ) - .unwrap(); + .expect("This is valid CEL!"); assert_eq!(value.attributes.len(), 3); assert_eq!(value.attributes[0].path, "auth.identity".into()); } @@ -650,7 +654,7 @@ mod tests { "true".bytes().collect(), ))); let value = Expression::new("auth.identity.anonymous") - .unwrap() + .expect("This is valid CEL!") .eval() .expect("This must evaluate!"); assert_eq!(value, true.into()); @@ -660,7 +664,7 @@ mod tests { "42".bytes().collect(), ))); let value = Expression::new("auth.identity.age") - .unwrap() + .expect("This is valid CEL!") .eval() .expect("This must evaluate!"); assert_eq!(value, 42.into()); @@ -670,7 +674,7 @@ mod tests { "42.3".bytes().collect(), ))); let value = Expression::new("auth.identity.age") - .unwrap() + .expect("This is valid CEL!") .eval() .expect("This must evaluate!"); assert_eq!(value, 42.3.into()); @@ -680,7 +684,7 @@ mod tests { "\"John\"".bytes().collect(), ))); let value = Expression::new("auth.identity.age") - .unwrap() + .expect("This is valid CEL!") .eval() .expect("This must evaluate!"); assert_eq!(value, "John".into()); @@ -690,7 +694,7 @@ mod tests { "-42".bytes().collect(), ))); let value = Expression::new("auth.identity.name") - .unwrap() + .expect("This is valid CEL!") .eval() .expect("This must evaluate!"); assert_eq!(value, (-42).into()); @@ -701,7 +705,7 @@ mod tests { "some random crap".bytes().collect(), ))); let value = Expression::new("auth.identity.age") - .unwrap() + .expect("This is valid CEL!") .eval() .expect("This must evaluate!"); assert_eq!(value, "some random crap".into()); @@ -775,12 +779,14 @@ mod tests { 80_i64.to_le_bytes().into(), ))); let value = known_attribute_for(&"destination.port".into()) - .unwrap() + .expect("destination.port known attribute exists") .get(); assert_eq!(value, 80.into()); property::test::TEST_PROPERTY_VALUE .set(Some(("request.method".into(), "GET".bytes().collect()))); - let value = known_attribute_for(&"request.method".into()).unwrap().get(); + let value = known_attribute_for(&"request.method".into()) + .expect("request.method known attribute exists") + .get(); assert_eq!(value, "GET".into()); } @@ -802,7 +808,7 @@ mod tests { b"\xCA\xFE".to_vec(), ))); let value = Expression::new("getHostProperty(['foo', 'bar.baz'])") - .unwrap() + .expect("This is valid CEL!") .eval() .expect("This must evaluate!"); assert_eq!(value, Value::Bytes(Arc::new(b"\xCA\xFE".to_vec()))); diff --git a/src/lib.rs b/src/lib.rs index 9f065b3c..ffaf17aa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,7 +20,9 @@ mod tests { .push(header("test", "some value")); resp.response_headers_to_add .push(header("other", "header value")); - let buffer = resp.write_to_bytes().unwrap(); + let buffer = resp + .write_to_bytes() + .expect("must be able to write RateLimitResponse to bytes"); let expected: [u8; 45] = [ 8, 1, 26, 18, 10, 4, 116, 101, 115, 116, 18, 10, 115, 111, 109, 101, 32, 118, 97, 108, 117, 101, 26, 21, 10, 5, 111, 116, 104, 101, 114, 18, 12, 104, 101, 97, 100, 101, 114, diff --git a/src/operation_dispatcher.rs b/src/operation_dispatcher.rs index cc171431..3b7779cc 100644 --- a/src/operation_dispatcher.rs +++ b/src/operation_dispatcher.rs @@ -463,34 +463,84 @@ mod tests { assert_eq!(operation_dispatcher.waiting_operations.len(), 0); let mut op = operation_dispatcher.next(); - assert_eq!(op.clone().unwrap().unwrap().get_result(), Ok(66)); assert_eq!( - *op.clone().unwrap().unwrap().get_service_type(), + op.clone() + .expect("ok result") + .expect("operation is some") + .get_result(), + Ok(66) + ); + assert_eq!( + *op.clone() + .expect("ok result") + .expect("operation is some") + .get_service_type(), ServiceType::RateLimit ); - assert_eq!(op.unwrap().unwrap().get_state(), State::Waiting); + assert_eq!( + op.expect("ok result") + .expect("operation is some") + .get_state(), + State::Waiting + ); assert_eq!(operation_dispatcher.waiting_operations.len(), 1); op = operation_dispatcher.next(); - assert_eq!(op.clone().unwrap().unwrap().get_result(), Ok(66)); - assert_eq!(op.unwrap().unwrap().get_state(), State::Done); + assert_eq!( + op.clone() + .expect("ok result") + .expect("operation is some") + .get_result(), + Ok(66) + ); + assert_eq!( + op.expect("ok result") + .expect("operation is some") + .get_state(), + State::Done + ); op = operation_dispatcher.next(); - assert_eq!(op.clone().unwrap().unwrap().get_result(), Ok(77)); assert_eq!( - *op.clone().unwrap().unwrap().get_service_type(), + op.clone() + .expect("ok result") + .expect("operation is some") + .get_result(), + Ok(77) + ); + assert_eq!( + *op.clone() + .expect("ok result") + .expect("operation is some") + .get_service_type(), ServiceType::Auth ); - assert_eq!(op.unwrap().unwrap().get_state(), State::Waiting); + assert_eq!( + op.expect("ok result") + .expect("operation is some") + .get_state(), + State::Waiting + ); assert_eq!(operation_dispatcher.waiting_operations.len(), 1); op = operation_dispatcher.next(); - assert_eq!(op.clone().unwrap().unwrap().get_result(), Ok(77)); - assert_eq!(op.unwrap().unwrap().get_state(), State::Done); + assert_eq!( + op.clone() + .expect("ok result") + .expect("operation is some") + .get_result(), + Ok(77) + ); + assert_eq!( + op.expect("ok result") + .expect("operation is some") + .get_state(), + State::Done + ); assert_eq!(operation_dispatcher.waiting_operations.len(), 1); op = operation_dispatcher.next(); - assert!(op.unwrap().is_none()); + assert!(op.expect("ok result").is_none()); assert!(operation_dispatcher.get_current_operation_state().is_none()); assert_eq!(operation_dispatcher.waiting_operations.len(), 0); } diff --git a/src/service/rate_limit.rs b/src/service/rate_limit.rs index d2a692f3..b4dcf809 100644 --- a/src/service/rate_limit.rs +++ b/src/service/rate_limit.rs @@ -106,8 +106,22 @@ mod tests { assert_eq!(msg.hits_addend, 1); assert_eq!(msg.domain, "rlp1".to_string()); - assert_eq!(msg.descriptors.first().unwrap().entries[0].key, "key1"); - assert_eq!(msg.descriptors.first().unwrap().entries[0].value, "value1"); + assert_eq!( + msg.descriptors + .first() + .expect("must have a descriptor") + .entries[0] + .key, + "key1" + ); + assert_eq!( + msg.descriptors + .first() + .expect("must have a descriptor") + .entries[0] + .value, + "value1" + ); assert_eq!(msg.unknown_fields, UnknownFields::default()); assert_eq!(msg.cached_size, CachedSize::default()); }