Skip to content

Commit

Permalink
implement untyped rpc server on top of typed server
Browse files Browse the repository at this point in the history
  • Loading branch information
mbarbin committed Jan 6, 2024
1 parent 7121496 commit 9b80a3a
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 162 deletions.
10 changes: 5 additions & 5 deletions lib/grpc-eio/client.ml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ module Typed_rpc = struct
'a

let make_handler (type request response)
(rpc : (request, _, response, _) Grpc.Rpc.Client_rpc.t) ~f =
~(rpc : (request, _, response, _) Grpc.Rpc.Client_rpc.t) ~f =
make_handler ~encode_request:rpc.encode_request
~decode_response:rpc.decode_response ~f

Expand All @@ -99,7 +99,7 @@ module Typed_rpc = struct
response,
Grpc.Rpc.Value_mode.stream )
Grpc.Rpc.Client_rpc.t) =
make_handler rpc ~f
make_handler ~rpc ~f

let client_streaming (type request response) ~f
(rpc :
Expand All @@ -108,7 +108,7 @@ module Typed_rpc = struct
response,
Grpc.Rpc.Value_mode.unary )
Grpc.Rpc.Client_rpc.t) =
make_handler rpc ~f:(fun request_writer responses ->
make_handler ~rpc ~f:(fun request_writer responses ->
let response, response_resolver = Eio.Promise.create () in
Eio.Fiber.pair
(fun () -> f request_writer response)
Expand All @@ -124,7 +124,7 @@ module Typed_rpc = struct
response,
Grpc.Rpc.Value_mode.stream )
Grpc.Rpc.Client_rpc.t) =
make_handler rpc ~f:(fun request_writer responses ->
make_handler ~rpc ~f:(fun request_writer responses ->
Seq.write request_writer request;
Seq.close_writer request_writer;
f responses)
Expand All @@ -136,7 +136,7 @@ module Typed_rpc = struct
response,
Grpc.Rpc.Value_mode.unary )
Grpc.Rpc.Client_rpc.t) =
make_handler rpc ~f:(fun request_writer responses ->
make_handler ~rpc ~f:(fun request_writer responses ->
Seq.write request_writer request;
Seq.close_writer request_writer;
let response = Seq.read_and_exhaust responses in
Expand Down
236 changes: 127 additions & 109 deletions lib/grpc-eio/server.ml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module ServiceMap = Map.Make (String)

type service = H2.Reqd.t -> unit
type reqd_handler = H2.Reqd.t -> unit
type service = reqd_handler
type t = service ServiceMap.t

let v () = ServiceMap.empty
Expand Down Expand Up @@ -44,93 +45,50 @@ let handle_request t reqd =
| None -> respond_with `Unsupported_media_type)
| _ -> respond_with `Not_found

module Rpc = struct
type unary = string -> Grpc.Status.t * string option
type client_streaming = string Seq.t -> Grpc.Status.t * string option
type server_streaming = string -> (string -> unit) -> Grpc.Status.t

type bidirectional_streaming =
string Seq.t -> (string -> unit) -> Grpc.Status.t

type t =
| Unary of unary
| Client_streaming of client_streaming
| Server_streaming of server_streaming
| Bidirectional_streaming of bidirectional_streaming

let bidirectional_streaming ~f reqd =
let body = H2.Reqd.request_body reqd in
let request_reader, request_writer = Seq.create_reader_writer () in
let response_reader, response_writer = Seq.create_reader_writer () in
Connection.Untyped.grpc_recv_streaming body request_writer;
let status_promise, status_notify = Eio.Promise.create () in
Eio.Fiber.both
(fun () ->
let respond = Seq.write response_writer in
let status = f request_reader respond in
Seq.close_writer response_writer;
Eio.Promise.resolve status_notify status)
(fun () ->
try
Connection.Untyped.grpc_send_streaming reqd response_reader
status_promise
with exn ->
(* https://github.com/anmonteiro/ocaml-h2/issues/175 *)
Eio.traceln "%s" (Printexc.to_string exn))

let client_streaming ~f reqd =
bidirectional_streaming reqd ~f:(fun requests respond ->
let status, response = f requests in
(match response with None -> () | Some response -> respond response);
status)

let server_streaming ~f reqd =
bidirectional_streaming reqd ~f:(fun requests respond ->
match Seq.read_and_exhaust requests with
| None -> Grpc.Status.(v OK)
| Some request -> f request respond)

let unary ~f reqd =
bidirectional_streaming reqd ~f:(fun requests respond ->
match Seq.read_and_exhaust requests with
| None -> Grpc.Status.(v OK)
| Some request ->
let status, response = f request in
(match response with
| None -> ()
| Some response -> respond response);
status)
end
let implement_rpc ~decode_request ~encode_response ~f reqd =
let body = H2.Reqd.request_body reqd in
let request_reader, request_writer = Seq.create_reader_writer () in
let response_reader, response_writer = Seq.create_reader_writer () in
Connection.Typed.grpc_recv_streaming ~decode:decode_request body
request_writer;
let status_promise, status_notify = Eio.Promise.create () in
Eio.Fiber.both
(fun () ->
let respond = Seq.write response_writer in
let status = f request_reader respond in
Seq.close_writer response_writer;
Eio.Promise.resolve status_notify status)
(fun () ->
try
Connection.Typed.grpc_send_streaming ~encode:encode_response reqd
response_reader status_promise
with exn ->
(* https://github.com/anmonteiro/ocaml-h2/issues/175 *)
Eio.traceln "%s" (Printexc.to_string exn))

module Service = struct
module RpcMap = Map.Make (String)
module Typed_rpc = struct
module Service = struct
module RpcMap = Map.Make (String)

type t = Rpc.t RpcMap.t
type t = reqd_handler RpcMap.t

let v () = RpcMap.empty
let add_rpc ~name ~rpc t = RpcMap.add name rpc t
let v () = RpcMap.empty
let add_rpc ~name ~rpc t = RpcMap.add name rpc t

let handle_request (t : t) reqd =
let request = H2.Reqd.request reqd in
let respond_with code =
H2.Reqd.respond_with_string reqd (H2.Response.create code) ""
in
let parts = String.split_on_char '/' request.target in
if List.length parts > 1 then
let rpc_name = List.nth parts (List.length parts - 1) in
let rpc = RpcMap.find_opt rpc_name t in
match rpc with
| Some rpc -> (
match rpc with
| Unary f -> Rpc.unary ~f reqd
| Client_streaming f -> Rpc.client_streaming ~f reqd
| Server_streaming f -> Rpc.server_streaming ~f reqd
| Bidirectional_streaming f -> Rpc.bidirectional_streaming ~f reqd)
| None -> respond_with `Not_found
else respond_with `Not_found
end
let handle_request (t : t) reqd =
let request = H2.Reqd.request reqd in
let respond_with code =
H2.Reqd.respond_with_string reqd (H2.Response.create code) ""
in
let parts = String.split_on_char '/' request.target in
if List.length parts > 1 then
let rpc_name = List.nth parts (List.length parts - 1) in
match RpcMap.find_opt rpc_name t with
| Some rpc -> rpc reqd
| None -> respond_with `Not_found
else respond_with `Not_found
end

module Typed_rpc = struct
type server = t

type ('request, 'response) unary =
Expand All @@ -154,7 +112,7 @@ module Typed_rpc = struct
'response_mode,
'service_spec )
Grpc.Rpc.Server_rpc.t;
rpc_impl : Rpc.t;
rpc_impl : reqd_handler;
}
-> 'service_spec t

Expand Down Expand Up @@ -195,19 +153,23 @@ module Typed_rpc = struct
in
Service.handle_request service)

let implement_rpc (type request response)
~(rpc_spec : (request, _, response, _, _) Grpc.Rpc.Server_rpc.t) ~f =
let rpc_impl =
implement_rpc ~decode_request:rpc_spec.decode_request
~encode_response:rpc_spec.encode_response ~f
in
T { rpc_spec; rpc_impl }

let bidirectional_streaming (type request response)
(rpc_spec :
( request,
Grpc.Rpc.Value_mode.stream,
response,
Grpc.Rpc.Value_mode.stream,
_ )
Grpc.Rpc.Server_rpc.t) ~f:handler =
let handler requests f =
let requests = Seq.map rpc_spec.decode_request requests in
handler requests (fun response -> f (rpc_spec.encode_response response))
in
T { rpc_spec; rpc_impl = Rpc.Bidirectional_streaming handler }
Grpc.Rpc.Server_rpc.t) ~f =
implement_rpc ~rpc_spec ~f

let unary (type request response)
(rpc_spec :
Expand All @@ -216,12 +178,16 @@ module Typed_rpc = struct
response,
Grpc.Rpc.Value_mode.unary,
_ )
Grpc.Rpc.Server_rpc.t) ~f:handler =
let handler buffer =
let status, response = handler (rpc_spec.decode_request buffer) in
(status, Option.map rpc_spec.encode_response response)
in
T { rpc_spec; rpc_impl = Rpc.Unary handler }
Grpc.Rpc.Server_rpc.t) ~f =
implement_rpc ~rpc_spec ~f:(fun requests respond ->
match Seq.read_and_exhaust requests with
| None -> Grpc.Status.(v OK)
| Some request ->
let status, response = f request in
(match response with
| None -> ()
| Some response -> respond response);
status)

let server_streaming (type request response)
(rpc_spec :
Expand All @@ -230,12 +196,11 @@ module Typed_rpc = struct
response,
Grpc.Rpc.Value_mode.stream,
_ )
Grpc.Rpc.Server_rpc.t) ~f:handler =
let handler buffer f =
handler (rpc_spec.decode_request buffer) (fun response ->
f (rpc_spec.encode_response response))
in
T { rpc_spec; rpc_impl = Rpc.Server_streaming handler }
Grpc.Rpc.Server_rpc.t) ~f =
implement_rpc ~rpc_spec ~f:(fun requests respond ->
match Seq.read_and_exhaust requests with
| None -> Grpc.Status.(v OK)
| Some request -> f request respond)

let client_streaming (type request response)
(rpc_spec :
Expand All @@ -244,11 +209,64 @@ module Typed_rpc = struct
response,
Grpc.Rpc.Value_mode.unary,
_ )
Grpc.Rpc.Server_rpc.t) ~f:handler =
let handler requests =
let requests = Seq.map rpc_spec.decode_request requests in
let status, response = handler requests in
(status, Option.map rpc_spec.encode_response response)
in
T { rpc_spec; rpc_impl = Rpc.Client_streaming handler }
Grpc.Rpc.Server_rpc.t) ~f =
implement_rpc ~rpc_spec ~f:(fun requests respond ->
let status, response = f requests in
(match response with None -> () | Some response -> respond response);
status)
end

module Rpc = struct
type unary = string -> Grpc.Status.t * string option
type client_streaming = string Seq.t -> Grpc.Status.t * string option
type server_streaming = string -> (string -> unit) -> Grpc.Status.t

type bidirectional_streaming =
string Seq.t -> (string -> unit) -> Grpc.Status.t

type t =
| Unary of unary
| Client_streaming of client_streaming
| Server_streaming of server_streaming
| Bidirectional_streaming of bidirectional_streaming

let bidirectional_streaming ~f reqd =
implement_rpc ~decode_request:Fun.id ~encode_response:Fun.id ~f reqd

let client_streaming ~f reqd =
bidirectional_streaming reqd ~f:(fun requests respond ->
let status, response = f requests in
(match response with None -> () | Some response -> respond response);
status)

let server_streaming ~f reqd =
bidirectional_streaming reqd ~f:(fun requests respond ->
match Seq.read_and_exhaust requests with
| None -> Grpc.Status.(v OK)
| Some request -> f request respond)

let unary ~f reqd =
bidirectional_streaming reqd ~f:(fun requests respond ->
match Seq.read_and_exhaust requests with
| None -> Grpc.Status.(v OK)
| Some request ->
let status, response = f request in
(match response with
| None -> ()
| Some response -> respond response);
status)
end

module Service = struct
include Typed_rpc.Service

let add_rpc ~name ~rpc t =
add_rpc ~name
~rpc:
(match rpc with
| Rpc.Unary f -> Rpc.unary ~f
| Client_streaming f -> Rpc.client_streaming ~f
| Server_streaming f -> Rpc.server_streaming ~f
| Bidirectional_streaming f -> Rpc.bidirectional_streaming ~f)
t
end
Loading

0 comments on commit 9b80a3a

Please sign in to comment.