From 09809bbe756c7e059bc9d257783ad859e9b0ab9e Mon Sep 17 00:00:00 2001 From: Kunam Balaram Reddy Date: Fri, 7 Jun 2024 13:40:33 +0530 Subject: [PATCH] feat: multiple onRequest handlers (#1863) Co-authored-by: Tushar Mathur Co-authored-by: amit --- benches/data_loader_bench.rs | 2 +- ...impl_path_string_for_evaluation_context.rs | 2 +- examples/jsonplaceholder_script.graphql | 2 +- generated/.tailcallrc.graphql | 14 + generated/.tailcallrc.schema.json | 14 + .../javascript/{js_request.rs => codec.rs} | 322 ++++++++++++------ src/cli/javascript/js_response.rs | 206 ----------- src/cli/javascript/mod.rs | 18 +- src/cli/javascript/request_filter.rs | 151 -------- src/cli/javascript/runtime.rs | 8 +- src/cli/runtime/mod.rs | 24 +- src/core/app_context.rs | 9 +- src/core/blueprint/operators/http.rs | 12 +- src/core/blueprint/upstream.rs | 2 + src/core/config/config.rs | 5 + src/core/config/upstream.rs | 9 + src/core/generator/graphql_type.rs | 1 - src/core/http/mod.rs | 13 + src/core/ir/io.rs | 147 ++++++-- src/core/mod.rs | 6 +- src/core/runtime.rs | 32 +- src/core/worker.rs | 186 +++++++++- tailcall-aws-lambda/src/runtime.rs | 2 +- tailcall-cloudflare/src/runtime.rs | 2 +- tests/core/parse.rs | 9 +- tests/core/runtime.rs | 26 +- ...test-js-multi-onRequest-handlers.md_0.snap | 16 + ...js-multi-onRequest-handlers.md_client.snap | 24 ++ ...js-multi-onRequest-handlers.md_merged.snap | 12 + .../test-js-request-response-2.md_merged.snap | 4 +- .../test-js-request-response.md_merged.snap | 4 +- .../test-js-multi-onRequest-handlers.md | 47 +++ tests/execution/test-js-request-response-2.md | 2 +- tests/execution/test-js-request-response.md | 2 +- tests/server_spec.rs | 30 +- 35 files changed, 767 insertions(+), 598 deletions(-) rename src/cli/javascript/{js_request.rs => codec.rs} (50%) delete mode 100644 src/cli/javascript/js_response.rs delete mode 100644 src/cli/javascript/request_filter.rs create mode 100644 tests/core/snapshots/test-js-multi-onRequest-handlers.md_0.snap create mode 100644 tests/core/snapshots/test-js-multi-onRequest-handlers.md_client.snap create mode 100644 tests/core/snapshots/test-js-multi-onRequest-handlers.md_merged.snap create mode 100644 tests/execution/test-js-multi-onRequest-handlers.md diff --git a/benches/data_loader_bench.rs b/benches/data_loader_bench.rs index f2385df093..c4a743be46 100644 --- a/benches/data_loader_bench.rs +++ b/benches/data_loader_bench.rs @@ -80,7 +80,7 @@ pub fn benchmark_data_loader(c: &mut Criterion) { file: Arc::new(File {}), cache: Arc::new(Cache {}), extensions: Arc::new(vec![]), - http_worker: None, + cmd_worker: None, worker: None, }; let loader = HttpDataLoader::new(rt, None, false); diff --git a/benches/impl_path_string_for_evaluation_context.rs b/benches/impl_path_string_for_evaluation_context.rs index 7fac1c0e6f..d3e74c3184 100644 --- a/benches/impl_path_string_for_evaluation_context.rs +++ b/benches/impl_path_string_for_evaluation_context.rs @@ -246,7 +246,7 @@ fn request_context() -> RequestContext { file: Arc::new(File {}), cache: Arc::new(InMemoryCache::new()), extensions: Arc::new(vec![]), - http_worker: None, + cmd_worker: None, worker: None, }; RequestContext::new(runtime) diff --git a/examples/jsonplaceholder_script.graphql b/examples/jsonplaceholder_script.graphql index b141b92f0b..71924d48ff 100644 --- a/examples/jsonplaceholder_script.graphql +++ b/examples/jsonplaceholder_script.graphql @@ -1,6 +1,6 @@ schema @server(port: 8000, hostname: "0.0.0.0") - @upstream(baseURL: "http://jsonplaceholder.typicode.com", httpCache: 42) + @upstream(baseURL: "http://jsonplaceholder.typicode.com", httpCache: 42, onRequest: "onRequest") @link(type: Script, src: "scripts/echo.js") { query: Query } diff --git a/generated/.tailcallrc.graphql b/generated/.tailcallrc.graphql index 694b44c433..dabbc52ff1 100644 --- a/generated/.tailcallrc.graphql +++ b/generated/.tailcallrc.graphql @@ -161,6 +161,11 @@ directive @http( """ method: Method """ + onRequest field in @http directive gives the ability to specify the request interception + handler. + """ + onRequest: String + """ Schema of the output of the API call. It is automatically inferred in most cases. """ output: Schema @@ -382,6 +387,10 @@ directive @upstream( """ keepAliveWhileIdle: Boolean """ + onRequest field gives the ability to specify the global request interception handler. + """ + onRequest: String + """ The time in seconds that the connection pool will wait before closing idle connections. """ poolIdleTimeout: Int @@ -646,6 +655,11 @@ input Http { """ method: Method """ + onRequest field in @http directive gives the ability to specify the request interception + handler. + """ + onRequest: String + """ Schema of the output of the API call. It is automatically inferred in most cases. """ output: Schema diff --git a/generated/.tailcallrc.schema.json b/generated/.tailcallrc.schema.json index 83b13b251f..a53f3864d8 100644 --- a/generated/.tailcallrc.schema.json +++ b/generated/.tailcallrc.schema.json @@ -713,6 +713,13 @@ } ] }, + "onRequest": { + "description": "onRequest field in @http directive gives the ability to specify the request interception handler.", + "type": [ + "string", + "null" + ] + }, "output": { "description": "Schema of the output of the API call. It is automatically inferred in most cases.", "anyOf": [ @@ -1399,6 +1406,13 @@ "null" ] }, + "onRequest": { + "description": "onRequest field gives the ability to specify the global request interception handler.", + "type": [ + "string", + "null" + ] + }, "poolIdleTimeout": { "description": "The time in seconds that the connection pool will wait before closing idle connections.", "type": [ diff --git a/src/cli/javascript/js_request.rs b/src/cli/javascript/codec.rs similarity index 50% rename from src/cli/javascript/js_request.rs rename to src/cli/javascript/codec.rs index 0785352e15..c827d9baf7 100644 --- a/src/cli/javascript/js_request.rs +++ b/src/cli/javascript/codec.rs @@ -1,61 +1,42 @@ use std::collections::BTreeMap; -use std::fmt::Display; use std::str::FromStr; use headers::HeaderValue; use reqwest::header::HeaderName; -use reqwest::Request; use rquickjs::{FromJs, IntoJs}; -use serde::{Deserialize, Serialize}; -use crate::core::is_default; -use crate::core::worker::WorkerRequest; +use super::create_header_map; +use crate::core::http::Response; +use crate::core::worker::*; -impl WorkerRequest { - fn uri(&self) -> Uri { - self.0.url().into() - } - - fn method(&self) -> String { - self.0.method().to_string() - } - - fn headers(&self) -> anyhow::Result> { - let headers = self.0.headers(); - let mut map = BTreeMap::new(); - for (k, v) in headers.iter() { - map.insert(k.to_string(), v.to_str()?.to_string()); - } - Ok(map) - } +impl<'js> FromJs<'js> for Command { + fn from_js(ctx: &rquickjs::Ctx<'js>, value: rquickjs::Value<'js>) -> rquickjs::Result { + let object = value.as_object().ok_or(rquickjs::Error::FromJs { + from: value.type_name(), + to: "rquickjs::Object", + message: Some("unable to cast JS Value as object".to_string()), + })?; - fn body(&self) -> Option { - if let Some(body) = self.0.body() { - let bytes = body.as_bytes()?; - Some(String::from_utf8_lossy(bytes).to_string()) + if object.contains_key("request")? { + Ok(Command::Request(WorkerRequest::from_js( + ctx, + object.get("request")?, + )?)) + } else if object.contains_key("response")? { + Ok(Command::Response(WorkerResponse::from_js( + ctx, + object.get("response")?, + )?)) } else { - None + Err(rquickjs::Error::FromJs { + from: "object", + to: "tailcall::cli::javascript::request_filter::Command", + message: Some("object must contain either request or response".to_string()), + }) } } } -impl TryFrom<&reqwest::Request> for WorkerRequest { - type Error = anyhow::Error; - - fn try_from(value: &Request) -> Result { - let request = value - .try_clone() - .ok_or(anyhow::anyhow!("unable to clone request"))?; - Ok(WorkerRequest(request)) - } -} - -impl From for reqwest::Request { - fn from(val: WorkerRequest) -> Self { - val.0 - } -} - impl<'js> IntoJs<'js> for WorkerRequest { fn into_js(self, ctx: &rquickjs::Ctx<'js>) -> rquickjs::Result> { let object = rquickjs::Object::new(ctx.clone())?; @@ -122,13 +103,6 @@ impl<'js> FromJs<'js> for WorkerRequest { } } -#[derive(Serialize, Deserialize, Default, Debug, PartialEq, Eq)] -pub enum Scheme { - #[default] - Http, - Https, -} - impl<'js> IntoJs<'js> for Scheme { fn into_js(self, ctx: &rquickjs::Ctx<'js>) -> rquickjs::Result> { match self { @@ -161,20 +135,6 @@ impl<'js> FromJs<'js> for Scheme { } } -#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] -#[serde(rename_all = "camelCase")] -pub struct Uri { - path: String, - #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] - query: BTreeMap, - #[serde(default, skip_serializing_if = "is_default")] - scheme: Scheme, - #[serde(default, skip_serializing_if = "is_default")] - host: Option, - #[serde(default, skip_serializing_if = "is_default")] - port: Option, -} - impl<'js> IntoJs<'js> for Uri { fn into_js(self, ctx: &rquickjs::Ctx<'js>) -> rquickjs::Result> { let object = rquickjs::Object::new(ctx.clone())?; @@ -204,53 +164,213 @@ impl<'js> FromJs<'js> for Uri { } } -impl From<&reqwest::Url> for Uri { - fn from(value: &reqwest::Url) -> Self { - Self { - path: value.path().to_string(), - query: value.query_pairs().into_owned().collect(), - scheme: match value.scheme() { - "https" => Scheme::Https, - _ => Scheme::Http, - }, - host: value.host_str().map(|u| u.to_string()), - port: value.port(), - } +impl<'js> IntoJs<'js> for WorkerResponse { + fn into_js(self, ctx: &rquickjs::Ctx<'js>) -> rquickjs::Result> { + let object = rquickjs::Object::new(ctx.clone())?; + object.set("status", self.status())?; + object.set("headers", self.headers())?; + object.set("body", self.body())?; + Ok(object.into_value()) } } -impl Display for Uri { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let host = self.host.as_deref().unwrap_or("localhost"); - let port = self.port.map(|p| format!(":{}", p)).unwrap_or_default(); - let scheme = match self.scheme { - Scheme::Https => "https", - _ => "http", +impl<'js> FromJs<'js> for WorkerResponse { + fn from_js(_: &rquickjs::Ctx<'js>, value: rquickjs::Value<'js>) -> rquickjs::Result { + let object = value.as_object().ok_or(rquickjs::Error::FromJs { + from: value.type_name(), + to: "rquickjs::Object", + message: Some("unable to cast JS Value as object".to_string()), + })?; + let status = object.get::<&str, u16>("status")?; + let headers = object.get::<&str, BTreeMap>("headers")?; + let body = object.get::<&str, Option>("body")?; + let response = Response { + status: reqwest::StatusCode::from_u16(status).map_err(|_| rquickjs::Error::FromJs { + from: "u16", + to: "reqwest::StatusCode", + message: Some("invalid status code".to_string()), + })?, + headers: create_header_map(headers).map_err(|e| rquickjs::Error::FromJs { + from: "BTreeMap", + to: "reqwest::header::HeaderMap", + message: Some(e.to_string()), + })?, + body: body.unwrap_or_default(), }; - let path = self.path.as_str(); - let query = self - .query - .iter() - .map(|(k, v)| format!("{}={}", k, v)) - .collect::>() - .join("&"); - - write!(f, "{}://{}{}{}", scheme, host, port, path)?; - - if !query.is_empty() { - write!(f, "?{}", query)?; - } - - Ok(()) + Ok(WorkerResponse(response)) } } #[cfg(test)] -mod tests { +mod test { + use std::collections::BTreeMap; + + use anyhow::Result; + use headers::{HeaderName, HeaderValue}; + use hyper::body::Bytes; use pretty_assertions::assert_eq; - use rquickjs::{Context, Runtime}; + use reqwest::header::HeaderMap; + use reqwest::Request; + use rquickjs::{Context, FromJs, IntoJs, Object, Runtime, String as JsString}; use super::*; + use crate::core::http::Response; + use crate::core::worker::{Command, WorkerRequest, WorkerResponse}; + + fn create_test_response() -> Result { + let mut headers = HeaderMap::new(); + headers.insert("content-type", "application/json".parse().unwrap()); + let response = crate::core::http::Response { + status: reqwest::StatusCode::OK, + headers, + body: Bytes::from("Hello, World!"), + }; + let js_response: Result = response.try_into(); + js_response + } + + #[test] + fn test_to_js_response() { + let js_response = create_test_response(); + assert!(js_response.is_ok()); + let js_response = js_response.unwrap(); + assert_eq!(js_response.status(), 200); + assert_eq!( + js_response.headers().get("content-type").unwrap(), + "application/json" + ); + assert_eq!(js_response.body(), Some("Hello, World!".into())); + } + + #[test] + fn test_from_js_response() { + let js_response = create_test_response().unwrap(); + let response: Result> = js_response.try_into(); + assert!(response.is_ok()); + let response = response.unwrap(); + assert_eq!(response.status, reqwest::StatusCode::OK); + assert_eq!( + response.headers.get("content-type").unwrap(), + "application/json" + ); + assert_eq!(response.body, Bytes::from("Hello, World!")); + } + + #[test] + fn test_unusual_headers() { + let body = "a"; + let mut headers = HeaderMap::new(); + headers.insert( + HeaderName::from_static("x-unusual-header"), + HeaderValue::from_str("🚀").unwrap(), + ); + let response = crate::core::http::Response { + status: reqwest::StatusCode::OK, + headers, + body: body.into(), + }; + let js_response = WorkerResponse(response); + + let response: Result, _> = js_response.try_into(); + assert!(response.is_ok()); + let response = response.unwrap(); + assert_eq!(response.headers.get("x-unusual-header").unwrap(), "🚀"); + assert_eq!(response.body, Bytes::from(body)); + } + + #[test] + fn test_response_into_js() { + let runtime = Runtime::new().unwrap(); + let context = Context::base(&runtime).unwrap(); + context.with(|ctx| { + let value = create_test_response().unwrap().into_js(&ctx).unwrap(); + let object = value.as_object().unwrap(); + + let status = object.get::<&str, u16>("status").unwrap(); + let headers = object + .get::<&str, BTreeMap>("headers") + .unwrap(); + let body = object.get::<&str, Option>("body").unwrap(); + + assert_eq!(status, reqwest::StatusCode::OK); + assert_eq!(body, Some("Hello, World!".to_owned())); + assert!(headers.contains_key("content-type")); + assert_eq!( + headers.get("content-type"), + Some(&"application/json".to_owned()) + ); + }); + } + + #[test] + fn test_response_from_js() { + let runtime = Runtime::new().unwrap(); + let context = Context::base(&runtime).unwrap(); + context.with(|ctx| { + let js_response = create_test_response().unwrap().into_js(&ctx).unwrap(); + let response = WorkerResponse::from_js(&ctx, js_response).unwrap(); + + assert_eq!(response.status(), reqwest::StatusCode::OK.as_u16()); + assert_eq!(response.body(), Some("Hello, World!".to_owned())); + assert_eq!( + response.headers().get("content-type"), + Some(&"application/json".to_owned()) + ); + }); + } + + #[test] + fn test_command_from_invalid_object() { + let runtime = Runtime::new().unwrap(); + let context = Context::base(&runtime).unwrap(); + context.with(|ctx| { + let value = JsString::from_str(ctx.clone(), "invalid") + .unwrap() + .into_value(); + assert!(Command::from_js(&ctx, value).is_err()); + }); + } + + #[test] + fn test_command_from_request() { + let runtime = Runtime::new().unwrap(); + let context = Context::base(&runtime).unwrap(); + context.with(|ctx| { + let request = + reqwest::Request::new(reqwest::Method::GET, "http://example.com/".parse().unwrap()); + let js_request: WorkerRequest = (&request).try_into().unwrap(); + let value = Object::new(ctx.clone()).unwrap(); + value.set("request", js_request.into_js(&ctx)).unwrap(); + assert!(Command::from_js(&ctx, value.into_value()).is_ok()); + }); + } + + #[test] + fn test_command_from_response() { + let runtime = Runtime::new().unwrap(); + let context = Context::base(&runtime).unwrap(); + context.with(|ctx| { + let js_response = WorkerResponse::try_from(Response { + status: reqwest::StatusCode::OK, + headers: reqwest::header::HeaderMap::default(), + body: Bytes::new(), + }) + .unwrap(); + let value = Object::new(ctx.clone()).unwrap(); + value.set("response", js_response).unwrap(); + assert!(Command::from_js(&ctx, value.into_value()).is_ok()); + }); + } + + #[test] + fn test_command_from_arbitrary_object() { + let runtime = Runtime::new().unwrap(); + let context = Context::base(&runtime).unwrap(); + context.with(|ctx| { + let value = Object::new(ctx.clone()).unwrap(); + assert!(Command::from_js(&ctx, value.into_value()).is_err()); + }); + } #[test] fn test_reqwest_request_to_js_request() { diff --git a/src/cli/javascript/js_response.rs b/src/cli/javascript/js_response.rs deleted file mode 100644 index 2053f5ce3b..0000000000 --- a/src/cli/javascript/js_response.rs +++ /dev/null @@ -1,206 +0,0 @@ -use std::collections::BTreeMap; - -use hyper::body::Bytes; -use rquickjs::{FromJs, IntoJs}; - -use super::create_header_map; -use crate::core::http::Response; -use crate::core::worker::WorkerResponse; - -impl WorkerResponse { - pub fn status(&self) -> u16 { - self.0.status.as_u16() - } - - pub fn headers(&self) -> BTreeMap { - let mut headers = BTreeMap::new(); - for (key, value) in self.0.headers.iter() { - headers.insert(key.to_string(), value.to_str().unwrap().to_string()); - } - headers - } - - pub fn body(&self) -> Option { - let b = self.0.body.as_bytes(); - Some(String::from_utf8_lossy(b).to_string()) - } -} - -impl<'js> IntoJs<'js> for WorkerResponse { - fn into_js(self, ctx: &rquickjs::Ctx<'js>) -> rquickjs::Result> { - let object = rquickjs::Object::new(ctx.clone())?; - object.set("status", self.status())?; - object.set("headers", self.headers())?; - object.set("body", self.body())?; - Ok(object.into_value()) - } -} - -impl<'js> FromJs<'js> for WorkerResponse { - fn from_js(_: &rquickjs::Ctx<'js>, value: rquickjs::Value<'js>) -> rquickjs::Result { - let object = value.as_object().ok_or(rquickjs::Error::FromJs { - from: value.type_name(), - to: "rquickjs::Object", - message: Some("unable to cast JS Value as object".to_string()), - })?; - let status = object.get::<&str, u16>("status")?; - let headers = object.get::<&str, BTreeMap>("headers")?; - let body = object.get::<&str, Option>("body")?; - let response = Response { - status: reqwest::StatusCode::from_u16(status).map_err(|_| rquickjs::Error::FromJs { - from: "u16", - to: "reqwest::StatusCode", - message: Some("invalid status code".to_string()), - })?, - headers: create_header_map(headers).map_err(|e| rquickjs::Error::FromJs { - from: "BTreeMap", - to: "reqwest::header::HeaderMap", - message: Some(e.to_string()), - })?, - body: body.unwrap_or_default(), - }; - Ok(WorkerResponse(response)) - } -} - -impl TryFrom for Response { - type Error = anyhow::Error; - - fn try_from(res: WorkerResponse) -> Result { - let res = res.0; - Ok(Response { - status: res.status, - headers: res.headers, - body: Bytes::from(res.body.as_bytes().to_vec()), - }) - } -} - -impl TryFrom> for WorkerResponse { - type Error = anyhow::Error; - - fn try_from(res: Response) -> Result { - let body = String::from_utf8_lossy(res.body.as_ref()).to_string(); - Ok(WorkerResponse(Response { - status: res.status, - headers: res.headers, - body, - })) - } -} - -#[cfg(test)] -mod test { - use std::collections::BTreeMap; - - use anyhow::Result; - use headers::{HeaderName, HeaderValue}; - use hyper::body::Bytes; - use pretty_assertions::assert_eq; - use reqwest::header::HeaderMap; - use rquickjs::{Context, FromJs, IntoJs, Runtime}; - - use super::WorkerResponse; - - fn create_test_response() -> Result { - let mut headers = HeaderMap::new(); - headers.insert("content-type", "application/json".parse().unwrap()); - let response = crate::core::http::Response { - status: reqwest::StatusCode::OK, - headers, - body: Bytes::from("Hello, World!"), - }; - let js_response: Result = response.try_into(); - js_response - } - - #[test] - fn test_to_js_response() { - let js_response = create_test_response(); - assert!(js_response.is_ok()); - let js_response = js_response.unwrap(); - assert_eq!(js_response.status(), 200); - assert_eq!( - js_response.headers().get("content-type").unwrap(), - "application/json" - ); - assert_eq!(js_response.body(), Some("Hello, World!".into())); - } - - #[test] - fn test_from_js_response() { - let js_response = create_test_response().unwrap(); - let response: Result> = js_response.try_into(); - assert!(response.is_ok()); - let response = response.unwrap(); - assert_eq!(response.status, reqwest::StatusCode::OK); - assert_eq!( - response.headers.get("content-type").unwrap(), - "application/json" - ); - assert_eq!(response.body, Bytes::from("Hello, World!")); - } - - #[test] - fn test_unusual_headers() { - let body = "a"; - let mut headers = HeaderMap::new(); - headers.insert( - HeaderName::from_static("x-unusual-header"), - HeaderValue::from_str("🚀").unwrap(), - ); - let response = crate::core::http::Response { - status: reqwest::StatusCode::OK, - headers, - body: body.into(), - }; - let js_response = WorkerResponse(response); - - let response: Result, _> = js_response.try_into(); - assert!(response.is_ok()); - let response = response.unwrap(); - assert_eq!(response.headers.get("x-unusual-header").unwrap(), "🚀"); - assert_eq!(response.body, Bytes::from(body)); - } - - #[test] - fn test_response_into_js() { - let runtime = Runtime::new().unwrap(); - let context = Context::base(&runtime).unwrap(); - context.with(|ctx| { - let value = create_test_response().unwrap().into_js(&ctx).unwrap(); - let object = value.as_object().unwrap(); - - let status = object.get::<&str, u16>("status").unwrap(); - let headers = object - .get::<&str, BTreeMap>("headers") - .unwrap(); - let body = object.get::<&str, Option>("body").unwrap(); - - assert_eq!(status, reqwest::StatusCode::OK); - assert_eq!(body, Some("Hello, World!".to_owned())); - assert!(headers.contains_key("content-type")); - assert_eq!( - headers.get("content-type"), - Some(&"application/json".to_owned()) - ); - }); - } - - #[test] - fn test_response_from_js() { - let runtime = Runtime::new().unwrap(); - let context = Context::base(&runtime).unwrap(); - context.with(|ctx| { - let js_response = create_test_response().unwrap().into_js(&ctx).unwrap(); - let response = WorkerResponse::from_js(&ctx, js_response).unwrap(); - - assert_eq!(response.status(), reqwest::StatusCode::OK.as_u16()); - assert_eq!(response.body(), Some("Hello, World!".to_owned())); - assert_eq!( - response.headers().get("content-type"), - Some(&"application/json".to_owned()) - ); - }); - } -} diff --git a/src/cli/javascript/mod.rs b/src/cli/javascript/mod.rs index 6dd8cdd2cc..33a2a4d9e2 100644 --- a/src/cli/javascript/mod.rs +++ b/src/cli/javascript/mod.rs @@ -3,27 +3,13 @@ use std::sync::Arc; use hyper::header::{HeaderName, HeaderValue}; -mod js_request; - -mod js_response; - -pub mod request_filter; +pub mod codec; mod runtime; -pub use request_filter::RequestFilter; pub use runtime::Runtime; -use crate::core::{blueprint, HttpIO, WorkerIO}; - -pub fn init_http( - http: Arc, - script: blueprint::Script, -) -> Arc { - tracing::debug!("Initializing JavaScript HTTP filter: {}", script.source); - let script_io = Arc::new(Runtime::new(script)); - Arc::new(RequestFilter::new(http, script_io)) -} +use crate::core::{blueprint, WorkerIO}; pub fn init_worker_io(script: blueprint::Script) -> Arc + Send + Sync> where diff --git a/src/cli/javascript/request_filter.rs b/src/cli/javascript/request_filter.rs deleted file mode 100644 index c40defca40..0000000000 --- a/src/cli/javascript/request_filter.rs +++ /dev/null @@ -1,151 +0,0 @@ -use std::sync::Arc; - -use hyper::body::Bytes; -use rquickjs::FromJs; - -use crate::core::http::Response; -use crate::core::worker::{Command, Event, WorkerRequest, WorkerResponse}; -use crate::core::{HttpIO, WorkerIO}; - -impl<'js> FromJs<'js> for Command { - fn from_js(ctx: &rquickjs::Ctx<'js>, value: rquickjs::Value<'js>) -> rquickjs::Result { - let object = value.as_object().ok_or(rquickjs::Error::FromJs { - from: value.type_name(), - to: "rquickjs::Object", - message: Some("unable to cast JS Value as object".to_string()), - })?; - - if object.contains_key("request")? { - Ok(Command::Request(WorkerRequest::from_js( - ctx, - object.get("request")?, - )?)) - } else if object.contains_key("response")? { - Ok(Command::Response(WorkerResponse::from_js( - ctx, - object.get("response")?, - )?)) - } else { - Err(rquickjs::Error::FromJs { - from: "object", - to: "tailcall::cli::javascript::request_filter::Command", - message: Some("object must contain either request or response".to_string()), - }) - } - } -} - -pub struct RequestFilter { - worker: Arc>, - client: Arc, -} - -impl RequestFilter { - pub fn new( - client: Arc, - worker: Arc>, - ) -> Self { - Self { worker, client } - } - - #[async_recursion::async_recursion] - async fn on_request(&self, mut request: reqwest::Request) -> anyhow::Result> { - let js_request = WorkerRequest::try_from(&request)?; - let event = Event::Request(js_request); - let command = self.worker.call("onRequest", event).await?; - match command { - Some(command) => match command { - Command::Request(js_request) => { - let response = self.client.execute(js_request.into()).await?; - Ok(response) - } - Command::Response(js_response) => { - // Check if the response is a redirect - if (js_response.status() == 301 || js_response.status() == 302) - && js_response.headers().contains_key("location") - { - request - .url_mut() - .set_path(js_response.headers()["location"].as_str()); - self.on_request(request).await - } else { - Ok(js_response.try_into()?) - } - } - }, - None => self.client.execute(request).await, - } - } -} - -#[async_trait::async_trait] -impl HttpIO for RequestFilter { - async fn execute( - &self, - request: reqwest::Request, - ) -> anyhow::Result> { - self.on_request(request).await - } -} - -#[cfg(test)] -mod tests { - use hyper::body::Bytes; - use rquickjs::{Context, FromJs, IntoJs, Object, Runtime, String as JsString}; - - use crate::core::http::Response; - use crate::core::worker::{Command, WorkerRequest, WorkerResponse}; - - #[test] - fn test_command_from_invalid_object() { - let runtime = Runtime::new().unwrap(); - let context = Context::base(&runtime).unwrap(); - context.with(|ctx| { - let value = JsString::from_str(ctx.clone(), "invalid") - .unwrap() - .into_value(); - assert!(Command::from_js(&ctx, value).is_err()); - }); - } - - #[test] - fn test_command_from_request() { - let runtime = Runtime::new().unwrap(); - let context = Context::base(&runtime).unwrap(); - context.with(|ctx| { - let request = - reqwest::Request::new(reqwest::Method::GET, "http://example.com/".parse().unwrap()); - let js_request: WorkerRequest = (&request).try_into().unwrap(); - let value = Object::new(ctx.clone()).unwrap(); - value.set("request", js_request.into_js(&ctx)).unwrap(); - assert!(Command::from_js(&ctx, value.into_value()).is_ok()); - }); - } - - #[test] - fn test_command_from_response() { - let runtime = Runtime::new().unwrap(); - let context = Context::base(&runtime).unwrap(); - context.with(|ctx| { - let js_response = WorkerResponse::try_from(Response { - status: reqwest::StatusCode::OK, - headers: reqwest::header::HeaderMap::default(), - body: Bytes::new(), - }) - .unwrap(); - let value = Object::new(ctx.clone()).unwrap(); - value.set("response", js_response).unwrap(); - assert!(Command::from_js(&ctx, value.into_value()).is_ok()); - }); - } - - #[test] - fn test_command_from_arbitrary_object() { - let runtime = Runtime::new().unwrap(); - let context = Context::base(&runtime).unwrap(); - context.with(|ctx| { - let value = Object::new(ctx.clone()).unwrap(); - assert!(Command::from_js(&ctx, value.into_value()).is_err()); - }); - } -} diff --git a/src/cli/javascript/runtime.rs b/src/cli/javascript/runtime.rs index 18c1f10cd9..0b068d6a6e 100644 --- a/src/cli/javascript/runtime.rs +++ b/src/cli/javascript/runtime.rs @@ -82,7 +82,7 @@ impl Drop for Runtime { #[async_trait::async_trait] impl WorkerIO for Runtime { - async fn call(&self, name: &'async_trait str, event: Event) -> anyhow::Result> { + async fn call(&self, name: &str, event: Event) -> anyhow::Result> { let script = self.script.clone(); let name = name.to_string(); // TODO if let Some(runtime) = &self.tokio_runtime { @@ -100,11 +100,7 @@ impl WorkerIO for Runtime { #[async_trait::async_trait] impl WorkerIO for Runtime { - async fn call( - &self, - name: &'async_trait str, - input: ConstValue, - ) -> anyhow::Result> { + async fn call(&self, name: &str, input: ConstValue) -> anyhow::Result> { let script = self.script.clone(); let name = name.to_string(); let value = serde_json::to_string(&input)?; diff --git a/src/cli/runtime/mod.rs b/src/cli/runtime/mod.rs index b6d5812ec6..b5e3d59246 100644 --- a/src/cli/runtime/mod.rs +++ b/src/cli/runtime/mod.rs @@ -23,17 +23,6 @@ fn init_file() -> Arc { Arc::new(file::NativeFileIO::init()) } -fn init_hook_http(http: Arc, script: Option) -> Arc { - #[cfg(feature = "js")] - if let Some(script) = script { - return super::javascript::init_http(http, script); - } - - let _ = script; - - http -} - fn init_http_worker_io( script: Option, ) -> Option>> { @@ -60,17 +49,18 @@ fn init_resolver_worker_io( // Provides access to http in native rust environment fn init_http(blueprint: &Blueprint) -> Arc { - let http_io = http::NativeHttp::init(&blueprint.upstream, &blueprint.telemetry); - init_hook_http(Arc::new(http_io), blueprint.server.script.clone()) + Arc::new(http::NativeHttp::init( + &blueprint.upstream, + &blueprint.telemetry, + )) } // Provides access to http in native rust environment fn init_http2_only(blueprint: &Blueprint) -> Arc { - let http_io = http::NativeHttp::init( + Arc::new(http::NativeHttp::init( &blueprint.upstream.clone().http2_only(true), &blueprint.telemetry, - ); - init_hook_http(Arc::new(http_io), blueprint.server.script.clone()) + )) } fn init_in_memory_cache() -> InMemoryCache { @@ -88,7 +78,7 @@ pub fn init(blueprint: &Blueprint) -> TargetRuntime { file: init_file(), cache: Arc::new(init_in_memory_cache()), extensions: Arc::new(vec![]), - http_worker: init_http_worker_io(blueprint.server.script.clone()), + cmd_worker: init_http_worker_io(blueprint.server.script.clone()), worker: init_resolver_worker_io(blueprint.server.script.clone()), } } diff --git a/src/core/app_context.rs b/src/core/app_context.rs index a87acb753c..85b394099a 100644 --- a/src/core/app_context.rs +++ b/src/core/app_context.rs @@ -44,7 +44,7 @@ impl AppContext { field.map_expr(|expr| { expr.modify(|expr| match expr { IR::IO(io) => match io { - IO::Http { req_template, group_by, .. } => { + IO::Http { req_template, group_by, http_filter, .. } => { let data_loader = HttpDataLoader::new( runtime.clone(), group_by.clone(), @@ -55,7 +55,8 @@ impl AppContext { let result = Some(IR::IO(IO::Http { req_template: req_template.clone(), group_by: group_by.clone(), - dl_id: Some(DataLoaderId(http_data_loaders.len())), + dl_id: Some(DataLoaderId::new(http_data_loaders.len())), + http_filter: http_filter.clone(), })); http_data_loaders.push(data_loader); @@ -74,7 +75,7 @@ impl AppContext { req_template: req_template.clone(), field_name: field_name.clone(), batch: *batch, - dl_id: Some(DataLoaderId(gql_data_loaders.len())), + dl_id: Some(DataLoaderId::new(gql_data_loaders.len())), })); gql_data_loaders.push(graphql_data_loader); @@ -95,7 +96,7 @@ impl AppContext { let result = Some(IR::IO(IO::Grpc { req_template: req_template.clone(), group_by: group_by.clone(), - dl_id: Some(DataLoaderId(grpc_data_loaders.len())), + dl_id: Some(DataLoaderId::new(grpc_data_loaders.len())), })); grpc_data_loaders.push(data_loader); diff --git a/src/core/blueprint/operators/http.rs b/src/core/blueprint/operators/http.rs index 6c7b172c6e..48a18d3a1a 100644 --- a/src/core/blueprint/operators/http.rs +++ b/src/core/blueprint/operators/http.rs @@ -2,7 +2,7 @@ use crate::core::blueprint::*; use crate::core::config::group_by::GroupBy; use crate::core::config::Field; use crate::core::endpoint::Endpoint; -use crate::core::http::{Method, RequestTemplate}; +use crate::core::http::{HttpFilter, Method, RequestTemplate}; use crate::core::ir::{IO, IR}; use crate::core::try_fold::TryFold; use crate::core::valid::{Valid, ValidationError, Validator}; @@ -59,14 +59,22 @@ pub fn compile_http( .into() }) .map(|req_template| { + // marge http and upstream on_request + let http_filter = http + .on_request + .clone() + .or(config_module.upstream.on_request.clone()) + .map(|on_request| HttpFilter { on_request }); + if !http.group_by.is_empty() && http.method == Method::GET { IR::IO(IO::Http { req_template, group_by: Some(GroupBy::new(http.group_by.clone())), dl_id: None, + http_filter, }) } else { - IR::IO(IO::Http { req_template, group_by: None, dl_id: None }) + IR::IO(IO::Http { req_template, group_by: None, dl_id: None, http_filter }) } }) } diff --git a/src/core/blueprint/upstream.rs b/src/core/blueprint/upstream.rs index 564e3e2612..0fc2bd5cc6 100644 --- a/src/core/blueprint/upstream.rs +++ b/src/core/blueprint/upstream.rs @@ -28,6 +28,7 @@ pub struct Upstream { pub batch: Option, pub http2_only: bool, pub dedupe: bool, + pub on_request: Option, } impl Upstream { @@ -82,6 +83,7 @@ impl TryFrom<&ConfigModule> for Upstream { batch, http2_only: (config_upstream).get_http_2_only(), dedupe: (config_upstream).get_dedupe(), + on_request: (config_upstream).get_on_request(), }) .to_result() } diff --git a/src/core/config/config.rs b/src/core/config/config.rs index ef2e003fa1..1c20acbbc1 100644 --- a/src/core/config/config.rs +++ b/src/core/config/config.rs @@ -411,6 +411,11 @@ pub struct Enum { /// REST API. In this scenario, the GraphQL server will make a GET request to /// the API endpoint specified when the `users` field is queried. pub struct Http { + #[serde(rename = "onRequest", default, skip_serializing_if = "is_default")] + /// onRequest field in @http directive gives the ability to specify the + /// request interception handler. + pub on_request: Option, + #[serde(rename = "baseURL", default, skip_serializing_if = "is_default")] /// This refers to the base URL of the API. If not specified, the default /// base URL is the one specified in the `@upstream` operator. diff --git a/src/core/config/upstream.rs b/src/core/config/upstream.rs index dfd32f613d..d8779e84e2 100644 --- a/src/core/config/upstream.rs +++ b/src/core/config/upstream.rs @@ -52,6 +52,11 @@ pub struct Proxy { /// upstream server connection. This includes settings like connection timeouts, /// keep-alive intervals, and more. If not specified, default values are used. pub struct Upstream { + #[serde(rename = "onRequest", default, skip_serializing_if = "is_default")] + /// onRequest field gives the ability to specify the global request + /// interception handler. + pub on_request: Option, + #[serde(default, skip_serializing_if = "is_default")] /// `allowedHeaders` defines the HTTP headers allowed to be forwarded to /// upstream services. If not set, no headers are forwarded, enhancing @@ -195,6 +200,10 @@ impl Upstream { pub fn get_dedupe(&self) -> bool { self.dedupe.unwrap_or(false) } + + pub fn get_on_request(&self) -> Option { + self.on_request.clone() + } } #[cfg(test)] diff --git a/src/core/generator/graphql_type.rs b/src/core/generator/graphql_type.rs index 5835cbb70f..6dcd527368 100644 --- a/src/core/generator/graphql_type.rs +++ b/src/core/generator/graphql_type.rs @@ -132,7 +132,6 @@ impl GraphQLType { } } -// FIXME: make it private /// Used to convert proto type names to GraphQL formatted names. /// Enum to represent the type of the descriptor #[derive(Clone, Debug, PartialEq)] diff --git a/src/core/http/mod.rs b/src/core/http/mod.rs index 6de8bec1ca..899e28d899 100644 --- a/src/core/http/mod.rs +++ b/src/core/http/mod.rs @@ -24,3 +24,16 @@ mod telemetry; pub static TAILCALL_HTTPS_ORIGIN: HeaderValue = HeaderValue::from_static("https://tailcall.run"); pub static TAILCALL_HTTP_ORIGIN: HeaderValue = HeaderValue::from_static("http://tailcall.run"); + +#[derive(Default, Clone, Debug)] +/// User can configure the filter/interceptor +/// for the http requests. +pub struct HttpFilter { + pub on_request: String, +} + +impl HttpFilter { + pub fn new(on_request: &str) -> Self { + HttpFilter { on_request: on_request.to_owned() } + } +} diff --git a/src/core/ir/io.rs b/src/core/ir/io.rs index b52df36d30..8d00906184 100644 --- a/src/core/ir/io.rs +++ b/src/core/ir/io.rs @@ -15,11 +15,14 @@ use crate::core::grpc::data_loader::GrpcDataLoader; use crate::core::grpc::protobuf::ProtobufOperation; use crate::core::grpc::request::execute_grpc_request; use crate::core::grpc::request_template::RenderedRequestTemplate; -use crate::core::http::{cache_policy, DataLoaderRequest, HttpDataLoader, Response}; +use crate::core::http::{ + cache_policy, DataLoaderRequest, HttpDataLoader, HttpFilter, RequestTemplate, Response, +}; use crate::core::ir::EvaluationError; use crate::core::json::JsonLike; use crate::core::valid::Validator; -use crate::core::{grpc, http}; +use crate::core::worker::*; +use crate::core::{grpc, http, WorkerIO}; #[derive(Clone, Debug, strum_macros::Display)] pub enum IO { @@ -27,6 +30,7 @@ pub enum IO { req_template: http::RequestTemplate, group_by: Option, dl_id: Option, + http_filter: Option, }, GraphQL { req_template: graphql::RequestTemplate, @@ -45,7 +49,15 @@ pub enum IO { } #[derive(Clone, Copy, Debug)] -pub struct DataLoaderId(pub usize); +pub struct DataLoaderId(usize); +impl DataLoaderId { + pub fn new(id: usize) -> Self { + Self(id) + } + pub fn as_usize(&self) -> usize { + self.0 + } +} impl Eval for IO { fn eval<'a, Ctx: super::ResolverContextLike<'a> + Sync + Send>( @@ -79,30 +91,20 @@ impl IO { ) -> Pin> + 'a + Send>> { Box::pin(async move { match self { - IO::Http { req_template, dl_id, .. } => { - let req = req_template.to_request(&ctx)?; - let is_get = req.method() == reqwest::Method::GET; - - let res = if is_get && ctx.request_ctx.is_batching_enabled() { - let data_loader: Option<&DataLoader> = - dl_id.and_then(|index| ctx.request_ctx.http_data_loaders.get(index.0)); - execute_request_with_dl(&ctx, req, data_loader).await? - } else { - execute_raw_request(&ctx, req).await? + IO::Http { req_template, dl_id, http_filter, .. } => { + let worker = &ctx.request_ctx.runtime.cmd_worker; + let executor = HttpRequestExecutor::new(ctx, req_template, dl_id); + let request = executor.init_request()?; + let response = match (&worker, http_filter) { + (Some(worker), Some(http_filter)) => { + executor + .execute_with_worker(request, worker, http_filter) + .await? + } + _ => executor.execute(request).await?, }; - if ctx.request_ctx.server.get_enable_http_validation() { - req_template - .endpoint - .output - .validate(&res.body) - .to_result() - .map_err(EvaluationError::from)?; - } - - set_headers(&ctx, &res); - - Ok(res.body) + Ok(response.body) } IO::GraphQL { req_template, field_name, dl_id, .. } => { let req = req_template.to_request(&ctx)?; @@ -306,3 +308,98 @@ fn parse_graphql_response<'ctx, Ctx: ResolverContextLike<'ctx>>( .map(|v| v.to_owned()) .unwrap_or_default()) } + +/// +/// Executing a HTTP request is a bit more complex than just sending a request +/// and getting a response. There are optimizations and customizations that the +/// user might have configured. HttpRequestExecutor is responsible for handling +/// all of that. +struct HttpRequestExecutor<'a, Context: ResolverContextLike<'a> + Send + Sync> { + evaluation_ctx: EvaluationContext<'a, Context>, + data_loader: Option<&'a DataLoader>, + request_template: &'a http::RequestTemplate, +} + +impl<'a, Context: ResolverContextLike<'a> + Send + Sync> HttpRequestExecutor<'a, Context> { + pub fn new( + evaluation_ctx: EvaluationContext<'a, Context>, + request_template: &'a RequestTemplate, + id: &Option, + ) -> Self { + let data_loader = if evaluation_ctx.request_ctx.is_batching_enabled() { + id.and_then(|id| evaluation_ctx.request_ctx.http_data_loaders.get(id.0)) + } else { + None + }; + + Self { evaluation_ctx, data_loader, request_template } + } + + pub fn init_request(&self) -> Result { + let ctx = &self.evaluation_ctx; + Ok(self.request_template.to_request(ctx)?) + } + + async fn execute( + &self, + req: Request, + ) -> Result, EvaluationError> { + let ctx = &self.evaluation_ctx; + let is_get = req.method() == reqwest::Method::GET; + let dl = &self.data_loader; + let response = if is_get && dl.is_some() { + execute_request_with_dl(ctx, req, self.data_loader).await? + } else { + execute_raw_request(ctx, req).await? + }; + + if ctx.request_ctx.server.get_enable_http_validation() { + self.request_template + .endpoint + .output + .validate(&response.body) + .to_result() + .map_err(EvaluationError::from)?; + } + + set_headers(ctx, &response); + + Ok(response) + } + + #[async_recursion::async_recursion] + async fn execute_with_worker( + &self, + mut request: reqwest::Request, + worker: &Arc>, + http_filter: &HttpFilter, + ) -> Result, EvaluationError> { + let js_request = WorkerRequest::try_from(&request)?; + let event = Event::Request(js_request); + + let command = worker.call(&http_filter.on_request, event).await?; + + match command { + Some(command) => match command { + Command::Request(w_request) => { + let response = self.execute(w_request.into()).await?; + Ok(response) + } + Command::Response(w_response) => { + // Check if the response is a redirect + if (w_response.status() == 301 || w_response.status() == 302) + && w_response.headers().contains_key("location") + { + request + .url_mut() + .set_path(w_response.headers()["location"].as_str()); + self.execute_with_worker(request, worker, http_filter).await + } else { + Ok(w_response.try_into()?) + } + } + }, + None => self.execute(request).await, + } + } +} diff --git a/src/core/mod.rs b/src/core/mod.rs index de4bd1dc72..ed7bcf5cb4 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -57,7 +57,9 @@ pub trait HttpIO: Sync + Send + 'static { async fn execute( &self, request: reqwest::Request, - ) -> anyhow::Result>; + ) -> anyhow::Result> { + self.execute(request).await + } } #[async_trait::async_trait] @@ -86,7 +88,7 @@ pub type EntityCache = dyn Cache; #[async_trait::async_trait] pub trait WorkerIO: Send + Sync + 'static { /// Calls a global JS function - async fn call(&self, name: &'async_trait str, input: In) -> anyhow::Result>; + async fn call(&self, name: &str, input: In) -> anyhow::Result>; } pub fn is_default(val: &T) -> bool { diff --git a/src/core/runtime.rs b/src/core/runtime.rs index 6dcbf0c0db..c7741156f8 100644 --- a/src/core/runtime.rs +++ b/src/core/runtime.rs @@ -29,7 +29,7 @@ pub struct TargetRuntime { /// functionality or integrate additional features. pub extensions: Arc>, /// Worker middleware for handling HTTP requests. - pub http_worker: Option>>, + pub cmd_worker: Option>>, /// Worker middleware for resolving data. pub worker: Option>>, } @@ -48,6 +48,7 @@ pub mod test { use std::time::Duration; use anyhow::{anyhow, Result}; + use async_graphql::Value; use http_cache_reqwest::{Cache, CacheMode, HttpCache, HttpCacheOptions}; use hyper::body::Bytes; use reqwest::Client; @@ -55,11 +56,12 @@ pub mod test { use tailcall_http_cache::HttpCacheManager; use tokio::io::{AsyncReadExt, AsyncWriteExt}; - use crate::cli::javascript; + use crate::cli::javascript::init_worker_io; use crate::core::blueprint::Upstream; use crate::core::cache::InMemoryCache; use crate::core::http::Response; use crate::core::runtime::TargetRuntime; + use crate::core::worker::{Command, Event}; use crate::core::{blueprint, EnvIO, FileIO, HttpIO}; #[derive(Clone)] @@ -172,20 +174,8 @@ pub mod test { } pub fn init(script: Option) -> TargetRuntime { - let http = if let Some(script) = script.clone() { - javascript::init_http(TestHttp::init(&Default::default()), script) - } else { - TestHttp::init(&Default::default()) - }; - - let http2 = if let Some(script) = script { - javascript::init_http( - TestHttp::init(&Upstream::default().http2_only(true)), - script, - ) - } else { - TestHttp::init(&Upstream::default().http2_only(true)) - }; + let http = TestHttp::init(&Default::default()); + let http2 = TestHttp::init(&Upstream::default().http2_only(true)); let file = TestFileIO::init(); let env = TestEnvIO::init(); @@ -197,8 +187,14 @@ pub mod test { file: Arc::new(file), cache: Arc::new(InMemoryCache::new()), extensions: Arc::new(vec![]), - http_worker: None, - worker: None, + cmd_worker: match &script { + Some(script) => Some(init_worker_io::(script.to_owned())), + None => None, + }, + worker: match &script { + Some(script) => Some(init_worker_io::(script.to_owned())), + None => None, + }, } } } diff --git a/src/core/worker.rs b/src/core/worker.rs index 810428407a..ea72919206 100644 --- a/src/core/worker.rs +++ b/src/core/worker.rs @@ -1,4 +1,31 @@ -use crate::core::Response; +use std::collections::BTreeMap; +use std::fmt::Display; + +use hyper::body::Bytes; +use reqwest::Request; +use serde::{Deserialize, Serialize}; + +use crate::core::{is_default, Response}; + +#[derive(Serialize, Deserialize, Default, Debug, PartialEq, Eq)] +pub enum Scheme { + #[default] + Http, + Https, +} +#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] +#[serde(rename_all = "camelCase")] +pub struct Uri { + pub path: String, + #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] + pub query: BTreeMap, + #[serde(default, skip_serializing_if = "is_default")] + pub scheme: Scheme, + #[serde(default, skip_serializing_if = "is_default")] + pub host: Option, + #[serde(default, skip_serializing_if = "is_default")] + pub port: Option, +} #[derive(Debug)] pub struct WorkerResponse(pub Response); @@ -16,3 +43,160 @@ pub enum Command { Request(WorkerRequest), Response(WorkerResponse), } + +impl WorkerResponse { + pub fn status(&self) -> u16 { + self.0.status.as_u16() + } + + pub fn headers(&self) -> BTreeMap { + let mut headers = BTreeMap::new(); + for (key, value) in self.0.headers.iter() { + headers.insert(key.to_string(), value.to_str().unwrap().to_string()); + } + headers + } + + pub fn body(&self) -> Option { + let b = self.0.body.as_bytes(); + Some(String::from_utf8_lossy(b).to_string()) + } +} + +impl TryFrom for Response { + type Error = anyhow::Error; + + fn try_from(res: WorkerResponse) -> Result { + let res = res.0; + Ok(Response { + status: res.status, + headers: res.headers, + body: Bytes::from(res.body.as_bytes().to_vec()), + }) + } +} + +impl TryFrom for Response { + type Error = anyhow::Error; + + fn try_from(res: WorkerResponse) -> Result { + let body: async_graphql::Value = match res.body() { + Some(body) => serde_json::from_str(&body)?, + None => async_graphql::Value::Null, + }; + + Ok(Response { status: res.0.status, headers: res.0.headers, body }) + } +} + +impl TryFrom> for WorkerResponse { + type Error = anyhow::Error; + + fn try_from(res: Response) -> Result { + let body = String::from_utf8_lossy(res.body.as_ref()).to_string(); + Ok(WorkerResponse(Response { + status: res.status, + headers: res.headers, + body, + })) + } +} + +impl TryFrom> for WorkerResponse { + type Error = anyhow::Error; + + fn try_from(res: Response) -> Result { + let body = serde_json::to_string(&res.body)?; + Ok(WorkerResponse(Response { + status: res.status, + headers: res.headers, + body, + })) + } +} + +impl Display for Uri { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let host = self.host.as_deref().unwrap_or("localhost"); + let port = self.port.map(|p| format!(":{}", p)).unwrap_or_default(); + let scheme = match self.scheme { + Scheme::Https => "https", + _ => "http", + }; + let path = self.path.as_str(); + let query = self + .query + .iter() + .map(|(k, v)| format!("{}={}", k, v)) + .collect::>() + .join("&"); + + write!(f, "{}://{}{}{}", scheme, host, port, path)?; + + if !query.is_empty() { + write!(f, "?{}", query)?; + } + + Ok(()) + } +} + +impl WorkerRequest { + pub fn uri(&self) -> Uri { + self.0.url().into() + } + + pub fn method(&self) -> String { + self.0.method().to_string() + } + + pub fn headers(&self) -> anyhow::Result> { + let headers = self.0.headers(); + let mut map = BTreeMap::new(); + for (k, v) in headers.iter() { + map.insert(k.to_string(), v.to_str()?.to_string()); + } + Ok(map) + } + + pub fn body(&self) -> Option { + if let Some(body) = self.0.body() { + let bytes = body.as_bytes()?; + Some(String::from_utf8_lossy(bytes).to_string()) + } else { + None + } + } +} + +impl TryFrom<&reqwest::Request> for WorkerRequest { + type Error = anyhow::Error; + + fn try_from(value: &Request) -> Result { + let request = value + .try_clone() + .ok_or(anyhow::anyhow!("unable to clone request"))?; + Ok(WorkerRequest(request)) + } +} + +impl From for reqwest::Request { + fn from(val: WorkerRequest) -> Self { + val.0 + } +} + +impl From<&reqwest::Url> for Uri { + fn from(value: &reqwest::Url) -> Self { + Self { + path: value.path().to_string(), + query: value.query_pairs().into_owned().collect(), + scheme: match value.scheme() { + "https" => Scheme::Https, + _ => Scheme::Http, + }, + host: value.host_str().map(|u| u.to_string()), + port: value.port(), + } + } +} diff --git a/tailcall-aws-lambda/src/runtime.rs b/tailcall-aws-lambda/src/runtime.rs index 1c041c7ee2..53b9d58bec 100644 --- a/tailcall-aws-lambda/src/runtime.rs +++ b/tailcall-aws-lambda/src/runtime.rs @@ -60,7 +60,7 @@ pub fn init_runtime() -> TargetRuntime { env: init_env(), cache: init_cache(), extensions: Arc::new(vec![]), - http_worker: None, + cmd_worker: None, worker: None, } } diff --git a/tailcall-cloudflare/src/runtime.rs b/tailcall-cloudflare/src/runtime.rs index a966ecfac0..a0b84d00bb 100644 --- a/tailcall-cloudflare/src/runtime.rs +++ b/tailcall-cloudflare/src/runtime.rs @@ -41,7 +41,7 @@ pub fn init(env: Rc) -> anyhow::Result { file: init_file(env.clone(), &bucket_id)?, cache: init_cache(env), extensions: Arc::new(vec![]), - http_worker: None, + cmd_worker: None, worker: None, }) } diff --git a/tests/core/parse.rs b/tests/core/parse.rs index 6d3fb0f631..10fef0528a 100644 --- a/tests/core/parse.rs +++ b/tests/core/parse.rs @@ -270,15 +270,10 @@ impl ExecutionSpec { &self, config: &ConfigModule, env: HashMap, - http_client: Arc, + http: Arc, ) -> Arc { let blueprint = Blueprint::try_from(config).unwrap(); let script = blueprint.server.script.clone(); - let http = if let Some(script) = script.clone() { - javascript::init_http(http_client, script) - } else { - http_client - }; let http2_only = http.clone(); @@ -303,7 +298,7 @@ impl ExecutionSpec { env: Arc::new(Env::init(env)), cache: Arc::new(InMemoryCache::new()), extensions: Arc::new(vec![]), - http_worker, + cmd_worker: http_worker, worker, }; diff --git a/tests/core/runtime.rs b/tests/core/runtime.rs index 7146540786..cb0da3d72b 100644 --- a/tests/core/runtime.rs +++ b/tests/core/runtime.rs @@ -5,12 +5,14 @@ use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use async_graphql::Value; use derive_setters::Setters; -use tailcall::cli::javascript; +use tailcall::cli::javascript::init_worker_io; use tailcall::core::blueprint::Script; use tailcall::core::cache::InMemoryCache; use tailcall::core::config::Source; use tailcall::core::runtime::TargetRuntime; +use tailcall::core::worker::{Command, Event}; use super::env::Env; use super::file::TestFileIO; @@ -65,17 +67,9 @@ pub fn create_runtime( env: Option>, script: Option