diff --git a/src/filter/http_context.rs b/src/filter/http_context.rs index 717563a..281eab9 100644 --- a/src/filter/http_context.rs +++ b/src/filter/http_context.rs @@ -1,18 +1,17 @@ use crate::configuration::{FailureMode, FilterConfig}; -use crate::envoy::{RateLimitRequest, RateLimitResponse, RateLimitResponse_Code}; +use crate::envoy::{RateLimitResponse, RateLimitResponse_Code}; use crate::filter::http_context::TracingHeader::{Baggage, Traceparent, Tracestate}; use crate::policy::Policy; +use crate::service::rate_limit::RateLimitService; +use crate::service::Service; use log::{debug, warn}; use protobuf::Message; use proxy_wasm::traits::{Context, HttpContext}; use proxy_wasm::types::{Action, Bytes}; use std::rc::Rc; -use std::time::Duration; - -const RATELIMIT_SERVICE_NAME: &str = "envoy.service.ratelimit.v3.RateLimitService"; -const RATELIMIT_METHOD_NAME: &str = "ShouldRateLimit"; // tracing headers +#[derive(Clone)] pub enum TracingHeader { Traceparent, Tracestate, @@ -24,7 +23,7 @@ impl TracingHeader { [Traceparent, Tracestate, Baggage] } - fn as_str(&self) -> &'static str { + pub fn as_str(&self) -> &'static str { match self { Traceparent => "traceparent", Tracestate => "tracestate", @@ -64,27 +63,10 @@ impl Filter { return Action::Continue; } - let mut rl_req = RateLimitRequest::new(); - rl_req.set_domain(rlp.domain.clone()); - rl_req.set_hits_addend(1); - rl_req.set_descriptors(descriptors); - - let rl_req_serialized = Message::write_to_bytes(&rl_req).unwrap(); // TODO(rahulanand16nov): Error Handling - - let rl_tracing_headers = self - .tracing_headers - .iter() - .map(|(header, value)| (header.as_str(), value.as_slice())) - .collect(); - - match self.dispatch_grpc_call( - rlp.service.as_str(), - RATELIMIT_SERVICE_NAME, - RATELIMIT_METHOD_NAME, - rl_tracing_headers, - Some(&rl_req_serialized), - Duration::from_secs(5), - ) { + let rls = RateLimitService::new(rlp.service.as_str(), self.tracing_headers.clone()); + let message = RateLimitService::message(rlp.domain.clone(), descriptors); + + match rls.send(message) { Ok(call_id) => { debug!( "#{} initiated gRPC call (id# {}) to Limitador", diff --git a/src/lib.rs b/src/lib.rs index 8ee6c31..fb1c60a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ mod filter; mod glob; mod policy; mod policy_index; +mod service; #[cfg(test)] mod tests { diff --git a/src/service.rs b/src/service.rs new file mode 100644 index 0000000..b63bb82 --- /dev/null +++ b/src/service.rs @@ -0,0 +1,8 @@ +pub(crate) mod rate_limit; + +use protobuf::Message; +use proxy_wasm::types::Status; + +pub trait Service { + fn send(&self, message: M) -> Result; +} diff --git a/src/service/rate_limit.rs b/src/service/rate_limit.rs new file mode 100644 index 0000000..6c4726c --- /dev/null +++ b/src/service/rate_limit.rs @@ -0,0 +1,113 @@ +use crate::envoy::{RateLimitDescriptor, RateLimitRequest}; +use crate::filter::http_context::TracingHeader; +use crate::service::Service; +use protobuf::{Message, RepeatedField}; +use proxy_wasm::hostcalls::dispatch_grpc_call; +use proxy_wasm::types::{Bytes, Status}; +use std::time::Duration; + +const RATELIMIT_SERVICE_NAME: &str = "envoy.service.ratelimit.v3.RateLimitService"; +const RATELIMIT_METHOD_NAME: &str = "ShouldRateLimit"; +pub struct RateLimitService { + endpoint: String, + tracing_headers: Vec<(TracingHeader, Bytes)>, +} + +impl RateLimitService { + pub fn new(endpoint: &str, metadata: Vec<(TracingHeader, Bytes)>) -> RateLimitService { + Self { + endpoint: String::from(endpoint), + tracing_headers: metadata, + } + } + pub fn message( + domain: String, + descriptors: RepeatedField, + ) -> RateLimitRequest { + RateLimitRequest { + domain, + descriptors, + hits_addend: 1, + unknown_fields: Default::default(), + cached_size: Default::default(), + } + } +} + +fn grpc_call( + upstream_name: &str, + initial_metadata: Vec<(&str, &[u8])>, + message: RateLimitRequest, +) -> Result { + let msg = Message::write_to_bytes(&message).unwrap(); // TODO(didierofrivia): Error Handling + dispatch_grpc_call( + upstream_name, + RATELIMIT_SERVICE_NAME, + RATELIMIT_METHOD_NAME, + initial_metadata, + Some(&msg), + Duration::from_secs(5), + ) +} + +impl Service for RateLimitService { + fn send(&self, message: RateLimitRequest) -> Result { + grpc_call( + self.endpoint.as_str(), + self.tracing_headers + .iter() + .map(|(header, value)| (header.as_str(), value.as_slice())) + .collect(), + message, + ) + } +} + +#[cfg(test)] +mod tests { + use crate::envoy::{RateLimitDescriptor, RateLimitDescriptor_Entry, RateLimitRequest}; + use crate::service::rate_limit::RateLimitService; + //use crate::service::Service; + use protobuf::{CachedSize, RepeatedField, UnknownFields}; + //use proxy_wasm::types::Status; + //use crate::filter::http_context::{Filter}; + + fn build_message() -> RateLimitRequest { + let domain = "rlp1"; + let mut field = RateLimitDescriptor::new(); + let mut entry = RateLimitDescriptor_Entry::new(); + entry.set_key("key1".to_string()); + entry.set_value("value1".to_string()); + field.set_entries(RepeatedField::from_vec(vec![entry])); + let descriptors = RepeatedField::from_vec(vec![field]); + + RateLimitService::message(domain.to_string(), descriptors.clone()) + } + #[test] + fn builds_correct_message() { + let msg = build_message(); + + assert_eq!(msg.hits_addend, 1); + assert_eq!(msg.domain, "rlp1".to_string()); + assert_eq!(msg.descriptors.first().unwrap().entries[0].key, "key1"); + assert_eq!(msg.descriptors.first().unwrap().entries[0].value, "value1"); + assert_eq!(msg.unknown_fields, UnknownFields::default()); + assert_eq!(msg.cached_size, CachedSize::default()); + } + /*#[test] + fn sends_message() { + let msg = build_message(); + let metadata = vec![("header-1", "value-1".as_bytes())]; + let rls = RateLimitService::new("limitador-cluster", metadata); + + // TODO(didierofrivia): When we have a grpc response type, assert the async response + } + + fn grpc_call( + _upstream_name: &str, + _initial_metadata: Vec<(&str, &[u8])>, + _message: RateLimitRequest, + ) -> Result { + Ok(1) + } */ +}