Skip to content

Commit

Permalink
Merge pull request #382 from MetaCell/feature/SCKAN-349-add-ownership
Browse files Browse the repository at this point in the history
Feature/sckan 349 add ownership check for alerts system + update displaying alerts for invalid/exported statements
  • Loading branch information
ddelpiano authored Dec 11, 2024
2 parents cf517e8 + cf69b21 commit 48ac0b9
Show file tree
Hide file tree
Showing 11 changed files with 203 additions and 118 deletions.
37 changes: 27 additions & 10 deletions backend/composer/api/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from composer.models import ConnectivityStatement, Sentence



# Permission Checks: Only staff users can update a Connectivity Statement when it is in state exported
class IsStaffUserIfExportedStateInConnectivityStatement(permissions.BasePermission):
def has_object_permission(self, request, view, obj):
if (request.method not in permissions.SAFE_METHODS) and (obj.state == CSState.EXPORTED):
if (request.method not in permissions.SAFE_METHODS) and (
obj.state == CSState.EXPORTED
):
return request.user.is_staff
return True

Expand All @@ -26,7 +27,7 @@ def has_permission(self, request, view):
return True

# If creating a new instance, ensure related entity ownership
if request.method == 'POST' and view.action == 'create':
if request.method == "POST" and view.action == "create":
return check_related_entity_ownership(request)

# For unsafe methods (PATCH, PUT, DELETE), allow only authenticated users
Expand All @@ -40,33 +41,47 @@ def has_object_permission(self, request, view, obj):
return True

# Allow 'assign_owner' action to any authenticated user
if view.action == 'assign_owner':
if view.action == "assign_owner":
return request.user.is_authenticated

# Write and delete permissions (PATCH, PUT, DELETE) are only allowed to the owner
return obj.owner == request.user


class IsOwnerOfConnectivityStatementOrReadOnly(permissions.BasePermission):
"""
Custom permission to allow only the owner of the related ConnectivityStatement to modify.
"""

def has_permission(self, request, view):
# Allow safe methods (GET, HEAD, OPTIONS) for all users
if request.method in permissions.SAFE_METHODS:
return True

# If creating a new instance, ensure related entity ownership
if request.method == "POST" and view.action == "create":
return check_related_entity_ownership(request)

# For unsafe methods (PATCH, PUT, DELETE), allow only authenticated users
# Object-level permissions (e.g., ownership) are handled by has_object_permission
return request.user.is_authenticated

def has_object_permission(self, request, view, obj):
# Read permissions are allowed to any request
if request.method in permissions.SAFE_METHODS:
return True

# Write permissions are only allowed to the owner of the related ConnectivityStatement
return obj.connectivity_statement.owner == request.user



def check_related_entity_ownership(request):
"""
Helper method to check ownership of sentence or connectivity statement.
Raises PermissionDenied if the user is not the owner.
"""
sentence_id = request.data.get('sentence_id')
connectivity_statement_id = request.data.get('connectivity_statement_id')
sentence_id = request.data.get("sentence_id")
connectivity_statement_id = request.data.get("connectivity_statement_id")

# Check ownership for sentence_id
if sentence_id:
Expand All @@ -80,10 +95,12 @@ def check_related_entity_ownership(request):
# Check ownership for connectivity_statement_id
if connectivity_statement_id:
try:
connectivity_statement = ConnectivityStatement.objects.get(id=connectivity_statement_id)
connectivity_statement = ConnectivityStatement.objects.get(
id=connectivity_statement_id
)
except ConnectivityStatement.DoesNotExist:
raise PermissionDenied()
if connectivity_statement.owner != request.user:
raise PermissionDenied()
return True

return True
22 changes: 11 additions & 11 deletions backend/composer/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ class Meta:
class StatementAlertSerializer(serializers.ModelSerializer):
id = serializers.IntegerField(required=False)

connectivity_statement = serializers.PrimaryKeyRelatedField(
connectivity_statement_id = serializers.PrimaryKeyRelatedField(
queryset=ConnectivityStatement.objects.all(), required=True
)
alert_type = serializers.PrimaryKeyRelatedField(
Expand All @@ -533,7 +533,7 @@ class Meta:
"saved_by",
"created_at",
"updated_at",
"connectivity_statement",
"connectivity_statement_id",
)
read_only_fields = ("created_at", "updated_at", "saved_by")
validators = []
Expand All @@ -542,24 +542,24 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# If 'connectivity_statement' is provided in context, make it not required
if 'connectivity_statement' in self.context:
self.fields['connectivity_statement'].required = False
if 'connectivity_statement_id' in self.context:
self.fields['connectivity_statement_id'].required = False

# If updating an instance, set 'alert_type' and 'connectivity_statement' as read-only
if self.instance:
self.fields['alert_type'].read_only = True
self.fields['connectivity_statement'].read_only = True
self.fields['connectivity_statement_id'].read_only = True

def validate(self, data):
# Get 'connectivity_statement' from context or instance
connectivity_statement = self.context.get('connectivity_statement') or data.get('connectivity_statement')
connectivity_statement = self.context.get('connectivity_statement_id') or data.get('connectivity_statement_id')
if not connectivity_statement and self.instance:
connectivity_statement = self.instance.connectivity_statement
if not connectivity_statement:
raise serializers.ValidationError({
'connectivity_statement': 'This field is required.'
'connectivity_statement_id': 'This field is required.'
})
data['connectivity_statement'] = connectivity_statement
data['connectivity_statement_id'] = connectivity_statement.id

# Get 'alert_type' from data or instance
alert_type = data.get('alert_type') or getattr(self.instance, 'alert_type', None)
Expand Down Expand Up @@ -850,13 +850,13 @@ def _update_statement_alerts(self, instance, alerts_data):
alert_instance = existing_alerts[alert_id]
# Remove 'alert_type' and 'connectivity_statement' from alert_data
alert_data.pop('alert_type', None)
alert_data.pop('connectivity_statement', None)
alert_data.pop('connectivity_statement_id', None)
serializer = StatementAlertSerializer(
alert_instance,
data=alert_data,
context={
"request": self.context.get("request"),
"connectivity_statement": instance, # Pass the parent instance
"connectivity_statement_id": instance.id, # Pass the parent instance
},
)
serializer.is_valid(raise_exception=True)
Expand All @@ -867,7 +867,7 @@ def _update_statement_alerts(self, instance, alerts_data):
data=alert_data,
context={
"request": self.context.get("request"),
"connectivity_statement": instance, # Pass the parent instance
"connectivity_statement_id": instance.id, # Pass the parent instance
},
)
serializer.is_valid(raise_exception=True)
Expand Down
29 changes: 15 additions & 14 deletions backend/composer/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ class AlertTypeViewSet(viewsets.ReadOnlyModelViewSet):
queryset = AlertType.objects.all()
serializer_class = AlertTypeSerializer


class ConnectivityStatementViewSet(
ProvenanceMixin,
SpecieMixin,
Expand Down Expand Up @@ -456,13 +457,18 @@ class SentenceViewSet(

def get_queryset(self):
if "ordering" not in self.request.query_params:
return super().get_queryset().annotate(
is_current_user=Case(
When(owner=self.request.user, then=Value(1)),
default=Value(0),
output_field=IntegerField(),
return (
super()
.get_queryset()
.annotate(
is_current_user=Case(
When(owner=self.request.user, then=Value(1)),
default=Value(0),
output_field=IntegerField(),
)
)
).order_by("-is_current_user", "-modified_date")
.order_by("-is_current_user", "-modified_date")
)
return super().get_queryset()


Expand Down Expand Up @@ -531,21 +537,16 @@ class DestinationViewSet(viewsets.ModelViewSet):
permission_classes = [IsOwnerOfConnectivityStatementOrReadOnly]
filterset_class = DestinationFilter


class StatementAlertViewSet(viewsets.ModelViewSet):
"""
StatementAlert
"""

queryset = StatementAlert.objects.all()
serializer_class = StatementAlertSerializer
permission_classes = [IsOwnerOfConnectivityStatementOrReadOnly]

def create(self, request, *args, **kwargs):
try:
return super().create(request, *args, **kwargs)
except Exception as e:
raise


@extend_schema(
responses=OpenApiTypes.OBJECT,
)
Expand All @@ -560,7 +561,7 @@ def jsonschemas(request):
ProvenanceSerializer,
SpecieSerializer,
NoteSerializer,
StatementAlertSerializer
StatementAlertSerializer,
]

schema = {}
Expand Down
4 changes: 2 additions & 2 deletions frontend/src/apiclient/backend/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2094,7 +2094,7 @@ export interface PatchedStatementAlert {
* @type {number}
* @memberof PatchedStatementAlert
*/
'connectivity_statement'?: number;
'connectivity_statement_id'?: number;
}
/**
* Via
Expand Down Expand Up @@ -2610,7 +2610,7 @@ export interface StatementAlert {
* @type {number}
* @memberof StatementAlert
*/
'connectivity_statement': number;
'connectivity_statement_id': number;
}
/**
* Note Tag
Expand Down
Loading

0 comments on commit 48ac0b9

Please sign in to comment.