diff --git a/dbt_automation/utils/postgres.py b/dbt_automation/utils/postgres.py index 148502b..54dd331 100644 --- a/dbt_automation/utils/postgres.py +++ b/dbt_automation/utils/postgres.py @@ -1,6 +1,7 @@ """helpers for postgres""" import os +import tempfile from logging import basicConfig, getLogger, INFO import psycopg2 from sshtunnel import SSHTunnelForwarder @@ -36,11 +37,43 @@ def get_connection(conn_info): "user", "password", "database", - "sslmode", - "sslrootcert", ]: if key in conn_info: connect_params[key] = conn_info[key] + + # ssl_mode is an alias for sslmode + if "ssl_mode" in conn_info: + conn_info["sslmode"] = conn_info["ssl_mode"] + + if "sslmode" in conn_info: + # sslmode can be a string or a boolean or a dict + if isinstance(conn_info["sslmode"], str): + # "require", "disable", "verify-ca", "verify-full" + connect_params["sslmode"] = conn_info["sslmode"] + elif isinstance(conn_info["sslmode"], bool): + # true = require, false = disable + connect_params["sslmode"] = ( + "require" if conn_info["sslmode"] else "disable" + ) + elif ( + isinstance(conn_info["sslmode"], dict) + and "mode" in conn_info["sslmode"] + ): + # mode is "require", "disable", "verify-ca", "verify-full" etc + connect_params["sslmode"] = conn_info["sslmode"]["mode"] + if "ca_certificate" in conn_info["sslmode"]: + # connect_params['sslcert'] needs a file path but + # conn_info['sslmode']['ca_certificate'] + # is a string (i.e. the actual certificate). so we write + # it to disk and pass the file path + with tempfile.NamedTemporaryFile(delete=False) as fp: + fp.write(conn_info["sslmode"]["ca_certificate"].encode()) + connect_params["sslrootcert"] = fp.name + connect_params["sslcert"] = fp.name + + if "sslrootcert" in conn_info: + connect_params["sslrootcert"] = conn_info["sslrootcert"] + connection = psycopg2.connect(**connect_params) return connection diff --git a/tests/utils/test_postgres.py b/tests/utils/test_postgres.py new file mode 100644 index 0000000..2d9dc5b --- /dev/null +++ b/tests/utils/test_postgres.py @@ -0,0 +1,113 @@ +from unittest.mock import patch, ANY +from dbt_automation.utils.postgres import PostgresClient + + +def test_get_connection_1(): + """tests PostgresClient.get_connection""" + with patch("dbt_automation.utils.postgres.psycopg2.connect") as mock_connect: + PostgresClient.get_connection( + {"host": "HOST", "port": 1234, "user": "USER", "password": "PASSWORD"} + ) + mock_connect.assert_called_once() + mock_connect.assert_called_with( + host="HOST", + port=1234, + user="USER", + password="PASSWORD", + ) + + +def test_get_connection_2(): + """tests PostgresClient.get_connection""" + with patch("dbt_automation.utils.postgres.psycopg2.connect") as mock_connect: + PostgresClient.get_connection( + { + "host": "HOST", + "port": 1234, + "user": "USER", + "password": "PASSWORD", + "database": "DATABASE", + } + ) + mock_connect.assert_called_once() + mock_connect.assert_called_with( + host="HOST", + port=1234, + user="USER", + password="PASSWORD", + database="DATABASE", + ) + + +def test_get_connection_3(): + """tests PostgresClient.get_connection""" + with patch("dbt_automation.utils.postgres.psycopg2.connect") as mock_connect: + PostgresClient.get_connection( + { + "sslmode": "verify-ca", + "sslrootcert": "/path/to/cert", + } + ) + mock_connect.assert_called_once() + mock_connect.assert_called_with( + sslmode="verify-ca", + sslrootcert="/path/to/cert", + ) + + +def test_get_connection_4(): + """tests PostgresClient.get_connection""" + with patch("dbt_automation.utils.postgres.psycopg2.connect") as mock_connect: + PostgresClient.get_connection( + { + "sslmode": True, + "sslrootcert": "/path/to/cert", + } + ) + mock_connect.assert_called_once() + mock_connect.assert_called_with( + sslmode="require", + sslrootcert="/path/to/cert", + ) + + +def test_get_connection_5(): + """tests PostgresClient.get_connection""" + with patch("dbt_automation.utils.postgres.psycopg2.connect") as mock_connect: + PostgresClient.get_connection( + { + "sslmode": False, + "sslrootcert": "/path/to/cert", + } + ) + mock_connect.assert_called_once() + mock_connect.assert_called_with( + sslmode="disable", + sslrootcert="/path/to/cert", + ) + + +def test_get_connection_6(): + """tests PostgresClient.get_connection""" + with patch("dbt_automation.utils.postgres.psycopg2.connect") as mock_connect: + PostgresClient.get_connection( + { + "sslmode": { + "mode": "disable", + } + } + ) + mock_connect.assert_called_once() + mock_connect.assert_called_with( + sslmode="disable", + ) + + +def test_get_connection_7(): + """tests PostgresClient.get_connection""" + with patch("dbt_automation.utils.postgres.psycopg2.connect") as mock_connect: + PostgresClient.get_connection( + {"sslmode": {"mode": "disable", "ca_certificate": "LONG-CERTIFICATE"}} + ) + mock_connect.assert_called_once() + mock_connect.assert_called_with(sslmode="disable", sslrootcert=ANY, sslcert=ANY)