Skip to content

Commit

Permalink
Merge pull request #186 from BritishGeologicalSurvey/more-type-hints
Browse files Browse the repository at this point in the history
More type hints
  • Loading branch information
volcan01010 authored Sep 14, 2023
2 parents 14220b6 + a836b79 commit 3729ba6
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 66 deletions.
13 changes: 13 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,2 +1,15 @@
[flake8]
max-line-length = 120
extend-ignore =
# Do not require self and cls to have type hints
# See: https://github.com/sco1/flake8-annotations/issues/75
ANN101, ANN102
per-file-ignores =
# Do not check annotations in these files
abort.py: A,
connect.py: A,
db_helper_factory.py: A,
db_params.py: A,
row_factories.py: A
utils.py: A,
etlhelper/db_helpers/*: A,
185 changes: 119 additions & 66 deletions etlhelper/etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""
import logging
import re
from collections import namedtuple
from copy import deepcopy
from itertools import (
zip_longest,
Expand All @@ -12,6 +11,12 @@
from typing import (
Any,
Callable,
Collection,
Iterable,
Iterator,
NamedTuple,
Optional,
Union,
)

from etlhelper.abort import (
Expand All @@ -26,18 +31,31 @@
ETLHelperQueryError,
)
from etlhelper.row_factories import dict_row_factory

Chunk = Any
from etlhelper.types import (
Connection,
Row,
Chunk,
)

logger = logging.getLogger('etlhelper')
CHUNKSIZE = 5000


class FailedRow(NamedTuple):
row: Row
exception: Exception


# iter_chunks is where data are retrieved from source database
# All data extraction processes call this function.
def iter_chunks(select_query, conn, parameters=(),
row_factory=dict_row_factory,
transform=None, chunk_size=CHUNKSIZE):
def iter_chunks(
select_query: str,
conn: Connection,
parameters: tuple = (),
row_factory: Callable = dict_row_factory,
transform: Optional[Callable[[Chunk], Chunk]] = None,
chunk_size: int = CHUNKSIZE
) -> Iterator[Chunk]:
"""
Run SQL query against connection and return iterator object to loop over
results in batches of chunksize (default 5000).
Expand Down Expand Up @@ -114,9 +132,14 @@ def iter_chunks(select_query, conn, parameters=(),
first_pass = False


def iter_rows(select_query, conn, parameters=(),
row_factory=dict_row_factory,
transform=None, chunk_size=CHUNKSIZE):
def iter_rows(
select_query: str,
conn: Connection,
parameters: tuple = (),
row_factory: Callable = dict_row_factory,
transform: Optional[Callable[[Chunk], Chunk]] = None,
chunk_size: int = CHUNKSIZE
) -> Iterator[Row]:
"""
Run SQL query against connection and return iterator object to loop over
results, row-by-row.
Expand All @@ -137,9 +160,14 @@ def iter_rows(select_query, conn, parameters=(),
yield row


def fetchone(select_query, conn, parameters=(),
row_factory=dict_row_factory, transform=None,
chunk_size=1):
def fetchone(
select_query: str,
conn: Connection,
parameters: tuple = (),
row_factory: Callable = dict_row_factory,
transform: Optional[Callable[[Chunk], Chunk]] = None,
chunk_size: int = 1
) -> Optional[Row]:
"""
Get first result of query. See iter_rows for details. Note: iter_rows is
recommended for looping over rows individually.
Expand All @@ -165,9 +193,14 @@ def fetchone(select_query, conn, parameters=(),
return result


def fetchall(select_query, conn, parameters=(),
row_factory=dict_row_factory, transform=None,
chunk_size=CHUNKSIZE):
def fetchall(
select_query: str,
conn: Connection,
parameters: tuple = (),
row_factory: Callable = dict_row_factory,
transform: Optional[Callable[[Chunk], Chunk]] = None,
chunk_size: int = CHUNKSIZE
) -> Chunk:
"""
Get all results of query as a list. See iter_rows for details.
:param select_query: str, SQL query to execute
Expand All @@ -184,14 +217,14 @@ def fetchall(select_query, conn, parameters=(),


def executemany(
query: str,
conn,
rows: list[tuple[Any]],
transform: Callable[[Chunk], Chunk] = None,
on_error: Callable = None,
commit_chunks: bool = True,
chunk_size: int = CHUNKSIZE,
) -> tuple[int, int]:
query: str,
conn: Connection,
rows: Iterable[Row],
transform: Optional[Callable[[Chunk], Chunk]] = None,
on_error: Optional[Callable[[list[FailedRow]], Any]] = None,
commit_chunks: bool = True,
chunk_size: int = CHUNKSIZE,
) -> tuple[int, int]:
"""
Use query to insert/update data from rows to database at conn. This
method uses the executemany or execute_batch (PostgreSQL) commands to
Expand Down Expand Up @@ -230,11 +263,11 @@ def executemany(
failed = 0

with helper.cursor(conn) as cursor:
for chunk in _chunker(rows, chunk_size):
for chunk_with_nones in _chunker(rows, chunk_size):
raise_for_abort("abort_etlhelper_threads() called during executemany")

# Chunker pads to whole chunk with None; remove these
chunk = [row for row in chunk if row is not None]
chunk = [row for row in chunk_with_nones if row is not None]

# Apply transform
if transform:
Expand Down Expand Up @@ -291,7 +324,11 @@ def executemany(
return processed, failed


def _execute_by_row(query, conn, chunk):
def _execute_by_row(
query: str,
conn: Connection,
chunk: Chunk
) -> list[FailedRow]:
"""
Retry execution of rows individually and return failed rows along with
their errors. Successful inserts are committed. This is because
Expand All @@ -302,8 +339,7 @@ def _execute_by_row(query, conn, chunk):
:param conn: open dbapi connection, used for transactions
:returns failed_rows: list of (row, exception) tuples
"""
FailedRow = namedtuple('FailedRow', 'row, exception')
failed_rows = []
failed_rows: list[FailedRow] = []

for row in chunk:
try:
Expand All @@ -316,17 +352,17 @@ def _execute_by_row(query, conn, chunk):


def copy_rows(
select_query,
source_conn,
insert_query,
dest_conn,
parameters=(),
row_factory=dict_row_factory,
transform=None,
on_error=None,
commit_chunks=True,
chunk_size=CHUNKSIZE,
) -> tuple[int, int]:
select_query: str,
source_conn: Connection,
insert_query: str,
dest_conn: Connection,
parameters: tuple = (),
row_factory: Callable = dict_row_factory,
transform: Optional[Callable[[Chunk], Chunk]] = None,
on_error: Optional[Callable] = None,
commit_chunks: bool = True,
chunk_size: int = CHUNKSIZE,
) -> tuple[int, int]:
"""
Copy rows from source_conn to dest_conn. select_query and insert_query
specify the data to be transferred.
Expand Down Expand Up @@ -371,7 +407,11 @@ def copy_rows(
return processed, failed


def execute(query, conn, parameters=()):
def execute(
query: str,
conn: Connection,
parameters: Collection[Any] = ()
) -> None:
"""
Run SQL query against connection.
Expand All @@ -398,10 +438,17 @@ def execute(query, conn, parameters=()):
raise ETLHelperQueryError(msg)


def copy_table_rows(table, source_conn, dest_conn, target=None,
row_factory=dict_row_factory,
transform=None, on_error=None, commit_chunks=True,
chunk_size=CHUNKSIZE):
def copy_table_rows(
table: str,
source_conn: Connection,
dest_conn: Connection,
target: Optional[str] = None,
row_factory: Callable = dict_row_factory,
transform: Optional[Callable[[Chunk], Chunk]] = None,
on_error: Optional[Callable] = None,
commit_chunks: bool = True,
chunk_size: int = CHUNKSIZE
) -> tuple[int, int]:
"""
Copy rows from 'table' in source_conn to same or target table in dest_conn.
This is a simple copy of all columns and rows using `load` to insert data.
Expand Down Expand Up @@ -446,14 +493,14 @@ def copy_table_rows(table, source_conn, dest_conn, target=None,


def load(
table: str,
conn,
rows: list,
transform: Callable = None,
on_error: Callable = None,
commit_chunks: bool = True,
chunk_size: int = CHUNKSIZE,
):
table: str,
conn: Connection,
rows: Iterable[Row],
transform: Optional[Callable[[Chunk], Chunk]] = None,
on_error: Optional[Callable] = None,
commit_chunks: bool = True,
chunk_size: int = CHUNKSIZE,
) -> tuple[int, int]:
"""
Load data from iterable of named tuples or dictionaries into pre-existing
table in database on conn.
Expand Down Expand Up @@ -513,7 +560,11 @@ def load(
return processed, failed


def generate_insert_sql(table, row, conn):
def generate_insert_sql(
table: str,
row: Row,
conn: Connection
) -> str:
"""Generate insert SQL for table, getting column names from row and the
placeholder style from the connection. `row` is either a namedtuple or
a dictionary."""
Expand All @@ -527,25 +578,20 @@ def generate_insert_sql(table, row, conn):
}

# Namedtuples use a query with positional placeholders
if not hasattr(row, 'keys'):
if hasattr(row, '_asdict'):
paramstyle = helper.positional_paramstyle

# Convert namedtuple to dictionary to easily access keys
try:
row = row._asdict()
except AttributeError:
msg = f"Row is not a dictionary or namedtuple ({type(row)})"
raise ETLHelperInsertError(msg)

columns = row.keys()
row_dict = row._asdict()
columns = row_dict.keys()
if paramstyle == "numeric":
placeholders = [paramstyles[paramstyle].format(number=i + 1)
for i in range(len(columns))]
else:
placeholders = [paramstyles[paramstyle]] * len(columns)

# Dictionaries use a query with named placeholders
else:
elif hasattr(row, 'keys'):
paramstyle = helper.named_paramstyle
if not paramstyle:
msg = (f"Database connection ({str(conn.__class__)}) doesn't support named parameters. "
Expand All @@ -555,6 +601,10 @@ def generate_insert_sql(table, row, conn):
columns = row.keys()
placeholders = [paramstyles[paramstyle].format(name=c) for c in columns]

else:
msg = f"Row is not a dictionary or namedtuple ({type(row)})"
raise ETLHelperInsertError(msg)

# Validate identifiers to prevent malicious code injection
for identifier in (table, *columns):
validate_identifier(identifier)
Expand All @@ -565,7 +615,7 @@ def generate_insert_sql(table, row, conn):
return sql


def validate_identifier(identifier):
def validate_identifier(identifier: str) -> None:
"""
Validate characters used in identifier e.g. table or column name.
Identifiers must comprise alpha-numeric characters, plus `_` or `$` and
Expand All @@ -588,10 +638,13 @@ def validate_identifier(identifier):
raise ETLHelperBadIdentifierError(msg)


def _chunker(iterable, n_chunks, fillvalue=None):
def _chunker(
iterable: Iterable[Row],
n_chunks: int,
) -> Iterator[tuple[Union[Row, None], ...]]:
"""Collect data into fixed-length chunks or blocks.
Code from recipe at https://docs.python.org/3.6/library/itertools.html
"""
# _chunker('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
# _chunker('ABCDEFG', 3) --> ABC DEF G"
args = [iter(iterable)] * n_chunks
return zip_longest(*args, fillvalue=fillvalue)
return zip_longest(*args, fillvalue=None)
23 changes: 23 additions & 0 deletions etlhelper/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import (
Any,
Collection,
Protocol,
)


class Connection(Protocol):
def close(self) -> None:
...

def commit(self) -> None:
...

def rollback(self) -> None:
...

def cursor(self): # noqa Cursor Protocol not defined
...


Row = Collection[Any]
Chunk = list[Row]
3 changes: 3 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
oracledb
cffi
flake8
flake8-annotations
mypy
ipdb
ipython
psycopg2-binary
types-psycopg2
pyodbc
pytest-cov
pytest>=4.6
Expand Down

0 comments on commit 3729ba6

Please sign in to comment.