From a0baf3ed30eb1167cdcf538f43bb899a6e5311df Mon Sep 17 00:00:00 2001 From: Matt Bray Date: Fri, 30 Aug 2024 11:45:20 +0100 Subject: [PATCH] fix(cohttp): ensure response body is always consumed Try to make it obvious that the body is consumed using `Util.consume_body` or `Util.drain_body` right next to the request call site. --- src/auth.ml | 45 +++++++++++++++++++++++++-------------- src/big_query.ml | 30 +++++++++++++++++++------- src/compute.ml | 14 +++++++----- src/container.ml | 7 ++++-- src/error.ml | 13 ++++------- src/kms.ml | 7 +++++- src/pub_sub.ml | 12 ++++++++--- src/secretmanager.ml | 8 +++++-- src/stackdriver_errors.ml | 5 ++++- src/storage.ml | 24 +++++++++++++++------ src/util.ml | 12 +++++++++++ 11 files changed, 123 insertions(+), 54 deletions(-) diff --git a/src/auth.ml b/src/auth.ml index 8336572..9ee90da 100644 --- a/src/auth.ml +++ b/src/auth.ml @@ -89,7 +89,9 @@ module Compute_engine = struct let ping () = let uri = Uri.of_string metadata_ip_root in + let open Lwt.Infix in Cohttp_lwt_unix.Client.get uri ~headers:metadata_headers + >>= Util.drain_body let get_project_id () : ( string, @@ -100,9 +102,10 @@ module Compute_engine = struct Uri.of_string (Printf.sprintf "%s/project/project-id" metadata_root) in Cohttp_lwt_unix.Client.get uri ~headers:metadata_headers + >>= Util.consume_body >>= fun (resp, body) -> match Cohttp.Response.status resp with - | `OK -> Cohttp_lwt.Body.to_string body |> ok + | `OK -> Lwt_result.return body | status -> `Bad_GCE_metadata_response status |> Lwt_result.fail end end @@ -230,14 +233,13 @@ module External_account_credentials = struct |> to_string let subject_token_of_response (t : t) - ((resp, body) : Cohttp.Response.t * Cohttp_lwt.Body.t) : + ((resp, body_str) : Cohttp.Response.t * string) : (string, [> `Bad_token_response of string ]) result Lwt.t = let open Lwt.Syntax in match Cohttp.Response.status resp with | `OK -> ( match t.credential_source.format.type_ with | `Json -> ( - let* body_str = Cohttp_lwt.Body.to_string body in try body_str |> Yojson.Basic.from_string |> subject_token_of_json t |> Lwt.return_ok @@ -245,7 +247,6 @@ module External_account_credentials = struct let* () = L.debug (fun m -> m "Type_error: %s" msg) in Lwt.return_error (`Bad_subject_token_response (resp, body_str)))) | _ -> - let* body_str = Cohttp_lwt.Body.to_string body in let* () = L.err (fun m -> m "response: %s" body_str) in Lwt.return_error (`Bad_subject_token_response (resp, body_str)) end @@ -367,15 +368,12 @@ let credentials_of_file (credentials_file : string) : lines |> String.concat "\n" |> credentials_of_string |> Lwt.return) let access_token_of_response ?(of_json = access_token_of_json) - ((resp, body) : Cohttp.Response.t * Cohttp_lwt.Body.t) : + ((resp, body_str) : Cohttp.Response.t * string) : (Access_token.t, [> `Bad_token_response of string ]) result Lwt.t = let open Lwt.Syntax in match Cohttp.Response.status resp with - | `OK -> - let* body_str = Cohttp_lwt.Body.to_string body in - body_str |> Yojson.Basic.from_string |> of_json |> Lwt.return + | `OK -> body_str |> Yojson.Basic.from_string |> of_json |> Lwt.return | _ -> - let* body_str = Cohttp_lwt.Body.to_string body in let* () = L.err (fun m -> m "response: %s" body_str) in Lwt.return_error (`Bad_token_response body_str) @@ -397,7 +395,11 @@ let access_token_of_credentials (scopes : string list) ("grant_type", [ "refresh_token" ]); ] in - let* res = Cohttp_lwt_unix.Client.post_form token_uri ~params |> ok in + let* res = + let open Lwt.Infix in + Cohttp_lwt_unix.Client.post_form token_uri ~params + >>= Util.consume_body |> ok + in access_token_of_response ~of_json:access_token_of_json res | Service_account c -> ( let now = Unix.time () in @@ -432,8 +434,9 @@ let access_token_of_credentials (scopes : string list) ] in let* res = + let open Lwt.Infix in Cohttp_lwt_unix.Client.post_form (Uri.of_string c.token_uri) ~params - |> ok + >>= Util.consume_body |> ok in access_token_of_response ~of_json:access_token_of_json res | _ -> Lwt_result.fail (`Bad_credentials_priv_key "Not RSA key")) @@ -444,9 +447,10 @@ let access_token_of_credentials (scopes : string list) |> Uri.of_string in let* res = + let open Lwt.Infix in Cohttp_lwt_unix.Client.get uri ~headers:Compute_engine.Metadata.metadata_headers - |> ok + >>= Util.consume_body |> ok in access_token_of_response ~of_json:access_token_of_json res | External_account (c : External_account_credentials.t) -> ( @@ -463,10 +467,11 @@ let access_token_of_credentials (scopes : string list) let* subject_token = let subject_token_uri = Uri.of_string c.credential_source.url in let* resp = + let open Lwt.Infix in Cohttp_lwt_unix.Client.get ~headers:(Cohttp.Header.of_list c.credential_source.headers) subject_token_uri - |> ok + >>= Util.consume_body |> ok in External_account_credentials.subject_token_of_response c resp in @@ -486,7 +491,11 @@ let access_token_of_credentials (scopes : string list) ] in let body = Cohttp_lwt.Body.of_string (Yojson.Basic.to_string params) in - let* res = Cohttp_lwt_unix.Client.post token_uri ~body |> ok in + let* res = + let open Lwt.Infix in + Cohttp_lwt_unix.Client.post token_uri ~body + >>= Util.consume_body |> ok + in Lwt_result.return res in match c.service_account_impersonation_url with @@ -514,7 +523,11 @@ let access_token_of_credentials (scopes : string list) in let uri = Uri.of_string sac in let* () = L.debug (fun m -> m "POST %a" Uri.pp_hum uri) |> ok in - let* res = Cohttp_lwt_unix.Client.post uri ~headers ~body |> ok in + let* res = + let open Lwt.Infix in + Cohttp_lwt_unix.Client.post uri ~headers ~body + >>= Util.consume_body |> ok + in let access_token_of_json (json : Yojson.Basic.t) : (Access_token.t, [> error ]) result = (* has a slightly different format from the access token in the other responses: @@ -579,7 +592,7 @@ let discover_credentials_with (discovery_mode : discovery_mode) = Lwt.catch (fun () -> let open Lwt_result.Syntax in - let* resp, _body = Compute_engine.Metadata.ping () |> ok in + let* resp = Compute_engine.Metadata.ping () |> ok in let* () = L.debug (fun m -> m "Got metadata response") |> ok in let has_metadata_header = Compute_engine.Metadata.response_has_metadata_header resp diff --git a/src/big_query.ml b/src/big_query.ml index eb1bb10..e79f3b8 100644 --- a/src/big_query.ml +++ b/src/big_query.ml @@ -2,6 +2,8 @@ let src = Logs.Src.create "gcloud.bigquery" module L = (val Logs_lwt.src_log src) +let ok = Lwt_result.ok + module Scopes = struct let bigquery = "https://www.googleapis.com/auth/bigquery" end @@ -85,11 +87,13 @@ module Datasets = struct ] in L.debug (fun m -> m "GET %a" Uri.pp_hum uri) |> Lwt_result.ok - >>= fun () -> Cohttp_lwt_unix.Client.get uri ~headers |> Lwt_result.ok) + >>= fun () -> + let open Lwt.Infix in + Cohttp_lwt_unix.Client.get uri ~headers >>= Util.consume_body |> ok) (fun e -> `Network_error e |> Lwt_result.fail) >>= fun (resp, body) -> match Cohttp.Response.status resp with - | `OK -> Cohttp_lwt.Body.to_string body |> Lwt_result.ok + | `OK -> Lwt_result.return body | x -> Error.of_response_status_code_and_body x body let list ?project_id () : (string, [> Error.t ]) Lwt_result.t = @@ -110,11 +114,13 @@ module Datasets = struct ] in L.debug (fun m -> m "GET %a" Uri.pp_hum uri) |> Lwt_result.ok - >>= fun () -> Cohttp_lwt_unix.Client.get uri ~headers |> Lwt_result.ok) + >>= fun () -> + let open Lwt.Infix in + Cohttp_lwt_unix.Client.get uri ~headers >>= Util.consume_body |> ok) (fun e -> `Network_error e |> Lwt_result.fail) >>= fun (resp, body) -> match Cohttp.Response.status resp with - | `OK -> Cohttp_lwt.Body.to_string body |> Lwt_result.ok + | `OK -> Lwt_result.return body | x -> Error.of_response_status_code_and_body x body module Tables = struct @@ -173,11 +179,13 @@ module Datasets = struct ] in L.debug (fun m -> m "GET %a" Uri.pp_hum uri) |> Lwt_result.ok - >>= fun () -> Cohttp_lwt_unix.Client.get uri ~headers |> Lwt_result.ok) + >>= fun () -> + let open Lwt.Infix in + Cohttp_lwt_unix.Client.get uri ~headers >>= Util.consume_body |> ok) (fun e -> `Network_error e |> Lwt_result.fail) >>= fun (resp, body) -> match Cohttp.Response.status resp with - | `OK -> Error.parse_body_json resp_of_yojson body + | `OK -> Error.parse_body_json resp_of_yojson body |> Lwt.return | x -> Error.of_response_status_code_and_body x body end end @@ -730,12 +738,15 @@ module Jobs = struct m "Query: %s" q_trimmed) |> Lwt_result.ok >>= fun () -> - Cohttp_lwt_unix.Client.post uri ~headers ~body |> Lwt_result.ok) + let open Lwt.Infix in + Cohttp_lwt_unix.Client.post uri ~headers ~body + >>= Util.consume_body |> ok) (fun e -> `Network_error e |> Lwt_result.fail) >>= fun (resp, body) -> match Cohttp.Response.status resp with | `OK -> Error.parse_body_json ~gzipped:use_gzip query_response_of_yojson body + |> Lwt.return >>= fun response -> L.debug (fun m -> m "%a" pp_query_response response) |> Lwt_result.ok >>= fun () -> Lwt_result.return response @@ -786,12 +797,15 @@ module Jobs = struct |> add_gzip_headers ~use_gzip in Lwt.catch - (fun () -> Cohttp_lwt_unix.Client.get uri ~headers |> Lwt_result.ok) + (fun () -> + let open Lwt.Infix in + Cohttp_lwt_unix.Client.get uri ~headers >>= Util.consume_body |> ok) (fun e -> Lwt_result.fail (`Network_error e)) >>= fun (resp, body) -> match Cohttp.Response.status resp with | `OK -> Error.parse_body_json ~gzipped:use_gzip query_response_of_yojson body + |> Lwt.return >>= fun response -> L.debug (fun m -> m "%a" pp_query_response response) |> Lwt_result.ok >>= fun () -> Lwt_result.return response diff --git a/src/compute.ml b/src/compute.ml index 7caa27c..fdc2089 100644 --- a/src/compute.ml +++ b/src/compute.ml @@ -1,3 +1,5 @@ +let ok = Lwt_result.ok + module Scopes = struct let cloud_platform = "https://www.googleapis.com/auth/cloud-platform" let compute = "https://www.googleapis.com/auth/cloud-platform" @@ -48,13 +50,14 @@ module FirewallRules = struct ] in let body_str = rule |> rule_to_yojson |> Yojson.Safe.to_string in - print_endline body_str; let body = body_str |> Cohttp_lwt.Body.of_string in - Cohttp_lwt_unix.Client.post uri ~body ~headers |> Lwt_result.ok) + let open Lwt.Infix in + Cohttp_lwt_unix.Client.post uri ~body ~headers + >>= Util.consume_body |> ok) (fun e -> Lwt_result.fail (`Network_error e)) >>= fun (resp, body) -> match Cohttp.Response.status resp with - | `OK -> Lwt_result.ok (Cohttp_lwt.Body.to_string body) + | `OK -> Lwt_result.return body | status_code -> Error.of_response_status_code_and_body status_code body let delete ?project_id ~(name : string) () : @@ -78,10 +81,11 @@ module FirewallRules = struct Printf.sprintf "Bearer %s" token_info.Auth.token.access_token ); ] in - Cohttp_lwt_unix.Client.delete uri ~headers |> Lwt_result.ok) + let open Lwt.Infix in + Cohttp_lwt_unix.Client.delete uri ~headers >>= Util.consume_body |> ok) (fun e -> Lwt_result.fail (`Network_error e)) >>= fun (resp, body) -> match Cohttp.Response.status resp with - | `OK -> Lwt_result.ok (Cohttp_lwt.Body.to_string body) + | `OK -> Lwt_result.return body | status_code -> Error.of_response_status_code_and_body status_code body end diff --git a/src/container.ml b/src/container.ml index 1a7841d..3715766 100644 --- a/src/container.ml +++ b/src/container.ml @@ -1,3 +1,5 @@ +let ok = Lwt_result.ok + module Scopes = struct let cloud_platform = "https://www.googleapis.com/auth/cloud-platform" end @@ -54,11 +56,12 @@ module Projects = struct token_info.Auth.token.access_token ); ] in - Cohttp_lwt_unix.Client.get uri ~headers |> Lwt_result.ok) + let open Lwt.Infix in + Cohttp_lwt_unix.Client.get uri ~headers >>= Util.consume_body |> ok) (fun e -> Lwt_result.fail (`Network_error e)) >>= fun (resp, body) -> match Cohttp.Response.status resp with - | `OK -> Error.parse_body_json of_yojson body + | `OK -> Error.parse_body_json of_yojson body |> Lwt.return | status_code -> Error.of_response_status_code_and_body status_code body end end diff --git a/src/error.ml b/src/error.ml index c7b1703..3e55cd0 100644 --- a/src/error.ml +++ b/src/error.ml @@ -56,10 +56,8 @@ let pp fmt (error : t) = | `Msg s -> Format.fprintf fmt "Msg: %s" s let parse_body_json ?(gzipped = false) - (transform : Yojson.Safe.t -> ('a, string) result) - (body : Cohttp_lwt.Body.t) : ('a, [> t ]) Lwt_result.t = - let open Lwt.Infix in - Cohttp_lwt.Body.to_string body >>= fun body_str -> + (transform : Yojson.Safe.t -> ('a, string) result) (body_str : string) : + ('a, [> t ]) result = let body = if gzipped then Ezgzip.decompress body_str @@ -72,18 +70,15 @@ let parse_body_json ?(gzipped = false) | Yojson.Json_error msg -> Error (`Json_parse_error (msg, body_str)) | e -> Error (`Json_parse_error (Printexc.to_string e, body_str)) in - parse_result |> CCResult.flat_map (fun json -> transform json |> CCResult.map_err (fun e -> `Json_transform_error (e, json))) - |> Lwt.return let of_response_status_code_and_body ?gzipped - (status_code : Cohttp.Code.status_code) (body : Cohttp_lwt.Body.t) : + (status_code : Cohttp.Code.status_code) (body_str : string) : ('a, [> t ]) Lwt_result.t = - let open Lwt.Infix in - parse_body_json ?gzipped api_json_error_of_yojson body >>= function + match parse_body_json ?gzipped api_json_error_of_yojson body_str with | Ok parsed_error -> Lwt_result.fail (`Gcloud_api_error (status_code, Json parsed_error)) | Error (`Json_parse_error (_, body_str)) -> diff --git a/src/kms.ml b/src/kms.ml index a20d40e..a88d142 100644 --- a/src/kms.ml +++ b/src/kms.ml @@ -1,3 +1,5 @@ +let ok = Lwt_result.ok + module Scopes = struct let cloudkms = "https://www.googleapis.com/auth/cloudkms" end @@ -36,7 +38,9 @@ module V1 = struct token_info.Auth.token.access_token ); ] in - Cohttp_lwt_unix.Client.post uri ~headers ~body |> Lwt_result.ok) + let open Lwt.Infix in + Cohttp_lwt_unix.Client.post uri ~headers ~body + >>= Util.consume_body |> ok) (fun e -> `Network_error e |> Lwt_result.fail) >>= fun (resp, body) -> match Cohttp.Response.status resp with @@ -52,6 +56,7 @@ module V1 = struct Error "Could not base64-decode the plaintext") | _ -> Error "Expected an object with field 'plaintext'") body + |> Lwt.return | x -> Error.of_response_status_code_and_body x body end end diff --git a/src/pub_sub.ml b/src/pub_sub.ml index 555d595..5e3f096 100644 --- a/src/pub_sub.ml +++ b/src/pub_sub.ml @@ -1,3 +1,5 @@ +let ok = Lwt_result.ok + module Scopes = struct let pubsub = "https://www.googleapis.com/auth/pubsub" end @@ -36,7 +38,9 @@ module Subscriptions = struct in Logs_lwt.debug (fun m -> m "POST %a" Uri.pp_hum uri) |> Lwt_result.ok >>= fun () -> - Cohttp_lwt_unix.Client.post uri ~headers ~body |> Lwt_result.ok) + let open Lwt.Infix in + Cohttp_lwt_unix.Client.post uri ~headers ~body + >>= Util.consume_body |> ok) (fun e -> `Network_error e |> Lwt_result.fail) >>= fun (resp, body) -> match Cohttp.Response.status resp with @@ -94,12 +98,14 @@ module Subscriptions = struct Logs_lwt.debug ~src:log_src_pull (fun m -> m "POST %a" Uri.pp_hum uri) |> Lwt_result.ok >>= fun () -> - Cohttp_lwt_unix.Client.post uri ~headers ~body |> Lwt_result.ok) + let open Lwt.Infix in + Cohttp_lwt_unix.Client.post uri ~headers ~body + >>= Util.consume_body |> ok) (fun e -> `Network_error e |> Lwt_result.fail) >>= fun (resp, body) -> match Cohttp.Response.status resp with | `OK -> - Error.parse_body_json received_messages_of_yojson body + Error.parse_body_json received_messages_of_yojson body |> Lwt.return >|= fun { received_messages } -> let received_messages = received_messages diff --git a/src/secretmanager.ml b/src/secretmanager.ml index 9801f74..308bc4b 100644 --- a/src/secretmanager.ml +++ b/src/secretmanager.ml @@ -1,3 +1,5 @@ +let ok = Lwt_result.ok + module Scopes = struct let cloud_platform = "https://www.googleapis.com/auth/cloud-platform" end @@ -43,12 +45,14 @@ module V1 = struct ); ] in - Cohttp_lwt_unix.Client.get uri ~headers |> Lwt_result.ok) + let open Lwt.Infix in + Cohttp_lwt_unix.Client.get uri ~headers + >>= Util.consume_body |> ok) (fun e -> Lwt_result.fail (`Network_error e)) in match Cohttp.Response.status resp with - | `OK -> Error.parse_body_json response_of_yojson body + | `OK -> Error.parse_body_json response_of_yojson body |> Lwt.return | status_code -> Error.of_response_status_code_and_body status_code body end diff --git a/src/stackdriver_errors.ml b/src/stackdriver_errors.ml index 3466b1b..5f1b189 100644 --- a/src/stackdriver_errors.ml +++ b/src/stackdriver_errors.ml @@ -1,3 +1,5 @@ +let ok = Lwt_result.ok + module Scopes = struct let stackdriver_integration = "https://www.googleapis.com/auth/stackdriver-integration" @@ -80,7 +82,8 @@ let report ?project_id (report_request : report_request) : report_request |> report_request_to_yojson |> Yojson.Safe.to_string in let body = body_str |> Cohttp_lwt.Body.of_string in - Cohttp_lwt_unix.Client.post uri ~headers ~body |> Lwt_result.ok + (let open Lwt.Infix in + Cohttp_lwt_unix.Client.post uri ~headers ~body >>= Util.consume_body |> ok) >>= fun (response, body) -> let status = Cohttp.Response.status response in match status with diff --git a/src/storage.ml b/src/storage.ml index 4702d41..ea6811c 100644 --- a/src/storage.ml +++ b/src/storage.ml @@ -1,3 +1,5 @@ +let ok = Lwt_result.ok + module Scopes = struct let devstorage_read_only = "https://www.googleapis.com/auth/devstorage.read_only" @@ -40,7 +42,10 @@ let get_object_stream (bucket_name : string) (object_path : string) : >>= fun (resp, body) -> match Cohttp.Response.status resp with | `OK -> Cohttp_lwt.Body.to_stream body |> Lwt_result.return - | status_code -> Error.of_response_status_code_and_body status_code body + | status_code -> + Cohttp_lwt.Body.to_string body + |> ok + >>= Error.of_response_status_code_and_body status_code let get_object (bucket_name : string) (object_path : string) : (string, [> Error.t ]) Lwt_result.t = @@ -67,11 +72,12 @@ let insert_object_ bucket_name name (body : Cohttp_lwt.Body.t) : Printf.sprintf "Bearer %s" token_info.Auth.token.access_token ); ] in - Cohttp_lwt_unix.Client.post uri ~headers ~body |> Lwt_result.ok) + let open Lwt.Infix in + Cohttp_lwt_unix.Client.post uri ~headers ~body >>= Util.consume_body |> ok) (fun e -> Lwt_result.fail (`Network_error e)) >>= fun (resp, body) -> match Cohttp.Response.status resp with - | `OK -> Error.parse_body_json object__of_yojson body + | `OK -> Error.parse_body_json object__of_yojson body |> Lwt.return | status_code -> Error.of_response_status_code_and_body status_code body let insert_object bucket_name name (data : string) : @@ -118,11 +124,13 @@ let rewrite_object source_bucket source_object destination_bucket ] in let body = Cohttp_lwt.Body.empty in - Cohttp_lwt_unix.Client.post uri ~headers ~body |> Lwt_result.ok) + let open Lwt.Infix in + Cohttp_lwt_unix.Client.post uri ~headers ~body >>= Util.consume_body |> ok) (fun e -> Lwt_result.fail (`Network_error e)) >>= fun (resp, body) -> match Cohttp.Response.status resp with - | `OK -> Error.parse_body_json rewrite_object_response_of_yojson body + | `OK -> + Error.parse_body_json rewrite_object_response_of_yojson body |> Lwt.return | status_code -> Error.of_response_status_code_and_body status_code body [@@@warning "-39"] @@ -168,9 +176,11 @@ let list_objects ?(delimiter : string option) ?(prefix : string option) Printf.sprintf "Bearer %s" token_info.Auth.token.access_token ); ] in - Cohttp_lwt_unix.Client.get uri ~headers |> Lwt_result.ok) + let open Lwt.Infix in + Cohttp_lwt_unix.Client.get uri ~headers >>= Util.consume_body |> ok) (fun e -> Lwt_result.fail (`Network_error e)) >>= fun (resp, body) -> match Cohttp.Response.status resp with - | `OK -> Error.parse_body_json list_objects_response_of_yojson body + | `OK -> + Error.parse_body_json list_objects_response_of_yojson body |> Lwt.return | status_code -> Error.of_response_status_code_and_body status_code body diff --git a/src/util.ml b/src/util.ml index 5f27130..ef2d833 100644 --- a/src/util.ml +++ b/src/util.ml @@ -8,3 +8,15 @@ module List = struct in aux [] xss end + +let consume_body ((resp, body) : Cohttp.Response.t * Cohttp_lwt.Body.t) : + (Cohttp.Response.t * string) Lwt.t = + let open Lwt.Syntax in + let* body_str = Cohttp_lwt.Body.to_string body in + Lwt.return (resp, body_str) + +let drain_body ((resp, body) : Cohttp.Response.t * Cohttp_lwt.Body.t) : + Cohttp.Response.t Lwt.t = + let open Lwt.Syntax in + let* () = Cohttp_lwt.Body.drain_body body in + Lwt.return resp