Skip to content

Commit

Permalink
partially fix headers
Browse files Browse the repository at this point in the history
  • Loading branch information
ssddOnTop committed Sep 19, 2024
1 parent c5c319a commit d159f4e
Show file tree
Hide file tree
Showing 22 changed files with 132 additions and 38 deletions.
4 changes: 2 additions & 2 deletions benches/request_template_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ impl PathString for Context {
}
}
impl HasHeaders for Context {
fn headers(&self) -> &HeaderMap {
&self.headers
fn headers(&self) -> HeaderMap {
self.headers.to_owned()
}
}
pub fn benchmark_to_request(c: &mut Criterion) {
Expand Down
18 changes: 15 additions & 3 deletions src/core/app_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,14 @@ impl AppContext {
field.map_expr(|expr| {
expr.modify(|expr| match expr {
IR::IO(io) => match io {
IO::Http { req_template, group_by, http_filter, batch, .. } => {
IO::Http {
req_template,
group_by,
http_filter,
batch,
headers,
..
} => {
let data_loader = HttpDataLoader::new(
runtime.clone(),
group_by.clone(),
Expand All @@ -61,14 +68,17 @@ impl AppContext {
dl_id: Some(DataLoaderId::new(http_data_loaders.len())),
http_filter: http_filter.clone(),
batch: batch.clone(),
headers: headers.clone(),
}));

http_data_loaders.push(data_loader);

result
}

IO::GraphQL { req_template, field_name, batch, .. } => {
IO::GraphQL {
req_template, field_name, batch, headers, ..
} => {
let graphql_data_loader =
GraphqlDataLoader::new(runtime.clone(), batch.is_some())
.into_data_loader(batch.clone().unwrap_or_default());
Expand All @@ -78,14 +88,15 @@ impl AppContext {
field_name: field_name.clone(),
batch: batch.clone(),
dl_id: Some(DataLoaderId::new(gql_data_loaders.len())),
headers: headers.clone(),
}));

gql_data_loaders.push(graphql_data_loader);

result
}

IO::Grpc { req_template, group_by, batch, .. } => {
IO::Grpc { req_template, group_by, batch, headers, .. } => {
let data_loader = GrpcDataLoader {
runtime: runtime.clone(),
operation: req_template.operation.clone(),
Expand All @@ -99,6 +110,7 @@ impl AppContext {
group_by: group_by.clone(),
dl_id: Some(DataLoaderId::new(grpc_data_loaders.len())),
batch: batch.clone(),
headers: headers.clone(),
}));

grpc_data_loaders.push(data_loader);
Expand Down
6 changes: 5 additions & 1 deletion src/core/auth/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ impl Verify for BasicVerifier {
async fn verify(&self, req_ctx: &RequestContext) -> Verification {
let header = req_ctx
.allowed_headers
.read()
.unwrap()
.typed_try_get::<Authorization<Basic>>();

let Ok(header) = header else {
Expand Down Expand Up @@ -67,6 +69,8 @@ testuser3:{SHA}Y2fEjdGT1W6nsLqtJbGUVeUp9e4=

req_context
.allowed_headers
.read()
.unwrap()
.typed_insert(Authorization::basic(username, password));

req_context
Expand Down Expand Up @@ -119,7 +123,7 @@ testuser3:{SHA}Y2fEjdGT1W6nsLqtJbGUVeUp9e4=
async fn verify_auth_failure() {
let provider = setup_provider();
let mut req_ctx = RequestContext::default();
req_ctx.allowed_headers.insert(
req_ctx.allowed_headers.lock().unwrap().insert(
"Authorization",
HeaderValue::from_static("Basic dGVzdHVzZXIyOm15cGFzc3dvcmQ"),
);
Expand Down
4 changes: 4 additions & 0 deletions src/core/auth/jwt/jwt_verify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ impl JwtVerifier {
fn resolve_token(&self, request: &RequestContext) -> anyhow::Result<Option<String>> {
let value = request
.allowed_headers
.read()
.unwrap()
.typed_try_get::<Authorization<Bearer>>()?;

Ok(value.map(|token| token.token().to_owned()))
Expand Down Expand Up @@ -168,6 +170,8 @@ pub mod tests {

req_context
.allowed_headers
.read()
.unwrap()
.typed_insert(Authorization::bearer(token).unwrap());

req_context
Expand Down
12 changes: 11 additions & 1 deletion src/core/blueprint/operators/graphql.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::collections::{HashMap, HashSet};
use std::str::FromStr;

use headers::{HeaderMap, HeaderName, HeaderValue};

use crate::core::blueprint::FieldDefinition;
use crate::core::config::{
Expand Down Expand Up @@ -71,7 +74,14 @@ pub fn compile_graphql(
.map(|req_template| {
let field_name = graphql.name.clone();
let batch = graphql.batch.clone();
IR::IO(IO::GraphQL { req_template, field_name, batch, dl_id: None })
let headers = graphql.headers.iter().filter_map(|kv| {
Some((
HeaderName::from_str(kv.key.as_str()).ok()?,
HeaderValue::from_str(kv.value.as_str()).ok()?,
))
});
let headers = HeaderMap::from_iter(headers);
IR::IO(IO::GraphQL { req_template, field_name, batch, dl_id: None, headers })
})
}

Expand Down
12 changes: 12 additions & 0 deletions src/core/blueprint/operators/grpc.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::fmt::Display;
use std::str::FromStr;

use headers::{HeaderMap, HeaderName, HeaderValue};
use prost_reflect::prost_types::FileDescriptorSet;
use prost_reflect::FieldDescriptor;

Expand Down Expand Up @@ -196,19 +198,29 @@ pub fn compile_grpc(inputs: CompileGrpc) -> Valid<IR, String> {
body,
operation_type: operation_type.clone(),
};
let grpc_headers = grpc.headers.iter().filter_map(|kv| {
Some((
HeaderName::from_str(kv.key.as_str()).ok()?,
HeaderValue::from_str(kv.value.as_str()).ok()?,
))
});
let grpc_headers = HeaderMap::from_iter(grpc_headers);

if !grpc.batch_key.is_empty() {
IR::IO(IO::Grpc {
req_template,
group_by: Some(GroupBy::new(grpc.batch_key.clone(), None)),
dl_id: None,
batch: grpc.batch.clone(),
headers: grpc_headers,
})
} else {
IR::IO(IO::Grpc {
req_template,
group_by: None,
dl_id: None,
batch: grpc.batch.clone(),
headers: grpc_headers,
})
}
})
Expand Down
15 changes: 15 additions & 0 deletions src/core/blueprint/operators/http.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
use std::str::FromStr;

use headers::{HeaderMap, HeaderName, HeaderValue};

use crate::core::blueprint::*;
use crate::core::config::group_by::GroupBy;
use crate::core::config::{Field, Resolver};
Expand Down Expand Up @@ -63,19 +67,29 @@ pub fn compile_http(
.or(config_module.upstream.on_request.clone())
.map(|on_request| HttpFilter { on_request });

let headers = http.headers.iter().filter_map(|kv| {
Some((
HeaderName::from_str(kv.key.as_str()).ok()?,
HeaderValue::from_str(kv.value.as_str()).ok()?,
))
});
let headers = HeaderMap::from_iter(headers);

if !http.batch_key.is_empty() && http.method == Method::GET {
// Find a query parameter that contains a reference to the {{.value}} key
let key = http.query.iter().find_map(|q| {
Mustache::parse(&q.value)
.expression_contains("value")
.then(|| q.key.clone())
});

IR::IO(IO::Http {
req_template,
group_by: Some(GroupBy::new(http.batch_key.clone(), key)),
dl_id: None,
http_filter,
batch: http.batch.clone(),
headers,
})
} else {
IR::IO(IO::Http {
Expand All @@ -84,6 +98,7 @@ pub fn compile_http(
dl_id: None,
http_filter,
batch: http.batch.clone(),
headers,
})
}
})
Expand Down
4 changes: 2 additions & 2 deletions src/core/config/reader_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ impl<'a> PathString for ConfigReaderContext<'a> {
}

impl HasHeaders for ConfigReaderContext<'_> {
fn headers(&self) -> &HeaderMap {
&self.headers
fn headers(&self) -> HeaderMap {
self.headers.to_owned()
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/core/graphql/request_template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl RequestTemplate {
reqwest::header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
headers.extend(ctx.headers().to_owned());
headers.extend(ctx.headers());
req
}

Expand Down Expand Up @@ -193,8 +193,8 @@ mod tests {
}

impl HasHeaders for Context {
fn headers(&self) -> &HeaderMap {
&self.headers
fn headers(&self) -> HeaderMap {
self.headers.to_owned()
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/core/grpc/request_template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ impl RequestTemplate {
req_headers.extend(headers);
}

req_headers.extend(ctx.headers().to_owned());
req_headers.extend(ctx.headers());

req_headers
}
Expand Down Expand Up @@ -210,8 +210,8 @@ mod tests {
}

impl crate::core::has_headers::HasHeaders for Context {
fn headers(&self) -> &HeaderMap {
&self.headers
fn headers(&self) -> HeaderMap {
self.headers.to_owned()
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/core/has_headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ use hyper::HeaderMap;
use crate::core::ir::{EvalContext, ResolverContextLike};

pub trait HasHeaders {
fn headers(&self) -> &HeaderMap;
fn headers(&self) -> HeaderMap;
}

impl<'a, Ctx: ResolverContextLike> HasHeaders for EvalContext<'a, Ctx> {
fn headers(&self) -> &HeaderMap {
fn headers(&self) -> HeaderMap {
self.headers()
}
}
8 changes: 4 additions & 4 deletions src/core/http/request_context.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::num::NonZeroU64;
use std::str::FromStr;
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex, RwLock};

use async_graphql_value::ConstValue;
use cache_control::{Cachability, CacheControl};
Expand All @@ -27,7 +27,7 @@ pub struct RequestContext {
pub cookie_headers: Option<Arc<Mutex<HeaderMap>>>,
// A subset of all the headers received in the GraphQL Request that will be sent to the
// upstream.
pub allowed_headers: HeaderMap,
pub allowed_headers: Arc<RwLock<HeaderMap>>,
pub auth_ctx: AuthContext,
pub http_data_loaders: Arc<Vec<DataLoader<DataLoaderRequest, HttpDataLoader>>>,
pub gql_data_loaders: Arc<Vec<DataLoader<DataLoaderRequest, GraphqlDataLoader>>>,
Expand All @@ -54,7 +54,7 @@ impl RequestContext {
runtime: target_runtime,
cache: DedupeResult::new(true),
dedupe_handler: Arc::new(DedupeResult::new(false)),
allowed_headers: HeaderMap::new(),
allowed_headers: Default::default(),
auth_ctx: AuthContext::default(),
}
}
Expand Down Expand Up @@ -197,7 +197,7 @@ impl From<&AppContext> for RequestContext {
upstream: app_ctx.blueprint.upstream.clone(),
x_response_headers: Arc::new(Mutex::new(HeaderMap::new())),
cookie_headers,
allowed_headers: HeaderMap::new(),
allowed_headers: Default::default(),
auth_ctx: (&app_ctx.auth_ctx).into(),
http_data_loaders: app_ctx.http_data_loaders.clone(),
gql_data_loaders: app_ctx.gql_data_loaders.clone(),
Expand Down
4 changes: 2 additions & 2 deletions src/core/http/request_handler.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::BTreeSet;
use std::ops::Deref;
use std::sync::Arc;
use std::sync::{Arc, RwLock};

use anyhow::Result;
use async_graphql::ServerError;
Expand Down Expand Up @@ -62,7 +62,7 @@ fn create_request_context(req: &Request<Body>, app_ctx: &AppContext) -> RequestC
let allowed_headers = create_allowed_headers(req.headers(), &allowed);

let _allowed = app_ctx.blueprint.server.get_experimental_headers();
RequestContext::from(app_ctx).allowed_headers(allowed_headers)
RequestContext::from(app_ctx).allowed_headers(Arc::new(RwLock::new(allowed_headers)))
}

fn update_cache_control_header(
Expand Down
6 changes: 3 additions & 3 deletions src/core/http/request_template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ impl RequestTemplate {
);
}

headers.extend(ctx.headers().to_owned());
headers.extend(ctx.headers());
req
}

Expand Down Expand Up @@ -352,8 +352,8 @@ mod tests {
}

impl crate::core::has_headers::HasHeaders for Context {
fn headers(&self) -> &HeaderMap {
&self.headers
fn headers(&self) -> HeaderMap {
self.headers.to_owned()
}
}

Expand Down
19 changes: 13 additions & 6 deletions src/core/ir/eval_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,21 @@ impl<'a, Ctx: ResolverContextLike> EvalContext<'a, Ctx> {
}
}

pub fn headers(&self) -> &HeaderMap {
&self.request_ctx.allowed_headers
pub fn headers(&self) -> HeaderMap {
self.request_ctx.allowed_headers.read().unwrap().to_owned()
}

pub fn header(&self, key: &str) -> Option<&str> {
let value = self.headers().get(key)?;

value.to_str().ok()
pub fn header(&self, key: &str) -> Option<String> {
let value = self
.request_ctx
.allowed_headers
.read()
.unwrap()
.get(key)?
.to_str()
.map(|v| v.to_owned())
.ok()?;
Some(value)
}

pub fn env_var(&self, key: &str) -> Option<Cow<'_, str>> {
Expand Down
Loading

0 comments on commit d159f4e

Please sign in to comment.