Skip to content

Commit

Permalink
feat: impl writer for socket and test it
Browse files Browse the repository at this point in the history
  • Loading branch information
leostera committed Dec 24, 2023
1 parent e5f8a86 commit 2cefea2
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 17 deletions.
19 changes: 9 additions & 10 deletions riot/lib/io.ml
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,13 @@ module Writer = struct
let flush = B.flush
end

type 't write = (module Write with type t = 't)
type 't writer = Writer : ('t write * 't) -> 't writer
type 'src write = (module Write with type t = 'src)
type 'src t = Writer : ('src write * 'src) -> 'src t

let of_write_src : type src. src write -> src -> src writer =
let of_write_src : type src. src write -> src -> src t =
fun write src -> Writer (write, src)

let write :
type src. src writer -> data:Buffer.t -> (int, [> `Closed ]) result =
let write : type src. src t -> data:Buffer.t -> (int, [> `Closed ]) result =
fun (Writer ((module W), src)) ~data -> W.write src ~data
end

Expand All @@ -116,14 +115,14 @@ module Reader = struct
let read = B.read
end

type 't read = (module Read with type t = 't)
type 't reader = Reader : ('t read * 't) -> 't reader
type 'src read = (module Read with type t = 'src)
type 'src t = Reader : ('src read * 'src) -> 'src t
type 'src reader = 'src t

let of_read_src : type src. src read -> src -> src reader =
let of_read_src : type src. src read -> src -> src t =
fun read src -> Reader (read, src)

let read : type src. src reader -> buf:Buffer.t -> (int, [> `Closed ]) result
=
let read : type src. src t -> buf:Buffer.t -> (int, [> `Closed ]) result =
fun (Reader ((module R), src)) ~buf -> R.read src ~buf

module Buffered = struct
Expand Down
9 changes: 9 additions & 0 deletions riot/lib/net.ml
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,13 @@ module Socket = struct
end)

let to_reader t = Io.Reader.of_read_src (module Read) t

module Write = Io.Writer.Make (struct
type t = stream_socket

let write t ~data = send ~data t
let flush _t = Ok ()
end)

let to_writer t = Io.Writer.of_write_src (module Write) t
end
14 changes: 9 additions & 5 deletions riot/riot.mli
Original file line number Diff line number Diff line change
Expand Up @@ -443,9 +443,9 @@ module IO : sig
end

module Writer : sig
type 'src writer
type 'src t

val write : 'src writer -> data:Buffer.t -> (int, [> `Closed ]) result
val write : 'src t -> data:Buffer.t -> (int, [> `Closed ]) result

module Make (B : Write) : sig
type t = B.t
Expand All @@ -462,7 +462,8 @@ module IO : sig
end

module Reader : sig
type 'src reader
type 'src t
type 'src reader = 'src t

val read : 'src reader -> buf:Buffer.t -> (int, [> `Closed ]) result

Expand All @@ -488,8 +489,8 @@ module File : sig
val open_write : string -> [ `w ] file
val close : _ file -> unit
val remove : string -> unit
val to_reader : [ `r ] file -> [ `r ] file IO.Reader.reader
val to_writer : [ `w ] file -> [ `w ] file IO.Writer.writer
val to_reader : [ `r ] file -> [ `r ] file IO.Reader.t
val to_writer : [ `w ] file -> [ `w ] file IO.Writer.t
end

module Net : sig
Expand Down Expand Up @@ -556,6 +557,9 @@ module Net : sig
Format.formatter ->
[ IO.unix_error | `Closed | `Timeout | `System_limit ] ->
unit

val to_reader : stream_socket -> stream_socket IO.Reader.t
val to_writer : stream_socket -> stream_socket IO.Writer.t
end
end

Expand Down
5 changes: 5 additions & 0 deletions test/dune
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@
(modules net_addr_uri_test)
(libraries riot))

(test
(name net_reader_writer_test)
(modules net_reader_writer_test)
(libraries riot))

(test
(name net_test)
(modules net_test)
Expand Down
4 changes: 2 additions & 2 deletions test/io_writer_test.ml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ let () =
Logger.set_log_level (Some Info);
let now = Ptime_clock.now () in
let path =
Format.asprintf "./test/generated/%a.io_writer_test.txt"
(Ptime.pp_rfc3339 ()) now
Format.asprintf "./generated/%a.io_writer_test.txt" (Ptime.pp_rfc3339 ())
now
in
let file = File.open_write path in
let writer = File.to_writer file in
Expand Down
111 changes: 111 additions & 0 deletions test/net_reader_writer_test.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
open Riot

type Message.t += Received of string

(* rudimentary tcp echo server *)
let server port =
let socket = Net.Socket.listen ~port () |> Result.get_ok in
Logger.debug (fun f -> f "Started server on %d" port);
process_flag (Trap_exit true);
let conn, addr = Net.Socket.accept socket |> Result.get_ok in
Logger.debug (fun f ->
f "Accepted client %a (%a)" Net.Addr.pp addr Net.Socket.pp conn);
let close () =
Net.Socket.close conn;
Logger.debug (fun f ->
f "Closed client %a (%a)" Net.Addr.pp addr Net.Socket.pp conn)
in

let reader = Net.Socket.to_reader conn in
let writer = Net.Socket.to_writer conn in

let buf = IO.Buffer.with_capacity 1024 in
let rec echo () =
Logger.debug (fun f ->
f "Reading from client client %a (%a)" Net.Addr.pp addr Net.Socket.pp
conn);
match IO.Reader.read reader ~buf with
| Ok len -> (
Logger.debug (fun f -> f "Server received %d bytes" len);
let data = IO.Buffer.sub ~off:0 ~len buf in
match IO.Writer.write ~data writer with
| Ok bytes ->
Logger.debug (fun f -> f "Server sent %d bytes" bytes);
echo ()
| Error `Closed -> close ()
| Error (`Unix_error unix_err) ->
Logger.error (fun f ->
f "send unix error %s" (Unix.error_message unix_err));
close ())
| Error (`Closed | `Timeout) -> close ()
| Error (`Unix_error unix_err) ->
Logger.error (fun f ->
f "recv unix error %s" (Unix.error_message unix_err));
close ()
in
echo ()

let client port main =
let addr = Net.Addr.(tcp loopback port) in
let conn = Net.Socket.connect addr |> Result.get_ok in
Logger.debug (fun f -> f "Connected to server on %d" port);
let data = IO.Buffer.of_string "hello world" in

let reader = Net.Socket.to_reader conn in
let writer = Net.Socket.to_writer conn in

let rec send_loop n =
sleep 0.001;
if n = 0 then Logger.error (fun f -> f "client retried too many times")
else
match IO.Writer.write ~data writer with
| Ok bytes -> Logger.debug (fun f -> f "Client sent %d bytes" bytes)
| Error `Closed -> Logger.debug (fun f -> f "connection closed")
| Error (`Unix_error (ENOTCONN | EPIPE)) -> send_loop n
| Error (`Unix_error unix_err) ->
Logger.error (fun f ->
f "client unix error %s" (Unix.error_message unix_err));
send_loop (n - 1)
in
send_loop 10_000;

let buf = IO.Buffer.with_capacity 128 in
let recv_loop () =
match IO.Reader.read ~buf reader with
| Ok bytes ->
Logger.debug (fun f -> f "Client received %d bytes" bytes);
bytes
| Error (`Closed | `Timeout) ->
Logger.error (fun f -> f "Server closed the connection");
0
| Error (`Unix_error unix_err) ->
Logger.error (fun f ->
f "client unix error %s" (Unix.error_message unix_err));
0
in
let len = recv_loop () in

if len = 0 then send main (Received "empty paylaod")
else send main (Received (IO.Buffer.to_string buf))

let () =
Riot.run @@ fun () ->
let _ = Logger.start () |> Result.get_ok in
Logger.set_log_level (Some Info);
let port = 2112 in
let main = self () in
let _server = spawn (fun () -> server port) in
let _client = spawn (fun () -> client port main) in
match receive () with
| Received "hello world" ->
Logger.info (fun f -> f "net_reader_writer_test: OK");
sleep 0.001;
shutdown ()
| Received other ->
Logger.error (fun f -> f "net_reader_writer_test: bad payload: %S" other);
sleep 0.001;
Stdlib.exit 1
| _ ->
Logger.error (fun f -> f "net_reader_writer_test: unexpected message");
sleep 0.001;
Stdlib.exit 1

0 comments on commit 2cefea2

Please sign in to comment.