Skip to content

Commit

Permalink
Enhance Cursor.description
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet authored and hashhar committed Jan 11, 2023
1 parent fd78e41 commit ea55fc0
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 17 deletions.
62 changes: 48 additions & 14 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def test_none_query_param(trino_connection):
rows = cur.fetchall()

assert rows[0][0] is None
assert_cursor_description(cur, trino_type="unknown")


def test_string_query_param(trino_connection):
Expand All @@ -128,6 +129,7 @@ def test_string_query_param(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == "six'"
assert_cursor_description(cur, trino_type="varchar(4)", size=4)


def test_execute_many(trino_connection):
Expand Down Expand Up @@ -241,10 +243,11 @@ def test_legacy_primitive_types_with_connection_and_cursor(
def test_decimal_query_param(trino_connection):
cur = trino_connection.cursor()

cur.execute("SELECT ?", params=(Decimal('0.142857'),))
cur.execute("SELECT ?", params=(Decimal('1112.142857'),))
rows = cur.fetchall()

assert rows[0][0] == Decimal('0.142857')
assert rows[0][0] == Decimal('1112.142857')
assert_cursor_description(cur, trino_type="decimal(10, 6)", precision=10, scale=6)


def test_null_decimal(trino_connection):
Expand All @@ -254,6 +257,7 @@ def test_null_decimal(trino_connection):
rows = cur.fetchall()

assert rows[0][0] is None
assert_cursor_description(cur, trino_type="decimal(38, 0)", precision=38, scale=0)


def test_biggest_decimal(trino_connection):
Expand All @@ -264,6 +268,7 @@ def test_biggest_decimal(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == params
assert_cursor_description(cur, trino_type="decimal(38, 0)", precision=38, scale=0)


def test_smallest_decimal(trino_connection):
Expand All @@ -274,6 +279,7 @@ def test_smallest_decimal(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == params
assert_cursor_description(cur, trino_type="decimal(38, 0)", precision=38, scale=0)


def test_highest_precision_decimal(trino_connection):
Expand All @@ -284,6 +290,7 @@ def test_highest_precision_decimal(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == params
assert_cursor_description(cur, trino_type="decimal(38, 38)", precision=38, scale=38)


def test_datetime_query_param(trino_connection):
Expand All @@ -295,7 +302,7 @@ def test_datetime_query_param(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == params
assert cur.description[0][1] == "timestamp(6)"
assert_cursor_description(cur, trino_type="timestamp(6)", precision=6)


def test_datetime_with_utc_time_zone_query_param(trino_connection):
Expand All @@ -307,7 +314,7 @@ def test_datetime_with_utc_time_zone_query_param(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == params
assert cur.description[0][1] == "timestamp(6) with time zone"
assert_cursor_description(cur, trino_type="timestamp(6) with time zone", precision=6)


def test_datetime_with_numeric_offset_time_zone_query_param(trino_connection):
Expand All @@ -321,7 +328,7 @@ def test_datetime_with_numeric_offset_time_zone_query_param(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == params
assert cur.description[0][1] == "timestamp(6) with time zone"
assert_cursor_description(cur, trino_type="timestamp(6) with time zone", precision=6)


def test_datetime_with_named_time_zone_query_param(trino_connection):
Expand All @@ -333,7 +340,7 @@ def test_datetime_with_named_time_zone_query_param(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == params
assert cur.description[0][1] == "timestamp(6) with time zone"
assert_cursor_description(cur, trino_type="timestamp(6) with time zone", precision=6)


def test_datetime_with_trailing_zeros(trino_connection):
Expand All @@ -343,6 +350,7 @@ def test_datetime_with_trailing_zeros(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == datetime.strptime("2001-08-22 03:04:05.321000", "%Y-%m-%d %H:%M:%S.%f")
assert_cursor_description(cur, trino_type="timestamp(6)", precision=6)


def test_null_datetime_with_time_zone(trino_connection):
Expand All @@ -352,6 +360,7 @@ def test_null_datetime_with_time_zone(trino_connection):
rows = cur.fetchall()

assert rows[0][0] is None
assert_cursor_description(cur, trino_type="timestamp(3) with time zone", precision=3)


def test_datetime_with_time_zone_numeric_offset(trino_connection):
Expand All @@ -361,6 +370,7 @@ def test_datetime_with_time_zone_numeric_offset(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == datetime.strptime("2001-08-22 03:04:05.321 -08:00", "%Y-%m-%d %H:%M:%S.%f %z")
assert_cursor_description(cur, trino_type="timestamp(3) with time zone", precision=3)


def test_datetimes_with_time_zone_in_dst_gap_query_param(trino_connection):
Expand Down Expand Up @@ -404,6 +414,7 @@ def test_date_query_param(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == params
assert_cursor_description(cur, trino_type="date")


def test_null_date(trino_connection):
Expand All @@ -413,6 +424,7 @@ def test_null_date(trino_connection):
rows = cur.fetchall()

assert rows[0][0] is None
assert_cursor_description(cur, trino_type="date")


def test_unsupported_python_dates(trino_connection):
Expand Down Expand Up @@ -462,6 +474,16 @@ def test_supported_special_dates_query_param(trino_connection):
assert rows[0][0] == params


def test_char(trino_connection):
cur = trino_connection.cursor()

cur.execute("SELECT CHAR 'trino'")
rows = cur.fetchall()

assert rows[0][0] == 'trino'
assert_cursor_description(cur, trino_type="char(5)", size=5)


def test_time_query_param(trino_connection):
cur = trino_connection.cursor()

Expand All @@ -471,7 +493,7 @@ def test_time_query_param(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == params
assert cur.description[0][1] == "time(6)"
assert_cursor_description(cur, trino_type="time(6)", precision=6)


def test_time_with_named_time_zone_query_param(trino_connection):
Expand Down Expand Up @@ -501,7 +523,7 @@ def test_time(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == time(1, 2, 3, 456000)
assert cur.description[0][1] == "time(3)"
assert_cursor_description(cur, trino_type="time(3)", precision=3)


def test_null_time(trino_connection):
Expand All @@ -511,6 +533,7 @@ def test_null_time(trino_connection):
rows = cur.fetchall()

assert rows[0][0] is None
assert_cursor_description(cur, trino_type="time(3)", precision=3)


def test_time_with_time_zone_negative_offset(trino_connection):
Expand All @@ -522,7 +545,7 @@ def test_time_with_time_zone_negative_offset(trino_connection):
tz = timezone(-timedelta(hours=8, minutes=0))

assert rows[0][0] == time(1, 2, 3, 456000, tzinfo=tz)
assert cur.description[0][1] == "time(3) with time zone"
assert_cursor_description(cur, trino_type="time(3) with time zone", precision=3)


def test_time_with_time_zone_positive_offset(trino_connection):
Expand All @@ -534,7 +557,7 @@ def test_time_with_time_zone_positive_offset(trino_connection):
tz = timezone(timedelta(hours=8, minutes=0))

assert rows[0][0] == time(1, 2, 3, 456000, tzinfo=tz)
assert cur.description[0][1] == "time(3) with time zone"
assert_cursor_description(cur, trino_type="time(3) with time zone", precision=3)


def test_null_date_with_time_zone(trino_connection):
Expand All @@ -544,6 +567,7 @@ def test_null_date_with_time_zone(trino_connection):
rows = cur.fetchall()

assert rows[0][0] is None
assert_cursor_description(cur, trino_type="time(3) with time zone", precision=3)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -717,7 +741,7 @@ def test_float_query_param(trino_connection):
cur.execute("SELECT ?", params=(1.1,))
rows = cur.fetchall()

assert cur.description[0][1] == "double"
assert_cursor_description(cur, trino_type="double")
assert rows[0][0] == 1.1


Expand All @@ -726,7 +750,7 @@ def test_float_nan_query_param(trino_connection):
cur.execute("SELECT ?", params=(float("nan"),))
rows = cur.fetchall()

assert cur.description[0][1] == "double"
assert_cursor_description(cur, trino_type="double")
assert isinstance(rows[0][0], float)
assert math.isnan(rows[0][0])

Expand All @@ -736,6 +760,7 @@ def test_float_inf_query_param(trino_connection):
cur.execute("SELECT ?", params=(float("inf"),))
rows = cur.fetchall()

assert_cursor_description(cur, trino_type="double")
assert rows[0][0] == float("inf")

cur.execute("SELECT ?", params=(float("-inf"),))
Expand All @@ -750,13 +775,13 @@ def test_int_query_param(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == 3
assert cur.description[0][1] == "integer"
assert_cursor_description(cur, trino_type="integer")

cur.execute("SELECT ?", params=(9223372036854775807,))
rows = cur.fetchall()

assert rows[0][0] == 9223372036854775807
assert cur.description[0][1] == "bigint"
assert_cursor_description(cur, trino_type="bigint")


@pytest.mark.parametrize('params', [
Expand Down Expand Up @@ -1239,3 +1264,12 @@ def test_describe_table_query(run_trino):
aliased=False,
)
]


def assert_cursor_description(cur, trino_type, size=None, precision=None, scale=None):
assert cur.description[0][1] == trino_type
assert cur.description[0][2] is None
assert cur.description[0][3] is size
assert cur.description[0][4] is precision
assert cur.description[0][5] is scale
assert cur.description[0][6] is None
4 changes: 4 additions & 0 deletions trino/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,7 @@
HEADER_SET_CATALOG = "X-Trino-Set-Catalog"

HEADER_CLIENT_CAPABILITIES = "X-Trino-Client-Capabilities"

LENGTH_TYPES = ["char", "varchar"]
PRECISION_TYPES = ["time", "time with time zone", "timestamp", "timestamp with time zone", "decimal"]
SCALE_TYPES = ["decimal"]
31 changes: 28 additions & 3 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import trino.exceptions
import trino.logging
from trino import constants
from trino.constants import LENGTH_TYPES, PRECISION_TYPES, SCALE_TYPES
from trino.exceptions import (
DatabaseError,
DataError,
Expand Down Expand Up @@ -237,6 +238,31 @@ def from_row(cls, row: List[Any]):
return cls(*row)


class ColumnDescription(NamedTuple):
name: str
type_code: int
display_size: int
internal_size: int
precision: int
scale: int
null_ok: bool

@classmethod
def from_column(cls, column: Dict[str, Any]):
type_signature = column["typeSignature"]
raw_type = type_signature["rawType"]
arguments = type_signature["arguments"]
return cls(
column["name"], # name
column["type"], # type_code
None, # display_size
arguments[0]["value"] if raw_type in LENGTH_TYPES else None, # internal_size
arguments[0]["value"] if raw_type in PRECISION_TYPES else None, # precision
arguments[1]["value"] if raw_type in SCALE_TYPES else None, # scale
None # null_ok
)


class Cursor(object):
"""Database cursor.
Expand Down Expand Up @@ -278,14 +304,13 @@ def update_type(self):
return None

@property
def description(self):
def description(self) -> List[ColumnDescription]:
if self._query.columns is None:
return None

# [ (name, type_code, display_size, internal_size, precision, scale, null_ok) ]
return [
(col["name"], col["type"], None, None, None, None, None)
for col in self._query.columns
ColumnDescription.from_column(col) for col in self._query.columns
]

@property
Expand Down

0 comments on commit ea55fc0

Please sign in to comment.