diff --git a/benches/request_template_bench.rs b/benches/request_template_bench.rs index a08fb27a83..57d6139df9 100644 --- a/benches/request_template_bench.rs +++ b/benches/request_template_bench.rs @@ -38,8 +38,8 @@ impl PathString for Context { } } impl HasHeaders for Context { - fn headers(&self) -> &HeaderMap { - &self.headers + fn headers(&self) -> HeaderMap { + self.headers.to_owned() } } pub fn benchmark_to_request(c: &mut Criterion) { diff --git a/src/core/app_context.rs b/src/core/app_context.rs index 57ba85f26d..f2eab338e0 100644 --- a/src/core/app_context.rs +++ b/src/core/app_context.rs @@ -47,7 +47,14 @@ impl AppContext { field.map_expr(|expr| { expr.modify(|expr| match expr { IR::IO(io) => match io { - IO::Http { req_template, group_by, http_filter, batch, .. } => { + IO::Http { + req_template, + group_by, + http_filter, + batch, + headers, + .. + } => { let data_loader = HttpDataLoader::new( runtime.clone(), group_by.clone(), @@ -61,6 +68,7 @@ impl AppContext { dl_id: Some(DataLoaderId::new(http_data_loaders.len())), http_filter: http_filter.clone(), batch: batch.clone(), + headers: headers.clone(), })); http_data_loaders.push(data_loader); @@ -68,7 +76,9 @@ impl AppContext { result } - IO::GraphQL { req_template, field_name, batch, .. } => { + IO::GraphQL { + req_template, field_name, batch, headers, .. + } => { let graphql_data_loader = GraphqlDataLoader::new(runtime.clone(), batch.is_some()) .into_data_loader(batch.clone().unwrap_or_default()); @@ -78,6 +88,7 @@ impl AppContext { field_name: field_name.clone(), batch: batch.clone(), dl_id: Some(DataLoaderId::new(gql_data_loaders.len())), + headers: headers.clone(), })); gql_data_loaders.push(graphql_data_loader); @@ -85,7 +96,7 @@ impl AppContext { result } - IO::Grpc { req_template, group_by, batch, .. } => { + IO::Grpc { req_template, group_by, batch, headers, .. } => { let data_loader = GrpcDataLoader { runtime: runtime.clone(), operation: req_template.operation.clone(), @@ -99,6 +110,7 @@ impl AppContext { group_by: group_by.clone(), dl_id: Some(DataLoaderId::new(grpc_data_loaders.len())), batch: batch.clone(), + headers: headers.clone(), })); grpc_data_loaders.push(data_loader); diff --git a/src/core/auth/basic.rs b/src/core/auth/basic.rs index 9856f00d66..391541cc78 100644 --- a/src/core/auth/basic.rs +++ b/src/core/auth/basic.rs @@ -18,6 +18,8 @@ impl Verify for BasicVerifier { async fn verify(&self, req_ctx: &RequestContext) -> Verification { let header = req_ctx .allowed_headers + .read() + .unwrap() .typed_try_get::>(); let Ok(header) = header else { @@ -67,6 +69,8 @@ testuser3:{SHA}Y2fEjdGT1W6nsLqtJbGUVeUp9e4= req_context .allowed_headers + .read() + .unwrap() .typed_insert(Authorization::basic(username, password)); req_context @@ -119,7 +123,7 @@ testuser3:{SHA}Y2fEjdGT1W6nsLqtJbGUVeUp9e4= async fn verify_auth_failure() { let provider = setup_provider(); let mut req_ctx = RequestContext::default(); - req_ctx.allowed_headers.insert( + req_ctx.allowed_headers.lock().unwrap().insert( "Authorization", HeaderValue::from_static("Basic dGVzdHVzZXIyOm15cGFzc3dvcmQ"), ); diff --git a/src/core/auth/jwt/jwt_verify.rs b/src/core/auth/jwt/jwt_verify.rs index 5238b4364e..be9a183bd3 100644 --- a/src/core/auth/jwt/jwt_verify.rs +++ b/src/core/auth/jwt/jwt_verify.rs @@ -41,6 +41,8 @@ impl JwtVerifier { fn resolve_token(&self, request: &RequestContext) -> anyhow::Result> { let value = request .allowed_headers + .read() + .unwrap() .typed_try_get::>()?; Ok(value.map(|token| token.token().to_owned())) @@ -168,6 +170,8 @@ pub mod tests { req_context .allowed_headers + .read() + .unwrap() .typed_insert(Authorization::bearer(token).unwrap()); req_context diff --git a/src/core/blueprint/operators/graphql.rs b/src/core/blueprint/operators/graphql.rs index 7f151cb7cc..f950750b29 100644 --- a/src/core/blueprint/operators/graphql.rs +++ b/src/core/blueprint/operators/graphql.rs @@ -1,4 +1,7 @@ use std::collections::{HashMap, HashSet}; +use std::str::FromStr; + +use headers::{HeaderMap, HeaderName, HeaderValue}; use crate::core::blueprint::FieldDefinition; use crate::core::config::{ @@ -71,7 +74,14 @@ pub fn compile_graphql( .map(|req_template| { let field_name = graphql.name.clone(); let batch = graphql.batch.clone(); - IR::IO(IO::GraphQL { req_template, field_name, batch, dl_id: None }) + let headers = graphql.headers.iter().filter_map(|kv| { + Some(( + HeaderName::from_str(kv.key.as_str()).ok()?, + HeaderValue::from_str(kv.value.as_str()).ok()?, + )) + }); + let headers = HeaderMap::from_iter(headers); + IR::IO(IO::GraphQL { req_template, field_name, batch, dl_id: None, headers }) }) } diff --git a/src/core/blueprint/operators/grpc.rs b/src/core/blueprint/operators/grpc.rs index 5d9e585c5b..070cf6a56e 100644 --- a/src/core/blueprint/operators/grpc.rs +++ b/src/core/blueprint/operators/grpc.rs @@ -1,5 +1,7 @@ use std::fmt::Display; +use std::str::FromStr; +use headers::{HeaderMap, HeaderName, HeaderValue}; use prost_reflect::prost_types::FileDescriptorSet; use prost_reflect::FieldDescriptor; @@ -196,12 +198,21 @@ pub fn compile_grpc(inputs: CompileGrpc) -> Valid { body, operation_type: operation_type.clone(), }; + let grpc_headers = grpc.headers.iter().filter_map(|kv| { + Some(( + HeaderName::from_str(kv.key.as_str()).ok()?, + HeaderValue::from_str(kv.value.as_str()).ok()?, + )) + }); + let grpc_headers = HeaderMap::from_iter(grpc_headers); + if !grpc.batch_key.is_empty() { IR::IO(IO::Grpc { req_template, group_by: Some(GroupBy::new(grpc.batch_key.clone(), None)), dl_id: None, batch: grpc.batch.clone(), + headers: grpc_headers, }) } else { IR::IO(IO::Grpc { @@ -209,6 +220,7 @@ pub fn compile_grpc(inputs: CompileGrpc) -> Valid { group_by: None, dl_id: None, batch: grpc.batch.clone(), + headers: grpc_headers, }) } }) diff --git a/src/core/blueprint/operators/http.rs b/src/core/blueprint/operators/http.rs index cc356c40f8..e8755df092 100644 --- a/src/core/blueprint/operators/http.rs +++ b/src/core/blueprint/operators/http.rs @@ -1,3 +1,7 @@ +use std::str::FromStr; + +use headers::{HeaderMap, HeaderName, HeaderValue}; + use crate::core::blueprint::*; use crate::core::config::group_by::GroupBy; use crate::core::config::{Field, Resolver}; @@ -63,6 +67,14 @@ pub fn compile_http( .or(config_module.upstream.on_request.clone()) .map(|on_request| HttpFilter { on_request }); + let headers = http.headers.iter().filter_map(|kv| { + Some(( + HeaderName::from_str(kv.key.as_str()).ok()?, + HeaderValue::from_str(kv.value.as_str()).ok()?, + )) + }); + let headers = HeaderMap::from_iter(headers); + if !http.batch_key.is_empty() && http.method == Method::GET { // Find a query parameter that contains a reference to the {{.value}} key let key = http.query.iter().find_map(|q| { @@ -70,12 +82,14 @@ pub fn compile_http( .expression_contains("value") .then(|| q.key.clone()) }); + IR::IO(IO::Http { req_template, group_by: Some(GroupBy::new(http.batch_key.clone(), key)), dl_id: None, http_filter, batch: http.batch.clone(), + headers, }) } else { IR::IO(IO::Http { @@ -84,6 +98,7 @@ pub fn compile_http( dl_id: None, http_filter, batch: http.batch.clone(), + headers, }) } }) diff --git a/src/core/config/reader_context.rs b/src/core/config/reader_context.rs index 693e75fcfd..b4e05a6d06 100644 --- a/src/core/config/reader_context.rs +++ b/src/core/config/reader_context.rs @@ -29,8 +29,8 @@ impl<'a> PathString for ConfigReaderContext<'a> { } impl HasHeaders for ConfigReaderContext<'_> { - fn headers(&self) -> &HeaderMap { - &self.headers + fn headers(&self) -> HeaderMap { + self.headers.to_owned() } } diff --git a/src/core/graphql/request_template.rs b/src/core/graphql/request_template.rs index c32433e76e..4c6f819ccd 100644 --- a/src/core/graphql/request_template.rs +++ b/src/core/graphql/request_template.rs @@ -57,7 +57,7 @@ impl RequestTemplate { reqwest::header::CONTENT_TYPE, HeaderValue::from_static("application/json"), ); - headers.extend(ctx.headers().to_owned()); + headers.extend(ctx.headers()); req } @@ -193,8 +193,8 @@ mod tests { } impl HasHeaders for Context { - fn headers(&self) -> &HeaderMap { - &self.headers + fn headers(&self) -> HeaderMap { + self.headers.to_owned() } } diff --git a/src/core/grpc/request_template.rs b/src/core/grpc/request_template.rs index ceddf1ac2a..bff5b3fc36 100644 --- a/src/core/grpc/request_template.rs +++ b/src/core/grpc/request_template.rs @@ -103,7 +103,7 @@ impl RequestTemplate { req_headers.extend(headers); } - req_headers.extend(ctx.headers().to_owned()); + req_headers.extend(ctx.headers()); req_headers } @@ -210,8 +210,8 @@ mod tests { } impl crate::core::has_headers::HasHeaders for Context { - fn headers(&self) -> &HeaderMap { - &self.headers + fn headers(&self) -> HeaderMap { + self.headers.to_owned() } } diff --git a/src/core/has_headers.rs b/src/core/has_headers.rs index 3c882d514e..e2faced623 100644 --- a/src/core/has_headers.rs +++ b/src/core/has_headers.rs @@ -3,11 +3,11 @@ use hyper::HeaderMap; use crate::core::ir::{EvalContext, ResolverContextLike}; pub trait HasHeaders { - fn headers(&self) -> &HeaderMap; + fn headers(&self) -> HeaderMap; } impl<'a, Ctx: ResolverContextLike> HasHeaders for EvalContext<'a, Ctx> { - fn headers(&self) -> &HeaderMap { + fn headers(&self) -> HeaderMap { self.headers() } } diff --git a/src/core/http/request_context.rs b/src/core/http/request_context.rs index a41d7e8cad..4fd555c54d 100644 --- a/src/core/http/request_context.rs +++ b/src/core/http/request_context.rs @@ -1,6 +1,6 @@ use std::num::NonZeroU64; use std::str::FromStr; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, RwLock}; use async_graphql_value::ConstValue; use cache_control::{Cachability, CacheControl}; @@ -27,7 +27,7 @@ pub struct RequestContext { pub cookie_headers: Option>>, // A subset of all the headers received in the GraphQL Request that will be sent to the // upstream. - pub allowed_headers: HeaderMap, + pub allowed_headers: Arc>, pub auth_ctx: AuthContext, pub http_data_loaders: Arc>>, pub gql_data_loaders: Arc>>, @@ -54,7 +54,7 @@ impl RequestContext { runtime: target_runtime, cache: DedupeResult::new(true), dedupe_handler: Arc::new(DedupeResult::new(false)), - allowed_headers: HeaderMap::new(), + allowed_headers: Default::default(), auth_ctx: AuthContext::default(), } } @@ -197,7 +197,7 @@ impl From<&AppContext> for RequestContext { upstream: app_ctx.blueprint.upstream.clone(), x_response_headers: Arc::new(Mutex::new(HeaderMap::new())), cookie_headers, - allowed_headers: HeaderMap::new(), + allowed_headers: Default::default(), auth_ctx: (&app_ctx.auth_ctx).into(), http_data_loaders: app_ctx.http_data_loaders.clone(), gql_data_loaders: app_ctx.gql_data_loaders.clone(), diff --git a/src/core/http/request_handler.rs b/src/core/http/request_handler.rs index 10fe134ad9..e7f54671b2 100644 --- a/src/core/http/request_handler.rs +++ b/src/core/http/request_handler.rs @@ -1,6 +1,6 @@ use std::collections::BTreeSet; use std::ops::Deref; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use anyhow::Result; use async_graphql::ServerError; @@ -62,7 +62,7 @@ fn create_request_context(req: &Request, app_ctx: &AppContext) -> RequestC let allowed_headers = create_allowed_headers(req.headers(), &allowed); let _allowed = app_ctx.blueprint.server.get_experimental_headers(); - RequestContext::from(app_ctx).allowed_headers(allowed_headers) + RequestContext::from(app_ctx).allowed_headers(Arc::new(RwLock::new(allowed_headers))) } fn update_cache_control_header( diff --git a/src/core/http/request_template.rs b/src/core/http/request_template.rs index d484270b5b..67ec99acd3 100644 --- a/src/core/http/request_template.rs +++ b/src/core/http/request_template.rs @@ -179,7 +179,7 @@ impl RequestTemplate { ); } - headers.extend(ctx.headers().to_owned()); + headers.extend(ctx.headers()); req } @@ -352,8 +352,8 @@ mod tests { } impl crate::core::has_headers::HasHeaders for Context { - fn headers(&self) -> &HeaderMap { - &self.headers + fn headers(&self) -> HeaderMap { + self.headers.to_owned() } } diff --git a/src/core/ir/eval_context.rs b/src/core/ir/eval_context.rs index 597335c00e..34b43db3d5 100644 --- a/src/core/ir/eval_context.rs +++ b/src/core/ir/eval_context.rs @@ -79,14 +79,21 @@ impl<'a, Ctx: ResolverContextLike> EvalContext<'a, Ctx> { } } - pub fn headers(&self) -> &HeaderMap { - &self.request_ctx.allowed_headers + pub fn headers(&self) -> HeaderMap { + self.request_ctx.allowed_headers.read().unwrap().to_owned() } - pub fn header(&self, key: &str) -> Option<&str> { - let value = self.headers().get(key)?; - - value.to_str().ok() + pub fn header(&self, key: &str) -> Option { + let value = self + .request_ctx + .allowed_headers + .read() + .unwrap() + .get(key)? + .to_str() + .map(|v| v.to_owned()) + .ok()?; + Some(value) } pub fn env_var(&self, key: &str) -> Option> { diff --git a/src/core/ir/model.rs b/src/core/ir/model.rs index d37edcfcf2..4635d698ae 100644 --- a/src/core/ir/model.rs +++ b/src/core/ir/model.rs @@ -3,6 +3,7 @@ use std::fmt::Debug; use std::num::NonZeroU64; use async_graphql::Value; +use headers::HeaderMap; use strum_macros::Display; use super::discriminator::Discriminator; @@ -44,18 +45,21 @@ pub enum IO { dl_id: Option, http_filter: Option, batch: Option, + headers: HeaderMap, }, GraphQL { req_template: graphql::RequestTemplate, field_name: String, batch: Option, dl_id: Option, + headers: HeaderMap, }, Grpc { req_template: grpc::RequestTemplate, group_by: Option, dl_id: Option, batch: Option, + headers: HeaderMap, }, Js { name: String, diff --git a/src/core/jit/exec.rs b/src/core/jit/exec.rs index f8d0da9331..5136046666 100644 --- a/src/core/jit/exec.rs +++ b/src/core/jit/exec.rs @@ -4,6 +4,7 @@ use std::sync::{Arc, Mutex}; use derive_getters::Getters; use futures_util::future::join_all; +use headers::HeaderMap; use super::context::{Context, RequestContext}; use super::{DataPath, OperationPlan, Positioned, Response, Store}; @@ -182,4 +183,5 @@ pub trait IRExecutor { ir: &'a IR, ctx: &'a Context<'a, Self::Input, Self::Output>, ) -> Result; + fn headers<'a>(&self, ir: &'a IR) -> Option<&'a HeaderMap>; } diff --git a/src/core/jit/exec_const.rs b/src/core/jit/exec_const.rs index 6e1c74682f..b53598065a 100644 --- a/src/core/jit/exec_const.rs +++ b/src/core/jit/exec_const.rs @@ -1,13 +1,14 @@ use std::sync::Arc; use async_graphql_value::ConstValue; +use headers::HeaderMap; use super::context::Context; use super::exec::{Executor, IRExecutor}; use super::{Error, OperationPlan, Request, Response, Result}; use crate::core::app_context::AppContext; use crate::core::http::RequestContext; -use crate::core::ir::model::IR; +use crate::core::ir::model::{IO, IR}; use crate::core::ir::EvalContext; use crate::core::jit::synth::Synth; @@ -59,7 +60,25 @@ impl<'ctx> IRExecutor for ConstValueExec<'ctx> { ) -> Result { let req_context = &self.req_context; let mut eval_ctx = EvalContext::new(req_context, ctx); + if let Some(h) = self.headers(ir) { + self.req_context + .allowed_headers + .write() + .unwrap() + .extend(h.to_owned()); + } Ok(ir.eval(&mut eval_ctx).await?) } + fn headers<'a>(&self, ir: &'a IR) -> Option<&'a HeaderMap> { + match ir { + IR::IO(io) => match io { + IO::Http { headers, .. } => Some(headers), + IO::GraphQL { headers, .. } => Some(headers), + IO::Grpc { headers, .. } => Some(headers), + IO::Js { .. } => None, + }, + _ => None, + } + } } diff --git a/src/core/path.rs b/src/core/path.rs index b50bb6f6ce..f475319ea0 100644 --- a/src/core/path.rs +++ b/src/core/path.rs @@ -88,7 +88,7 @@ impl<'a, Ctx: ResolverContextLike> EvalContext<'a, Ctx> { .and_then(move |(head, tail)| match head.as_ref() { "value" => Some(ValueString::Value(ctx.path_value(tail)?)), "args" => Some(ValueString::Value(ctx.path_arg(tail)?)), - "headers" => Some(ValueString::String(Cow::Borrowed( + "headers" => Some(ValueString::String(Cow::Owned( ctx.header(tail[0].as_ref())?, ))), "vars" => Some(ValueString::String(Cow::Borrowed( diff --git a/tests/core/snapshots/grpc-url-from-upstream.md_merged.snap b/tests/core/snapshots/grpc-url-from-upstream.md_merged.snap index 0163528d57..467d1fcc21 100644 --- a/tests/core/snapshots/grpc-url-from-upstream.md_merged.snap +++ b/tests/core/snapshots/grpc-url-from-upstream.md_merged.snap @@ -4,7 +4,7 @@ expression: formatter --- schema @server(port: 8000) - @upstream(baseURL: "http://localhost:50051", batch: {delay: 10, headers: []}, httpCache: 42) + @upstream(baseURL: "http://localhost:50051", httpCache: 42) @link(id: "news", src: "news.proto", type: Protobuf) { query: Query } diff --git a/tests/core/snapshots/resolve-with-headers.md_merged.snap b/tests/core/snapshots/resolve-with-headers.md_merged.snap index b382420ea5..7959e4048e 100644 --- a/tests/core/snapshots/resolve-with-headers.md_merged.snap +++ b/tests/core/snapshots/resolve-with-headers.md_merged.snap @@ -2,7 +2,7 @@ source: tests/core/spec.rs expression: formatter --- -schema @server @upstream(allowedHeaders: ["authorization"]) { +schema @server @upstream { query: Query } @@ -14,5 +14,10 @@ type Post { } type Query { - post1: Post @http(baseURL: "http://jsonplaceholder.typicode.com", path: "/posts/{{.headers.authorization}}") + post1: Post + @http( + baseURL: "http://jsonplaceholder.typicode.com" + path: "/posts/{{.headers.authorization}}" + allowedHeaders: ["authorization"] + ) } diff --git a/tests/execution/grpc-url-from-upstream.md b/tests/execution/grpc-url-from-upstream.md index bbfaf28a0a..ad5303b5d9 100644 --- a/tests/execution/grpc-url-from-upstream.md +++ b/tests/execution/grpc-url-from-upstream.md @@ -39,7 +39,7 @@ message NewsList { ```graphql @config schema @server(port: 8000) - @upstream(httpCache: 42, batch: {delay: 10}, baseURL: "http://localhost:50051") + @upstream(httpCache: 42, baseURL: "http://localhost:50051") @link(id: "news", src: "news.proto", type: Protobuf) { query: Query }