From 1ffcb611eaf7010c5367804b67310be9a35a3a7e Mon Sep 17 00:00:00 2001 From: Rohit Chatterjee Date: Sun, 7 Apr 2024 16:00:20 +0530 Subject: [PATCH 01/10] map_airbyte_keys_to_postgres_keys --- dbt_automation/utils/warehouseclient.py | 23 ++++++++ tests/utils/test_warehouseclient.py | 71 +++++++++++++++++++++++++ 2 files changed, 94 insertions(+) create mode 100644 tests/utils/test_warehouseclient.py diff --git a/dbt_automation/utils/warehouseclient.py b/dbt_automation/utils/warehouseclient.py index 54d4783..87e0029 100644 --- a/dbt_automation/utils/warehouseclient.py +++ b/dbt_automation/utils/warehouseclient.py @@ -4,9 +4,32 @@ from dbt_automation.utils.bigquery import BigQueryClient +def map_airbyte_keys_to_postgres_keys(conn_info: dict): + """called below and by `post_system_transformation_tasks`""" + if "tunnel_method" in conn_info: + method = conn_info["tunnel_method"] + + if method["tunnel_method"] in ["SSH_KEY_AUTH", "SSH_PASSWORD_AUTH"]: + conn_info["ssh_host"] = method["tunnel_host"] + conn_info["ssh_port"] = method["tunnel_port"] + conn_info["ssh_username"] = method["tunnel_user"] + + if method["tunnel_method"] == "SSH_KEY_AUTH": + conn_info["ssh_pkey"] = method["ssh_key"] + + elif method["tunnel_method"] == "SSH_PASSWORD_AUTH": + conn_info["ssh_password"] = method["tunnel_user_password"] + + conn_info["user"] = conn_info["username"] + + return conn_info + + def get_client(warehouse: str, conn_info: dict = None, location: str = None): """constructs and returns an instance of the client for the right warehouse""" if warehouse == "postgres": + # conn_info gets passed to psycopg2.connect + conn_info = map_airbyte_keys_to_postgres_keys(conn_info) client = PostgresClient(conn_info) elif warehouse == "bigquery": client = BigQueryClient(conn_info, location) diff --git a/tests/utils/test_warehouseclient.py b/tests/utils/test_warehouseclient.py new file mode 100644 index 0000000..677cf99 --- /dev/null +++ b/tests/utils/test_warehouseclient.py @@ -0,0 +1,71 @@ +import pytest +from dbt_automation.utils.warehouseclient import map_airbyte_keys_to_postgres_keys + + +def test_map_airbyte_keys_to_postgres_keys_oldconfig(): + """verifies the correct mapping of keys""" + conn_info = { + "host": "host", + "port": 100, + "username": "user", + "password": "password", + "database": "database", + } + conn_info = map_airbyte_keys_to_postgres_keys(conn_info) + assert conn_info["host"] == "host" + assert conn_info["port"] == "port" + assert conn_info["username"] == "username" + assert conn_info["password"] == "password" + assert conn_info["database"] == "database" + + +def test_map_airbyte_keys_to_postgres_keys_sshkey(): + """verifies the correct mapping of keys""" + conn_info = { + "tunnel_method": { + "tunnel_method": "SSH_KEY_AUTH", + }, + "tunnel_host": "host", + "tunnel_port": 22, + "tunnel_user": "user", + "ssh_key": "ssh-key", + } + conn_info = map_airbyte_keys_to_postgres_keys(conn_info) + assert conn_info["ssh_host"] == "host" + assert conn_info["ssh_port"] == 22 + assert conn_info["ssh_username"] == "user" + assert conn_info["ssh_pkey"] == "ssh-key" + + +def test_map_airbyte_keys_to_postgres_keys_password(): + """verifies the correct mapping of keys""" + conn_info = { + "tunnel_method": { + "tunnel_method": "SSH_PASSWORD_AUTH", + }, + "tunnel_host": "host", + "tunnel_port": 22, + "tunnel_user": "user", + "ssh_password": "ssh-password", + } + conn_info = map_airbyte_keys_to_postgres_keys(conn_info) + assert conn_info["ssh_host"] == "host" + assert conn_info["ssh_port"] == 22 + assert conn_info["ssh_username"] == "user" + assert conn_info["ssh_password"] == "ssh-password" + + +def test_map_airbyte_keys_to_postgres_keys_notunnel(): + """verifies the correct mapping of keys""" + conn_info = { + "tunnel_method": { + "tunnel_method": "NO_TUNNEL", + }, + "tunnel_host": "host", + "tunnel_port": 22, + "tunnel_user": "user", + } + conn_info = map_airbyte_keys_to_postgres_keys(conn_info) + assert conn_info["ssh_host"] == "host" + assert conn_info["ssh_port"] == 22 + assert conn_info["ssh_username"] == "user" From 68d8e8b65f60d5e030e66004c88c7bc4b8828ec4 Mon Sep 17 00:00:00 2001 From: Rohit Chatterjee Date: Sun, 7 Apr 2024 17:07:50 +0530 Subject: [PATCH 02/10] support ssh tunneling for postgres --- dbt_automation/utils/postgres.py | 107 ++++++++++++++++-------- dbt_automation/utils/warehouseclient.py | 3 + requirements.txt | 12 +++ 3 files changed, 89 insertions(+), 33 deletions(-) diff --git a/dbt_automation/utils/postgres.py b/dbt_automation/utils/postgres.py index a57a7af..dc6fd1f 100644 --- a/dbt_automation/utils/postgres.py +++ b/dbt_automation/utils/postgres.py @@ -1,8 +1,9 @@ """helpers for postgres""" +import os from logging import basicConfig, getLogger, INFO import psycopg2 -import os +from sshtunnel import SSHTunnelForwarder from dbt_automation.utils.columnutils import quote_columnname from dbt_automation.utils.interfaces.warehouse_interface import WarehouseInterface @@ -15,15 +16,32 @@ class PostgresClient(WarehouseInterface): """a postgres client that can be used as a context manager""" @staticmethod - def get_connection(host: str, port: str, user: str, password: str, database: str): - """returns a psycopg connection""" - connection = psycopg2.connect( - host=host, - port=port, - user=user, - password=password, - database=database, - ) + def get_connection(conn_info): + """ + returns a psycopg connection + parameters are + host: str + port: str + user: str + password: str + database: str + sslmode: require | disable | prefer | allow | verify-ca | verify-full + sslrootcert: /path/to/cert + ... + """ + connect_params = {} + for key in [ + "host", + "port", + "user", + "password", + "database", + "sslmode", + "sslrootcert", + ]: + if key in conn_info: + connect_params[key] = conn_info[key] + connection = psycopg2.connect(**connect_params) return connection def __init__(self, conn_info: dict): @@ -32,21 +50,44 @@ def __init__(self, conn_info: dict): conn_info = { "host": os.getenv("DBHOST"), "port": os.getenv("DBPORT"), - "username": os.getenv("DBUSER"), + "user": os.getenv("DBUSER"), "password": os.getenv("DBPASSWORD"), "database": os.getenv("DBNAME"), } - self.connection = PostgresClient.get_connection( - conn_info.get("host"), - conn_info.get("port"), - conn_info.get("username"), - conn_info.get("password"), - conn_info.get("database"), - ) + self.tunnel = None + if "ssh_host" in conn_info: + self.tunnel = SSHTunnelForwarder( + (conn_info["ssh_host"], conn_info["ssh_port"]), + remote_bind_address=(conn_info["host"], conn_info["port"]), + # ...and credentials + ssh_pkey=conn_info.get("ssh_pkey"), + ssh_username=conn_info.get("ssh_username"), + ssh_password=conn_info.get("ssh_password"), + ssh_private_key_password=conn_info.get("ssh_private_key_password"), + ) + self.tunnel.start() + conn_info["host"] = "localhost" + conn_info["port"] = self.tunnel.local_bind_port + self.connection = PostgresClient.get_connection(conn_info) + + else: + self.connection = PostgresClient.get_connection(conn_info) self.cursor = None self.conn_info = conn_info + def __del__(self): + """destructor""" + if self.cursor is not None: + self.cursor.close() + self.cursor = None + if self.connection is not None: + self.connection.close() + self.connection = None + if self.tunnel is not None: + self.tunnel.stop() + self.tunnel = None + def runcmd(self, statement: str): """runs a command""" if self.cursor is None: @@ -75,7 +116,7 @@ def get_tables(self, schema: str) -> list: def get_schemas(self) -> list: """returns the list of schema names in the given database connection""" resultset = self.execute( - f""" + """ SELECT nspname FROM pg_namespace WHERE nspname NOT LIKE 'pg_%' AND nspname != 'information_schema'; @@ -132,7 +173,7 @@ def get_table_columns(self, schema: str, table: str) -> list: ) return [{"name": x[0], "data_type": x[1]} for x in resultset] - def get_columnspec(self, schema: str, table: str): + def get_columnspec(self, schema: str, table_id: str): """get the column schema for this table""" return [ x[0] @@ -140,7 +181,7 @@ def get_columnspec(self, schema: str, table: str): f"""SELECT column_name FROM information_schema.columns WHERE table_schema = '{schema}' - AND table_name = '{table}' + AND table_name = '{table_id}' """ ) ] @@ -200,26 +241,26 @@ def close(self): return True def generate_profiles_yaml_dbt(self, project_name, default_schema): - """Generates the profiles.yml dictionary object for dbt""" - if project_name is None or default_schema is None: - raise ValueError("project_name and default_schema are required") - - target = "prod" - """ - : + Generates the profiles.yml dictionary object for dbt + : outputs: - prod: - dbname: - host: - password: + prod: + dbname: + host: + password: port: 5432 user: airbyte_user - schema: + schema: threads: 4 type: postgres target: prod """ + if project_name is None or default_schema is None: + raise ValueError("project_name and default_schema are required") + + target = "prod" + profiles_yml = { f"{project_name}": { "outputs": { diff --git a/dbt_automation/utils/warehouseclient.py b/dbt_automation/utils/warehouseclient.py index 87e0029..6c1459e 100644 --- a/dbt_automation/utils/warehouseclient.py +++ b/dbt_automation/utils/warehouseclient.py @@ -16,6 +16,9 @@ def map_airbyte_keys_to_postgres_keys(conn_info: dict): if method["tunnel_method"] == "SSH_KEY_AUTH": conn_info["ssh_pkey"] = method["ssh_key"] + conn_info["ssh_private_key_password"] = method.get( + "tunnel_private_key_password" + ) elif method["tunnel_method"] == "SSH_PASSWORD_AUTH": conn_info["ssh_password"] = method["tunnel_user_password"] diff --git a/requirements.txt b/requirements.txt index 2660098..126e2c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,11 @@ +bcrypt==4.1.2 cachetools==5.3.1 certifi==2023.7.22 +cffi==1.16.0 charset-normalizer==3.3.0 coverage==7.3.2 +cryptography==42.0.5 +dbt-automation==0.1 exceptiongroup==1.1.3 google-api-core==2.12.0 google-auth==2.23.2 @@ -14,22 +18,30 @@ grpcio==1.59.0 grpcio-status==1.59.0 idna==3.4 iniconfig==2.0.0 +numpy==1.26.0 packaging==23.2 +pandas==2.1.1 +paramiko==3.4.0 pluggy==1.3.0 proto-plus==1.22.3 protobuf==4.24.4 psycopg2-binary==2.9.7 pyasn1==0.5.0 pyasn1-modules==0.3.0 +pycparser==2.22 +PyNaCl==1.5.0 pytest==7.4.3 pytest-cov==4.1.0 pytest-env==1.1.1 python-dateutil==2.8.2 python-dotenv==1.0.0 +pytz==2023.3.post1 PyYAML==6.0.1 requests==2.31.0 rsa==4.9 six==1.16.0 +sshtunnel==0.4.0 tomli==2.0.1 tqdm==4.66.1 +tzdata==2023.3 urllib3==2.0.6 From f9df8019e9ab300d3ebd2591cb8f5f2da077de4f Mon Sep 17 00:00:00 2001 From: Rohit Chatterjee Date: Tue, 9 Apr 2024 20:35:20 +0530 Subject: [PATCH 03/10] map_airbyte_keys_to_postgres_keys will be called in django before passing the conn_info in to get_client --- dbt_automation/utils/warehouseclient.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/dbt_automation/utils/warehouseclient.py b/dbt_automation/utils/warehouseclient.py index 6c1459e..c50df49 100644 --- a/dbt_automation/utils/warehouseclient.py +++ b/dbt_automation/utils/warehouseclient.py @@ -4,35 +4,10 @@ from dbt_automation.utils.bigquery import BigQueryClient -def map_airbyte_keys_to_postgres_keys(conn_info: dict): - """called below and by `post_system_transformation_tasks`""" - if "tunnel_method" in conn_info: - method = conn_info["tunnel_method"] - - if method["tunnel_method"] in ["SSH_KEY_AUTH", "SSH_PASSWORD_AUTH"]: - conn_info["ssh_host"] = method["tunnel_host"] - conn_info["ssh_port"] = method["tunnel_port"] - conn_info["ssh_username"] = method["tunnel_user"] - - if method["tunnel_method"] == "SSH_KEY_AUTH": - conn_info["ssh_pkey"] = method["ssh_key"] - conn_info["ssh_private_key_password"] = method.get( - "tunnel_private_key_password" - ) - - elif method["tunnel_method"] == "SSH_PASSWORD_AUTH": - conn_info["ssh_password"] = method["tunnel_user_password"] - - conn_info["user"] = conn_info["username"] - - return conn_info - - def get_client(warehouse: str, conn_info: dict = None, location: str = None): """constructs and returns an instance of the client for the right warehouse""" if warehouse == "postgres": # conn_info gets passed to psycopg2.connect - conn_info = map_airbyte_keys_to_postgres_keys(conn_info) client = PostgresClient(conn_info) elif warehouse == "bigquery": client = BigQueryClient(conn_info, location) From 50d70d0dadc18fa40b9d4aa0a1f8cf340bc535e8 Mon Sep 17 00:00:00 2001 From: Rohit Chatterjee Date: Tue, 9 Apr 2024 20:37:03 +0530 Subject: [PATCH 04/10] remove comment --- dbt_automation/utils/warehouseclient.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dbt_automation/utils/warehouseclient.py b/dbt_automation/utils/warehouseclient.py index c50df49..54d4783 100644 --- a/dbt_automation/utils/warehouseclient.py +++ b/dbt_automation/utils/warehouseclient.py @@ -7,7 +7,6 @@ def get_client(warehouse: str, conn_info: dict = None, location: str = None): """constructs and returns an instance of the client for the right warehouse""" if warehouse == "postgres": - # conn_info gets passed to psycopg2.connect client = PostgresClient(conn_info) elif warehouse == "bigquery": client = BigQueryClient(conn_info, location) From ec0eeaec81e45209c18f1ff0d3c24f0483d092be Mon Sep 17 00:00:00 2001 From: Rohit Chatterjee Date: Tue, 9 Apr 2024 20:39:27 +0530 Subject: [PATCH 05/10] moved to ddp backend --- tests/utils/test_warehouseclient.py | 71 ----------------------------- 1 file changed, 71 deletions(-) delete mode 100644 tests/utils/test_warehouseclient.py diff --git a/tests/utils/test_warehouseclient.py b/tests/utils/test_warehouseclient.py deleted file mode 100644 index 677cf99..0000000 --- a/tests/utils/test_warehouseclient.py +++ /dev/null @@ -1,71 +0,0 @@ -import pytest -from dbt_automation.utils.warehouseclient import map_airbyte_keys_to_postgres_keys - - -def test_map_airbyte_keys_to_postgres_keys_oldconfig(): - """verifies the correct mapping of keys""" - conn_info = { - "host": "host", - "port": 100, - "username": "user", - "password": "password", - "database": "database", - } - conn_info = map_airbyte_keys_to_postgres_keys(conn_info) - assert conn_info["host"] == "host" - assert conn_info["port"] == "port" - assert conn_info["username"] == "username" - assert conn_info["password"] == "password" - assert conn_info["database"] == "database" - - -def test_map_airbyte_keys_to_postgres_keys_sshkey(): - """verifies the correct mapping of keys""" - conn_info = { - "tunnel_method": { - "tunnel_method": "SSH_KEY_AUTH", - }, - "tunnel_host": "host", - "tunnel_port": 22, - "tunnel_user": "user", - "ssh_key": "ssh-key", - } - conn_info = map_airbyte_keys_to_postgres_keys(conn_info) - assert conn_info["ssh_host"] == "host" - assert conn_info["ssh_port"] == 22 - assert conn_info["ssh_username"] == "user" - assert conn_info["ssh_pkey"] == "ssh-key" - - -def test_map_airbyte_keys_to_postgres_keys_password(): - """verifies the correct mapping of keys""" - conn_info = { - "tunnel_method": { - "tunnel_method": "SSH_PASSWORD_AUTH", - }, - "tunnel_host": "host", - "tunnel_port": 22, - "tunnel_user": "user", - "ssh_password": "ssh-password", - } - conn_info = map_airbyte_keys_to_postgres_keys(conn_info) - assert conn_info["ssh_host"] == "host" - assert conn_info["ssh_port"] == 22 - assert conn_info["ssh_username"] == "user" - assert conn_info["ssh_password"] == "ssh-password" - - -def test_map_airbyte_keys_to_postgres_keys_notunnel(): - """verifies the correct mapping of keys""" - conn_info = { - "tunnel_method": { - "tunnel_method": "NO_TUNNEL", - }, - "tunnel_host": "host", - "tunnel_port": 22, - "tunnel_user": "user", - } - conn_info = map_airbyte_keys_to_postgres_keys(conn_info) - assert conn_info["ssh_host"] == "host" - assert conn_info["ssh_port"] == 22 - assert conn_info["ssh_username"] == "user" From 6910215367db1ec9c38d445b1ad3ea22dd2d7810 Mon Sep 17 00:00:00 2001 From: Ishankoradia Date: Wed, 17 Apr 2024 11:26:48 +0530 Subject: [PATCH 06/10] minor change --- dbt_automation/utils/postgres.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dbt_automation/utils/postgres.py b/dbt_automation/utils/postgres.py index dc6fd1f..ab8d065 100644 --- a/dbt_automation/utils/postgres.py +++ b/dbt_automation/utils/postgres.py @@ -46,6 +46,10 @@ def get_connection(conn_info): def __init__(self, conn_info: dict): self.name = "postgres" + self.cursor = None + self.tunnel = None + self.connection = None + if conn_info is None: # take creds from env conn_info = { "host": os.getenv("DBHOST"), @@ -55,7 +59,6 @@ def __init__(self, conn_info: dict): "database": os.getenv("DBNAME"), } - self.tunnel = None if "ssh_host" in conn_info: self.tunnel = SSHTunnelForwarder( (conn_info["ssh_host"], conn_info["ssh_port"]), @@ -73,7 +76,6 @@ def __init__(self, conn_info: dict): else: self.connection = PostgresClient.get_connection(conn_info) - self.cursor = None self.conn_info = conn_info def __del__(self): From 8cfbfc0f0971ddc75661bb8a24eca9c4d40b9785 Mon Sep 17 00:00:00 2001 From: Ishankoradia Date: Wed, 17 Apr 2024 11:55:08 +0530 Subject: [PATCH 07/10] do the cleanup in close method also --- dbt_automation/utils/postgres.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/dbt_automation/utils/postgres.py b/dbt_automation/utils/postgres.py index ab8d065..200b8e3 100644 --- a/dbt_automation/utils/postgres.py +++ b/dbt_automation/utils/postgres.py @@ -236,7 +236,15 @@ def json_extract_op(self, json_column: str, json_field: str, sql_column: str): def close(self): try: - self.connection.close() + if self.cursor is not None: + self.cursor.close() + self.cursor = None + if self.tunnel is not None: + self.tunnel.stop() + self.tunnel = None + if self.connection is not None: + self.connection.close() + self.connection = None except Exception: logger.error("something went wrong while closing the postgres connection") From 4d8edae1693e7ca1a45e616e2d5dbd34c3d80473 Mon Sep 17 00:00:00 2001 From: Ishankoradia Date: Wed, 17 Apr 2024 11:59:08 +0530 Subject: [PATCH 08/10] fixing the requirement for dbt automation --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 126e2c6..93bfeea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ cffi==1.16.0 charset-normalizer==3.3.0 coverage==7.3.2 cryptography==42.0.5 -dbt-automation==0.1 +dbt-automation @ git+https://github.com/DalgoT4D/dbt-automation.git exceptiongroup==1.1.3 google-api-core==2.12.0 google-auth==2.23.2 From 806ecf38e7b81aa9f9b3c99b814bfba4a0e168eb Mon Sep 17 00:00:00 2001 From: Ishankoradia Date: Wed, 17 Apr 2024 12:27:01 +0530 Subject: [PATCH 09/10] fixing the test cases --- dbt_automation/utils/postgres.py | 2 +- tests/warehouse/test_postgres_ops.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dbt_automation/utils/postgres.py b/dbt_automation/utils/postgres.py index 200b8e3..f997240 100644 --- a/dbt_automation/utils/postgres.py +++ b/dbt_automation/utils/postgres.py @@ -279,7 +279,7 @@ def generate_profiles_yaml_dbt(self, project_name, default_schema): "host": self.conn_info["host"], "password": self.conn_info["password"], "port": int(self.conn_info["port"]), - "user": self.conn_info["username"], + "user": self.conn_info["user"], "schema": default_schema, "threads": 4, "type": "postgres", diff --git a/tests/warehouse/test_postgres_ops.py b/tests/warehouse/test_postgres_ops.py index b9bc9fd..583f6ca 100644 --- a/tests/warehouse/test_postgres_ops.py +++ b/tests/warehouse/test_postgres_ops.py @@ -40,7 +40,7 @@ class TestPostgresOperations: { "host": os.environ.get("TEST_PG_DBHOST"), "port": os.environ.get("TEST_PG_DBPORT"), - "username": os.environ.get("TEST_PG_DBUSER"), + "user": os.environ.get("TEST_PG_DBUSER"), "database": os.environ.get("TEST_PG_DBNAME"), "password": os.environ.get("TEST_PG_DBPASSWORD"), }, From 83b13971c847aaf91ab6065d4b5fada730c4926a Mon Sep 17 00:00:00 2001 From: Ishankoradia Date: Wed, 17 Apr 2024 13:04:31 +0530 Subject: [PATCH 10/10] fix the bigquery unpivot test case --- tests/warehouse/test_bigquery_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/warehouse/test_bigquery_ops.py b/tests/warehouse/test_bigquery_ops.py index f6e6900..011581c 100644 --- a/tests/warehouse/test_bigquery_ops.py +++ b/tests/warehouse/test_bigquery_ops.py @@ -790,6 +790,7 @@ def test_unpivot(self): "unpivot_columns": ["NGO", "SPOC"], "unpivot_field_name": "col_field", "unpivot_value_name": "col_val", + "cast_to": "STRING", } unpivot(