Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
benjackwhite committed Mar 20, 2024
1 parent d2cddc2 commit 839ee79
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 126 deletions.
129 changes: 40 additions & 89 deletions ee/api/rbac/access_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,71 +80,53 @@ def validate(self, data):
return data


class AccessControlLimitOffsetPagination(LimitOffsetPagination):
class AccessControlViewSetMixin:
"""
To help the UI do its job we can return information about the access levels for the requested resource
Adds an "access_controls" action to the viewset that handles access control for the given resource
"""

def get_paginated_response(self, data):
return Response(
OrderedDict(
[
("count", self.count),
("next", self.get_next_link()),
("previous", self.get_previous_link()),
("available_access_levels", ordered_access_levels(self.request.GET.get("resource"))),
("results", data),
]
)
)

def get_paginated_response_schema(self, schema):
schema = super().get_paginated_response_schema(schema)
# TODO: Now that we are on the viewset we can
# 1. Know that the project level access is covered by the Permission check
# 2. Get the actual object which we can pass to the serializer to check if the user created it
# 3. We can also use the serializer to check the access level for the object

schema["properties"]["available_access_levels"] = {
"type": "array",
"items": {"type": "string"},
}
def _get_access_control_serializer(self, *args, **kwargs):
kwargs.setdefault("context", self.get_serializer_context())
return AccessControlSerializer(*args, **kwargs)

return schema


class AccessControlViewSet(
TeamAndOrgViewSetMixin,
mixins.ListModelMixin,
viewsets.GenericViewSet,
):
scope_object = "INTERNAL"
serializer_class = AccessControlSerializer
queryset = AccessControl.objects.all()
permission_classes = [PremiumFeaturePermission]
# NOTE: DashboardCollaborators that should be replaced by this use ADVANCED_PERMISSIONS - what do with that?
premium_feature = AvailableFeature.PROJECT_BASED_PERMISSIONING
pagination_class = AccessControlLimitOffsetPagination

def filter_queryset(self, queryset):
params = self.request.GET
def _get_access_controls(self, request: Request):
resource = getattr(self, "scope_object", None)
obj = self.get_object()
resource_id = obj.id

if params.get("resource"):
queryset = queryset.filter(resource=params["resource"])
access_controls = AccessControl.objects.filter(team=self.team, resource=resource, resource_id=resource_id).all()
serializer = self._get_access_control_serializer(instance=access_controls, many=True)

if params.get("resource_id"):
queryset = queryset.filter(resource_id=params["resource_id"])
elif params.get("resource"):
queryset = queryset.filter(resource_id=None)
return Response(
{
"access_controls": serializer.data,
"available_access_levels": ordered_access_levels(resource),
}
)

return queryset
def _update_access_controls(self, request: Request):
resource = getattr(self, "scope_object", None)
obj = self.get_object()
resource_id = str(obj.id)

def put(self, request: Request, *args, **kwargs):
# Generically validate the incoming data
partial_serializer = self.get_serializer(data=request.data)
data = request.data
data["resource"] = resource
data["resource_id"] = resource_id

partial_serializer = self._get_access_control_serializer(data=request.data)
partial_serializer.is_valid(raise_exception=True)
params = partial_serializer.validated_data

instance = self.queryset.filter(
instance = AccessControl.objects.filter(
team=self.team,
resource=params["resource"],
resource_id=params.get("resource_id"),
resource=resource,
resource_id=resource_id,
organization_member=params.get("organization_member"),
role=params.get("role"),
).first()
Expand All @@ -156,50 +138,19 @@ def put(self, request: Request, *args, **kwargs):

# Perform the upsert
if instance:
serializer = self.get_serializer(instance, data=request.data)
serializer = self._get_access_control_serializer(instance, data=request.data)
else:
serializer = self.get_serializer(data=request.data)
serializer = self._get_access_control_serializer(data=request.data)

serializer.is_valid(raise_exception=True)
serializer.validated_data["team"] = self.team
serializer.save()

return Response(serializer.data, status=status.HTTP_200_OK)

@action(methods=["GET"], detail=False)
def check(self, request: Request, *args, **kwargs):
resource = request.GET.get("resource")
resource_id = request.GET.get("resource_id")

if not resource:
raise exceptions.ValidationError("Resource must be provided.")

control = self.user_access_control.access_control_for_object(resource, resource_id)
return Response(
{
"access_level": control.access_level if control else None,
"available_access_levels": ordered_access_levels(resource),
},
status=status.HTTP_403_FORBIDDEN,
)


class AccessControlViewSetMixin:
"""
Adds an "access_control" action to the viewset that handles access control for the given resource
"""

@action(methods=["GET", "PATCH"], detail=True)
@action(methods=["GET", "PUT"], detail=True)
def access_controls(self, request: Request, *args, **kwargs):
resource = getattr(self, "scope_object", None)
obj = self.get_object()
resource_id = obj.id

control = self.user_access_control.access_control_for_object(resource, resource_id)
return Response(
{
"access_level": control.access_level if control else None,
"available_access_levels": ordered_access_levels(resource),
},
status=status.HTTP_403_FORBIDDEN,
)
if request.method == "GET":
return self._get_access_controls(request)
if request.method == "PUT":
return self._update_access_controls(request)
54 changes: 37 additions & 17 deletions ee/api/rbac/test/test_access_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@

from ee.api.test.base import APILicensedTest
from posthog.constants import AvailableFeature
from posthog.models.notebook.notebook import Notebook
from posthog.models.organization import OrganizationMembership


class BaseAccessControlTest(APILicensedTest):
default_resource = "project"
default_resource_id = None
default_access_level = "admin"

def setUp(self):
super().setUp()
self.organization.available_features = [
Expand All @@ -21,11 +18,7 @@ def setUp(self):
self.default_resource_id = self.team.id

def _put_access_control(self, data={}):
payload = {
"resource": self.default_resource,
"resource_id": self.default_resource_id,
"access_level": self.default_access_level,
}
payload = {"access_level": "admin"}

payload.update(data)
return self.client.put(
Expand Down Expand Up @@ -82,7 +75,7 @@ def test_project_change_if_in_access_control(self):
{"organization_member": str(self.organization_membership.id), "access_level": "admin"}
)
assert res.status_code == status.HTTP_403_FORBIDDEN
assert res.json()["detail"] == "You must be an admin to modify project permissions."
assert res.json()["detail"] == "Must be admin to modify project permissions."

def test_project_change_rejected_if_not_in_organization(self):
self.organization_membership.delete()
Expand All @@ -98,18 +91,45 @@ def test_project_change_rejected_if_bad_access_level(self):


class TestAccessControlResourceLevelAPI(BaseAccessControlTest):
default_resource = "dashboard"
default_resource = "notebook"
default_resource_id = 1
default_access_level = "editor"

def setUp(self):
super().setUp()

self.notebook = Notebook.objects.create(
team=self.team, created_by=self.user, short_id="01234", title="first notebook"
)

def _get_access_controls(self, data={}):
return self.client.get(f"/api/projects/@current/notebooks/{self.notebook.short_id}/access_controls")

def _put_access_control(self, data={}):
payload = {
"access_level": self.default_access_level,
}

payload.update(data)
return self.client.put(
f"/api/projects/@current/notebooks/{self.notebook.short_id}/access_controls",
payload,
)

def test_get_access_controls(self):
self._org_membership(OrganizationMembership.Level.MEMBER)
res = self._get_access_controls()
assert res.status_code == status.HTTP_200_OK, res.json()
assert res.json() == {"access_controls": [], "available_access_levels": ["viewer", "editor"]}

def test_change_rejected_if_not_org_admin(self):
self._org_membership(OrganizationMembership.Level.MEMBER)
res = self._put_access_control()
assert res.status_code == status.HTTP_403_FORBIDDEN, res.json()

def test_change_permitted_if_creator_of_the_resource(self):
# TODO: Implement this test
assert False
# self._org_membership(OrganizationMembership.Level.MEMBER)
# res = self._put_access_control()
# assert res.status_code == status.HTTP_403_FORBIDDEN, res.json()
# def test_change_permitted_if_creator_of_the_resource(self):
# # TODO: Implement this test
# assert False
# # self._org_membership(OrganizationMembership.Level.MEMBER)
# # res = self._put_access_control()
# # assert res.status_code == status.HTTP_403_FORBIDDEN, res.json()
9 changes: 1 addition & 8 deletions ee/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from rest_framework_extensions.routers import NestedRegistryItem

from ee.api import integration, time_to_see_data
from .api.rbac import organization_resource_access, role, access_control
from .api.rbac import organization_resource_access, role
from posthog.api.routing import DefaultRouterPlusPlus

from .api import (
Expand Down Expand Up @@ -51,13 +51,6 @@ def extend_api_router(
["organization_id", "role_id"],
)

projects_router.register(
r"access_controls",
access_control.AccessControlViewSet,
"project_access_controls",
["team_id"],
)

# ROUTES TO BE DEPRECATED
project_feature_flags_router.register(
r"role_access",
Expand Down
5 changes: 1 addition & 4 deletions posthog/api/dashboards/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from posthog.models.tagged_item import TaggedItem
from posthog.models.team.team import check_is_feature_available_for_team
from posthog.models.user import User
from posthog.rbac.access_control_api_mixin import AccessControlViewSetMixin
from posthog.user_permissions import UserPermissionsSerializerMixin

logger = structlog.get_logger(__name__)
Expand Down Expand Up @@ -402,9 +401,7 @@ def _update_creation_mode(self, validated_data, use_template: str, use_dashboard
return {**validated_data, "creation_mode": "default"}


class DashboardsViewSet(
TeamAndOrgViewSetMixin, TaggedItemViewSetMixin, ForbidDestroyModel, viewsets.ModelViewSet, AccessControlViewSetMixin
):
class DashboardsViewSet(TeamAndOrgViewSetMixin, TaggedItemViewSetMixin, ForbidDestroyModel, viewsets.ModelViewSet):
scope_object = "dashboard"
queryset = Dashboard.objects.order_by("name")
permission_classes = [CanEditDashboard]
Expand Down
3 changes: 2 additions & 1 deletion posthog/api/notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from posthog.models.activity_logging.activity_page import activity_page_response
from posthog.models.notebook.notebook import Notebook
from posthog.models.utils import UUIDT
from posthog.rbac.access_control_api_mixin import AccessControlViewSetMixin
from posthog.utils import relative_date_parse
from loginas.utils import is_impersonated_session

Expand Down Expand Up @@ -233,7 +234,7 @@ def update(self, instance: Notebook, validated_data: Dict, **kwargs) -> Notebook
],
)
)
class NotebookViewSet(TeamAndOrgViewSetMixin, ForbidDestroyModel, viewsets.ModelViewSet):
class NotebookViewSet(TeamAndOrgViewSetMixin, ForbidDestroyModel, viewsets.ModelViewSet, AccessControlViewSetMixin):
scope_object = "notebook"
queryset = Notebook.objects.all()
filter_backends = [DjangoFilterBackend]
Expand Down
6 changes: 3 additions & 3 deletions posthog/api/organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from posthog.models.signals import mute_selected_signals
from posthog.models.team.util import delete_bulky_postgres_data
from posthog.permissions import (
CREATE_METHODS,
CREATE_ACTIONS,
APIScopePermission,
OrganizationAdminWritePermissions,
extract_organization,
Expand All @@ -32,12 +32,12 @@ class PremiumMultiorganizationPermissions(permissions.BasePermission):

message = "You must upgrade your PostHog plan to be able to create and manage multiple organizations."

def has_permission(self, request: Request, view) -> bool:
def has_permission(self, request: Request, view: View) -> bool:
user = cast(User, request.user)
if (
# Make multiple orgs only premium on self-hosted, since enforcement of this wouldn't make sense on Cloud
not is_cloud()
and request.method in CREATE_METHODS
and view.action in CREATE_ACTIONS
and (
user.organization is None
or not user.organization.is_feature_available(AvailableFeature.ORGANIZATIONS_PROJECTS)
Expand Down
7 changes: 4 additions & 3 deletions posthog/api/team.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@
from posthog.models.team.util import delete_batch_exports, delete_bulky_postgres_data
from posthog.models.utils import generate_random_token_project, UUIDT
from posthog.permissions import (
CREATE_METHODS,
CREATE_ACTIONS,
APIScopePermission,
OrganizationAdminWritePermissions,
OrganizationMemberPermissions,
TeamMemberLightManagementPermission,
TeamMemberStrictManagementPermission,
get_organization_from_view,
)
from posthog.rbac.access_control_api_mixin import AccessControlViewSetMixin
from posthog.tasks.demo_create_data import create_data_for_demo_team
from posthog.user_permissions import UserPermissions, UserPermissionsSerializerMixin
from posthog.utils import get_ip_address, get_week_start_for_country_code
Expand All @@ -57,7 +58,7 @@ class PremiumMultiProjectPermissions(BasePermission):
message = "You must upgrade your PostHog plan to be able to create and manage multiple projects."

def has_permission(self, request: request.Request, view) -> bool:
if request.method in CREATE_METHODS:
if view.action in CREATE_ACTIONS:
try:
organization = get_organization_from_view(view)
except ValueError:
Expand Down Expand Up @@ -390,7 +391,7 @@ def update(self, instance: Team, validated_data: Dict[str, Any]) -> Team:
return updated_team


class TeamViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet):
class TeamViewSet(TeamAndOrgViewSetMixin, AccessControlViewSetMixin, viewsets.ModelViewSet):
"""
Projects for the current organization.
"""
Expand Down
2 changes: 1 addition & 1 deletion posthog/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from posthog.models.personal_api_key import APIScopeObjectOrNotSupported
from posthog.utils import get_can_create_org

CREATE_METHODS = ["POST", "PUT"]
CREATE_ACTIONS = ["create", "update", "partial_update"]


def extract_organization(object: Model, view: View) -> Organization:
Expand Down

0 comments on commit 839ee79

Please sign in to comment.