diff --git a/src/envoy/mod.rs b/src/envoy/mod.rs index db52720..6195cfa 100644 --- a/src/envoy/mod.rs +++ b/src/envoy/mod.rs @@ -37,7 +37,7 @@ pub use { AttributeContext_Request, }, base::Metadata, - external_auth::CheckRequest, + external_auth::{CheckRequest, DeniedHttpResponse, OkHttpResponse}, ratelimit::{RateLimitDescriptor, RateLimitDescriptor_Entry}, rls::{RateLimitRequest, RateLimitResponse, RateLimitResponse_Code}, }; diff --git a/src/service/auth.rs b/src/service/auth.rs index 0831cd6..ece695c 100644 --- a/src/service/auth.rs +++ b/src/service/auth.rs @@ -3,10 +3,12 @@ use crate::envoy::{ Address, AttributeContext, AttributeContext_HttpRequest, AttributeContext_Peer, AttributeContext_Request, CheckRequest, Metadata, SocketAddress, }; +use crate::service::grpc_message::{GrpcMessageResponse, GrpcMessageResult}; use chrono::{DateTime, FixedOffset, Timelike}; use protobuf::well_known_types::Timestamp; +use protobuf::Message; use proxy_wasm::hostcalls; -use proxy_wasm::types::MapType; +use proxy_wasm::types::{Bytes, MapType}; use std::collections::HashMap; pub const AUTH_SERVICE_NAME: &str = "envoy.service.auth.v3.Authorization"; @@ -16,10 +18,35 @@ pub struct AuthService; #[allow(dead_code)] impl AuthService { - pub fn message(ce_host: String) -> CheckRequest { + pub fn request_message(ce_host: String) -> CheckRequest { AuthService::build_check_req(ce_host) } + pub fn response_message( + res_body_bytes: &Bytes, + status_code: u32, + ) -> GrpcMessageResult { + if status_code % 2 == 0 { + AuthService::response_message_ok(res_body_bytes) + } else { + AuthService::response_message_denied(res_body_bytes) + } + } + + fn response_message_ok(res_body_bytes: &Bytes) -> GrpcMessageResult { + match Message::parse_from_bytes(res_body_bytes) { + Ok(res) => Ok(GrpcMessageResponse::AuthOk(res)), + Err(e) => Err(e), + } + } + + fn response_message_denied(res_body_bytes: &Bytes) -> GrpcMessageResult { + match Message::parse_from_bytes(res_body_bytes) { + Ok(res) => Ok(GrpcMessageResponse::AuthDenied(res)), + Err(e) => Err(e), + } + } + fn build_check_req(ce_host: String) -> CheckRequest { let mut auth_req = CheckRequest::default(); let mut attr = AttributeContext::default(); diff --git a/src/service/grpc_message.rs b/src/service/grpc_message.rs index ad41fab..0ed70c9 100644 --- a/src/service/grpc_message.rs +++ b/src/service/grpc_message.rs @@ -1,11 +1,16 @@ use crate::configuration::ExtensionType; -use crate::envoy::{CheckRequest, RateLimitDescriptor, RateLimitRequest}; +use crate::envoy::{ + CheckRequest, DeniedHttpResponse, OkHttpResponse, RateLimitDescriptor, RateLimitRequest, + RateLimitResponse, +}; use crate::service::auth::AuthService; use crate::service::rate_limit::RateLimitService; use protobuf::reflect::MessageDescriptor; use protobuf::{ - Clear, CodedInputStream, CodedOutputStream, Message, ProtobufResult, UnknownFields, + Clear, CodedInputStream, CodedOutputStream, Message, ProtobufError, ProtobufResult, + UnknownFields, }; +use proxy_wasm::types::Bytes; use std::any::Any; #[derive(Clone, Debug)] @@ -126,11 +131,145 @@ impl GrpcMessageRequest { descriptors: protobuf::RepeatedField, ) -> Self { match extension_type { - ExtensionType::RateLimit => GrpcMessageRequest::RateLimit(RateLimitService::message( - domain.clone(), - descriptors, - )), - ExtensionType::Auth => GrpcMessageRequest::Auth(AuthService::message(domain.clone())), + ExtensionType::RateLimit => GrpcMessageRequest::RateLimit( + RateLimitService::request_message(domain.clone(), descriptors), + ), + ExtensionType::Auth => { + GrpcMessageRequest::Auth(AuthService::request_message(domain.clone())) + } } } } + +#[derive(Clone, Debug)] +pub enum GrpcMessageResponse { + AuthOk(OkHttpResponse), + AuthDenied(DeniedHttpResponse), + RateLimit(RateLimitResponse), +} + +impl Default for GrpcMessageResponse { + fn default() -> Self { + GrpcMessageResponse::RateLimit(RateLimitResponse::new()) + } +} + +impl Clear for GrpcMessageResponse { + fn clear(&mut self) { + todo!() + } +} + +impl Message for GrpcMessageResponse { + fn descriptor(&self) -> &'static MessageDescriptor { + match self { + GrpcMessageResponse::AuthOk(res) => res.descriptor(), + GrpcMessageResponse::AuthDenied(res) => res.descriptor(), + GrpcMessageResponse::RateLimit(res) => res.descriptor(), + } + } + + fn is_initialized(&self) -> bool { + match self { + GrpcMessageResponse::AuthOk(res) => res.is_initialized(), + GrpcMessageResponse::AuthDenied(res) => res.is_initialized(), + GrpcMessageResponse::RateLimit(res) => res.is_initialized(), + } + } + + fn merge_from(&mut self, is: &mut CodedInputStream) -> ProtobufResult<()> { + match self { + GrpcMessageResponse::AuthOk(res) => res.merge_from(is), + GrpcMessageResponse::AuthDenied(res) => res.merge_from(is), + GrpcMessageResponse::RateLimit(res) => res.merge_from(is), + } + } + + fn write_to_with_cached_sizes(&self, os: &mut CodedOutputStream) -> ProtobufResult<()> { + match self { + GrpcMessageResponse::AuthOk(res) => res.write_to_with_cached_sizes(os), + GrpcMessageResponse::AuthDenied(res) => res.write_to_with_cached_sizes(os), + GrpcMessageResponse::RateLimit(res) => res.write_to_with_cached_sizes(os), + } + } + + fn write_to_bytes(&self) -> ProtobufResult> { + match self { + GrpcMessageResponse::AuthOk(res) => res.write_to_bytes(), + GrpcMessageResponse::AuthDenied(res) => res.write_to_bytes(), + GrpcMessageResponse::RateLimit(res) => res.write_to_bytes(), + } + } + + fn compute_size(&self) -> u32 { + match self { + GrpcMessageResponse::AuthOk(res) => res.compute_size(), + GrpcMessageResponse::AuthDenied(res) => res.compute_size(), + GrpcMessageResponse::RateLimit(res) => res.compute_size(), + } + } + + fn get_cached_size(&self) -> u32 { + match self { + GrpcMessageResponse::AuthOk(res) => res.get_cached_size(), + GrpcMessageResponse::AuthDenied(res) => res.get_cached_size(), + GrpcMessageResponse::RateLimit(res) => res.get_cached_size(), + } + } + + fn get_unknown_fields(&self) -> &UnknownFields { + match self { + GrpcMessageResponse::AuthOk(res) => res.get_unknown_fields(), + GrpcMessageResponse::AuthDenied(res) => res.get_unknown_fields(), + GrpcMessageResponse::RateLimit(res) => res.get_unknown_fields(), + } + } + + fn mut_unknown_fields(&mut self) -> &mut UnknownFields { + match self { + GrpcMessageResponse::AuthOk(res) => res.mut_unknown_fields(), + GrpcMessageResponse::AuthDenied(res) => res.mut_unknown_fields(), + GrpcMessageResponse::RateLimit(res) => res.mut_unknown_fields(), + } + } + + fn as_any(&self) -> &dyn Any { + match self { + GrpcMessageResponse::AuthOk(res) => res.as_any(), + GrpcMessageResponse::AuthDenied(res) => res.as_any(), + GrpcMessageResponse::RateLimit(res) => res.as_any(), + } + } + + fn new() -> Self + where + Self: Sized, + { + // Returning default value + GrpcMessageResponse::default() + } + + fn default_instance() -> &'static Self + where + Self: Sized, + { + #[allow(non_upper_case_globals)] + static instance: ::protobuf::rt::LazyV2 = ::protobuf::rt::LazyV2::INIT; + instance.get(|| GrpcMessageResponse::RateLimit(RateLimitResponse::new())) + } +} + +impl GrpcMessageResponse { + pub fn new( + extension_type: ExtensionType, + res_body_bytes: &Bytes, + status_code: u32, + ) -> GrpcMessageResult { + match extension_type { + ExtensionType::RateLimit => RateLimitService::response_message(res_body_bytes), + ExtensionType::Auth => AuthService::response_message(res_body_bytes, status_code), + } + } +} + +pub type GrpcMessageResult = Result; diff --git a/src/service/rate_limit.rs b/src/service/rate_limit.rs index b6b0357..4a81884 100644 --- a/src/service/rate_limit.rs +++ b/src/service/rate_limit.rs @@ -1,5 +1,7 @@ use crate::envoy::{RateLimitDescriptor, RateLimitRequest}; -use protobuf::RepeatedField; +use crate::service::grpc_message::{GrpcMessageResponse, GrpcMessageResult}; +use protobuf::{Message, RepeatedField}; +use proxy_wasm::types::Bytes; pub const RATELIMIT_SERVICE_NAME: &str = "envoy.service.ratelimit.v3.RateLimitService"; pub const RATELIMIT_METHOD_NAME: &str = "ShouldRateLimit"; @@ -7,7 +9,7 @@ pub const RATELIMIT_METHOD_NAME: &str = "ShouldRateLimit"; pub struct RateLimitService; impl RateLimitService { - pub fn message( + pub fn request_message( domain: String, descriptors: RepeatedField, ) -> RateLimitRequest { @@ -19,6 +21,13 @@ impl RateLimitService { cached_size: Default::default(), } } + + pub fn response_message(res_body_bytes: &Bytes) -> GrpcMessageResult { + match Message::parse_from_bytes(res_body_bytes) { + Ok(res) => Ok(GrpcMessageResponse::RateLimit(res)), + Err(e) => Err(e), + } + } } #[cfg(test)] @@ -37,7 +46,7 @@ mod tests { field.set_entries(RepeatedField::from_vec(vec![entry])); let descriptors = RepeatedField::from_vec(vec![field]); - RateLimitService::message(domain.to_string(), descriptors.clone()) + RateLimitService::request_message(domain.to_string(), descriptors.clone()) } #[test] fn builds_correct_message() {