Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Build iolist for WebSocket frame data #390

Merged
merged 4 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions lib/bandit/extractor.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
defmodule Bandit.Extractor do
@moduledoc false
# A state machine for efficiently extracting full frames from received packets

@type deserialize_result :: any()

@callback header_and_payload_length(binary(), max_frame_size :: integer()) ::
{:ok, {header_length :: integer(), payload_length :: integer()}}
| {:error, term()}
| :more

@callback deserialize(binary()) :: deserialize_result()

@type t :: %__MODULE__{
header: binary(),
payload: iodata(),
payload_length: non_neg_integer(),
required_length: non_neg_integer(),
mode: :header_parsing | :payload_parsing,
max_frame_size: non_neg_integer(),
frame_parser: atom()
}

defstruct header: <<>>,
payload: [],
payload_length: 0,
required_length: 0,
mode: :header_parsing,
max_frame_size: 0,
frame_parser: nil

@spec new(module(), Keyword.t()) :: t()
def new(frame_parser, opts) do
max_frame_size = Keyword.get(opts, :max_frame_size, 0)

%__MODULE__{
max_frame_size: max_frame_size,
frame_parser: frame_parser
}
end

@spec push_data(t(), binary()) :: t()
def push_data(%__MODULE__{} = state, data) do
case state do
%{mode: :header_parsing} ->
%{state | header: state.header <> data}

%{mode: :payload_parsing, payload: payload, payload_length: length} ->
%{state | payload: [payload, data], payload_length: length + byte_size(data)}
end
end

@spec pop_frame(t()) :: {t(), :more | deserialize_result()}
def pop_frame(state)

def pop_frame(%__MODULE__{mode: :header_parsing} = state) do
case state.frame_parser.header_and_payload_length(state.header, state.max_frame_size) do
{:ok, {header_length, required_length}} ->
state
|> transition_to_payload_parsing(header_length, required_length)
|> pop_frame()

{:error, message} ->
{state, {:error, message}}

:more ->
{state, :more}
end
end

def pop_frame(
%__MODULE__{
mode: :payload_parsing,
payload_length: payload_length,
required_length: required_length
} = state
) do
if payload_length >= required_length do
<<payload::binary-size(required_length), rest::binary>> =
IO.iodata_to_binary(state.payload)

frame = state.frame_parser.deserialize(state.header <> payload)
state = transition_to_header_parsing(state, rest)

{state, frame}
else
{state, :more}
end
end

defp transition_to_payload_parsing(state, header_length, required_length) do
payload_length = byte_size(state.header) - header_length

state
|> Map.put(:header, binary_part(state.header, 0, header_length))
|> Map.put(:payload, binary_part(state.header, header_length, payload_length))
|> Map.put(:payload_length, payload_length)
|> Map.put(:required_length, required_length)
|> Map.put(:mode, :payload_parsing)
end

defp transition_to_header_parsing(state, rest) do
state
|> Map.put(:header, rest)
|> Map.put(:payload, [])
|> Map.put(:payload_length, 0)
|> Map.put(:required_length, 0)
|> Map.put(:mode, :header_parsing)
end
end
102 changes: 65 additions & 37 deletions lib/bandit/websocket/frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ defmodule Bandit.WebSocket.Frame do

alias Bandit.WebSocket.Frame

@behaviour Bandit.Extractor

@typedoc "Indicates an opcode"
@type opcode ::
(binary :: 0x2)
Expand All @@ -21,50 +23,80 @@ defmodule Bandit.WebSocket.Frame do
| Frame.Ping.t()
| Frame.Pong.t()

@spec deserialize(binary(), non_neg_integer()) ::
{{:ok, frame()}, iodata()}
| {{:more, binary()}, <<>>}
| {{:error, term()}, iodata()}
| nil
def deserialize(
<<fin::1, compressed::1, rsv::2, opcode::4, 1::1, 127::7, length::64, mask::32,
payload::binary-size(length), rest::binary>>,
@impl Bandit.Extractor
@spec header_and_payload_length(binary(), non_neg_integer()) ::
{:ok, {header_length :: integer(), payload_length :: integer()}}
| {:error, :max_frame_size_exceeded | :client_frame_without_mask}
| :more
def header_and_payload_length(
<<_fin::1, _compressed::1, _rsv::2, _opcode::4, 1::1, 127::7, length::64, _mask::32,
_rest::binary>>,
max_frame_size
)
when max_frame_size == 0 or length <= max_frame_size do
to_frame(fin, compressed, rsv, opcode, mask, payload, rest)
) do
validate_max_frame_size(14, length, max_frame_size)
end

def deserialize(
<<fin::1, compressed::1, rsv::2, opcode::4, 1::1, 126::7, length::16, mask::32,
payload::binary-size(length), rest::binary>>,
def header_and_payload_length(
<<_fin::1, _compressed::1, _rsv::2, _opcode::4, 1::1, 126::7, length::16, _mask::32,
_rest::binary>>,
max_frame_size
)
when max_frame_size == 0 or length <= max_frame_size do
to_frame(fin, compressed, rsv, opcode, mask, payload, rest)
) do
validate_max_frame_size(8, length, max_frame_size)
end

def deserialize(
<<fin::1, compressed::1, rsv::2, opcode::4, 1::1, length::7, mask::32,
payload::binary-size(length), rest::binary>>,
def header_and_payload_length(
<<_fin::1, _compressed::1, _rsv::2, _opcode::4, 1::1, length::7, _mask::32,
_rest::binary>>,
max_frame_size
)
when length <= 125 and (max_frame_size == 0 or length <= max_frame_size) do
to_frame(fin, compressed, rsv, opcode, mask, payload, rest)
when length <= 125 do
validate_max_frame_size(6, length, max_frame_size)
end

def header_and_payload_length(
<<_fin::1, _compressed::1, _rsv::2, _opcode::4, 0::1, _rest::binary>>,
_max_frame_size
) do
{:error, :client_frame_without_mask}
end

def header_and_payload_length(_msg, _max_frame_size) do
:more
end

# nil is used to indicate for Stream.unfold/2 that the frame deserialization is finished
def deserialize(<<>>, _max_frame_size) do
nil
defp validate_max_frame_size(header_length, payload_length, max_frame_size) do
if max_frame_size != 0 and header_length + payload_length > max_frame_size do
{:error, :max_frame_size_exceeded}
else
{:ok, {header_length, payload_length}}
end
end

def deserialize(msg, max_frame_size)
when max_frame_size != 0 and byte_size(msg) > max_frame_size do
{{:error, :max_frame_size_exceeded}, msg}
@impl Bandit.Extractor
@spec deserialize(binary()) :: {:ok, frame()} | {:error, term()}
def deserialize(
<<fin::1, compressed::1, rsv::2, opcode::4, 1::1, 127::7, length::64, mask::32,
payload::binary-size(length)>>
) do
to_frame(fin, compressed, rsv, opcode, mask, payload)
end

def deserialize(msg, _max_frame_size) do
{{:more, msg}, <<>>}
def deserialize(
<<fin::1, compressed::1, rsv::2, opcode::4, 1::1, 126::7, length::16, mask::32,
payload::binary-size(length)>>
) do
to_frame(fin, compressed, rsv, opcode, mask, payload)
end

def deserialize(
<<fin::1, compressed::1, rsv::2, opcode::4, 1::1, length::7, mask::32,
payload::binary-size(length)>>
) do
to_frame(fin, compressed, rsv, opcode, mask, payload)
end

def deserialize(_msg) do
{:error, :deserialization_failed}
end

def recv_metrics(%frame_type{} = frame) do
Expand Down Expand Up @@ -123,11 +155,11 @@ defmodule Bandit.WebSocket.Frame do
end
end

defp to_frame(_fin, _compressed, rsv, _opcode, _mask, _payload, rest) when rsv != 0x0 do
{{:error, "Received unsupported RSV flags #{rsv}"}, rest}
defp to_frame(_fin, _compressed, rsv, _opcode, _mask, _payload) when rsv != 0x0 do
{:error, "Received unsupported RSV flags #{rsv}"}
end

defp to_frame(fin, compressed, 0x0, opcode, mask, payload, rest) do
defp to_frame(fin, compressed, 0x0, opcode, mask, payload) do
fin = fin == 0x1
compressed = compressed == 0x1
unmasked_payload = mask(payload, mask)
Expand All @@ -142,10 +174,6 @@ defmodule Bandit.WebSocket.Frame do
0xA -> Frame.Pong.deserialize(fin, compressed, unmasked_payload)
unknown -> {:error, "unknown opcode #{unknown}"}
end
|> case do
{:ok, frame} -> {{:ok, frame}, rest}
{:error, reason} -> {{:error, reason}, rest}
end
end

defprotocol Serializable do
Expand Down
33 changes: 18 additions & 15 deletions lib/bandit/websocket/handler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ defmodule Bandit.WebSocket.Handler do

use ThousandIsland.Handler

alias Bandit.Extractor
alias Bandit.WebSocket.{Connection, Frame}

@impl ThousandIsland.Handler
Expand All @@ -19,7 +20,7 @@ defmodule Bandit.WebSocket.Handler do
state =
state
|> Map.take([:handler_module])
|> Map.put(:buffer, <<>>)
|> Map.put(:extractor, Extractor.new(Frame, connection_opts))

case Connection.init(websock, websock_opts, connection_opts, socket) do
{:continue, connection} ->
Expand All @@ -35,29 +36,31 @@ defmodule Bandit.WebSocket.Handler do

@impl ThousandIsland.Handler
def handle_data(data, socket, state) do
(state.buffer <> data)
|> Stream.unfold(
&Frame.deserialize(&1, Keyword.get(state.connection.opts, :max_frame_size, 0))
)
|> Enum.reduce_while({:continue, state}, fn
{:ok, frame}, {:continue, state} ->
state.extractor
|> Extractor.push_data(data)
|> pop_frame(socket, state)
end

defp pop_frame(extractor, socket, state) do
case Extractor.pop_frame(extractor) do
{extractor, {:ok, frame}} ->
case Connection.handle_frame(frame, socket, state.connection) do
{:continue, connection} ->
{:cont, {:continue, %{state | connection: connection, buffer: <<>>}}}
pop_frame(extractor, socket, %{state | extractor: extractor, connection: connection})

{:close, connection} ->
{:halt, {:close, %{state | connection: connection, buffer: <<>>}}}
{:close, %{state | extractor: extractor, connection: connection}}

{:error, reason, connection} ->
{:halt, {:error, reason, %{state | connection: connection, buffer: <<>>}}}
{:error, reason, %{state | extractor: extractor, connection: connection}}
end

{:more, rest}, {:continue, state} ->
{:halt, {:continue, %{state | buffer: rest}}}
{extractor, {:error, reason}} ->
{:error, {:deserializing, reason}, %{state | extractor: extractor}}

{:error, message}, {:continue, state} ->
{:halt, {:error, {:deserializing, message}, state}}
end)
{extractor, :more} ->
{:continue, %{state | extractor: extractor}}
end
end

@impl ThousandIsland.Handler
Expand Down
Loading