Skip to content

Commit

Permalink
Improve coverage (#394)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrien-berchet authored Jul 28, 2022
1 parent 82b0f15 commit 9d89df1
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 45 deletions.
76 changes: 34 additions & 42 deletions geoalchemy2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,35 +125,32 @@ def after_parent_attach(column, table):
except AttributeError:
pass

kwargs = {
'postgresql_using': 'gist',
'_column_flag': True,
}
col = column
if _check_spatial_type(column.type, (Geometry, Geography)):
if column.type.use_N_D_index:
postgresql_ops = {column.name: "gist_geometry_ops_nd"}
else:
postgresql_ops = {}
table.append_constraint(
Index(
_spatial_idx_name(table.name, column.name),
column,
postgresql_using='gist',
postgresql_ops=postgresql_ops,
_column_flag=True,
)
)
kwargs['postgresql_ops'] = {column.name: "gist_geometry_ops_nd"}
elif _check_spatial_type(column.type, Raster):
table.append_constraint(
Index(
_spatial_idx_name(table.name, column.name),
func.ST_ConvexHull(column),
postgresql_using='gist',
_column_flag=True,
)
col = func.ST_ConvexHull(column)

table.append_constraint(
Index(
_spatial_idx_name(table.name, column.name),
col,
**kwargs,
)
)

def dispatch(current_event, table, bind):
if current_event in ('before-create', 'before-drop'):
dialect = bind.dialect

# Filter Geometry columns from the table with management=True
# Note: Geography and PostGIS >= 2.0 don't need this
gis_cols = _get_gis_cols(table, Geometry, bind.dialect, check_col_management=True)
gis_cols = _get_gis_cols(table, Geometry, dialect, check_col_management=True)

# Find all other columns that are not managed Geometries
regular_cols = [x for x in table.columns if x not in gis_cols]
Expand All @@ -171,7 +168,7 @@ def dispatch(current_event, table, bind):
if current_event == 'before-drop':
# Drop the managed Geometry columns
for col in gis_cols:
if bind.dialect.name == 'sqlite':
if dialect.name == 'sqlite':
drop_func = 'DiscardGeometryColumn'

# Disable spatial indexes if present
Expand All @@ -198,10 +195,10 @@ def dispatch(current_event, table, bind):
)
)
)
elif bind.dialect.name == 'postgresql':
elif dialect.name == 'postgresql':
drop_func = 'DropGeometryColumn'
else:
raise ArgumentError('dialect {} is not supported'.format(bind.dialect.name))
raise ArgumentError('dialect {} is not supported'.format(dialect.name))
args = [table.schema] if table.schema else []
args.extend([table.name, col.name])

Expand All @@ -217,16 +214,16 @@ def dispatch(current_event, table, bind):
for idx in current_indexes:
for col in table.info["_saved_columns"]:
if (
_check_spatial_type(col.type, Geometry, bind.dialect)
and check_management(col, bind.dialect)
_check_spatial_type(col.type, Geometry, dialect)
and check_management(col, dialect)
) and col in idx.columns.values():
table.indexes.remove(idx)
if (
idx.name != _spatial_idx_name(table.name, col.name)
or not getattr(col.type, "spatial_index", False)
):
table.info["_after_create_indexes"].append(idx)
if bind.dialect.name == 'sqlite':
if dialect.name == 'sqlite':
for col in gis_cols:
# Add dummy columns with GEOMETRY type
col._actual_type = col.type
Expand All @@ -236,16 +233,18 @@ def dispatch(current_event, table, bind):

elif current_event == 'after-create':
# Restore original column list including managed Geometry columns
dialect = bind.dialect

table.columns = table.info.pop('_saved_columns')

for col in table.columns:
# Add the managed Geometry columns with AddGeometryColumn()
if (
_check_spatial_type(col.type, Geometry, bind.dialect)
and check_management(col, bind.dialect)
_check_spatial_type(col.type, Geometry, dialect)
and check_management(col, dialect)
):
dimension = col.type.dimension
if bind.dialect.name == 'sqlite':
if dialect.name == 'sqlite':
col.type = col._actual_type
del col._actual_type
create_func = func.RecoverGeometryColumn
Expand All @@ -268,7 +267,7 @@ def dispatch(current_event, table, bind):
col.type.geometry_type,
dimension
])
if col.type.use_typmod is not None:
if col.type.use_typmod is not None and dialect.name != 'sqlite':
args.append(col.type.use_typmod)

stmt = select(*_format_select_args(create_func(*args)))
Expand All @@ -277,20 +276,20 @@ def dispatch(current_event, table, bind):

# Add spatial indices for the Geometry and Geography columns
if (
_check_spatial_type(col.type, (Geometry, Geography), bind.dialect)
_check_spatial_type(col.type, (Geometry, Geography), dialect)
and col.type.spatial_index is True
):
if bind.dialect.name == 'sqlite':
if dialect.name == 'sqlite':
stmt = select(*_format_select_args(func.CreateSpatialIndex(table.name,
col.name)))
stmt = stmt.execution_options(autocommit=True)
bind.execute(stmt)
elif bind.dialect.name == 'postgresql':
elif dialect.name == 'postgresql':
# If the index does not exist (which might be the case when
# management=False), define it and create it
if (
not [i for i in table.indexes if col in i.columns.values()]
and check_management(col, bind.dialect)
and check_management(col, dialect)
):
if col.type.use_N_D_index:
postgresql_ops = {col.name: "gist_geometry_ops_nd"}
Expand All @@ -306,14 +305,7 @@ def dispatch(current_event, table, bind):
idx.create(bind=bind)

else:
raise ArgumentError('dialect {} is not supported'.format(bind.dialect.name))

if (
isinstance(col.type, (Geometry, Geography))
and col.type.spatial_index is False
and col.type.use_N_D_index is True
):
raise ArgumentError('Arg Error(use_N_D_index): spatial_index must be True')
raise ArgumentError('dialect {} is not supported'.format(dialect.name))

for idx in table.info.pop("_after_create_indexes"):
table.indexes.add(idx)
Expand Down
2 changes: 1 addition & 1 deletion geoalchemy2/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __setstate__(self, state):

@staticmethod
def _data_from_desc(desc):
raise NotImplementedError()
raise NotImplementedError() # pragma: no cover


class WKTElement(_SpatialElement):
Expand Down
5 changes: 3 additions & 2 deletions tests/test_alembic_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def alembic_env(
process_revision_directives=alembic_helpers.writer,
render_item=alembic_helpers.render_item,
include_object=alembic_helpers.include_object,
render_as_batch=True
render_as_batch={}
)
try:
Expand All @@ -195,7 +195,8 @@ def alembic_env(
engine.dispose()
""".format(
str(test_script_path)
str(test_script_path),
True if engine.dialect.name == "sqlite" else False
)
)
with test_script_path.open(mode="w", encoding="utf8") as f:
Expand Down
36 changes: 36 additions & 0 deletions tests/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ def test_eq(self):
b = WKTElement('POINT(1 2)')
assert a == b

def test_hash(self):
a = WKTElement('POINT(1 2)')
b = WKTElement('POINT(10 20)')
c = WKTElement('POINT(10 20)')
assert set([a, b, c]) == set([a, b, c])
assert len(set([a, b, c])) == 2


class TestExtendedWKTElement():

Expand Down Expand Up @@ -134,6 +141,13 @@ def test_eq(self):
b = WKTElement(self._ewkt, extended=True)
assert a == b

def test_hash(self):
a = WKTElement('SRID=3857;POINT (1 2 3)', extended=True)
b = WKTElement('SRID=3857;POINT (10 20 30)', extended=True)
c = WKTElement('SRID=3857;POINT (10 20 30)', extended=True)
assert set([a, b, c]) == set([a, b, c])
assert len(set([a, b, c])) == 2

def test_missing_srid(self):
with pytest.raises(ArgumentError, match="invalid EWKT string"):
WKTElement(self._wkt, extended=True)
Expand Down Expand Up @@ -270,6 +284,13 @@ def test_eq(self):
b = WKBElement(self._bin, extended=True)
assert a == b

def test_hash(self):
a = WKBElement(str('010100002003000000000000000000f03f0000000000000040'), extended=True)
b = WKBElement(str('010100002003000000000000000000f02f0000000000000040'), extended=True)
c = WKBElement(str('010100002003000000000000000000f02f0000000000000040'), extended=True)
assert set([a, b, c]) == set([a, b, c])
assert len(set([a, b, c])) == 2


class TestWKBElement():

Expand Down Expand Up @@ -301,6 +322,13 @@ def test_eq(self):
b = WKBElement(b'\x01\x02')
assert a == b

def test_hash(self):
a = WKBElement(b'\x01\x02')
b = WKBElement(b'\x01\x03')
c = WKBElement(b'\x01\x03')
assert set([a, b, c]) == set([a, b, c])
assert len(set([a, b, c])) == 2


class TestNotEqualSpatialElement():

Expand Down Expand Up @@ -385,6 +413,14 @@ def test_pickle_unpickle(self):
u'raster_1': self.hex_rast_data,
}

def test_hash(self):
new_hex_rast_data = self.hex_rast_data.replace('f', 'e')
a = WKBElement(self.hex_rast_data)
b = WKBElement(new_hex_rast_data)
c = WKBElement(new_hex_rast_data)
assert set([a, b, c]) == set([a, b, c])
assert len(set([a, b, c])) == 2


class TestCompositeElement():

Expand Down
53 changes: 53 additions & 0 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from sqlalchemy.sql import func
from sqlalchemy.testing.assertions import ComparesTables

import geoalchemy2
from geoalchemy2 import Geometry
from geoalchemy2 import Raster
from geoalchemy2 import _get_spatialite_attrs
Expand All @@ -46,10 +47,43 @@
from . import skip_case_insensitivity
from . import skip_pg12_sa1217
from . import skip_postgis1
from . import test_only_with_dialects

SQLA_LT_2 = parse_version(SA_VERSION) <= parse_version("1.999")


class TestAdmin():

def test_create_drop_tables(
self,
conn,
metadata,
Lake,
Poi,
Summit,
Ocean,
PointZ,
LocalPoint,
IndexTestWithSchema,
IndexTestWithNDIndex,
IndexTestWithoutSchema,
):
metadata.drop_all(conn, checkfirst=True)
metadata.create_all(conn)
metadata.drop_all(conn, checkfirst=True)


class TestMiscellaneous():

@test_only_with_dialects("sqlite")
def test_load_spatialite(self, monkeypatch, conn):
geoalchemy2.load_spatialite(conn.connection.dbapi_connection, None)

monkeypatch.delenv("SPATIALITE_LIBRARY_PATH")
with pytest.raises(RuntimeError):
geoalchemy2.load_spatialite(conn.connection.dbapi_connection, None)


class TestInsertionCore():

def test_insert(self, conn, Lake, setup_tables):
Expand Down Expand Up @@ -745,6 +779,25 @@ def test_raster_reflection(self, conn, Ocean, setup_tables):
type_ = t.c.rast.type
assert isinstance(type_, Raster)

@test_only_with_dialects("sqlite")
def test_sqlite_reflection_with_discarded_col(self, conn, Lake, setup_tables):
"""Test that a discarded geometry column is not properly reflected with SQLite."""
conn.execute("""DELETE FROM "geometry_columns" WHERE f_table_name = 'lake';""")
t = Table(
'lake',
MetaData(),
autoload_with=conn,
)

# In this case the reflected type is generic with default values
assert t.c.geom.type.geometry_type == "GEOMETRY"
assert t.c.geom.type.dimension == 2
assert t.c.geom.type.extended
assert not t.c.geom.type.management
assert t.c.geom.type.nullable
assert t.c.geom.type.spatial_index
assert t.c.geom.type.srid == -1

@pytest.fixture
def ocean_view(self, conn, Ocean):
conn.execute("CREATE VIEW test_view AS SELECT * FROM ocean;")
Expand Down

0 comments on commit 9d89df1

Please sign in to comment.