From 5e5cdc7d137586ca67f5e584b103bd2bf14aa470 Mon Sep 17 00:00:00 2001 From: Tristan Sloughter Date: Sun, 17 Jun 2018 10:57:28 -0600 Subject: [PATCH] client stream callbacks --- README.md | 4 +- src/chatterbox.app.src | 2 +- src/chatterbox_static_content_handler.erl | 138 ------------------- src/chatterbox_static_stream.erl | 14 +- src/h2_client.erl | 12 +- src/h2_connection.erl | 117 +++++++++------- src/h2_stream.erl | 157 +++++++++++++++++----- src/h2_stream_set.erl | 24 ++-- test/client_server_SUITE.erl | 10 +- test/double_body_handler.erl | 18 +-- test/echo_handler.erl | 18 +-- test/flow_control_handler.erl | 18 +-- test/peer_test_handler.erl | 18 +-- test/server_connection_receive_window.erl | 18 +-- test/server_stream_receive_window.erl | 18 +-- 15 files changed, 290 insertions(+), 296 deletions(-) delete mode 100644 src/chatterbox_static_content_handler.erl diff --git a/README.md b/README.md index 3e315fd9..5cbc8fb9 100644 --- a/README.md +++ b/README.md @@ -104,7 +104,7 @@ RequestHeaders = [ RequestBody = <<>>, -{ok, {ResponseHeaders, ResponseBody}} +{ok, {ResponseHeaders, ResponseBody, ResponseTrailers}} = h2_client:sync_request(Pid, RequestHeaders, RequestBody). ``` @@ -125,7 +125,7 @@ So you can use a receive block, like this ```erlang receive {'END_STREAM', StreamId} -> - {ok, {ResponseHeaders, ResponseBody}} = h2_client:get_response(Pid, StreamId) + {ok, {ResponseHeaders, ResponseBody, ResponseTrailers}} = h2_client:get_response(Pid, StreamId) end, ``` diff --git a/src/chatterbox.app.src b/src/chatterbox.app.src index 97b4adc7..6599afdd 100644 --- a/src/chatterbox.app.src +++ b/src/chatterbox.app.src @@ -45,7 +45,7 @@ ]}, - {maintainers, ["Joe DeVivo"]}, + {maintainers, ["Joe DeVivo", "Tristan Sloughter"]}, {licenses, ["MIT"]}, {links, [{"Github", "https://github.com/joedevivo/chatterbox"}]} ]}. diff --git a/src/chatterbox_static_content_handler.erl b/src/chatterbox_static_content_handler.erl deleted file mode 100644 index b6300120..00000000 --- a/src/chatterbox_static_content_handler.erl +++ /dev/null @@ -1,138 +0,0 @@ --module(chatterbox_static_content_handler). - --include("http2.hrl"). - --export([ - spawn_handle/4, - handle/4 - ]). - --spec spawn_handle( - pid(), - stream_id(), %% Stream Id - hpack:headers(), %% Decoded Request Headers - iodata() %% Request Body - ) -> pid(). -spawn_handle(Pid, StreamId, Headers, ReqBody) -> - Handler = fun() -> - handle(Pid, StreamId, Headers, ReqBody) - end, - spawn_link(Handler). - --spec handle( - pid(), - stream_id(), - hpack:headers(), - iodata() - ) -> ok. -handle(ConnPid, StreamId, Headers, _ReqBody) -> - Path = binary_to_list(proplists:get_value(<<":path">>, Headers)), - - %% QueryString Hack? - Path2 = case string:chr(Path, $?) of - 0 -> Path; - X -> string:substr(Path, 1, X-1) - end, - - %% Dot Hack - Path3 = case Path2 of - [$.|T] -> T; - Other -> Other - end, - - - Path4 = case Path3 of - [$/|T2] -> [$/|T2]; - Other2 -> [$/|Other2] - end, - - %% TODO: Should have a better way of extracting root_dir (i.e. not on every request) - StaticHandlerSettings = application:get_env(chatterbox, ?MODULE, []), - RootDir = proplists:get_value(root_dir, StaticHandlerSettings, code:priv_dir(chatterbox)), - - %% TODO: Logic about "/" vs "index.html", "index.htm", etc... - %% Directory browsing? - File = RootDir ++ Path4, - - case {filelib:is_file(File), filelib:is_dir(File)} of - {_, true} -> - ResponseHeaders = [ - {<<":status">>,<<"403">>} - ], - h2_connection:send_headers(ConnPid, StreamId, ResponseHeaders), - h2_connection:send_body(ConnPid, StreamId, <<"No soup for you!">>), - ok; - {true, false} -> - Ext = filename:extension(File), - MimeType = case Ext of - ".js" -> <<"text/javascript">>; - ".html" -> <<"text/html">>; - ".css" -> <<"text/css">>; - ".scss" -> <<"text/css">>; - ".woff" -> <<"application/font-woff">>; - ".ttf" -> <<"application/font-snft">>; - _ -> <<"unknown">> - end, - {ok, Data} = file:read_file(File), - - ResponseHeaders = [ - {<<":status">>, <<"200">>}, - {<<"content-type">>, MimeType} - ], - - h2_connection:send_headers(ConnPid, StreamId, ResponseHeaders), - - - case {MimeType, h2_connection:is_push(ConnPid)} of - {<<"text/html">>, true} -> - %% Search Data for resources to push - {ok, RE} = re:compile(" - [dot_hack(lists:last(M)) || M <- Matches]; - _ -> [] - end, - - NewStreams = - lists:foldl(fun(R, Acc) -> - NewStreamId = h2_connection:new_stream(ConnPid), - PHeaders = generate_push_promise_headers(Headers, <<$/,R/binary>> - ), - h2_connection:send_promise(ConnPid, StreamId, NewStreamId, PHeaders), - [{NewStreamId, PHeaders}|Acc] - end, - [], - Resources - ), - - [spawn_handle(ConnPid, NewStreamId, PHeaders, <<>>) || {NewStreamId, PHeaders} <- NewStreams], - - ok; - _ -> - ok - end, - h2_connection:send_body(ConnPid, StreamId, Data), - ok; - {false, false} -> - ResponseHeaders = [ - {<<":status">>,<<"404">>} - ], - h2_connection:send_headers(ConnPid, StreamId, ResponseHeaders), - h2_connection:send_body(ConnPid, StreamId, <<"No soup for you!">>), - ok - end, - ok. - --spec generate_push_promise_headers(hpack:headers(), binary()) -> hpack:headers(). -generate_push_promise_headers(Request, Path) -> - [ - {<<":path">>, Path},{<<":method">>, <<"GET">>}| - lists:filter(fun({<<":authority">>,_}) -> true; - ({<<":scheme">>, _}) -> true; - (_) -> false end, Request) - ]. - --spec dot_hack(binary()) -> binary(). -dot_hack(<<$.,Bin/binary>>) -> - Bin; -dot_hack(Bin) -> Bin. diff --git a/src/chatterbox_static_stream.erl b/src/chatterbox_static_stream.erl index d4404c0a..9b779e16 100644 --- a/src/chatterbox_static_stream.erl +++ b/src/chatterbox_static_stream.erl @@ -6,10 +6,10 @@ -export([ init/3, - on_receive_request_headers/2, + on_receive_headers/2, on_send_push_promise/2, - on_receive_request_data/2, - on_request_end_stream/1 + on_receive_data/2, + on_end_stream/1 ]). -record(cb_static, { @@ -23,16 +23,16 @@ init(ConnPid, StreamId, _) -> {ok, #cb_static{connection_pid=ConnPid, stream_id=StreamId}}. -on_receive_request_headers(Headers, State) -> +on_receive_headers(Headers, State) -> {ok, State#cb_static{req_headers=Headers}}. on_send_push_promise(Headers, State) -> {ok, State#cb_static{req_headers=Headers}}. -on_receive_request_data(_Bin, State)-> +on_receive_data(_Bin, State)-> {ok, State}. -on_request_end_stream(State=#cb_static{connection_pid=ConnPid, +on_end_stream(State=#cb_static{connection_pid=ConnPid, stream_id=StreamId}) -> Headers = State#cb_static.req_headers, @@ -102,7 +102,7 @@ on_request_end_stream(State=#cb_static{connection_pid=ConnPid, lists:foldl( fun(R, Acc) -> - NewStreamId = h2_connection:new_stream(ConnPid), + {NewStreamId, _} = h2_connection:new_stream(ConnPid), PHeaders = generate_push_promise_headers(Headers, <<$/,R/binary>> ), h2_connection:send_promise(ConnPid, StreamId, NewStreamId, PHeaders), diff --git a/src/h2_client.erl b/src/h2_client.erl index 493292de..b180e0bc 100644 --- a/src/h2_client.erl +++ b/src/h2_client.erl @@ -19,6 +19,7 @@ start_link/2, start_link/3, start_link/4, + start_link/5, start_ssl_upgrade_link/4, stop/1, send_request/3, @@ -102,14 +103,17 @@ start_link(https, Host, SSLOptions) | ignore | {error, term()}. start_link(Transport, Host, Port, SSLOptions) -> + start_link(Transport, Host, Port, SSLOptions, #{}). + +start_link(Transport, Host, Port, SSLOptions, ConnectionSettings) -> NewT = case Transport of http -> gen_tcp; https -> ssl end, - h2_connection:start_client_link(NewT, Host, Port, SSLOptions, chatterbox:settings(client)). + h2_connection:start_client_link(NewT, Host, Port, SSLOptions, chatterbox:settings(client), ConnectionSettings). start_ssl_upgrade_link(Host, Port, InitialMessage, SSLOptions) -> - h2_connection:start_ssl_upgrade_link(Host, Port, InitialMessage, SSLOptions, chatterbox:settings(client)). + h2_connection:start_ssl_upgrade_link(Host, Port, InitialMessage, SSLOptions, chatterbox:settings(client), #{}). -spec stop(pid()) -> ok. stop(Pid) -> @@ -139,7 +143,7 @@ send_request(CliPid, Headers, Body) -> case h2_connection:new_stream(CliPid) of {error, _Code} = Err -> Err; - StreamId -> + {StreamId, _} -> h2_connection:send_headers(CliPid, StreamId, Headers), h2_connection:send_body(CliPid, StreamId, Body), {ok, StreamId} @@ -149,7 +153,7 @@ send_ping(CliPid) -> h2_connection:send_ping(CliPid). -spec get_response(pid(), stream_id()) -> - {ok, {hpack:header(), iodata()}} + {ok, {hpack:headers(), iodata(), hpack:headers()}} | not_ready | {error, term()}. get_response(CliPid, StreamId) -> diff --git a/src/h2_connection.erl b/src/h2_connection.erl index 071f5744..0a58f6eb 100644 --- a/src/h2_connection.erl +++ b/src/h2_connection.erl @@ -4,8 +4,8 @@ %% Start/Stop API -export([ - start_client_link/5, - start_ssl_upgrade_link/5, + start_client_link/6, + start_ssl_upgrade_link/6, start_server_link/3, become/1, become/2, @@ -26,6 +26,7 @@ is_push/1, new_stream/1, new_stream/2, + new_stream/4, send_promise/4, get_response/2, get_peer/1, @@ -90,8 +91,8 @@ settings_sent = queue:new() :: queue:queue(), next_available_stream_id = 2 :: stream_id(), streams :: h2_stream_set:stream_set(), - stream_callback_mod = application:get_env(chatterbox, stream_callback_mod, chatterbox_static_stream) :: module(), - stream_callback_opts = application:get_env(chatterbox, stream_callback_opts, []) :: list(), + stream_callback_mod :: module() | undefined, + stream_callback_opts :: list() | undefined, buffer = empty :: empty | {binary, binary()} | {frame, h2_frame:header(), binary()}, continuation = undefined :: undefined | #continuation_state{}, flow_control = auto :: auto | manual, @@ -105,25 +106,33 @@ -export_type([send_option/0, send_opts/0]). +-ifdef(OTP_RELEASE). +-define(ssl_accept(ClientSocket, SSLOptions), ssl:handshake(ClientSocket, SSLOptions)). +-else. +-define(ssl_accept(ClientSocket, SSLOptions), ssl:ssl_accept(ClientSocket, SSLOptions)). +-endif. + -spec start_client_link(gen_tcp | ssl, inet:ip_address() | inet:hostname(), inet:port_number(), [ssl:ssl_option()], - settings() + settings(), + maps:map() ) -> {ok, pid()} | ignore | {error, term()}. -start_client_link(Transport, Host, Port, SSLOptions, Http2Settings) -> - gen_statem:start_link(?MODULE, {client, Transport, Host, Port, SSLOptions, Http2Settings}, []). +start_client_link(Transport, Host, Port, SSLOptions, Http2Settings, ConnectionSettings) -> + gen_statem:start_link(?MODULE, {client, Transport, Host, Port, SSLOptions, Http2Settings, ConnectionSettings}, []). -spec start_ssl_upgrade_link(inet:ip_address() | inet:hostname(), inet:port_number(), binary(), [ssl:ssl_option()], - settings() + settings(), + maps:map() ) -> {ok, pid()} | ignore | {error, term()}. -start_ssl_upgrade_link(Host, Port, InitialMessage, SSLOptions, Http2Settings) -> - gen_statem:start_link(?MODULE, {client_ssl_upgrade, Host, Port, InitialMessage, SSLOptions, Http2Settings}, []). +start_ssl_upgrade_link(Host, Port, InitialMessage, SSLOptions, Http2Settings, ConnectionSettings) -> + gen_statem:start_link(?MODULE, {client_ssl_upgrade, Host, Port, InitialMessage, SSLOptions, Http2Settings, ConnectionSettings}, []). -spec start_server_link(socket(), [ssl:ssl_option()], @@ -161,7 +170,7 @@ become({Transport, Socket}, Http2Settings, ConnectionSettings) -> NewState). %% Init callback -init({client, Transport, Host, Port, SSLOptions, Http2Settings}) -> +init({client, Transport, Host, Port, SSLOptions, Http2Settings, ConnectionSettings}) -> case Transport:connect(Host, Port, client_options(Transport, SSLOptions)) of {ok, Socket} -> ok = sock:setopts({Transport, Socket}, [{packet, raw}, binary]), @@ -169,6 +178,8 @@ init({client, Transport, Host, Port, SSLOptions, Http2Settings}) -> InitialState = #connection{ type = client, + stream_callback_mod = maps:get(stream_callback_mod, ConnectionSettings, undefined), + stream_callback_opts = maps:get(stream_callback_opts, ConnectionSettings, []), streams = h2_stream_set:new(client), socket = {Transport, Socket}, next_available_stream_id=1, @@ -181,7 +192,7 @@ init({client, Transport, Host, Port, SSLOptions, Http2Settings}) -> {error, Reason} -> {stop, Reason} end; -init({client_ssl_upgrade, Host, Port, InitialMessage, SSLOptions, Http2Settings}) -> +init({client_ssl_upgrade, Host, Port, InitialMessage, SSLOptions, Http2Settings, ConnectionSettings}) -> case gen_tcp:connect(Host, Port, [{active, false}]) of {ok, TCP} -> gen_tcp:send(TCP, InitialMessage), @@ -193,6 +204,8 @@ init({client_ssl_upgrade, Host, Port, InitialMessage, SSLOptions, Http2Settings} InitialState = #connection{ type = client, + stream_callback_mod = maps:get(stream_callback_mod, ConnectionSettings, undefined), + stream_callback_opts = maps:get(stream_callback_opts, ConnectionSettings, []), streams = h2_stream_set:new(client), socket = {ssl, Socket}, next_available_stream_id=1, @@ -290,16 +303,19 @@ get_peercert(Pid) -> is_push(Pid) -> gen_statem:call(Pid, is_push). --spec new_stream(pid()) -> stream_id() | {error, error_code()}. +-spec new_stream(pid()) -> {stream_id(), pid()} | {error, error_code()}. new_stream(Pid) -> new_stream(Pid, self()). -spec new_stream(pid(), pid()) -> - stream_id() + {stream_id(), pid()} | {error, error_code()}. new_stream(Pid, NotifyPid) -> gen_statem:call(Pid, {new_stream, NotifyPid}). +new_stream(Pid, CallbackMod, CallbackOpts, NotifyPid) -> + gen_statem:call(Pid, {new_stream, CallbackMod, CallbackOpts, NotifyPid}). + -spec send_promise(pid(), stream_id(), stream_id(), hpack:headers()) -> ok. send_promise(Pid, StreamId, NewStreamId, Headers) -> gen_statem:cast(Pid, {send_promise, StreamId, NewStreamId, Headers}), @@ -348,7 +364,7 @@ listen(info, {inet_async, ListenSocket, Ref, {ok, ClientSocket}}, gen_tcp -> ClientSocket; ssl -> - {ok, AcceptSocket} = ssl:ssl_accept(ClientSocket, SSLOptions), + {ok, AcceptSocket} = ?ssl_accept(ClientSocket, SSLOptions), {ok, <<"h2">>} = ssl:negotiated_protocol(AcceptSocket), AcceptSocket end, @@ -635,11 +651,12 @@ route_frame({#frame_header{type=?HEADERS}=FH, _Payload}=Frame, Conn#connection.socket, (Conn#connection.peer_settings)#settings.initial_window_size, (Conn#connection.self_settings)#settings.initial_window_size, + Conn#connection.type, Streams) of {error, ErrorCode, NewStream} -> rst_stream(NewStream, ErrorCode, Conn), {none, Conn}; - NewStreams -> + {_, NewStreams} -> {headers, Conn#connection{streams=NewStreams}} end; {active, server} -> @@ -735,7 +752,7 @@ route_frame({H=#frame_header{ %% reserved(local) and reserved(remote) aren't technically %% 'active', but they're being counted that way right now. Again, %% that only matters if Server Push is enabled. - NewStreams = + {_, NewStreams} = h2_stream_set:new_stream( PSID, NotifyPid, @@ -744,6 +761,7 @@ route_frame({H=#frame_header{ Conn#connection.socket, (Conn#connection.peer_settings)#settings.initial_window_size, (Conn#connection.self_settings)#settings.initial_window_size, + Conn#connection.type, Streams), Continuation = #continuation_state{ @@ -879,7 +897,8 @@ route_frame(Frame, #connection{}=Conn) -> handle_event(_, {stream_finished, StreamId, Headers, - Body}, + Body, + Trailers}, Conn) -> Stream = h2_stream_set:get(StreamId, Conn#connection.streams), case h2_stream_set:type(Stream) of @@ -888,7 +907,7 @@ handle_event(_, {stream_finished, Response = case Conn#connection.type of server -> garbage; - client -> {Headers, Body} + client -> {Headers, Body, Trailers} end, {_NewStream, NewStreams} = h2_stream_set:close( @@ -1126,32 +1145,12 @@ handle_event({call, From}, {get_response, StreamId}, {keep_state, Conn, [{reply, From, Reply}]}; handle_event({call, From}, {new_stream, NotifyPid}, #connection{ - streams=Streams, - next_available_stream_id=NextId + stream_callback_mod=CallbackMod, + stream_callback_opts=CallbackOpts }=Conn) -> - {Reply, NewStreams} = - case - h2_stream_set:new_stream( - NextId, - NotifyPid, - Conn#connection.stream_callback_mod, - Conn#connection.stream_callback_opts, - Conn#connection.socket, - Conn#connection.peer_settings#settings.initial_window_size, - Conn#connection.self_settings#settings.initial_window_size, - Streams) - of - {error, Code, _NewStream} -> - %% TODO: probably want to have events like this available for metrics - %% tried to create new_stream but there are too many - {{error, Code}, Streams}; - GoodStreamSet -> - {NextId, GoodStreamSet} - end, - {keep_state, Conn#connection{ - next_available_stream_id=NextId+2, - streams=NewStreams - }, [{reply, From, Reply}]}; + new_stream_(From, CallbackMod, CallbackOpts, NotifyPid, Conn); +handle_event({call, From}, {new_stream, CallbackMod, CallbackState, NotifyPid}, Conn) -> + new_stream_(From, CallbackMod, CallbackState, NotifyPid, Conn); handle_event({call, From}, is_push, #connection{ peer_settings=#settings{enable_push=Push} @@ -1270,6 +1269,33 @@ terminate(_Reason, _StateName, _Conn=#connection{}) -> terminate(_Reason, _StateName, _State) -> ok. +new_stream_(From, CallbackMod, CallbackState, NotifyPid, Conn=#connection{streams=Streams, + next_available_stream_id=NextId}) -> + {Reply, NewStreams} = + case + h2_stream_set:new_stream( + NextId, + NotifyPid, + CallbackMod, + CallbackState, + Conn#connection.socket, + Conn#connection.peer_settings#settings.initial_window_size, + Conn#connection.self_settings#settings.initial_window_size, + Conn#connection.type, + Streams) + of + {error, Code, _NewStream} -> + %% TODO: probably want to have events like this available for metrics + %% tried to create new_stream but there are too many + {{error, Code}, Streams}; + {Pid, GoodStreamSet} -> + {{NextId, Pid}, GoodStreamSet} + end, + {keep_state, Conn#connection{ + next_available_stream_id=NextId+2, + streams=NewStreams + }, [{reply, From, Reply}]}. + -spec go_away(error_code(), connection()) -> {next_state, closing, connection()}. go_away(ErrorCode, #connection{ @@ -1650,12 +1676,13 @@ send_request(NextId, NotifyPid, Conn, Streams, Headers, Body) -> Conn#connection.socket, Conn#connection.peer_settings#settings.initial_window_size, Conn#connection.self_settings#settings.initial_window_size, + Conn#connection.type, Streams) of {error, Code, _NewStream} -> %% error creating new stream {error, Code}; - GoodStreamSet -> + {_, GoodStreamSet} -> send_headers(self(), NextId, Headers), send_body(self(), NextId, Body), diff --git a/src/h2_stream.erl b/src/h2_stream.erl index 4a9d2b7f..c1232bb3 100644 --- a/src/h2_stream.erl +++ b/src/h2_stream.erl @@ -3,7 +3,7 @@ %% Public API -export([ - start_link/5, + start_link/6, send_event/2, send_pp/2, send_data/2, @@ -56,13 +56,14 @@ request_end_headers = false :: boolean(), response_headers = [] :: hpack:headers(), response_trailers = [] :: hpack:headers(), - response_body :: iodata() | undefined, + response_body = undefined :: iodata() | undefined, response_end_headers = false :: boolean(), response_end_stream = false :: boolean(), next_state = undefined :: undefined | stream_state_name(), promised_stream = undefined :: undefined | state(), callback_state = undefined :: any(), - callback_mod = undefined :: module() + callback_mod = undefined :: module(), + type :: client | server }). -type state() :: #stream_state{}. @@ -76,7 +77,7 @@ ) -> {ok, callback_state()}. --callback on_receive_request_headers( +-callback on_receive_headers( Headers :: hpack:headers(), CallbackState :: callback_state()) -> {ok, NewState :: callback_state()}. @@ -86,12 +87,12 @@ CallbackState :: callback_state()) -> {ok, NewState :: callback_state()}. --callback on_receive_request_data( +-callback on_receive_data( iodata(), CallbackState :: callback_state())-> {ok, NewState :: callback_state()}. --callback on_request_end_stream( +-callback on_end_stream( CallbackState :: callback_state()) -> {ok, NewState :: callback_state()}. @@ -101,15 +102,17 @@ Connection :: pid(), CallbackModule :: module(), CallbackOptions :: list(), + Type :: client | server, Socket :: sock:socket() ) -> {ok, pid()} | ignore | {error, term()}. -start_link(StreamId, Connection, CallbackModule, CallbackOptions, Socket) -> +start_link(StreamId, Connection, CallbackModule, CallbackOptions, Type, Socket) -> gen_statem:start_link(?MODULE, [StreamId, Connection, CallbackModule, CallbackOptions, + Type, Socket], []). @@ -152,22 +155,38 @@ rst_stream(Pid, Code) -> stop(Pid) -> gen_statem:stop(Pid). +init([ + StreamId, + ConnectionPid, + CB=undefined, + _CBOptions, + Type, + Socket + ]) -> + {ok, idle, #stream_state{ + callback_mod=CB, + socket=Socket, + stream_id=StreamId, + connection=ConnectionPid, + type = Type + }}; init([ StreamId, ConnectionPid, CB, CBOptions, + Type, Socket ]) -> %% TODO: Check for CB implementing this behaviour - {ok, CallbackState} = CB:init(ConnectionPid, StreamId, [Socket | CBOptions]), - + {ok, NewCBState} = callback(CB, init, [ConnectionPid, StreamId], [Socket | CBOptions]), {ok, idle, #stream_state{ callback_mod=CB, socket=Socket, stream_id=StreamId, connection=ConnectionPid, - callback_state=CallbackState + callback_state=NewCBState, + type = Type }}. callback_mode() -> @@ -183,6 +202,19 @@ callback_mode() -> %% PUSH_PROMISE frame with that Stream Id. It's a subtle thing, but it %% drove me crazy until I figured it out +callback(undefined, _, _, State) -> + {ok, State}; +callback(Mod, Fun, Args, State) -> + %% load the module if it isn't already + AllArgs = Args ++ [State], + erlang:function_exported(Mod, module_info, 0) orelse code:ensure_loaded(Mod), + case erlang:function_exported(Mod, Fun, length(AllArgs)) of + true -> + erlang:apply(Mod, Fun, AllArgs); + false -> + {ok, State} + end. + %% Server 'RECV H' idle(cast, {recv_h, Headers}, #stream_state{ @@ -191,7 +223,7 @@ idle(cast, {recv_h, Headers}, }=Stream) -> case is_valid_headers(request, Headers) of ok -> - {ok, NewCBState} = CB:on_receive_request_headers(Headers, CallbackState), + {ok, NewCBState} = callback(CB, on_receive_headers, [Headers], CallbackState), {next_state, open, Stream#stream_state{ @@ -208,7 +240,7 @@ idle(cast, {send_pp, Headers}, callback_mod=CB, callback_state=CallbackState }=Stream) -> - {ok, NewCBState} = CB:on_send_push_promise(Headers, CallbackState), + {ok, NewCBState} = callback(CB, on_send_push_promise, [Headers], CallbackState), {next_state, reserved_local, Stream#stream_state{ @@ -244,7 +276,7 @@ reserved_local(timeout, _, callback_mod=CB }=Stream) -> check_content_length(Stream), - {ok, NewCBState} = CB:on_request_end_stream(CallbackState), + {ok, NewCBState} = callback(CB, on_end_stream, [], CallbackState), {next_state, reserved_local, Stream#stream_state{ @@ -271,19 +303,27 @@ reserved_local(Type, Event, State) -> reserved_remote(cast, {recv_h, Headers}, #stream_state{ + callback_mod=CB, + callback_state=CallbackState }=Stream) -> + {ok, NewCBState} = callback(CB, on_receive_headers, [Headers], CallbackState), {next_state, half_closed_local, Stream#stream_state{ - response_headers=Headers + response_headers=Headers, + callback_state=NewCBState }}; reserved_remote(cast, {recv_t, Headers}, #stream_state{ + callback_mod=CB, + callback_state=CallbackState }=Stream) -> + {ok, NewCBState} = callback(CB, on_receive_headers, [Headers], CallbackState), {next_state, half_closed_local, Stream#stream_state{ - response_headers=Headers + response_headers=Headers, + callback_state=NewCBState }}; reserved_remote(Type, Event, State) -> handle_event(Type, Event, State). @@ -295,7 +335,7 @@ open(cast, recv_es, }=Stream) -> case check_content_length(Stream) of ok -> - {ok, NewCBState} = CB:on_request_end_stream(CallbackState), + {ok, NewCBState} = callback(CB, on_end_stream, [], CallbackState), {next_state, half_closed_remote, Stream#stream_state{ @@ -320,7 +360,7 @@ open(cast, {recv_data, }=Stream) when ?NOT_FLAG(Flags, ?FLAG_END_STREAM) -> Bin = h2_frame_data:data(Payload), - {ok, NewCBState} = CB:on_receive_request_data(Bin, CallbackState), + {ok, NewCBState} = callback(CB, on_receive_data, [Bin], CallbackState), {next_state, open, Stream#stream_state{ @@ -343,20 +383,20 @@ open(cast, {recv_data, }=Stream) when ?IS_FLAG(Flags, ?FLAG_END_STREAM) -> Bin = h2_frame_data:data(Payload), - {ok, CallbackState1} = CB:on_receive_request_data(Bin, CallbackState), + {ok, NewCBState} = callback(CB, on_receive_data, [Bin], CallbackState), NewStream = Stream#stream_state{ incoming_frames=queue:in(F, IFQ), request_body_size=Stream#stream_state.request_body_size+L, request_end_stream=true, - callback_state=CallbackState1 + callback_state=NewCBState }, case check_content_length(NewStream) of ok -> - {ok, NewCBState} = CB:on_request_end_stream(CallbackState1), + {ok, NewCBState1} = callback(CB, on_end_stream, [], NewCBState), {next_state, half_closed_remote, NewStream#stream_state{ - callback_state=NewCBState + callback_state=NewCBState1 }}; rst_stream -> {next_state, @@ -366,17 +406,30 @@ open(cast, {recv_data, %% Trailers open(cast, {recv_h, Trailers}, - #stream_state{}=Stream) -> + #stream_state{type=server}=Stream) -> case is_valid_headers(request, Trailers) of ok -> - {next_state, - open, + {keep_state, Stream#stream_state{ request_headers=Stream#stream_state.request_headers ++ Trailers }}; {error, Code} -> rst_stream_(Code, Stream) end; +open(cast, {recv_h, Headers}, + #stream_state{type=client, + callback_mod=CB, + callback_state=CallbackState}=Stream) -> + case is_valid_headers(response, Headers) of + ok -> + {ok, NewCBState} = callback(CB, on_receive_headers, [Headers], CallbackState), + {keep_state, + Stream#stream_state{ + callback_state=NewCBState, + response_headers=Headers}}; + {error, Code} -> + rst_stream_(Code, Stream) + end; open(cast, {send_data, {#frame_header{ type=?HEADERS, @@ -506,24 +559,30 @@ half_closed_remote(Type, Event, State) -> %% but that stream may be ready to transition, it'll make sense, I %% hope! half_closed_local(cast, - {recv_h, Headers}, - #stream_state{}=Stream) -> + {recv_h, Headers}, + #stream_state{callback_mod=CB, + callback_state=CallbackState + }=Stream) -> case is_valid_headers(response, Headers) of ok -> + {ok, NewCBState} = callback(CB, on_receive_headers, [Headers], CallbackState), {next_state, half_closed_local, Stream#stream_state{ + callback_state=NewCBState, response_headers=Headers}}; {error, Code} -> rst_stream_(Code, Stream) end; + half_closed_local(cast, {recv_data, {#frame_header{ flags=Flags, type=?DATA - },_}=F}, + }, _}=F}, #stream_state{ + callback_mod=undefined, incoming_frames=IFQ } = Stream) -> NewQ = queue:in(F, IFQ), @@ -544,27 +603,60 @@ half_closed_local(cast, incoming_frames=NewQ }} end; - +half_closed_local(cast, + {recv_data, + {#frame_header{ + flags=Flags, + type=?DATA + }, Payload}}, + #stream_state{ + callback_mod=CB, + callback_state=CallbackState + } = Stream) -> + Data = h2_frame_data:data(Payload), + {ok, NewCBState} = callback(CB, on_receive_data, [Data], CallbackState), + case ?IS_FLAG(Flags, ?FLAG_END_STREAM) of + true -> + {ok, NewCBState1} = callback(CB, on_end_stream, [], NewCBState), + {next_state, closed, + Stream#stream_state{ + callback_state=NewCBState1 + }, 0}; + _ -> + {next_state, + half_closed_local, + Stream#stream_state{ + callback_state=NewCBState + }} + end; half_closed_local(cast, recv_es, #stream_state{ response_body = undefined, + callback_mod=CB, + callback_state=CallbackState, incoming_frames = Q } = Stream) -> + {ok, NewCBState} = callback(CB, on_end_stream, [], CallbackState), Data = [h2_frame_data:data(Payload) || {#frame_header{type=?DATA}, Payload} <- queue:to_list(Q)], {next_state, closed, Stream#stream_state{ incoming_frames=queue:new(), - response_body = Data + response_body = Data, + callback_state=NewCBState }, 0}; half_closed_local(cast, recv_es, #stream_state{ - response_body = Data + response_body = Data, + callback_mod=CB, + callback_state=CallbackState } = Stream) -> + {ok, NewCBState} = callback(CB, on_end_stream, [], CallbackState), {next_state, closed, Stream#stream_state{ incoming_frames=queue:new(), - response_body = Data + response_body = Data, + callback_state=NewCBState }, 0}; half_closed_local(_, _, @@ -579,7 +671,8 @@ closed(timeout, _, {stream_finished, Stream#stream_state.stream_id, Stream#stream_state.response_headers, - Stream#stream_state.response_body}), + Stream#stream_state.response_body, + Stream#stream_state.response_trailers}), {stop, normal, Stream}; closed(_, _, #stream_state{}=Stream) -> diff --git a/src/h2_stream_set.erl b/src/h2_stream_set.erl index 4a2c6a1a..e7582e17 100644 --- a/src/h2_stream_set.erl +++ b/src/h2_stream_set.erl @@ -96,6 +96,8 @@ response_headers :: hpack:headers() | undefined, % The response body response_body :: binary() | undefined, + % The response trailers received + response_trailers :: hpack:headers() | undefined, % Can this be thrown away? garbage = false :: boolean() | undefined }). @@ -123,7 +125,7 @@ -export( [ new/1, - new_stream/8, + new_stream/9, get/2, upsert/2, sort/1 @@ -210,8 +212,9 @@ new(server) -> Socket :: sock:socket(), InitialSendWindow :: integer(), InitialRecvWindow :: integer(), + Type :: client | server, StreamSet :: stream_set()) -> - stream_set() + {pid(), stream_set()} | {error, error_code(), closed_stream()}. new_stream( StreamId, @@ -221,6 +224,7 @@ new_stream( Socket, InitialSendWindow, InitialRecvWindow, + Type, StreamSet) -> PeerSubset = get_peer_subset(StreamId, StreamSet), case PeerSubset#peer_subset.max_active =/= unlimited andalso @@ -234,6 +238,7 @@ new_stream( self(), CBMod, CBOpts, + Type, Socket ), NewStream = #active_stream{ @@ -257,7 +262,7 @@ new_stream( h2_stream:stop(Pid), {error, ?REFUSED_STREAM, #closed_stream{id=StreamId}}; NewStreamSet -> - NewStreamSet + {Pid, NewStreamSet} end end. @@ -505,24 +510,26 @@ close(Closed=#closed_stream{}, Streams) -> {Closed, Streams}; close(_Idle=#idle_stream{id=StreamId}, - {Headers, Body}, + {Headers, Body, Trailers}, Streams) -> Closed = #closed_stream{ id=StreamId, response_headers=Headers, - response_body=Body + response_body=Body, + response_trailers=Trailers }, {Closed, upsert(Closed, Streams)}; close(#active_stream{ id=Id, notify_pid=NotifyPid }, - {Headers, Body}, + {Headers, Body, Trailers}, Streams) -> Closed = #closed_stream{ id=Id, response_headers=Headers, response_body=Body, + response_trailers=Trailers, notify_pid=NotifyPid }, {Closed, upsert(Closed, Streams)}. @@ -817,8 +824,9 @@ update_data_queue(_, _, S) -> response(#closed_stream{ response_headers=Headers, - response_body=Body}) -> - {Headers, Body}; + response_body=Body, + response_trailers=Trailers}) -> + {Headers, Body, Trailers}; response(_) -> no_response. diff --git a/test/client_server_SUITE.erl b/test/client_server_SUITE.erl index 071f8223..000c59e0 100644 --- a/test/client_server_SUITE.erl +++ b/test/client_server_SUITE.erl @@ -67,7 +67,7 @@ complex_request(_Config) -> {<<"accept-encoding">>, <<"gzip, deflate">>}, {<<"user-agent">>, <<"chattercli/0.0.1 :D">>} ], - {ok, {ResponseHeaders, ResponseBody}} = h2_client:sync_request(Client, RequestHeaders, <<>>), + {ok, {ResponseHeaders, ResponseBody, _Trailers}} = h2_client:sync_request(Client, RequestHeaders, <<>>), ct:pal("Response Headers: ~p", [ResponseHeaders]), ct:pal("Response Body: ~p", [ResponseBody]), @@ -87,7 +87,7 @@ upgrade_tcp_connection(_Config) -> {<<"accept-encoding">>, <<"gzip, deflate">>}, {<<"user-agent">>, <<"chattercli/0.0.1 :D">>} ], - {ok, {ResponseHeaders, ResponseBody}} = h2_client:sync_request(Client, RequestHeaders, <<>>), + {ok, {ResponseHeaders, ResponseBody, _Trailers}} = h2_client:sync_request(Client, RequestHeaders, <<>>), ct:pal("Response Headers: ~p", [ResponseHeaders]), ct:pal("Response Body: ~p", [ResponseBody]), ok. @@ -105,7 +105,7 @@ basic_push(_Config) -> {<<"accept-encoding">>, <<"gzip, deflate">>}, {<<"user-agent">>, <<"chattercli/0.0.1 :D">>} ], - {ok, {ResponseHeaders, ResponseBody}} = h2_client:sync_request(Client, RequestHeaders, <<>>), + {ok, {ResponseHeaders, ResponseBody, _Trailers}} = h2_client:sync_request(Client, RequestHeaders, <<>>), ct:pal("Response Headers: ~p", [ResponseHeaders]), ct:pal("Response Body: ~p", [ResponseBody]), @@ -156,7 +156,7 @@ get_peer_in_handler(_Config) -> ], - {ok, {ResponseHeaders, ResponseBody}} = h2_client:sync_request(Client, RequestHeaders, <<>>), + {ok, {ResponseHeaders, ResponseBody, _Trailers}} = h2_client:sync_request(Client, RequestHeaders, <<>>), ct:pal("Response Headers: ~p", [ResponseHeaders]), ct:pal("Response Body: ~p", [ResponseBody]), ok. @@ -176,7 +176,7 @@ send_body_opts(_Config) -> ExpectedResponseBody = <<"BodyPart1\nBodyPart2">>, - {ok, {ResponseHeaders, ResponseBody}} = h2_client:sync_request(Client, RequestHeaders, <<>>), + {ok, {ResponseHeaders, ResponseBody, _Trailers}} = h2_client:sync_request(Client, RequestHeaders, <<>>), ct:pal("Response Headers: ~p", [ResponseHeaders]), ct:pal("Response Body: ~p", [ResponseBody]), ?assertEqual(ExpectedResponseBody, (iolist_to_binary(ResponseBody))), diff --git a/test/double_body_handler.erl b/test/double_body_handler.erl index 5fb1b117..9877c510 100644 --- a/test/double_body_handler.erl +++ b/test/double_body_handler.erl @@ -6,10 +6,10 @@ -export([ init/3, - on_receive_request_headers/2, + on_receive_headers/2, on_send_push_promise/2, - on_receive_request_data/2, - on_request_end_stream/1 + on_receive_data/2, + on_end_stream/1 ]). -record(state, {conn_pid :: pid(), @@ -21,25 +21,25 @@ init(ConnPid, StreamId, _Opts) -> {ok, #state{conn_pid=ConnPid, stream_id=StreamId}}. --spec on_receive_request_headers( +-spec on_receive_headers( Headers :: hpack:headers(), CallbackState :: any()) -> {ok, NewState :: any()}. -on_receive_request_headers(_Headers, State) -> {ok, State}. +on_receive_headers(_Headers, State) -> {ok, State}. -spec on_send_push_promise( Headers :: hpack:headers(), CallbackState :: any()) -> {ok, NewState :: any()}. on_send_push_promise(_Headers, State) -> {ok, State}. --spec on_receive_request_data( +-spec on_receive_data( iodata(), CallbackState :: any())-> {ok, NewState :: any()}. -on_receive_request_data(_Data, State) -> {ok, State}. +on_receive_data(_Data, State) -> {ok, State}. --spec on_request_end_stream( +-spec on_end_stream( CallbackState :: any()) -> {ok, NewState :: any()}. -on_request_end_stream(State=#state{conn_pid=ConnPid, +on_end_stream(State=#state{conn_pid=ConnPid, stream_id=StreamId}) -> ResponseHeaders = [ {<<":status">>,<<"200">>} diff --git a/test/echo_handler.erl b/test/echo_handler.erl index 8db32af2..33a3e792 100644 --- a/test/echo_handler.erl +++ b/test/echo_handler.erl @@ -6,10 +6,10 @@ -export([ init/3, - on_receive_request_headers/2, + on_receive_headers/2, on_send_push_promise/2, - on_receive_request_data/2, - on_request_end_stream/1 + on_receive_data/2, + on_end_stream/1 ]). -record(state, {conn_pid :: pid(), @@ -22,26 +22,26 @@ init(ConnPid, StreamId, _Opts) -> {ok, #state{conn_pid=ConnPid, stream_id=StreamId}}. --spec on_receive_request_headers( +-spec on_receive_headers( Headers :: hpack:headers(), CallbackState :: any()) -> {ok, NewState :: any()}. -on_receive_request_headers(_Headers, State) -> {ok, State}. +on_receive_headers(_Headers, State) -> {ok, State}. -spec on_send_push_promise( Headers :: hpack:headers(), CallbackState :: any()) -> {ok, NewState :: any()}. on_send_push_promise(_Headers, State) -> {ok, State}. --spec on_receive_request_data( +-spec on_receive_data( iodata(), CallbackState :: any())-> {ok, NewState :: any()}. -on_receive_request_data(Data, State=#state{buffer=Buffer}) -> +on_receive_data(Data, State=#state{buffer=Buffer}) -> {ok, State#state{buffer = <>}}. --spec on_request_end_stream( +-spec on_end_stream( CallbackState :: any()) -> {ok, NewState :: any()}. -on_request_end_stream(State=#state{conn_pid=ConnPid, +on_end_stream(State=#state{conn_pid=ConnPid, stream_id=StreamId, buffer=Buffer}) -> ResponseHeaders = [ diff --git a/test/flow_control_handler.erl b/test/flow_control_handler.erl index 18ef4b63..17cbc306 100644 --- a/test/flow_control_handler.erl +++ b/test/flow_control_handler.erl @@ -8,10 +8,10 @@ -export([ init/3, - on_receive_request_headers/2, + on_receive_headers/2, on_send_push_promise/2, - on_receive_request_data/2, - on_request_end_stream/1 + on_receive_data/2, + on_end_stream/1 ]). -record(state, {conn_pid :: pid(), @@ -23,25 +23,25 @@ init(ConnPid, StreamId, _Opts) -> {ok, #state{conn_pid=ConnPid, stream_id=StreamId}}. --spec on_receive_request_headers( +-spec on_receive_headers( Headers :: hpack:headers(), CallbackState :: any()) -> {ok, NewState :: any()}. -on_receive_request_headers(_Headers, State) -> {ok, State}. +on_receive_headers(_Headers, State) -> {ok, State}. -spec on_send_push_promise( Headers :: hpack:headers(), CallbackState :: any()) -> {ok, NewState :: any()}. on_send_push_promise(_Headers, State) -> {ok, State}. --spec on_receive_request_data( +-spec on_receive_data( iodata(), CallbackState :: any())-> {ok, NewState :: any()}. -on_receive_request_data(_Data, State) -> {ok, State}. +on_receive_data(_Data, State) -> {ok, State}. --spec on_request_end_stream( +-spec on_end_stream( CallbackState :: any()) -> {ok, NewState :: any()}. -on_request_end_stream(State=#state{conn_pid=ConnPid, +on_end_stream(State=#state{conn_pid=ConnPid, stream_id=StreamId}) -> ResponseHeaders = [ {<<":status">>,<<"200">>} diff --git a/test/peer_test_handler.erl b/test/peer_test_handler.erl index 3e95747b..2246a415 100644 --- a/test/peer_test_handler.erl +++ b/test/peer_test_handler.erl @@ -6,10 +6,10 @@ -export([ init/3, - on_receive_request_headers/2, + on_receive_headers/2, on_send_push_promise/2, - on_receive_request_data/2, - on_request_end_stream/1 + on_receive_data/2, + on_end_stream/1 ]). -record(state, {conn_pid :: pid(), @@ -23,10 +23,10 @@ init(ConnPid, StreamId, _Opts) -> {ok, #state{conn_pid=ConnPid, stream_id=StreamId}}. --spec on_receive_request_headers( +-spec on_receive_headers( Headers :: hpack:headers(), CallbackState :: any()) -> {ok, NewState :: any()}. -on_receive_request_headers(_Headers, State=#state{conn_pid=ConnPid}) -> +on_receive_headers(_Headers, State=#state{conn_pid=ConnPid}) -> {ok, Peer} = h2_connection:get_peer(ConnPid), {ok, State#state{peer=Peer}}. @@ -35,15 +35,15 @@ on_receive_request_headers(_Headers, State=#state{conn_pid=ConnPid}) -> CallbackState :: any()) -> {ok, NewState :: any()}. on_send_push_promise(_Headers, State) -> {ok, State}. --spec on_receive_request_data( +-spec on_receive_data( iodata(), CallbackState :: any())-> {ok, NewState :: any()}. -on_receive_request_data(_Data, State) -> {ok, State}. +on_receive_data(_Data, State) -> {ok, State}. --spec on_request_end_stream( +-spec on_end_stream( CallbackState :: any()) -> {ok, NewState :: any()}. -on_request_end_stream(State=#state{conn_pid=ConnPid, +on_end_stream(State=#state{conn_pid=ConnPid, stream_id=StreamId, peer={Address, Port}}) -> Body = list_to_binary(io_lib:format("Address: ~p, Port: ~p", [Address, Port])), diff --git a/test/server_connection_receive_window.erl b/test/server_connection_receive_window.erl index 4535538f..3f5716e6 100644 --- a/test/server_connection_receive_window.erl +++ b/test/server_connection_receive_window.erl @@ -4,10 +4,10 @@ -export([ init/3, - on_receive_request_headers/2, + on_receive_headers/2, on_send_push_promise/2, - on_receive_request_data/2, - on_request_end_stream/1 + on_receive_data/2, + on_end_stream/1 ]). -record(cb_static, { @@ -18,19 +18,19 @@ init(_ConnPid, _StreamId, _Opts) -> %% You need to pull settings here from application:env or something {ok, #cb_static{}}. -on_receive_request_headers(Headers, State) -> +on_receive_headers(Headers, State) -> h2_stream:send_window_update(65535), - ct:pal("on_receive_request_headers(~p, ~p)", [Headers, State]), + ct:pal("on_receive_headers(~p, ~p)", [Headers, State]), {ok, State#cb_static{req_headers=Headers}}. on_send_push_promise(Headers, State) -> ct:pal("on_send_push_promise(~p, ~p)", [Headers, State]), {ok, State#cb_static{req_headers=Headers}}. -on_receive_request_data(Bin, State)-> - ct:pal("on_receive_request_data(~p, ~p)", [Bin, State]), +on_receive_data(Bin, State)-> + ct:pal("on_receive_data(~p, ~p)", [Bin, State]), {ok, State}. -on_request_end_stream(State) -> - ct:pal("on_request_end_stream(~p)", [State]), +on_end_stream(State) -> + ct:pal("on_end_stream(~p)", [State]), {ok, State}. diff --git a/test/server_stream_receive_window.erl b/test/server_stream_receive_window.erl index 3007ccea..fb6d7e6c 100644 --- a/test/server_stream_receive_window.erl +++ b/test/server_stream_receive_window.erl @@ -4,10 +4,10 @@ -export([ init/3, - on_receive_request_headers/2, + on_receive_headers/2, on_send_push_promise/2, - on_receive_request_data/2, - on_request_end_stream/1 + on_receive_data/2, + on_end_stream/1 ]). -record(cb_static, { @@ -18,19 +18,19 @@ init(_ConnPid, _StreamId, _Opts) -> %% You need to pull settings here from application:env or something {ok, #cb_static{}}. -on_receive_request_headers(Headers, State) -> +on_receive_headers(Headers, State) -> h2_stream:send_connection_window_update(65535), - ct:pal("on_receive_request_headers(~p, ~p)", [Headers, State]), + ct:pal("on_receive_headers(~p, ~p)", [Headers, State]), {ok, State#cb_static{req_headers=Headers}}. on_send_push_promise(Headers, State) -> ct:pal("on_send_push_promise(~p, ~p)", [Headers, State]), {ok, State#cb_static{req_headers=Headers}}. -on_receive_request_data(_Bin, State)-> - ct:pal("on_receive_request_data(Bin!, ~p)", [State]), +on_receive_data(_Bin, State)-> + ct:pal("on_receive_data(Bin!, ~p)", [State]), {ok, State}. -on_request_end_stream(State) -> - ct:pal("on_request_end_stream(~p)", [State]), +on_end_stream(State) -> + ct:pal("on_end_stream(~p)", [State]), {ok, State}.