diff --git a/posthog/management/commands/create_batch_export_from_app.py b/posthog/management/commands/create_batch_export_from_app.py index 20b0b4c89ca86..eadf71532db02 100644 --- a/posthog/management/commands/create_batch_export_from_app.py +++ b/posthog/management/commands/create_batch_export_from_app.py @@ -48,6 +48,12 @@ def add_arguments(self, parser): default=False, help="Backfill the newly created BatchExport with the last period of data.", ) + parser.add_argument( + "--migrate-disabled-plugin-config", + action="store_true", + default=False, + help="Migrate a PluginConfig even if its disabled.", + ) def handle(self, *args, **options): """Handle creation of a BatchExport from a given PluginConfig.""" @@ -82,8 +88,8 @@ def handle(self, *args, **options): "destination_data": destination_data, } - if dry_run is True: - self.stdout.write("No BatchExport will be created as this is a dry run or confirmation check rejected.") + if dry_run is True or (options["migrate_disabled_plugin_config"] is False and plugin_config.enabled is False): + self.stdout.write("No BatchExport will be created as this is a dry run or existing plugin is disabled.") return json.dumps(batch_export_data, indent=4, default=str) else: destination = BatchExportDestination(**batch_export_data["destination_data"]) diff --git a/posthog/management/commands/test/test_create_batch_export_from_app.py b/posthog/management/commands/test/test_create_batch_export_from_app.py index 4a51975d86648..bbbb36079d013 100644 --- a/posthog/management/commands/test/test_create_batch_export_from_app.py +++ b/posthog/management/commands/test/test_create_batch_export_from_app.py @@ -1,4 +1,5 @@ import datetime as dt +import itertools import json import typing @@ -116,6 +117,20 @@ def plugin_config(request, s3_plugin_config, snowflake_plugin_config) -> PluginC raise ValueError(f"Unsupported plugin: {request.param}") +@pytest.fixture +def disabled_plugin_config(request, s3_plugin_config, snowflake_plugin_config) -> PluginConfig: + if request.param == "S3": + s3_plugin_config.enabled = False + s3_plugin_config.save() + return s3_plugin_config + elif request.param == "Snowflake": + snowflake_plugin_config.enabled = False + snowflake_plugin_config.save() + return snowflake_plugin_config + else: + raise ValueError(f"Unsupported plugin: {request.param}") + + @pytest.mark.django_db @pytest.mark.parametrize( "plugin_config,config,expected_type", @@ -155,7 +170,6 @@ def test_create_batch_export_from_app_fails_with_mismatched_team_id(plugin_confi @pytest.mark.parametrize("plugin_config", ["S3", "Snowflake"], indirect=True) def test_create_batch_export_from_app_dry_run(plugin_config): """Test a dry_run of the create_batch_export_from_app command.""" - output = call_command( "create_batch_export_from_app", f"--plugin-config-id={plugin_config.id}", @@ -166,6 +180,7 @@ def test_create_batch_export_from_app_dry_run(plugin_config): batch_export_data = json.loads(output) + assert "id" not in batch_export_data assert batch_export_data["team_id"] == plugin_config.team.id assert batch_export_data["interval"] == "hour" assert batch_export_data["name"] == f"{export_type} Export" @@ -178,19 +193,14 @@ def test_create_batch_export_from_app_dry_run(plugin_config): @pytest.mark.django_db @pytest.mark.parametrize( "interval,plugin_config,disable_plugin_config", - [ - ("hour", "S3", True), - ("hour", "S3", False), - ("day", "S3", True), - ("day", "S3", False), - ("hour", "Snowflake", True), - ("hour", "Snowflake", False), - ("day", "Snowflake", True), - ("day", "Snowflake", False), - ], + itertools.product(["hour", "day"], ["S3", "Snowflake"], [True, False]), indirect=["plugin_config"], ) -def test_create_batch_export_from_app(interval, plugin_config, disable_plugin_config): +def test_create_batch_export_from_app( + interval, + plugin_config, + disable_plugin_config, +): """Test a live run of the create_batch_export_from_app command.""" args = [ f"--plugin-config-id={plugin_config.id}", @@ -237,6 +247,69 @@ def test_create_batch_export_from_app(interval, plugin_config, disable_plugin_co assert args[key] == expected +@pytest.mark.django_db +@pytest.mark.parametrize( + "interval,disabled_plugin_config,migrate_disabled_plugin_config", + itertools.product(["hour", "day"], ["S3", "Snowflake"], [True, False]), + indirect=["disabled_plugin_config"], +) +def test_create_batch_export_from_app_with_disabled_plugin( + interval, + disabled_plugin_config, + migrate_disabled_plugin_config, +): + """Test a live run of the create_batch_export_from_app command.""" + args = [ + f"--plugin-config-id={disabled_plugin_config.id}", + f"--team-id={disabled_plugin_config.team.id}", + f"--interval={interval}", + ] + if migrate_disabled_plugin_config: + args.append("--migrate-disabled-plugin-config") + + output = call_command("create_batch_export_from_app", *args) + + disabled_plugin_config.refresh_from_db() + assert disabled_plugin_config.enabled is False + + export_type, config = map_plugin_config_to_destination(disabled_plugin_config) + + batch_export_data = json.loads(output) + + assert batch_export_data["team_id"] == disabled_plugin_config.team.id + assert batch_export_data["interval"] == interval + assert batch_export_data["name"] == f"{export_type} Export" + assert batch_export_data["destination_data"] == { + "type": export_type, + "config": config, + } + + if not migrate_disabled_plugin_config: + assert "id" not in batch_export_data + return + + assert "id" in batch_export_data + + temporal = sync_connect() + + schedule = describe_schedule(temporal, str(batch_export_data["id"])) + expected_interval = dt.timedelta(**{f"{interval}s": 1}) + assert schedule.schedule.spec.intervals[0].every == expected_interval + + codec = EncryptionCodec(settings=settings) + decoded_payload = async_to_sync(codec.decode)(schedule.schedule.action.args) + args = json.loads(decoded_payload[0].data) + + # Common inputs + assert args["team_id"] == disabled_plugin_config.team.pk + assert args["batch_export_id"] == str(batch_export_data["id"]) + assert args["interval"] == interval + + # Type specific inputs + for key, expected in config.items(): + assert args[key] == expected + + @async_to_sync async def list_workflows(temporal, schedule_id: str): """List Workflows scheduled by given Schedule."""