Skip to content

Commit

Permalink
fix: Deleting role members (#20928)
Browse files Browse the repository at this point in the history
  • Loading branch information
benjackwhite authored Mar 18, 2024
1 parent 9de88e7 commit 69b2bec
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 27 deletions.
15 changes: 10 additions & 5 deletions ee/api/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ee.models.feature_flag_role_access import FeatureFlagRoleAccess
from ee.models.organization_resource_access import OrganizationResourceAccess
from ee.models.role import Role, RoleMembership
from posthog.api.organization_member import OrganizationMemberSerializer
from posthog.api.routing import TeamAndOrgViewSetMixin
from posthog.api.shared import UserBasicSerializer
from posthog.models import OrganizationMembership
Expand Down Expand Up @@ -105,20 +106,24 @@ def get_queryset(self):

class RoleMembershipSerializer(serializers.ModelSerializer):
user = UserBasicSerializer(read_only=True)
organization_member = OrganizationMemberSerializer(read_only=True)
role_id = serializers.UUIDField(read_only=True)
user_uuid = serializers.UUIDField(required=True, write_only=True)

class Meta:
model = RoleMembership
fields = ["id", "role_id", "user", "joined_at", "updated_at", "user_uuid"]

read_only_fields = ["id", "role_id", "user"]
fields = ["id", "role_id", "organization_member", "user", "joined_at", "updated_at", "user_uuid"]
read_only_fields = ["id", "role_id", "organization_member", "user", "joined_at", "updated_at"]

def create(self, validated_data):
user_uuid = validated_data.pop("user_uuid")
try:
validated_data["user"] = User.objects.filter(is_active=True).get(uuid=user_uuid)
except User.DoesNotExist:
validated_data["organization_member"] = OrganizationMembership.objects.select_related("user").get(
organization_id=self.context["organization_id"], user__uuid=user_uuid, user__is_active=True
)

validated_data["user"] = validated_data["organization_member"].user
except OrganizationMembership.DoesNotExist:
raise serializers.ValidationError("User does not exist.")
validated_data["role_id"] = self.context["role_id"]
try:
Expand Down
57 changes: 46 additions & 11 deletions ee/api/test/test_role_membership.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,50 @@ def setUp(self):
self.eng_role = Role.objects.create(name="Engineering", organization=self.organization)
self.marketing_role = Role.objects.create(name="Marketing", organization=self.organization)

def test_adds_member_to_a_role(self):
user = User.objects.create_and_join(self.organization, "[email protected]", None)

self.organization_membership.level = OrganizationMembership.Level.ADMIN
self.organization_membership.save()
assert RoleMembership.objects.count() == 0

res = self.client.post(
f"/api/organizations/@current/roles/{self.eng_role.id}/role_memberships",
{"user_uuid": user.uuid},
)

assert res.status_code == status.HTTP_201_CREATED
assert res.json()["id"] == str(RoleMembership.objects.first().id)
assert res.json()["role_id"] == str(self.eng_role.id)
assert res.json()["organization_member"]["user"]["id"] == user.id
assert res.json()["user"]["id"] == user.id

def test_only_organization_admins_and_higher_can_add_users(self):
user_a = User.objects.create_and_join(self.organization, "[email protected]", None)
user_b = User.objects.create_and_join(self.organization, "[email protected]", None)
self.assertEqual(self.organization_membership.level, OrganizationMembership.Level.MEMBER)
assert self.organization_membership.level == OrganizationMembership.Level.MEMBER

add_user_b_res = self.client.post(
f"/api/organizations/@current/roles/{self.eng_role.id}/role_memberships",
{"user_uuid": user_b.uuid},
)
self.assertEqual(add_user_b_res.status_code, status.HTTP_403_FORBIDDEN)
assert add_user_b_res.status_code == status.HTTP_403_FORBIDDEN

self.organization_membership.level = OrganizationMembership.Level.ADMIN
self.organization_membership.save()
add_user_a_res = self.client.post(
f"/api/organizations/@current/roles/{self.eng_role.id}/role_memberships",
{"user_uuid": user_a.uuid},
)
self.assertEqual(add_user_a_res.status_code, status.HTTP_201_CREATED)
self.assertEqual(RoleMembership.objects.count(), 1)
self.assertEqual(RoleMembership.objects.first().user, user_a) # type: ignore
assert add_user_a_res.status_code == status.HTTP_201_CREATED
assert RoleMembership.objects.count() == 1
assert RoleMembership.objects.first().user == user_a

def test_user_can_belong_to_multiple_roles(self):
user_a = User.objects.create_and_join(self.organization, "[email protected]", None)
self.organization_membership.level = OrganizationMembership.Level.ADMIN
self.organization_membership.save()
self.assertEqual(RoleMembership.objects.count(), 0)
assert RoleMembership.objects.count() == 0

self.client.post(
f"/api/organizations/@current/roles/{self.eng_role.id}/role_memberships",
Expand All @@ -47,7 +65,24 @@ def test_user_can_belong_to_multiple_roles(self):
f"/api/organizations/@current/roles/{self.marketing_role.id}/role_memberships",
{"user_uuid": user_a.uuid},
)
self.assertEqual(RoleMembership.objects.count(), 2)
assert RoleMembership.objects.count() == 2

def test_user_can_be_removed_from_role(self):
user_a = User.objects.create_and_join(self.organization, "[email protected]", None)
self.organization_membership.level = OrganizationMembership.Level.ADMIN
self.organization_membership.save()
assert RoleMembership.objects.count() == 0

res = self.client.post(
f"/api/organizations/@current/roles/{self.eng_role.id}/role_memberships",
{"user_uuid": user_a.uuid},
)
assert RoleMembership.objects.count() == 1
delete_response = self.client.delete(
f"/api/organizations/@current/roles/{self.eng_role.id}/role_memberships/{res.json()['id']}",
)
assert delete_response.status_code == status.HTTP_204_NO_CONTENT
assert RoleMembership.objects.count() == 0

def test_returns_correct_results_by_organization(self):
self.organization_membership.level = OrganizationMembership.Level.ADMIN
Expand All @@ -62,10 +97,10 @@ def test_returns_correct_results_by_organization(self):
)
other_org_same_name_role = Role.objects.create(organization=other_org, name="Engineering")
RoleMembership.objects.create(role=other_org_same_name_role, user=user_b)
self.assertEqual(RoleMembership.objects.count(), 2)
assert RoleMembership.objects.count() == 2
get_res = self.client.get(
f"/api/organizations/@current/roles/{self.eng_role.id}/role_memberships",
)
self.assertEqual(get_res.json()["count"], 1)
self.assertEqual(get_res.json()["results"][0]["user"]["distinct_id"], user_a.distinct_id)
self.assertNotContains(get_res, str(user_b.email))
assert get_res.json()["count"] == 1
assert get_res.json()["results"][0]["user"]["distinct_id"] == user_a.distinct_id
assert str(user_b.email) not in get_res.content.decode()
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# serializer version: 1
# name: ClickhouseTestExperimentSecondaryResults.test_basic_secondary_metric_results
'''
/* user_id:108 celery:posthog.tasks.tasks.sync_insight_caching_state */
/* user_id:107 celery:posthog.tasks.tasks.sync_insight_caching_state */
SELECT team_id,
date_diff('second', max(timestamp), now()) AS age
FROM events
Expand Down
25 changes: 25 additions & 0 deletions ee/migrations/0016_rolemembership_organization_member.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Generated by Django 4.1.13 on 2024-03-14 13:40

from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):
dependencies = [
("posthog", "0397_projects_backfill"),
("ee", "0015_add_verified_properties"),
]

operations = [
migrations.AddField(
model_name="rolemembership",
name="organization_member",
field=models.ForeignKey(
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="role_memberships",
related_query_name="role_membership",
to="posthog.organizationmembership",
),
),
]
9 changes: 9 additions & 0 deletions ee/models/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,21 @@ class RoleMembership(UUIDModel):
related_name="roles",
related_query_name="role",
)
# TODO: Eventually remove this as we only need the organization membership
user: models.ForeignKey = models.ForeignKey(
"posthog.User",
on_delete=models.CASCADE,
related_name="role_memberships",
related_query_name="role_membership",
)

organization_member: models.ForeignKey = models.ForeignKey(
"posthog.OrganizationMembership",
on_delete=models.CASCADE,
related_name="role_memberships",
related_query_name="role_membership",
null=True,
)
joined_at: models.DateTimeField = models.DateTimeField(auto_now_add=True)
updated_at: models.DateTimeField = models.DateTimeField(auto_now=True)

Expand Down
2 changes: 1 addition & 1 deletion latest_migrations.manifest
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ admin: 0003_logentry_add_action_flag_choices
auth: 0012_alter_user_first_name_max_length
axes: 0006_remove_accesslog_trusted
contenttypes: 0002_remove_content_type_name
ee: 0015_add_verified_properties
ee: 0016_rolemembership_organization_member
otp_static: 0002_throttling
otp_totp: 0002_auto_20190420_0723
posthog: 0397_projects_backfill
Expand Down
5 changes: 3 additions & 2 deletions posthog/api/organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from django.db.models import Model, QuerySet
from django.shortcuts import get_object_or_404
from django.views import View
from rest_framework import exceptions, permissions, serializers, viewsets
from rest_framework.request import Request

Expand Down Expand Up @@ -48,11 +49,11 @@ def has_permission(self, request: Request, view) -> bool:


class OrganizationPermissionsWithDelete(OrganizationAdminWritePermissions):
def has_object_permission(self, request: Request, view, object: Model) -> bool:
def has_object_permission(self, request: Request, view: View, object: Model) -> bool:
if request.method in permissions.SAFE_METHODS:
return True
# TODO: Optimize so that this computation is only done once, on `OrganizationMemberPermissions`
organization = extract_organization(object)
organization = extract_organization(object, view)
min_level = (
OrganizationMembership.Level.OWNER if request.method == "DELETE" else OrganizationMembership.Level.ADMIN
)
Expand Down
5 changes: 3 additions & 2 deletions posthog/api/organization_member.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from django.db.models import Model, Prefetch, QuerySet
from django.shortcuts import get_object_or_404
from django.views import View
from django_otp.plugins.otp_totp.models import TOTPDevice
from rest_framework import exceptions, mixins, serializers, viewsets
from rest_framework.permissions import SAFE_METHODS, BasePermission
Expand All @@ -22,10 +23,10 @@ class OrganizationMemberObjectPermissions(BasePermission):

message = "Your cannot edit other organization members."

def has_object_permission(self, request: Request, view, membership: OrganizationMembership) -> bool:
def has_object_permission(self, request: Request, view: View, membership: OrganizationMembership) -> bool:
if request.method in SAFE_METHODS:
return True
organization = extract_organization(membership)
organization = extract_organization(membership, view)
requesting_membership: OrganizationMembership = OrganizationMembership.objects.get(
user_id=cast(User, request.user).id,
organization=organization,
Expand Down
19 changes: 14 additions & 5 deletions posthog/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.db.models import Model
from django.core.exceptions import ImproperlyConfigured

from django.views import View
from rest_framework.exceptions import PermissionDenied
from rest_framework.exceptions import NotFound
from rest_framework.permissions import SAFE_METHODS, BasePermission, IsAdminUser
Expand All @@ -19,7 +20,15 @@
CREATE_METHODS = ["POST", "PUT"]


def extract_organization(object: Model) -> Organization:
def extract_organization(object: Model, view: View) -> Organization:
# This is set as part of the TeamAndOrgViewSetMixin to allow models that are not directly related to an organization
organization_id_rewrite = getattr(view, "filter_rewrite_rules", {}).get("organization_id")
if organization_id_rewrite:
for part in organization_id_rewrite.split("__"):
if part == "organization_id":
break
object = getattr(object, part)

if isinstance(object, Organization):
return object
try:
Expand Down Expand Up @@ -89,8 +98,8 @@ def has_permission(self, request: Request, view) -> bool:

return OrganizationMembership.objects.filter(user=cast(User, request.user), organization=organization).exists()

def has_object_permission(self, request: Request, view, object: Model) -> bool:
organization = extract_organization(object)
def has_object_permission(self, request: Request, view: View, object: Model) -> bool:
organization = extract_organization(object, view)
return OrganizationMembership.objects.filter(user=cast(User, request.user), organization=organization).exists()


Expand Down Expand Up @@ -119,12 +128,12 @@ def has_permission(self, request: Request, view) -> bool:
>= OrganizationMembership.Level.ADMIN
)

def has_object_permission(self, request: Request, view, object: Model) -> bool:
def has_object_permission(self, request: Request, view: View, object: Model) -> bool:
if request.method in SAFE_METHODS:
return True

# TODO: Optimize so that this computation is only done once, on `OrganizationMemberPermissions`
organization = extract_organization(object)
organization = extract_organization(object, view)

return (
OrganizationMembership.objects.get(user=cast(User, request.user), organization=organization).level
Expand Down

0 comments on commit 69b2bec

Please sign in to comment.