diff --git a/lib/filter.py b/lib/filter.py index 23da8ee..a36d23a 100644 --- a/lib/filter.py +++ b/lib/filter.py @@ -122,17 +122,11 @@ def decide_geometric_relation_type(self, value, column_name, compare_type): for column in columns: if column.get('column_name') == column_name: if column.get('type') == 'GEOMETRYCOLLECTION': - ids = self.extract_geometry_collection_db(value, compare_type, column_name) - for id_value in ids: - self.filter_list_or.append(getattr(self.mapped_class, self.mapped_class.description().get('pk_name')) == id_value) + self.extract_geometry_collection_db(value, compare_type, column_name) elif 'GEOMETRYCOLLECTION' in value: - ids = self.extract_geometry_collection_input(value, compare_type, column_name) - for id_value in ids: - self.filter_list_or.append(getattr(self.mapped_class, self.mapped_class.description().get('pk_name')) == id_value) + self.extract_geometry_collection_input(value, compare_type, column_name) elif 'GEOMETRYCOLLECTION' in value and column.get('type') == 'GEOMETRYCOLLECTION': - ids = self.extract_geometry_collection_input_and_db(value, compare_type, column_name) - for id_value in ids: - self.filter_list_or.append(getattr(self.mapped_class, self.mapped_class.description().get('pk_name')) == id_value) + self.extract_geometry_collection_input_and_db(value, compare_type, column_name) elif column.get('type') != 'GEOMETRYCOLLECTION': if compare_type == 'ST_Intersects': self.filter_list.append(getattr(self.mapped_class, column_name).intersects(WKTSpatialElement(value, srid=2056))) @@ -144,7 +138,6 @@ def decide_geometric_relation_type(self, value, column_name, compare_type): self.filter_list.append(getattr(self.mapped_class, column_name).within(WKTSpatialElement(value, srid=2056))) def extract_geometry_collection_db(self, compare_geometry, compare_type, column_name): - result_ids = [] db_path_list = [ self.mapped_class.__table_args__.get('schema'), self.mapped_class.__table__.name, @@ -154,19 +147,11 @@ def extract_geometry_collection_db(self, compare_geometry, compare_type, column_ sql_text_point = '{0}(ST_CollectionExtract({1}, 1), ST_GeomFromText(\'{2}\', 2056))'.format(compare_type, db_path, compare_geometry) sql_text_line = '{0}(ST_CollectionExtract({1}, 2), ST_GeomFromText(\'{2}\', 2056))'.format(compare_type, db_path, compare_geometry) sql_text_polygon = '{0}(ST_CollectionExtract({1}, 3), ST_GeomFromText(\'{2}\', 2056))'.format(compare_type, db_path, compare_geometry) - sub_query_point = self.db_session.query(self.mapped_class).filter(text(sql_text_point)).all() - sub_query_line = self.db_session.query(self.mapped_class).filter(text(sql_text_line)).all() - sub_query_polygon = self.db_session.query(self.mapped_class).filter(text(sql_text_polygon)).all() - for point in sub_query_point: - result_ids.append(getattr(point, self.mapped_class.description().get('pk_name'))) - for line in sub_query_line: - result_ids.append(getattr(line, self.mapped_class.description().get('pk_name'))) - for polygon in sub_query_polygon: - result_ids.append(getattr(polygon, self.mapped_class.description().get('pk_name'))) - return result_ids + self.filter_list_or.append(text(sql_text_point)) + self.filter_list_or.append(text(sql_text_line)) + self.filter_list_or.append(text(sql_text_polygon)) def extract_geometry_collection_input(self, compare_geometry, compare_type, column_name): - result_ids = [] db_path_list = [ self.mapped_class.__table_args__.get('schema'), self.mapped_class.__table__.name, @@ -176,19 +161,11 @@ def extract_geometry_collection_input(self, compare_geometry, compare_type, colu sql_text_point = '{0}({1}, ST_CollectionExtract(ST_GeomFromText(\'{2}\', 2056), 1))'.format(compare_type, db_path, compare_geometry) sql_text_line = '{0}({1}, ST_CollectionExtract(ST_GeomFromText(\'{2}\', 2056), 2))'.format(compare_type, db_path, compare_geometry) sql_text_polygon = '{0}({1}, ST_CollectionExtract(ST_GeomFromText(\'{2}\', 2056), 3))'.format(compare_type, db_path, compare_geometry) - sub_query_point = self.db_session.query(self.mapped_class).filter(text(sql_text_point)).all() - sub_query_line = self.db_session.query(self.mapped_class).filter(text(sql_text_line)).all() - sub_query_polygon = self.db_session.query(self.mapped_class).filter(text(sql_text_polygon)).all() - for point in sub_query_point: - result_ids.append(getattr(point, self.mapped_class.description().get('pk_name'))) - for line in sub_query_line: - result_ids.append(getattr(line, self.mapped_class.description().get('pk_name'))) - for polygon in sub_query_polygon: - result_ids.append(getattr(polygon, self.mapped_class.description().get('pk_name'))) - return result_ids + self.filter_list_or.append(text(sql_text_point)) + self.filter_list_or.append(text(sql_text_line)) + self.filter_list_or.append(text(sql_text_polygon)) def extract_geometry_collection_input_and_db(self, compare_geometry, compare_type, column_name): - result_ids = [] db_path_list = [ self.mapped_class.__table_args__.get('schema'), self.mapped_class.__table__.name, @@ -198,13 +175,6 @@ def extract_geometry_collection_input_and_db(self, compare_geometry, compare_typ sql_text_point = '{0}((ST_CollectionExtract({1}, 1), ST_CollectionExtract(ST_GeomFromText(\'{2}\', 2056), 1))'.format(compare_type, db_path, compare_geometry) sql_text_line = '{0}((ST_CollectionExtract({1}, 2), ST_CollectionExtract(ST_GeomFromText(\'{2}\', 2056), 2))'.format(compare_type, db_path, compare_geometry) sql_text_polygon = '{0}((ST_CollectionExtract({1}, 3), ST_CollectionExtract(ST_GeomFromText(\'{2}\', 2056), 3))'.format(compare_type, db_path, compare_geometry) - sub_query_point = self.db_session.query(self.mapped_class).filter(text(sql_text_point)).all() - sub_query_line = self.db_session.query(self.mapped_class).filter(text(sql_text_line)).all() - sub_query_polygon = self.db_session.query(self.mapped_class).filter(text(sql_text_polygon)).all() - for point in sub_query_point: - result_ids.append(getattr(point, self.mapped_class.description().get('pk_name'))) - for line in sub_query_line: - result_ids.append(getattr(line, self.mapped_class.description().get('pk_name'))) - for polygon in sub_query_polygon: - result_ids.append(getattr(polygon, self.mapped_class.description().get('pk_name'))) - return result_ids \ No newline at end of file + self.filter_list_or.append(text(sql_text_point)) + self.filter_list_or.append(text(sql_text_line)) + self.filter_list_or.append(text(sql_text_polygon)) \ No newline at end of file