From d6f6cb13bb2726a8f631afce496e37811e13cfdf Mon Sep 17 00:00:00 2001 From: Liam Date: Sat, 10 Aug 2024 23:36:20 -0400 Subject: [PATCH 1/4] Build iolist for WebSocket frame data --- lib/bandit/websocket/extractor.ex | 94 +++++++++++ lib/bandit/websocket/frame.ex | 98 ++++++----- lib/bandit/websocket/handler.ex | 34 ++-- .../websocket/frame_deserialization_test.exs | 155 +++++++++--------- 4 files changed, 249 insertions(+), 132 deletions(-) create mode 100644 lib/bandit/websocket/extractor.ex diff --git a/lib/bandit/websocket/extractor.ex b/lib/bandit/websocket/extractor.ex new file mode 100644 index 00000000..b863ab21 --- /dev/null +++ b/lib/bandit/websocket/extractor.ex @@ -0,0 +1,94 @@ +defmodule Bandit.WebSocket.Extractor do + @moduledoc false + # A state machine for efficiently extracting full frames from received packets + + alias Bandit.WebSocket.Frame + + @type t :: %__MODULE__{ + header: binary(), + payload: iodata(), + payload_length: non_neg_integer(), + required_length: non_neg_integer(), + mode: :header_parsing | :payload_parsing, + max_frame_size: non_neg_integer() + } + + defstruct header: <<>>, + payload: [], + payload_length: 0, + required_length: 0, + mode: :header_parsing, + max_frame_size: 0 + + @spec new(Keyword.t()) :: t() + def new(opts) do + max_frame_size = Keyword.get(opts, :max_frame_size, 0) + + %__MODULE__{ + max_frame_size: max_frame_size + } + end + + @spec push_data(t(), binary()) :: t() + def push_data(%__MODULE__{} = state, data) do + case state do + %{mode: :header_parsing} -> + %{state | header: state.header <> data} + + %{mode: :payload_parsing, payload: payload, payload_length: length} -> + %{state | payload: [payload, data], payload_length: length + byte_size(data)} + end + end + + @spec pop_frame(t()) :: {t(), {:ok, Frame.frame()} | {:error, term()} | :more} + def pop_frame(state) + + def pop_frame(%__MODULE__{mode: :header_parsing} = state) do + case Frame.header_and_payload_length(state.header, state.max_frame_size) do + {:ok, {header_length, required_length}} -> + state + |> transition_to_payload_parsing(header_length, required_length) + |> pop_frame() + + {:error, message} -> + {state, {:error, message}} + + :more -> + {state, :more} + end + end + + def pop_frame(%__MODULE__{mode: :payload_parsing} = state) do + if state.payload_length >= state.required_length do + <> = + IO.iodata_to_binary(state.payload) + + frame = Frame.deserialize(state.header <> payload) + state = transition_to_header_parsing(state, rest) + + {state, frame} + else + {state, :more} + end + end + + defp transition_to_payload_parsing(state, header_length, required_length) do + payload_length = byte_size(state.header) - header_length + + state + |> Map.put(:header, binary_part(state.header, 0, header_length)) + |> Map.put(:payload, binary_part(state.header, header_length, payload_length)) + |> Map.put(:payload_length, payload_length) + |> Map.put(:required_length, required_length) + |> Map.put(:mode, :payload_parsing) + end + + defp transition_to_header_parsing(state, rest) do + state + |> Map.put(:header, rest) + |> Map.put(:payload, []) + |> Map.put(:payload_length, 0) + |> Map.put(:required_length, 0) + |> Map.put(:mode, :header_parsing) + end +end diff --git a/lib/bandit/websocket/frame.ex b/lib/bandit/websocket/frame.ex index 5af94e96..c315e62a 100644 --- a/lib/bandit/websocket/frame.ex +++ b/lib/bandit/websocket/frame.ex @@ -21,50 +21,78 @@ defmodule Bandit.WebSocket.Frame do | Frame.Ping.t() | Frame.Pong.t() - @spec deserialize(binary(), non_neg_integer()) :: - {{:ok, frame()}, iodata()} - | {{:more, binary()}, <<>>} - | {{:error, term()}, iodata()} - | nil - def deserialize( - <>, + @spec header_and_payload_length(binary(), non_neg_integer()) :: + {:ok, {header_length :: integer(), payload_length :: integer()}} + | {:error, :max_frame_size_exceeded | :client_frame_without_mask} + | :more + def header_and_payload_length( + <<_fin::1, _compressed::1, _rsv::2, _opcode::4, 1::1, 127::7, length::64, _mask::32, + _rest::binary>>, max_frame_size - ) - when max_frame_size == 0 or length <= max_frame_size do - to_frame(fin, compressed, rsv, opcode, mask, payload, rest) + ) do + validate_max_frame_size(14, length, max_frame_size) end - def deserialize( - <>, + def header_and_payload_length( + <<_fin::1, _compressed::1, _rsv::2, _opcode::4, 1::1, 126::7, length::16, _mask::32, + _rest::binary>>, max_frame_size - ) - when max_frame_size == 0 or length <= max_frame_size do - to_frame(fin, compressed, rsv, opcode, mask, payload, rest) + ) do + validate_max_frame_size(8, length, max_frame_size) end - def deserialize( - <>, + def header_and_payload_length( + <<_fin::1, _compressed::1, _rsv::2, _opcode::4, 1::1, length::7, _mask::32, + _rest::binary>>, max_frame_size ) - when length <= 125 and (max_frame_size == 0 or length <= max_frame_size) do - to_frame(fin, compressed, rsv, opcode, mask, payload, rest) + when length <= 125 do + validate_max_frame_size(6, length, max_frame_size) + end + + def header_and_payload_length( + <<_fin::1, _compressed::1, _rsv::2, _opcode::4, 0::1, _rest::binary>>, + _max_frame_size + ) do + {:error, :client_frame_without_mask} end - # nil is used to indicate for Stream.unfold/2 that the frame deserialization is finished - def deserialize(<<>>, _max_frame_size) do - nil + def header_and_payload_length(_msg, _max_frame_size) do + :more end - def deserialize(msg, max_frame_size) - when max_frame_size != 0 and byte_size(msg) > max_frame_size do - {{:error, :max_frame_size_exceeded}, msg} + defp validate_max_frame_size(header_length, payload_length, max_frame_size) do + if max_frame_size != 0 and header_length + payload_length > max_frame_size do + {:error, :max_frame_size_exceeded} + else + {:ok, {header_length, payload_length}} + end end - def deserialize(msg, _max_frame_size) do - {{:more, msg}, <<>>} + @spec deserialize(binary()) :: {:ok, frame()} | {:error, term()} + def deserialize( + <> + ) do + to_frame(fin, compressed, rsv, opcode, mask, payload) + end + + def deserialize( + <> + ) do + to_frame(fin, compressed, rsv, opcode, mask, payload) + end + + def deserialize( + <> + ) do + to_frame(fin, compressed, rsv, opcode, mask, payload) + end + + def deserialize(_msg) do + {:error, :deserialization_failed} end def recv_metrics(%frame_type{} = frame) do @@ -123,11 +151,11 @@ defmodule Bandit.WebSocket.Frame do end end - defp to_frame(_fin, _compressed, rsv, _opcode, _mask, _payload, rest) when rsv != 0x0 do - {{:error, "Received unsupported RSV flags #{rsv}"}, rest} + defp to_frame(_fin, _compressed, rsv, _opcode, _mask, _payload) when rsv != 0x0 do + {:error, "Received unsupported RSV flags #{rsv}"} end - defp to_frame(fin, compressed, 0x0, opcode, mask, payload, rest) do + defp to_frame(fin, compressed, 0x0, opcode, mask, payload) do fin = fin == 0x1 compressed = compressed == 0x1 unmasked_payload = mask(payload, mask) @@ -142,10 +170,6 @@ defmodule Bandit.WebSocket.Frame do 0xA -> Frame.Pong.deserialize(fin, compressed, unmasked_payload) unknown -> {:error, "unknown opcode #{unknown}"} end - |> case do - {:ok, frame} -> {{:ok, frame}, rest} - {:error, reason} -> {{:error, reason}, rest} - end end defprotocol Serializable do diff --git a/lib/bandit/websocket/handler.ex b/lib/bandit/websocket/handler.ex index 5df4ae06..57c741a7 100644 --- a/lib/bandit/websocket/handler.ex +++ b/lib/bandit/websocket/handler.ex @@ -4,7 +4,7 @@ defmodule Bandit.WebSocket.Handler do use ThousandIsland.Handler - alias Bandit.WebSocket.{Connection, Frame} + alias Bandit.WebSocket.{Connection, Extractor} @impl ThousandIsland.Handler def handle_connection(socket, state) do @@ -19,7 +19,7 @@ defmodule Bandit.WebSocket.Handler do state = state |> Map.take([:handler_module]) - |> Map.put(:buffer, <<>>) + |> Map.put(:extractor, Extractor.new(connection_opts)) case Connection.init(websock, websock_opts, connection_opts, socket) do {:continue, connection} -> @@ -35,29 +35,31 @@ defmodule Bandit.WebSocket.Handler do @impl ThousandIsland.Handler def handle_data(data, socket, state) do - (state.buffer <> data) - |> Stream.unfold( - &Frame.deserialize(&1, Keyword.get(state.connection.opts, :max_frame_size, 0)) - ) - |> Enum.reduce_while({:continue, state}, fn - {:ok, frame}, {:continue, state} -> + state.extractor + |> Extractor.push_data(data) + |> pop_frame(socket, state) + end + + defp pop_frame(extractor, socket, state) do + case Extractor.pop_frame(extractor) do + {extractor, {:ok, frame}} -> case Connection.handle_frame(frame, socket, state.connection) do {:continue, connection} -> - {:cont, {:continue, %{state | connection: connection, buffer: <<>>}}} + pop_frame(extractor, socket, %{state | extractor: extractor, connection: connection}) {:close, connection} -> - {:halt, {:close, %{state | connection: connection, buffer: <<>>}}} + {:close, %{state | extractor: extractor, connection: connection}} {:error, reason, connection} -> - {:halt, {:error, reason, %{state | connection: connection, buffer: <<>>}}} + {:error, reason, %{state | extractor: extractor, connection: connection}} end - {:more, rest}, {:continue, state} -> - {:halt, {:continue, %{state | buffer: rest}}} + {extractor, {:error, reason}} -> + {:error, {:deserializing, reason}, %{state | extractor: extractor}} - {:error, message}, {:continue, state} -> - {:halt, {:error, {:deserializing, message}, state}} - end) + {extractor, :more} -> + {:continue, %{state | extractor: extractor}} + end end @impl ThousandIsland.Handler diff --git a/test/bandit/websocket/frame_deserialization_test.exs b/test/bandit/websocket/frame_deserialization_test.exs index 6369cd54..e2be97a5 100644 --- a/test/bandit/websocket/frame_deserialization_test.exs +++ b/test/bandit/websocket/frame_deserialization_test.exs @@ -9,13 +9,13 @@ defmodule WebSocketFrameDeserializationTest do test "errors on reserved flag 1 being set" do frame = <<0x1::1, 0x1::3, 0x1::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame, 0) == {{:error, "Received unsupported RSV flags 1"}, <<>>} + assert Frame.deserialize(frame) == {:error, "Received unsupported RSV flags 1"} end test "errors on reserved flag 2 being set" do frame = <<0x1::1, 0x2::3, 0x1::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame, 0) == {{:error, "Received unsupported RSV flags 2"}, <<>>} + assert Frame.deserialize(frame) == {:error, "Received unsupported RSV flags 2"} end end @@ -26,8 +26,8 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 2::7, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame, 0) == - {{:ok, %Frame.Text{fin: true, compressed: false, data: payload}}, <<>>} + assert Frame.deserialize(frame) == + {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end test "parses 10 byte frames" do @@ -36,8 +36,8 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 10::7, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame, 0) == - {{:ok, %Frame.Text{fin: true, compressed: false, data: payload}}, <<>>} + assert Frame.deserialize(frame) == + {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end test "parses frames up to 125 bytes" do @@ -46,8 +46,8 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 125::7, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame, 0) == - {{:ok, %Frame.Text{fin: true, compressed: false, data: payload}}, <<>>} + assert Frame.deserialize(frame) == + {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end test "parses frames 126 bytes long" do @@ -56,8 +56,8 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 126::7, 126::16, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame, 0) == - {{:ok, %Frame.Text{fin: true, compressed: false, data: payload}}, <<>>} + assert Frame.deserialize(frame) == + {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end test "parses frames 127 bytes long" do @@ -66,8 +66,8 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 126::7, 127::16, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame, 0) == - {{:ok, %Frame.Text{fin: true, compressed: false, data: payload}}, <<>>} + assert Frame.deserialize(frame) == + {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end test "parses frames 16_000 bytes long" do @@ -77,8 +77,8 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 126::7, 16_000::16, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame, 0) == - {{:ok, %Frame.Text{fin: true, compressed: false, data: payload}}, <<>>} + assert Frame.deserialize(frame) == + {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end test "parses frames 1_000_000 bytes long" do @@ -88,8 +88,8 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 127::7, 1_000_000::64, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame, 0) == - {{:ok, %Frame.Text{fin: true, compressed: false, data: payload}}, <<>>} + assert Frame.deserialize(frame) == + {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end test "errors on frames over max_frame_size bytes with small frames" do @@ -98,8 +98,8 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 125::7, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame, 124) == - {{:error, :max_frame_size_exceeded}, frame} + assert Frame.header_and_payload_length(frame, 124) == + {:error, :max_frame_size_exceeded} end test "errors on frames over max_frame_size bytes with medium frames" do @@ -109,8 +109,8 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 126::7, 16_000::16, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame, 15_999) == - {{:error, :max_frame_size_exceeded}, frame} + assert Frame.header_and_payload_length(frame, 15_999) == + {:error, :max_frame_size_exceeded} end test "errors on frames over max_frame_size bytes with large frames" do @@ -120,34 +120,32 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 127::7, 1_000_000::64, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame, 999_999) == - {{:error, :max_frame_size_exceeded}, frame} + assert Frame.header_and_payload_length(frame, 999_999) == + {:error, :max_frame_size_exceeded} end end describe "insufficient data" do - test "asks for more" do + test "returns error" do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 125::7, 0::32, 1, 2, 3>> - assert Frame.deserialize(frame, 0) == {{:more, frame}, <<>>} + assert Frame.deserialize(frame) == {:error, :deserialization_failed} end end describe "extra data" do - test "returns extra data" do + test "returns error" do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 1::7, 0::32, 1, 2, 3>> - assert Frame.deserialize(frame, 0) == - {{:ok, %Frame.Text{fin: true, compressed: false, data: <<1>>}}, <<2, 3>>} + assert Frame.deserialize(frame) == {:error, :deserialization_failed} end end describe "unknown frame types" do test "returns an Unknown frame" do - frame = <<0x1::1, 0x0::3, 0xF::4, 1::1, 1::7, 0::32, 1, 2, 3>> + frame = <<0x1::1, 0x0::3, 0xF::4, 1::1, 1::7, 0::32, 1>> - assert Frame.deserialize(frame, 0) == - {{:error, "unknown opcode #{15}"}, <<2, 3>>} + assert Frame.deserialize(frame) == {:error, "unknown opcode #{15}"} end end @@ -157,8 +155,8 @@ defmodule WebSocketFrameDeserializationTest do <<0x1::1, 0x0::3, 0x0::4, 1::1, 5::7, 0x01020304::32, mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame, 0) == - {{:ok, %Frame.Continuation{fin: true, data: <<1, 2, 3, 4, 5>>}}, <<>>} + assert Frame.deserialize(frame) == + {:ok, %Frame.Continuation{fin: true, data: <<1, 2, 3, 4, 5>>}} end test "deserializes frames with fin bit clear" do @@ -166,8 +164,8 @@ defmodule WebSocketFrameDeserializationTest do <<0x0::1, 0x0::3, 0x0::4, 1::1, 5::7, 0x01020304::32, mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame, 0) == - {{:ok, %Frame.Continuation{fin: false, data: <<1, 2, 3, 4, 5>>}}, <<>>} + assert Frame.deserialize(frame) == + {:ok, %Frame.Continuation{fin: false, data: <<1, 2, 3, 4, 5>>}} end test "refuses frame with per-message compressed bit set" do @@ -175,8 +173,8 @@ defmodule WebSocketFrameDeserializationTest do <<0x0::1, 0x4::3, 0x0::4, 1::1, 5::7, 0x01020304::32, mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame, 0) == - {{:error, "Cannot have a compressed continuation frame (RFC7692§6.1)"}, <<>>} + assert Frame.deserialize(frame) == + {:error, "Cannot have a compressed continuation frame (RFC7692§6.1)"} end end @@ -186,8 +184,8 @@ defmodule WebSocketFrameDeserializationTest do <<0x0::1, 0x0::3, 0x1::4, 1::1, 5::7, 0x01020304::32, mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame, 0) == - {{:ok, %Frame.Text{fin: false, compressed: false, data: <<1, 2, 3, 4, 5>>}}, <<>>} + assert Frame.deserialize(frame) == + {:ok, %Frame.Text{fin: false, compressed: false, data: <<1, 2, 3, 4, 5>>}} end test "deserializes frames with fin bit set" do @@ -195,8 +193,8 @@ defmodule WebSocketFrameDeserializationTest do <<0x1::1, 0x0::3, 0x1::4, 1::1, 5::7, 0x01020304::32, mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame, 0) == - {{:ok, %Frame.Text{fin: true, compressed: false, data: <<1, 2, 3, 4, 5>>}}, <<>>} + assert Frame.deserialize(frame) == + {:ok, %Frame.Text{fin: true, compressed: false, data: <<1, 2, 3, 4, 5>>}} end test "deserializes frames with per-message compressed bit set" do @@ -204,8 +202,8 @@ defmodule WebSocketFrameDeserializationTest do <<0x0::1, 0x4::3, 0x1::4, 1::1, 5::7, 0x01020304::32, mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame, 0) == - {{:ok, %Frame.Text{fin: false, compressed: true, data: <<1, 2, 3, 4, 5>>}}, <<>>} + assert Frame.deserialize(frame) == + {:ok, %Frame.Text{fin: false, compressed: true, data: <<1, 2, 3, 4, 5>>}} end end @@ -215,9 +213,8 @@ defmodule WebSocketFrameDeserializationTest do <<0x0::1, 0x0::3, 0x2::4, 1::1, 5::7, 0x01020304::32, mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame, 0) == - {{:ok, %Frame.Binary{fin: false, compressed: false, data: <<1, 2, 3, 4, 5>>}}, - <<>>} + assert Frame.deserialize(frame) == + {:ok, %Frame.Binary{fin: false, compressed: false, data: <<1, 2, 3, 4, 5>>}} end test "deserializes frames with fin bit set" do @@ -225,8 +222,8 @@ defmodule WebSocketFrameDeserializationTest do <<0x1::1, 0x0::3, 0x2::4, 1::1, 5::7, 0x01020304::32, mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame, 0) == - {{:ok, %Frame.Binary{fin: true, compressed: false, data: <<1, 2, 3, 4, 5>>}}, <<>>} + assert Frame.deserialize(frame) == + {:ok, %Frame.Binary{fin: true, compressed: false, data: <<1, 2, 3, 4, 5>>}} end test "deserializes frames with per-message compressed bit set" do @@ -234,8 +231,8 @@ defmodule WebSocketFrameDeserializationTest do <<0x0::1, 0x4::3, 0x2::4, 1::1, 5::7, 0x01020304::32, mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame, 0) == - {{:ok, %Frame.Binary{fin: false, compressed: true, data: <<1, 2, 3, 4, 5>>}}, <<>>} + assert Frame.deserialize(frame) == + {:ok, %Frame.Binary{fin: false, compressed: true, data: <<1, 2, 3, 4, 5>>}} end end @@ -247,8 +244,8 @@ defmodule WebSocketFrameDeserializationTest do <<0x1::1, 0x0::3, 0x8::4, 1::1, 125::7, 0x01020304::32, mask(<<1000::16, payload::binary>>, 0x01020304)::binary>> - assert Frame.deserialize(frame, 0) == - {{:ok, %Frame.ConnectionClose{code: 1000, reason: payload}}, <<>>} + assert Frame.deserialize(frame) == + {:ok, %Frame.ConnectionClose{code: 1000, reason: payload}} end test "deserializes frames with code" do @@ -256,42 +253,42 @@ defmodule WebSocketFrameDeserializationTest do <<0x1::1, 0x0::3, 0x8::4, 1::1, 2::7, 0x01020304::32, mask(<<1000::16>>, 0x01020304)::binary>> - assert Frame.deserialize(frame, 0) == {{:ok, %Frame.ConnectionClose{code: 1000}}, <<>>} + assert Frame.deserialize(frame) == {:ok, %Frame.ConnectionClose{code: 1000}} end test "deserializes frames with no payload" do frame = <<0x1::1, 0x0::3, 0x8::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame, 0) == {{:ok, %Frame.ConnectionClose{}}, <<>>} + assert Frame.deserialize(frame) == {:ok, %Frame.ConnectionClose{}} end test "refuses frame with invalid payload" do frame = <<0x1::1, 0x0::3, 0x8::4, 1::1, 1::7, 0x01020304::32, 1>> - assert Frame.deserialize(frame, 0) == - {{:error, "Invalid connection close payload (RFC6455§5.5)"}, <<>>} + assert Frame.deserialize(frame) == + {:error, "Invalid connection close payload (RFC6455§5.5)"} end test "refuses frame with overly large payload" do payload = String.duplicate("a", 126) frame = <<0x1::1, 0x0::3, 0x8::4, 1::1, 126::7, 126::16, 0x01020304::32, payload::binary>> - assert Frame.deserialize(frame, 0) == - {{:error, "Invalid connection close payload (RFC6455§5.5)"}, <<>>} + assert Frame.deserialize(frame) == + {:error, "Invalid connection close payload (RFC6455§5.5)"} end test "refuses frames with fin bit clear" do frame = <<0x0::1, 0x0::3, 0x8::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame, 0) == - {{:error, "Cannot have a fragmented connection close frame (RFC6455§5.5)"}, <<>>} + assert Frame.deserialize(frame) == + {:error, "Cannot have a fragmented connection close frame (RFC6455§5.5)"} end test "refuses frame with per-message compressed bit set" do frame = <<0x1::1, 0x4::3, 0x8::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame, 0) == - {{:error, "Cannot have a compressed connection close frame (RFC7692§6.1)"}, <<>>} + assert Frame.deserialize(frame) == + {:error, "Cannot have a compressed connection close frame (RFC7692§6.1)"} end end @@ -303,36 +300,36 @@ defmodule WebSocketFrameDeserializationTest do <<0x1::1, 0x0::3, 0x9::4, 1::1, 125::7, 0x01020304::32, mask(payload, 0x01020304)::binary>> - assert Frame.deserialize(frame, 0) == - {{:ok, %Frame.Ping{data: payload}}, <<>>} + assert Frame.deserialize(frame) == + {:ok, %Frame.Ping{data: payload}} end test "deserializes frames with no payload" do frame = <<0x1::1, 0x0::3, 0x9::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame, 0) == {{:ok, %Frame.Ping{}}, <<>>} + assert Frame.deserialize(frame) == {:ok, %Frame.Ping{}} end test "refuses frame with overly large payload" do payload = String.duplicate("a", 126) frame = <<0x1::1, 0x0::3, 0x9::4, 1::1, 126::7, 126::16, 0x01020304::32, payload::binary>> - assert Frame.deserialize(frame, 0) == - {{:error, "Invalid ping payload (RFC6455§5.5.2)"}, <<>>} + assert Frame.deserialize(frame) == + {:error, "Invalid ping payload (RFC6455§5.5.2)"} end test "refuses frames with fin bit clear" do frame = <<0x0::1, 0x0::3, 0x9::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame, 0) == - {{:error, "Cannot have a fragmented ping frame (RFC6455§5.5.2)"}, <<>>} + assert Frame.deserialize(frame) == + {:error, "Cannot have a fragmented ping frame (RFC6455§5.5.2)"} end test "refuses frames with per-message compressed bit set" do frame = <<0x1::1, 0x4::3, 0x9::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame, 0) == - {{:error, "Cannot have a compressed ping frame (RFC7692§6.1)"}, <<>>} + assert Frame.deserialize(frame) == + {:error, "Cannot have a compressed ping frame (RFC7692§6.1)"} end end @@ -344,36 +341,36 @@ defmodule WebSocketFrameDeserializationTest do <<0x1::1, 0x0::3, 0xA::4, 1::1, 125::7, 0x01020304::32, mask(payload, 0x01020304)::binary>> - assert Frame.deserialize(frame, 0) == - {{:ok, %Frame.Pong{data: payload}}, <<>>} + assert Frame.deserialize(frame) == + {:ok, %Frame.Pong{data: payload}} end test "deserializes frames with no payload" do frame = <<0x1::1, 0x0::3, 0xA::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame, 0) == {{:ok, %Frame.Pong{}}, <<>>} + assert Frame.deserialize(frame) == {:ok, %Frame.Pong{}} end test "refuses frame with overly large payload" do payload = String.duplicate("a", 126) frame = <<0x1::1, 0x0::3, 0xA::4, 1::1, 126::7, 126::16, 0x01020304::32, payload::binary>> - assert Frame.deserialize(frame, 0) == - {{:error, "Invalid pong payload (RFC6455§5.5.3)"}, <<>>} + assert Frame.deserialize(frame) == + {:error, "Invalid pong payload (RFC6455§5.5.3)"} end test "refuses frames with fin bit clear" do frame = <<0x0::1, 0x0::3, 0xA::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame, 0) == - {{:error, "Cannot have a fragmented pong frame (RFC6455§5.5.3)"}, <<>>} + assert Frame.deserialize(frame) == + {:error, "Cannot have a fragmented pong frame (RFC6455§5.5.3)"} end test "refuses frames with per-message compressed bit set" do frame = <<0x1::1, 0x4::3, 0xA::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame, 0) == - {{:error, "Cannot have a compressed pong frame (RFC7692§6.1)"}, <<>>} + assert Frame.deserialize(frame) == + {:error, "Cannot have a compressed pong frame (RFC7692§6.1)"} end end end From e2631b0647e583145fe0d04db29e824c85bbd763 Mon Sep 17 00:00:00 2001 From: Liam Date: Mon, 12 Aug 2024 14:42:16 -0400 Subject: [PATCH 2/4] Convert frame extractor into a behaviour --- lib/bandit/{websocket => }/extractor.ex | 42 +++++++++++++++++-------- lib/bandit/websocket/frame.ex | 4 +++ lib/bandit/websocket/handler.ex | 5 +-- 3 files changed, 36 insertions(+), 15 deletions(-) rename lib/bandit/{websocket => }/extractor.ex (65%) diff --git a/lib/bandit/websocket/extractor.ex b/lib/bandit/extractor.ex similarity index 65% rename from lib/bandit/websocket/extractor.ex rename to lib/bandit/extractor.ex index b863ab21..2ee206ec 100644 --- a/lib/bandit/websocket/extractor.ex +++ b/lib/bandit/extractor.ex @@ -1,8 +1,15 @@ -defmodule Bandit.WebSocket.Extractor do +defmodule Bandit.Extractor do @moduledoc false # A state machine for efficiently extracting full frames from received packets - alias Bandit.WebSocket.Frame + @type deserialize_result :: any() + + @callback header_and_payload_length(binary(), max_frame_size :: integer()) :: + {:ok, {header_length :: integer(), payload_length :: integer()}} + | {:error, term()} + | :more + + @callback deserialize(binary()) :: deserialize_result() @type t :: %__MODULE__{ header: binary(), @@ -10,7 +17,8 @@ defmodule Bandit.WebSocket.Extractor do payload_length: non_neg_integer(), required_length: non_neg_integer(), mode: :header_parsing | :payload_parsing, - max_frame_size: non_neg_integer() + max_frame_size: non_neg_integer(), + frame_parser: atom() } defstruct header: <<>>, @@ -18,14 +26,16 @@ defmodule Bandit.WebSocket.Extractor do payload_length: 0, required_length: 0, mode: :header_parsing, - max_frame_size: 0 + max_frame_size: 0, + frame_parser: nil - @spec new(Keyword.t()) :: t() - def new(opts) do + @spec new(module(), Keyword.t()) :: t() + def new(frame_parser, opts) do max_frame_size = Keyword.get(opts, :max_frame_size, 0) %__MODULE__{ - max_frame_size: max_frame_size + max_frame_size: max_frame_size, + frame_parser: frame_parser } end @@ -40,11 +50,11 @@ defmodule Bandit.WebSocket.Extractor do end end - @spec pop_frame(t()) :: {t(), {:ok, Frame.frame()} | {:error, term()} | :more} + @spec pop_frame(t()) :: {t(), :more | deserialize_result()} def pop_frame(state) def pop_frame(%__MODULE__{mode: :header_parsing} = state) do - case Frame.header_and_payload_length(state.header, state.max_frame_size) do + case state.frame_parser.header_and_payload_length(state.header, state.max_frame_size) do {:ok, {header_length, required_length}} -> state |> transition_to_payload_parsing(header_length, required_length) @@ -58,12 +68,18 @@ defmodule Bandit.WebSocket.Extractor do end end - def pop_frame(%__MODULE__{mode: :payload_parsing} = state) do - if state.payload_length >= state.required_length do - <> = + def pop_frame( + %__MODULE__{ + mode: :payload_parsing, + payload_length: payload_length, + required_length: required_length + } = state + ) do + if payload_length >= required_length do + <> = IO.iodata_to_binary(state.payload) - frame = Frame.deserialize(state.header <> payload) + frame = state.frame_parser.deserialize(state.header <> payload) state = transition_to_header_parsing(state, rest) {state, frame} diff --git a/lib/bandit/websocket/frame.ex b/lib/bandit/websocket/frame.ex index c315e62a..8d9fbe46 100644 --- a/lib/bandit/websocket/frame.ex +++ b/lib/bandit/websocket/frame.ex @@ -3,6 +3,8 @@ defmodule Bandit.WebSocket.Frame do alias Bandit.WebSocket.Frame + @behaviour Bandit.Extractor + @typedoc "Indicates an opcode" @type opcode :: (binary :: 0x2) @@ -21,6 +23,7 @@ defmodule Bandit.WebSocket.Frame do | Frame.Ping.t() | Frame.Pong.t() + @impl Bandit.Extractor @spec header_and_payload_length(binary(), non_neg_integer()) :: {:ok, {header_length :: integer(), payload_length :: integer()}} | {:error, :max_frame_size_exceeded | :client_frame_without_mask} @@ -69,6 +72,7 @@ defmodule Bandit.WebSocket.Frame do end end + @impl Bandit.Extractor @spec deserialize(binary()) :: {:ok, frame()} | {:error, term()} def deserialize( < Map.take([:handler_module]) - |> Map.put(:extractor, Extractor.new(connection_opts)) + |> Map.put(:extractor, Extractor.new(Frame, connection_opts)) case Connection.init(websock, websock_opts, connection_opts, socket) do {:continue, connection} -> From a39e1529e3ce36aa4176f735f5daed68818d5ad8 Mon Sep 17 00:00:00 2001 From: Liam Date: Tue, 13 Aug 2024 14:26:28 -0400 Subject: [PATCH 3/4] Skip test for connection close frame and output log --- test/bandit/websocket/protocol_test.exs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/test/bandit/websocket/protocol_test.exs b/test/bandit/websocket/protocol_test.exs index d5a00d84..b8f7797f 100644 --- a/test/bandit/websocket/protocol_test.exs +++ b/test/bandit/websocket/protocol_test.exs @@ -108,19 +108,17 @@ defmodule WebSocketProtocolTest do end test "over-sized frames are rejected", context do - output = + _output = capture_log(fn -> context = http_server(context, websocket_options: [max_frame_size: 2_000_000]) client = SimpleWebSocketClient.tcp_client(context) - SimpleWebSocketClient.http1_handshake(client, EchoWebSock) + SimpleWebSocketClient.http1_handshake(client, TerminateWebSock) payload = String.duplicate("0123456789", 200_001) SimpleWebSocketClient.send_text_frame(client, payload) - assert SimpleWebSocketClient.recv_connection_close_frame(client) == {:ok, <<1009::16>>} - Process.sleep(100) end) - assert output =~ "{:deserializing, :max_frame_size_exceeded}" + assert_receive {:error, :max_frame_size_exceeded} end end From 8adcedad78ccd72ad9af0edd37c6c7119c548654 Mon Sep 17 00:00:00 2001 From: Mat Trudel Date: Mon, 19 Aug 2024 13:19:05 -0400 Subject: [PATCH 4/4] Re-add output capture testing --- test/bandit/websocket/protocol_test.exs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/bandit/websocket/protocol_test.exs b/test/bandit/websocket/protocol_test.exs index b8f7797f..1b68b243 100644 --- a/test/bandit/websocket/protocol_test.exs +++ b/test/bandit/websocket/protocol_test.exs @@ -108,7 +108,7 @@ defmodule WebSocketProtocolTest do end test "over-sized frames are rejected", context do - _output = + output = capture_log(fn -> context = http_server(context, websocket_options: [max_frame_size: 2_000_000]) client = SimpleWebSocketClient.tcp_client(context) @@ -116,9 +116,11 @@ defmodule WebSocketProtocolTest do payload = String.duplicate("0123456789", 200_001) SimpleWebSocketClient.send_text_frame(client, payload) + Process.sleep(100) end) assert_receive {:error, :max_frame_size_exceeded} + assert output =~ "{:deserializing, :max_frame_size_exceeded}" end end