Skip to content

Commit

Permalink
Use Plug.Exception to handle Bandit.HTTPError in the general case
Browse files Browse the repository at this point in the history
  • Loading branch information
mtrudel committed Apr 19, 2024
1 parent 07d5ebc commit 0042fa5
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 28 deletions.
20 changes: 5 additions & 15 deletions lib/bandit/http1/socket.ex
Original file line number Diff line number Diff line change
Expand Up @@ -382,28 +382,18 @@ defmodule Bandit.HTTP1.Socket do

def supported_upgrade?(_socket, protocol), do: protocol == :websocket

def send_on_error(%@for{} = socket, %Bandit.HTTPError{} = error) do
socket = send_error(socket, error.status)
%{socket | write_state: :sent}
end

def send_on_error(%@for{} = socket, error) do
socket = send_error(socket, Plug.Exception.status(error))
%{socket | write_state: :sent}
end

defp send_error(socket, status) do
receive do
{:plug_conn, :sent} -> socket
{:plug_conn, :sent} -> %{socket | write_state: :sent}
after
0 -> send_headers(socket, status, [], :no_body)
0 -> send_headers(socket, Plug.Exception.status(error), [], :no_body)
end
end

@spec request_error!(term()) :: no_return()
@spec request_error!(term(), atom()) :: no_return()
defp request_error!(reason, status \\ :bad_request) do
raise Bandit.HTTPError, message: to_string(reason), status: Plug.Conn.Status.code(status)
@spec request_error!(term(), Plug.Conn.status()) :: no_return()
defp request_error!(reason, plug_status \\ :bad_request) do
raise Bandit.HTTPError, message: to_string(reason), plug_status: plug_status
end
end
end
5 changes: 0 additions & 5 deletions lib/bandit/http2/stream.ex
Original file line number Diff line number Diff line change
Expand Up @@ -504,11 +504,6 @@ defmodule Bandit.HTTP2.Stream do

def supported_upgrade?(_stream, _protocol), do: false

def send_on_error(%@for{} = stream, %Bandit.HTTPError{}) do
do_send(stream, {:send_rst_stream, Bandit.HTTP2.Errors.protocol_error()})
%{stream | state: :closed}
end

def send_on_error(%@for{} = stream, %Bandit.HTTP2.Errors.StreamError{} = error) do
do_send(stream, {:send_rst_stream, error.error_code})
%{stream | state: :closed}
Expand Down
2 changes: 1 addition & 1 deletion lib/bandit/http_error.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ defmodule Bandit.HTTPError do
# Represents an error suitable for return as an HTTP status. Note that these may be surfaced
# from anywhere that such a message is well defined, including within HTTP/1 transport concerns
# and also within shared HTTP semantics (ie: within Bandit.Adapter or Bandit.Pipeline)
defexception message: nil, status: 400
defexception message: nil, plug_status: :bad_request
end
6 changes: 3 additions & 3 deletions lib/bandit/pipeline.ex
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ defmodule Bandit.Pipeline do
end

@spec request_error!(term()) :: no_return()
@spec request_error!(term(), atom()) :: no_return()
defp request_error!(reason, status \\ :bad_request) do
raise Bandit.HTTPError, message: reason, status: Plug.Conn.Status.code(status)
@spec request_error!(term(), Plug.Conn.status()) :: no_return()
defp request_error!(reason, plug_status \\ :bad_request) do
raise Bandit.HTTPError, message: reason, plug_status: plug_status
end
end
8 changes: 4 additions & 4 deletions test/bandit/http2/protocol_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2343,7 +2343,7 @@ defmodule HTTP2ProtocolTest do
assert Jason.decode!(body)["host"] == "banana"
end

test "resets stream if no host header set", context do
test "sends 400 if no host header set", context do
output =
capture_log(fn ->
socket = SimpleH2Client.setup_connection(context)
Expand All @@ -2355,7 +2355,7 @@ defmodule HTTP2ProtocolTest do
]

SimpleH2Client.send_headers(socket, 1, true, headers)
assert SimpleH2Client.recv_rst_stream(socket) == {:ok, 1, 1}
assert {:ok, 1, true, [{":status", "400"} | _], _} = SimpleH2Client.recv_headers(socket)
Process.sleep(100)
end)

Expand Down Expand Up @@ -2414,7 +2414,7 @@ defmodule HTTP2ProtocolTest do
assert Jason.decode!(body)["port"] == 1234
end

test "resets stream if port cannot be parsed from host header", context do
test "sends 400 if port cannot be parsed from host header", context do
output =
capture_log(fn ->
socket = SimpleH2Client.setup_connection(context)
Expand All @@ -2427,7 +2427,7 @@ defmodule HTTP2ProtocolTest do
]

SimpleH2Client.send_headers(socket, 1, true, headers)
assert SimpleH2Client.recv_rst_stream(socket) == {:ok, 1, 1}
assert {:ok, 1, true, [{":status", "400"} | _], _} = SimpleH2Client.recv_headers(socket)
Process.sleep(100)
end)

Expand Down

0 comments on commit 0042fa5

Please sign in to comment.