Skip to content

Commit

Permalink
SSL support for downstream connections (#129)
Browse files Browse the repository at this point in the history
* add ssl to client_handler

* vars for downstream creds

Co-authored-by: Wojtek Mach <[email protected]>

* Update lib/supavisor/client_handler.ex

Co-authored-by: Wojtek Mach <[email protected]>

* Update lib/supavisor/client_handler.ex

Co-authored-by: Wojtek Mach <[email protected]>

* Update lib/supavisor/client_handler.ex

Co-authored-by: Wojtek Mach <[email protected]>

* update client_handler

---------

Co-authored-by: Wojtek Mach <[email protected]>
  • Loading branch information
abc3 and wojtekmach authored Jul 11, 2023
1 parent b549e86 commit fd6e7a9
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 62 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.11
0.2.12
42 changes: 39 additions & 3 deletions config/runtime.exs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import Config

require Logger
alias Supavisor.Helpers, as: H

secret_key_base =
Expand Down Expand Up @@ -102,8 +103,41 @@ upstream_ca =
File.read!(path)
|> H.cert_to_bin()
|> case do
{:ok, bin} -> bin
{:error, _} -> raise "There is no valid certificate in $GLOBAL_UPSTREAM_CA_PATH"
{:ok, bin} ->
Logger.info("Loaded upstream CA from $GLOBAL_UPSTREAM_CA_PATH",
ansi_color: :green
)

bin

{:error, _} ->
raise "There is no valid certificate in $GLOBAL_UPSTREAM_CA_PATH"
end
end

downstream_cert =
if path = System.get_env("GLOBAL_DOWNSTREAM_CERT_PATH") do
if File.exists?(path) do
Logger.info("Loaded downstream cert from $GLOBAL_DOWNSTREAM_CERT_PATH, path: #{path}",
ansi_color: :green
)

path
else
raise "There is no such file in $GLOBAL_DOWNSTREAM_CERT_PATH"
end
end

downstream_key =
if path = System.get_env("GLOBAL_DOWNSTREAM_KEY_PATH") do
if File.exists?(path) do
Logger.info("Loaded downstream key from $GLOBAL_DOWNSTREAM_KEY_PATH, path: #{path}",
ansi_color: :green
)

path
else
raise "There is no such file in $GLOBAL_DOWNSTREAM_KEY_PATH"
end
end

Expand All @@ -116,7 +150,9 @@ if config_env() != :test do
metrics_jwt_secret: System.get_env("METRICS_JWT_SECRET"),
proxy_port: System.get_env("PROXY_PORT", "7654") |> String.to_integer(),
prom_poll_rate: System.get_env("PROM_POLL_RATE", "15000") |> String.to_integer(),
global_upstream_ca: upstream_ca
global_upstream_ca: upstream_ca,
global_downstream_cert: downstream_cert,
global_downstream_key: downstream_key

config :supavisor, Supavisor.Repo,
url: System.get_env("DATABASE_URL", "ecto://postgres:postgres@localhost:6432/postgres"),
Expand Down
144 changes: 96 additions & 48 deletions lib/supavisor/client_handler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,21 @@ defmodule Supavisor.ClientHandler do
supervisor. Each client connection is assigned to a specific tenant supervisor.
"""

@type tcp_sock :: {:gen_tcp, :gen_tcp.socket()}
@type ssl_sock :: {:ssl, :ssl.sslsocket()}
@type sock :: tcp_sock() | ssl_sock()

require Logger

@behaviour :ranch_protocol
@behaviour :gen_statem

alias Supavisor.DbHandler, as: Db
alias Supavisor.Helpers, as: H
alias Supavisor.{Tenants, Protocol.Server, Monitoring.Telem}

@impl true
def start_link(ref, _socket, transport, opts) do
def start_link(ref, _sock, transport, opts) do
pid = :proc_lib.spawn_link(__MODULE__, :init, [ref, transport, opts])
{:ok, pid}
end
Expand All @@ -33,12 +38,12 @@ defmodule Supavisor.ClientHandler do
def init(ref, trans, _opts) do
Process.flag(:trap_exit, true)

{:ok, socket} = :ranch.handshake(ref)
:ok = trans.setopts(socket, [{:active, true}])
{:ok, sock} = :ranch.handshake(ref)
:ok = trans.setopts(sock, active: true)
Logger.debug("ClientHandler is: #{inspect(self())}")

data = %{
socket: socket,
sock: {:gen_tcp, sock},
trans: trans,
db_pid: nil,
tenant: nil,
Expand All @@ -48,29 +53,54 @@ defmodule Supavisor.ClientHandler do
query_start: nil,
mode: nil,
timeout: nil,
ps: nil
ps: nil,
ssl: false
}

:gen_statem.enter_loop(__MODULE__, [hibernate_after: 5_000], :exchange, data)
end

@impl true
def handle_event(:info, {:tcp, _, <<"GET", _::binary>>}, :exchange, data) do
def handle_event(:info, {_proto, _, <<"GET", _::binary>>}, :exchange, data) do
Logger.debug("Client is trying to request HTTP")
:gen_tcp.send(data.socket, "HTTP/1.1 204 OK\r\n\r\n")
sock_send(data.sock, "HTTP/1.1 204 OK\r\n\r\n")
{:stop, :normal, data}
end

def handle_event(:info, {:tcp, _, <<_::64>>}, :exchange, data) do
Logger.warn("Client is trying to connect with SSL")
# TODO: implement SSL negotiation
# SSL negotiation, S/N/Error
:gen_tcp.send(data.socket, "N")
def handle_event(:info, {:tcp, _, <<_::64>>}, :exchange, %{sock: sock} = data) do
Logger.debug("Client is trying to connect with SSL")

:keep_state_and_data
downstream_cert = H.downstream_cert()
downstream_key = H.downstream_key()

# SSL negotiation, S/N/Error
if !!downstream_cert and !!downstream_key do
:ok = setopts(sock, active: false)
:ok = sock_send(sock, "S")

opts = [
certfile: downstream_cert,
keyfile: downstream_key
]

case :ssl.handshake(elem(sock, 1), opts) do
{:ok, ssl_sock} ->
socket = {:ssl, ssl_sock}
:ok = setopts(socket, active: true)
{:keep_state, %{data | sock: socket, ssl: true}}

error ->
Logger.error("SSL handshake error: #{inspect(error)}")
{:stop, :normal, data}
end
else
Logger.error("User requested SSL connection but no downstream cert/key found")
:ok = sock_send(data.sock, "N")
:keep_state_and_data
end
end

def handle_event(:info, {:tcp, _, bin}, :exchange, %{socket: socket} = data) do
def handle_event(:info, {_proto, _, bin}, :exchange, %{sock: sock} = data) do
hello = decode_startup_packet(bin)
Logger.warning("Client startup message: #{inspect(hello)}")
{user, external_id} = parse_user_info(hello.payload["user"])
Expand All @@ -85,26 +115,26 @@ defmodule Supavisor.ClientHandler do

{:error, reason} ->
Logger.error("User not found: #{inspect(reason)} #{inspect({user, external_id})}")
Server.send_error(socket, "XX000", "Tenant or user not found")

:ok = send_error(sock, "XX000", "Tenant or user not found")
{:stop, :normal, data}
end
end

def handle_event(:internal, {:handle, pass}, _, %{socket: socket} = data) do
def handle_event(:internal, {:handle, pass}, _, %{sock: sock} = data) do
Logger.info("Handle exchange")

case handle_exchange(socket, pass) do
case handle_exchange(sock, pass, data.ssl) do
{:error, reason} ->
Logger.error("Exchange error: #{inspect(reason)}")

"e=#{reason}"
|> Server.send_exchange_message(:final, socket)
msg = Server.exchange_message(:final, "e=#{reason}")
sock_send(sock, msg)

{:stop, :normal, data}

:ok ->
Logger.info("Exchange success")
:ok = :gen_tcp.send(socket, Server.authentication_ok())
:ok = sock_send(sock, Server.authentication_ok())
{:keep_state_and_data, {:next_event, :internal, :subscribe}}
end
end
Expand Down Expand Up @@ -132,8 +162,8 @@ defmodule Supavisor.ClientHandler do
end
end

def handle_event(:internal, {:greetings, ps}, _, data) do
:ok = :gen_tcp.send(data.socket, Server.greetings(ps))
def handle_event(:internal, {:greetings, ps}, _, %{sock: sock} = data) do
:ok = sock_send(sock, Server.greetings(ps))
{:next_state, :idle, data}
end

Expand All @@ -143,25 +173,26 @@ defmodule Supavisor.ClientHandler do

def handle_event(:timeout, :wait_ps, _, data) do
Logger.error("Wait parameter status timeout, send default #{inspect(data.ps)}}")

ps = Server.encode_parameter_status(data.ps)
{:keep_state_and_data, {:next_event, :internal, {:greetings, ps}}}
end

# ignore termination messages
def handle_event(:info, {:tcp, _, <<?X, 4::32>>}, _, _) do
def handle_event(:info, {proto, _, <<?X, 4::32>>}, _, _) when proto in [:tcp, :ssl] do
Logger.warn("Receive termination")
:keep_state_and_data
end

def handle_event(:info, {:tcp, _, bin}, :idle, data) do
def handle_event(:info, {proto, _, bin}, :idle, data) do
ts = System.monotonic_time()
db_pid = db_checkout(:on_query, data)

{:next_state, :busy, %{data | db_pid: db_pid, query_start: ts},
{:next_event, :internal, {:tcp, nil, bin}}}
{:next_event, :internal, {proto, nil, bin}}}
end

def handle_event(_, {:tcp, _, bin}, :busy, data) do
def handle_event(_, {proto, _, bin}, :busy, data) when proto in [:tcp, :ssl] do
case Db.call(data.db_pid, bin) do
:ok ->
Logger.info("DB call success")
Expand All @@ -173,7 +204,7 @@ defmodule Supavisor.ClientHandler do
if size > 1_000_000 do
msg = "Db buffer size is too big: #{size}"
Logger.error(msg)
Server.send_error(data.socket, "XX000", msg)
sock_send(data.sock, Server.error_message("XX000", msg))
{:stop, :normal, data}
else
Logger.debug("DB call buffering")
Expand All @@ -183,7 +214,7 @@ defmodule Supavisor.ClientHandler do
{:error, reason} ->
msg = "DB call error: #{inspect(reason)}"
Logger.error(msg)
Server.send_error(data.socket, "XX000", msg)
sock_send(data.sock, Server.error_message("XX000", msg))
{:stop, :normal, data}
end
end
Expand All @@ -203,6 +234,12 @@ defmodule Supavisor.ClientHandler do
{:stop, :normal}
end

def handle_event(_, {closed, _}, _, data)
when closed in [:tcp_closed, :ssl_closed] do
Logger.debug("#{closed} soket closed for #{inspect(data.tenant)}")
{:stop, :normal}
end

# linked db_handler went down
def handle_event(:info, {:EXIT, db_pid, reason}, _, _) do
Logger.error("DB handler #{inspect(db_pid)} exited #{inspect(reason)}")
Expand All @@ -228,14 +265,14 @@ defmodule Supavisor.ClientHandler do
def handle_event({:call, from}, {:client_call, bin, ready?}, _, data) do
Logger.debug("--> --> bin #{inspect(byte_size(bin))} bytes")

reply = {:reply, from, :gen_tcp.send(data.socket, bin)}
reply = {:reply, from, sock_send(data.sock, bin)}

if ready? do
Logger.debug("Client is ready")

db_pid = handle_db_pid(data.mode, data.pool, data.db_pid)

Telem.network_usage(:client, {:gen_tcp, data.socket}, data.tenant, data.user_alias)
Telem.network_usage(:client, data.sock, data.tenant, data.user_alias)
Telem.client_query_time(data.query_start, data.tenant, data.user_alias)
{:next_state, :idle, %{data | db_pid: db_pid}, reply}
else
Expand Down Expand Up @@ -273,7 +310,7 @@ defmodule Supavisor.ClientHandler do
end

Logger.error(msg)
Server.send_error(data.socket, "XX000", msg)
sock_send(data.sock, Server.error_message("XX000", msg))
:ok
end

Expand Down Expand Up @@ -309,38 +346,32 @@ defmodule Supavisor.ClientHandler do
:undef
end

@spec handle_exchange(port, fun) :: :ok | {:error, String.t()}
def handle_exchange(socket, password) do
:ok = Server.send_request_authentication(socket)
@spec handle_exchange(sock(), fun(), boolean()) :: :ok | {:error, String.t()}
def handle_exchange({_, socket} = sock, password, ssl) do
:ok = sock_send(sock, Server.auth_request())

receive do
{:tcp, socket, bin} ->
{_proto, ^socket, bin} ->
case Server.decode_pkt(bin) do
{:ok,
%{tag: :password_message, payload: {:scram_sha_256, %{"n" => user, "r" => nonce}}},
_} ->
message = Server.exchange_first_message(nonce)
server_first_parts = :pgo_scram.parse_server_first(message, nonce)
channel = if ssl, do: "eSws", else: "biws"

{client_final_message, server_proof} =
:pgo_scram.get_client_final(
server_first_parts,
nonce,
user,
password.()
)
H.get_client_final(server_first_parts, nonce, user, password.(), channel)

:ok =
message
|> Server.send_exchange_message(:first, socket)
:ok = sock_send(sock, Server.exchange_message(:first, message))

receive do
{:tcp, socket, bin} ->
{_proto, ^socket, bin} ->
case Server.decode_pkt(bin) do
{:ok, %{tag: :password_message, payload: {:first_msg_response, %{"p" => p}}}, _} ->
if p == List.last(client_final_message) do
"v=#{Base.encode64(server_proof)}"
|> Server.send_exchange_message(:final, socket)
message = "v=#{Base.encode64(server_proof)}"
:ok = sock_send(sock, Server.exchange_message(:final, message))
else
{:error, "Invalid client signature"}
end
Expand Down Expand Up @@ -399,4 +430,21 @@ defmodule Supavisor.ClientHandler do
ps: user_info.default_parameter_status
}
end

@spec sock_send(tcp_sock() | ssl_sock(), iodata()) :: :ok | {:error, term()}
defp sock_send({mod, sock}, data) do
mod.send(sock, data)
end

@spec send_error(sock, String.t(), String.t()) :: :ok | {:error, term()}
defp send_error(sock, code, message) do
data = Server.error_message(code, message)
sock_send(sock, data)
end

@spec setopts(sock, term()) :: :ok | {:error, term()}
defp setopts({mod, sock}, opts) do
mod = if mod == :gen_tcp, do: :inet, else: mod
mod.setopts(sock, opts)
end
end
Loading

0 comments on commit fd6e7a9

Please sign in to comment.