From dbe6f630243b8237fe3422cc7f2e8068db9631b9 Mon Sep 17 00:00:00 2001 From: Ben White Date: Mon, 18 Mar 2024 16:36:34 +0100 Subject: [PATCH] Fixes --- ee/api/rbac/access_control.py | 51 ++++++++++++++- ee/api/rbac/test/test_access_control.py | 64 +++++++++++++++++++ posthog/api/routing.py | 19 +++++- .../rbac/test/test_user_access_control.py | 13 +++- {ee => posthog}/rbac/user_access_control.py | 46 ++++++++++--- 5 files changed, 177 insertions(+), 16 deletions(-) create mode 100644 ee/api/rbac/test/test_access_control.py rename {ee => posthog}/rbac/test/test_user_access_control.py (98%) rename {ee => posthog}/rbac/user_access_control.py (87%) diff --git a/ee/api/rbac/access_control.py b/ee/api/rbac/access_control.py index 368341f139075..ab22fa6ed34c0 100644 --- a/ee/api/rbac/access_control.py +++ b/ee/api/rbac/access_control.py @@ -1,4 +1,6 @@ -from rest_framework import serializers, mixins, status, viewsets +from typing import cast + +from rest_framework import exceptions, mixins, serializers, status, viewsets from rest_framework.request import Request from rest_framework.response import Response @@ -7,7 +9,7 @@ from posthog.constants import AvailableFeature from posthog.models.personal_api_key import API_SCOPE_OBJECTS from posthog.permissions import PremiumFeaturePermission - +from posthog.rbac.user_access_control import UserAccessControl # TODO: Validate that an access control can only have one of team, organization_member, or role @@ -41,6 +43,51 @@ def validate(self, data): if sum([bool(data.get("team")), bool(data.get("organization_member")), bool(data.get("role"))]) != 1: raise serializers.ValidationError("Exactly one of 'team', 'organization_member', or 'role' must be set.") + access_control = cast(UserAccessControl, self.context["view"].user_access_control) + resource = data["resource"] + resource_id = data.get("resource_id") + + if resource == "project" and resource_id: + # Special check for modifying a specific project's access + if not access_control.check_access_level_for_object("project", data["resource_id"], "admin"): + raise exceptions.PermissionDenied("You do not have the required access to this project.") + + # team: Team = self.context["get_team"]() + # if not team.access_control: + # raise exceptions.ValidationError( + # "Explicit members can only be accessed for projects with project-based permissioning enabled." + # ) + # requesting_user: User = self.context["request"].user + # membership_being_accessed = cast(Optional[ExplicitTeamMembership], self.instance) + # try: + # requesting_level = self.user_permissions.team(team).effective_membership_level + # except OrganizationMembership.DoesNotExist: + # # Requesting user does not belong to the project's organization, so we spoof a 404 for enhanced security + # raise exceptions.NotFound("Project not found.") + + # new_level = attrs.get("level") + + # if requesting_level is None: + # raise exceptions.PermissionDenied("You do not have the required access to this project.") + + # if attrs.get("user_uuid") == requesting_user.uuid: + # # Create-only check + # raise exceptions.PermissionDenied("You can't explicitly add yourself to projects.") + + # if new_level is not None and new_level > requesting_level: + # raise exceptions.PermissionDenied("You can only set access level to lower or equal to your current one.") + + # if membership_being_accessed is not None: + # # Update-only checks + # if membership_being_accessed.parent_membership.user_id != requesting_user.id: + # # Requesting user updating someone else + # if membership_being_accessed.level > requesting_level: + # raise exceptions.PermissionDenied("You can only edit others with level lower or equal to you.") + # else: + # # Requesting user updating themselves + # if new_level is not None: + # raise exceptions.PermissionDenied("You can't set your own access level.") + return data diff --git a/ee/api/rbac/test/test_access_control.py b/ee/api/rbac/test/test_access_control.py new file mode 100644 index 0000000000000..c942cc903b702 --- /dev/null +++ b/ee/api/rbac/test/test_access_control.py @@ -0,0 +1,64 @@ +from rest_framework import status + +from ee.api.test.base import APILicensedTest +from posthog.constants import AvailableFeature +from posthog.models.organization import OrganizationMembership + + +class TestAccessControlAPI(APILicensedTest): + # def _create_access_control( + # self, resource="project", resource_id=None, access_level="admin", organization_member=None, team=None, role=None + # ): + # return AccessControl.objects.create( + # organization=self.organization, + # resource=resource, + # resource_id=resource_id or self.team.id, + # access_level=access_level, + # # Targets + # organization_member=organization_member, + # team=team, + # role=role, + # ) + + def setUp(self): + super().setUp() + self.organization.available_features = [ + AvailableFeature.PROJECT_BASED_PERMISSIONING, + AvailableFeature.ROLE_BASED_ACCESS, + ] + self.organization.save() + + def _put_access_control(self, data): + payload = { + "resource": "project", + "resource_id": self.team.id, + "access_level": "admin", + } + + payload.update(data) + return self.client.put( + "/api/organizations/@current/access_controls", + payload, + ) + + def _org_membership(self, level: OrganizationMembership.Level = OrganizationMembership.Level.ADMIN): + self.organization_membership.level = level + self.organization_membership.save() + + def test_project_change_rejected_if_not_org_admin(self): + self._org_membership(OrganizationMembership.Level.MEMBER) + res = self._put_access_control({"team": self.team.id}) + assert res.status_code == status.HTTP_403_FORBIDDEN, res.json() + + def test_project_change_accepted_if_org_admin(self): + self._org_membership(OrganizationMembership.Level.ADMIN) + res = self._put_access_control({"team": self.team.id}) + assert res.status_code == status.HTTP_200_OK, res.json() + + def test_project_change_if_in_access_control(self): + self._org_membership(OrganizationMembership.Level.ADMIN) + # Add ourselves to access + res = self._put_access_control({"team": self.team.id}) + assert res.status_code == status.HTTP_200_OK, res.json() + + # TODO diff --git a/posthog/api/routing.py b/posthog/api/routing.py index b768538c05d50..36c193720864d 100644 --- a/posthog/api/routing.py +++ b/posthog/api/routing.py @@ -24,6 +24,7 @@ SharingTokenPermission, TeamMemberAccessPermission, ) +from posthog.rbac.user_access_control import UserAccessControl from posthog.user_permissions import UserPermissions if TYPE_CHECKING: @@ -55,7 +56,6 @@ class TeamAndOrgViewSetMixin(_GenericViewSet): authentication_classes = [] permission_classes = [] - # NOTE: Could we type this? Would be pretty cool as a helper scope_object: Optional[APIScopeObjectOrNotSupported] = None required_scopes: Optional[list[str]] = None sharing_enabled_actions: list[str] = [] @@ -239,6 +239,23 @@ def _get_team_from_request(self) -> Optional["Team"]: def user_permissions(self) -> "UserPermissions": return UserPermissions(user=cast(User, self.request.user), team=self.team) + @cached_property + def user_access_control(self) -> "UserAccessControl": + organization = self.organization + team: Optional[Team] = None + try: + # TODO: Check this is correct... + if self.request.resolver_match.url_name.startswith("team-"): + # /projects/ endpoint handling + team = self.get_object() + else: + team = self.team + # TODO: I don't think this will work - we will need to know about the underlying object to get the team I think... + except (Team.DoesNotExist, KeyError): + pass + + return UserAccessControl(user=self.request.user, organization=organization, team=team) + # Stdout tracing to see what legacy endpoints (non-project-nested) are still requested by the frontend # TODO: Delete below when no legacy endpoints are used anymore diff --git a/ee/rbac/test/test_user_access_control.py b/posthog/rbac/test/test_user_access_control.py similarity index 98% rename from ee/rbac/test/test_user_access_control.py rename to posthog/rbac/test/test_user_access_control.py index f4fdec27f017b..c1f3274a40111 100644 --- a/ee/rbac/test/test_user_access_control.py +++ b/posthog/rbac/test/test_user_access_control.py @@ -1,11 +1,18 @@ -from ee.models.rbac.access_control import AccessControl -from ee.models.rbac.role import Role, RoleMembership -from ee.rbac.user_access_control import UserAccessControl +import pytest from posthog.constants import AvailableFeature from posthog.models.user import User +from posthog.rbac.user_access_control import UserAccessControl from posthog.test.base import BaseTest +try: + from ee.models.rbac.access_control import AccessControl + from ee.models.rbac.role import Role, RoleMembership +except ImportError: + pass + + +@pytest.mark.ee class TestUserTeamPermissions(BaseTest): user_access_control: UserAccessControl diff --git a/ee/rbac/user_access_control.py b/posthog/rbac/user_access_control.py similarity index 87% rename from ee/rbac/user_access_control.py rename to posthog/rbac/user_access_control.py index 0e78ea6079181..3b2dfe921cb8a 100644 --- a/ee/rbac/user_access_control.py +++ b/posthog/rbac/user_access_control.py @@ -2,9 +2,8 @@ from functools import cached_property from django.db.models import Q, QuerySet -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional -from ee.models.rbac.access_control import AccessControl from posthog.constants import AvailableFeature from posthog.models import ( Organization, @@ -15,6 +14,21 @@ from posthog.models.personal_api_key import APIScopeObject from posthog.permissions import extract_organization + +if TYPE_CHECKING: + from ee.models import AccessControl + + _AccessControl = AccessControl +else: + _AccessControl = object + + +try: + from ee.models.rbac.access_control import AccessControl +except ImportError: + pass + + MEMBER_BASED_ACCESS_LEVELS = ["member", "admin"] RESOURCE_BASED_ACCESS_LEVELS = ["viewer", "editor"] @@ -39,8 +53,9 @@ def __init__(self, user: User, organization: Organization, team: Optional[Team] self._organization = organization @cached_property - def _organization_membership(self, organization: Organization) -> Optional[OrganizationMembership]: - return OrganizationMembership.objects.get(organization=organization, user=self.user) + def _organization_membership(self) -> Optional[OrganizationMembership]: + # TODO: Don't throw if none + return OrganizationMembership.objects.get(organization=self._organization, user=self._user) @property def _rbac_supported(self) -> bool: @@ -56,7 +71,7 @@ def _access_controls_supported(self) -> bool: ) or self._organization.is_feature_available(AvailableFeature.ADVANCED_PERMISSIONS) # @cached_property - def _access_controls_for_object(self, resource: APIScopeObject, resource_id: str) -> List[AccessControl]: + def _access_controls_for_object(self, resource: APIScopeObject, resource_id: str) -> List[_AccessControl]: """ Used when checking an individual object - gets all access controls for the object and its type """ @@ -80,7 +95,7 @@ def _access_controls_for_object(self, resource: APIScopeObject, resource_id: str ) ) - def access_control_for_object(self, resource: APIScopeObject, resource_id: str) -> Optional[AccessControl]: + def access_control_for_object(self, resource: APIScopeObject, resource_id: str) -> Optional[_AccessControl]: """ Access levels are strings - the order of which is determined at run time. We find all relevant access controls and then return the highest value @@ -113,12 +128,23 @@ def check_access_level_for_object( Returns true or false if access controls are applied, otherwise None """ - access_control = self.access_control_for_object(resource, resource_id) + org_membership = self._organization_membership - if not access_control: - return + if not org_membership: + # NOTE: Here we need to change it to indicate they aren't an org member + return False + + # Org admins always have object level access + if org_membership.level == OrganizationMembership.Level.ADMIN: + return True - return access_level_satisfied(resource, access_control.access_level, required_level) + access_control = self.access_control_for_object(resource, resource_id) + + return ( + None + if not access_control + else access_level_satisfied(resource, access_control.access_level, required_level) + ) # Used for filtering a queryset by access level def filter_queryset_by_access_level(self, queryset: QuerySet) -> QuerySet: