Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sckan 349 add ownership check for alerts system + update displaying alerts for invalid/exported statements #382

Merged
merged 4 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions backend/composer/api/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from composer.models import ConnectivityStatement, Sentence



# Permission Checks: Only staff users can update a Connectivity Statement when it is in state exported
class IsStaffUserIfExportedStateInConnectivityStatement(permissions.BasePermission):
def has_object_permission(self, request, view, obj):
if (request.method not in permissions.SAFE_METHODS) and (obj.state == CSState.EXPORTED):
if (request.method not in permissions.SAFE_METHODS) and (
obj.state == CSState.EXPORTED
):
return request.user.is_staff
return True

Expand All @@ -26,7 +27,7 @@ def has_permission(self, request, view):
return True

# If creating a new instance, ensure related entity ownership
if request.method == 'POST' and view.action == 'create':
if request.method == "POST" and view.action == "create":
return check_related_entity_ownership(request)

# For unsafe methods (PATCH, PUT, DELETE), allow only authenticated users
Expand All @@ -40,33 +41,47 @@ def has_object_permission(self, request, view, obj):
return True

# Allow 'assign_owner' action to any authenticated user
if view.action == 'assign_owner':
if view.action == "assign_owner":
return request.user.is_authenticated

# Write and delete permissions (PATCH, PUT, DELETE) are only allowed to the owner
return obj.owner == request.user


class IsOwnerOfConnectivityStatementOrReadOnly(permissions.BasePermission):
"""
Custom permission to allow only the owner of the related ConnectivityStatement to modify.
"""

def has_permission(self, request, view):
# Allow safe methods (GET, HEAD, OPTIONS) for all users
if request.method in permissions.SAFE_METHODS:
return True

# If creating a new instance, ensure related entity ownership
if request.method == "POST" and view.action == "create":
return check_related_entity_ownership(request)

# For unsafe methods (PATCH, PUT, DELETE), allow only authenticated users
# Object-level permissions (e.g., ownership) are handled by has_object_permission
return request.user.is_authenticated

def has_object_permission(self, request, view, obj):
# Read permissions are allowed to any request
if request.method in permissions.SAFE_METHODS:
return True

# Write permissions are only allowed to the owner of the related ConnectivityStatement
return obj.connectivity_statement.owner == request.user



def check_related_entity_ownership(request):
"""
Helper method to check ownership of sentence or connectivity statement.
Raises PermissionDenied if the user is not the owner.
"""
sentence_id = request.data.get('sentence_id')
connectivity_statement_id = request.data.get('connectivity_statement_id')
sentence_id = request.data.get("sentence_id")
connectivity_statement_id = request.data.get("connectivity_statement_id")

# Check ownership for sentence_id
if sentence_id:
Expand All @@ -80,10 +95,12 @@ def check_related_entity_ownership(request):
# Check ownership for connectivity_statement_id
if connectivity_statement_id:
try:
connectivity_statement = ConnectivityStatement.objects.get(id=connectivity_statement_id)
connectivity_statement = ConnectivityStatement.objects.get(
id=connectivity_statement_id
)
except ConnectivityStatement.DoesNotExist:
raise PermissionDenied()
if connectivity_statement.owner != request.user:
raise PermissionDenied()
return True

return True
22 changes: 11 additions & 11 deletions backend/composer/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ class Meta:
class StatementAlertSerializer(serializers.ModelSerializer):
id = serializers.IntegerField(required=False)

connectivity_statement = serializers.PrimaryKeyRelatedField(
connectivity_statement_id = serializers.PrimaryKeyRelatedField(
queryset=ConnectivityStatement.objects.all(), required=True
)
alert_type = serializers.PrimaryKeyRelatedField(
Expand All @@ -533,7 +533,7 @@ class Meta:
"saved_by",
"created_at",
"updated_at",
"connectivity_statement",
"connectivity_statement_id",
)
read_only_fields = ("created_at", "updated_at", "saved_by")
validators = []
Expand All @@ -542,24 +542,24 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# If 'connectivity_statement' is provided in context, make it not required
if 'connectivity_statement' in self.context:
self.fields['connectivity_statement'].required = False
if 'connectivity_statement_id' in self.context:
self.fields['connectivity_statement_id'].required = False

# If updating an instance, set 'alert_type' and 'connectivity_statement' as read-only
if self.instance:
self.fields['alert_type'].read_only = True
self.fields['connectivity_statement'].read_only = True
self.fields['connectivity_statement_id'].read_only = True

def validate(self, data):
# Get 'connectivity_statement' from context or instance
connectivity_statement = self.context.get('connectivity_statement') or data.get('connectivity_statement')
connectivity_statement = self.context.get('connectivity_statement_id') or data.get('connectivity_statement_id')
if not connectivity_statement and self.instance:
connectivity_statement = self.instance.connectivity_statement
if not connectivity_statement:
raise serializers.ValidationError({
'connectivity_statement': 'This field is required.'
'connectivity_statement_id': 'This field is required.'
})
data['connectivity_statement'] = connectivity_statement
data['connectivity_statement_id'] = connectivity_statement.id

# Get 'alert_type' from data or instance
alert_type = data.get('alert_type') or getattr(self.instance, 'alert_type', None)
Expand Down Expand Up @@ -850,13 +850,13 @@ def _update_statement_alerts(self, instance, alerts_data):
alert_instance = existing_alerts[alert_id]
# Remove 'alert_type' and 'connectivity_statement' from alert_data
alert_data.pop('alert_type', None)
alert_data.pop('connectivity_statement', None)
alert_data.pop('connectivity_statement_id', None)
serializer = StatementAlertSerializer(
alert_instance,
data=alert_data,
context={
"request": self.context.get("request"),
"connectivity_statement": instance, # Pass the parent instance
"connectivity_statement_id": instance.id, # Pass the parent instance
},
)
serializer.is_valid(raise_exception=True)
Expand All @@ -867,7 +867,7 @@ def _update_statement_alerts(self, instance, alerts_data):
data=alert_data,
context={
"request": self.context.get("request"),
"connectivity_statement": instance, # Pass the parent instance
"connectivity_statement_id": instance.id, # Pass the parent instance
},
)
serializer.is_valid(raise_exception=True)
Expand Down
29 changes: 15 additions & 14 deletions backend/composer/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ class AlertTypeViewSet(viewsets.ReadOnlyModelViewSet):
queryset = AlertType.objects.all()
serializer_class = AlertTypeSerializer


class ConnectivityStatementViewSet(
ProvenanceMixin,
SpecieMixin,
Expand Down Expand Up @@ -456,13 +457,18 @@ class SentenceViewSet(

def get_queryset(self):
if "ordering" not in self.request.query_params:
return super().get_queryset().annotate(
is_current_user=Case(
When(owner=self.request.user, then=Value(1)),
default=Value(0),
output_field=IntegerField(),
return (
super()
.get_queryset()
.annotate(
is_current_user=Case(
When(owner=self.request.user, then=Value(1)),
default=Value(0),
output_field=IntegerField(),
)
)
).order_by("-is_current_user", "-modified_date")
.order_by("-is_current_user", "-modified_date")
)
return super().get_queryset()


Expand Down Expand Up @@ -531,21 +537,16 @@ class DestinationViewSet(viewsets.ModelViewSet):
permission_classes = [IsOwnerOfConnectivityStatementOrReadOnly]
filterset_class = DestinationFilter


class StatementAlertViewSet(viewsets.ModelViewSet):
"""
StatementAlert
"""

queryset = StatementAlert.objects.all()
serializer_class = StatementAlertSerializer
permission_classes = [IsOwnerOfConnectivityStatementOrReadOnly]

def create(self, request, *args, **kwargs):
try:
return super().create(request, *args, **kwargs)
except Exception as e:
raise


@extend_schema(
responses=OpenApiTypes.OBJECT,
)
Expand All @@ -560,7 +561,7 @@ def jsonschemas(request):
ProvenanceSerializer,
SpecieSerializer,
NoteSerializer,
StatementAlertSerializer
StatementAlertSerializer,
]

schema = {}
Expand Down
4 changes: 2 additions & 2 deletions frontend/src/apiclient/backend/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2094,7 +2094,7 @@ export interface PatchedStatementAlert {
* @type {number}
* @memberof PatchedStatementAlert
*/
'connectivity_statement'?: number;
'connectivity_statement_id'?: number;
}
/**
* Via
Expand Down Expand Up @@ -2610,7 +2610,7 @@ export interface StatementAlert {
* @type {number}
* @memberof StatementAlert
*/
'connectivity_statement': number;
'connectivity_statement_id': number;
}
/**
* Note Tag
Expand Down
Loading
Loading