Skip to content

Commit

Permalink
Fix creation of columns using a TypeDecorator (#343)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrien-berchet authored Jan 5, 2022
1 parent 9b251d8 commit 5c864c7
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 6 deletions.
26 changes: 21 additions & 5 deletions geoalchemy2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import sqlalchemy
from sqlalchemy import Table, event
from sqlalchemy.sql import select, func, expression, text
from sqlalchemy.types import TypeDecorator

from packaging import version

Expand All @@ -31,6 +32,16 @@ def _format_select_args(*args):
return args


def _check_spatial_type(tested_type, spatial_types):
return (
isinstance(tested_type, spatial_types)
or (
isinstance(tested_type, TypeDecorator)
and isinstance(tested_type.impl, spatial_types)
)
)


def _setup_ddl_event_listeners():
@event.listens_for(Table, "before_create")
def before_create(target, connection, **kw):
Expand All @@ -53,8 +64,8 @@ def dispatch(event, table, bind):
# Filter Geometry columns from the table with management=True
# Note: Geography and PostGIS >= 2.0 don't need this
gis_cols = [c for c in table.c if
isinstance(c.type, Geometry) and
c.type.management is True]
_check_spatial_type(c.type, Geometry)
and c.type.management is True]

# Find all other columns that are not managed Geometries
regular_cols = [x for x in table.c if x not in gis_cols]
Expand Down Expand Up @@ -91,7 +102,10 @@ def dispatch(event, table, bind):

for c in table.c:
# Add the managed Geometry columns with AddGeometryColumn()
if isinstance(c.type, Geometry) and c.type.management is True:
if (
_check_spatial_type(c.type, Geometry)
and c.type.management is True
):
args = [table.schema] if table.schema else []
args.extend([
table.name,
Expand All @@ -110,8 +124,10 @@ def dispatch(event, table, bind):
bind.execute(stmt)

# Add spatial indices for the Geometry and Geography columns
if isinstance(c.type, (Geometry, Geography)) and \
c.type.spatial_index is True:
if (
_check_spatial_type(c.type, (Geometry, Geography))
and c.type.spatial_index is True
):
if bind.dialect.name == 'sqlite':
stmt = select(*_format_select_args(func.CreateSpatialIndex(table.name,
c.name)))
Expand Down
22 changes: 22 additions & 0 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sqlalchemy.exc import DataError, IntegrityError, InternalError, ProgrammingError
from sqlalchemy.sql import select, func
from sqlalchemy.sql.expression import type_coerce
from sqlalchemy.types import TypeDecorator
from sqlalchemy import __version__ as SA_VERSION

from geoalchemy2 import Geometry, Geography, Raster
Expand Down Expand Up @@ -75,6 +76,20 @@ def __init__(self, geom):
self.geom = geom


class ThreeDGeometry(TypeDecorator):
"""This class is used to insert a ST_Force3D() in each insert."""
impl = Geometry

def bind_expression(self, bindvalue):
return func.ST_Force3D(self.impl.bind_expression(bindvalue))


class PointZ(Base):
__tablename__ = "point_z"
id = Column(Integer, primary_key=True)
three_d_geom = Column(ThreeDGeometry(srid=4326, geometry_type="POINTZ", dimension=3))


class IndexTestWithSchema(Base):
__tablename__ = 'indextestwithschema'
__table_args__ = {'schema': 'gis'}
Expand Down Expand Up @@ -198,6 +213,13 @@ def test_index_without_schema(self):
assert not indices[1].get('unique')
assert indices[1].get('column_names')[0] in (u'geom1', u'geom2')

def test_type_decorator_index(self):
inspector = get_inspector(engine)
indices = inspector.get_indexes(PointZ.__tablename__)
assert len(indices) == 1
assert not indices[0].get('unique')
assert indices[0].get('column_names') == ['three_d_geom']


class TestTypMod():

Expand Down
64 changes: 63 additions & 1 deletion tests/test_functional_spatialite.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from json import loads
import os
import re
from json import loads

from pkg_resources import parse_version
import pytest
import platform
Expand All @@ -11,6 +13,7 @@
from sqlalchemy.event import listen
from sqlalchemy.exc import IntegrityError
from sqlalchemy.sql import select, func
from sqlalchemy.types import TypeDecorator

from geoalchemy2 import Geometry
from geoalchemy2.elements import WKTElement, WKBElement
Expand Down Expand Up @@ -51,6 +54,39 @@ def __init__(self, geom):
self.geom = geom


class TransformedGeometry(TypeDecorator):
"""This class is used to insert a ST_Transform() in each insert or select."""
impl = Geometry

def __init__(self, db_srid, app_srid, **kwargs):
kwargs["srid"] = db_srid
self.impl = self.__class__.impl(**kwargs)
self.app_srid = app_srid
self.db_srid = db_srid

def column_expression(self, col):
"""The column_expression() method is overrided to ensure that the
SRID of the resulting WKBElement is correct"""
return getattr(func, self.impl.as_binary)(
func.ST_Transform(col, self.app_srid),
type_=self.__class__.impl(srid=self.app_srid)
# srid could also be -1 so that the SRID is deduced from the
# WKB data
)

def bind_expression(self, bindvalue):
return func.ST_Transform(
self.impl.bind_expression(bindvalue), self.db_srid)


class LocalPoint(Base):
__tablename__ = "local_point"
id = Column(Integer, primary_key=True)
geom = Column(
TransformedGeometry(
db_srid=2154, app_srid=4326, geometry_type="POINT", management=True))


session = sessionmaker(bind=engine)()

session.execute('SELECT InitSpatialMetaData()')
Expand Down Expand Up @@ -150,6 +186,32 @@ def test_WKBElement(self):
srid = session.execute(lake.geom.ST_SRID()).scalar()
assert srid == 4326

def test_transform(self):
# Create new point instance
p = LocalPoint()
p.geom = "SRID=4326;POINT(5 45)" # Insert 2D geometry into 3D column

# Insert point
session.add(p)
session.flush()
session.expire(p)

# Query the point and check the result
pt = session.query(LocalPoint).one()
assert pt.id == 1
assert pt.geom.srid == 4326
pt_wkb = to_shape(pt.geom)
assert round(pt_wkb.x, 5) == 5
assert round(pt_wkb.y, 5) == 45

# Check that the data is correct in DB using raw query
q = "SELECT id, ST_AsText(geom) AS geom FROM local_point;"
res_q = session.execute(q).fetchone()
assert res_q.id == 1
x, y = re.match(r"POINT\((\d+\.\d*) (\d+\.\d*)\)", res_q.geom).groups()
assert round(float(x), 3) == 857581.899
assert round(float(y), 3) == 6435414.748


class TestShapely():

Expand Down

0 comments on commit 5c864c7

Please sign in to comment.