Skip to content

Commit

Permalink
Merge pull request #104 from DalgoT4D/103-get_connection-should-pass-…
Browse files Browse the repository at this point in the history
…parameters-on-to-psycopg2

103 get connection should pass parameters on to psycopg2
  • Loading branch information
Ishankoradia authored Apr 17, 2024
2 parents 648f7f5 + 83b1397 commit 78ef9c5
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 37 deletions.
123 changes: 87 additions & 36 deletions dbt_automation/utils/postgres.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""helpers for postgres"""

import os
from logging import basicConfig, getLogger, INFO
import psycopg2
import os
from sshtunnel import SSHTunnelForwarder
from dbt_automation.utils.columnutils import quote_columnname
from dbt_automation.utils.interfaces.warehouse_interface import WarehouseInterface

Expand All @@ -15,38 +16,80 @@ class PostgresClient(WarehouseInterface):
"""a postgres client that can be used as a context manager"""

@staticmethod
def get_connection(host: str, port: str, user: str, password: str, database: str):
"""returns a psycopg connection"""
connection = psycopg2.connect(
host=host,
port=port,
user=user,
password=password,
database=database,
)
def get_connection(conn_info):
"""
returns a psycopg connection
parameters are
host: str
port: str
user: str
password: str
database: str
sslmode: require | disable | prefer | allow | verify-ca | verify-full
sslrootcert: /path/to/cert
...
"""
connect_params = {}
for key in [
"host",
"port",
"user",
"password",
"database",
"sslmode",
"sslrootcert",
]:
if key in conn_info:
connect_params[key] = conn_info[key]
connection = psycopg2.connect(**connect_params)
return connection

def __init__(self, conn_info: dict):
self.name = "postgres"
self.cursor = None
self.tunnel = None
self.connection = None

if conn_info is None: # take creds from env
conn_info = {
"host": os.getenv("DBHOST"),
"port": os.getenv("DBPORT"),
"username": os.getenv("DBUSER"),
"user": os.getenv("DBUSER"),
"password": os.getenv("DBPASSWORD"),
"database": os.getenv("DBNAME"),
}

self.connection = PostgresClient.get_connection(
conn_info.get("host"),
conn_info.get("port"),
conn_info.get("username"),
conn_info.get("password"),
conn_info.get("database"),
)
self.cursor = None
if "ssh_host" in conn_info:
self.tunnel = SSHTunnelForwarder(
(conn_info["ssh_host"], conn_info["ssh_port"]),
remote_bind_address=(conn_info["host"], conn_info["port"]),
# ...and credentials
ssh_pkey=conn_info.get("ssh_pkey"),
ssh_username=conn_info.get("ssh_username"),
ssh_password=conn_info.get("ssh_password"),
ssh_private_key_password=conn_info.get("ssh_private_key_password"),
)
self.tunnel.start()
conn_info["host"] = "localhost"
conn_info["port"] = self.tunnel.local_bind_port
self.connection = PostgresClient.get_connection(conn_info)

else:
self.connection = PostgresClient.get_connection(conn_info)
self.conn_info = conn_info

def __del__(self):
"""destructor"""
if self.cursor is not None:
self.cursor.close()
self.cursor = None
if self.connection is not None:
self.connection.close()
self.connection = None
if self.tunnel is not None:
self.tunnel.stop()
self.tunnel = None

def runcmd(self, statement: str):
"""runs a command"""
if self.cursor is None:
Expand Down Expand Up @@ -75,7 +118,7 @@ def get_tables(self, schema: str) -> list:
def get_schemas(self) -> list:
"""returns the list of schema names in the given database connection"""
resultset = self.execute(
f"""
"""
SELECT nspname
FROM pg_namespace
WHERE nspname NOT LIKE 'pg_%' AND nspname != 'information_schema';
Expand Down Expand Up @@ -132,15 +175,15 @@ def get_table_columns(self, schema: str, table: str) -> list:
)
return [{"name": x[0], "data_type": x[1]} for x in resultset]

def get_columnspec(self, schema: str, table: str):
def get_columnspec(self, schema: str, table_id: str):
"""get the column schema for this table"""
return [
x[0]
for x in self.execute(
f"""SELECT column_name
FROM information_schema.columns
WHERE table_schema = '{schema}'
AND table_name = '{table}'
AND table_name = '{table_id}'
"""
)
]
Expand Down Expand Up @@ -193,33 +236,41 @@ def json_extract_op(self, json_column: str, json_field: str, sql_column: str):

def close(self):
try:
self.connection.close()
if self.cursor is not None:
self.cursor.close()
self.cursor = None
if self.tunnel is not None:
self.tunnel.stop()
self.tunnel = None
if self.connection is not None:
self.connection.close()
self.connection = None
except Exception:
logger.error("something went wrong while closing the postgres connection")

return True

def generate_profiles_yaml_dbt(self, project_name, default_schema):
"""Generates the profiles.yml dictionary object for dbt"""
if project_name is None or default_schema is None:
raise ValueError("project_name and default_schema are required")

target = "prod"

"""
<project_name>:
Generates the profiles.yml dictionary object for dbt
<project_name>:
outputs:
prod:
dbname:
host:
password:
prod:
dbname:
host:
password:
port: 5432
user: airbyte_user
schema:
schema:
threads: 4
type: postgres
target: prod
"""
if project_name is None or default_schema is None:
raise ValueError("project_name and default_schema are required")

target = "prod"

profiles_yml = {
f"{project_name}": {
"outputs": {
Expand All @@ -228,7 +279,7 @@ def generate_profiles_yaml_dbt(self, project_name, default_schema):
"host": self.conn_info["host"],
"password": self.conn_info["password"],
"port": int(self.conn_info["port"]),
"user": self.conn_info["username"],
"user": self.conn_info["user"],
"schema": default_schema,
"threads": 4,
"type": "postgres",
Expand Down
12 changes: 12 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
bcrypt==4.1.2
cachetools==5.3.1
certifi==2023.7.22
cffi==1.16.0
charset-normalizer==3.3.0
coverage==7.3.2
cryptography==42.0.5
dbt-automation @ git+https://github.com/DalgoT4D/dbt-automation.git
exceptiongroup==1.1.3
google-api-core==2.12.0
google-auth==2.23.2
Expand All @@ -14,22 +18,30 @@ grpcio==1.59.0
grpcio-status==1.59.0
idna==3.4
iniconfig==2.0.0
numpy==1.26.0
packaging==23.2
pandas==2.1.1
paramiko==3.4.0
pluggy==1.3.0
proto-plus==1.22.3
protobuf==4.24.4
psycopg2-binary==2.9.7
pyasn1==0.5.0
pyasn1-modules==0.3.0
pycparser==2.22
PyNaCl==1.5.0
pytest==7.4.3
pytest-cov==4.1.0
pytest-env==1.1.1
python-dateutil==2.8.2
python-dotenv==1.0.0
pytz==2023.3.post1
PyYAML==6.0.1
requests==2.31.0
rsa==4.9
six==1.16.0
sshtunnel==0.4.0
tomli==2.0.1
tqdm==4.66.1
tzdata==2023.3
urllib3==2.0.6
1 change: 1 addition & 0 deletions tests/warehouse/test_bigquery_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,7 @@ def test_unpivot(self):
"unpivot_columns": ["NGO", "SPOC"],
"unpivot_field_name": "col_field",
"unpivot_value_name": "col_val",
"cast_to": "STRING",
}

unpivot(
Expand Down
2 changes: 1 addition & 1 deletion tests/warehouse/test_postgres_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class TestPostgresOperations:
{
"host": os.environ.get("TEST_PG_DBHOST"),
"port": os.environ.get("TEST_PG_DBPORT"),
"username": os.environ.get("TEST_PG_DBUSER"),
"user": os.environ.get("TEST_PG_DBUSER"),
"database": os.environ.get("TEST_PG_DBNAME"),
"password": os.environ.get("TEST_PG_DBPASSWORD"),
},
Expand Down

0 comments on commit 78ef9c5

Please sign in to comment.