Skip to content

Commit

Permalink
use geometry collection filters directly as filter_list_or
Browse files Browse the repository at this point in the history
  • Loading branch information
vvmruder committed Oct 7, 2015
1 parent 95915ed commit 4b290de
Showing 1 changed file with 12 additions and 42 deletions.
54 changes: 12 additions & 42 deletions lib/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,11 @@ def decide_geometric_relation_type(self, value, column_name, compare_type):
for column in columns:
if column.get('column_name') == column_name:
if column.get('type') == 'GEOMETRYCOLLECTION':
ids = self.extract_geometry_collection_db(value, compare_type, column_name)
for id_value in ids:
self.filter_list_or.append(getattr(self.mapped_class, self.mapped_class.description().get('pk_name')) == id_value)
self.extract_geometry_collection_db(value, compare_type, column_name)
elif 'GEOMETRYCOLLECTION' in value:
ids = self.extract_geometry_collection_input(value, compare_type, column_name)
for id_value in ids:
self.filter_list_or.append(getattr(self.mapped_class, self.mapped_class.description().get('pk_name')) == id_value)
self.extract_geometry_collection_input(value, compare_type, column_name)
elif 'GEOMETRYCOLLECTION' in value and column.get('type') == 'GEOMETRYCOLLECTION':
ids = self.extract_geometry_collection_input_and_db(value, compare_type, column_name)
for id_value in ids:
self.filter_list_or.append(getattr(self.mapped_class, self.mapped_class.description().get('pk_name')) == id_value)
self.extract_geometry_collection_input_and_db(value, compare_type, column_name)
elif column.get('type') != 'GEOMETRYCOLLECTION':
if compare_type == 'ST_Intersects':
self.filter_list.append(getattr(self.mapped_class, column_name).intersects(WKTSpatialElement(value, srid=2056)))
Expand All @@ -144,7 +138,6 @@ def decide_geometric_relation_type(self, value, column_name, compare_type):
self.filter_list.append(getattr(self.mapped_class, column_name).within(WKTSpatialElement(value, srid=2056)))

def extract_geometry_collection_db(self, compare_geometry, compare_type, column_name):
result_ids = []
db_path_list = [
self.mapped_class.__table_args__.get('schema'),
self.mapped_class.__table__.name,
Expand All @@ -154,19 +147,11 @@ def extract_geometry_collection_db(self, compare_geometry, compare_type, column_
sql_text_point = '{0}(ST_CollectionExtract({1}, 1), ST_GeomFromText(\'{2}\', 2056))'.format(compare_type, db_path, compare_geometry)
sql_text_line = '{0}(ST_CollectionExtract({1}, 2), ST_GeomFromText(\'{2}\', 2056))'.format(compare_type, db_path, compare_geometry)
sql_text_polygon = '{0}(ST_CollectionExtract({1}, 3), ST_GeomFromText(\'{2}\', 2056))'.format(compare_type, db_path, compare_geometry)
sub_query_point = self.db_session.query(self.mapped_class).filter(text(sql_text_point)).all()
sub_query_line = self.db_session.query(self.mapped_class).filter(text(sql_text_line)).all()
sub_query_polygon = self.db_session.query(self.mapped_class).filter(text(sql_text_polygon)).all()
for point in sub_query_point:
result_ids.append(getattr(point, self.mapped_class.description().get('pk_name')))
for line in sub_query_line:
result_ids.append(getattr(line, self.mapped_class.description().get('pk_name')))
for polygon in sub_query_polygon:
result_ids.append(getattr(polygon, self.mapped_class.description().get('pk_name')))
return result_ids
self.filter_list_or.append(text(sql_text_point))
self.filter_list_or.append(text(sql_text_line))
self.filter_list_or.append(text(sql_text_polygon))

def extract_geometry_collection_input(self, compare_geometry, compare_type, column_name):
result_ids = []
db_path_list = [
self.mapped_class.__table_args__.get('schema'),
self.mapped_class.__table__.name,
Expand All @@ -176,19 +161,11 @@ def extract_geometry_collection_input(self, compare_geometry, compare_type, colu
sql_text_point = '{0}({1}, ST_CollectionExtract(ST_GeomFromText(\'{2}\', 2056), 1))'.format(compare_type, db_path, compare_geometry)
sql_text_line = '{0}({1}, ST_CollectionExtract(ST_GeomFromText(\'{2}\', 2056), 2))'.format(compare_type, db_path, compare_geometry)
sql_text_polygon = '{0}({1}, ST_CollectionExtract(ST_GeomFromText(\'{2}\', 2056), 3))'.format(compare_type, db_path, compare_geometry)
sub_query_point = self.db_session.query(self.mapped_class).filter(text(sql_text_point)).all()
sub_query_line = self.db_session.query(self.mapped_class).filter(text(sql_text_line)).all()
sub_query_polygon = self.db_session.query(self.mapped_class).filter(text(sql_text_polygon)).all()
for point in sub_query_point:
result_ids.append(getattr(point, self.mapped_class.description().get('pk_name')))
for line in sub_query_line:
result_ids.append(getattr(line, self.mapped_class.description().get('pk_name')))
for polygon in sub_query_polygon:
result_ids.append(getattr(polygon, self.mapped_class.description().get('pk_name')))
return result_ids
self.filter_list_or.append(text(sql_text_point))
self.filter_list_or.append(text(sql_text_line))
self.filter_list_or.append(text(sql_text_polygon))

def extract_geometry_collection_input_and_db(self, compare_geometry, compare_type, column_name):
result_ids = []
db_path_list = [
self.mapped_class.__table_args__.get('schema'),
self.mapped_class.__table__.name,
Expand All @@ -198,13 +175,6 @@ def extract_geometry_collection_input_and_db(self, compare_geometry, compare_typ
sql_text_point = '{0}((ST_CollectionExtract({1}, 1), ST_CollectionExtract(ST_GeomFromText(\'{2}\', 2056), 1))'.format(compare_type, db_path, compare_geometry)
sql_text_line = '{0}((ST_CollectionExtract({1}, 2), ST_CollectionExtract(ST_GeomFromText(\'{2}\', 2056), 2))'.format(compare_type, db_path, compare_geometry)
sql_text_polygon = '{0}((ST_CollectionExtract({1}, 3), ST_CollectionExtract(ST_GeomFromText(\'{2}\', 2056), 3))'.format(compare_type, db_path, compare_geometry)
sub_query_point = self.db_session.query(self.mapped_class).filter(text(sql_text_point)).all()
sub_query_line = self.db_session.query(self.mapped_class).filter(text(sql_text_line)).all()
sub_query_polygon = self.db_session.query(self.mapped_class).filter(text(sql_text_polygon)).all()
for point in sub_query_point:
result_ids.append(getattr(point, self.mapped_class.description().get('pk_name')))
for line in sub_query_line:
result_ids.append(getattr(line, self.mapped_class.description().get('pk_name')))
for polygon in sub_query_polygon:
result_ids.append(getattr(polygon, self.mapped_class.description().get('pk_name')))
return result_ids
self.filter_list_or.append(text(sql_text_point))
self.filter_list_or.append(text(sql_text_line))
self.filter_list_or.append(text(sql_text_polygon))

0 comments on commit 4b290de

Please sign in to comment.