From 631215914df9cf6e4d906cbebaad0147005c396f Mon Sep 17 00:00:00 2001 From: Shashi Kant Date: Fri, 26 Apr 2024 14:17:05 +0000 Subject: [PATCH] creating request in cache_key --- benches/request_template_bench.rs | 13 ++++- src/http/request_template.rs | 97 +++++++++++++++++-------------- src/lambda/io.rs | 8 ++- 3 files changed, 72 insertions(+), 46 deletions(-) diff --git a/benches/request_template_bench.rs b/benches/request_template_bench.rs index bc67c1486f..4cd609fb66 100644 --- a/benches/request_template_bench.rs +++ b/benches/request_template_bench.rs @@ -1,4 +1,5 @@ use std::borrow::Cow; +use std::collections::hash_map::DefaultHasher; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use derive_setters::Setters; @@ -49,13 +50,21 @@ fn benchmark_to_request(c: &mut Criterion) { c.bench_function("with_mustache_literal", |b| { b.iter(|| { - black_box(tmpl_literal.to_request(&ctx).unwrap()); + black_box( + tmpl_literal + .to_request(&ctx, None::) + .unwrap(), + ); }) }); c.bench_function("with_mustache_expressions", |b| { b.iter(|| { - black_box(tmpl_mustache.to_request(&ctx).unwrap()); + black_box( + tmpl_mustache + .to_request(&ctx, None::) + .unwrap(), + ); }) }); } diff --git a/src/http/request_template.rs b/src/http/request_template.rs index 761c1a362e..f88996619f 100644 --- a/src/http/request_template.rs +++ b/src/http/request_template.rs @@ -1,6 +1,7 @@ use std::borrow::Cow; use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; +use std::sync::{Arc, Mutex}; use derive_setters::Setters; use hyper::HeaderMap; @@ -28,6 +29,7 @@ pub struct RequestTemplate { pub body_path: Option, pub endpoint: Endpoint, pub encoding: Encoding, + pub rendered_request: Arc>>>, } impl RequestTemplate { @@ -97,13 +99,19 @@ impl RequestTemplate { pub fn to_request( &self, ctx: &C, + mut hasher: Option, ) -> anyhow::Result { // Create url let url = self.create_url(ctx)?; let method = self.method.clone(); + if let Some(hasher) = hasher.as_mut() { + url.hash(hasher); + method.hash(hasher); + } let mut req = reqwest::Request::new(method, url); - req = self.set_headers(req, ctx); - req = self.set_body(req, ctx)?; + + req = self.set_headers(req, ctx, hasher.as_mut()); + req = self.set_body(req, ctx, hasher.as_mut())?; Ok(req) } @@ -113,11 +121,16 @@ impl RequestTemplate { &self, mut req: reqwest::Request, ctx: &C, + mut hasher: Option, ) -> anyhow::Result { if let Some(body_path) = &self.body_path { match &self.encoding { Encoding::ApplicationJson => { - req.body_mut().replace(body_path.render(ctx).into()); + let body = body_path.render(ctx); + if let Some(hasher) = hasher.as_mut() { + body.hash(hasher); + } + req.body_mut().replace(body.into()); } Encoding::ApplicationXWwwFormUrlencoded => { // TODO: this is a performance bottleneck @@ -128,6 +141,9 @@ impl RequestTemplate { Err(_) => body, }; + if let Some(hasher) = hasher.as_mut() { + form_data.hash(hasher); + } req.body_mut().replace(form_data.into()); } } @@ -140,6 +156,7 @@ impl RequestTemplate { &self, mut req: reqwest::Request, ctx: &C, + mut hasher: Option, ) -> reqwest::Request { let headers = self.create_headers(ctx); if !headers.is_empty() { @@ -163,6 +180,14 @@ impl RequestTemplate { } headers.extend(ctx.headers().to_owned()); + + if let Some(hasher) = hasher.as_mut() { + for (key, value) in headers.iter() { + key.hash(hasher); + value.hash(hasher); + } + } + req } @@ -175,6 +200,7 @@ impl RequestTemplate { body_path: Default::default(), endpoint: Endpoint::new(root_url.to_string()), encoding: Default::default(), + rendered_request: Default::default(), }) } @@ -220,6 +246,7 @@ impl TryFrom for RequestTemplate { body_path: body, endpoint, encoding, + rendered_request: Default::default(), }) } } @@ -229,27 +256,8 @@ impl CacheKey for RequestTemplate { let mut hasher = DefaultHasher::new(); let state = &mut hasher; - self.method.hash(state); - - let mut headers = vec![]; - for (name, mustache) in self.headers.iter() { - name.hash(state); - mustache.render(ctx).hash(state); - headers.push((name.to_string(), mustache.render(ctx))); - } - - for (name, value) in ctx.headers().iter() { - name.hash(state); - value.hash(state); - headers.push((name.to_string(), value.to_str().unwrap().to_string())); - } - - if let Some(body) = self.body_path.as_ref() { - body.render(ctx).hash(state) - } - - let url = self.create_url(ctx).unwrap(); - url.hash(state); + let request = self.to_request(ctx, Some(state)); + self.rendered_request.lock().unwrap().replace(request); hasher.finish() } @@ -258,6 +266,7 @@ impl CacheKey for RequestTemplate { #[cfg(test)] mod tests { use std::borrow::Cow; + use std::collections::hash_map::DefaultHasher; use derive_setters::Setters; use hyper::header::HeaderName; @@ -297,7 +306,7 @@ mod tests { impl RequestTemplate { fn to_body(&self, ctx: &C) -> anyhow::Result { let body = self - .to_request(ctx)? + .to_request(ctx, None::)? .body() .and_then(|a| a.as_bytes()) .map(|a| a.to_vec()) @@ -311,7 +320,7 @@ mod tests { fn test_url() { let tmpl = RequestTemplate::new("http://localhost:3000/").unwrap(); let ctx = Context::default(); - let req = tmpl.to_request(&ctx).unwrap(); + let req = tmpl.to_request(&ctx, None::).unwrap(); assert_eq!(req.url().to_string(), "http://localhost:3000/"); } @@ -319,7 +328,7 @@ mod tests { fn test_url_path() { let tmpl = RequestTemplate::new("http://localhost:3000/foo/bar").unwrap(); let ctx = Context::default(); - let req = tmpl.to_request(&ctx).unwrap(); + let req = tmpl.to_request(&ctx, None::).unwrap(); assert_eq!(req.url().to_string(), "http://localhost:3000/foo/bar"); } @@ -332,7 +341,7 @@ mod tests { } })); - let req = tmpl.to_request(&ctx).unwrap(); + let req = tmpl.to_request(&ctx, None::).unwrap(); assert_eq!(req.url().to_string(), "http://localhost:3000/foo/bar"); } @@ -347,7 +356,7 @@ mod tests { "booz": 1 } })); - let req = tmpl.to_request(&ctx).unwrap(); + let req = tmpl.to_request(&ctx, None::).unwrap(); assert_eq!( req.url().to_string(), "http://localhost:3000/foo/bar/boozes/1" @@ -365,7 +374,7 @@ mod tests { .unwrap() .query(query); let ctx = Context::default(); - let req = tmpl.to_request(&ctx).unwrap(); + let req = tmpl.to_request(&ctx, None::).unwrap(); assert_eq!( req.url().to_string(), "http://localhost:3000/?foo=0&bar=1&baz=2" @@ -390,7 +399,7 @@ mod tests { "id": 2 } })); - let req = tmpl.to_request(&ctx).unwrap(); + let req = tmpl.to_request(&ctx, None::).unwrap(); assert_eq!( req.url().to_string(), "http://localhost:3000/?foo=0&bar=1&baz=2" @@ -417,7 +426,7 @@ mod tests { .unwrap() .headers(headers); let ctx = Context::default(); - let req = tmpl.to_request(&ctx).unwrap(); + let req = tmpl.to_request(&ctx, None::).unwrap(); assert_eq!(req.headers().get("foo").unwrap(), "foo"); assert_eq!(req.headers().get("bar").unwrap(), "bar"); assert_eq!(req.headers().get("baz").unwrap(), "baz"); @@ -450,7 +459,7 @@ mod tests { "id": 2 } })); - let req = tmpl.to_request(&ctx).unwrap(); + let req = tmpl.to_request(&ctx, None::).unwrap(); assert_eq!(req.headers().get("foo").unwrap(), "0"); assert_eq!(req.headers().get("bar").unwrap(), "1"); assert_eq!(req.headers().get("baz").unwrap(), "2"); @@ -463,7 +472,7 @@ mod tests { .method(reqwest::Method::POST) .encoding(crate::config::Encoding::ApplicationJson); let ctx = Context::default(); - let req = tmpl.to_request(&ctx).unwrap(); + let req = tmpl.to_request(&ctx, None::).unwrap(); assert_eq!( req.headers().get("Content-Type").unwrap(), "application/json" @@ -477,7 +486,7 @@ mod tests { .method(reqwest::Method::POST) .encoding(crate::config::Encoding::ApplicationXWwwFormUrlencoded); let ctx = Context::default(); - let req = tmpl.to_request(&ctx).unwrap(); + let req = tmpl.to_request(&ctx, None::).unwrap(); assert_eq!( req.headers().get("Content-Type").unwrap(), "application/x-www-form-urlencoded" @@ -490,7 +499,7 @@ mod tests { .unwrap() .method(reqwest::Method::POST); let ctx = Context::default(); - let req = tmpl.to_request(&ctx).unwrap(); + let req = tmpl.to_request(&ctx, None::).unwrap(); assert_eq!(req.method(), reqwest::Method::POST); } @@ -534,6 +543,8 @@ mod tests { } mod endpoint { + use std::collections::hash_map::DefaultHasher; + use hyper::HeaderMap; use serde_json::json; @@ -550,7 +561,7 @@ mod tests { .body(Some("foo".into())); let tmpl = RequestTemplate::try_from(endpoint).unwrap(); let ctx = Context::default(); - let req = tmpl.to_request(&ctx).unwrap(); + let req = tmpl.to_request(&ctx, None::).unwrap(); assert_eq!(req.method(), reqwest::Method::POST); assert_eq!(req.headers().get("foo").unwrap(), "bar"); let body = req.body().unwrap().as_bytes().unwrap().to_owned(); @@ -575,7 +586,7 @@ mod tests { "header": "abc" } })); - let req = tmpl.to_request(&ctx).unwrap(); + let req = tmpl.to_request(&ctx, None::).unwrap(); assert_eq!(req.method(), reqwest::Method::POST); assert_eq!(req.headers().get("foo").unwrap(), "abc"); let body = req.body().unwrap().as_bytes().unwrap().to_owned(); @@ -589,7 +600,7 @@ mod tests { crate::endpoint::Endpoint::new("http://localhost:3000/?a={{args.a}}".to_string()); let tmpl = RequestTemplate::try_from(endpoint).unwrap(); let ctx = Context::default(); - let req = tmpl.to_request(&ctx).unwrap(); + let req = tmpl.to_request(&ctx, None::).unwrap(); assert_eq!(req.url().to_string(), "http://localhost:3000/"); } @@ -604,7 +615,7 @@ mod tests { ]); let tmpl = RequestTemplate::try_from(endpoint).unwrap(); let ctx = Context::default(); - let req = tmpl.to_request(&ctx).unwrap(); + let req = tmpl.to_request(&ctx, None::).unwrap(); assert_eq!(req.url().to_string(), "http://localhost:3000/?q=1&b=1"); } @@ -620,7 +631,7 @@ mod tests { "d": "bar" } })); - let req = tmpl.to_request(&ctx).unwrap(); + let req = tmpl.to_request(&ctx, None::).unwrap(); assert_eq!( req.url().to_string(), "http://localhost:3000/foo?b=foo&d=bar" @@ -644,7 +655,7 @@ mod tests { "f": "baz" } })); - let req = tmpl.to_request(&ctx).unwrap(); + let req = tmpl.to_request(&ctx, None::).unwrap(); assert_eq!( req.url().to_string(), "http://localhost:3000/foo?b=foo&d=bar&f=baz" @@ -658,7 +669,7 @@ mod tests { let mut headers = HeaderMap::new(); headers.insert("baz", "qux".parse().unwrap()); let ctx = Context::default().headers(headers); - let req = tmpl.to_request(&ctx).unwrap(); + let req = tmpl.to_request(&ctx, None::).unwrap(); assert_eq!(req.headers().get("baz").unwrap(), "qux"); } } diff --git a/src/lambda/io.rs b/src/lambda/io.rs index 97d101b379..36a4c43fc9 100644 --- a/src/lambda/io.rs +++ b/src/lambda/io.rs @@ -1,4 +1,5 @@ use core::future::Future; +use std::collections::hash_map::DefaultHasher; use std::pin::Pin; use std::sync::Arc; @@ -76,7 +77,12 @@ impl IO { Box::pin(async move { match self { IO::Http { req_template, dl_id, .. } => { - let req = req_template.to_request(&ctx)?; + let req = req_template + .rendered_request + .lock() + .unwrap() + .take() + .unwrap_or(req_template.to_request(&ctx, None::))?; let is_get = req.method() == reqwest::Method::GET; let res = if is_get && ctx.request_ctx.is_batching_enabled() {