From 9d89df18a4e96f1edb25b407f3a056c68037b063 Mon Sep 17 00:00:00 2001 From: Adrien Berchet Date: Thu, 28 Jul 2022 17:41:51 +0200 Subject: [PATCH] Improve coverage (#394) --- geoalchemy2/__init__.py | 76 ++++++++++++++------------------ geoalchemy2/elements.py | 2 +- tests/test_alembic_migrations.py | 5 ++- tests/test_elements.py | 36 +++++++++++++++ tests/test_functional.py | 53 ++++++++++++++++++++++ 5 files changed, 127 insertions(+), 45 deletions(-) diff --git a/geoalchemy2/__init__.py b/geoalchemy2/__init__.py index c1b941a9..0bdecd27 100644 --- a/geoalchemy2/__init__.py +++ b/geoalchemy2/__init__.py @@ -125,35 +125,32 @@ def after_parent_attach(column, table): except AttributeError: pass + kwargs = { + 'postgresql_using': 'gist', + '_column_flag': True, + } + col = column if _check_spatial_type(column.type, (Geometry, Geography)): if column.type.use_N_D_index: - postgresql_ops = {column.name: "gist_geometry_ops_nd"} - else: - postgresql_ops = {} - table.append_constraint( - Index( - _spatial_idx_name(table.name, column.name), - column, - postgresql_using='gist', - postgresql_ops=postgresql_ops, - _column_flag=True, - ) - ) + kwargs['postgresql_ops'] = {column.name: "gist_geometry_ops_nd"} elif _check_spatial_type(column.type, Raster): - table.append_constraint( - Index( - _spatial_idx_name(table.name, column.name), - func.ST_ConvexHull(column), - postgresql_using='gist', - _column_flag=True, - ) + col = func.ST_ConvexHull(column) + + table.append_constraint( + Index( + _spatial_idx_name(table.name, column.name), + col, + **kwargs, ) + ) def dispatch(current_event, table, bind): if current_event in ('before-create', 'before-drop'): + dialect = bind.dialect + # Filter Geometry columns from the table with management=True # Note: Geography and PostGIS >= 2.0 don't need this - gis_cols = _get_gis_cols(table, Geometry, bind.dialect, check_col_management=True) + gis_cols = _get_gis_cols(table, Geometry, dialect, check_col_management=True) # Find all other columns that are not managed Geometries regular_cols = [x for x in table.columns if x not in gis_cols] @@ -171,7 +168,7 @@ def dispatch(current_event, table, bind): if current_event == 'before-drop': # Drop the managed Geometry columns for col in gis_cols: - if bind.dialect.name == 'sqlite': + if dialect.name == 'sqlite': drop_func = 'DiscardGeometryColumn' # Disable spatial indexes if present @@ -198,10 +195,10 @@ def dispatch(current_event, table, bind): ) ) ) - elif bind.dialect.name == 'postgresql': + elif dialect.name == 'postgresql': drop_func = 'DropGeometryColumn' else: - raise ArgumentError('dialect {} is not supported'.format(bind.dialect.name)) + raise ArgumentError('dialect {} is not supported'.format(dialect.name)) args = [table.schema] if table.schema else [] args.extend([table.name, col.name]) @@ -217,8 +214,8 @@ def dispatch(current_event, table, bind): for idx in current_indexes: for col in table.info["_saved_columns"]: if ( - _check_spatial_type(col.type, Geometry, bind.dialect) - and check_management(col, bind.dialect) + _check_spatial_type(col.type, Geometry, dialect) + and check_management(col, dialect) ) and col in idx.columns.values(): table.indexes.remove(idx) if ( @@ -226,7 +223,7 @@ def dispatch(current_event, table, bind): or not getattr(col.type, "spatial_index", False) ): table.info["_after_create_indexes"].append(idx) - if bind.dialect.name == 'sqlite': + if dialect.name == 'sqlite': for col in gis_cols: # Add dummy columns with GEOMETRY type col._actual_type = col.type @@ -236,16 +233,18 @@ def dispatch(current_event, table, bind): elif current_event == 'after-create': # Restore original column list including managed Geometry columns + dialect = bind.dialect + table.columns = table.info.pop('_saved_columns') for col in table.columns: # Add the managed Geometry columns with AddGeometryColumn() if ( - _check_spatial_type(col.type, Geometry, bind.dialect) - and check_management(col, bind.dialect) + _check_spatial_type(col.type, Geometry, dialect) + and check_management(col, dialect) ): dimension = col.type.dimension - if bind.dialect.name == 'sqlite': + if dialect.name == 'sqlite': col.type = col._actual_type del col._actual_type create_func = func.RecoverGeometryColumn @@ -268,7 +267,7 @@ def dispatch(current_event, table, bind): col.type.geometry_type, dimension ]) - if col.type.use_typmod is not None: + if col.type.use_typmod is not None and dialect.name != 'sqlite': args.append(col.type.use_typmod) stmt = select(*_format_select_args(create_func(*args))) @@ -277,20 +276,20 @@ def dispatch(current_event, table, bind): # Add spatial indices for the Geometry and Geography columns if ( - _check_spatial_type(col.type, (Geometry, Geography), bind.dialect) + _check_spatial_type(col.type, (Geometry, Geography), dialect) and col.type.spatial_index is True ): - if bind.dialect.name == 'sqlite': + if dialect.name == 'sqlite': stmt = select(*_format_select_args(func.CreateSpatialIndex(table.name, col.name))) stmt = stmt.execution_options(autocommit=True) bind.execute(stmt) - elif bind.dialect.name == 'postgresql': + elif dialect.name == 'postgresql': # If the index does not exist (which might be the case when # management=False), define it and create it if ( not [i for i in table.indexes if col in i.columns.values()] - and check_management(col, bind.dialect) + and check_management(col, dialect) ): if col.type.use_N_D_index: postgresql_ops = {col.name: "gist_geometry_ops_nd"} @@ -306,14 +305,7 @@ def dispatch(current_event, table, bind): idx.create(bind=bind) else: - raise ArgumentError('dialect {} is not supported'.format(bind.dialect.name)) - - if ( - isinstance(col.type, (Geometry, Geography)) - and col.type.spatial_index is False - and col.type.use_N_D_index is True - ): - raise ArgumentError('Arg Error(use_N_D_index): spatial_index must be True') + raise ArgumentError('dialect {} is not supported'.format(dialect.name)) for idx in table.info.pop("_after_create_indexes"): table.indexes.add(idx) diff --git a/geoalchemy2/elements.py b/geoalchemy2/elements.py index 57ebaa8b..921f1bd4 100644 --- a/geoalchemy2/elements.py +++ b/geoalchemy2/elements.py @@ -101,7 +101,7 @@ def __setstate__(self, state): @staticmethod def _data_from_desc(desc): - raise NotImplementedError() + raise NotImplementedError() # pragma: no cover class WKTElement(_SpatialElement): diff --git a/tests/test_alembic_migrations.py b/tests/test_alembic_migrations.py index 84443ae3..77a41eb9 100644 --- a/tests/test_alembic_migrations.py +++ b/tests/test_alembic_migrations.py @@ -184,7 +184,7 @@ def alembic_env( process_revision_directives=alembic_helpers.writer, render_item=alembic_helpers.render_item, include_object=alembic_helpers.include_object, - render_as_batch=True + render_as_batch={} ) try: @@ -195,7 +195,8 @@ def alembic_env( engine.dispose() """.format( - str(test_script_path) + str(test_script_path), + True if engine.dialect.name == "sqlite" else False ) ) with test_script_path.open(mode="w", encoding="utf8") as f: diff --git a/tests/test_elements.py b/tests/test_elements.py index aca744cd..c77cff88 100644 --- a/tests/test_elements.py +++ b/tests/test_elements.py @@ -72,6 +72,13 @@ def test_eq(self): b = WKTElement('POINT(1 2)') assert a == b + def test_hash(self): + a = WKTElement('POINT(1 2)') + b = WKTElement('POINT(10 20)') + c = WKTElement('POINT(10 20)') + assert set([a, b, c]) == set([a, b, c]) + assert len(set([a, b, c])) == 2 + class TestExtendedWKTElement(): @@ -134,6 +141,13 @@ def test_eq(self): b = WKTElement(self._ewkt, extended=True) assert a == b + def test_hash(self): + a = WKTElement('SRID=3857;POINT (1 2 3)', extended=True) + b = WKTElement('SRID=3857;POINT (10 20 30)', extended=True) + c = WKTElement('SRID=3857;POINT (10 20 30)', extended=True) + assert set([a, b, c]) == set([a, b, c]) + assert len(set([a, b, c])) == 2 + def test_missing_srid(self): with pytest.raises(ArgumentError, match="invalid EWKT string"): WKTElement(self._wkt, extended=True) @@ -270,6 +284,13 @@ def test_eq(self): b = WKBElement(self._bin, extended=True) assert a == b + def test_hash(self): + a = WKBElement(str('010100002003000000000000000000f03f0000000000000040'), extended=True) + b = WKBElement(str('010100002003000000000000000000f02f0000000000000040'), extended=True) + c = WKBElement(str('010100002003000000000000000000f02f0000000000000040'), extended=True) + assert set([a, b, c]) == set([a, b, c]) + assert len(set([a, b, c])) == 2 + class TestWKBElement(): @@ -301,6 +322,13 @@ def test_eq(self): b = WKBElement(b'\x01\x02') assert a == b + def test_hash(self): + a = WKBElement(b'\x01\x02') + b = WKBElement(b'\x01\x03') + c = WKBElement(b'\x01\x03') + assert set([a, b, c]) == set([a, b, c]) + assert len(set([a, b, c])) == 2 + class TestNotEqualSpatialElement(): @@ -385,6 +413,14 @@ def test_pickle_unpickle(self): u'raster_1': self.hex_rast_data, } + def test_hash(self): + new_hex_rast_data = self.hex_rast_data.replace('f', 'e') + a = WKBElement(self.hex_rast_data) + b = WKBElement(new_hex_rast_data) + c = WKBElement(new_hex_rast_data) + assert set([a, b, c]) == set([a, b, c]) + assert len(set([a, b, c])) == 2 + class TestCompositeElement(): diff --git a/tests/test_functional.py b/tests/test_functional.py index 22ee4e14..3cb3814c 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -31,6 +31,7 @@ from sqlalchemy.sql import func from sqlalchemy.testing.assertions import ComparesTables +import geoalchemy2 from geoalchemy2 import Geometry from geoalchemy2 import Raster from geoalchemy2 import _get_spatialite_attrs @@ -46,10 +47,43 @@ from . import skip_case_insensitivity from . import skip_pg12_sa1217 from . import skip_postgis1 +from . import test_only_with_dialects SQLA_LT_2 = parse_version(SA_VERSION) <= parse_version("1.999") +class TestAdmin(): + + def test_create_drop_tables( + self, + conn, + metadata, + Lake, + Poi, + Summit, + Ocean, + PointZ, + LocalPoint, + IndexTestWithSchema, + IndexTestWithNDIndex, + IndexTestWithoutSchema, + ): + metadata.drop_all(conn, checkfirst=True) + metadata.create_all(conn) + metadata.drop_all(conn, checkfirst=True) + + +class TestMiscellaneous(): + + @test_only_with_dialects("sqlite") + def test_load_spatialite(self, monkeypatch, conn): + geoalchemy2.load_spatialite(conn.connection.dbapi_connection, None) + + monkeypatch.delenv("SPATIALITE_LIBRARY_PATH") + with pytest.raises(RuntimeError): + geoalchemy2.load_spatialite(conn.connection.dbapi_connection, None) + + class TestInsertionCore(): def test_insert(self, conn, Lake, setup_tables): @@ -745,6 +779,25 @@ def test_raster_reflection(self, conn, Ocean, setup_tables): type_ = t.c.rast.type assert isinstance(type_, Raster) + @test_only_with_dialects("sqlite") + def test_sqlite_reflection_with_discarded_col(self, conn, Lake, setup_tables): + """Test that a discarded geometry column is not properly reflected with SQLite.""" + conn.execute("""DELETE FROM "geometry_columns" WHERE f_table_name = 'lake';""") + t = Table( + 'lake', + MetaData(), + autoload_with=conn, + ) + + # In this case the reflected type is generic with default values + assert t.c.geom.type.geometry_type == "GEOMETRY" + assert t.c.geom.type.dimension == 2 + assert t.c.geom.type.extended + assert not t.c.geom.type.management + assert t.c.geom.type.nullable + assert t.c.geom.type.spatial_index + assert t.c.geom.type.srid == -1 + @pytest.fixture def ocean_view(self, conn, Ocean): conn.execute("CREATE VIEW test_view AS SELECT * FROM ocean;")