Skip to content

Commit

Permalink
feat(experiments): Add holdout groups (#25764)
Browse files Browse the repository at this point in the history
  • Loading branch information
neilkakkar authored Oct 24, 2024
1 parent f87d8da commit 75e28ba
Show file tree
Hide file tree
Showing 13 changed files with 562 additions and 9 deletions.
110 changes: 110 additions & 0 deletions ee/clickhouse/views/experiment_holdouts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from typing import Any
from rest_framework import serializers, viewsets
from rest_framework.exceptions import ValidationError
from rest_framework.request import Request
from rest_framework.response import Response
from django.db import transaction


from posthog.api.feature_flag import FeatureFlagSerializer
from posthog.api.routing import TeamAndOrgViewSetMixin
from posthog.api.shared import UserBasicSerializer
from posthog.models.experiment import ExperimentHoldout


class ExperimentHoldoutSerializer(serializers.ModelSerializer):
created_by = UserBasicSerializer(read_only=True)

class Meta:
model = ExperimentHoldout
fields = [
"id",
"name",
"description",
"filters",
"created_by",
"created_at",
"updated_at",
]
read_only_fields = [
"id",
"created_by",
"created_at",
"updated_at",
]

def _get_filters_with_holdout_id(self, id: int, filters: list) -> list:
variant_key = f"holdout-{id}"
updated_filters = []
for filter in filters:
updated_filters.append(
{
**filter,
"variant": variant_key,
}
)
return updated_filters

def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> ExperimentHoldout:
request = self.context["request"]
validated_data["created_by"] = request.user
validated_data["team_id"] = self.context["team_id"]

if not validated_data.get("filters"):
raise ValidationError("Filters are required to create an holdout group")

instance = super().create(validated_data)
instance.filters = self._get_filters_with_holdout_id(instance.id, instance.filters)
instance.save()
return instance

def update(self, instance: ExperimentHoldout, validated_data):
filters = validated_data.get("filters")
if filters and instance.filters != filters:
# update flags on all experiments in this holdout group
new_filters = self._get_filters_with_holdout_id(instance.id, filters)
validated_data["filters"] = new_filters
with transaction.atomic():
for experiment in instance.experiment_set.all():
flag = experiment.feature_flag
existing_flag_serializer = FeatureFlagSerializer(
flag,
data={
"filters": {**flag.filters, "holdout_groups": validated_data["filters"]},
},
partial=True,
context=self.context,
)
existing_flag_serializer.is_valid(raise_exception=True)
existing_flag_serializer.save()

return super().update(instance, validated_data)


class ExperimentHoldoutViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet):
scope_object = "experiment"
queryset = ExperimentHoldout.objects.prefetch_related("created_by").all()
serializer_class = ExperimentHoldoutSerializer
ordering = "-created_at"

def destroy(self, request: Request, *args: Any, **kwargs: Any) -> Response:
instance = self.get_object()

with transaction.atomic():
for experiment in instance.experiment_set.all():
flag = experiment.feature_flag
existing_flag_serializer = FeatureFlagSerializer(
flag,
data={
"filters": {
**flag.filters,
"holdout_groups": None,
}
},
partial=True,
context={"request": request, "team": self.team, "team_id": self.team_id},
)
existing_flag_serializer.is_valid(raise_exception=True)
existing_flag_serializer.save()

return super().destroy(request, *args, **kwargs)
29 changes: 27 additions & 2 deletions ee/clickhouse/views/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class Meta:
"end_date",
"feature_flag_key",
"feature_flag",
"holdout",
"exposure_cohort",
"parameters",
"secondary_metrics",
Expand Down Expand Up @@ -221,6 +222,10 @@ def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Experiment:
if properties:
raise ValidationError("Experiments do not support global filter properties")

holdout_groups = None
if validated_data.get("holdout"):
holdout_groups = validated_data["holdout"].filters

default_variants = [
{"key": "control", "name": "Control Group", "rollout_percentage": 50},
{"key": "test", "name": "Test Variant", "rollout_percentage": 50},
Expand All @@ -230,6 +235,7 @@ def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Experiment:
"groups": [{"properties": properties, "rollout_percentage": 100}],
"multivariate": {"variants": variants or default_variants},
"aggregation_group_type_index": aggregation_group_type_index,
"holdout_groups": holdout_groups,
}

feature_flag_serializer = FeatureFlagSerializer(
Expand Down Expand Up @@ -263,6 +269,7 @@ def update(self, instance: Experiment, validated_data: dict, *args: Any, **kwarg
"parameters",
"archived",
"secondary_metrics",
"holdout",
}
given_keys = set(validated_data.keys())
extra_keys = given_keys - expected_keys
Expand All @@ -273,7 +280,7 @@ def update(self, instance: Experiment, validated_data: dict, *args: Any, **kwarg
if extra_keys:
raise ValidationError(f"Can't update keys: {', '.join(sorted(extra_keys))} on Experiment")

# if an experiment has launched, we cannot edit its variants anymore.
# if an experiment has launched, we cannot edit its variants or holdout anymore.
if not instance.is_draft:
if "feature_flag_variants" in validated_data.get("parameters", {}):
if len(validated_data["parameters"]["feature_flag_variants"]) != len(feature_flag.variants):
Expand All @@ -285,13 +292,19 @@ def update(self, instance: Experiment, validated_data: dict, *args: Any, **kwarg
!= 1
):
raise ValidationError("Can't update feature_flag_variants on Experiment")
if "holdout" in validated_data and validated_data["holdout"] != instance.holdout:
raise ValidationError("Can't update holdout on running Experiment")

properties = validated_data.get("filters", {}).get("properties")
if properties:
raise ValidationError("Experiments do not support global filter properties")

if instance.is_draft:
# if feature flag variants have changed, update the feature flag.
# if feature flag variants or holdout have changed, update the feature flag.
holdout_groups = instance.holdout.filters if instance.holdout else None
if "holdout" in validated_data:
holdout_groups = validated_data["holdout"].filters if validated_data["holdout"] else None

if validated_data.get("parameters"):
variants = validated_data["parameters"].get("feature_flag_variants", [])
aggregation_group_type_index = validated_data["parameters"].get("aggregation_group_type_index")
Expand All @@ -312,6 +325,7 @@ def update(self, instance: Experiment, validated_data: dict, *args: Any, **kwarg
"groups": [{"properties": properties, "rollout_percentage": 100}],
"multivariate": {"variants": variants or default_variants},
"aggregation_group_type_index": aggregation_group_type_index,
"holdout_groups": holdout_groups,
}

existing_flag_serializer = FeatureFlagSerializer(
Expand All @@ -322,6 +336,17 @@ def update(self, instance: Experiment, validated_data: dict, *args: Any, **kwarg
)
existing_flag_serializer.is_valid(raise_exception=True)
existing_flag_serializer.save()
else:
# no parameters provided, just update the holdout if necessary
if "holdout" in validated_data:
existing_flag_serializer = FeatureFlagSerializer(
feature_flag,
data={"filters": {**feature_flag.filters, "holdout_groups": holdout_groups}},
partial=True,
context=self.context,
)
existing_flag_serializer.is_valid(raise_exception=True)
existing_flag_serializer.save()

if instance.is_draft and has_start_date:
feature_flag.active = True
Expand Down
Loading

0 comments on commit 75e28ba

Please sign in to comment.