diff --git a/rust/hook-api/src/handlers/webhook.rs b/rust/hook-api/src/handlers/webhook.rs index c0e5f413130052..38465d3ad172be 100644 --- a/rust/hook-api/src/handlers/webhook.rs +++ b/rust/hook-api/src/handlers/webhook.rs @@ -4,7 +4,7 @@ use std::time::Instant; use axum::{extract::State, http::StatusCode, Json}; use hook_common::webhook::{WebhookJobMetadata, WebhookJobParameters}; use serde_derive::Deserialize; -use serde_json::Value; +use serde_json::{Map, Value}; use url::Url; use hook_common::pgqueue::{NewJob, PgQueue}; @@ -65,85 +65,81 @@ pub async fn post_webhook( Ok(Json(WebhookPostResponse { error: None })) } -#[derive(Debug, Deserialize)] -pub struct HogFetchParameters { +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct HoghookFetchParameters { + #[serde(skip_serializing_if = "Option::is_none")] pub body: Option, + + #[serde(skip_serializing_if = "Option::is_none")] pub headers: Option>, + + #[serde(skip_serializing_if = "Option::is_none")] pub method: Option, } -// Hoghook expects a JSON payload in the format of `HogFunctionInvocationResult` (as seen in -// plugin-server), but we accept a plain `Json` via Axum here, and this is why: -// * The reason we don't decode that into a `HogFunctionInvocationResult`-shaped Rust struct is that -// there's no benefit in mirroring the exact shape of that type (and keeping it sync with the -// plugin-server type). -// * Hoghook only cares about a small subset of the payload (the `asyncFunctionRequest` field), and -// the reason we don't decode *that* into a Rust struct is because the function args are a simple -// array (because this type is used for more than just `fetch` requests), and so we would need to -// manually validate and destructure the array elements anyway. -// * Additionally, don't want to discard the rest of the payload because we pass it back to the -// plugin-server after receiving the response body from the remote server. By accepting a plain -// `Json` we only decode the JSON once, we can do our minimal validation/extraction, and we -// can save the rest of the payload for later. +#[derive(Debug, Serialize, Deserialize)] +struct HoghookArgs( + String, + #[serde(default, skip_serializing_if = "Option::is_none")] Option, +); + +#[derive(Debug, Serialize, Deserialize)] +struct HoghookAsyncFunctionRequest { + name: String, + args: HoghookArgs, +} + +#[derive(Debug, Deserialize)] +pub struct HoghookPayload { + #[serde(rename = "asyncFunctionRequest")] + async_function_request: HoghookAsyncFunctionRequest, + + #[serde(flatten)] + passthrough: HashMap, +} + pub async fn post_hoghook( State(pg_queue): State, - Json(mut payload): Json, + Json(payload): Json, ) -> Result, (StatusCode, Json)> { debug!("received payload: {:?}", payload); - let parameters: WebhookJobParameters = match &mut payload { - Value::Object(object) => { - let async_fn_request = object - .get("asyncFunctionRequest") - .ok_or_else(|| bad_request("missing required field 'asyncFunctionRequest'"))?; - - let name = async_fn_request - .get("name") - .ok_or_else(|| bad_request("missing required field 'asyncFunctionRequest.name'"))?; - - if name != "fetch" { - return Err(bad_request("asyncFunctionRequest.name must be 'fetch'")); - } - - let args = async_fn_request - .get("args") - .ok_or_else(|| bad_request("missing required field 'asyncFunctionRequest.args'"))?; - - // Note that the URL is parsed (and thus validated as a valid URL) as part of - // `get_hostname` below. - let url = args.get(0).ok_or_else(|| { - bad_request("missing required field 'asyncFunctionRequest.args[0]'") - })?; - - let fetch_options: HogFetchParameters = if let Some(value) = args.get(1) { - serde_json::from_value(value.clone()).map_err(|_| { - bad_request("failed to deserialize asyncFunctionRequest.args[1]") - })? - } else { - HogFetchParameters { - body: None, - headers: None, - method: None, - } - }; - - WebhookJobParameters { - body: fetch_options.body.unwrap_or("".to_owned()), - headers: fetch_options.headers.unwrap_or_default(), - method: fetch_options.method.unwrap_or(HttpMethod::POST), - url: url - .as_str() - .ok_or_else(|| bad_request("url must be a string"))? - .to_owned(), - } + if payload.async_function_request.name != "fetch" { + return Err(bad_request("asyncFunctionRequest.name must be 'fetch'")); + } + + // Note that the URL is parsed (and thus validated as a valid URL) as part of + // `get_hostname` below. + let url = payload.async_function_request.args.0.clone(); + let parameters = if let Some(ref fetch_options) = payload.async_function_request.args.1 { + let fetch_options = fetch_options.clone(); + WebhookJobParameters { + body: fetch_options.body.unwrap_or("".to_owned()), + headers: fetch_options.headers.unwrap_or_default(), + method: fetch_options.method.unwrap_or(HttpMethod::POST), + url, + } + } else { + WebhookJobParameters { + body: "".to_owned(), + headers: HashMap::new(), + method: HttpMethod::POST, + url, } - _ => return Err(bad_request("expected JSON object")), }; let url_hostname = get_hostname(¶meters.url)?; let max_attempts = default_max_attempts() as i32; - let job = NewJob::new(max_attempts, payload, parameters, url_hostname.as_str()); + // Reconstruct the original JSON payload. + let mut json_map: Map = payload.passthrough.into_iter().collect(); + json_map.insert( + "asyncFunctionRequest".to_owned(), + serde_json::to_value(payload.async_function_request).unwrap(), + ); + let passthrough: Value = Value::Object(json_map.into()); + + let job = NewJob::new(max_attempts, passthrough, parameters, url_hostname.as_str()); let start_time = Instant::now(); @@ -491,7 +487,10 @@ mod tests { .await .unwrap(); - assert_eq!(response.status(), StatusCode::BAD_REQUEST); + assert!( + response.status() == StatusCode::BAD_REQUEST + || response.status() == StatusCode::UNPROCESSABLE_ENTITY + ); } } }