From 75e28baf2bc9d1dbba796b3b574218bcbf201c81 Mon Sep 17 00:00:00 2001 From: Neil Kakkar Date: Thu, 24 Oct 2024 14:11:41 +0100 Subject: [PATCH] feat(experiments): Add holdout groups (#25764) --- ee/clickhouse/views/experiment_holdouts.py | 110 ++++++++++ ee/clickhouse/views/experiments.py | 29 ++- .../views/test/test_clickhouse_experiments.py | 189 ++++++++++++++++++ .../views/test/test_experiment_holdouts.py | 145 ++++++++++++++ latest_migrations.manifest | 2 +- posthog/api/__init__.py | 4 + .../api/test/__snapshots__/test_api_docs.ambr | 1 + .../test/__snapshots__/test_feature_flag.ambr | 1 + .../test_organization_feature_flag.ambr | 1 + ...97_experimentholdout_experiment_holdout.py | 66 ++++++ posthog/models/experiment.py | 16 ++ posthog/models/feature_flag/flag_matching.py | 2 +- posthog/test/test_feature_flag.py | 5 - 13 files changed, 562 insertions(+), 9 deletions(-) create mode 100644 ee/clickhouse/views/experiment_holdouts.py create mode 100644 ee/clickhouse/views/test/test_experiment_holdouts.py create mode 100644 posthog/migrations/0497_experimentholdout_experiment_holdout.py diff --git a/ee/clickhouse/views/experiment_holdouts.py b/ee/clickhouse/views/experiment_holdouts.py new file mode 100644 index 0000000000000..c7d8eff83ce5a --- /dev/null +++ b/ee/clickhouse/views/experiment_holdouts.py @@ -0,0 +1,110 @@ +from typing import Any +from rest_framework import serializers, viewsets +from rest_framework.exceptions import ValidationError +from rest_framework.request import Request +from rest_framework.response import Response +from django.db import transaction + + +from posthog.api.feature_flag import FeatureFlagSerializer +from posthog.api.routing import TeamAndOrgViewSetMixin +from posthog.api.shared import UserBasicSerializer +from posthog.models.experiment import ExperimentHoldout + + +class ExperimentHoldoutSerializer(serializers.ModelSerializer): + created_by = UserBasicSerializer(read_only=True) + + class Meta: + model = ExperimentHoldout + fields = [ + "id", + "name", + "description", + "filters", + "created_by", + "created_at", + "updated_at", + ] + read_only_fields = [ + "id", + "created_by", + "created_at", + "updated_at", + ] + + def _get_filters_with_holdout_id(self, id: int, filters: list) -> list: + variant_key = f"holdout-{id}" + updated_filters = [] + for filter in filters: + updated_filters.append( + { + **filter, + "variant": variant_key, + } + ) + return updated_filters + + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> ExperimentHoldout: + request = self.context["request"] + validated_data["created_by"] = request.user + validated_data["team_id"] = self.context["team_id"] + + if not validated_data.get("filters"): + raise ValidationError("Filters are required to create an holdout group") + + instance = super().create(validated_data) + instance.filters = self._get_filters_with_holdout_id(instance.id, instance.filters) + instance.save() + return instance + + def update(self, instance: ExperimentHoldout, validated_data): + filters = validated_data.get("filters") + if filters and instance.filters != filters: + # update flags on all experiments in this holdout group + new_filters = self._get_filters_with_holdout_id(instance.id, filters) + validated_data["filters"] = new_filters + with transaction.atomic(): + for experiment in instance.experiment_set.all(): + flag = experiment.feature_flag + existing_flag_serializer = FeatureFlagSerializer( + flag, + data={ + "filters": {**flag.filters, "holdout_groups": validated_data["filters"]}, + }, + partial=True, + context=self.context, + ) + existing_flag_serializer.is_valid(raise_exception=True) + existing_flag_serializer.save() + + return super().update(instance, validated_data) + + +class ExperimentHoldoutViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): + scope_object = "experiment" + queryset = ExperimentHoldout.objects.prefetch_related("created_by").all() + serializer_class = ExperimentHoldoutSerializer + ordering = "-created_at" + + def destroy(self, request: Request, *args: Any, **kwargs: Any) -> Response: + instance = self.get_object() + + with transaction.atomic(): + for experiment in instance.experiment_set.all(): + flag = experiment.feature_flag + existing_flag_serializer = FeatureFlagSerializer( + flag, + data={ + "filters": { + **flag.filters, + "holdout_groups": None, + } + }, + partial=True, + context={"request": request, "team": self.team, "team_id": self.team_id}, + ) + existing_flag_serializer.is_valid(raise_exception=True) + existing_flag_serializer.save() + + return super().destroy(request, *args, **kwargs) diff --git a/ee/clickhouse/views/experiments.py b/ee/clickhouse/views/experiments.py index 7aed519d29ee6..5a389073ef6b7 100644 --- a/ee/clickhouse/views/experiments.py +++ b/ee/clickhouse/views/experiments.py @@ -167,6 +167,7 @@ class Meta: "end_date", "feature_flag_key", "feature_flag", + "holdout", "exposure_cohort", "parameters", "secondary_metrics", @@ -221,6 +222,10 @@ def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Experiment: if properties: raise ValidationError("Experiments do not support global filter properties") + holdout_groups = None + if validated_data.get("holdout"): + holdout_groups = validated_data["holdout"].filters + default_variants = [ {"key": "control", "name": "Control Group", "rollout_percentage": 50}, {"key": "test", "name": "Test Variant", "rollout_percentage": 50}, @@ -230,6 +235,7 @@ def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Experiment: "groups": [{"properties": properties, "rollout_percentage": 100}], "multivariate": {"variants": variants or default_variants}, "aggregation_group_type_index": aggregation_group_type_index, + "holdout_groups": holdout_groups, } feature_flag_serializer = FeatureFlagSerializer( @@ -263,6 +269,7 @@ def update(self, instance: Experiment, validated_data: dict, *args: Any, **kwarg "parameters", "archived", "secondary_metrics", + "holdout", } given_keys = set(validated_data.keys()) extra_keys = given_keys - expected_keys @@ -273,7 +280,7 @@ def update(self, instance: Experiment, validated_data: dict, *args: Any, **kwarg if extra_keys: raise ValidationError(f"Can't update keys: {', '.join(sorted(extra_keys))} on Experiment") - # if an experiment has launched, we cannot edit its variants anymore. + # if an experiment has launched, we cannot edit its variants or holdout anymore. if not instance.is_draft: if "feature_flag_variants" in validated_data.get("parameters", {}): if len(validated_data["parameters"]["feature_flag_variants"]) != len(feature_flag.variants): @@ -285,13 +292,19 @@ def update(self, instance: Experiment, validated_data: dict, *args: Any, **kwarg != 1 ): raise ValidationError("Can't update feature_flag_variants on Experiment") + if "holdout" in validated_data and validated_data["holdout"] != instance.holdout: + raise ValidationError("Can't update holdout on running Experiment") properties = validated_data.get("filters", {}).get("properties") if properties: raise ValidationError("Experiments do not support global filter properties") if instance.is_draft: - # if feature flag variants have changed, update the feature flag. + # if feature flag variants or holdout have changed, update the feature flag. + holdout_groups = instance.holdout.filters if instance.holdout else None + if "holdout" in validated_data: + holdout_groups = validated_data["holdout"].filters if validated_data["holdout"] else None + if validated_data.get("parameters"): variants = validated_data["parameters"].get("feature_flag_variants", []) aggregation_group_type_index = validated_data["parameters"].get("aggregation_group_type_index") @@ -312,6 +325,7 @@ def update(self, instance: Experiment, validated_data: dict, *args: Any, **kwarg "groups": [{"properties": properties, "rollout_percentage": 100}], "multivariate": {"variants": variants or default_variants}, "aggregation_group_type_index": aggregation_group_type_index, + "holdout_groups": holdout_groups, } existing_flag_serializer = FeatureFlagSerializer( @@ -322,6 +336,17 @@ def update(self, instance: Experiment, validated_data: dict, *args: Any, **kwarg ) existing_flag_serializer.is_valid(raise_exception=True) existing_flag_serializer.save() + else: + # no parameters provided, just update the holdout if necessary + if "holdout" in validated_data: + existing_flag_serializer = FeatureFlagSerializer( + feature_flag, + data={"filters": {**feature_flag.filters, "holdout_groups": holdout_groups}}, + partial=True, + context=self.context, + ) + existing_flag_serializer.is_valid(raise_exception=True) + existing_flag_serializer.save() if instance.is_draft and has_start_date: feature_flag.active = True diff --git a/ee/clickhouse/views/test/test_clickhouse_experiments.py b/ee/clickhouse/views/test/test_clickhouse_experiments.py index fdd0c05656c7c..2db87bf8965b2 100644 --- a/ee/clickhouse/views/test/test_clickhouse_experiments.py +++ b/ee/clickhouse/views/test/test_clickhouse_experiments.py @@ -123,6 +123,192 @@ def test_creating_updating_basic_experiment(self): self.assertEqual(experiment.description, "Bazinga") self.assertEqual(experiment.end_date.strftime("%Y-%m-%dT%H:%M"), end_date) + def test_transferring_holdout_to_another_group(self): + response = self.client.post( + f"/api/projects/{self.team.id}/experiment_holdouts/", + data={ + "name": "Test Experiment holdout", + "filters": [ + { + "properties": [], + "rollout_percentage": 20, + "variant": "holdout", + } + ], + }, + format="json", + ) + + holdout_id = response.json()["id"] + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.json()["name"], "Test Experiment holdout") + self.assertEqual( + response.json()["filters"], + [{"properties": [], "rollout_percentage": 20, "variant": f"holdout-{holdout_id}"}], + ) + + # Generate draft experiment to be part of holdout + ff_key = "a-b-tests" + response = self.client.post( + f"/api/projects/{self.team.id}/experiments/", + { + "name": "Test Experiment", + "description": "", + "start_date": None, + "end_date": None, + "feature_flag_key": ff_key, + "parameters": None, + "filters": { + "events": [ + {"order": 0, "id": "$pageview"}, + {"order": 1, "id": "$pageleave"}, + ], + "properties": [], + }, + "holdout": holdout_id, + }, + ) + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.json()["name"], "Test Experiment") + self.assertEqual(response.json()["feature_flag_key"], ff_key) + + created_ff = FeatureFlag.objects.get(key=ff_key) + + self.assertEqual(created_ff.key, ff_key) + self.assertEqual( + created_ff.filters["holdout_groups"], + [{"properties": [], "rollout_percentage": 20, "variant": f"holdout-{holdout_id}"}], + ) + + exp_id = response.json()["id"] + + # new holdout, and update experiment + response = self.client.post( + f"/api/projects/{self.team.id}/experiment_holdouts/", + data={ + "name": "Test Experiment holdout 2", + "filters": [ + { + "properties": [], + "rollout_percentage": 5, + "variant": "holdout", + } + ], + }, + format="json", + ) + holdout_2_id = response.json()["id"] + + response = self.client.patch( + f"/api/projects/{self.team.id}/experiments/{exp_id}", + {"holdout": holdout_2_id}, + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + experiment = Experiment.objects.get(pk=exp_id) + self.assertEqual(experiment.holdout_id, holdout_2_id) + + created_ff = FeatureFlag.objects.get(key=ff_key) + self.assertEqual( + created_ff.filters["holdout_groups"], + [{"properties": [], "rollout_percentage": 5, "variant": f"holdout-{holdout_2_id}"}], + ) + + # update parameters + response = self.client.patch( + f"/api/projects/{self.team.id}/experiments/{exp_id}", + { + "parameters": { + "feature_flag_variants": [ + { + "key": "control", + "name": "Control Group", + "rollout_percentage": 33, + }, + { + "key": "test_1", + "name": "Test Variant", + "rollout_percentage": 33, + }, + { + "key": "test_2", + "name": "Test Variant", + "rollout_percentage": 34, + }, + ] + }, + }, + ) + + experiment = Experiment.objects.get(pk=exp_id) + self.assertEqual(experiment.holdout_id, holdout_2_id) + + created_ff = FeatureFlag.objects.get(key=ff_key) + self.assertEqual( + created_ff.filters["holdout_groups"], + [{"properties": [], "rollout_percentage": 5, "variant": f"holdout-{holdout_2_id}"}], + ) + self.assertEqual( + created_ff.filters["multivariate"]["variants"], + [ + {"key": "control", "name": "Control Group", "rollout_percentage": 33}, + {"key": "test_1", "name": "Test Variant", "rollout_percentage": 33}, + {"key": "test_2", "name": "Test Variant", "rollout_percentage": 34}, + ], + ) + + # remove holdouts + response = self.client.patch( + f"/api/projects/{self.team.id}/experiments/{exp_id}", + {"holdout": None}, + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + experiment = Experiment.objects.get(pk=exp_id) + self.assertEqual(experiment.holdout_id, None) + + created_ff = FeatureFlag.objects.get(key=ff_key) + self.assertEqual(created_ff.filters["holdout_groups"], None) + + # try adding invalid holdout + response = self.client.patch( + f"/api/projects/{self.team.id}/experiments/{exp_id}", + {"holdout": 123456}, + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json()["detail"], 'Invalid pk "123456" - object does not exist.') + + # add back holdout + response = self.client.patch( + f"/api/projects/{self.team.id}/experiments/{exp_id}", + {"holdout": holdout_2_id}, + ) + + # launch experiment and try updating holdouts again + response = self.client.patch( + f"/api/projects/{self.team.id}/experiments/{exp_id}", + {"start_date": "2021-12-01T10:23"}, + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + response = self.client.patch( + f"/api/projects/{self.team.id}/experiments/{exp_id}", + {"holdout": holdout_id}, + ) + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json()["detail"], "Can't update holdout on running Experiment") + + created_ff = FeatureFlag.objects.get(key=ff_key) + self.assertEqual( + created_ff.filters["holdout_groups"], + [{"properties": [], "rollout_percentage": 5, "variant": f"holdout-{holdout_2_id}"}], + ) + def test_adding_behavioral_cohort_filter_to_experiment_fails(self): cohort = Cohort.objects.create( team=self.team, @@ -1119,6 +1305,7 @@ def test_create_experiment_updates_feature_flag_cache(self): ] }, "aggregation_group_type_index": None, + "holdout_groups": None, }, ) @@ -1170,6 +1357,7 @@ def test_create_experiment_updates_feature_flag_cache(self): ] }, "aggregation_group_type_index": None, + "holdout_groups": None, }, ) @@ -1237,6 +1425,7 @@ def test_create_experiment_updates_feature_flag_cache(self): ] }, "aggregation_group_type_index": None, + "holdout_groups": None, }, ) diff --git a/ee/clickhouse/views/test/test_experiment_holdouts.py b/ee/clickhouse/views/test/test_experiment_holdouts.py new file mode 100644 index 0000000000000..37e9dc7b25e5f --- /dev/null +++ b/ee/clickhouse/views/test/test_experiment_holdouts.py @@ -0,0 +1,145 @@ +from rest_framework import status + +from ee.api.test.base import APILicensedTest +from posthog.models.experiment import Experiment +from posthog.models.feature_flag import FeatureFlag + + +class TestExperimentHoldoutCRUD(APILicensedTest): + def test_can_list_experiment_holdouts(self): + response = self.client.get(f"/api/projects/{self.team.id}/experiment_holdouts/") + self.assertEqual(response.status_code, status.HTTP_200_OK) + + def test_create_update_experiment_holdouts(self) -> None: + response = self.client.post( + f"/api/projects/{self.team.id}/experiment_holdouts/", + data={ + "name": "Test Experiment holdout", + "filters": [ + { + "properties": [], + "rollout_percentage": 20, + "variant": "holdout", + } + ], + }, + format="json", + ) + + holdout_id = response.json()["id"] + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.json()["name"], "Test Experiment holdout") + self.assertEqual( + response.json()["filters"], + [{"properties": [], "rollout_percentage": 20, "variant": f"holdout-{holdout_id}"}], + ) + + # Generate experiment to be part of holdout + ff_key = "a-b-tests" + response = self.client.post( + f"/api/projects/{self.team.id}/experiments/", + { + "name": "Test Experiment", + "description": "", + "start_date": "2021-12-01T10:23", + "end_date": None, + "feature_flag_key": ff_key, + "parameters": None, + "filters": { + "events": [ + {"order": 0, "id": "$pageview"}, + {"order": 1, "id": "$pageleave"}, + ], + "properties": [], + }, + "holdout": holdout_id, + }, + ) + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.json()["name"], "Test Experiment") + self.assertEqual(response.json()["feature_flag_key"], ff_key) + + created_ff = FeatureFlag.objects.get(key=ff_key) + + self.assertEqual(created_ff.key, ff_key) + self.assertEqual(created_ff.filters["multivariate"]["variants"][0]["key"], "control") + self.assertEqual(created_ff.filters["multivariate"]["variants"][1]["key"], "test") + self.assertEqual(created_ff.filters["groups"][0]["properties"], []) + self.assertEqual( + created_ff.filters["holdout_groups"], + [{"properties": [], "rollout_percentage": 20, "variant": f"holdout-{holdout_id}"}], + ) + + exp_id = response.json()["id"] + # Now try updating holdout + response = self.client.patch( + f"/api/projects/{self.team.id}/experiment_holdouts/{holdout_id}", + { + "name": "Test Experiment holdout 2", + "filters": [ + { + "properties": [], + "rollout_percentage": 30, + "variant": "holdout", + } + ], + }, + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.json()["name"], "Test Experiment holdout 2") + self.assertEqual( + response.json()["filters"], + [{"properties": [], "rollout_percentage": 30, "variant": f"holdout-{holdout_id}"}], + ) + + # make sure flag for experiment in question was updated as well + created_ff = FeatureFlag.objects.get(key=ff_key) + self.assertEqual( + created_ff.filters["holdout_groups"], + [{"properties": [], "rollout_percentage": 30, "variant": f"holdout-{holdout_id}"}], + ) + + # now delete holdout + response = self.client.delete(f"/api/projects/{self.team.id}/experiment_holdouts/{holdout_id}") + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + + # make sure flag for experiment in question was updated as well + created_ff = FeatureFlag.objects.get(key=ff_key) + self.assertEqual(created_ff.filters["holdout_groups"], None) + + # and same for experiment + exp = Experiment.objects.get(pk=exp_id) + self.assertEqual(exp.holdout, None) + + def test_invalid_create(self): + response = self.client.post( + f"/api/projects/{self.team.id}/experiment_holdouts/", + data={ + "name": None, # invalid + "filters": [ + { + "properties": [], + "rollout_percentage": 20, + "variant": "holdout", + } + ], + }, + format="json", + ) + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json()["detail"], "This field may not be null.") + + response = self.client.post( + f"/api/projects/{self.team.id}/experiment_holdouts/", + data={ + "name": "xyz", + "filters": [], + }, + format="json", + ) + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json()["detail"], "Filters are required to create an holdout group") diff --git a/latest_migrations.manifest b/latest_migrations.manifest index eef9d95450a07..8842f6c11851d 100644 --- a/latest_migrations.manifest +++ b/latest_migrations.manifest @@ -5,7 +5,7 @@ contenttypes: 0002_remove_content_type_name ee: 0016_rolemembership_organization_member otp_static: 0002_throttling otp_totp: 0002_auto_20190420_0723 -posthog: 0496_team_person_processing_opt_out +posthog: 0497_experimentholdout_experiment_holdout sessions: 0001_initial social_django: 0010_uid_db_index two_factor: 0007_auto_20201201_1019 diff --git a/posthog/api/__init__.py b/posthog/api/__init__.py index 173909d404df6..06c97742c3a27 100644 --- a/posthog/api/__init__.py +++ b/posthog/api/__init__.py @@ -411,11 +411,15 @@ def register_grandfathered_environment_nested_viewset( if EE_AVAILABLE: from ee.clickhouse.views.experiments import EnterpriseExperimentsViewSet + from ee.clickhouse.views.experiment_holdouts import ExperimentHoldoutViewSet from ee.clickhouse.views.groups import GroupsTypesViewSet, GroupsViewSet from ee.clickhouse.views.insights import EnterpriseInsightsViewSet from ee.clickhouse.views.person import EnterprisePersonViewSet, LegacyEnterprisePersonViewSet projects_router.register(r"experiments", EnterpriseExperimentsViewSet, "project_experiments", ["project_id"]) + projects_router.register( + r"experiment_holdouts", ExperimentHoldoutViewSet, "project_experiment_holdouts", ["project_id"] + ) register_grandfathered_environment_nested_viewset(r"groups", GroupsViewSet, "environment_groups", ["team_id"]) projects_router.register(r"groups_types", GroupsTypesViewSet, "project_groups_types", ["project_id"]) environment_insights_router, legacy_project_insights_router = register_grandfathered_environment_nested_viewset( diff --git a/posthog/api/test/__snapshots__/test_api_docs.ambr b/posthog/api/test/__snapshots__/test_api_docs.ambr index 5f9fbda86651a..6f47bd7322137 100644 --- a/posthog/api/test/__snapshots__/test_api_docs.ambr +++ b/posthog/api/test/__snapshots__/test_api_docs.ambr @@ -83,6 +83,7 @@ '/home/runner/work/posthog/posthog/posthog/api/team.py: Warning [TeamViewSet > TeamSerializer]: unable to resolve type hint for function "get_product_intents". Consider using a type hint or @extend_schema_field. Defaulting to string.', "/home/runner/work/posthog/posthog/posthog/api/event_definition.py: Error [EventDefinitionViewSet]: exception raised while getting serializer. Hint: Is get_serializer_class() returning None or is get_queryset() not working without a request? Ignoring the view for now. (Exception: 'AnonymousUser' object has no attribute 'organization')", '/home/runner/work/posthog/posthog/posthog/api/event_definition.py: Warning [EventDefinitionViewSet]: could not derive type of path parameter "project_id" because model "posthog.models.event_definition.EventDefinition" contained no such field. Consider annotating parameter with @extend_schema. Defaulting to "string".', + '/home/runner/work/posthog/posthog/ee/clickhouse/views/experiment_holdouts.py: Warning [ExperimentHoldoutViewSet]: could not derive type of path parameter "project_id" because model "posthog.models.experiment.ExperimentHoldout" contained no such field. Consider annotating parameter with @extend_schema. Defaulting to "string".', '/home/runner/work/posthog/posthog/ee/clickhouse/views/experiments.py: Warning [EnterpriseExperimentsViewSet]: could not derive type of path parameter "project_id" because model "posthog.models.experiment.Experiment" contained no such field. Consider annotating parameter with @extend_schema. Defaulting to "string".', '/home/runner/work/posthog/posthog/posthog/api/feature_flag.py: Warning [FeatureFlagViewSet]: could not derive type of path parameter "project_id" because model "posthog.models.feature_flag.feature_flag.FeatureFlag" contained no such field. Consider annotating parameter with @extend_schema. Defaulting to "string".', '/home/runner/work/posthog/posthog/ee/api/feature_flag_role_access.py: Warning [FeatureFlagRoleAccessViewSet]: could not derive type of path parameter "project_id" because model "ee.models.feature_flag_role_access.FeatureFlagRoleAccess" contained no such field. Consider annotating parameter with @extend_schema. Defaulting to "string".', diff --git a/posthog/api/test/__snapshots__/test_feature_flag.ambr b/posthog/api/test/__snapshots__/test_feature_flag.ambr index cfcd27b078478..061a24249522b 100644 --- a/posthog/api/test/__snapshots__/test_feature_flag.ambr +++ b/posthog/api/test/__snapshots__/test_feature_flag.ambr @@ -1809,6 +1809,7 @@ "posthog_experiment"."created_by_id", "posthog_experiment"."feature_flag_id", "posthog_experiment"."exposure_cohort_id", + "posthog_experiment"."holdout_id", "posthog_experiment"."start_date", "posthog_experiment"."end_date", "posthog_experiment"."created_at", diff --git a/posthog/api/test/__snapshots__/test_organization_feature_flag.ambr b/posthog/api/test/__snapshots__/test_organization_feature_flag.ambr index d2a587f083c1e..d978be0b9714c 100644 --- a/posthog/api/test/__snapshots__/test_organization_feature_flag.ambr +++ b/posthog/api/test/__snapshots__/test_organization_feature_flag.ambr @@ -1061,6 +1061,7 @@ "posthog_experiment"."created_by_id", "posthog_experiment"."feature_flag_id", "posthog_experiment"."exposure_cohort_id", + "posthog_experiment"."holdout_id", "posthog_experiment"."start_date", "posthog_experiment"."end_date", "posthog_experiment"."created_at", diff --git a/posthog/migrations/0497_experimentholdout_experiment_holdout.py b/posthog/migrations/0497_experimentholdout_experiment_holdout.py new file mode 100644 index 0000000000000..383fd61708d9f --- /dev/null +++ b/posthog/migrations/0497_experimentholdout_experiment_holdout.py @@ -0,0 +1,66 @@ +# Generated by Django 4.2.15 on 2024-10-24 11:57 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion +import django.utils.timezone + + +class Migration(migrations.Migration): + atomic = False # Added to support concurrent index creation + dependencies = [ + ("posthog", "0496_team_person_processing_opt_out"), + ] + + operations = [ + migrations.CreateModel( + name="ExperimentHoldout", + fields=[ + ("id", models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("name", models.CharField(max_length=400)), + ("description", models.CharField(blank=True, max_length=400, null=True)), + ("filters", models.JSONField(default=list)), + ("created_at", models.DateTimeField(default=django.utils.timezone.now)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "created_by", + models.ForeignKey( + null=True, on_delete=django.db.models.deletion.SET_NULL, to=settings.AUTH_USER_MODEL + ), + ), + ("team", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="posthog.team")), + ], + ), + migrations.SeparateDatabaseAndState( + state_operations=[ + migrations.AddField( + model_name="experiment", + name="holdout", + field=models.ForeignKey( + null=True, on_delete=django.db.models.deletion.SET_NULL, to="posthog.experimentholdout" + ), + ), + ], + database_operations=[ + # We add -- existing-table-constraint-ignore to ignore the constraint validation in CI. + migrations.RunSQL( + """ + ALTER TABLE "posthog_experiment" ADD COLUMN "holdout_id" integer NULL CONSTRAINT "posthog_experiment_holdout_id_ffd173dd_fk_posthog_e" REFERENCES "posthog_experimentholdout"("id") DEFERRABLE INITIALLY DEFERRED; -- existing-table-constraint-ignore + SET CONSTRAINTS "posthog_experiment_holdout_id_ffd173dd_fk_posthog_e" IMMEDIATE; -- existing-table-constraint-ignore + """, + reverse_sql=""" + ALTER TABLE "posthog_experiment" DROP COLUMN IF EXISTS "holdout_id"; + """, + ), + # We add CONCURRENTLY to the create command + migrations.RunSQL( + """ + CREATE INDEX CONCURRENTLY "posthog_experiment_holdout_id_ffd173dd_fk_posthog_e" ON "posthog_experiment" ("holdout_id"); + """, + reverse_sql=""" + DROP INDEX IF EXISTS "posthog_experiment_holdout_id_ffd173dd_fk_posthog_e"; + """, + ), + ], + ), + ] diff --git a/posthog/models/experiment.py b/posthog/models/experiment.py index f594c0faf5ed8..2266982282892 100644 --- a/posthog/models/experiment.py +++ b/posthog/models/experiment.py @@ -30,6 +30,8 @@ class ExperimentType(models.TextChoices): created_by = models.ForeignKey("User", on_delete=models.SET_NULL, null=True) feature_flag = models.ForeignKey("FeatureFlag", blank=False, on_delete=models.RESTRICT) exposure_cohort = models.ForeignKey("Cohort", on_delete=models.SET_NULL, null=True) + holdout = models.ForeignKey("ExperimentHoldout", on_delete=models.SET_NULL, null=True) + start_date = models.DateTimeField(null=True) end_date = models.DateTimeField(null=True) created_at = models.DateTimeField(default=timezone.now) @@ -46,3 +48,17 @@ def get_feature_flag_key(self): @property def is_draft(self): return not self.start_date + + +class ExperimentHoldout(models.Model): + name = models.CharField(max_length=400) + description = models.CharField(max_length=400, null=True, blank=True) + team = models.ForeignKey("Team", on_delete=models.CASCADE) + + # Filters define the definition of the holdout + # This is then replicated across flags for experiments in the holdout + filters = models.JSONField(default=list) + + created_by = models.ForeignKey("User", on_delete=models.SET_NULL, null=True) + created_at = models.DateTimeField(default=timezone.now) + updated_at = models.DateTimeField(auto_now=True) diff --git a/posthog/models/feature_flag/flag_matching.py b/posthog/models/feature_flag/flag_matching.py index e898c3ff9d315..573d766e65498 100644 --- a/posthog/models/feature_flag/flag_matching.py +++ b/posthog/models/feature_flag/flag_matching.py @@ -334,7 +334,7 @@ def is_holdout_condition_match(self, feature_flag: FeatureFlag) -> tuple[bool, s # rollout_percentage is None (=100%), or we are inside holdout rollout bound. # Thus, we match. Now get the variant override for the holdout condition. variant_override = condition.get("variant") - if variant_override in [variant["key"] for variant in feature_flag.variants]: + if variant_override: variant = variant_override else: variant = self.get_matching_variant(feature_flag) diff --git a/posthog/test/test_feature_flag.py b/posthog/test/test_feature_flag.py index 440fbc018eb05..cb36678a9296c 100644 --- a/posthog/test/test_feature_flag.py +++ b/posthog/test/test_feature_flag.py @@ -886,11 +886,6 @@ def test_feature_flag_with_holdout_filter(self): "name": "Third Variant", "rollout_percentage": 25, }, - { - "key": "holdout", - "name": "Hold out variant", - "rollout_percentage": 0, - }, ] } feature_flag = self.create_feature_flag(