From bcc208cf59eeabac9784dc8afd55b9603e9b49e7 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Tue, 27 Feb 2024 15:14:12 -0500 Subject: [PATCH] fix middlewares: merge-sort per-request middleares and global ones --- src/core/server.ml | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/src/core/server.ml b/src/core/server.ml index e24bed89..764c1eae 100644 --- a/src/core/server.ml +++ b/src/core/server.ml @@ -60,12 +60,12 @@ module type IO_BACKEND = sig end type handler_result = - | Handle of cb_path_handler + | Handle of (int * Middleware.t) list * cb_path_handler | Fail of resp_error | Upgrade of upgrade_handler let unwrap_handler_result req = function - | Handle x -> x + | Handle (l, h) -> l, h | Fail (c, s) -> raise (Bad_req (c, s)) | Upgrade up -> raise (Upgrade (req, up)) @@ -101,6 +101,9 @@ let active_connections (self : t) = | None -> 0 | Some s -> s.active_connections () +let sort_middlewares_ l = + List.stable_sort (fun (s1, _) (s2, _) -> compare s1 s2) l + let add_middleware ~stage self m = let stage = match stage with @@ -109,9 +112,7 @@ let add_middleware ~stage self m = | `Stage n -> n in self.middlewares <- (stage, m) :: self.middlewares; - self.middlewares_sorted <- - lazy - (List.stable_sort (fun (s1, _) (s2, _) -> compare s1 s2) self.middlewares) + self.middlewares_sorted <- lazy (sort_middlewares_ self.middlewares) let add_decode_request_cb self f = (* turn it into a middleware *) @@ -145,6 +146,7 @@ let set_top_handler self f = self.handler <- f and makes it into a handler. *) let add_route_handler_ ?(accept = fun _req -> Ok ()) ?(middlewares = []) ?meth ~tr_req self (route : _ Route.t) f = + let middlewares = List.map (fun h -> 5, h) middlewares in let ph req : handler_result option = match meth with | Some m when m <> req.Request.meth -> None (* ignore *) @@ -156,9 +158,7 @@ let add_route_handler_ ?(accept = fun _req -> Ok ()) ?(middlewares = []) ?meth | Ok () -> Some (Handle - (fun oc -> - Middleware.apply_l middlewares @@ fun req ~resp -> - tr_req oc req ~resp handler)) + (middlewares, fun oc req ~resp -> tr_req oc req ~resp handler)) | Error err -> Some (Fail err)) | None -> None (* path didn't match *)) in @@ -409,10 +409,10 @@ let client_handle_for (self : t) ~client_addr ic oc : unit = (try (* is there a handler for this path? *) - let base_handler = + let handler_middlewares, base_handler = match find_map (fun ph -> ph req) self.path_handlers with | Some f -> unwrap_handler_result req f - | None -> fun _oc req ~resp -> resp (self.handler req) + | None -> [], fun _oc req ~resp -> resp (self.handler req) in (* handle expect/continue *) @@ -424,12 +424,21 @@ let client_handle_for (self : t) ~client_addr ic oc : unit = | Some s -> bad_reqf 417 "unknown expectation %s" s | None -> ()); + (* merge per-request middlewares with the server-global middlewares *) + let global_middlewares = Lazy.force self.middlewares_sorted in + let all_middlewares = + if handler_middlewares = [] then + global_middlewares + else + sort_middlewares_ + (List.rev_append handler_middlewares self.middlewares) + in + (* apply middlewares *) let handler oc = List.fold_right (fun (_, m) h -> m h) - (Lazy.force self.middlewares_sorted) - (base_handler oc) + all_middlewares (base_handler oc) in (* now actually read request's body into a stream *)