diff --git a/lib/extensions/postgres_cdc_rls/migrations.ex b/lib/extensions/postgres_cdc_rls/migrations.ex index 28d427d18..c9e3e5ed5 100644 --- a/lib/extensions/postgres_cdc_rls/migrations.ex +++ b/lib/extensions/postgres_cdc_rls/migrations.ex @@ -107,7 +107,7 @@ defmodule Extensions.PostgresCdcRls.Migrations do "db_user" => db_user, "db_password" => db_password, "db_socket_opts" => db_socket_opts - } = _args + } = args ) do {host, port, name, user, pass} = H.decrypt_creds( @@ -118,26 +118,27 @@ defmodule Extensions.PostgresCdcRls.Migrations do db_password ) - Repo.with_dynamic_repo( - [ - hostname: host, - port: port, - database: name, - password: pass, - username: user, - pool_size: 2, - socket_options: db_socket_opts - ], - fn repo -> - Ecto.Migrator.run( - Repo, - @migrations, - :up, - all: true, - prefix: "realtime", - dynamic_repo: repo - ) - end - ) + ssl_enforced = H.default_ssl_param(args) + + [ + hostname: host, + port: port, + database: name, + password: pass, + username: user, + pool_size: 2, + socket_options: db_socket_opts + ] + |> H.maybe_enforce_ssl_config(ssl_enforced) + |> Repo.with_dynamic_repo(fn repo -> + Ecto.Migrator.run( + Repo, + @migrations, + :up, + all: true, + prefix: "realtime", + dynamic_repo: repo + ) + end) end end diff --git a/lib/extensions/postgres_cdc_rls/replication_poller.ex b/lib/extensions/postgres_cdc_rls/replication_poller.ex index 47321fd36..b6a39f9f7 100644 --- a/lib/extensions/postgres_cdc_rls/replication_poller.ex +++ b/lib/extensions/postgres_cdc_rls/replication_poller.ex @@ -8,7 +8,8 @@ defmodule Extensions.PostgresCdcRls.ReplicationPoller do require Logger - import Realtime.Helpers, only: [cancel_timer: 1, decrypt_creds: 5] + import Realtime.Helpers, + only: [cancel_timer: 1, decrypt_creds: 5, default_ssl_param: 1, maybe_enforce_ssl_config: 2] alias Extensions.PostgresCdcRls.{Replications, MessageDispatcher} alias DBConnection.Backoff @@ -28,6 +29,8 @@ defmodule Extensions.PostgresCdcRls.ReplicationPoller do @impl true def init(args) do + ssl_enforced = default_ssl_param(args) + {:ok, conn} = connect_db( args["db_host"], @@ -35,7 +38,8 @@ defmodule Extensions.PostgresCdcRls.ReplicationPoller do args["db_name"], args["db_user"], args["db_password"], - args["db_socket_opts"] + args["db_socket_opts"], + ssl_enforced ) tenant = args["id"] @@ -306,10 +310,10 @@ defmodule Extensions.PostgresCdcRls.ReplicationPoller do defp convert_errors(_), do: nil - defp connect_db(host, port, name, user, pass, socket_opts) do + defp connect_db(host, port, name, user, pass, socket_opts, ssl_enforced) do {host, port, name, user, pass} = decrypt_creds(host, port, name, user, pass) - Postgrex.start_link( + [ hostname: host, port: port, database: name, @@ -320,7 +324,9 @@ defmodule Extensions.PostgresCdcRls.ReplicationPoller do application_name: "realtime_rls" ], socket_options: socket_opts - ) + ] + |> maybe_enforce_ssl_config(ssl_enforced) + |> Postgrex.start_link() end defp prepare_replication( diff --git a/lib/extensions/postgres_cdc_rls/subscription_manager.ex b/lib/extensions/postgres_cdc_rls/subscription_manager.ex index 93b5c98c3..f690c5c9f 100644 --- a/lib/extensions/postgres_cdc_rls/subscription_manager.ex +++ b/lib/extensions/postgres_cdc_rls/subscription_manager.ex @@ -69,8 +69,13 @@ defmodule Extensions.PostgresCdcRls.SubscriptionManager do Logger.metadata(external_id: id, project: id) - {:ok, conn} = H.connect_db(host, port, name, user, pass, socket_opts, 1) - {:ok, conn_pub} = H.connect_db(host, port, name, user, pass, socket_opts, subs_pool_size) + ssl_enforced = H.default_ssl_param(args) + + {:ok, conn} = H.connect_db(host, port, name, user, pass, socket_opts, 1, 5_000, ssl_enforced) + + {:ok, conn_pub} = + H.connect_db(host, port, name, user, pass, socket_opts, subs_pool_size, 5_000, ssl_enforced) + {:ok, _} = Subscriptions.maybe_delete_all(conn) Rls.update_meta(id, self(), conn_pub) diff --git a/lib/extensions/postgres_cdc_rls/subscriptions_checker.ex b/lib/extensions/postgres_cdc_rls/subscriptions_checker.ex index 9404cb725..ae65715eb 100644 --- a/lib/extensions/postgres_cdc_rls/subscriptions_checker.ex +++ b/lib/extensions/postgres_cdc_rls/subscriptions_checker.ex @@ -49,7 +49,9 @@ defmodule Extensions.PostgresCdcRls.SubscriptionsChecker do Logger.metadata(external_id: id, project: id) - {:ok, conn} = H.connect_db(host, port, name, user, pass, socket_opts, 1) + ssl_enforced = H.default_ssl_param(args) + + {:ok, conn} = H.connect_db(host, port, name, user, pass, socket_opts, 1, 5_000, ssl_enforced) state = %State{ id: id, diff --git a/lib/realtime/helpers.ex b/lib/realtime/helpers.ex index 694661eb4..4dae972d6 100644 --- a/lib/realtime/helpers.ex +++ b/lib/realtime/helpers.ex @@ -29,10 +29,21 @@ defmodule Realtime.Helpers do String.t(), list(), non_neg_integer(), - non_neg_integer() + non_neg_integer(), + boolean() ) :: {:ok, pid} | {:error, Postgrex.Error.t() | term()} - def connect_db(host, port, name, user, pass, socket_opts, pool \\ 5, queue_target \\ 5_000) do + def connect_db( + host, + port, + name, + user, + pass, + socket_opts, + pool \\ 5, + queue_target \\ 5_000, + ssl_enforced \\ true + ) do secure_key = Application.get_env(:realtime, :db_enc_key) host = decrypt!(host, secure_key) @@ -41,7 +52,7 @@ defmodule Realtime.Helpers do pass = decrypt!(pass, secure_key) user = decrypt!(user, secure_key) - Postgrex.start_link( + [ hostname: host, port: port, database: name, @@ -53,7 +64,33 @@ defmodule Realtime.Helpers do application_name: "supabase_realtime" ], socket_options: socket_opts - ) + ] + |> maybe_enforce_ssl_config(ssl_enforced) + |> Postgrex.start_link() + end + + @spec default_ssl_param(map) :: boolean + def default_ssl_param(%{"ssl_enforced" => ssl_enforced}) when is_boolean(ssl_enforced), + do: ssl_enforced + + def default_ssl_param(_), do: true + + @spec maybe_enforce_ssl_config(maybe_improper_list, boolean()) :: maybe_improper_list + def maybe_enforce_ssl_config(db_config, ssl_enforced) + when is_list(db_config) and is_boolean(ssl_enforced) do + if ssl_enforced do + enforce_ssl_config(db_config) + else + db_config + end + end + + def maybe_enforce_ssl_config(db_config, _) do + enforce_ssl_config(db_config) + end + + defp enforce_ssl_config(db_config) when is_list(db_config) do + db_config ++ [ssl: true, ssl_opts: [verify: :verify_none]] end @doc """ diff --git a/lib/realtime_web.ex b/lib/realtime_web.ex index 96420b755..63f2ccc45 100644 --- a/lib/realtime_web.ex +++ b/lib/realtime_web.ex @@ -84,7 +84,7 @@ defmodule RealtimeWeb do def channel do quote do - use Phoenix.Channel + use Phoenix.Channel, log_join: :info, log_handle_in: :info import RealtimeWeb.Gettext end end diff --git a/priv/repo/seeds_after_migration.exs b/priv/repo/seeds_after_migration.exs index 327ba5861..cff31f4fd 100644 --- a/priv/repo/seeds_after_migration.exs +++ b/priv/repo/seeds_after_migration.exs @@ -27,7 +27,8 @@ end "poll_max_changes" => 100, "poll_max_record_bytes" => 1_048_576, "publication" => publication, - "region" => "us-east-1" + "region" => "us-east-1", + "ssl_enforced" => false } } ],