diff --git a/geoalchemy2/__init__.py b/geoalchemy2/__init__.py index 1d56c98d..e442cf55 100644 --- a/geoalchemy2/__init__.py +++ b/geoalchemy2/__init__.py @@ -18,6 +18,7 @@ import sqlalchemy from sqlalchemy import Table, event from sqlalchemy.sql import select, func, expression, text +from sqlalchemy.types import TypeDecorator from packaging import version @@ -31,6 +32,16 @@ def _format_select_args(*args): return args +def _check_spatial_type(tested_type, spatial_types): + return ( + isinstance(tested_type, spatial_types) + or ( + isinstance(tested_type, TypeDecorator) + and isinstance(tested_type.impl, spatial_types) + ) + ) + + def _setup_ddl_event_listeners(): @event.listens_for(Table, "before_create") def before_create(target, connection, **kw): @@ -53,8 +64,8 @@ def dispatch(event, table, bind): # Filter Geometry columns from the table with management=True # Note: Geography and PostGIS >= 2.0 don't need this gis_cols = [c for c in table.c if - isinstance(c.type, Geometry) and - c.type.management is True] + _check_spatial_type(c.type, Geometry) + and c.type.management is True] # Find all other columns that are not managed Geometries regular_cols = [x for x in table.c if x not in gis_cols] @@ -91,7 +102,10 @@ def dispatch(event, table, bind): for c in table.c: # Add the managed Geometry columns with AddGeometryColumn() - if isinstance(c.type, Geometry) and c.type.management is True: + if ( + _check_spatial_type(c.type, Geometry) + and c.type.management is True + ): args = [table.schema] if table.schema else [] args.extend([ table.name, @@ -110,8 +124,10 @@ def dispatch(event, table, bind): bind.execute(stmt) # Add spatial indices for the Geometry and Geography columns - if isinstance(c.type, (Geometry, Geography)) and \ - c.type.spatial_index is True: + if ( + _check_spatial_type(c.type, (Geometry, Geography)) + and c.type.spatial_index is True + ): if bind.dialect.name == 'sqlite': stmt = select(*_format_select_args(func.CreateSpatialIndex(table.name, c.name))) diff --git a/tests/test_functional.py b/tests/test_functional.py index 2cfc81e4..3b16f4ad 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -19,6 +19,7 @@ from sqlalchemy.exc import DataError, IntegrityError, InternalError, ProgrammingError from sqlalchemy.sql import select, func from sqlalchemy.sql.expression import type_coerce +from sqlalchemy.types import TypeDecorator from sqlalchemy import __version__ as SA_VERSION from geoalchemy2 import Geometry, Geography, Raster @@ -75,6 +76,20 @@ def __init__(self, geom): self.geom = geom +class ThreeDGeometry(TypeDecorator): + """This class is used to insert a ST_Force3D() in each insert.""" + impl = Geometry + + def bind_expression(self, bindvalue): + return func.ST_Force3D(self.impl.bind_expression(bindvalue)) + + +class PointZ(Base): + __tablename__ = "point_z" + id = Column(Integer, primary_key=True) + three_d_geom = Column(ThreeDGeometry(srid=4326, geometry_type="POINTZ", dimension=3)) + + class IndexTestWithSchema(Base): __tablename__ = 'indextestwithschema' __table_args__ = {'schema': 'gis'} @@ -198,6 +213,13 @@ def test_index_without_schema(self): assert not indices[1].get('unique') assert indices[1].get('column_names')[0] in (u'geom1', u'geom2') + def test_type_decorator_index(self): + inspector = get_inspector(engine) + indices = inspector.get_indexes(PointZ.__tablename__) + assert len(indices) == 1 + assert not indices[0].get('unique') + assert indices[0].get('column_names') == ['three_d_geom'] + class TestTypMod(): diff --git a/tests/test_functional_spatialite.py b/tests/test_functional_spatialite.py index eb3581e8..192e9673 100644 --- a/tests/test_functional_spatialite.py +++ b/tests/test_functional_spatialite.py @@ -1,5 +1,7 @@ -from json import loads import os +import re +from json import loads + from pkg_resources import parse_version import pytest import platform @@ -11,6 +13,7 @@ from sqlalchemy.event import listen from sqlalchemy.exc import IntegrityError from sqlalchemy.sql import select, func +from sqlalchemy.types import TypeDecorator from geoalchemy2 import Geometry from geoalchemy2.elements import WKTElement, WKBElement @@ -51,6 +54,39 @@ def __init__(self, geom): self.geom = geom +class TransformedGeometry(TypeDecorator): + """This class is used to insert a ST_Transform() in each insert or select.""" + impl = Geometry + + def __init__(self, db_srid, app_srid, **kwargs): + kwargs["srid"] = db_srid + self.impl = self.__class__.impl(**kwargs) + self.app_srid = app_srid + self.db_srid = db_srid + + def column_expression(self, col): + """The column_expression() method is overrided to ensure that the + SRID of the resulting WKBElement is correct""" + return getattr(func, self.impl.as_binary)( + func.ST_Transform(col, self.app_srid), + type_=self.__class__.impl(srid=self.app_srid) + # srid could also be -1 so that the SRID is deduced from the + # WKB data + ) + + def bind_expression(self, bindvalue): + return func.ST_Transform( + self.impl.bind_expression(bindvalue), self.db_srid) + + +class LocalPoint(Base): + __tablename__ = "local_point" + id = Column(Integer, primary_key=True) + geom = Column( + TransformedGeometry( + db_srid=2154, app_srid=4326, geometry_type="POINT", management=True)) + + session = sessionmaker(bind=engine)() session.execute('SELECT InitSpatialMetaData()') @@ -150,6 +186,32 @@ def test_WKBElement(self): srid = session.execute(lake.geom.ST_SRID()).scalar() assert srid == 4326 + def test_transform(self): + # Create new point instance + p = LocalPoint() + p.geom = "SRID=4326;POINT(5 45)" # Insert 2D geometry into 3D column + + # Insert point + session.add(p) + session.flush() + session.expire(p) + + # Query the point and check the result + pt = session.query(LocalPoint).one() + assert pt.id == 1 + assert pt.geom.srid == 4326 + pt_wkb = to_shape(pt.geom) + assert round(pt_wkb.x, 5) == 5 + assert round(pt_wkb.y, 5) == 45 + + # Check that the data is correct in DB using raw query + q = "SELECT id, ST_AsText(geom) AS geom FROM local_point;" + res_q = session.execute(q).fetchone() + assert res_q.id == 1 + x, y = re.match(r"POINT\((\d+\.\d*) (\d+\.\d*)\)", res_q.geom).groups() + assert round(float(x), 3) == 857581.899 + assert round(float(y), 3) == 6435414.748 + class TestShapely():