From c2b403bd11967afac122b0f56fa1d1370ea8bccd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Brunner?= Date: Thu, 11 Jul 2024 15:39:50 +0200 Subject: [PATCH] Don't do 2 update per editing In papyrus there is a flush that create a first commit The idea is to update the fields in the before_* callback to avoid having tow commits --- .../lib/dbreflection.py | 13 +- .../c2cgeoportal_geoportal/views/layers.py | 264 +++++++++--------- 2 files changed, 145 insertions(+), 132 deletions(-) diff --git a/geoportal/c2cgeoportal_geoportal/lib/dbreflection.py b/geoportal/c2cgeoportal_geoportal/lib/dbreflection.py index bb46b1ccc0b..6fcaaaa8b50 100644 --- a/geoportal/c2cgeoportal_geoportal/lib/dbreflection.py +++ b/geoportal/c2cgeoportal_geoportal/lib/dbreflection.py @@ -130,14 +130,19 @@ def get_table( # create table and reflect it with warnings.catch_warnings(): warnings.filterwarnings("ignore", "Did not recognize type 'geometry' of column", SAWarning) - args = [tablename, metadata] + args = [] if primary_key is not None: # Ensure we have a primary key to be able to edit views args.append(Column(primary_key, Integer, primary_key=True)) with _get_table_lock: - table = Table(*args, schema=schema, autoload_with=engine) # type: ignore[arg-type] - print(f"Table {tablename} loaded") - print([c.name for c in table.columns]) + table = Table( + tablename, + metadata, + *args, + schema=schema, + autoload_with=engine, + keep_existing=True, + ) return table diff --git a/geoportal/c2cgeoportal_geoportal/views/layers.py b/geoportal/c2cgeoportal_geoportal/views/layers.py index 75697146d07..90b6473d5f2 100644 --- a/geoportal/c2cgeoportal_geoportal/views/layers.py +++ b/geoportal/c2cgeoportal_geoportal/views/layers.py @@ -30,7 +30,7 @@ import os from collections.abc import Generator from datetime import datetime -from typing import TYPE_CHECKING, Any, TypedDict, cast +from typing import Any, TypedDict, cast import geoalchemy2.elements import geojson.geometry @@ -39,6 +39,7 @@ import shapely.geometry import sqlalchemy.ext.declarative import sqlalchemy.orm +import sqlalchemy.orm.query from geoalchemy2 import Geometry from geoalchemy2.shape import from_shape, to_shape from geojson.feature import Feature, FeatureCollection @@ -61,17 +62,124 @@ from sqlalchemy.sql import and_, or_ from c2cgeoportal_commons import models +from c2cgeoportal_commons.models import main from c2cgeoportal_geoportal.lib import get_roles_id from c2cgeoportal_geoportal.lib.caching import get_region from c2cgeoportal_geoportal.lib.common_headers import Cache, set_common_headers from c2cgeoportal_geoportal.lib.dbreflection import _AssociationProxy, get_class, get_table -if TYPE_CHECKING: - from c2cgeoportal_commons.models import main # pylint: disable=ungrouped-imports.useless-suppression _LOG = logging.getLogger(__name__) _CACHE_REGION = get_region("std") +class _BaseCallback: + def __init__(self, layer: main.Layer): + self.layer = layer + + def update(self, request: pyramid.request.Request, obj: Any) -> None: + last_update_date = Layers.get_metadata(self.layer, "lastUpdateDateColumn") + if last_update_date is not None: + setattr(obj, last_update_date, datetime.now()) + + last_update_user = Layers.get_metadata(self.layer, "lastUpdateUserColumn") + if last_update_user is not None: + setattr(obj, last_update_user, request.user.id) + + def _get_geometry_check_base_query( + self, request: pyramid.request.Request + ) -> sqlalchemy.orm.query.RowReturningQuery[tuple[int]]: + from c2cgeoportal_commons.models.main import ( # pylint: disable=import-outside-toplevel + Layer, + RestrictionArea, + Role, + ) + + assert models.DBSession is not None + allowed = models.DBSession.query(func.count(RestrictionArea.id)) # pylint: disable=not-callable + allowed = allowed.join(RestrictionArea.roles) + allowed = allowed.join(RestrictionArea.layers) + allowed = allowed.filter(RestrictionArea.readwrite.is_(True)) + allowed = allowed.filter(Role.id.in_(get_roles_id(request))) + allowed = allowed.filter(Layer.id == self.layer.id) + return allowed + + +class _InsertCallback(_BaseCallback): + def __call__(self, request: pyramid.request.Request, feature: Feature, obj: Any) -> None: + from c2cgeoportal_commons.models.main import ( # pylint: disable=import-outside-toplevel + RestrictionArea, + ) + + assert models.DBSession is not None + + geom = feature.geometry + if geom and not isinstance(geom, geojson.geometry.Default): + shape = shapely.geometry.shape(geom) + srid = Layers._get_geom_col_info(self.layer)[1] + spatial_elt = from_shape(shape, srid=srid) + allowed = self._get_geometry_check_base_query(request) + allowed = allowed.filter( + or_(RestrictionArea.area.is_(None), RestrictionArea.area.ST_Contains(spatial_elt)) + ) + if allowed.scalar() == 0: + raise HTTPForbidden() + + # Check if geometry is valid + if Layers._get_validation_setting(self.layer, request): + Layers._validate_geometry(spatial_elt) + + self.update(request, obj) + + +class _UpdateCallback(_BaseCallback): + def __call__(self, request: pyramid.request.Request, feature: Feature, obj: Any) -> None: + from c2cgeoportal_commons.models.main import ( # pylint: disable=import-outside-toplevel + RestrictionArea, + ) + + assert models.DBSession is not None + + # we need both the "original" and "new" geometry to be + # within the restriction area + geom_attr, srid = Layers._get_geom_col_info(self.layer) + geom_attr = getattr(obj, geom_attr) + geom = feature.geometry + allowed = self._get_geometry_check_base_query(request) + allowed = allowed.filter( + or_(RestrictionArea.area.is_(None), RestrictionArea.area.ST_Contains(geom_attr)) + ) + spatial_elt = None + if geom and not isinstance(geom, geojson.geometry.Default): + shape = shapely.geometry.shape(geom) + spatial_elt = from_shape(shape, srid=srid) + allowed = allowed.filter( + or_(RestrictionArea.area.is_(None), RestrictionArea.area.ST_Contains(spatial_elt)) + ) + if allowed.scalar() == 0: + raise HTTPForbidden() + + # Check is geometry is valid + if Layers._get_validation_setting(self.layer, request): + Layers._validate_geometry(spatial_elt) + + self.update(request, obj) + + +class _DeleteCallback(_BaseCallback): + def __call__(self, request: pyramid.request.Request, obj: Any) -> None: + from c2cgeoportal_commons.models.main import ( # pylint: disable=import-outside-toplevel + RestrictionArea, + ) + + geom_attr = getattr(obj, Layers._get_geom_col_info(self.layer)[0]) + allowed = self._get_geometry_check_base_query(request) + allowed = allowed.filter( + or_(RestrictionArea.area.is_(None), RestrictionArea.area.ST_Contains(geom_attr)) + ) + if allowed.scalar() == 0: + raise HTTPForbidden() + + class Layers: """ All the layers view (editing). @@ -81,8 +189,12 @@ class Layers: def __init__(self, request: pyramid.request.Request): self.request = request - self.settings = request.registry.settings.get("layers", {}) - self.layers_enum_config = self.settings.get("enum") + self.settings = self._get_settings(request) + self.layers_enum_config = self.settings.get("enum", {}) + + @staticmethod + def _get_settings(request: pyramid.request.Request) -> dict[str, Any]: + return cast(dict[str, Any], request.registry.settings.get("layers", {})) @staticmethod def _get_geom_col_info(layer: "main.Layer") -> tuple[str, int]: @@ -145,16 +257,24 @@ def _get_layer_for_request(self) -> "main.Layer": """Return a ``Layer`` object for the first layer id found in the ``layer_id`` matchdict.""" return next(self._get_layers_for_request()) - def _get_protocol_for_layer(self, layer: "main.Layer", **kwargs: Any) -> Protocol: + def _get_protocol_for_layer(self, layer: "main.Layer") -> Protocol: """Return a papyrus ``Protocol`` for the ``Layer`` object.""" cls = get_layer_class(layer) geom_attr = self._get_geom_col_info(layer)[0] - return Protocol(models.DBSession, cls, geom_attr, **kwargs) - def _get_protocol_for_request(self, **kwargs: Any) -> Protocol: + return Protocol( + models.DBSession, + cls, + geom_attr, + before_insert=_InsertCallback(layer), + before_update=_UpdateCallback(layer), + before_delete=_DeleteCallback(layer), + ) + + def _get_protocol_for_request(self) -> Protocol: """Return a papyrus ``Protocol`` for the first layer id found in the ``layer_id`` matchdict.""" layer = self._get_layer_for_request() - return self._get_protocol_for_layer(layer, **kwargs) + return self._get_protocol_for_layer(layer) def _proto_read(self, layer: "main.Layer") -> FeatureCollection: """Read features for the layer based on the self.request.""" @@ -265,12 +385,6 @@ def count(self) -> int: @view_config(route_name="layers_create", renderer="geojson") # type: ignore def create(self) -> FeatureCollection | None: - from c2cgeoportal_commons.models.main import ( # pylint: disable=import-outside-toplevel - Layer, - RestrictionArea, - Role, - ) - set_common_headers(self.request, "layers", Cache.PRIVATE_NO) if self.request.user is None: @@ -278,43 +392,11 @@ def create(self) -> FeatureCollection | None: self.request.response.cache_control.no_cache = True - layer = self._get_layer_for_request() - - def check_geometry(_: Any, feature: Feature, obj: Any) -> None: - del obj # unused - assert models.DBSession is not None - - geom = feature.geometry - if geom and not isinstance(geom, geojson.geometry.Default): - shape = shapely.geometry.shape(geom) - srid = self._get_geom_col_info(layer)[1] - spatial_elt = from_shape(shape, srid=srid) - allowed = models.DBSession.query( - func.count(RestrictionArea.id) # pylint: disable=not-callable - ) - allowed = allowed.join(RestrictionArea.roles) - allowed = allowed.join(RestrictionArea.layers) - allowed = allowed.filter(RestrictionArea.readwrite.is_(True)) - allowed = allowed.filter(Role.id.in_(get_roles_id(self.request))) - allowed = allowed.filter(Layer.id == layer.id) - allowed = allowed.filter( - or_(RestrictionArea.area.is_(None), RestrictionArea.area.ST_Contains(spatial_elt)) - ) - if allowed.scalar() == 0: - raise HTTPForbidden() - - # Check if geometry is valid - if self._get_validation_setting(layer): - self._validate_geometry(spatial_elt) - - protocol = self._get_protocol_for_layer(layer, before_create=check_geometry) + protocol = self._get_protocol_for_request() try: features = protocol.create(self.request) if isinstance(features, HTTPException): raise features - if features is not None: - for feature in features.features: # pylint: disable=no-member - self._log_last_update(layer, feature) return features except TopologicalError as e: self.request.response.status_int = 400 @@ -327,12 +409,6 @@ def check_geometry(_: Any, feature: Feature, obj: Any) -> None: @view_config(route_name="layers_update", renderer="geojson") # type: ignore def update(self) -> Feature: - from c2cgeoportal_commons.models.main import ( # pylint: disable=import-outside-toplevel - Layer, - RestrictionArea, - Role, - ) - set_common_headers(self.request, "layers", Cache.PRIVATE_NO) if self.request.user is None: @@ -341,45 +417,11 @@ def update(self) -> Feature: self.request.response.cache_control.no_cache = True feature_id = self.request.matchdict.get("feature_id") - layer = self._get_layer_for_request() - - def check_geometry(_: Any, feature: Feature, obj: Any) -> None: - assert models.DBSession is not None - - # we need both the "original" and "new" geometry to be - # within the restriction area - geom_attr, srid = self._get_geom_col_info(layer) - geom_attr = getattr(obj, geom_attr) - geom = feature.geometry - allowed = models.DBSession.query(func.count(RestrictionArea.id)) # pylint: disable=not-callable - allowed = allowed.join(RestrictionArea.roles) - allowed = allowed.join(RestrictionArea.layers) - allowed = allowed.filter(RestrictionArea.readwrite.is_(True)) - allowed = allowed.filter(Role.id.in_(get_roles_id(self.request))) - allowed = allowed.filter(Layer.id == layer.id) - allowed = allowed.filter( - or_(RestrictionArea.area.is_(None), RestrictionArea.area.ST_Contains(geom_attr)) - ) - spatial_elt = None - if geom and not isinstance(geom, geojson.geometry.Default): - shape = shapely.geometry.shape(geom) - spatial_elt = from_shape(shape, srid=srid) - allowed = allowed.filter( - or_(RestrictionArea.area.is_(None), RestrictionArea.area.ST_Contains(spatial_elt)) - ) - if allowed.scalar() == 0: - raise HTTPForbidden() - - # Check is geometry is valid - if self._get_validation_setting(layer): - self._validate_geometry(spatial_elt) - - protocol = self._get_protocol_for_layer(layer, before_update=check_geometry) + protocol = self._get_protocol_for_request() try: feature = protocol.update(self.request, feature_id) if isinstance(feature, HTTPException): raise feature - self._log_last_update(layer, feature) return cast(Feature, feature) except TopologicalError as e: self.request.response.status_int = 400 @@ -403,15 +445,6 @@ def _validate_geometry(geom: geoalchemy2.elements.WKBElement | None) -> None: reason = models.DBSession.query(func.ST_IsValidReason(func.ST_GeomFromEWKB(geom))).scalar() raise TopologicalError(reason) - def _log_last_update(self, layer: "main.Layer", feature: Feature) -> None: - last_update_date = self.get_metadata(layer, "lastUpdateDateColumn") - if last_update_date is not None: - setattr(feature, last_update_date, datetime.now()) - - last_update_user = self.get_metadata(layer, "lastUpdateUserColumn") - if last_update_user is not None: - setattr(feature, last_update_user, self.request.user.id) - @staticmethod def get_metadata(layer: "main.Layer", key: str, default: str | None = None) -> str | None: metadata = layer.get_metadata(key) @@ -420,46 +453,21 @@ def get_metadata(layer: "main.Layer", key: str, default: str | None = None) -> s return metadata.value return default - def _get_validation_setting(self, layer: "main.Layer") -> bool: + @classmethod + def _get_validation_setting(cls, layer: "main.Layer", request: pyramid.request.Request) -> bool: # The validation UIMetadata is stored as a string, not a boolean - should_validate = self.get_metadata(layer, "geometryValidation", None) + should_validate = cls.get_metadata(layer, "geometryValidation", None) if should_validate: return should_validate.lower() != "false" - return cast(bool, self.settings.get("geometry_validation", False)) + return cast(bool, cls._get_settings(request).get("geometry_validation", False)) @view_config(route_name="layers_delete") # type: ignore def delete(self) -> pyramid.response.Response: - assert models.DBSession is not None - - from c2cgeoportal_commons.models.main import ( # pylint: disable=import-outside-toplevel - Layer, - RestrictionArea, - Role, - ) - if self.request.user is None: raise HTTPForbidden() feature_id = self.request.matchdict.get("feature_id") - layer = self._get_layer_for_request() - - def security_cb(_: Any, obj: Any) -> None: - assert models.DBSession is not None - - geom_attr = getattr(obj, self._get_geom_col_info(layer)[0]) - allowed = models.DBSession.query(func.count(RestrictionArea.id)) # pylint: disable=not-callable - allowed = allowed.join(RestrictionArea.roles) - allowed = allowed.join(RestrictionArea.layers) - allowed = allowed.filter(RestrictionArea.readwrite.is_(True)) - allowed = allowed.filter(Role.id.in_(get_roles_id(self.request))) - allowed = allowed.filter(Layer.id == layer.id) - allowed = allowed.filter( - or_(RestrictionArea.area.is_(None), RestrictionArea.area.ST_Contains(geom_attr)) - ) - if allowed.scalar() == 0: - raise HTTPForbidden() - - protocol = self._get_protocol_for_layer(layer, before_delete=security_cb) + protocol = self._get_protocol_for_request() response = protocol.delete(self.request, feature_id) if isinstance(response, HTTPException): raise response