Skip to content

Commit

Permalink
DateTime and more: Use sqlalchemy_cratedb.dialect.DateTime ...
Browse files Browse the repository at this point in the history
... instead of `sa.DateTime` and `sa.TIMESTAMP`. Introduce
`visit_TIMESTAMP` from PGTypeCompiler to render SQL DDL clauses like
`TIMESTAMP WITH|WITHOUT TIME ZONE`.
  • Loading branch information
amotl committed Jun 24, 2024
1 parent d3ca5df commit 1258b4a
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 28 deletions.
12 changes: 11 additions & 1 deletion src/sqlalchemy_cratedb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def visit_SMALLINT(self, type_, **kw):
return 'SHORT'

def visit_datetime(self, type_, **kw):
return 'TIMESTAMP'
return self.visit_TIMESTAMP(type_, **kw)

def visit_date(self, type_, **kw):
return 'TIMESTAMP'
Expand All @@ -245,6 +245,16 @@ def visit_FLOAT_VECTOR(self, type_, **kw):
raise ValueError("FloatVector must be initialized with dimension size")
return f"FLOAT_VECTOR({dimensions})"

def visit_TIMESTAMP(self, type_, **kw):
"""
Support for `TIMESTAMP WITH|WITHOUT TIME ZONE`.
From `sqlalchemy.dialects.postgresql.base.PGTypeCompiler`.
"""
return "TIMESTAMP %s" % (
(type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE",
)


class CrateCompiler(compiler.SQLCompiler):

Expand Down
11 changes: 6 additions & 5 deletions src/sqlalchemy_cratedb/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
"boolean": sqltypes.Boolean,
"short": sqltypes.SmallInteger,
"smallint": sqltypes.SmallInteger,
"timestamp": sqltypes.TIMESTAMP,
"timestamp with time zone": sqltypes.TIMESTAMP,
"timestamp": sqltypes.TIMESTAMP(timezone=False),
"timestamp with time zone": sqltypes.TIMESTAMP(timezone=True),
"object": ObjectType,
"integer": sqltypes.Integer,
"long": sqltypes.NUMERIC,
Expand All @@ -61,8 +61,8 @@
TYPES_MAP["boolean_array"] = ARRAY(sqltypes.Boolean)
TYPES_MAP["short_array"] = ARRAY(sqltypes.SmallInteger)
TYPES_MAP["smallint_array"] = ARRAY(sqltypes.SmallInteger)
TYPES_MAP["timestamp_array"] = ARRAY(sqltypes.TIMESTAMP)
TYPES_MAP["timestamp with time zone_array"] = ARRAY(sqltypes.TIMESTAMP)
TYPES_MAP["timestamp_array"] = ARRAY(sqltypes.TIMESTAMP(timezone=False))
TYPES_MAP["timestamp with time zone_array"] = ARRAY(sqltypes.TIMESTAMP(timezone=True))
TYPES_MAP["long_array"] = ARRAY(sqltypes.NUMERIC)
TYPES_MAP["bigint_array"] = ARRAY(sqltypes.NUMERIC)
TYPES_MAP["double_array"] = ARRAY(sqltypes.DECIMAL)
Expand Down Expand Up @@ -147,8 +147,9 @@ def process(value):


colspecs = {
sqltypes.Date: Date,
sqltypes.DateTime: DateTime,
sqltypes.Date: Date
sqltypes.TIMESTAMP: DateTime,
}


Expand Down
53 changes: 31 additions & 22 deletions tests/datetime_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from __future__ import absolute_import

from datetime import datetime, tzinfo, timedelta
from datetime import tzinfo, timedelta
import datetime as dt
from unittest import TestCase, skipIf
from unittest.mock import patch, MagicMock
Expand All @@ -31,6 +31,7 @@
from sqlalchemy.orm import Session, sessionmaker

from sqlalchemy_cratedb import SA_VERSION, SA_1_4
from sqlalchemy_cratedb.dialect import DateTime

try:
from sqlalchemy.orm import declarative_base
Expand Down Expand Up @@ -78,7 +79,7 @@ class Character(Base):
__tablename__ = 'characters'
name = sa.Column(sa.String, primary_key=True)
date = sa.Column(sa.Date)
timestamp = sa.Column(sa.DateTime)
datetime = sa.Column(sa.DateTime)

fake_cursor.description = (
('characters_name', None, None, None, None, None, None),
Expand All @@ -100,7 +101,7 @@ def test_date_can_handle_datetime(self):
def test_date_can_handle_tz_aware_datetime(self):
character = self.Character()
character.name = "Athur"
character.timestamp = INPUT_DATETIME_NOTZ
character.datetime = INPUT_DATETIME_NOTZ
self.session.add(character)


Expand All @@ -111,8 +112,8 @@ class FooBar(Base):
__tablename__ = "foobar"
name = sa.Column(sa.String, primary_key=True)
date = sa.Column(sa.Date)
datetime = sa.Column(sa.DateTime)
timestamp = sa.Column(sa.TIMESTAMP)
datetime_notz = sa.Column(DateTime(timezone=False))
datetime_tz = sa.Column(DateTime(timezone=True))


@pytest.fixture
Expand All @@ -135,23 +136,27 @@ def test_datetime_notz(session):
foo_item = FooBar(
name="foo",
date=INPUT_DATE,
datetime=INPUT_DATETIME_NOTZ,
timestamp=INPUT_DATETIME_NOTZ,
datetime_notz=INPUT_DATETIME_NOTZ,
datetime_tz=INPUT_DATETIME_NOTZ,
)
session.add(foo_item)
session.commit()
session.execute(sa.text("REFRESH TABLE foobar"))

# Query record.
result = session.execute(sa.select(FooBar.name, FooBar.date, FooBar.datetime, FooBar.timestamp)).mappings().first()
result = session.execute(sa.select(
FooBar.name, FooBar.date, FooBar.datetime_notz, FooBar.datetime_tz)).mappings().first()

# Compare outcome.
assert result["date"] == OUTPUT_DATE
assert result["datetime"] == OUTPUT_DATETIME_NOTZ
assert result["timestamp"] == OUTPUT_DATETIME_NOTZ
assert result["datetime"].tzname() is None
assert result["datetime"].timetz() == dt.time(19, 19, 30, 123000)
assert result["datetime"].tzinfo is None
assert result["datetime_notz"] == OUTPUT_DATETIME_NOTZ
assert result["datetime_notz"].tzname() is None
assert result["datetime_notz"].timetz() == OUTPUT_TIME
assert result["datetime_notz"].tzinfo is None
assert result["datetime_tz"] == OUTPUT_DATETIME_NOTZ
assert result["datetime_tz"].tzname() is None
assert result["datetime_tz"].timetz() == OUTPUT_TIME
assert result["datetime_tz"].tzinfo is None


@pytest.mark.skipif(SA_VERSION < SA_1_4, reason="Test case not supported on SQLAlchemy 1.3")
Expand All @@ -163,21 +168,25 @@ def test_datetime_tz(session):
# Insert record.
foo_item = FooBar(
name="foo",
date=dt.date(2009, 5, 13),
datetime=INPUT_DATETIME_TZ,
timestamp=INPUT_DATETIME_TZ,
date=INPUT_DATE,
datetime_notz=INPUT_DATETIME_TZ,
datetime_tz=INPUT_DATETIME_TZ,
)
session.add(foo_item)
session.commit()
session.execute(sa.text("REFRESH TABLE foobar"))

# Query record.
result = session.execute(sa.select(FooBar.name, FooBar.date, FooBar.datetime, FooBar.timestamp)).mappings().first()
result = session.execute(sa.select(
FooBar.name, FooBar.date, FooBar.datetime_notz, FooBar.datetime_tz)).mappings().first()

# Compare outcome.
assert result["date"] == OUTPUT_DATE
assert result["datetime"] == OUTPUT_DATETIME_TZ
assert result["timestamp"] == OUTPUT_DATETIME_TZ
assert result["datetime"].tzname() is None
assert result["datetime"].timetz() == OUTPUT_TIME
assert result["datetime"].tzinfo is None
assert result["datetime_notz"] == OUTPUT_DATETIME_TZ
assert result["datetime_notz"].tzname() is None
assert result["datetime_notz"].timetz() == OUTPUT_TIME
assert result["datetime_notz"].tzinfo is None
assert result["datetime_tz"] == OUTPUT_DATETIME_TZ
assert result["datetime_tz"].tzname() is None
assert result["datetime_tz"].timetz() == OUTPUT_TIME
assert result["datetime_tz"].tzinfo is None

0 comments on commit 1258b4a

Please sign in to comment.