From 9b80a3a6c8a5197c7f448aacb621f2224f7e08ee Mon Sep 17 00:00:00 2001 From: Mathieu Barbin Date: Sat, 6 Jan 2024 15:39:39 +0100 Subject: [PATCH] implement untyped rpc server on top of typed server --- lib/grpc-eio/client.ml | 10 +- lib/grpc-eio/server.ml | 236 +++++++++++++++++++++------------------- lib/grpc-eio/server.mli | 100 +++++++++-------- 3 files changed, 184 insertions(+), 162 deletions(-) diff --git a/lib/grpc-eio/client.ml b/lib/grpc-eio/client.ml index 512cc29..95d20b2 100644 --- a/lib/grpc-eio/client.ml +++ b/lib/grpc-eio/client.ml @@ -88,7 +88,7 @@ module Typed_rpc = struct 'a let make_handler (type request response) - (rpc : (request, _, response, _) Grpc.Rpc.Client_rpc.t) ~f = + ~(rpc : (request, _, response, _) Grpc.Rpc.Client_rpc.t) ~f = make_handler ~encode_request:rpc.encode_request ~decode_response:rpc.decode_response ~f @@ -99,7 +99,7 @@ module Typed_rpc = struct response, Grpc.Rpc.Value_mode.stream ) Grpc.Rpc.Client_rpc.t) = - make_handler rpc ~f + make_handler ~rpc ~f let client_streaming (type request response) ~f (rpc : @@ -108,7 +108,7 @@ module Typed_rpc = struct response, Grpc.Rpc.Value_mode.unary ) Grpc.Rpc.Client_rpc.t) = - make_handler rpc ~f:(fun request_writer responses -> + make_handler ~rpc ~f:(fun request_writer responses -> let response, response_resolver = Eio.Promise.create () in Eio.Fiber.pair (fun () -> f request_writer response) @@ -124,7 +124,7 @@ module Typed_rpc = struct response, Grpc.Rpc.Value_mode.stream ) Grpc.Rpc.Client_rpc.t) = - make_handler rpc ~f:(fun request_writer responses -> + make_handler ~rpc ~f:(fun request_writer responses -> Seq.write request_writer request; Seq.close_writer request_writer; f responses) @@ -136,7 +136,7 @@ module Typed_rpc = struct response, Grpc.Rpc.Value_mode.unary ) Grpc.Rpc.Client_rpc.t) = - make_handler rpc ~f:(fun request_writer responses -> + make_handler ~rpc ~f:(fun request_writer responses -> Seq.write request_writer request; Seq.close_writer request_writer; let response = Seq.read_and_exhaust responses in diff --git a/lib/grpc-eio/server.ml b/lib/grpc-eio/server.ml index 175f87b..b0fdfb9 100644 --- a/lib/grpc-eio/server.ml +++ b/lib/grpc-eio/server.ml @@ -1,6 +1,7 @@ module ServiceMap = Map.Make (String) -type service = H2.Reqd.t -> unit +type reqd_handler = H2.Reqd.t -> unit +type service = reqd_handler type t = service ServiceMap.t let v () = ServiceMap.empty @@ -44,93 +45,50 @@ let handle_request t reqd = | None -> respond_with `Unsupported_media_type) | _ -> respond_with `Not_found -module Rpc = struct - type unary = string -> Grpc.Status.t * string option - type client_streaming = string Seq.t -> Grpc.Status.t * string option - type server_streaming = string -> (string -> unit) -> Grpc.Status.t - - type bidirectional_streaming = - string Seq.t -> (string -> unit) -> Grpc.Status.t - - type t = - | Unary of unary - | Client_streaming of client_streaming - | Server_streaming of server_streaming - | Bidirectional_streaming of bidirectional_streaming - - let bidirectional_streaming ~f reqd = - let body = H2.Reqd.request_body reqd in - let request_reader, request_writer = Seq.create_reader_writer () in - let response_reader, response_writer = Seq.create_reader_writer () in - Connection.Untyped.grpc_recv_streaming body request_writer; - let status_promise, status_notify = Eio.Promise.create () in - Eio.Fiber.both - (fun () -> - let respond = Seq.write response_writer in - let status = f request_reader respond in - Seq.close_writer response_writer; - Eio.Promise.resolve status_notify status) - (fun () -> - try - Connection.Untyped.grpc_send_streaming reqd response_reader - status_promise - with exn -> - (* https://github.com/anmonteiro/ocaml-h2/issues/175 *) - Eio.traceln "%s" (Printexc.to_string exn)) - - let client_streaming ~f reqd = - bidirectional_streaming reqd ~f:(fun requests respond -> - let status, response = f requests in - (match response with None -> () | Some response -> respond response); - status) - - let server_streaming ~f reqd = - bidirectional_streaming reqd ~f:(fun requests respond -> - match Seq.read_and_exhaust requests with - | None -> Grpc.Status.(v OK) - | Some request -> f request respond) - - let unary ~f reqd = - bidirectional_streaming reqd ~f:(fun requests respond -> - match Seq.read_and_exhaust requests with - | None -> Grpc.Status.(v OK) - | Some request -> - let status, response = f request in - (match response with - | None -> () - | Some response -> respond response); - status) -end +let implement_rpc ~decode_request ~encode_response ~f reqd = + let body = H2.Reqd.request_body reqd in + let request_reader, request_writer = Seq.create_reader_writer () in + let response_reader, response_writer = Seq.create_reader_writer () in + Connection.Typed.grpc_recv_streaming ~decode:decode_request body + request_writer; + let status_promise, status_notify = Eio.Promise.create () in + Eio.Fiber.both + (fun () -> + let respond = Seq.write response_writer in + let status = f request_reader respond in + Seq.close_writer response_writer; + Eio.Promise.resolve status_notify status) + (fun () -> + try + Connection.Typed.grpc_send_streaming ~encode:encode_response reqd + response_reader status_promise + with exn -> + (* https://github.com/anmonteiro/ocaml-h2/issues/175 *) + Eio.traceln "%s" (Printexc.to_string exn)) -module Service = struct - module RpcMap = Map.Make (String) +module Typed_rpc = struct + module Service = struct + module RpcMap = Map.Make (String) - type t = Rpc.t RpcMap.t + type t = reqd_handler RpcMap.t - let v () = RpcMap.empty - let add_rpc ~name ~rpc t = RpcMap.add name rpc t + let v () = RpcMap.empty + let add_rpc ~name ~rpc t = RpcMap.add name rpc t - let handle_request (t : t) reqd = - let request = H2.Reqd.request reqd in - let respond_with code = - H2.Reqd.respond_with_string reqd (H2.Response.create code) "" - in - let parts = String.split_on_char '/' request.target in - if List.length parts > 1 then - let rpc_name = List.nth parts (List.length parts - 1) in - let rpc = RpcMap.find_opt rpc_name t in - match rpc with - | Some rpc -> ( - match rpc with - | Unary f -> Rpc.unary ~f reqd - | Client_streaming f -> Rpc.client_streaming ~f reqd - | Server_streaming f -> Rpc.server_streaming ~f reqd - | Bidirectional_streaming f -> Rpc.bidirectional_streaming ~f reqd) - | None -> respond_with `Not_found - else respond_with `Not_found -end + let handle_request (t : t) reqd = + let request = H2.Reqd.request reqd in + let respond_with code = + H2.Reqd.respond_with_string reqd (H2.Response.create code) "" + in + let parts = String.split_on_char '/' request.target in + if List.length parts > 1 then + let rpc_name = List.nth parts (List.length parts - 1) in + match RpcMap.find_opt rpc_name t with + | Some rpc -> rpc reqd + | None -> respond_with `Not_found + else respond_with `Not_found + end -module Typed_rpc = struct type server = t type ('request, 'response) unary = @@ -154,7 +112,7 @@ module Typed_rpc = struct 'response_mode, 'service_spec ) Grpc.Rpc.Server_rpc.t; - rpc_impl : Rpc.t; + rpc_impl : reqd_handler; } -> 'service_spec t @@ -195,6 +153,14 @@ module Typed_rpc = struct in Service.handle_request service) + let implement_rpc (type request response) + ~(rpc_spec : (request, _, response, _, _) Grpc.Rpc.Server_rpc.t) ~f = + let rpc_impl = + implement_rpc ~decode_request:rpc_spec.decode_request + ~encode_response:rpc_spec.encode_response ~f + in + T { rpc_spec; rpc_impl } + let bidirectional_streaming (type request response) (rpc_spec : ( request, @@ -202,12 +168,8 @@ module Typed_rpc = struct response, Grpc.Rpc.Value_mode.stream, _ ) - Grpc.Rpc.Server_rpc.t) ~f:handler = - let handler requests f = - let requests = Seq.map rpc_spec.decode_request requests in - handler requests (fun response -> f (rpc_spec.encode_response response)) - in - T { rpc_spec; rpc_impl = Rpc.Bidirectional_streaming handler } + Grpc.Rpc.Server_rpc.t) ~f = + implement_rpc ~rpc_spec ~f let unary (type request response) (rpc_spec : @@ -216,12 +178,16 @@ module Typed_rpc = struct response, Grpc.Rpc.Value_mode.unary, _ ) - Grpc.Rpc.Server_rpc.t) ~f:handler = - let handler buffer = - let status, response = handler (rpc_spec.decode_request buffer) in - (status, Option.map rpc_spec.encode_response response) - in - T { rpc_spec; rpc_impl = Rpc.Unary handler } + Grpc.Rpc.Server_rpc.t) ~f = + implement_rpc ~rpc_spec ~f:(fun requests respond -> + match Seq.read_and_exhaust requests with + | None -> Grpc.Status.(v OK) + | Some request -> + let status, response = f request in + (match response with + | None -> () + | Some response -> respond response); + status) let server_streaming (type request response) (rpc_spec : @@ -230,12 +196,11 @@ module Typed_rpc = struct response, Grpc.Rpc.Value_mode.stream, _ ) - Grpc.Rpc.Server_rpc.t) ~f:handler = - let handler buffer f = - handler (rpc_spec.decode_request buffer) (fun response -> - f (rpc_spec.encode_response response)) - in - T { rpc_spec; rpc_impl = Rpc.Server_streaming handler } + Grpc.Rpc.Server_rpc.t) ~f = + implement_rpc ~rpc_spec ~f:(fun requests respond -> + match Seq.read_and_exhaust requests with + | None -> Grpc.Status.(v OK) + | Some request -> f request respond) let client_streaming (type request response) (rpc_spec : @@ -244,11 +209,64 @@ module Typed_rpc = struct response, Grpc.Rpc.Value_mode.unary, _ ) - Grpc.Rpc.Server_rpc.t) ~f:handler = - let handler requests = - let requests = Seq.map rpc_spec.decode_request requests in - let status, response = handler requests in - (status, Option.map rpc_spec.encode_response response) - in - T { rpc_spec; rpc_impl = Rpc.Client_streaming handler } + Grpc.Rpc.Server_rpc.t) ~f = + implement_rpc ~rpc_spec ~f:(fun requests respond -> + let status, response = f requests in + (match response with None -> () | Some response -> respond response); + status) +end + +module Rpc = struct + type unary = string -> Grpc.Status.t * string option + type client_streaming = string Seq.t -> Grpc.Status.t * string option + type server_streaming = string -> (string -> unit) -> Grpc.Status.t + + type bidirectional_streaming = + string Seq.t -> (string -> unit) -> Grpc.Status.t + + type t = + | Unary of unary + | Client_streaming of client_streaming + | Server_streaming of server_streaming + | Bidirectional_streaming of bidirectional_streaming + + let bidirectional_streaming ~f reqd = + implement_rpc ~decode_request:Fun.id ~encode_response:Fun.id ~f reqd + + let client_streaming ~f reqd = + bidirectional_streaming reqd ~f:(fun requests respond -> + let status, response = f requests in + (match response with None -> () | Some response -> respond response); + status) + + let server_streaming ~f reqd = + bidirectional_streaming reqd ~f:(fun requests respond -> + match Seq.read_and_exhaust requests with + | None -> Grpc.Status.(v OK) + | Some request -> f request respond) + + let unary ~f reqd = + bidirectional_streaming reqd ~f:(fun requests respond -> + match Seq.read_and_exhaust requests with + | None -> Grpc.Status.(v OK) + | Some request -> + let status, response = f request in + (match response with + | None -> () + | Some response -> respond response); + status) +end + +module Service = struct + include Typed_rpc.Service + + let add_rpc ~name ~rpc t = + add_rpc ~name + ~rpc: + (match rpc with + | Rpc.Unary f -> Rpc.unary ~f + | Client_streaming f -> Rpc.client_streaming ~f + | Server_streaming f -> Rpc.server_streaming ~f + | Bidirectional_streaming f -> Rpc.bidirectional_streaming ~f) + t end diff --git a/lib/grpc-eio/server.mli b/lib/grpc-eio/server.mli index 76c000e..1f4be5a 100644 --- a/lib/grpc-eio/server.mli +++ b/lib/grpc-eio/server.mli @@ -1,53 +1,6 @@ include Grpc.Server.S -module Rpc : sig - type unary = string -> Grpc.Status.t * string option - (** [unary] is the type for a unary grpc rpc, one request, one response. *) - - type client_streaming = string Seq.t -> Grpc.Status.t * string option - (** [client_streaming] is the type for an rpc where the client streams the requests and the server responds once. *) - - type server_streaming = string -> (string -> unit) -> Grpc.Status.t - (** [server_streaming] is the type for an rpc where the client sends one request and the server sends multiple responses. *) - - type bidirectional_streaming = - string Seq.t -> (string -> unit) -> Grpc.Status.t - (** [bidirectional_streaming] is the type for an rpc where both the client and server can send multiple messages. *) - - type t = - | Unary of unary - | Client_streaming of client_streaming - | Server_streaming of server_streaming - | Bidirectional_streaming of bidirectional_streaming - - (** [t] represents the types of rpcs available in gRPC. *) - - val unary : f:unary -> H2.Reqd.t -> unit - (** [unary ~f reqd] calls [f] with the request obtained from [reqd] and handles sending the response. *) - - val client_streaming : f:client_streaming -> H2.Reqd.t -> unit - (** [client_streaming ~f reqd] calls [f] with a stream to pull requests from and handles sending the response. *) - - val server_streaming : f:server_streaming -> H2.Reqd.t -> unit - (** [server_streaming ~f reqd] calls [f] with the request optained from [reqd] and handles sending the responses pushed out. *) - - val bidirectional_streaming : f:bidirectional_streaming -> H2.Reqd.t -> unit - (** [bidirectional_streaming ~f reqd] calls [f] with a stream to pull requests from and andles sending the responses pushed out. *) -end - -module Service : sig - type t - (** [t] represents a gRPC service with potentially multiple rpcs and the information needed to route to them. *) - - val v : unit -> t - (** [v ()] creates a new service *) - - val add_rpc : name:string -> rpc:Rpc.t -> t -> t - (** [add_rpc ~name ~rpc t] adds [rpc] to [t] and ensures that [t] can route to it with [name]. *) - - val handle_request : t -> H2.Reqd.t -> unit - (** [handle_request t reqd] handles routing [reqd] to the correct rpc if available in [t]. *) -end +(** {1 Typed API} *) module Typed_rpc : sig (** This is an experimental API to build RPCs on the server side. Compared to @@ -133,3 +86,54 @@ module Typed_rpc : sig takes care of registering the services based on the names provided by the protoc specification. *) end + +(** {1 Untyped API} *) + +module Rpc : sig + type unary = string -> Grpc.Status.t * string option + (** [unary] is the type for a unary grpc rpc, one request, one response. *) + + type client_streaming = string Seq.t -> Grpc.Status.t * string option + (** [client_streaming] is the type for an rpc where the client streams the requests and the server responds once. *) + + type server_streaming = string -> (string -> unit) -> Grpc.Status.t + (** [server_streaming] is the type for an rpc where the client sends one request and the server sends multiple responses. *) + + type bidirectional_streaming = + string Seq.t -> (string -> unit) -> Grpc.Status.t + (** [bidirectional_streaming] is the type for an rpc where both the client and server can send multiple messages. *) + + type t = + | Unary of unary + | Client_streaming of client_streaming + | Server_streaming of server_streaming + | Bidirectional_streaming of bidirectional_streaming + + (** [t] represents the types of rpcs available in gRPC. *) + + val unary : f:unary -> H2.Reqd.t -> unit + (** [unary ~f reqd] calls [f] with the request obtained from [reqd] and handles sending the response. *) + + val client_streaming : f:client_streaming -> H2.Reqd.t -> unit + (** [client_streaming ~f reqd] calls [f] with a stream to pull requests from and handles sending the response. *) + + val server_streaming : f:server_streaming -> H2.Reqd.t -> unit + (** [server_streaming ~f reqd] calls [f] with the request optained from [reqd] and handles sending the responses pushed out. *) + + val bidirectional_streaming : f:bidirectional_streaming -> H2.Reqd.t -> unit + (** [bidirectional_streaming ~f reqd] calls [f] with a stream to pull requests from and andles sending the responses pushed out. *) +end + +module Service : sig + type t + (** [t] represents a gRPC service with potentially multiple rpcs and the information needed to route to them. *) + + val v : unit -> t + (** [v ()] creates a new service *) + + val add_rpc : name:string -> rpc:Rpc.t -> t -> t + (** [add_rpc ~name ~rpc t] adds [rpc] to [t] and ensures that [t] can route to it with [name]. *) + + val handle_request : t -> H2.Reqd.t -> unit + (** [handle_request t reqd] handles routing [reqd] to the correct rpc if available in [t]. *) +end