diff --git a/bin/migrate_kafka_data.py b/bin/migrate_kafka_data.py index 87eaf391657e2..3da55ed538c06 100755 --- a/bin/migrate_kafka_data.py +++ b/bin/migrate_kafka_data.py @@ -21,7 +21,6 @@ import argparse import sys -from typing import List from kafka import KafkaAdminClient, KafkaConsumer, KafkaProducer from kafka.errors import KafkaError @@ -192,7 +191,7 @@ def handle(**options): print("Polling for messages") # noqa: T201 messages_by_topic = consumer.poll(timeout_ms=timeout_ms) - futures: List[FutureRecordMetadata] = [] + futures: list[FutureRecordMetadata] = [] if not messages_by_topic: break diff --git a/ee/api/authentication.py b/ee/api/authentication.py index 2dfb6c7b9f053..f2850bcfb5f61 100644 --- a/ee/api/authentication.py +++ b/ee/api/authentication.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Union +from typing import Any, Union from django.core.exceptions import ValidationError as DjangoValidationError from django.http.response import HttpResponse @@ -91,8 +91,8 @@ def auth_url(self): def _get_attr( self, - response_attributes: Dict[str, Any], - attribute_names: List[str], + response_attributes: dict[str, Any], + attribute_names: list[str], optional: bool = False, ) -> str: """ diff --git a/ee/api/dashboard_collaborator.py b/ee/api/dashboard_collaborator.py index 998eeba8238f9..6a004215d96a3 100644 --- a/ee/api/dashboard_collaborator.py +++ b/ee/api/dashboard_collaborator.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, cast +from typing import Any, cast from django.db import IntegrityError from rest_framework import exceptions, mixins, serializers, viewsets @@ -45,7 +45,7 @@ class Meta: ] read_only_fields = ["id", "dashboard_id", "user", "user"] - def validate(self, attrs: Dict[str, Any]) -> Dict[str, Any]: + def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: dashboard: Dashboard = self.context["dashboard"] dashboard_permissions = self.user_permissions.dashboard(dashboard) if dashboard_permissions.effective_restriction_level <= Dashboard.RestrictionLevel.EVERYONE_IN_PROJECT_CAN_EDIT: @@ -96,7 +96,7 @@ class DashboardCollaboratorViewSet( serializer_class = DashboardCollaboratorSerializer filter_rewrite_rules = {"team_id": "dashboard__team_id"} - def get_serializer_context(self) -> Dict[str, Any]: + def get_serializer_context(self) -> dict[str, Any]: context = super().get_serializer_context() try: context["dashboard"] = Dashboard.objects.get(id=context["dashboard_id"]) diff --git a/ee/api/role.py b/ee/api/role.py index 44909f504eece..0c4894c2779ce 100644 --- a/ee/api/role.py +++ b/ee/api/role.py @@ -1,4 +1,4 @@ -from typing import List, cast +from typing import cast from django.db import IntegrityError from rest_framework import mixins, serializers, viewsets @@ -76,7 +76,7 @@ def get_members(self, role: Role): return RoleMembershipSerializer(members, many=True).data def get_associated_flags(self, role: Role): - associated_flags: List[dict] = [] + associated_flags: list[dict] = [] role_access_objects = FeatureFlagRoleAccess.objects.filter(role=role).values_list("feature_flag_id") flags = FeatureFlag.objects.filter(id__in=role_access_objects) diff --git a/ee/api/sentry_stats.py b/ee/api/sentry_stats.py index 52b16647c2cbf..06b4e53b1bd59 100644 --- a/ee/api/sentry_stats.py +++ b/ee/api/sentry_stats.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union import requests from django.http import HttpRequest, JsonResponse @@ -9,8 +9,8 @@ from posthog.models.instance_setting import get_instance_settings -def get_sentry_stats(start_time: str, end_time: str) -> Tuple[dict, int]: - sentry_config: Dict[str, str] = get_instance_settings(["SENTRY_AUTH_TOKEN", "SENTRY_ORGANIZATION"]) +def get_sentry_stats(start_time: str, end_time: str) -> tuple[dict, int]: + sentry_config: dict[str, str] = get_instance_settings(["SENTRY_AUTH_TOKEN", "SENTRY_ORGANIZATION"]) org_slug = sentry_config.get("SENTRY_ORGANIZATION") token = sentry_config.get("SENTRY_AUTH_TOKEN") @@ -41,9 +41,9 @@ def get_sentry_stats(start_time: str, end_time: str) -> Tuple[dict, int]: def get_tagged_issues_stats( - start_time: str, end_time: str, tags: Dict[str, str], target_issues: List[str] -) -> Dict[str, Any]: - sentry_config: Dict[str, str] = get_instance_settings(["SENTRY_AUTH_TOKEN", "SENTRY_ORGANIZATION"]) + start_time: str, end_time: str, tags: dict[str, str], target_issues: list[str] +) -> dict[str, Any]: + sentry_config: dict[str, str] = get_instance_settings(["SENTRY_AUTH_TOKEN", "SENTRY_ORGANIZATION"]) org_slug = sentry_config.get("SENTRY_ORGANIZATION") token = sentry_config.get("SENTRY_AUTH_TOKEN") @@ -58,7 +58,7 @@ def get_tagged_issues_stats( for tag, value in tags.items(): query += f" {tag}:{value}" - params: Dict[str, Union[list, str]] = { + params: dict[str, Union[list, str]] = { "start": start_time, "end": end_time, "sort": "freq", @@ -89,8 +89,8 @@ def get_stats_for_timerange( base_end_time: str, target_start_time: str, target_end_time: str, - tags: Optional[Dict[str, str]] = None, -) -> Tuple[int, int]: + tags: Optional[dict[str, str]] = None, +) -> tuple[int, int]: base_counts, base_total_count = get_sentry_stats(base_start_time, base_end_time) target_counts, target_total_count = get_sentry_stats(target_start_time, target_end_time) diff --git a/ee/api/subscription.py b/ee/api/subscription.py index 412ddc5cfaff3..9f8881026fbdb 100644 --- a/ee/api/subscription.py +++ b/ee/api/subscription.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any import jwt from django.db.models import QuerySet @@ -67,7 +67,7 @@ def validate(self, attrs): return attrs - def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> Subscription: + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Subscription: request = self.context["request"] validated_data["team_id"] = self.context["team_id"] validated_data["created_by"] = request.user diff --git a/ee/api/test/base.py b/ee/api/test/base.py index 55e7930bfadf1..066dcc373d6d5 100644 --- a/ee/api/test/base.py +++ b/ee/api/test/base.py @@ -1,5 +1,5 @@ import datetime -from typing import Dict, Optional, cast +from typing import Optional, cast from zoneinfo import ZoneInfo @@ -20,7 +20,7 @@ class LicensedTestMixin: def license_required_response( self, message: str = "This feature is part of the premium PostHog offering. Self-hosted licenses are no longer available for purchase. Please contact sales@posthog.com to discuss options.", - ) -> Dict[str, Optional[str]]: + ) -> dict[str, Optional[str]]: return { "type": "server_error", "code": "payment_required", diff --git a/ee/api/test/fixtures/available_product_features.py b/ee/api/test/fixtures/available_product_features.py index 5be816a169ba3..8cc5413754db1 100644 --- a/ee/api/test/fixtures/available_product_features.py +++ b/ee/api/test/fixtures/available_product_features.py @@ -1,6 +1,6 @@ -from typing import Any, Dict, List +from typing import Any -AVAILABLE_PRODUCT_FEATURES: List[Dict[str, Any]] = [ +AVAILABLE_PRODUCT_FEATURES: list[dict[str, Any]] = [ { "description": "Create playlists of certain session recordings to easily find and watch them again in the future.", "key": "recordings_playlists", diff --git a/ee/api/test/test_authentication.py b/ee/api/test/test_authentication.py index 00fca66b2914b..451efdd3d3777 100644 --- a/ee/api/test/test_authentication.py +++ b/ee/api/test/test_authentication.py @@ -364,7 +364,6 @@ def test_can_login_with_saml(self): with open( os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response"), - "r", encoding="utf_8", ) as f: saml_response = f.read() @@ -407,7 +406,6 @@ def test_saml_jit_provisioning_and_assertion_with_different_attribute_names(self with open( os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response_alt_attribute_names"), - "r", encoding="utf_8", ) as f: saml_response = f.read() @@ -474,7 +472,6 @@ def test_cannot_login_with_improperly_signed_payload(self): with open( os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response"), - "r", encoding="utf_8", ) as f: saml_response = f.read() @@ -514,7 +511,6 @@ def test_cannot_signup_with_saml_if_jit_provisioning_is_disabled(self): with open( os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response"), - "r", encoding="utf_8", ) as f: saml_response = f.read() @@ -552,7 +548,6 @@ def test_cannot_create_account_without_first_name_in_payload(self): with open( os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response_no_first_name"), - "r", encoding="utf_8", ) as f: saml_response = f.read() @@ -594,7 +589,6 @@ def test_cannot_login_with_saml_on_unverified_domain(self): with open( os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response"), - "r", encoding="utf_8", ) as f: saml_response = f.read() @@ -683,7 +677,6 @@ def test_cannot_use_saml_without_enterprise_license(self): with open( os.path.join(CURRENT_FOLDER, "fixtures/saml_login_response"), - "r", encoding="utf_8", ) as f: saml_response = f.read() diff --git a/ee/api/test/test_billing.py b/ee/api/test/test_billing.py index c1698bd1cae7f..94eed34d29d79 100644 --- a/ee/api/test/test_billing.py +++ b/ee/api/test/test_billing.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Dict, List +from typing import Any from unittest.mock import MagicMock, patch from uuid import uuid4 from zoneinfo import ZoneInfo @@ -22,7 +22,7 @@ from posthog.test.base import APIBaseTest, _create_event, flush_persons_and_events -def create_billing_response(**kwargs) -> Dict[str, Any]: +def create_billing_response(**kwargs) -> dict[str, Any]: data: Any = {"license": {"type": "cloud"}} data.update(kwargs) return data @@ -106,7 +106,7 @@ def create_billing_customer(**kwargs) -> CustomerInfo: return data -def create_billing_products_response(**kwargs) -> Dict[str, List[CustomerProduct]]: +def create_billing_products_response(**kwargs) -> dict[str, list[CustomerProduct]]: data: Any = { "products": [ CustomerProduct( diff --git a/ee/api/test/test_capture.py b/ee/api/test/test_capture.py index 891a9759a80c5..4f716d785098c 100644 --- a/ee/api/test/test_capture.py +++ b/ee/api/test/test_capture.py @@ -68,26 +68,26 @@ def test_produce_to_kafka(self, kafka_produce): self.assertEqual(event2_data["properties"]["distinct_id"], "id2") # Make sure we're producing data correctly in the way the plugin server expects - self.assertEquals(type(kafka_produce_call1["data"]["distinct_id"]), str) - self.assertEquals(type(kafka_produce_call2["data"]["distinct_id"]), str) + self.assertEqual(type(kafka_produce_call1["data"]["distinct_id"]), str) + self.assertEqual(type(kafka_produce_call2["data"]["distinct_id"]), str) self.assertIn(type(kafka_produce_call1["data"]["ip"]), [str, type(None)]) self.assertIn(type(kafka_produce_call2["data"]["ip"]), [str, type(None)]) - self.assertEquals(type(kafka_produce_call1["data"]["site_url"]), str) - self.assertEquals(type(kafka_produce_call2["data"]["site_url"]), str) + self.assertEqual(type(kafka_produce_call1["data"]["site_url"]), str) + self.assertEqual(type(kafka_produce_call2["data"]["site_url"]), str) - self.assertEquals(type(kafka_produce_call1["data"]["token"]), str) - self.assertEquals(type(kafka_produce_call2["data"]["token"]), str) + self.assertEqual(type(kafka_produce_call1["data"]["token"]), str) + self.assertEqual(type(kafka_produce_call2["data"]["token"]), str) - self.assertEquals(type(kafka_produce_call1["data"]["sent_at"]), str) - self.assertEquals(type(kafka_produce_call2["data"]["sent_at"]), str) + self.assertEqual(type(kafka_produce_call1["data"]["sent_at"]), str) + self.assertEqual(type(kafka_produce_call2["data"]["sent_at"]), str) - self.assertEquals(type(event1_data["properties"]), dict) - self.assertEquals(type(event2_data["properties"]), dict) + self.assertEqual(type(event1_data["properties"]), dict) + self.assertEqual(type(event2_data["properties"]), dict) - self.assertEquals(type(kafka_produce_call1["data"]["uuid"]), str) - self.assertEquals(type(kafka_produce_call2["data"]["uuid"]), str) + self.assertEqual(type(kafka_produce_call1["data"]["uuid"]), str) + self.assertEqual(type(kafka_produce_call2["data"]["uuid"]), str) @patch("posthog.kafka_client.client._KafkaProducer.produce") def test_capture_event_with_uuid_in_payload(self, kafka_produce): diff --git a/ee/api/test/test_dashboard.py b/ee/api/test/test_dashboard.py index 8c39a17135db0..e494dfbce7a44 100644 --- a/ee/api/test/test_dashboard.py +++ b/ee/api/test/test_dashboard.py @@ -106,7 +106,7 @@ def test_cannot_set_dashboard_to_restrict_editing_as_other_user_who_is_project_m response_data = response.json() self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - self.assertEquals( + self.assertEqual( response_data, self.permission_denied_response( "Only the dashboard owner and project admins have the restriction rights required to change the dashboard's restriction level." @@ -178,7 +178,7 @@ def test_cannot_edit_restricted_dashboard_as_other_user_who_is_project_member(se response_data = response.json() self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - self.assertEquals( + self.assertEqual( response_data, self.permission_denied_response("You don't have edit permissions for this dashboard."), ) @@ -262,7 +262,7 @@ def test_sharing_edits_limited_to_collaborators(self): response_data = response.json() self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - self.assertEquals( + self.assertEqual( response_data, self.permission_denied_response("You don't have edit permissions for this dashboard."), ) diff --git a/ee/api/test/test_event_definition.py b/ee/api/test/test_event_definition.py index 6e3cbb8775fb9..2aa87e63e2e65 100644 --- a/ee/api/test/test_event_definition.py +++ b/ee/api/test/test_event_definition.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import cast, Optional, List, Dict, Any +from typing import cast, Optional, Any import dateutil.parser from django.utils import timezone @@ -26,7 +26,7 @@ class TestEventDefinitionEnterpriseAPI(APIBaseTest): Ignoring the verified field we'd expect ordering purchase, watched_movie, entered_free_trial, $pageview With it we expect watched_movie, entered_free_trial, purchase, $pageview """ - EXPECTED_EVENT_DEFINITIONS: List[Dict[str, Any]] = [ + EXPECTED_EVENT_DEFINITIONS: list[dict[str, Any]] = [ {"name": "purchase", "verified": None}, {"name": "entered_free_trial", "verified": True}, {"name": "watched_movie", "verified": True}, diff --git a/ee/api/test/test_insight.py b/ee/api/test/test_insight.py index 00863551500ee..7db46bf79dea1 100644 --- a/ee/api/test/test_insight.py +++ b/ee/api/test/test_insight.py @@ -1,6 +1,6 @@ import json from datetime import timedelta -from typing import cast, Optional, List, Dict +from typing import cast, Optional from django.test import override_settings from django.utils import timezone from freezegun import freeze_time @@ -305,7 +305,7 @@ def test_cannot_update_restricted_insight_as_other_user_who_is_project_member(se dashboard.refresh_from_db() self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - self.assertEquals( + self.assertEqual( response_data, self.permission_denied_response( "This insight is on a dashboard that can only be edited by its owner, team members invited to editing the dashboard, and project admins." @@ -547,7 +547,7 @@ def test_an_insight_on_restricted_dashboard_does_not_restrict_admin(self) -> Non @override_settings(PERSON_ON_EVENTS_OVERRIDE=False, PERSON_ON_EVENTS_V2_OVERRIDE=False) @snapshot_postgres_queries def test_listing_insights_does_not_nplus1(self) -> None: - query_counts: List[int] = [] + query_counts: list[int] = [] queries = [] for i in range(5): @@ -587,10 +587,10 @@ def test_listing_insights_does_not_nplus1(self) -> None: f"received query counts\n\n{query_counts}", ) - def assert_insight_activity(self, insight_id: Optional[int], expected: List[Dict]): + def assert_insight_activity(self, insight_id: Optional[int], expected: list[dict]): activity_response = self.dashboard_api.get_insight_activity(insight_id) - activity: List[Dict] = activity_response["results"] + activity: list[dict] = activity_response["results"] self.maxDiff = None assert activity == expected diff --git a/ee/api/test/test_integration.py b/ee/api/test/test_integration.py index d675415e4bd81..7f30635b5afdf 100644 --- a/ee/api/test/test_integration.py +++ b/ee/api/test/test_integration.py @@ -25,7 +25,7 @@ def _headers_for_payload(self, payload: Any): signature = ( "v0=" + hmac.new( - "not-so-secret".encode("utf-8"), + b"not-so-secret", sig_basestring.encode("utf-8"), digestmod=hashlib.sha256, ).hexdigest() diff --git a/ee/api/test/test_property_definition.py b/ee/api/test/test_property_definition.py index ef8d4dd928540..effa43a9f4b18 100644 --- a/ee/api/test/test_property_definition.py +++ b/ee/api/test/test_property_definition.py @@ -1,4 +1,4 @@ -from typing import cast, Optional, List, Dict +from typing import cast, Optional from freezegun import freeze_time import pytest from django.db.utils import IntegrityError @@ -450,7 +450,7 @@ def test_list_property_definitions(self): plan="enterprise", valid_until=timezone.datetime(2500, 1, 19, 3, 14, 7) ) - properties: List[Dict] = [ + properties: list[dict] = [ {"name": "1_when_verified", "verified": True}, {"name": "2_when_verified", "verified": True}, {"name": "3_when_verified", "verified": True}, diff --git a/ee/api/test/test_time_to_see_data.py b/ee/api/test/test_time_to_see_data.py index 4c5a50d51e58f..1ad6b4b08135b 100644 --- a/ee/api/test/test_time_to_see_data.py +++ b/ee/api/test/test_time_to_see_data.py @@ -1,6 +1,6 @@ import json from dataclasses import asdict, dataclass, field -from typing import Any, List +from typing import Any from unittest import mock import pytest @@ -64,7 +64,7 @@ def test_sessions_api(self): ) response = self.client.post("/api/time_to_see_data/sessions").json() - self.assertEquals( + self.assertEqual( response, [ { @@ -209,18 +209,18 @@ class QueryLogRow: query_time_range_days: int = 1 has_joins: int = 0 has_json_operations: int = 0 - filter_by_type: List[str] = field(default_factory=list) - breakdown_by: List[str] = field(default_factory=list) - entity_math: List[str] = field(default_factory=list) + filter_by_type: list[str] = field(default_factory=list) + breakdown_by: list[str] = field(default_factory=list) + entity_math: list[str] = field(default_factory=list) filter: str = "" ProfileEvents: dict = field(default_factory=dict) - tables: List[str] = field(default_factory=list) - columns: List[str] = field(default_factory=list) + tables: list[str] = field(default_factory=list) + columns: list[str] = field(default_factory=list) query: str = "" log_comment = "" -def insert(table: str, rows: List): +def insert(table: str, rows: list): columns = asdict(rows[0]).keys() all_values, params = [], {} diff --git a/ee/benchmarks/benchmarks.py b/ee/benchmarks/benchmarks.py index 83e82df068f9b..d999467779a4c 100644 --- a/ee/benchmarks/benchmarks.py +++ b/ee/benchmarks/benchmarks.py @@ -2,7 +2,6 @@ # Needs to be first to set up django environment from .helpers import benchmark_clickhouse, no_materialized_columns, now from datetime import timedelta -from typing import List, Tuple from ee.clickhouse.materialized_columns.analyze import ( backfill_materialized_columns, get_materialized_columns, @@ -29,7 +28,7 @@ from posthog.models.property import PropertyName, TableWithProperties from posthog.constants import FunnelCorrelationType -MATERIALIZED_PROPERTIES: List[Tuple[TableWithProperties, PropertyName]] = [ +MATERIALIZED_PROPERTIES: list[tuple[TableWithProperties, PropertyName]] = [ ("events", "$host"), ("events", "$current_url"), ("events", "$event_type"), diff --git a/ee/billing/billing_manager.py b/ee/billing/billing_manager.py index da95c0871f55a..c301e80f8c27f 100644 --- a/ee/billing/billing_manager.py +++ b/ee/billing/billing_manager.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from typing import Any, Dict, Optional, cast +from typing import Any, Optional, cast import jwt import requests @@ -53,7 +53,7 @@ class BillingManager: def __init__(self, license): self.license = license or get_cached_instance_license() - def get_billing(self, organization: Optional[Organization], plan_keys: Optional[str]) -> Dict[str, Any]: + def get_billing(self, organization: Optional[Organization], plan_keys: Optional[str]) -> dict[str, Any]: if organization and self.license and self.license.is_v2_license: billing_service_response = self._get_billing(organization) @@ -63,7 +63,7 @@ def get_billing(self, organization: Optional[Organization], plan_keys: Optional[ if organization and billing_service_response: self.update_org_details(organization, billing_service_response) - response: Dict[str, Any] = {"available_features": []} + response: dict[str, Any] = {"available_features": []} response["license"] = {"plan": self.license.plan} @@ -102,7 +102,7 @@ def get_billing(self, organization: Optional[Organization], plan_keys: Optional[ return response - def update_billing(self, organization: Organization, data: Dict[str, Any]) -> None: + def update_billing(self, organization: Organization, data: dict[str, Any]) -> None: res = requests.patch( f"{BILLING_SERVICE_URL}/api/billing/", headers=self.get_auth_headers(organization), diff --git a/ee/billing/billing_types.py b/ee/billing/billing_types.py index 6151ad3288051..0761e02e807ef 100644 --- a/ee/billing/billing_types.py +++ b/ee/billing/billing_types.py @@ -1,5 +1,5 @@ from decimal import Decimal -from typing import Dict, List, Optional, TypedDict +from typing import Optional, TypedDict from posthog.constants import AvailableFeature @@ -18,7 +18,7 @@ class CustomerProduct(TypedDict): image_url: Optional[str] type: str free_allocation: int - tiers: List[Tier] + tiers: list[Tier] tiered: bool unit_amount_usd: Optional[Decimal] current_amount_usd: Decimal @@ -51,16 +51,16 @@ class CustomerInfo(TypedDict): deactivated: bool has_active_subscription: bool billing_period: BillingPeriod - available_features: List[AvailableFeature] + available_features: list[AvailableFeature] current_total_amount_usd: Optional[str] current_total_amount_usd_after_discount: Optional[str] - products: Optional[List[CustomerProduct]] - custom_limits_usd: Optional[Dict[str, str]] - usage_summary: Optional[Dict[str, Dict[str, Optional[int]]]] + products: Optional[list[CustomerProduct]] + custom_limits_usd: Optional[dict[str, str]] + usage_summary: Optional[dict[str, dict[str, Optional[int]]]] free_trial_until: Optional[str] discount_percent: Optional[int] discount_amount_usd: Optional[str] - customer_trust_scores: Dict[str, int] + customer_trust_scores: dict[str, int] class BillingStatus(TypedDict): diff --git a/ee/billing/quota_limiting.py b/ee/billing/quota_limiting.py index 1c50b69803a57..8f5864c3ed513 100644 --- a/ee/billing/quota_limiting.py +++ b/ee/billing/quota_limiting.py @@ -1,7 +1,8 @@ import copy from datetime import datetime, timedelta from enum import Enum -from typing import Dict, List, Mapping, Optional, Sequence, Tuple, TypedDict, cast +from typing import Optional, TypedDict, cast +from collections.abc import Mapping, Sequence import dateutil.parser import posthoganalytics @@ -66,13 +67,13 @@ def add_limited_team_tokens(resource: QuotaResource, tokens: Mapping[str, int], redis_client.zadd(f"{cache_key}{resource.value}", tokens) # type: ignore # (zadd takes a Mapping[str, int] but the derived Union type is wrong) -def remove_limited_team_tokens(resource: QuotaResource, tokens: List[str], cache_key: QuotaLimitingCaches) -> None: +def remove_limited_team_tokens(resource: QuotaResource, tokens: list[str], cache_key: QuotaLimitingCaches) -> None: redis_client = get_client() redis_client.zrem(f"{cache_key}{resource.value}", *tokens) @cache_for(timedelta(seconds=30), background_refresh=True) -def list_limited_team_attributes(resource: QuotaResource, cache_key: QuotaLimitingCaches) -> List[str]: +def list_limited_team_attributes(resource: QuotaResource, cache_key: QuotaLimitingCaches) -> list[str]: now = timezone.now() redis_client = get_client() results = redis_client.zrangebyscore(f"{cache_key}{resource.value}", min=now.timestamp(), max="+inf") @@ -86,7 +87,7 @@ class UsageCounters(TypedDict): def org_quota_limited_until( - organization: Organization, resource: QuotaResource, previously_quota_limited_team_tokens: List[str] + organization: Organization, resource: QuotaResource, previously_quota_limited_team_tokens: list[str] ) -> Optional[OrgQuotaLimitingInformation]: if not organization.usage: return None @@ -265,7 +266,7 @@ def sync_org_quota_limits(organization: Organization): def get_team_attribute_by_quota_resource(organization: Organization, resource: QuotaResource): if resource in [QuotaResource.EVENTS, QuotaResource.RECORDINGS]: - team_tokens: List[str] = [x for x in list(organization.teams.values_list("api_token", flat=True)) if x] + team_tokens: list[str] = [x for x in list(organization.teams.values_list("api_token", flat=True)) if x] if not team_tokens: capture_exception(Exception(f"quota_limiting: No team tokens found for organization: {organization.id}")) @@ -274,7 +275,7 @@ def get_team_attribute_by_quota_resource(organization: Organization, resource: Q return team_tokens if resource == QuotaResource.ROWS_SYNCED: - team_ids: List[str] = [x for x in list(organization.teams.values_list("id", flat=True)) if x] + team_ids: list[str] = [x for x in list(organization.teams.values_list("id", flat=True)) if x] if not team_ids: capture_exception(Exception(f"quota_limiting: No team ids found for organization: {organization.id}")) @@ -322,7 +323,7 @@ def set_org_usage_summary( def update_all_org_billing_quotas( dry_run: bool = False, -) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]: +) -> tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: period = get_current_day() period_start, period_end = period @@ -352,8 +353,8 @@ def update_all_org_billing_quotas( ) ) - todays_usage_report: Dict[str, UsageCounters] = {} - orgs_by_id: Dict[str, Organization] = {} + todays_usage_report: dict[str, UsageCounters] = {} + orgs_by_id: dict[str, Organization] = {} # we iterate through all teams, and add their usage to the organization they belong to for team in teams: @@ -373,12 +374,12 @@ def update_all_org_billing_quotas( for field in team_report: org_report[field] += team_report[field] # type: ignore - quota_limited_orgs: Dict[str, Dict[str, int]] = {x.value: {} for x in QuotaResource} - quota_limiting_suspended_orgs: Dict[str, Dict[str, int]] = {x.value: {} for x in QuotaResource} + quota_limited_orgs: dict[str, dict[str, int]] = {x.value: {} for x in QuotaResource} + quota_limiting_suspended_orgs: dict[str, dict[str, int]] = {x.value: {} for x in QuotaResource} # Get the current quota limits so we can track to poshog if it changes orgs_with_changes = set() - previously_quota_limited_team_tokens: Dict[str, List[str]] = {x.value: [] for x in QuotaResource} + previously_quota_limited_team_tokens: dict[str, list[str]] = {x.value: [] for x in QuotaResource} for field in quota_limited_orgs: previously_quota_limited_team_tokens[field] = list_limited_team_attributes( @@ -405,8 +406,8 @@ def update_all_org_billing_quotas( elif quota_limited_until: quota_limited_orgs[field][org_id] = quota_limited_until - quota_limited_teams: Dict[str, Dict[str, int]] = {x.value: {} for x in QuotaResource} - quota_limiting_suspended_teams: Dict[str, Dict[str, int]] = {x.value: {} for x in QuotaResource} + quota_limited_teams: dict[str, dict[str, int]] = {x.value: {} for x in QuotaResource} + quota_limiting_suspended_teams: dict[str, dict[str, int]] = {x.value: {} for x in QuotaResource} # Convert the org ids to team tokens for team in teams: diff --git a/ee/clickhouse/materialized_columns/analyze.py b/ee/clickhouse/materialized_columns/analyze.py index dac1aa6abc0f4..e8801fe17f606 100644 --- a/ee/clickhouse/materialized_columns/analyze.py +++ b/ee/clickhouse/materialized_columns/analyze.py @@ -1,6 +1,7 @@ import re from datetime import timedelta -from typing import Dict, Generator, List, Optional, Set, Tuple +from typing import Optional +from collections.abc import Generator import structlog @@ -27,18 +28,18 @@ from posthog.models.property_definition import PropertyDefinition from posthog.models.team import Team -Suggestion = Tuple[TableWithProperties, TableColumn, PropertyName] +Suggestion = tuple[TableWithProperties, TableColumn, PropertyName] logger = structlog.get_logger(__name__) class TeamManager: @instance_memoize - def person_properties(self, team_id: str) -> Set[str]: + def person_properties(self, team_id: str) -> set[str]: return self._get_properties(GET_PERSON_PROPERTIES_COUNT, team_id) @instance_memoize - def event_properties(self, team_id: str) -> Set[str]: + def event_properties(self, team_id: str) -> set[str]: return set( PropertyDefinition.objects.filter(team_id=team_id, type=PropertyDefinition.Type.EVENT).values_list( "name", flat=True @@ -46,17 +47,17 @@ def event_properties(self, team_id: str) -> Set[str]: ) @instance_memoize - def person_on_events_properties(self, team_id: str) -> Set[str]: + def person_on_events_properties(self, team_id: str) -> set[str]: return self._get_properties(GET_EVENT_PROPERTIES_COUNT.format(column_name="person_properties"), team_id) @instance_memoize - def group_on_events_properties(self, group_type_index: int, team_id: str) -> Set[str]: + def group_on_events_properties(self, group_type_index: int, team_id: str) -> set[str]: return self._get_properties( GET_EVENT_PROPERTIES_COUNT.format(column_name=f"group{group_type_index}_properties"), team_id, ) - def _get_properties(self, query, team_id) -> Set[str]: + def _get_properties(self, query, team_id) -> set[str]: rows = sync_execute(query, {"team_id": team_id}) return {name for name, _ in rows} @@ -86,12 +87,12 @@ def team_id(self) -> Optional[str]: return matches[0] if matches else None @cached_property - def _all_properties(self) -> List[Tuple[str, PropertyName]]: + def _all_properties(self) -> list[tuple[str, PropertyName]]: return re.findall(r"JSONExtract\w+\((\S+), '([^']+)'\)", self.query_string) def properties( self, team_manager: TeamManager - ) -> Generator[Tuple[TableWithProperties, TableColumn, PropertyName], None, None]: + ) -> Generator[tuple[TableWithProperties, TableColumn, PropertyName], None, None]: # Reverse-engineer whether a property is an "event" or "person" property by getting their event definitions. # :KLUDGE: Note that the same property will be found on both tables if both are used. # We try to hone in on the right column by looking at the column from which the property is extracted. @@ -124,7 +125,7 @@ def properties( yield "events", "group4_properties", property -def _analyze(since_hours_ago: int, min_query_time: int) -> List[Suggestion]: +def _analyze(since_hours_ago: int, min_query_time: int) -> list[Suggestion]: "Finds columns that should be materialized" raw_queries = sync_execute( @@ -179,7 +180,7 @@ def _analyze(since_hours_ago: int, min_query_time: int) -> List[Suggestion]: def materialize_properties_task( - columns_to_materialize: Optional[List[Suggestion]] = None, + columns_to_materialize: Optional[list[Suggestion]] = None, time_to_analyze_hours: int = MATERIALIZE_COLUMNS_ANALYSIS_PERIOD_HOURS, maximum: int = MATERIALIZE_COLUMNS_MAX_AT_ONCE, min_query_time: int = MATERIALIZE_COLUMNS_MINIMUM_QUERY_TIME, @@ -203,7 +204,7 @@ def materialize_properties_task( else: logger.info("Found no columns to materialize.") - properties: Dict[TableWithProperties, List[Tuple[PropertyName, TableColumn]]] = { + properties: dict[TableWithProperties, list[tuple[PropertyName, TableColumn]]] = { "events": [], "person": [], } diff --git a/ee/clickhouse/materialized_columns/columns.py b/ee/clickhouse/materialized_columns/columns.py index 71bfd5adcc751..1340abde0a682 100644 --- a/ee/clickhouse/materialized_columns/columns.py +++ b/ee/clickhouse/materialized_columns/columns.py @@ -1,6 +1,6 @@ import re from datetime import timedelta -from typing import Dict, List, Literal, Tuple, Union, cast +from typing import Literal, Union, cast from clickhouse_driver.errors import ServerException from django.utils.timezone import now @@ -36,7 +36,7 @@ @cache_for(timedelta(minutes=15)) def get_materialized_columns( table: TablesWithMaterializedColumns, -) -> Dict[Tuple[PropertyName, TableColumn], ColumnName]: +) -> dict[tuple[PropertyName, TableColumn], ColumnName]: rows = sync_execute( """ SELECT comment, name @@ -141,7 +141,7 @@ def add_minmax_index(table: TablesWithMaterializedColumns, column_name: str): def backfill_materialized_columns( table: TableWithProperties, - properties: List[Tuple[PropertyName, TableColumn]], + properties: list[tuple[PropertyName, TableColumn]], backfill_period: timedelta, test_settings=None, ) -> None: @@ -215,7 +215,7 @@ def _materialized_column_name( return f"{prefix}{property_str}{suffix}" -def _extract_property(comment: str) -> Tuple[PropertyName, TableColumn]: +def _extract_property(comment: str) -> tuple[PropertyName, TableColumn]: # Old style comments have the format "column_materializer::property", dealing with the default table column. # Otherwise, it's "column_materializer::table_column::property" split_column = comment.split("::", 2) diff --git a/ee/clickhouse/models/test/test_action.py b/ee/clickhouse/models/test/test_action.py index 692844e55c1e4..4f06b3e871a88 100644 --- a/ee/clickhouse/models/test/test_action.py +++ b/ee/clickhouse/models/test/test_action.py @@ -1,5 +1,4 @@ import dataclasses -from typing import List from posthog.client import sync_execute from posthog.hogql.hogql import HogQLContext @@ -22,7 +21,7 @@ class MockEvent: distinct_id: str -def _get_events_for_action(action: Action) -> List[MockEvent]: +def _get_events_for_action(action: Action) -> list[MockEvent]: hogql_context = HogQLContext(team_id=action.team_id) formatted_query, params = format_action_filter( team_id=action.team_id, action=action, prepend="", hogql_context=hogql_context diff --git a/ee/clickhouse/models/test/test_property.py b/ee/clickhouse/models/test/test_property.py index 913058d4ae1bd..6348697d84435 100644 --- a/ee/clickhouse/models/test/test_property.py +++ b/ee/clickhouse/models/test/test_property.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import List, Literal, Union, cast +from typing import Literal, Union, cast from uuid import UUID import pytest @@ -43,7 +43,7 @@ class TestPropFormat(ClickhouseTestMixin, BaseTest): CLASS_DATA_LEVEL_SETUP = False - def _run_query(self, filter: Filter, **kwargs) -> List: + def _run_query(self, filter: Filter, **kwargs) -> list: query, params = parse_prop_grouped_clauses( property_group=filter.property_groups, allow_denormalized_props=True, @@ -776,7 +776,7 @@ def test_parse_groups_persons(self): class TestPropDenormalized(ClickhouseTestMixin, BaseTest): CLASS_DATA_LEVEL_SETUP = False - def _run_query(self, filter: Filter, join_person_tables=False) -> List: + def _run_query(self, filter: Filter, join_person_tables=False) -> list: outer_properties = PropertyOptimizer().parse_property_groups(filter.property_groups).outer query, params = parse_prop_grouped_clauses( team_id=self.team.pk, @@ -1232,7 +1232,7 @@ def test_parse_groups_persons_edge_case_with_single_filter(snapshot): @pytest.mark.parametrize("breakdown, table, query_alias, column, expected", TEST_BREAKDOWN_PROCESSING) def test_breakdown_query_expression( clean_up_materialised_columns, - breakdown: Union[str, List[str]], + breakdown: Union[str, list[str]], table: TableWithProperties, query_alias: Literal["prop", "value"], column: str, @@ -1281,7 +1281,7 @@ def test_breakdown_query_expression( ) def test_breakdown_query_expression_materialised( clean_up_materialised_columns, - breakdown: Union[str, List[str]], + breakdown: Union[str, list[str]], table: TableWithProperties, query_alias: Literal["prop", "value"], column: str, @@ -1317,7 +1317,7 @@ def test_breakdown_query_expression_materialised( @pytest.fixture -def test_events(db, team) -> List[UUID]: +def test_events(db, team) -> list[UUID]: return [ _create_event( event="$pageview", @@ -1958,7 +1958,7 @@ def test_combine_group_properties(): ], } - combined_group = PropertyGroup(PropertyOperatorType.AND, cast(List[Property], [])).combine_properties( + combined_group = PropertyGroup(PropertyOperatorType.AND, cast(list[Property], [])).combine_properties( PropertyOperatorType.OR, [propertyC, propertyD] ) assert combined_group.to_dict() == { diff --git a/ee/clickhouse/queries/column_optimizer.py b/ee/clickhouse/queries/column_optimizer.py index dd62154dd2037..b1bf142aa3d1e 100644 --- a/ee/clickhouse/queries/column_optimizer.py +++ b/ee/clickhouse/queries/column_optimizer.py @@ -1,5 +1,5 @@ -from typing import Counter as TCounter -from typing import Set, cast +from collections import Counter as TCounter +from typing import cast from posthog.clickhouse.materialized_columns.column import ColumnName from posthog.constants import TREND_FILTER_TYPE_ACTIONS, FunnelCorrelationType @@ -20,16 +20,16 @@ class EnterpriseColumnOptimizer(FOSSColumnOptimizer): @cached_property - def group_types_to_query(self) -> Set[GroupTypeIndex]: + def group_types_to_query(self) -> set[GroupTypeIndex]: used_properties = self.used_properties_with_type("group") return {cast(GroupTypeIndex, group_type_index) for _, _, group_type_index in used_properties} @cached_property - def group_on_event_columns_to_query(self) -> Set[ColumnName]: + def group_on_event_columns_to_query(self) -> set[ColumnName]: "Returns a list of event table group columns containing materialized properties that this query needs" used_properties = self.used_properties_with_type("group") - columns_to_query: Set[ColumnName] = set() + columns_to_query: set[ColumnName] = set() for group_type_index in range(5): columns_to_query = columns_to_query.union( @@ -120,7 +120,7 @@ def properties_used_in_filter(self) -> TCounter[PropertyIdentifier]: counter += get_action_tables_and_properties(entity.get_action()) if ( - not isinstance(self.filter, (StickinessFilter, PropertiesTimelineFilter)) + not isinstance(self.filter, StickinessFilter | PropertiesTimelineFilter) and self.filter.correlation_type == FunnelCorrelationType.PROPERTIES and self.filter.correlation_property_names ): diff --git a/ee/clickhouse/queries/enterprise_cohort_query.py b/ee/clickhouse/queries/enterprise_cohort_query.py index a748a64adf06a..814b61e9a8bf5 100644 --- a/ee/clickhouse/queries/enterprise_cohort_query.py +++ b/ee/clickhouse/queries/enterprise_cohort_query.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple, cast +from typing import Any, cast from posthog.constants import PropertyOperatorType from posthog.models.cohort.util import get_count_operator @@ -15,18 +15,18 @@ from posthog.schema import PersonsOnEventsMode -def check_negation_clause(prop: PropertyGroup) -> Tuple[bool, bool]: +def check_negation_clause(prop: PropertyGroup) -> tuple[bool, bool]: has_negation_clause = False has_primary_clase = False if len(prop.values): if isinstance(prop.values[0], PropertyGroup): - for p in cast(List[PropertyGroup], prop.values): + for p in cast(list[PropertyGroup], prop.values): has_neg, has_primary = check_negation_clause(p) has_negation_clause = has_negation_clause or has_neg has_primary_clase = has_primary_clase or has_primary else: - for property in cast(List[Property], prop.values): + for property in cast(list[Property], prop.values): if property.negation: has_negation_clause = True else: @@ -42,7 +42,7 @@ def check_negation_clause(prop: PropertyGroup) -> Tuple[bool, bool]: class EnterpriseCohortQuery(FOSSCohortQuery): - def get_query(self) -> Tuple[str, Dict[str, Any]]: + def get_query(self) -> tuple[str, dict[str, Any]]: if not self._outer_property_groups: # everything is pushed down, no behavioral stuff to do # thus, use personQuery directly @@ -87,9 +87,9 @@ def get_query(self) -> Tuple[str, Dict[str, Any]]: return final_query, self.params - def _get_condition_for_property(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def _get_condition_for_property(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: res: str = "" - params: Dict[str, Any] = {} + params: dict[str, Any] = {} if prop.type == "behavioral": if prop.value == "performed_event": @@ -117,7 +117,7 @@ def _get_condition_for_property(self, prop: Property, prepend: str, idx: int) -> return res, params - def get_stopped_performing_event(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def get_stopped_performing_event(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: event = (prop.event_type, prop.key) column_name = f"stopped_event_condition_{prepend}_{idx}" @@ -152,7 +152,7 @@ def get_stopped_performing_event(self, prop: Property, prepend: str, idx: int) - }, ) - def get_restarted_performing_event(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def get_restarted_performing_event(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: event = (prop.event_type, prop.key) column_name = f"restarted_event_condition_{prepend}_{idx}" @@ -191,7 +191,7 @@ def get_restarted_performing_event(self, prop: Property, prepend: str, idx: int) }, ) - def get_performed_event_first_time(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def get_performed_event_first_time(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: event = (prop.event_type, prop.key) entity_query, entity_params = self._get_entity(event, prepend, idx) @@ -212,7 +212,7 @@ def get_performed_event_first_time(self, prop: Property, prepend: str, idx: int) {f"{date_param}": date_value, **entity_params}, ) - def get_performed_event_regularly(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def get_performed_event_regularly(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: event = (prop.event_type, prop.key) entity_query, entity_params = self._get_entity(event, prepend, idx) @@ -266,7 +266,7 @@ def get_performed_event_regularly(self, prop: Property, prepend: str, idx: int) ) @cached_property - def sequence_filters_to_query(self) -> List[Property]: + def sequence_filters_to_query(self) -> list[Property]: props = [] for prop in self._filter.property_groups.flat: if prop.value == "performed_event_sequence": @@ -274,13 +274,13 @@ def sequence_filters_to_query(self) -> List[Property]: return props @cached_property - def sequence_filters_lookup(self) -> Dict[str, str]: + def sequence_filters_lookup(self) -> dict[str, str]: lookup = {} for idx, prop in enumerate(self.sequence_filters_to_query): lookup[str(prop.to_dict())] = f"{idx}" return lookup - def _get_sequence_query(self) -> Tuple[str, Dict[str, Any], str]: + def _get_sequence_query(self) -> tuple[str, dict[str, Any], str]: params = {} materialized_columns = list(self._column_optimizer.event_columns_to_query) @@ -356,7 +356,7 @@ def _get_sequence_query(self) -> Tuple[str, Dict[str, Any], str]: self.FUNNEL_QUERY_ALIAS, ) - def _get_sequence_filter(self, prop: Property, idx: int) -> Tuple[List[str], List[str], List[str], Dict[str, Any]]: + def _get_sequence_filter(self, prop: Property, idx: int) -> tuple[list[str], list[str], list[str], dict[str, Any]]: event = validate_entity((prop.event_type, prop.key)) entity_query, entity_params = self._get_entity(event, f"event_sequence_{self._cohort_pk}", idx) seq_event = validate_entity((prop.seq_event_type, prop.seq_event)) @@ -405,7 +405,7 @@ def _get_sequence_filter(self, prop: Property, idx: int) -> Tuple[List[str], Lis }, ) - def get_performed_event_sequence(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def get_performed_event_sequence(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: return ( f"{self.SEQUENCE_FIELD_ALIAS}_{self.sequence_filters_lookup[str(prop.to_dict())]}", {}, diff --git a/ee/clickhouse/queries/event_query.py b/ee/clickhouse/queries/event_query.py index b1b4dbb695e63..0e16abc780049 100644 --- a/ee/clickhouse/queries/event_query.py +++ b/ee/clickhouse/queries/event_query.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union from ee.clickhouse.materialized_columns.columns import ColumnName from ee.clickhouse.queries.column_optimizer import EnterpriseColumnOptimizer @@ -33,9 +33,9 @@ def __init__( should_join_distinct_ids=False, should_join_persons=False, # Extra events/person table columns to fetch since parent query needs them - extra_fields: Optional[List[ColumnName]] = None, - extra_event_properties: Optional[List[PropertyName]] = None, - extra_person_fields: Optional[List[ColumnName]] = None, + extra_fields: Optional[list[ColumnName]] = None, + extra_event_properties: Optional[list[PropertyName]] = None, + extra_person_fields: Optional[list[ColumnName]] = None, override_aggregate_users_by_distinct_id: Optional[bool] = None, person_on_events_mode: PersonsOnEventsMode = PersonsOnEventsMode.disabled, **kwargs, @@ -62,7 +62,7 @@ def __init__( self._column_optimizer = EnterpriseColumnOptimizer(self._filter, self._team_id) - def _get_groups_query(self) -> Tuple[str, Dict]: + def _get_groups_query(self) -> tuple[str, dict]: if isinstance(self._filter, PropertiesTimelineFilter): raise Exception("Properties Timeline never needs groups query") return GroupsJoinQuery( diff --git a/ee/clickhouse/queries/experiments/funnel_experiment_result.py b/ee/clickhouse/queries/experiments/funnel_experiment_result.py index ab117b07c69e2..845cce75d505c 100644 --- a/ee/clickhouse/queries/experiments/funnel_experiment_result.py +++ b/ee/clickhouse/queries/experiments/funnel_experiment_result.py @@ -1,7 +1,7 @@ from dataclasses import asdict, dataclass from datetime import datetime import json -from typing import List, Optional, Tuple, Type +from typing import Optional from zoneinfo import ZoneInfo from numpy.random import default_rng @@ -56,7 +56,7 @@ def __init__( feature_flag: FeatureFlag, experiment_start_date: datetime, experiment_end_date: Optional[datetime] = None, - funnel_class: Type[ClickhouseFunnel] = ClickhouseFunnel, + funnel_class: type[ClickhouseFunnel] = ClickhouseFunnel, ): breakdown_key = f"$feature/{feature_flag.key}" self.variants = [variant["key"] for variant in feature_flag.variants] @@ -148,9 +148,9 @@ def get_variants(self, funnel_results): @staticmethod def calculate_results( control_variant: Variant, - test_variants: List[Variant], - priors: Tuple[int, int] = (1, 1), - ) -> List[Probability]: + test_variants: list[Variant], + priors: tuple[int, int] = (1, 1), + ) -> list[Probability]: """ Calculates probability that A is better than B. First variant is control, rest are test variants. @@ -186,9 +186,9 @@ def calculate_results( @staticmethod def are_results_significant( control_variant: Variant, - test_variants: List[Variant], - probabilities: List[Probability], - ) -> Tuple[ExperimentSignificanceCode, Probability]: + test_variants: list[Variant], + probabilities: list[Probability], + ) -> tuple[ExperimentSignificanceCode, Probability]: def get_conversion_rate(variant: Variant): return variant.success_count / (variant.success_count + variant.failure_count) @@ -226,7 +226,7 @@ def get_conversion_rate(variant: Variant): return ExperimentSignificanceCode.SIGNIFICANT, expected_loss -def calculate_expected_loss(target_variant: Variant, variants: List[Variant]) -> float: +def calculate_expected_loss(target_variant: Variant, variants: list[Variant]) -> float: """ Calculates expected loss in conversion rate for a given variant. Loss calculation comes from VWO's SmartStats technical paper: @@ -268,7 +268,7 @@ def calculate_expected_loss(target_variant: Variant, variants: List[Variant]) -> return loss / simulations_count -def simulate_winning_variant_for_conversion(target_variant: Variant, variants: List[Variant]) -> Probability: +def simulate_winning_variant_for_conversion(target_variant: Variant, variants: list[Variant]) -> Probability: random_sampler = default_rng() prior_success = 1 prior_failure = 1 @@ -300,7 +300,7 @@ def simulate_winning_variant_for_conversion(target_variant: Variant, variants: L return winnings / simulations_count -def calculate_probability_of_winning_for_each(variants: List[Variant]) -> List[Probability]: +def calculate_probability_of_winning_for_each(variants: list[Variant]) -> list[Probability]: """ Calculates the probability of winning for each variant. """ diff --git a/ee/clickhouse/queries/experiments/secondary_experiment_result.py b/ee/clickhouse/queries/experiments/secondary_experiment_result.py index 4926d2920afbd..bd485c43622bf 100644 --- a/ee/clickhouse/queries/experiments/secondary_experiment_result.py +++ b/ee/clickhouse/queries/experiments/secondary_experiment_result.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Dict, Optional +from typing import Optional from rest_framework.exceptions import ValidationError from ee.clickhouse.queries.experiments.funnel_experiment_result import ClickhouseFunnelExperimentResult @@ -55,7 +55,7 @@ def get_results(self): return {"result": variants, **significance_results} - def get_funnel_conversion_rate_for_variants(self, insight_results) -> Dict[str, float]: + def get_funnel_conversion_rate_for_variants(self, insight_results) -> dict[str, float]: variants = {} for result in insight_results: total = result[0]["count"] @@ -67,7 +67,7 @@ def get_funnel_conversion_rate_for_variants(self, insight_results) -> Dict[str, return variants - def get_trend_count_data_for_variants(self, insight_results) -> Dict[str, float]: + def get_trend_count_data_for_variants(self, insight_results) -> dict[str, float]: # this assumes the Trend insight is Cumulative, unless using count per user variants = {} diff --git a/ee/clickhouse/queries/experiments/test_experiment_result.py b/ee/clickhouse/queries/experiments/test_experiment_result.py index 20b737efa1767..18eb673bf9ac8 100644 --- a/ee/clickhouse/queries/experiments/test_experiment_result.py +++ b/ee/clickhouse/queries/experiments/test_experiment_result.py @@ -1,7 +1,6 @@ import unittest from functools import lru_cache from math import exp, lgamma, log -from typing import List from flaky import flaky @@ -31,7 +30,7 @@ def logbeta(x: int, y: int) -> float: # calculation: https://www.evanmiller.org/bayesian-ab-testing.html#binary_ab -def calculate_probability_of_winning_for_target(target_variant: Variant, other_variants: List[Variant]) -> Probability: +def calculate_probability_of_winning_for_target(target_variant: Variant, other_variants: list[Variant]) -> Probability: """ Calculates the probability of winning for target variant. """ @@ -455,7 +454,7 @@ def test_calculate_results_many_variants_control_is_significant(self): # calculation: https://www.evanmiller.org/bayesian-ab-testing.html#count_ab def calculate_probability_of_winning_for_target_count_data( - target_variant: CountVariant, other_variants: List[CountVariant] + target_variant: CountVariant, other_variants: list[CountVariant] ) -> Probability: """ Calculates the probability of winning for target variant. diff --git a/ee/clickhouse/queries/experiments/trend_experiment_result.py b/ee/clickhouse/queries/experiments/trend_experiment_result.py index 02974d8bd8252..0370e0a684a88 100644 --- a/ee/clickhouse/queries/experiments/trend_experiment_result.py +++ b/ee/clickhouse/queries/experiments/trend_experiment_result.py @@ -3,7 +3,7 @@ from datetime import datetime from functools import lru_cache from math import exp, lgamma, log -from typing import List, Optional, Tuple, Type +from typing import Optional from zoneinfo import ZoneInfo from numpy.random import default_rng @@ -78,7 +78,7 @@ def __init__( feature_flag: FeatureFlag, experiment_start_date: datetime, experiment_end_date: Optional[datetime] = None, - trend_class: Type[Trends] = Trends, + trend_class: type[Trends] = Trends, custom_exposure_filter: Optional[Filter] = None, ): breakdown_key = f"$feature/{feature_flag.key}" @@ -316,7 +316,7 @@ def get_variants(self, insight_results, exposure_results): return control_variant, test_variants @staticmethod - def calculate_results(control_variant: Variant, test_variants: List[Variant]) -> List[Probability]: + def calculate_results(control_variant: Variant, test_variants: list[Variant]) -> list[Probability]: """ Calculates probability that A is better than B. First variant is control, rest are test variants. @@ -346,9 +346,9 @@ def calculate_results(control_variant: Variant, test_variants: List[Variant]) -> @staticmethod def are_results_significant( control_variant: Variant, - test_variants: List[Variant], - probabilities: List[Probability], - ) -> Tuple[ExperimentSignificanceCode, Probability]: + test_variants: list[Variant], + probabilities: list[Probability], + ) -> tuple[ExperimentSignificanceCode, Probability]: # TODO: Experiment with Expected Loss calculations for trend experiments for variant in test_variants: @@ -375,7 +375,7 @@ def are_results_significant( return ExperimentSignificanceCode.SIGNIFICANT, p_value -def simulate_winning_variant_for_arrival_rates(target_variant: Variant, variants: List[Variant]) -> float: +def simulate_winning_variant_for_arrival_rates(target_variant: Variant, variants: list[Variant]) -> float: random_sampler = default_rng() simulations_count = 100_000 @@ -399,7 +399,7 @@ def simulate_winning_variant_for_arrival_rates(target_variant: Variant, variants return winnings / simulations_count -def calculate_probability_of_winning_for_each(variants: List[Variant]) -> List[Probability]: +def calculate_probability_of_winning_for_each(variants: list[Variant]) -> list[Probability]: """ Calculates the probability of winning for each variant. """ @@ -458,7 +458,7 @@ def poisson_p_value(control_count, control_exposure, test_count, test_exposure): return min(1, 2 * min(low_p_value, high_p_value)) -def calculate_p_value(control_variant: Variant, test_variants: List[Variant]) -> Probability: +def calculate_p_value(control_variant: Variant, test_variants: list[Variant]) -> Probability: best_test_variant = max(test_variants, key=lambda variant: variant.count) return poisson_p_value( diff --git a/ee/clickhouse/queries/experiments/utils.py b/ee/clickhouse/queries/experiments/utils.py index 88418e3e354d2..c0211e4c9de24 100644 --- a/ee/clickhouse/queries/experiments/utils.py +++ b/ee/clickhouse/queries/experiments/utils.py @@ -1,4 +1,4 @@ -from typing import Set, Union +from typing import Union from posthog.client import sync_execute from posthog.constants import TREND_FILTER_TYPE_ACTIONS @@ -20,7 +20,7 @@ def requires_flag_warning(filter: Filter, team: Team) -> bool: {parsed_date_to} """ - events: Set[Union[int, str]] = set() + events: set[Union[int, str]] = set() entities_to_use = filter.entities for entity in entities_to_use: diff --git a/ee/clickhouse/queries/funnels/funnel_correlation.py b/ee/clickhouse/queries/funnels/funnel_correlation.py index ed3995968a001..c25763167f2bf 100644 --- a/ee/clickhouse/queries/funnels/funnel_correlation.py +++ b/ee/clickhouse/queries/funnels/funnel_correlation.py @@ -2,12 +2,8 @@ import urllib.parse from typing import ( Any, - Dict, - List, Literal, Optional, - Set, - Tuple, TypedDict, Union, cast, @@ -40,7 +36,7 @@ class EventDefinition(TypedDict): event: str - properties: Dict[str, Any] + properties: dict[str, Any] elements: list @@ -74,7 +70,7 @@ class FunnelCorrelationResponse(TypedDict): queries, but we could use, for example, a dataclass """ - events: List[EventOddsRatioSerialized] + events: list[EventOddsRatioSerialized] skewed: bool @@ -153,7 +149,7 @@ def __init__( ) @property - def properties_to_include(self) -> List[str]: + def properties_to_include(self) -> list[str]: props_to_include = [] if ( self._team.person_on_events_mode != PersonsOnEventsMode.disabled @@ -203,7 +199,7 @@ def support_autocapture_elements(self) -> bool: return True return False - def get_contingency_table_query(self) -> Tuple[str, Dict[str, Any]]: + def get_contingency_table_query(self) -> tuple[str, dict[str, Any]]: """ Returns a query string and params, which are used to generate the contingency table. The query returns success and failure count for event / property values, along with total success and failure counts. @@ -216,7 +212,7 @@ def get_contingency_table_query(self) -> Tuple[str, Dict[str, Any]]: return self.get_event_query() - def get_event_query(self) -> Tuple[str, Dict[str, Any]]: + def get_event_query(self) -> tuple[str, dict[str, Any]]: funnel_persons_query, funnel_persons_params = self.get_funnel_actors_cte() event_join_query = self._get_events_join_query() @@ -279,7 +275,7 @@ def get_event_query(self) -> Tuple[str, Dict[str, Any]]: return query, params - def get_event_property_query(self) -> Tuple[str, Dict[str, Any]]: + def get_event_property_query(self) -> tuple[str, dict[str, Any]]: if not self._filter.correlation_event_names: raise ValidationError("Event Property Correlation expects atleast one event name to run correlation on") @@ -359,7 +355,7 @@ def get_event_property_query(self) -> Tuple[str, Dict[str, Any]]: return query, params - def get_properties_query(self) -> Tuple[str, Dict[str, Any]]: + def get_properties_query(self) -> tuple[str, dict[str, Any]]: if not self._filter.correlation_property_names: raise ValidationError("Property Correlation expects atleast one Property to run correlation on") @@ -580,7 +576,7 @@ def _get_properties_prop_clause(self): ) def _get_funnel_step_names(self): - events: Set[Union[int, str]] = set() + events: set[Union[int, str]] = set() for entity in self._filter.entities: if entity.type == TREND_FILTER_TYPE_ACTIONS: action = entity.get_action() @@ -590,7 +586,7 @@ def _get_funnel_step_names(self): return sorted(events) - def _run(self) -> Tuple[List[EventOddsRatio], bool]: + def _run(self) -> tuple[list[EventOddsRatio], bool]: """ Run the diagnose query. @@ -834,7 +830,7 @@ def construct_person_properties_people_url( ).to_params() return f"{self._base_uri}api/person/funnel/correlation?{urllib.parse.urlencode(params)}&cache_invalidation_key={cache_invalidation_key}" - def format_results(self, results: Tuple[List[EventOddsRatio], bool]) -> FunnelCorrelationResponse: + def format_results(self, results: tuple[list[EventOddsRatio], bool]) -> FunnelCorrelationResponse: odds_ratios, skewed_totals = results return { "events": [self.serialize_event_odds_ratio(odds_ratio=odds_ratio) for odds_ratio in odds_ratios], @@ -847,7 +843,7 @@ def run(self) -> FunnelCorrelationResponse: return self.format_results(self._run()) - def get_partial_event_contingency_tables(self) -> Tuple[List[EventContingencyTable], int, int]: + def get_partial_event_contingency_tables(self) -> tuple[list[EventContingencyTable], int, int]: """ For each event a person that started going through the funnel, gets stats for how many of these users are sucessful and how many are unsuccessful. @@ -888,7 +884,7 @@ def get_partial_event_contingency_tables(self) -> Tuple[List[EventContingencyTab failure_total, ) - def get_funnel_actors_cte(self) -> Tuple[str, Dict[str, Any]]: + def get_funnel_actors_cte(self) -> tuple[str, dict[str, Any]]: extra_fields = ["steps", "final_timestamp", "first_timestamp"] for prop in self.properties_to_include: @@ -975,12 +971,12 @@ def get_entity_odds_ratio(event_contingency_table: EventContingencyTable, prior_ ) -def build_selector(elements: List[Dict[str, Any]]) -> str: +def build_selector(elements: list[dict[str, Any]]) -> str: # build a CSS select given an "elements_chain" # NOTE: my source of what this should be doing is # https://github.com/PostHog/posthog/blob/cc054930a47fb59940531e99a856add49a348ee5/frontend/src/scenes/events/createActionFromEvent.tsx#L36:L36 # - def element_to_selector(element: Dict[str, Any]) -> str: + def element_to_selector(element: dict[str, Any]) -> str: if attr_id := element.get("attr_id"): return f'[id="{attr_id}"]' diff --git a/ee/clickhouse/queries/funnels/funnel_correlation_persons.py b/ee/clickhouse/queries/funnels/funnel_correlation_persons.py index 6a0cfe3655103..b02a8b8e9b6cb 100644 --- a/ee/clickhouse/queries/funnels/funnel_correlation_persons.py +++ b/ee/clickhouse/queries/funnels/funnel_correlation_persons.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import Optional, Union from django.db.models.query import QuerySet from rest_framework.exceptions import ValidationError @@ -52,9 +52,9 @@ def actor_query(self, limit_actors: Optional[bool] = True): def get_actors( self, - ) -> Tuple[ + ) -> tuple[ Union[QuerySet[Person], QuerySet[Group]], - Union[List[SerializedGroup], List[SerializedPerson]], + Union[list[SerializedGroup], list[SerializedPerson]], int, ]: if self._filter.correlation_type == FunnelCorrelationType.PROPERTIES: @@ -167,7 +167,7 @@ def aggregation_group_type_index(self): def actor_query( self, limit_actors: Optional[bool] = True, - extra_fields: Optional[List[str]] = None, + extra_fields: Optional[list[str]] = None, ): if not self._filter.correlation_property_values: raise ValidationError("Property Correlation expects atleast one Property to get persons for") diff --git a/ee/clickhouse/queries/funnels/test/breakdown_cases.py b/ee/clickhouse/queries/funnels/test/breakdown_cases.py index f4fb2689d87b7..7a1b2076776d0 100644 --- a/ee/clickhouse/queries/funnels/test/breakdown_cases.py +++ b/ee/clickhouse/queries/funnels/test/breakdown_cases.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Dict, List +from typing import Any from posthog.constants import INSIGHT_FUNNELS from posthog.models.filters import Filter @@ -51,8 +51,8 @@ def _create_groups(self): properties={"industry": "random"}, ) - def _assert_funnel_breakdown_result_is_correct(self, result, steps: List[FunnelStepResult]): - def funnel_result(step: FunnelStepResult, order: int) -> Dict[str, Any]: + def _assert_funnel_breakdown_result_is_correct(self, result, steps: list[FunnelStepResult]): + def funnel_result(step: FunnelStepResult, order: int) -> dict[str, Any]: return { "action_id": step.name if step.type == "events" else step.action_id, "name": step.name, diff --git a/ee/clickhouse/queries/groups_join_query.py b/ee/clickhouse/queries/groups_join_query.py index db1d12a3c6c46..7a3dc46daf993 100644 --- a/ee/clickhouse/queries/groups_join_query.py +++ b/ee/clickhouse/queries/groups_join_query.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Union from ee.clickhouse.queries.column_optimizer import EnterpriseColumnOptimizer from posthog.models import Filter @@ -35,7 +35,7 @@ def __init__( self._join_key = join_key self._person_on_events_mode = person_on_events_mode - def get_join_query(self) -> Tuple[str, Dict]: + def get_join_query(self) -> tuple[str, dict]: join_queries, params = [], {} if self._person_on_events_mode != PersonsOnEventsMode.disabled and groups_on_events_querying_enabled(): @@ -63,7 +63,7 @@ def get_join_query(self) -> Tuple[str, Dict]: return "\n".join(join_queries), params - def get_filter_query(self, group_type_index: GroupTypeIndex) -> Tuple[str, Dict]: + def get_filter_query(self, group_type_index: GroupTypeIndex) -> tuple[str, dict]: var = f"group_index_{group_type_index}" params = { "team_id": self._team_id, diff --git a/ee/clickhouse/queries/paths/paths.py b/ee/clickhouse/queries/paths/paths.py index a5b9968da589e..f20744ee6729e 100644 --- a/ee/clickhouse/queries/paths/paths.py +++ b/ee/clickhouse/queries/paths/paths.py @@ -1,5 +1,5 @@ from re import escape -from typing import Dict, Literal, Optional, Tuple, Union, cast +from typing import Literal, Optional, Union, cast from jsonschema import ValidationError @@ -34,8 +34,8 @@ def __init__(self, filter: PathFilter, team: Team, funnel_filter: Optional[Filte ): raise ValidationError("Max Edge weight can't be lower than min edge weight") - def get_edge_weight_clause(self) -> Tuple[str, Dict]: - params: Dict[str, int] = {} + def get_edge_weight_clause(self) -> tuple[str, dict]: + params: dict[str, int] = {} conditions = [] @@ -60,8 +60,8 @@ def get_target_point_filter(self) -> str: else: return "" - def get_target_clause(self) -> Tuple[str, Dict]: - params: Dict[str, Union[str, None]] = { + def get_target_clause(self) -> tuple[str, dict]: + params: dict[str, Union[str, None]] = { "target_point": None, "secondary_target_point": None, } @@ -152,7 +152,7 @@ def get_array_compacting_function(self) -> Literal["arrayResize", "arraySlice"]: else: return "arraySlice" - def get_filtered_path_ordering(self) -> Tuple[str, ...]: + def get_filtered_path_ordering(self) -> tuple[str, ...]: fields_to_include = ["filtered_path", "filtered_timings"] + [ f"filtered_{field}s" for field in self.extra_event_fields_and_properties ] diff --git a/ee/clickhouse/queries/related_actors_query.py b/ee/clickhouse/queries/related_actors_query.py index 9c031a3b66221..e4cd462ace4f4 100644 --- a/ee/clickhouse/queries/related_actors_query.py +++ b/ee/clickhouse/queries/related_actors_query.py @@ -1,6 +1,6 @@ from datetime import timedelta from functools import cached_property -from typing import List, Optional, Union +from typing import Optional, Union from django.utils.timezone import now @@ -38,8 +38,8 @@ def __init__( self.group_type_index = validate_group_type_index("group_type_index", group_type_index) self.id = id - def run(self) -> List[SerializedActor]: - results: List[SerializedActor] = [] + def run(self) -> list[SerializedActor]: + results: list[SerializedActor] = [] results.extend(self._query_related_people()) for group_type_mapping in GroupTypeMapping.objects.filter(team_id=self.team.pk): results.extend(self._query_related_groups(group_type_mapping.group_type_index)) @@ -49,7 +49,7 @@ def run(self) -> List[SerializedActor]: def is_aggregating_by_groups(self) -> bool: return self.group_type_index is not None - def _query_related_people(self) -> List[SerializedPerson]: + def _query_related_people(self) -> list[SerializedPerson]: if not self.is_aggregating_by_groups: return [] @@ -72,7 +72,7 @@ def _query_related_people(self) -> List[SerializedPerson]: _, serialized_people = get_people(self.team, person_ids) return serialized_people - def _query_related_groups(self, group_type_index: GroupTypeIndex) -> List[SerializedGroup]: + def _query_related_groups(self, group_type_index: GroupTypeIndex) -> list[SerializedGroup]: if group_type_index == self.group_type_index: return [] @@ -102,7 +102,7 @@ def _query_related_groups(self, group_type_index: GroupTypeIndex) -> List[Serial _, serialize_groups = get_groups(self.team.pk, group_type_index, group_ids) return serialize_groups - def _take_first(self, rows: List) -> List: + def _take_first(self, rows: list) -> list: return [row[0] for row in rows] @property diff --git a/ee/clickhouse/queries/test/test_paths.py b/ee/clickhouse/queries/test/test_paths.py index fdaf25a043a6d..69f673e4489ca 100644 --- a/ee/clickhouse/queries/test/test_paths.py +++ b/ee/clickhouse/queries/test/test_paths.py @@ -1,5 +1,4 @@ from datetime import timedelta -from typing import Tuple from unittest.mock import MagicMock from uuid import UUID @@ -2905,7 +2904,7 @@ def test_start_and_end(self): @snapshot_clickhouse_queries def test_properties_queried_using_path_filter(self): - def should_query_list(filter) -> Tuple[bool, bool]: + def should_query_list(filter) -> tuple[bool, bool]: path_query = PathEventQuery(filter, self.team) return (path_query._should_query_url(), path_query._should_query_screen()) diff --git a/ee/clickhouse/views/experiments.py b/ee/clickhouse/views/experiments.py index f50b9921c926c..b37d4e4d765df 100644 --- a/ee/clickhouse/views/experiments.py +++ b/ee/clickhouse/views/experiments.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Optional +from typing import Any, Optional +from collections.abc import Callable from django.utils.timezone import now from rest_framework import serializers, viewsets diff --git a/ee/clickhouse/views/groups.py b/ee/clickhouse/views/groups.py index e539de4673d60..4c67072b11de2 100644 --- a/ee/clickhouse/views/groups.py +++ b/ee/clickhouse/views/groups.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Dict, List, cast +from typing import cast from django.db.models import Q from drf_spectacular.types import OpenApiTypes @@ -34,7 +34,7 @@ class ClickhouseGroupsTypesView(TeamAndOrgViewSetMixin, mixins.ListModelMixin, v @action(detail=False, methods=["PATCH"], name="Update group types metadata") def update_metadata(self, request: request.Request, *args, **kwargs): - for row in cast(List[Dict], request.data): + for row in cast(list[dict], request.data): instance = GroupTypeMapping.objects.get(team=self.team, group_type_index=row["group_type_index"]) serializer = self.get_serializer(instance, data=row) serializer.is_valid(raise_exception=True) diff --git a/ee/clickhouse/views/insights.py b/ee/clickhouse/views/insights.py index ff772b71aaef8..e6adf49e7ff9e 100644 --- a/ee/clickhouse/views/insights.py +++ b/ee/clickhouse/views/insights.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any from rest_framework.decorators import action from rest_framework.permissions import SAFE_METHODS, BasePermission @@ -47,7 +47,7 @@ def funnel_correlation(self, request: Request, *args: Any, **kwargs: Any) -> Res return Response(result) @cached_by_filters - def calculate_funnel_correlation(self, request: Request) -> Dict[str, Any]: + def calculate_funnel_correlation(self, request: Request) -> dict[str, Any]: team = self.team filter = Filter(request=request, team=team) diff --git a/ee/clickhouse/views/person.py b/ee/clickhouse/views/person.py index d01dba65da928..f3f8432ad6871 100644 --- a/ee/clickhouse/views/person.py +++ b/ee/clickhouse/views/person.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple +from typing import Optional from rest_framework import request, response from rest_framework.decorators import action @@ -28,7 +28,7 @@ def funnel_correlation(self, request: request.Request, **kwargs) -> response.Res @cached_by_filters def calculate_funnel_correlation_persons( self, request: request.Request - ) -> Dict[str, Tuple[List, Optional[str], Optional[str], int]]: + ) -> dict[str, tuple[list, Optional[str], Optional[str], int]]: filter = Filter(request=request, data={"insight": INSIGHT_FUNNELS}, team=self.team) if not filter.correlation_person_limit: filter = filter.shallow_clone({FUNNEL_CORRELATION_PERSON_LIMIT: 100}) diff --git a/ee/clickhouse/views/test/funnel/test_clickhouse_funnel_correlation.py b/ee/clickhouse/views/test/funnel/test_clickhouse_funnel_correlation.py index 829232d1bd94f..f5ff3722008b8 100644 --- a/ee/clickhouse/views/test/funnel/test_clickhouse_funnel_correlation.py +++ b/ee/clickhouse/views/test/funnel/test_clickhouse_funnel_correlation.py @@ -552,15 +552,15 @@ def test_properties_correlation_endpoint_provides_people_drill_down_urls(self): ), ) - (browser_correlation,) = [ + (browser_correlation,) = ( correlation for correlation in odds["result"]["events"] if correlation["event"]["event"] == "$browser::1" - ] + ) - (notset_correlation,) = [ + (notset_correlation,) = ( correlation for correlation in odds["result"]["events"] if correlation["event"]["event"] == "$browser::" - ] + ) assert get_people_for_correlation_ok(client=self.client, correlation=browser_correlation) == { "success": ["Person 2"], diff --git a/ee/clickhouse/views/test/funnel/util.py b/ee/clickhouse/views/test/funnel/util.py index 8d2c304cb8b4c..45984ee41ba29 100644 --- a/ee/clickhouse/views/test/funnel/util.py +++ b/ee/clickhouse/views/test/funnel/util.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, Dict, Literal, Optional, TypedDict, Union +from typing import Any, Literal, Optional, TypedDict, Union from django.test.client import Client @@ -12,7 +12,7 @@ class EventPattern(TypedDict, total=False): id: str type: Union[Literal["events"], Literal["actions"]] order: int - properties: Dict[str, Any] + properties: dict[str, Any] @dataclasses.dataclass @@ -46,7 +46,7 @@ def get_funnel(client: Client, team_id: int, request: FunnelRequest): ) -def get_funnel_ok(client: Client, team_id: int, request: FunnelRequest) -> Dict[str, Any]: +def get_funnel_ok(client: Client, team_id: int, request: FunnelRequest) -> dict[str, Any]: response = get_funnel(client=client, team_id=team_id, request=request) assert response.status_code == 200, response.content @@ -73,14 +73,14 @@ def get_funnel_correlation(client: Client, team_id: int, request: FunnelCorrelat ) -def get_funnel_correlation_ok(client: Client, team_id: int, request: FunnelCorrelationRequest) -> Dict[str, Any]: +def get_funnel_correlation_ok(client: Client, team_id: int, request: FunnelCorrelationRequest) -> dict[str, Any]: response = get_funnel_correlation(client=client, team_id=team_id, request=request) assert response.status_code == 200, response.content return response.json() -def get_people_for_correlation_ok(client: Client, correlation: EventOddsRatioSerialized) -> Dict[str, Any]: +def get_people_for_correlation_ok(client: Client, correlation: EventOddsRatioSerialized) -> dict[str, Any]: """ Helper for getting people for a correlation. Note we keep checking to just inclusion of name, to make the stable to changes in other people props. diff --git a/ee/clickhouse/views/test/test_clickhouse_experiment_secondary_results.py b/ee/clickhouse/views/test/test_clickhouse_experiment_secondary_results.py index 232312ec6449f..e7f9ebf7e2c3e 100644 --- a/ee/clickhouse/views/test/test_clickhouse_experiment_secondary_results.py +++ b/ee/clickhouse/views/test/test_clickhouse_experiment_secondary_results.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any from flaky import flaky @@ -7,7 +7,7 @@ from posthog.test.base import ClickhouseTestMixin, snapshot_clickhouse_queries from posthog.test.test_journeys import journeys_for -DEFAULT_JOURNEYS_FOR_PAYLOAD: Dict[str, List[Dict[str, Any]]] = { +DEFAULT_JOURNEYS_FOR_PAYLOAD: dict[str, list[dict[str, Any]]] = { # For a trend pageview metric "person1": [ { diff --git a/ee/clickhouse/views/test/test_clickhouse_retention.py b/ee/clickhouse/views/test/test_clickhouse_retention.py index 0e5a8ad0fafdf..5deff716a2658 100644 --- a/ee/clickhouse/views/test/test_clickhouse_retention.py +++ b/ee/clickhouse/views/test/test_clickhouse_retention.py @@ -1,5 +1,5 @@ from dataclasses import asdict, dataclass -from typing import List, Literal, Optional, TypedDict, Union +from typing import Literal, Optional, TypedDict, Union from django.test.client import Client @@ -719,10 +719,10 @@ class RetentionRequest: period: Union[Literal["Hour"], Literal["Day"], Literal["Week"], Literal["Month"]] retention_type: Literal["retention_first_time", "retention"] # probably not an exhaustive list - breakdowns: Optional[List[Breakdown]] = None + breakdowns: Optional[list[Breakdown]] = None breakdown_type: Optional[Literal["person", "event"]] = None - properties: Optional[List[PropertyFilter]] = None + properties: Optional[list[PropertyFilter]] = None filter_test_accounts: Optional[str] = None limit: Optional[int] = None @@ -734,26 +734,26 @@ class Value(TypedDict): class Cohort(TypedDict): - values: List[Value] + values: list[Value] date: str label: str class RetentionResponse(TypedDict): - result: List[Cohort] + result: list[Cohort] class Person(TypedDict): - distinct_ids: List[str] + distinct_ids: list[str] class RetentionTableAppearance(TypedDict): person: Person - appearances: List[int] + appearances: list[int] class RetentionTablePeopleResponse(TypedDict): - result: List[RetentionTableAppearance] + result: list[RetentionTableAppearance] def get_retention_ok(client: Client, team_id: int, request: RetentionRequest) -> RetentionResponse: diff --git a/ee/clickhouse/views/test/test_clickhouse_trends.py b/ee/clickhouse/views/test/test_clickhouse_trends.py index 8ce3809263a4b..4de1f00e53401 100644 --- a/ee/clickhouse/views/test/test_clickhouse_trends.py +++ b/ee/clickhouse/views/test/test_clickhouse_trends.py @@ -1,7 +1,7 @@ import json from dataclasses import dataclass, field from datetime import datetime -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from unittest.case import skip from unittest.mock import ANY @@ -420,20 +420,20 @@ class TrendsRequest: insight: Optional[str] = None display: Optional[str] = None compare: Optional[bool] = None - events: List[Dict[str, Any]] = field(default_factory=list) - properties: List[Dict[str, Any]] = field(default_factory=list) + events: list[dict[str, Any]] = field(default_factory=list) + properties: list[dict[str, Any]] = field(default_factory=list) smoothing_intervals: Optional[int] = 1 refresh: Optional[bool] = False @dataclass class TrendsRequestBreakdown(TrendsRequest): - breakdown: Optional[Union[List[int], str]] = None + breakdown: Optional[Union[list[int], str]] = None breakdown_type: Optional[str] = None def get_trends(client, request: Union[TrendsRequestBreakdown, TrendsRequest], team: Team): - data: Dict[str, Any] = { + data: dict[str, Any] = { "date_from": request.date_from, "date_to": request.date_to, "interval": request.interval, @@ -471,7 +471,7 @@ class NormalizedTrendResult: def get_trends_time_series_ok( client: Client, request: TrendsRequest, team: Team, with_order: bool = False -) -> Dict[str, Dict[str, NormalizedTrendResult]]: +) -> dict[str, dict[str, NormalizedTrendResult]]: data = get_trends_ok(client=client, request=request, team=team) res = {} for item in data["result"]: @@ -491,7 +491,7 @@ def get_trends_time_series_ok( return res -def get_trends_aggregate_ok(client: Client, request: TrendsRequest, team: Team) -> Dict[str, NormalizedTrendResult]: +def get_trends_aggregate_ok(client: Client, request: TrendsRequest, team: Team) -> dict[str, NormalizedTrendResult]: data = get_trends_ok(client=client, request=request, team=team) res = {} for item in data["result"]: diff --git a/ee/migrations/0001_initial.py b/ee/migrations/0001_initial.py index fd3cad3892708..5b668bc772b6a 100644 --- a/ee/migrations/0001_initial.py +++ b/ee/migrations/0001_initial.py @@ -1,6 +1,5 @@ # Generated by Django 3.0.7 on 2020-08-07 09:15 -from typing import List from django.db import migrations, models @@ -8,7 +7,7 @@ class Migration(migrations.Migration): initial = True - dependencies: List = [] + dependencies: list = [] operations = [ migrations.CreateModel( diff --git a/ee/migrations/0012_migrate_tags_v2.py b/ee/migrations/0012_migrate_tags_v2.py index 9a2cf8e3d39c4..540cd281338d4 100644 --- a/ee/migrations/0012_migrate_tags_v2.py +++ b/ee/migrations/0012_migrate_tags_v2.py @@ -1,5 +1,5 @@ # Generated by Django 3.2.5 on 2022-03-02 22:44 -from typing import Any, List, Tuple +from typing import Any from django.core.paginator import Paginator from django.db import migrations @@ -19,7 +19,7 @@ def forwards(apps, schema_editor): EnterpriseEventDefinition = apps.get_model("ee", "EnterpriseEventDefinition") EnterprisePropertyDefinition = apps.get_model("ee", "EnterprisePropertyDefinition") - createables: List[Tuple[Any, Any]] = [] + createables: list[tuple[Any, Any]] = [] batch_size = 1_000 # Collect event definition tags and taggeditems diff --git a/ee/models/license.py b/ee/models/license.py index f0e12d3d2f440..35530b89687ac 100644 --- a/ee/models/license.py +++ b/ee/models/license.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from django.contrib.auth import get_user_model from django.db import models @@ -85,7 +85,7 @@ class License(models.Model): PLAN_TO_SORTING_VALUE = {SCALE_PLAN: 10, ENTERPRISE_PLAN: 20} @property - def available_features(self) -> List[AvailableFeature]: + def available_features(self) -> list[AvailableFeature]: return self.PLANS.get(self.plan, []) @property diff --git a/ee/session_recordings/ai/embeddings_queries.py b/ee/session_recordings/ai/embeddings_queries.py index 6d657d111096d..2034a9f190152 100644 --- a/ee/session_recordings/ai/embeddings_queries.py +++ b/ee/session_recordings/ai/embeddings_queries.py @@ -1,6 +1,5 @@ from django.conf import settings -from typing import List from posthog.models import Team from posthog.clickhouse.client import sync_execute @@ -9,7 +8,7 @@ MIN_DURATION_INCLUDE_SECONDS = settings.REPLAY_EMBEDDINGS_MIN_DURATION_SECONDS -def fetch_errors_by_session_without_embeddings(team_id: int, offset=0) -> List[str]: +def fetch_errors_by_session_without_embeddings(team_id: int, offset=0) -> list[str]: query = """ WITH embedded_sessions AS ( SELECT @@ -47,7 +46,7 @@ def fetch_errors_by_session_without_embeddings(team_id: int, offset=0) -> List[s ) -def fetch_recordings_without_embeddings(team_id: int, offset=0) -> List[str]: +def fetch_recordings_without_embeddings(team_id: int, offset=0) -> list[str]: team = Team.objects.get(id=team_id) query = """ diff --git a/ee/session_recordings/ai/embeddings_runner.py b/ee/session_recordings/ai/embeddings_runner.py index 101c7175acb61..413e9f45368fe 100644 --- a/ee/session_recordings/ai/embeddings_runner.py +++ b/ee/session_recordings/ai/embeddings_runner.py @@ -3,7 +3,7 @@ import datetime import pytz -from typing import Dict, Any, List, Tuple +from typing import Any from abc import ABC, abstractmethod from prometheus_client import Histogram, Counter @@ -88,7 +88,7 @@ class EmbeddingPreparation(ABC): @staticmethod @abstractmethod - def prepare(item, team) -> Tuple[str, str]: + def prepare(item, team) -> tuple[str, str]: raise NotImplementedError() @@ -100,7 +100,7 @@ def __init__(self, team: Team): self.team = team self.openai_client = OpenAI() - def run(self, items: List[Any], embeddings_preparation: type[EmbeddingPreparation]) -> None: + def run(self, items: list[Any], embeddings_preparation: type[EmbeddingPreparation]) -> None: source_type = embeddings_preparation.source_type try: @@ -196,7 +196,7 @@ def _num_tokens_for_input(self, string: str) -> int: """Returns the number of tokens in a text string.""" return len(encoding.encode(string)) - def _flush_embeddings_to_clickhouse(self, embeddings: List[Dict[str, Any]], source_type: str) -> None: + def _flush_embeddings_to_clickhouse(self, embeddings: list[dict[str, Any]], source_type: str) -> None: try: sync_execute( "INSERT INTO session_replay_embeddings (session_id, team_id, embeddings, source_type, input) VALUES", @@ -213,7 +213,7 @@ class ErrorEmbeddingsPreparation(EmbeddingPreparation): source_type = "error" @staticmethod - def prepare(item: Tuple[str, str], _): + def prepare(item: tuple[str, str], _): session_id = item[0] error_message = item[1] return session_id, error_message @@ -286,7 +286,7 @@ def prepare(session_id: str, team: Team): return session_id, input @staticmethod - def _compact_result(event_name: str, current_url: int, elements_chain: Dict[str, str] | str) -> str: + def _compact_result(event_name: str, current_url: int, elements_chain: dict[str, str] | str) -> str: elements_string = ( elements_chain if isinstance(elements_chain, str) else ", ".join(str(e) for e in elements_chain) ) diff --git a/ee/session_recordings/ai/utils.py b/ee/session_recordings/ai/utils.py index a1d5f31460de0..1b7770a136128 100644 --- a/ee/session_recordings/ai/utils.py +++ b/ee/session_recordings/ai/utils.py @@ -1,7 +1,7 @@ import dataclasses from datetime import datetime -from typing import List, Dict, Any +from typing import Any from posthog.models.element import chain_to_elements from hashlib import shake_256 @@ -12,11 +12,11 @@ class SessionSummaryPromptData: # we may allow customisation of columns included in the future, # and we alter the columns present as we process the data # so want to stay as loose as possible here - columns: List[str] = dataclasses.field(default_factory=list) - results: List[List[Any]] = dataclasses.field(default_factory=list) + columns: list[str] = dataclasses.field(default_factory=list) + results: list[list[Any]] = dataclasses.field(default_factory=list) # in order to reduce the number of tokens in the prompt # we replace URLs with a placeholder and then pass this mapping of placeholder to URL into the prompt - url_mapping: Dict[str, str] = dataclasses.field(default_factory=dict) + url_mapping: dict[str, str] = dataclasses.field(default_factory=dict) def is_empty(self) -> bool: return not self.columns or not self.results @@ -63,7 +63,7 @@ def simplify_window_id(session_events: SessionSummaryPromptData) -> SessionSumma # find window_id column index window_id_index = session_events.column_index("$window_id") - window_id_mapping: Dict[str, int] = {} + window_id_mapping: dict[str, int] = {} simplified_results = [] for result in session_events.results: if window_id_index is None: @@ -128,7 +128,7 @@ def deduplicate_urls(session_events: SessionSummaryPromptData) -> SessionSummary # find url column index url_index = session_events.column_index("$current_url") - url_mapping: Dict[str, str] = {} + url_mapping: dict[str, str] = {} deduplicated_results = [] for result in session_events.results: if url_index is None: diff --git a/ee/session_recordings/queries/test/test_session_recording_list_from_session_replay.py b/ee/session_recordings/queries/test/test_session_recording_list_from_session_replay.py index 71196ec0ecadc..797ac453e69e0 100644 --- a/ee/session_recordings/queries/test/test_session_recording_list_from_session_replay.py +++ b/ee/session_recordings/queries/test/test_session_recording_list_from_session_replay.py @@ -1,5 +1,4 @@ from itertools import product -from typing import Dict from unittest import mock from uuid import uuid4 @@ -131,7 +130,7 @@ def test_effect_of_poe_settings_on_query_generated( poe_v2: bool, allow_denormalized_props: bool, expected_poe_mode: PersonsOnEventsMode, - expected_query_params: Dict, + expected_query_params: dict, unmaterialized_person_column_used: bool, materialized_event_column_used: bool, ) -> None: diff --git a/ee/session_recordings/session_recording_playlist.py b/ee/session_recordings/session_recording_playlist.py index a54f8e38a6bdd..7d2b9fe0b0cb2 100644 --- a/ee/session_recordings/session_recording_playlist.py +++ b/ee/session_recordings/session_recording_playlist.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, List, Optional +from typing import Any, Optional import structlog from django.db.models import Q, QuerySet @@ -49,7 +49,7 @@ def log_playlist_activity( team_id: int, user: User, was_impersonated: bool, - changes: Optional[List[Change]] = None, + changes: Optional[list[Change]] = None, ) -> None: """ Insight id and short_id are passed separately as some activities (like delete) alter the Insight instance @@ -101,7 +101,7 @@ class Meta: created_by = UserBasicSerializer(read_only=True) last_modified_by = UserBasicSerializer(read_only=True) - def create(self, validated_data: Dict, *args, **kwargs) -> SessionRecordingPlaylist: + def create(self, validated_data: dict, *args, **kwargs) -> SessionRecordingPlaylist: request = self.context["request"] team = self.context["get_team"]() @@ -128,7 +128,7 @@ def create(self, validated_data: Dict, *args, **kwargs) -> SessionRecordingPlayl return playlist - def update(self, instance: SessionRecordingPlaylist, validated_data: Dict, **kwargs) -> SessionRecordingPlaylist: + def update(self, instance: SessionRecordingPlaylist, validated_data: dict, **kwargs) -> SessionRecordingPlaylist: try: before_update = SessionRecordingPlaylist.objects.get(pk=instance.id) except SessionRecordingPlaylist.DoesNotExist: diff --git a/ee/session_recordings/test/test_session_recording_extensions.py b/ee/session_recordings/test/test_session_recording_extensions.py index 35fd5d2bc8b7a..ad545e5cec33f 100644 --- a/ee/session_recordings/test/test_session_recording_extensions.py +++ b/ee/session_recordings/test/test_session_recording_extensions.py @@ -103,7 +103,7 @@ def test_persists_recording_from_blob_ingested_storage(self): for file in ["a", "b", "c"]: blob_path = f"{TEST_BUCKET}/team_id/{self.team.pk}/session_id/{session_id}/data" file_name = f"{blob_path}/{file}" - write(file_name, f"my content-{file}".encode("utf-8")) + write(file_name, f"my content-{file}".encode()) recording: SessionRecording = SessionRecording.objects.create(team=self.team, session_id=session_id) @@ -164,7 +164,7 @@ def test_can_save_content_to_new_location(self, mock_write: MagicMock): mock_write.assert_called_with( f"{expected_path}/12345000-12346000", - gzip.compress("the new content".encode("utf-8")), + gzip.compress(b"the new content"), extras={ "ContentEncoding": "gzip", "ContentType": "application/json", diff --git a/ee/settings.py b/ee/settings.py index 7342bdf98f987..d9a863c3f816b 100644 --- a/ee/settings.py +++ b/ee/settings.py @@ -3,14 +3,13 @@ """ import os -from typing import Dict, List from posthog.settings import AUTHENTICATION_BACKENDS, DEMO, SITE_URL, DEBUG from posthog.settings.utils import get_from_env from posthog.utils import str_to_bool # Zapier REST hooks -HOOK_EVENTS: Dict[str, str] = { +HOOK_EVENTS: dict[str, str] = { # "event_name": "App.Model.Action" (created/updated/deleted) "action_performed": "posthog.Action.performed", } @@ -43,7 +42,7 @@ SOCIAL_AUTH_GOOGLE_OAUTH2_KEY = os.getenv("SOCIAL_AUTH_GOOGLE_OAUTH2_KEY") SOCIAL_AUTH_GOOGLE_OAUTH2_SECRET = os.getenv("SOCIAL_AUTH_GOOGLE_OAUTH2_SECRET") if "SOCIAL_AUTH_GOOGLE_OAUTH2_WHITELISTED_DOMAINS" in os.environ: - SOCIAL_AUTH_GOOGLE_OAUTH2_WHITELISTED_DOMAINS: List[str] = os.environ[ + SOCIAL_AUTH_GOOGLE_OAUTH2_WHITELISTED_DOMAINS: list[str] = os.environ[ "SOCIAL_AUTH_GOOGLE_OAUTH2_WHITELISTED_DOMAINS" ].split(",") elif DEMO: diff --git a/ee/tasks/auto_rollback_feature_flag.py b/ee/tasks/auto_rollback_feature_flag.py index d1b7e606976a6..f676f91d0c4bf 100644 --- a/ee/tasks/auto_rollback_feature_flag.py +++ b/ee/tasks/auto_rollback_feature_flag.py @@ -1,5 +1,4 @@ from datetime import datetime, timedelta -from typing import Dict from zoneinfo import ZoneInfo from celery import shared_task @@ -30,7 +29,7 @@ def check_feature_flag_rollback_conditions(feature_flag_id: int) -> None: flag.save() -def calculate_rolling_average(threshold_metric: Dict, team: Team, timezone: str) -> float: +def calculate_rolling_average(threshold_metric: dict, team: Team, timezone: str) -> float: curr = datetime.now(tz=ZoneInfo(timezone)) rolling_average_days = 7 @@ -54,7 +53,7 @@ def calculate_rolling_average(threshold_metric: Dict, team: Team, timezone: str) return sum(data) / rolling_average_days -def check_condition(rollback_condition: Dict, feature_flag: FeatureFlag) -> bool: +def check_condition(rollback_condition: dict, feature_flag: FeatureFlag) -> bool: if rollback_condition["threshold_type"] == "sentry": created_date = feature_flag.created_at base_start_date = created_date.strftime("%Y-%m-%dT%H:%M:%S") diff --git a/ee/tasks/replay.py b/ee/tasks/replay.py index 036925b279a91..fcf57196c2dc5 100644 --- a/ee/tasks/replay.py +++ b/ee/tasks/replay.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any import structlog from celery import shared_task @@ -25,7 +25,7 @@ # we currently are allowed 500 calls per minute, so let's rate limit each worker # to much less than that @shared_task(ignore_result=False, queue=CeleryQueue.SESSION_REPLAY_EMBEDDINGS.value, rate_limit="75/m") -def embed_batch_of_recordings_task(recordings: List[Any], team_id: int) -> None: +def embed_batch_of_recordings_task(recordings: list[Any], team_id: int) -> None: try: team = Team.objects.get(id=team_id) runner = SessionEmbeddingsRunner(team=team) diff --git a/ee/tasks/slack.py b/ee/tasks/slack.py index 0137089b08bab..251e9fd26138b 100644 --- a/ee/tasks/slack.py +++ b/ee/tasks/slack.py @@ -1,5 +1,5 @@ import re -from typing import Any, Dict +from typing import Any from urllib.parse import urlparse import structlog @@ -16,7 +16,7 @@ SHARED_LINK_REGEX = r"\/(?:shared_dashboard|shared|embedded)\/(.+)" -def _block_for_asset(asset: ExportedAsset) -> Dict: +def _block_for_asset(asset: ExportedAsset) -> dict: image_url = asset.get_public_content_url() alt_text = None if asset.insight: diff --git a/ee/tasks/subscriptions/email_subscriptions.py b/ee/tasks/subscriptions/email_subscriptions.py index aa62b7d83a4e0..39e342bcec1dd 100644 --- a/ee/tasks/subscriptions/email_subscriptions.py +++ b/ee/tasks/subscriptions/email_subscriptions.py @@ -1,5 +1,5 @@ import uuid -from typing import List, Optional +from typing import Optional import structlog @@ -15,7 +15,7 @@ def send_email_subscription_report( email: str, subscription: Subscription, - assets: List[ExportedAsset], + assets: list[ExportedAsset], invite_message: Optional[str] = None, total_asset_count: Optional[int] = None, ) -> None: diff --git a/ee/tasks/subscriptions/slack_subscriptions.py b/ee/tasks/subscriptions/slack_subscriptions.py index 1d35259a6f3c4..73643c7a97bbd 100644 --- a/ee/tasks/subscriptions/slack_subscriptions.py +++ b/ee/tasks/subscriptions/slack_subscriptions.py @@ -1,5 +1,3 @@ -from typing import Dict, List - import structlog from django.conf import settings @@ -12,7 +10,7 @@ UTM_TAGS_BASE = "utm_source=posthog&utm_campaign=subscription_report" -def _block_for_asset(asset: ExportedAsset) -> Dict: +def _block_for_asset(asset: ExportedAsset) -> dict: image_url = asset.get_public_content_url() alt_text = None if asset.insight: @@ -26,7 +24,7 @@ def _block_for_asset(asset: ExportedAsset) -> Dict: def send_slack_subscription_report( subscription: Subscription, - assets: List[ExportedAsset], + assets: list[ExportedAsset], total_asset_count: int, is_new_subscription: bool = False, ) -> None: diff --git a/ee/tasks/subscriptions/subscription_utils.py b/ee/tasks/subscriptions/subscription_utils.py index d89d73d4a3b40..6fa4b63960fc2 100644 --- a/ee/tasks/subscriptions/subscription_utils.py +++ b/ee/tasks/subscriptions/subscription_utils.py @@ -1,5 +1,5 @@ import datetime -from typing import List, Tuple, Union +from typing import Union from django.conf import settings import structlog from celery import chain @@ -28,7 +28,7 @@ def generate_assets( resource: Union[Subscription, SharingConfiguration], max_asset_count: int = DEFAULT_MAX_ASSET_COUNT, -) -> Tuple[List[Insight], List[ExportedAsset]]: +) -> tuple[list[Insight], list[ExportedAsset]]: with SUBSCRIPTION_ASSET_GENERATION_TIMER.time(): if resource.dashboard: tiles = get_tiles_ordered_by_position(resource.dashboard) diff --git a/ee/tasks/test/subscriptions/test_subscriptions.py b/ee/tasks/test/subscriptions/test_subscriptions.py index d6afe50b68f7f..c814b2a4ebc18 100644 --- a/ee/tasks/test/subscriptions/test_subscriptions.py +++ b/ee/tasks/test/subscriptions/test_subscriptions.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import List from unittest.mock import MagicMock, call, patch from zoneinfo import ZoneInfo @@ -25,10 +24,10 @@ @patch("ee.tasks.subscriptions.generate_assets") @freeze_time("2022-02-02T08:55:00.000Z") class TestSubscriptionsTasks(APIBaseTest): - subscriptions: List[Subscription] = None # type: ignore + subscriptions: list[Subscription] = None # type: ignore dashboard: Dashboard insight: Insight - tiles: List[DashboardTile] = None # type: ignore + tiles: list[DashboardTile] = None # type: ignore asset: ExportedAsset def setUp(self) -> None: diff --git a/ee/tasks/test/subscriptions/test_subscriptions_utils.py b/ee/tasks/test/subscriptions/test_subscriptions_utils.py index c8ff89adcea65..edab23bbfb9ed 100644 --- a/ee/tasks/test/subscriptions/test_subscriptions_utils.py +++ b/ee/tasks/test/subscriptions/test_subscriptions_utils.py @@ -1,4 +1,3 @@ -from typing import List from unittest.mock import MagicMock, patch import pytest @@ -21,7 +20,7 @@ class TestSubscriptionsTasksUtils(APIBaseTest): dashboard: Dashboard insight: Insight asset: ExportedAsset - tiles: List[DashboardTile] + tiles: list[DashboardTile] def setUp(self) -> None: self.dashboard = Dashboard.objects.create(team=self.team, name="private dashboard", created_by=self.user) diff --git a/ee/tasks/test/test_slack.py b/ee/tasks/test/test_slack.py index 03b28b8155cfe..64b227d7d1e64 100644 --- a/ee/tasks/test/test_slack.py +++ b/ee/tasks/test/test_slack.py @@ -1,4 +1,3 @@ -from typing import List from unittest.mock import MagicMock, patch from freezegun import freeze_time @@ -14,7 +13,7 @@ from posthog.test.base import APIBaseTest -def create_mock_unfurl_event(team_id: str, links: List[str]): +def create_mock_unfurl_event(team_id: str, links: list[str]): return { "token": "XXYYZZ", "team_id": team_id, diff --git a/ee/urls.py b/ee/urls.py index a3851a2807583..2ee3f7d3a8fc0 100644 --- a/ee/urls.py +++ b/ee/urls.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any from django.conf import settings from django.contrib import admin @@ -92,7 +92,7 @@ def extend_api_router( ) -urlpatterns: List[Any] = [ +urlpatterns: list[Any] = [ path("api/saml/metadata/", authentication.saml_metadata_view), path("api/sentry_stats/", sentry_stats.sentry_stats), *admin_urlpatterns, diff --git a/gunicorn.config.py b/gunicorn.config.py index 1e56182026068..acd7ba3f5f592 100644 --- a/gunicorn.config.py +++ b/gunicorn.config.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- import logging import os diff --git a/hogvm/python/execute.py b/hogvm/python/execute.py index 4e4a61a1af5a0..a1130c0d54c89 100644 --- a/hogvm/python/execute.py +++ b/hogvm/python/execute.py @@ -1,5 +1,5 @@ import re -from typing import List, Any, Dict +from typing import Any from hogvm.python.operation import Operation, HOGQL_BYTECODE_IDENTIFIER @@ -33,7 +33,7 @@ def to_concat_arg(arg) -> str: return str(arg) -def execute_bytecode(bytecode: List[Any], fields: Dict[str, Any]) -> Any: +def execute_bytecode(bytecode: list[Any], fields: dict[str, Any]) -> Any: try: stack = [] iterator = iter(bytecode) diff --git a/plugin-server/bin/generate_session_recordings_messages.py b/plugin-server/bin/generate_session_recordings_messages.py index 4b5462bebd3a7..cfd3d034d194b 100755 --- a/plugin-server/bin/generate_session_recordings_messages.py +++ b/plugin-server/bin/generate_session_recordings_messages.py @@ -53,7 +53,6 @@ import json import uuid from sys import stderr, stdout -from typing import List import numpy from faker import Faker @@ -144,7 +143,7 @@ def get_parser(): def chunked( data: str, chunk_size: int, -) -> List[str]: +) -> list[str]: return [data[i : i + chunk_size] for i in range(0, len(data), chunk_size)] diff --git a/posthog/api/action.py b/posthog/api/action.py index 437f0227c817f..38eb33d10745e 100644 --- a/posthog/api/action.py +++ b/posthog/api/action.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, cast +from typing import Any, cast from django.db.models import Count, Prefetch from rest_framework import request, serializers, viewsets @@ -123,7 +123,7 @@ def create(self, validated_data: Any) -> Any: return instance - def update(self, instance: Any, validated_data: Dict[str, Any]) -> Any: + def update(self, instance: Any, validated_data: dict[str, Any]) -> Any: steps = validated_data.pop("steps", None) # If there's no steps property at all we just ignore it # If there is a step property but it's an empty array [], we'll delete all the steps @@ -182,7 +182,7 @@ def get_queryset(self): def list(self, request: request.Request, *args: Any, **kwargs: Any) -> Response: actions = self.get_queryset() - actions_list: List[Dict[Any, Any]] = self.serializer_class( + actions_list: list[dict[Any, Any]] = self.serializer_class( actions, many=True, context={"request": request} ).data # type: ignore return Response({"results": actions_list}) diff --git a/posthog/api/activity_log.py b/posthog/api/activity_log.py index fefa2554d19a3..35ff30d5703a6 100644 --- a/posthog/api/activity_log.py +++ b/posthog/api/activity_log.py @@ -1,5 +1,5 @@ import time -from typing import Any, Optional, Dict +from typing import Any, Optional from django.db.models import Q, QuerySet @@ -49,7 +49,7 @@ class ActivityLogPagination(pagination.CursorPagination): # context manager for gathering a sequence of server timings class ServerTimingsGathered: # Class level dictionary to store timings - timings_dict: Dict[str, float] = {} + timings_dict: dict[str, float] = {} def __call__(self, name): self.name = name diff --git a/posthog/api/annotation.py b/posthog/api/annotation.py index 4806d5a632f25..7216efe6cd643 100644 --- a/posthog/api/annotation.py +++ b/posthog/api/annotation.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any from django.db.models import Q, QuerySet from django.db.models.signals import post_save @@ -40,11 +40,11 @@ class Meta: "updated_at", ] - def update(self, instance: Annotation, validated_data: Dict[str, Any]) -> Annotation: + def update(self, instance: Annotation, validated_data: dict[str, Any]) -> Annotation: instance.team_id = self.context["team_id"] return super().update(instance, validated_data) - def create(self, validated_data: Dict[str, Any], *args: Any, **kwargs: Any) -> Annotation: + def create(self, validated_data: dict[str, Any], *args: Any, **kwargs: Any) -> Annotation: request = self.context["request"] team = self.context["get_team"]() annotation = Annotation.objects.create( diff --git a/posthog/api/authentication.py b/posthog/api/authentication.py index 069acac50c95f..d7911059506ed 100644 --- a/posthog/api/authentication.py +++ b/posthog/api/authentication.py @@ -1,6 +1,6 @@ import datetime import time -from typing import Any, Dict, Optional, cast +from typing import Any, Optional, cast from uuid import uuid4 from django.conf import settings @@ -92,7 +92,7 @@ class LoginSerializer(serializers.Serializer): email = serializers.EmailField() password = serializers.CharField() - def to_representation(self, instance: Any) -> Dict[str, Any]: + def to_representation(self, instance: Any) -> dict[str, Any]: return {"success": True} def _check_if_2fa_required(self, user: User) -> bool: @@ -113,7 +113,7 @@ def _check_if_2fa_required(self, user: User) -> bool: pass return True - def create(self, validated_data: Dict[str, str]) -> Any: + def create(self, validated_data: dict[str, str]) -> Any: # Check SSO enforcement (which happens at the domain level) sso_enforcement = OrganizationDomain.objects.get_sso_enforcement_for_email_address(validated_data["email"]) if sso_enforcement: @@ -159,10 +159,10 @@ def create(self, validated_data: Dict[str, str]) -> Any: class LoginPrecheckSerializer(serializers.Serializer): email = serializers.EmailField() - def to_representation(self, instance: Dict[str, str]) -> Dict[str, Any]: + def to_representation(self, instance: dict[str, str]) -> dict[str, Any]: return instance - def create(self, validated_data: Dict[str, str]) -> Any: + def create(self, validated_data: dict[str, str]) -> Any: email = validated_data.get("email", "") # TODO: Refactor methods below to remove duplicate queries return { diff --git a/posthog/api/capture.py b/posthog/api/capture.py index 31592e90e790d..9c223f8264acb 100644 --- a/posthog/api/capture.py +++ b/posthog/api/capture.py @@ -18,7 +18,8 @@ from sentry_sdk.api import capture_exception, start_span from statshog.defaults.django import statsd from token_bucket import Limiter, MemoryStorage -from typing import Any, Dict, Iterator, List, Optional, Tuple, Set +from typing import Any, Optional +from collections.abc import Iterator from ee.billing.quota_limiting import QuotaLimitingCaches from posthog.api.utils import get_data, get_token, safe_clickhouse_string @@ -129,12 +130,12 @@ def build_kafka_event_data( distinct_id: str, ip: Optional[str], site_url: str, - data: Dict, + data: dict, now: datetime, sent_at: Optional[datetime], event_uuid: UUIDT, token: str, -) -> Dict: +) -> dict: logger.debug("build_kafka_event_data", token=token) return { "uuid": str(event_uuid), @@ -168,10 +169,10 @@ def _kafka_topic(event_name: str, historical: bool = False, overflowing: bool = def log_event( - data: Dict, + data: dict, event_name: str, partition_key: Optional[str], - headers: Optional[List] = None, + headers: Optional[list] = None, historical: bool = False, overflowing: bool = False, ) -> FutureRecordMetadata: @@ -205,7 +206,7 @@ def _datetime_from_seconds_or_millis(timestamp: str) -> datetime: return datetime.fromtimestamp(timestamp_number, timezone.utc) -def _get_sent_at(data, request) -> Tuple[Optional[datetime], Any]: +def _get_sent_at(data, request) -> tuple[Optional[datetime], Any]: try: if request.GET.get("_"): # posthog-js sent_at = request.GET["_"] @@ -253,7 +254,7 @@ def _check_token_shape(token: Any) -> Optional[str]: return None -def get_distinct_id(data: Dict[str, Any]) -> str: +def get_distinct_id(data: dict[str, Any]) -> str: raw_value: Any = "" try: raw_value = data["$distinct_id"] @@ -274,12 +275,12 @@ def get_distinct_id(data: Dict[str, Any]) -> str: return str(raw_value)[0:200] -def drop_performance_events(events: List[Any]) -> List[Any]: +def drop_performance_events(events: list[Any]) -> list[Any]: cleaned_list = [event for event in events if event.get("event") != "$performance_event"] return cleaned_list -def drop_events_over_quota(token: str, events: List[Any]) -> List[Any]: +def drop_events_over_quota(token: str, events: list[Any]) -> list[Any]: if not settings.EE_AVAILABLE: return events @@ -381,7 +382,7 @@ def get_event(request): structlog.contextvars.bind_contextvars(token=token) - replay_events: List[Any] = [] + replay_events: list[Any] = [] historical = token in settings.TOKENS_HISTORICAL_DATA with start_span(op="request.process"): @@ -437,7 +438,7 @@ def get_event(request): generate_exception_response("capture", f"Invalid payload: {e}", code="invalid_payload"), ) - futures: List[FutureRecordMetadata] = [] + futures: list[FutureRecordMetadata] = [] with start_span(op="kafka.produce") as span: span.set_tag("event.count", len(processed_events)) @@ -536,7 +537,7 @@ def get_event(request): return cors_response(request, JsonResponse({"status": 1})) -def preprocess_events(events: List[Dict[str, Any]]) -> Iterator[Tuple[Dict[str, Any], UUIDT, str]]: +def preprocess_events(events: list[dict[str, Any]]) -> Iterator[tuple[dict[str, Any], UUIDT, str]]: for event in events: event_uuid = UUIDT() distinct_id = get_distinct_id(event) @@ -580,7 +581,7 @@ def capture_internal( event_uuid=None, token=None, historical=False, - extra_headers: List[Tuple[str, str]] | None = None, + extra_headers: list[tuple[str, str]] | None = None, ): if event_uuid is None: event_uuid = UUIDT() @@ -680,7 +681,7 @@ def is_randomly_partitioned(candidate_partition_key: str) -> bool: @cache_for(timedelta(seconds=30), background_refresh=True) -def _list_overflowing_keys(input_type: InputType) -> Set[str]: +def _list_overflowing_keys(input_type: InputType) -> set[str]: """Retrieve the active overflows from Redis with caching and pre-fetching cache_for will keep the old value if Redis is temporarily unavailable. diff --git a/posthog/api/cohort.py b/posthog/api/cohort.py index 64eb30db9b0ba..af85769fd0ff1 100644 --- a/posthog/api/cohort.py +++ b/posthog/api/cohort.py @@ -18,7 +18,7 @@ from posthog.metrics import LABEL_TEAM_ID from posthog.renderers import SafeJSONRenderer from datetime import datetime -from typing import Any, Dict, cast, Optional +from typing import Any, cast, Optional from django.conf import settings from django.db.models import QuerySet, Prefetch, prefetch_related_objects, OuterRef, Subquery @@ -133,7 +133,7 @@ class Meta: "experiment_set", ] - def _handle_static(self, cohort: Cohort, context: Dict, validated_data: Dict) -> None: + def _handle_static(self, cohort: Cohort, context: dict, validated_data: dict) -> None: request = self.context["request"] if request.FILES.get("csv"): self._calculate_static_by_csv(request.FILES["csv"], cohort) @@ -149,7 +149,7 @@ def _handle_static(self, cohort: Cohort, context: Dict, validated_data: Dict) -> if filter_data: insert_cohort_from_insight_filter.delay(cohort.pk, filter_data) - def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> Cohort: + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Cohort: request = self.context["request"] validated_data["created_by"] = request.user @@ -176,7 +176,7 @@ def _calculate_static_by_csv(self, file, cohort: Cohort) -> None: distinct_ids_and_emails = [row[0] for row in reader if len(row) > 0 and row] calculate_cohort_from_list.delay(cohort.pk, distinct_ids_and_emails) - def validate_query(self, query: Optional[Dict]) -> Optional[Dict]: + def validate_query(self, query: Optional[dict]) -> Optional[dict]: if not query: return None if not isinstance(query, dict): @@ -186,7 +186,7 @@ def validate_query(self, query: Optional[Dict]) -> Optional[Dict]: ActorsQuery.model_validate(query) return query - def validate_filters(self, request_filters: Dict): + def validate_filters(self, request_filters: dict): if isinstance(request_filters, dict) and "properties" in request_filters: if self.context["request"].method == "PATCH": parsed_filter = Filter(data=request_filters) @@ -225,7 +225,7 @@ def validate_filters(self, request_filters: Dict): else: raise ValidationError("Filters must be a dictionary with a 'properties' key.") - def update(self, cohort: Cohort, validated_data: Dict, *args: Any, **kwargs: Any) -> Cohort: # type: ignore + def update(self, cohort: Cohort, validated_data: dict, *args: Any, **kwargs: Any) -> Cohort: # type: ignore request = self.context["request"] user = cast(User, request.user) @@ -498,7 +498,7 @@ def insert_cohort_query_actors_into_ch(cohort: Cohort): insert_actors_into_cohort_by_query(cohort, query, {}, context) -def insert_cohort_actors_into_ch(cohort: Cohort, filter_data: Dict): +def insert_cohort_actors_into_ch(cohort: Cohort, filter_data: dict): from_existing_cohort_id = filter_data.get("from_cohort_id") context: HogQLContext @@ -561,7 +561,7 @@ def insert_cohort_actors_into_ch(cohort: Cohort, filter_data: Dict): insert_actors_into_cohort_by_query(cohort, query, params, context) -def insert_actors_into_cohort_by_query(cohort: Cohort, query: str, params: Dict[str, Any], context: HogQLContext): +def insert_actors_into_cohort_by_query(cohort: Cohort, query: str, params: dict[str, Any], context: HogQLContext): try: sync_execute( INSERT_COHORT_ALL_PEOPLE_THROUGH_PERSON_ID.format(cohort_table=PERSON_STATIC_COHORT_TABLE, query=query), @@ -600,7 +600,7 @@ def get_cohort_actors_for_feature_flag(cohort_id: int, flag: str, team_id: int, cohort = Cohort.objects.get(pk=cohort_id, team_id=team_id) matcher_cache = FlagsMatcherCache(team_id) uuids_to_add_to_cohort = [] - cohorts_cache: Dict[int, CohortOrEmpty] = {} + cohorts_cache: dict[int, CohortOrEmpty] = {} if feature_flag.uses_cohorts: # TODO: Consider disabling flags with cohorts for creating static cohorts @@ -709,7 +709,7 @@ def get_cohort_actors_for_feature_flag(cohort_id: int, flag: str, team_id: int, capture_exception(err) -def get_default_person_property(prop: Property, cohorts_cache: Dict[int, CohortOrEmpty]): +def get_default_person_property(prop: Property, cohorts_cache: dict[int, CohortOrEmpty]): default_person_properties = {} if prop.operator not in ("is_set", "is_not_set") and prop.type == "person": @@ -725,7 +725,7 @@ def get_default_person_property(prop: Property, cohorts_cache: Dict[int, CohortO return default_person_properties -def get_default_person_properties_for_cohort(cohort: Cohort, cohorts_cache: Dict[int, CohortOrEmpty]) -> Dict[str, str]: +def get_default_person_properties_for_cohort(cohort: Cohort, cohorts_cache: dict[int, CohortOrEmpty]) -> dict[str, str]: """ Returns a dictionary of default person properties to use when evaluating a feature flag """ diff --git a/posthog/api/comments.py b/posthog/api/comments.py index 8b9a9174dda61..63ef5d1d33a16 100644 --- a/posthog/api/comments.py +++ b/posthog/api/comments.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, cast +from typing import Any, cast from django.db import transaction from django.db.models import QuerySet @@ -40,7 +40,7 @@ def create(self, validated_data: Any) -> Any: validated_data["team_id"] = self.context["team_id"] return super().create(validated_data) - def update(self, instance: Comment, validated_data: Dict, **kwargs) -> Comment: + def update(self, instance: Comment, validated_data: dict, **kwargs) -> Comment: request = self.context["request"] with transaction.atomic(): diff --git a/posthog/api/dashboards/dashboard.py b/posthog/api/dashboards/dashboard.py index a89d41814d616..850e29b52a4e3 100644 --- a/posthog/api/dashboards/dashboard.py +++ b/posthog/api/dashboards/dashboard.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, List, Optional, Type, cast +from typing import Any, Optional, cast import structlog from django.db.models import Prefetch, QuerySet @@ -155,13 +155,13 @@ class Meta: ] read_only_fields = ["creation_mode", "effective_restriction_level", "is_shared"] - def validate_filters(self, value) -> Dict: + def validate_filters(self, value) -> dict: if not isinstance(value, dict): raise serializers.ValidationError("Filters must be a dictionary") return value - def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> Dashboard: + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Dashboard: request = self.context["request"] validated_data["created_by"] = request.user team_id = self.context["team_id"] @@ -260,7 +260,7 @@ def _deep_duplicate_tiles(self, dashboard: Dashboard, existing_tile: DashboardTi color=existing_tile.color, ) - def update(self, instance: Dashboard, validated_data: Dict, *args: Any, **kwargs: Any) -> Dashboard: + def update(self, instance: Dashboard, validated_data: dict, *args: Any, **kwargs: Any) -> Dashboard: can_user_restrict = self.user_permissions.dashboard(instance).can_restrict if "restriction_level" in validated_data and not can_user_restrict: raise exceptions.PermissionDenied( @@ -292,11 +292,11 @@ def update(self, instance: Dashboard, validated_data: Dict, *args: Any, **kwargs return instance @staticmethod - def _update_tiles(instance: Dashboard, tile_data: Dict, user: User) -> None: + def _update_tiles(instance: Dashboard, tile_data: dict, user: User) -> None: tile_data.pop("is_cached", None) # read only field if tile_data.get("text", None): - text_json: Dict = tile_data.get("text", {}) + text_json: dict = tile_data.get("text", {}) created_by_json = text_json.get("created_by", None) if created_by_json: last_modified_by = user @@ -348,7 +348,7 @@ def _undo_delete_related_tiles(instance: Dashboard) -> None: insights_to_undelete.append(tile.insight) Insight.objects.bulk_update(insights_to_undelete, ["deleted"]) - def get_tiles(self, dashboard: Dashboard) -> Optional[List[ReturnDict]]: + def get_tiles(self, dashboard: Dashboard) -> Optional[list[ReturnDict]]: if self.context["view"].action == "list": return None @@ -401,7 +401,7 @@ class DashboardsViewSet( queryset = Dashboard.objects_including_soft_deleted.order_by("name") permission_classes = [CanEditDashboard] - def get_serializer_class(self) -> Type[BaseSerializer]: + def get_serializer_class(self) -> type[BaseSerializer]: return DashboardBasicSerializer if self.action == "list" else DashboardSerializer def get_queryset(self) -> QuerySet: @@ -512,7 +512,7 @@ def create_from_template_json(self, request: Request, *args: Any, **kwargs: Any) class LegacyDashboardsViewSet(DashboardsViewSet): derive_current_team_from_user_only = True - def get_parents_query_dict(self) -> Dict[str, Any]: + def get_parents_query_dict(self) -> dict[str, Any]: if not self.request.user.is_authenticated or "share_token" in self.request.GET: return {} return {"team_id": self.team_id} diff --git a/posthog/api/dashboards/dashboard_template_json_schema_parser.py b/posthog/api/dashboards/dashboard_template_json_schema_parser.py index 3463601514e01..8f9149cd84d11 100644 --- a/posthog/api/dashboards/dashboard_template_json_schema_parser.py +++ b/posthog/api/dashboards/dashboard_template_json_schema_parser.py @@ -15,9 +15,7 @@ class DashboardTemplateCreationJSONSchemaParser(JSONParser): The template is sent in the "template" key""" def parse(self, stream, media_type=None, parser_context=None): - data = super(DashboardTemplateCreationJSONSchemaParser, self).parse( - stream, media_type or "application/json", parser_context - ) + data = super().parse(stream, media_type or "application/json", parser_context) try: template = data["template"] jsonschema.validate(template, dashboard_template_schema) diff --git a/posthog/api/dashboards/dashboard_templates.py b/posthog/api/dashboards/dashboard_templates.py index 6e8752e0cbd39..03740b06ebd6b 100644 --- a/posthog/api/dashboards/dashboard_templates.py +++ b/posthog/api/dashboards/dashboard_templates.py @@ -1,6 +1,5 @@ import json from pathlib import Path -from typing import Dict import structlog from django.db.models import Q @@ -50,7 +49,7 @@ class Meta: "scope", ] - def create(self, validated_data: Dict, *args, **kwargs) -> DashboardTemplate: + def create(self, validated_data: dict, *args, **kwargs) -> DashboardTemplate: if not validated_data["tiles"]: raise ValidationError(detail="You need to provide tiles for the template.") @@ -61,7 +60,7 @@ def create(self, validated_data: Dict, *args, **kwargs) -> DashboardTemplate: validated_data["team_id"] = self.context["team_id"] return super().create(validated_data, *args, **kwargs) - def update(self, instance: DashboardTemplate, validated_data: Dict, *args, **kwargs) -> DashboardTemplate: + def update(self, instance: DashboardTemplate, validated_data: dict, *args, **kwargs) -> DashboardTemplate: # if the original request was to make the template scope to team only, and the template is none then deny the request if validated_data.get("scope") == "team" and instance.scope == "global" and not instance.team_id: raise ValidationError(detail="The original templates cannot be made private as they would be lost.") diff --git a/posthog/api/dashboards/test/test_dashboard_templates.py b/posthog/api/dashboards/test/test_dashboard_templates.py index f07610ba90351..e562b3798d895 100644 --- a/posthog/api/dashboards/test/test_dashboard_templates.py +++ b/posthog/api/dashboards/test/test_dashboard_templates.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, List +from typing import Optional from rest_framework import status @@ -510,7 +510,7 @@ def test_filter_template_list_by_scope(self): assert flag_response.status_code == status.HTTP_200_OK assert [(r["id"], r["scope"]) for r in flag_response.json()["results"]] == [(flag_template_id, "feature_flag")] - def create_template(self, overrides: Dict[str, str | List[str]], team_id: Optional[int] = None) -> str: + def create_template(self, overrides: dict[str, str | list[str]], team_id: Optional[int] = None) -> str: template = {**variable_template, **overrides} response = self.client.post( f"/api/projects/{team_id or self.team.pk}/dashboard_templates", diff --git a/posthog/api/dead_letter_queue.py b/posthog/api/dead_letter_queue.py index 93e2b09370b0e..2bab687543568 100644 --- a/posthog/api/dead_letter_queue.py +++ b/posthog/api/dead_letter_queue.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from rest_framework import mixins, permissions, serializers, viewsets @@ -65,7 +65,7 @@ class DeadLetterQueueMetric: key: str = "" metric: str = "" value: Union[str, bool, int, None] = None - subrows: Optional[List[Any]] = None + subrows: Optional[list[Any]] = None def __init__(self, **kwargs): for field in ("key", "metric", "value", "subrows"): @@ -138,7 +138,7 @@ def get_dead_letter_queue_events_last_24h() -> int: )[0][0] -def get_dead_letter_queue_events_per_error(offset: Optional[int] = 0) -> List[Union[str, int]]: +def get_dead_letter_queue_events_per_error(offset: Optional[int] = 0) -> list[Union[str, int]]: return sync_execute( f""" SELECT error, count(*) AS c @@ -151,7 +151,7 @@ def get_dead_letter_queue_events_per_error(offset: Optional[int] = 0) -> List[Un ) -def get_dead_letter_queue_events_per_location(offset: Optional[int] = 0) -> List[Union[str, int]]: +def get_dead_letter_queue_events_per_location(offset: Optional[int] = 0) -> list[Union[str, int]]: return sync_execute( f""" SELECT error_location, count(*) AS c @@ -164,7 +164,7 @@ def get_dead_letter_queue_events_per_location(offset: Optional[int] = 0) -> List ) -def get_dead_letter_queue_events_per_day(offset: Optional[int] = 0) -> List[Union[str, int]]: +def get_dead_letter_queue_events_per_day(offset: Optional[int] = 0) -> list[Union[str, int]]: return sync_execute( f""" SELECT toDate(error_timestamp) as day, count(*) AS c @@ -177,7 +177,7 @@ def get_dead_letter_queue_events_per_day(offset: Optional[int] = 0) -> List[Unio ) -def get_dead_letter_queue_events_per_tag(offset: Optional[int] = 0) -> List[Union[str, int]]: +def get_dead_letter_queue_events_per_tag(offset: Optional[int] = 0) -> list[Union[str, int]]: return sync_execute( f""" SELECT arrayJoin(tags) as tag, count(*) as c from events_dead_letter_queue diff --git a/posthog/api/decide.py b/posthog/api/decide.py index 3a6e08bc7a7a0..827194dea9c7b 100644 --- a/posthog/api/decide.py +++ b/posthog/api/decide.py @@ -1,6 +1,6 @@ import re from random import random -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from urllib.parse import urlparse import structlog @@ -56,7 +56,7 @@ def on_permitted_recording_domain(team: Team, request: HttpRequest) -> bool: return is_authorized_web_client or is_authorized_mobile_client -def hostname_in_allowed_url_list(allowed_url_list: Optional[List[str]], hostname: Optional[str]) -> bool: +def hostname_in_allowed_url_list(allowed_url_list: Optional[list[str]], hostname: Optional[str]) -> bool: if not hostname: return False @@ -182,7 +182,7 @@ def get_decide(request: HttpRequest): if geoip_enabled: property_overrides = get_geoip_properties(get_ip_address(request)) - all_property_overrides: Dict[str, Union[str, int]] = { + all_property_overrides: dict[str, Union[str, int]] = { **property_overrides, **(data.get("person_properties") or {}), } @@ -296,8 +296,8 @@ def get_decide(request: HttpRequest): return cors_response(request, JsonResponse(response)) -def _session_recording_config_response(request: HttpRequest, team: Team) -> bool | Dict: - session_recording_config_response: bool | Dict = False +def _session_recording_config_response(request: HttpRequest, team: Team) -> bool | dict: + session_recording_config_response: bool | dict = False try: if team.session_recording_opt_in and ( @@ -312,7 +312,7 @@ def _session_recording_config_response(request: HttpRequest, team: Team) -> bool linked_flag = None linked_flag_config = team.session_recording_linked_flag or None - if isinstance(linked_flag_config, Dict): + if isinstance(linked_flag_config, dict): linked_flag_key = linked_flag_config.get("key", None) linked_flag_variant = linked_flag_config.get("variant", None) if linked_flag_variant is not None: @@ -330,7 +330,7 @@ def _session_recording_config_response(request: HttpRequest, team: Team) -> bool "networkPayloadCapture": team.session_recording_network_payload_capture_config or None, } - if isinstance(team.session_replay_config, Dict): + if isinstance(team.session_replay_config, dict): record_canvas = team.session_replay_config.get("record_canvas", False) session_recording_config_response.update( { diff --git a/posthog/api/documentation.py b/posthog/api/documentation.py index 47820a9cb2203..3cae48fcdb006 100644 --- a/posthog/api/documentation.py +++ b/posthog/api/documentation.py @@ -1,5 +1,5 @@ import re -from typing import Dict, get_args +from typing import get_args from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import ( @@ -215,7 +215,7 @@ def preprocess_exclude_path_format(endpoints, **kwargs): def custom_postprocessing_hook(result, generator, request, public): all_tags = [] - paths: Dict[str, Dict] = {} + paths: dict[str, dict] = {} for path, methods in result["paths"].items(): paths[path] = {} diff --git a/posthog/api/early_access_feature.py b/posthog/api/early_access_feature.py index 911c860a75a16..57885666fde7d 100644 --- a/posthog/api/early_access_feature.py +++ b/posthog/api/early_access_feature.py @@ -1,5 +1,3 @@ -from typing import Type - from django.http import JsonResponse from rest_framework.response import Response from posthog.api.feature_flag import FeatureFlagSerializer, MinimalFeatureFlagSerializer @@ -221,7 +219,7 @@ class EarlyAccessFeatureViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): scope_object = "early_access_feature" queryset = EarlyAccessFeature.objects.select_related("feature_flag").all() - def get_serializer_class(self) -> Type[serializers.Serializer]: + def get_serializer_class(self) -> type[serializers.Serializer]: if self.request.method == "POST": return EarlyAccessFeatureSerializerCreateOnly else: diff --git a/posthog/api/element.py b/posthog/api/element.py index d7b721dee8195..b617ea8be28b5 100644 --- a/posthog/api/element.py +++ b/posthog/api/element.py @@ -1,4 +1,4 @@ -from typing import Literal, Tuple +from typing import Literal from rest_framework import request, response, serializers, viewsets from rest_framework.decorators import action @@ -128,8 +128,8 @@ def stats(self, request: request.Request, **kwargs) -> response.Response: else: return response.Response(serialized_elements) - def _events_filter(self, request) -> Tuple[Literal["$autocapture", "$rageclick"], ...]: - event_to_filter: Tuple[Literal["$autocapture", "$rageclick"], ...] = () + def _events_filter(self, request) -> tuple[Literal["$autocapture", "$rageclick"], ...]: + event_to_filter: tuple[Literal["$autocapture", "$rageclick"], ...] = () # when multiple includes are sent expects them as separate parameters # e.g. ?include=a&include=b events_to_include = request.query_params.getlist("include", []) diff --git a/posthog/api/event.py b/posthog/api/event.py index 6366ee866f657..5c642a2612973 100644 --- a/posthog/api/event.py +++ b/posthog/api/event.py @@ -1,7 +1,7 @@ import json import urllib from datetime import datetime -from typing import Any, Dict, List, Optional, Union +from typing import Any, List, Optional, Union # noqa: UP035 from django.db.models.query import Prefetch from drf_spectacular.types import OpenApiTypes @@ -94,7 +94,7 @@ def _build_next_url( self, request: request.Request, last_event_timestamp: datetime, - order_by: List[str], + order_by: list[str], ) -> str: params = request.GET.dict() reverse = "-timestamp" in order_by @@ -175,7 +175,7 @@ def list(self, request: request.Request, *args: Any, **kwargs: Any) -> response. team = self.team filter = Filter(request=request, team=self.team) - order_by: List[str] = ( + order_by: list[str] = ( list(json.loads(request.GET["orderBy"])) if request.GET.get("orderBy") else ["-timestamp"] ) @@ -217,11 +217,11 @@ def list(self, request: request.Request, *args: Any, **kwargs: Any) -> response. capture_exception(ex) raise ex - def _get_people(self, query_result: List[Dict], team: Team) -> Dict[str, Any]: + def _get_people(self, query_result: List[dict], team: Team) -> dict[str, Any]: # noqa: UP006 distinct_ids = [event["distinct_id"] for event in query_result] persons = get_persons_by_distinct_ids(team.pk, distinct_ids) persons = persons.prefetch_related(Prefetch("persondistinctid_set", to_attr="distinct_ids_cache")) - distinct_to_person: Dict[str, Person] = {} + distinct_to_person: dict[str, Person] = {} for person in persons: for distinct_id in person.distinct_ids: distinct_to_person[distinct_id] = person diff --git a/posthog/api/event_definition.py b/posthog/api/event_definition.py index 82a9c0617bd74..76314578fb98f 100644 --- a/posthog/api/event_definition.py +++ b/posthog/api/event_definition.py @@ -1,4 +1,4 @@ -from typing import Any, Literal, Tuple, Type, cast +from typing import Any, Literal, cast from django.db.models import Manager, Prefetch from rest_framework import ( @@ -117,7 +117,7 @@ def get_queryset(self): def _ordering_params_from_request( self, - ) -> Tuple[str, Literal["ASC", "DESC"]]: + ) -> tuple[str, Literal["ASC", "DESC"]]: order_direction: Literal["ASC", "DESC"] ordering = self.request.GET.get("ordering") @@ -154,7 +154,7 @@ def get_object(self): return EventDefinition.objects.get(id=id, team_id=self.team_id) - def get_serializer_class(self) -> Type[serializers.ModelSerializer]: + def get_serializer_class(self) -> type[serializers.ModelSerializer]: serializer_class = self.serializer_class if EE_AVAILABLE and self.request.user.organization.is_feature_available( # type: ignore AvailableFeature.INGESTION_TAXONOMY diff --git a/posthog/api/exports.py b/posthog/api/exports.py index 2099b2f169e2e..9fbaea35df3c2 100644 --- a/posthog/api/exports.py +++ b/posthog/api/exports.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Any, Dict +from typing import Any import structlog from django.http import HttpResponse @@ -40,7 +40,7 @@ class Meta: ] read_only_fields = ["id", "created_at", "has_content", "filename"] - def validate(self, data: Dict) -> Dict: + def validate(self, data: dict) -> dict: if not data.get("export_format"): raise ValidationError("Must provide export format") @@ -61,13 +61,13 @@ def validate(self, data: Dict) -> Dict: def synthetic_create(self, reason: str, *args: Any, **kwargs: Any) -> ExportedAsset: return self._create_asset(self.validated_data, user=None, reason=reason) - def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> ExportedAsset: + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> ExportedAsset: request = self.context["request"] return self._create_asset(validated_data, user=request.user, reason=None) def _create_asset( self, - validated_data: Dict, + validated_data: dict, user: User | None, reason: str | None, ) -> ExportedAsset: diff --git a/posthog/api/feature_flag.py b/posthog/api/feature_flag.py index 8bf1dbb5d3cf4..bd53f02955252 100644 --- a/posthog/api/feature_flag.py +++ b/posthog/api/feature_flag.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, Optional, cast +from typing import Any, Optional, cast from datetime import datetime from django.db.models import QuerySet, Q, deletion @@ -145,12 +145,12 @@ def get_is_simple_flag(self, feature_flag: FeatureFlag) -> bool: and feature_flag.aggregation_group_type_index is None ) - def get_features(self, feature_flag: FeatureFlag) -> Dict: + def get_features(self, feature_flag: FeatureFlag) -> dict: from posthog.api.early_access_feature import MinimalEarlyAccessFeatureSerializer return MinimalEarlyAccessFeatureSerializer(feature_flag.features, many=True).data - def get_surveys(self, feature_flag: FeatureFlag) -> Dict: + def get_surveys(self, feature_flag: FeatureFlag) -> dict: from posthog.api.survey import SurveyAPISerializer return SurveyAPISerializer(feature_flag.surveys_linked_flag, many=True).data @@ -263,7 +263,7 @@ def properties_all_match(predicate): return filters - def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> FeatureFlag: + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> FeatureFlag: request = self.context["request"] validated_data["created_by"] = request.user validated_data["team_id"] = self.context["team_id"] @@ -299,7 +299,7 @@ def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> FeatureFlag return instance - def update(self, instance: FeatureFlag, validated_data: Dict, *args: Any, **kwargs: Any) -> FeatureFlag: + def update(self, instance: FeatureFlag, validated_data: dict, *args: Any, **kwargs: Any) -> FeatureFlag: if "deleted" in validated_data and validated_data["deleted"] is True and instance.features.count() > 0: raise exceptions.ValidationError( "Cannot delete a feature flag that is in use with early access features. Please delete the early access feature before deleting the flag." @@ -496,13 +496,11 @@ def my_flags(self, request: request.Request, **kwargs): feature_flags, many=True, context=self.get_serializer_context() ).data return Response( - ( - { - "feature_flag": feature_flag, - "value": matches.get(feature_flag["key"], False), - } - for feature_flag in all_serialized_flags - ) + { + "feature_flag": feature_flag, + "value": matches.get(feature_flag["key"], False), + } + for feature_flag in all_serialized_flags ) @action( @@ -516,7 +514,7 @@ def local_evaluation(self, request: request.Request, **kwargs): should_send_cohorts = "send_cohorts" in request.GET cohorts = {} - seen_cohorts_cache: Dict[int, CohortOrEmpty] = {} + seen_cohorts_cache: dict[int, CohortOrEmpty] = {} if should_send_cohorts: seen_cohorts_cache = { diff --git a/posthog/api/geoip.py b/posthog/api/geoip.py index d3d029cdd3f33..7a749c0b294c2 100644 --- a/posthog/api/geoip.py +++ b/posthog/api/geoip.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Optional import structlog from django.contrib.gis.geoip2 import GeoIP2 @@ -27,7 +27,7 @@ ] -def get_geoip_properties(ip_address: Optional[str]) -> Dict[str, str]: +def get_geoip_properties(ip_address: Optional[str]) -> dict[str, str]: """ Returns a dictionary of geoip properties for the given ip address. diff --git a/posthog/api/insight.py b/posthog/api/insight.py index 36495a5469b2e..a2fe0c53edc2c 100644 --- a/posthog/api/insight.py +++ b/posthog/api/insight.py @@ -1,6 +1,6 @@ import json from functools import lru_cache -from typing import Any, Dict, List, Optional, Type, Union, cast +from typing import Any, Optional, Union, cast import structlog from django.db import transaction @@ -118,7 +118,7 @@ def log_insight_activity( team_id: int, user: User, was_impersonated: bool, - changes: Optional[List[Change]] = None, + changes: Optional[list[Change]] = None, ) -> None: """ Insight id and short_id are passed separately as some activities (like delete) alter the Insight instance @@ -148,7 +148,7 @@ class QuerySchemaParser(JSONParser): """ def parse(self, stream, media_type=None, parser_context=None): - data = super(QuerySchemaParser, self).parse(stream, media_type, parser_context) + data = super().parse(stream, media_type, parser_context) try: query = data.get("query", None) if query: @@ -197,7 +197,7 @@ class Meta: ] read_only_fields = ("short_id", "updated_at", "last_refresh", "refreshing") - def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> Any: + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError() def to_representation(self, instance): @@ -306,7 +306,7 @@ class Meta: "is_cached", ) - def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> Insight: + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Insight: request = self.context["request"] tags = validated_data.pop("tags", None) # tags are created separately as global tag relationships team_id = self.context["team_id"] @@ -345,8 +345,8 @@ def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> Insight: return insight - def update(self, instance: Insight, validated_data: Dict, **kwargs) -> Insight: - dashboards_before_change: List[Union[str, Dict]] = [] + def update(self, instance: Insight, validated_data: dict, **kwargs) -> Insight: + dashboards_before_change: list[Union[str, dict]] = [] try: # since it is possible to be undeleting a soft deleted insight # the state captured before the update has to include soft deleted insights @@ -411,7 +411,7 @@ def _log_insight_update(self, before_update, dashboards_before_change, updated_i changes=changes, ) - def _synthetic_dashboard_changes(self, dashboards_before_change: List[Dict]) -> List[Change]: + def _synthetic_dashboard_changes(self, dashboards_before_change: list[dict]) -> list[Change]: artificial_dashboard_changes = self.context.get("after_dashboard_changes", []) if artificial_dashboard_changes: return [ @@ -426,7 +426,7 @@ def _synthetic_dashboard_changes(self, dashboards_before_change: List[Dict]) -> return [] - def _update_insight_dashboards(self, dashboards: List[Dashboard], instance: Insight) -> None: + def _update_insight_dashboards(self, dashboards: list[Dashboard], instance: Insight) -> None: old_dashboard_ids = [tile.dashboard_id for tile in instance.dashboard_tiles.all()] new_dashboard_ids = [d.id for d in dashboards if not d.deleted] @@ -598,14 +598,14 @@ class InsightViewSet( parser_classes = (QuerySchemaParser,) - def get_serializer_class(self) -> Type[serializers.BaseSerializer]: + def get_serializer_class(self) -> type[serializers.BaseSerializer]: if (self.action == "list" or self.action == "retrieve") and str_to_bool( self.request.query_params.get("basic", "0") ): return InsightBasicSerializer return super().get_serializer_class() - def get_serializer_context(self) -> Dict[str, Any]: + def get_serializer_context(self) -> dict[str, Any]: context = super().get_serializer_context() context["is_shared"] = isinstance(self.request.successful_authenticator, SharingAccessTokenAuthentication) return context @@ -867,7 +867,7 @@ def trend(self, request: request.Request, *args: Any, **kwargs: Any): return Response({**result, "next": next}) @cached_by_filters - def calculate_trends(self, request: request.Request) -> Dict[str, Any]: + def calculate_trends(self, request: request.Request) -> dict[str, Any]: team = self.team filter = Filter(request=request, team=self.team) @@ -919,7 +919,7 @@ def funnel(self, request: request.Request, *args: Any, **kwargs: Any) -> Respons return Response(funnel) @cached_by_filters - def calculate_funnel(self, request: request.Request) -> Dict[str, Any]: + def calculate_funnel(self, request: request.Request) -> dict[str, Any]: team = self.team filter = Filter(request=request, data={"insight": INSIGHT_FUNNELS}, team=self.team) @@ -959,7 +959,7 @@ def retention(self, request: request.Request, *args: Any, **kwargs: Any) -> Resp return Response(result) @cached_by_filters - def calculate_retention(self, request: request.Request) -> Dict[str, Any]: + def calculate_retention(self, request: request.Request) -> dict[str, Any]: team = self.team data = {} if not request.GET.get("date_from") and not request.data.get("date_from"): @@ -989,7 +989,7 @@ def path(self, request: request.Request, *args: Any, **kwargs: Any) -> Response: return Response(result) @cached_by_filters - def calculate_path(self, request: request.Request) -> Dict[str, Any]: + def calculate_path(self, request: request.Request) -> dict[str, Any]: team = self.team filter = PathFilter(request=request, data={"insight": INSIGHT_PATHS}, team=self.team) diff --git a/posthog/api/instance_settings.py b/posthog/api/instance_settings.py index dc0b41e5cb1da..13c1461ba5655 100644 --- a/posthog/api/instance_settings.py +++ b/posthog/api/instance_settings.py @@ -1,5 +1,5 @@ import re -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Union from rest_framework import exceptions, mixins, permissions, serializers, viewsets @@ -50,7 +50,7 @@ def __init__(self, **kwargs): setattr(self, field, kwargs.get(field, None)) -def get_instance_setting(key: str, setting_config: Optional[Tuple] = None) -> InstanceSettingHelper: +def get_instance_setting(key: str, setting_config: Optional[tuple] = None) -> InstanceSettingHelper: setting_config = setting_config or CONSTANCE_CONFIG[key] is_secret = key in SECRET_SETTINGS value = get_instance_setting_raw(key) @@ -73,7 +73,7 @@ class InstanceSettingsSerializer(serializers.Serializer): editable = serializers.BooleanField(read_only=True) is_secret = serializers.BooleanField(read_only=True) - def update(self, instance: InstanceSettingHelper, validated_data: Dict[str, Any]) -> InstanceSettingHelper: + def update(self, instance: InstanceSettingHelper, validated_data: dict[str, Any]) -> InstanceSettingHelper: if instance.key not in SETTINGS_ALLOWING_API_OVERRIDE: raise serializers.ValidationError("This setting cannot be updated from the API.", code="no_api_override") diff --git a/posthog/api/instance_status.py b/posthog/api/instance_status.py index c0dff3a3e4a1c..1e001b74703be 100644 --- a/posthog/api/instance_status.py +++ b/posthog/api/instance_status.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Union +from typing import Any, Union from django.conf import settings from django.db import connection @@ -40,7 +40,7 @@ def list(self, request: Request) -> Response: redis_alive = is_redis_alive() postgres_alive = is_postgres_alive() - metrics: List[Dict[str, Union[str, bool, int, float, Dict[str, Any]]]] = [] + metrics: list[dict[str, Union[str, bool, int, float, dict[str, Any]]]] = [] metrics.append( {"key": "posthog_git_sha", "metric": "PostHog Git SHA", "value": get_git_commit_short() or "unknown"} diff --git a/posthog/api/mixins.py b/posthog/api/mixins.py index 69b83d3469e01..a326eb3d1d2cd 100644 --- a/posthog/api/mixins.py +++ b/posthog/api/mixins.py @@ -1,4 +1,4 @@ -from typing import TypeVar, Type +from typing import TypeVar from pydantic import BaseModel, ValidationError @@ -9,7 +9,7 @@ class PydanticModelMixin: - def get_model(self, data: dict, model: Type[T]) -> T: + def get_model(self, data: dict, model: type[T]) -> T: try: return model.model_validate(data) except ValidationError as exc: diff --git a/posthog/api/notebook.py b/posthog/api/notebook.py index 5910af4948c38..4125b79dd6551 100644 --- a/posthog/api/notebook.py +++ b/posthog/api/notebook.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Any, Type +from typing import Optional, Any from django.db.models import Q import structlog from django.db import transaction @@ -58,7 +58,7 @@ def log_notebook_activity( team_id: int, user: User, was_impersonated: bool, - changes: Optional[List[Change]] = None, + changes: Optional[list[Change]] = None, ) -> None: short_id = str(notebook.short_id) @@ -118,7 +118,7 @@ class Meta: "last_modified_by", ] - def create(self, validated_data: Dict, *args, **kwargs) -> Notebook: + def create(self, validated_data: dict, *args, **kwargs) -> Notebook: request = self.context["request"] team = self.context["get_team"]() @@ -141,7 +141,7 @@ def create(self, validated_data: Dict, *args, **kwargs) -> Notebook: return notebook - def update(self, instance: Notebook, validated_data: Dict, **kwargs) -> Notebook: + def update(self, instance: Notebook, validated_data: dict, **kwargs) -> Notebook: try: before_update = Notebook.objects.get(pk=instance.id) except Notebook.DoesNotExist: @@ -240,7 +240,7 @@ class NotebookViewSet(TeamAndOrgViewSetMixin, ForbidDestroyModel, viewsets.Model filterset_fields = ["short_id"] lookup_field = "short_id" - def get_serializer_class(self) -> Type[BaseSerializer]: + def get_serializer_class(self) -> type[BaseSerializer]: return NotebookMinimalSerializer if self.action == "list" else NotebookSerializer def get_queryset(self) -> QuerySet: @@ -298,8 +298,8 @@ def _filter_request(self, request: Request, queryset: QuerySet) -> QuerySet: if target: # the JSONB query requires a specific structure - basic_structure = List[Dict[str, Any]] - nested_structure = basic_structure | List[Dict[str, basic_structure]] + basic_structure = list[dict[str, Any]] + nested_structure = basic_structure | list[dict[str, basic_structure]] presence_match_structure: basic_structure | nested_structure = [{"type": f"ph-{target}"}] diff --git a/posthog/api/organization.py b/posthog/api/organization.py index ea1a9f31615b1..f528d5413190a 100644 --- a/posthog/api/organization.py +++ b/posthog/api/organization.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Optional, Union, cast from django.db.models import Model, QuerySet from django.shortcuts import get_object_or_404 @@ -108,7 +108,7 @@ class Meta: }, # slug is not required here as it's generated automatically for new organizations } - def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> Organization: + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Organization: serializers.raise_errors_on_nested_writes("create", self, validated_data) user = self.context["request"].user organization, _, _ = Organization.objects.bootstrap(user, **validated_data) @@ -119,11 +119,11 @@ def get_membership_level(self, organization: Organization) -> Optional[Organizat membership = self.user_permissions.organization_memberships.get(organization.pk) return membership.level if membership is not None else None - def get_teams(self, instance: Organization) -> List[Dict[str, Any]]: + def get_teams(self, instance: Organization) -> list[dict[str, Any]]: visible_teams = instance.teams.filter(id__in=self.user_permissions.team_ids_visible_for_user) return TeamBasicSerializer(visible_teams, context=self.context, many=True).data # type: ignore - def get_metadata(self, instance: Organization) -> Dict[str, Union[str, int, object]]: + def get_metadata(self, instance: Organization) -> dict[str, Union[str, int, object]]: return { "instance_tag": settings.INSTANCE_TAG, } @@ -210,7 +210,7 @@ def perform_destroy(self, organization: Organization): ignore_conflicts=True, ) - def get_serializer_context(self) -> Dict[str, Any]: + def get_serializer_context(self) -> dict[str, Any]: return { **super().get_serializer_context(), "user_permissions": UserPermissions(cast(User, self.request.user)), diff --git a/posthog/api/organization_domain.py b/posthog/api/organization_domain.py index b3a4ada0b4e06..81b8c8efad8b7 100644 --- a/posthog/api/organization_domain.py +++ b/posthog/api/organization_domain.py @@ -1,5 +1,5 @@ import re -from typing import Any, Dict, cast +from typing import Any, cast from rest_framework import exceptions, request, response, serializers from rest_framework.decorators import action @@ -38,7 +38,7 @@ class Meta: "has_saml": {"read_only": True}, } - def create(self, validated_data: Dict[str, Any]) -> OrganizationDomain: + def create(self, validated_data: dict[str, Any]) -> OrganizationDomain: validated_data["organization"] = self.context["view"].organization validated_data.pop( "jit_provisioning_enabled", None @@ -56,7 +56,7 @@ def validate_domain(self, domain: str) -> str: raise serializers.ValidationError("Please enter a valid domain or subdomain name.") return domain - def validate(self, attrs: Dict[str, Any]) -> Dict[str, Any]: + def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: instance = cast(OrganizationDomain, self.instance) if instance and not instance.verified_at: diff --git a/posthog/api/organization_feature_flag.py b/posthog/api/organization_feature_flag.py index 0ed25ada28eef..d2468cb07ce12 100644 --- a/posthog/api/organization_feature_flag.py +++ b/posthog/api/organization_feature_flag.py @@ -1,4 +1,3 @@ -from typing import Dict from django.core.exceptions import ObjectDoesNotExist from rest_framework.response import Response from rest_framework.decorators import action @@ -95,13 +94,13 @@ def copy_flags(self, request, *args, **kwargs): continue # get all linked cohorts, sorted by creation order - seen_cohorts_cache: Dict[int, CohortOrEmpty] = {} + seen_cohorts_cache: dict[int, CohortOrEmpty] = {} sorted_cohort_ids = flag_to_copy.get_cohort_ids( seen_cohorts_cache=seen_cohorts_cache, sort_by_topological_order=True ) # destination cohort id is different from original cohort id - create mapping - name_to_dest_cohort_id: Dict[str, int] = {} + name_to_dest_cohort_id: dict[str, int] = {} # create cohorts in the destination project if len(sorted_cohort_ids): for cohort_id in sorted_cohort_ids: diff --git a/posthog/api/organization_invite.py b/posthog/api/organization_invite.py index 6a8140479a950..961f2cddba27d 100644 --- a/posthog/api/organization_invite.py +++ b/posthog/api/organization_invite.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, cast +from typing import Any, cast from rest_framework import ( exceptions, @@ -49,7 +49,7 @@ def validate_target_email(self, email: str): local_part, domain = email.split("@") return f"{local_part}@{domain.lower()}" - def create(self, validated_data: Dict[str, Any], *args: Any, **kwargs: Any) -> OrganizationInvite: + def create(self, validated_data: dict[str, Any], *args: Any, **kwargs: Any) -> OrganizationInvite: if OrganizationMembership.objects.filter( organization_id=self.context["organization_id"], user__email=validated_data["target_email"], diff --git a/posthog/api/person.py b/posthog/api/person.py index 942f07e9a9ef8..e242c7ffffe14 100644 --- a/posthog/api/person.py +++ b/posthog/api/person.py @@ -2,17 +2,14 @@ import posthoganalytics from posthog.renderers import SafeJSONRenderer from datetime import datetime -from typing import ( +from typing import ( # noqa: UP035 Any, - Callable, - Dict, List, Optional, - Tuple, - Type, TypeVar, cast, ) +from collections.abc import Callable from django.db.models import Prefetch from django.shortcuts import get_object_or_404 @@ -176,7 +173,7 @@ def get_name(self, person: Person) -> str: team = self.context["get_team"]() return get_person_name(team, person) - def to_representation(self, instance: Person) -> Dict[str, Any]: + def to_representation(self, instance: Person) -> dict[str, Any]: representation = super().to_representation(instance) representation["distinct_ids"] = sorted(representation["distinct_ids"], key=is_anonymous_id) return representation @@ -192,7 +189,7 @@ def get_distinct_ids(self, person): def get_funnel_actor_class(filter: Filter) -> Callable: - funnel_actor_class: Type[ActorBaseQuery] + funnel_actor_class: type[ActorBaseQuery] if filter.correlation_person_entity and EE_AVAILABLE: if EE_AVAILABLE: @@ -678,7 +675,7 @@ def _set_properties(self, properties, user): ) # PRAGMA: Methods for getting Persons via clickhouse queries - def _respond_with_cached_results(self, results_package: Dict[str, Tuple[List, Optional[str], Optional[str], int]]): + def _respond_with_cached_results(self, results_package: dict[str, tuple[List, Optional[str], Optional[str], int]]): # noqa: UP006 if not results_package: return response.Response(data=[]) @@ -705,7 +702,7 @@ def funnel(self, request: request.Request, **kwargs) -> response.Response: @cached_by_filters def calculate_funnel_persons( self, request: request.Request - ) -> Dict[str, Tuple[List, Optional[str], Optional[str], int]]: + ) -> dict[str, tuple[List, Optional[str], Optional[str], int]]: # noqa: UP006 filter = Filter(request=request, data={"insight": INSIGHT_FUNNELS}, team=self.team) filter = prepare_actor_query_filter(filter) funnel_actor_class = get_funnel_actor_class(filter) @@ -734,7 +731,7 @@ def path(self, request: request.Request, **kwargs) -> response.Response: @cached_by_filters def calculate_path_persons( self, request: request.Request - ) -> Dict[str, Tuple[List, Optional[str], Optional[str], int]]: + ) -> dict[str, tuple[List, Optional[str], Optional[str], int]]: # noqa: UP006 filter = PathFilter(request=request, data={"insight": INSIGHT_PATHS}, team=self.team) filter = prepare_actor_query_filter(filter) @@ -769,7 +766,7 @@ def trends(self, request: request.Request, *args: Any, **kwargs: Any) -> Respons @cached_by_filters def calculate_trends_persons( self, request: request.Request - ) -> Dict[str, Tuple[List, Optional[str], Optional[str], int]]: + ) -> dict[str, tuple[List, Optional[str], Optional[str], int]]: # noqa: UP006 filter = Filter(request=request, team=self.team) filter = prepare_actor_query_filter(filter) entity = get_target_entity(filter) diff --git a/posthog/api/plugin.py b/posthog/api/plugin.py index 2a6e00f325451..7a4dea1a8d7a9 100644 --- a/posthog/api/plugin.py +++ b/posthog/api/plugin.py @@ -2,7 +2,7 @@ import os import re import subprocess -from typing import Any, Dict, List, Optional, Set, cast, Literal +from typing import Any, Optional, cast, Literal import requests from dateutil.relativedelta import relativedelta @@ -64,8 +64,8 @@ def _update_plugin_attachments(request: request.Request, plugin_config: PluginCo def get_plugin_config_changes( - old_config: Dict[str, Any], new_config: Dict[str, Any], secret_fields=None -) -> List[Change]: + old_config: dict[str, Any], new_config: dict[str, Any], secret_fields=None +) -> list[Change]: if secret_fields is None: secret_fields = [] config_changes = dict_changes_between("Plugin", old_config, new_config) @@ -103,8 +103,8 @@ def log_enabled_change_activity( def log_config_update_activity( new_plugin_config: PluginConfig, - old_config: Dict[str, Any], - secret_fields: Set[str], + old_config: dict[str, Any], + secret_fields: set[str], old_enabled: bool, user: User, was_impersonated: bool, @@ -280,7 +280,7 @@ def get_latest_tag(self, plugin: Plugin) -> Optional[str]: def get_organization_name(self, plugin: Plugin) -> str: return plugin.organization.name - def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> Plugin: + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> Plugin: validated_data["url"] = self.initial_data.get("url", None) validated_data["organization_id"] = self.context["organization_id"] validated_data["updated_at"] = now() @@ -291,7 +291,7 @@ def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> Plugin: return plugin - def update(self, plugin: Plugin, validated_data: Dict, *args: Any, **kwargs: Any) -> Plugin: # type: ignore + def update(self, plugin: Plugin, validated_data: dict, *args: Any, **kwargs: Any) -> Plugin: # type: ignore context_organization = self.context["get_organization"]() if ( "is_global" in validated_data @@ -387,7 +387,7 @@ def check_for_updates(self, request: request.Request, **kwargs): @action(methods=["GET"], detail=True) def source(self, request: request.Request, **kwargs): plugin = self.get_plugin_with_permissions(reason="source editing") - response: Dict[str, str] = {} + response: dict[str, str] = {} for source in PluginSourceFile.objects.filter(plugin=plugin): response[source.filename] = source.source return Response(response) @@ -395,7 +395,7 @@ def source(self, request: request.Request, **kwargs): @action(methods=["PATCH"], detail=True) def update_source(self, request: request.Request, **kwargs): plugin = self.get_plugin_with_permissions(reason="source editing") - sources: Dict[str, PluginSourceFile] = {} + sources: dict[str, PluginSourceFile] = {} performed_changes = False for plugin_source_file in PluginSourceFile.objects.filter(plugin=plugin): sources[plugin_source_file.filename] = plugin_source_file @@ -438,7 +438,7 @@ def update_source(self, request: request.Request, **kwargs): sources[key].error = error sources[key].save() - response: Dict[str, str] = {} + response: dict[str, str] = {} for _, source in sources.items(): response[source.filename] = source.source @@ -476,7 +476,7 @@ def upgrade(self, request: request.Request, **kwargs): Plugin.PluginType.SOURCE, Plugin.PluginType.LOCAL, ): - validated_data: Dict[str, Any] = {} + validated_data: dict[str, Any] = {} plugin_json = update_validated_data_from_url(validated_data, plugin.url) with transaction.atomic(): serializer.update(plugin, validated_data) @@ -647,7 +647,7 @@ def get_error(self, plugin_config: PluginConfig) -> None: # error details instead. return None - def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> PluginConfig: + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> PluginConfig: if not can_configure_plugins(self.context["get_organization"]()): raise ValidationError("Plugin configuration is not available for the current organization!") validated_data["team_id"] = self.context["team_id"] @@ -682,7 +682,7 @@ def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> PluginConfi def update( # type: ignore self, plugin_config: PluginConfig, - validated_data: Dict, + validated_data: dict, *args: Any, **kwargs: Any, ) -> PluginConfig: @@ -731,7 +731,7 @@ def get_queryset(self): queryset = queryset.filter(deleted=False) return queryset.order_by("order", "plugin_id") - def get_serializer_context(self) -> Dict[str, Any]: + def get_serializer_context(self) -> dict[str, Any]: context = super().get_serializer_context() if context["view"].action in ("retrieve", "list"): context["delivery_rates_1d"] = TeamPluginsDeliveryRateQuery(self.team).run() @@ -856,7 +856,7 @@ def frontend(self, request: request.Request, **kwargs): content = plugin_source.transpiled or "" return HttpResponse(content, content_type="application/javascript; charset=UTF-8") - obj: Dict[str, Any] = {} + obj: dict[str, Any] = {} if not plugin_source: obj = {"no_frontend": True} elif plugin_source.status is None or plugin_source.status == PluginSourceFile.Status.LOCKED: @@ -868,7 +868,7 @@ def frontend(self, request: request.Request, **kwargs): return HttpResponse(content, content_type="application/javascript; charset=UTF-8") -def _get_secret_fields_for_plugin(plugin: Plugin) -> Set[str]: +def _get_secret_fields_for_plugin(plugin: Plugin) -> set[str]: # A set of keys for config fields that have secret = true secret_fields = {field["key"] for field in plugin.config_schema if isinstance(field, dict) and field.get("secret")} return secret_fields diff --git a/posthog/api/property_definition.py b/posthog/api/property_definition.py index 584644f902b33..6a87fc6f348d7 100644 --- a/posthog/api/property_definition.py +++ b/posthog/api/property_definition.py @@ -1,6 +1,6 @@ import dataclasses import json -from typing import Any, Dict, List, Optional, Type, cast +from typing import Any, Optional, cast from django.db import connection from django.db.models import Prefetch @@ -125,7 +125,7 @@ class QueryContext: posthog_eventproperty_table_join_alias = "check_for_matching_event_property" - params: Dict = dataclasses.field(default_factory=dict) + params: dict = dataclasses.field(default_factory=dict) def with_properties_to_filter(self, properties_to_filter: Optional[str]) -> "QueryContext": if properties_to_filter: @@ -219,7 +219,7 @@ def with_event_property_filter( params={**self.params, "event_names": list(map(str, event_names or []))}, ) - def with_search(self, search_query: str, search_kwargs: Dict) -> "QueryContext": + def with_search(self, search_query: str, search_kwargs: dict) -> "QueryContext": return dataclasses.replace( self, search_query=search_query, @@ -443,7 +443,7 @@ def get_count(self, queryset) -> int: return self.count - def paginate_queryset(self, queryset, request, view=None) -> Optional[List[Any]]: + def paginate_queryset(self, queryset, request, view=None) -> Optional[list[Any]]: """ Assumes the queryset has already had pagination applied """ @@ -570,7 +570,7 @@ def get_queryset(self): return queryset.raw(query_context.as_sql(order_by_verified), params=query_context.params) - def get_serializer_class(self) -> Type[serializers.ModelSerializer]: + def get_serializer_class(self) -> type[serializers.ModelSerializer]: serializer_class = self.serializer_class if self.request.user.organization.is_feature_available(AvailableFeature.INGESTION_TAXONOMY): try: diff --git a/posthog/api/routing.py b/posthog/api/routing.py index b768538c05d50..02654051a3f12 100644 --- a/posthog/api/routing.py +++ b/posthog/api/routing.py @@ -1,5 +1,5 @@ from functools import cached_property, lru_cache -from typing import TYPE_CHECKING, Any, Dict, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from rest_framework.exceptions import AuthenticationFailed, NotFound, ValidationError from rest_framework.permissions import IsAuthenticated @@ -50,7 +50,7 @@ class TeamAndOrgViewSetMixin(_GenericViewSet): # Rewrite filter queries, so that for example foreign keys can be accessed # Example: {"team_id": "foo__team_id"} will make the viewset filtered by obj.foo.team_id instead of obj.team_id - filter_rewrite_rules: Dict[str, str] = {} + filter_rewrite_rules: dict[str, str] = {} authentication_classes = [] permission_classes = [] @@ -170,7 +170,7 @@ def filter_queryset_by_parents_lookups(self, queryset): return queryset @cached_property - def parents_query_dict(self) -> Dict[str, Any]: + def parents_query_dict(self) -> dict[str, Any]: # used to override the last visited project if there's a token in the request team_from_request = self._get_team_from_request() @@ -213,7 +213,7 @@ def parents_query_dict(self) -> Dict[str, Any]: result[query_lookup] = query_value return result - def get_serializer_context(self) -> Dict[str, Any]: + def get_serializer_context(self) -> dict[str, Any]: serializer_context = super().get_serializer_context() if hasattr(super(), "get_serializer_context") else {} serializer_context.update(self.parents_query_dict) # The below are lambdas for lazy evaluation (i.e. we only query Postgres for team/org if actually needed) diff --git a/posthog/api/scheduled_change.py b/posthog/api/scheduled_change.py index 5d1878ebfe4fb..2100f6b7bdc7d 100644 --- a/posthog/api/scheduled_change.py +++ b/posthog/api/scheduled_change.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any from rest_framework import ( serializers, viewsets, @@ -29,7 +29,7 @@ class Meta: ] read_only_fields = ["id", "created_at", "created_by", "updated_at"] - def create(self, validated_data: Dict, *args: Any, **kwargs: Any) -> ScheduledChange: + def create(self, validated_data: dict, *args: Any, **kwargs: Any) -> ScheduledChange: request = self.context["request"] validated_data["created_by"] = request.user validated_data["team_id"] = self.context["team_id"] diff --git a/posthog/api/sharing.py b/posthog/api/sharing.py index c7ab40fb0f89d..3d4a2d693749c 100644 --- a/posthog/api/sharing.py +++ b/posthog/api/sharing.py @@ -1,6 +1,6 @@ import json from datetime import timedelta -from typing import Any, Dict, Optional, cast +from typing import Any, Optional, cast from urllib.parse import urlparse, urlunparse from django.core.serializers.json import DjangoJSONEncoder @@ -87,7 +87,7 @@ class SharingConfigurationViewSet(TeamAndOrgViewSetMixin, mixins.ListModelMixin, def get_serializer_context( self, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: context = super().get_serializer_context() dashboard_id = context.get("dashboard_id") @@ -113,7 +113,7 @@ def get_serializer_context( return context - def _get_sharing_configuration(self, context: Dict[str, Any]): + def _get_sharing_configuration(self, context: dict[str, Any]): """ Gets but does not create a SharingConfiguration. Only once enabled do we actually store it """ @@ -247,7 +247,7 @@ def retrieve(self, request: Request, *args: Any, **kwargs: Any) -> Any: "user_permissions": UserPermissions(cast(User, request.user), resource.team), "is_shared": True, } - exported_data: Dict[str, Any] = {"type": "embed" if embedded else "scene"} + exported_data: dict[str, Any] = {"type": "embed" if embedded else "scene"} if isinstance(resource, SharingConfiguration) and request.path.endswith(f".png"): exported_data["accessToken"] = resource.access_token diff --git a/posthog/api/signup.py b/posthog/api/signup.py index c31f37b891eb3..8385dc7759798 100644 --- a/posthog/api/signup.py +++ b/posthog/api/signup.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union, cast +from typing import Any, Optional, Union, cast from urllib.parse import urlencode import structlog @@ -71,7 +71,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.is_social_signup = False - def get_fields(self) -> Dict[str, serializers.Field]: + def get_fields(self) -> dict[str, serializers.Field]: fields = super().get_fields() if settings.DEMO: # There's no password in the demo env @@ -156,7 +156,7 @@ def enter_demo(self, validated_data) -> User: def create_team(self, organization: Organization, user: User) -> Team: return Team.objects.create_with_data(user=user, organization=organization) - def to_representation(self, instance) -> Dict: + def to_representation(self, instance) -> dict: data = UserBasicSerializer(instance=instance).data data["redirect_url"] = get_redirect_url(data["uuid"], data["is_email_verified"]) return data @@ -185,7 +185,7 @@ def to_representation(self, instance): data["redirect_url"] = get_redirect_url(data["uuid"], data["is_email_verified"]) return data - def validate(self, data: Dict[str, Any]) -> Dict[str, Any]: + def validate(self, data: dict[str, Any]) -> dict[str, Any]: if "request" not in self.context or not self.context["request"].user.is_authenticated: # If there's no authenticated user and we're creating a new one, attributes are required. @@ -469,7 +469,7 @@ def social_create_user( return {"is_new": False} backend_processor = "social_create_user" - email = details["email"][0] if isinstance(details["email"], (list, tuple)) else details["email"] + email = details["email"][0] if isinstance(details["email"], list | tuple) else details["email"] full_name = ( details.get("fullname") or f"{details.get('first_name') or ''} {details.get('last_name') or ''}".strip() diff --git a/posthog/api/survey.py b/posthog/api/survey.py index cb991a5f95abe..3ffce982b8981 100644 --- a/posthog/api/survey.py +++ b/posthog/api/survey.py @@ -1,5 +1,4 @@ from contextlib import contextmanager -from typing import Type from django.http import JsonResponse @@ -271,7 +270,7 @@ class SurveyViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): scope_object = "survey" queryset = Survey.objects.select_related("linked_flag", "targeting_flag").all() - def get_serializer_class(self) -> Type[serializers.Serializer]: + def get_serializer_class(self) -> type[serializers.Serializer]: if self.request.method == "POST" or self.request.method == "PATCH": return SurveySerializerCreateUpdateOnly else: diff --git a/posthog/api/tagged_item.py b/posthog/api/tagged_item.py index 85aa08323a04c..d73275523b639 100644 --- a/posthog/api/tagged_item.py +++ b/posthog/api/tagged_item.py @@ -50,7 +50,7 @@ def _attempt_set_tags(self, tags, obj, force_create=False): obj.prefetched_tags = tagged_item_objects def to_representation(self, obj): - ret = super(TaggedItemSerializerMixin, self).to_representation(obj) + ret = super().to_representation(obj) ret["tags"] = [] if self._is_licensed(): if hasattr(obj, "prefetched_tags"): @@ -61,12 +61,12 @@ def to_representation(self, obj): def create(self, validated_data): validated_data.pop("tags", None) - instance = super(TaggedItemSerializerMixin, self).create(validated_data) + instance = super().create(validated_data) self._attempt_set_tags(self.initial_data.get("tags"), instance) return instance def update(self, instance, validated_data): - instance = super(TaggedItemSerializerMixin, self).update(instance, validated_data) + instance = super().update(instance, validated_data) self._attempt_set_tags(self.initial_data.get("tags"), instance) return instance @@ -96,7 +96,7 @@ def prefetch_tagged_items_if_available(self, queryset: QuerySet) -> QuerySet: return queryset def get_queryset(self): - queryset = super(TaggedItemViewSetMixin, self).get_queryset() + queryset = super().get_queryset() return self.prefetch_tagged_items_if_available(queryset) diff --git a/posthog/api/team.py b/posthog/api/team.py index c8b2513b6798c..39acc8c2a0a7e 100644 --- a/posthog/api/team.py +++ b/posthog/api/team.py @@ -1,6 +1,6 @@ import json from functools import cached_property -from typing import Any, Dict, List, Optional, Type, cast +from typing import Any, Optional, cast from django.core.cache import cache from django.shortcuts import get_object_or_404 @@ -190,11 +190,11 @@ def get_has_group_types(self, team: Team) -> bool: def get_groups_on_events_querying_enabled(self, team: Team) -> bool: return groups_on_events_querying_enabled() - def validate_session_recording_linked_flag(self, value) -> Dict | None: + def validate_session_recording_linked_flag(self, value) -> dict | None: if value is None: return None - if not isinstance(value, Dict): + if not isinstance(value, dict): raise exceptions.ValidationError("Must provide a dictionary or None.") received_keys = value.keys() valid_keys = [ @@ -208,11 +208,11 @@ def validate_session_recording_linked_flag(self, value) -> Dict | None: return value - def validate_session_recording_network_payload_capture_config(self, value) -> Dict | None: + def validate_session_recording_network_payload_capture_config(self, value) -> dict | None: if value is None: return None - if not isinstance(value, Dict): + if not isinstance(value, dict): raise exceptions.ValidationError("Must provide a dictionary or None.") if not all(key in ["recordHeaders", "recordBody"] for key in value.keys()): @@ -222,11 +222,11 @@ def validate_session_recording_network_payload_capture_config(self, value) -> Di return value - def validate_session_replay_config(self, value) -> Dict | None: + def validate_session_replay_config(self, value) -> dict | None: if value is None: return None - if not isinstance(value, Dict): + if not isinstance(value, dict): raise exceptions.ValidationError("Must provide a dictionary or None.") known_keys = ["record_canvas", "ai_config"] @@ -240,9 +240,9 @@ def validate_session_replay_config(self, value) -> Dict | None: return value - def validate_session_replay_ai_summary_config(self, value: Dict | None) -> Dict | None: + def validate_session_replay_ai_summary_config(self, value: dict | None) -> dict | None: if value is not None: - if not isinstance(value, Dict): + if not isinstance(value, dict): raise exceptions.ValidationError("Must provide a dictionary or None.") allowed_keys = [ @@ -294,7 +294,7 @@ def validate(self, attrs: Any) -> Any: ) return super().validate(attrs) - def create(self, validated_data: Dict[str, Any], **kwargs) -> Team: + def create(self, validated_data: dict[str, Any], **kwargs) -> Team: serializers.raise_errors_on_nested_writes("create", self, validated_data) request = self.context["request"] organization = self.context["view"].organization # Use the org we used to validate permissions @@ -337,7 +337,7 @@ def _handle_timezone_update(self, team: Team) -> None: hashes = InsightCachingState.objects.filter(team=team).values_list("cache_key", flat=True) cache.delete_many(hashes) - def update(self, instance: Team, validated_data: Dict[str, Any]) -> Team: + def update(self, instance: Team, validated_data: dict[str, Any]) -> Team: before_update = instance.__dict__.copy() if "timezone" in validated_data and validated_data["timezone"] != instance.timezone: @@ -406,13 +406,13 @@ def get_queryset(self): visible_teams_ids = UserPermissions(cast(User, self.request.user)).team_ids_visible_for_user return super().get_queryset().filter(id__in=visible_teams_ids) - def get_serializer_class(self) -> Type[serializers.BaseSerializer]: + def get_serializer_class(self) -> type[serializers.BaseSerializer]: if self.action == "list": return TeamBasicSerializer return super().get_serializer_class() # NOTE: Team permissions are somewhat complex so we override the underlying viewset's get_permissions method - def get_permissions(self) -> List: + def get_permissions(self) -> list: """ Special permissions handling for create requests as the organization is inferred from the current user. """ diff --git a/posthog/api/test/dashboards/__init__.py b/posthog/api/test/dashboards/__init__.py index 79d1e435e64ec..ad6505b5a61a7 100644 --- a/posthog/api/test/dashboards/__init__.py +++ b/posthog/api/test/dashboards/__init__.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Literal, Optional, Tuple +from typing import Any, Literal, Optional from rest_framework import status @@ -15,7 +15,7 @@ def soft_delete( self, model_id: int, model_type: Literal["insights", "dashboards"], - extra_data: Optional[Dict] = None, + extra_data: Optional[dict] = None, expected_get_status: int = status.HTTP_404_NOT_FOUND, ) -> None: if extra_data is None: @@ -33,10 +33,10 @@ def soft_delete( def create_dashboard( self, - data: Dict[str, Any], + data: dict[str, Any], team_id: Optional[int] = None, expected_status: int = status.HTTP_201_CREATED, - ) -> Tuple[int, Dict[str, Any]]: + ) -> tuple[int, dict[str, Any]]: if team_id is None: team_id = self.team.id response = self.client.post(f"/api/projects/{team_id}/dashboards/", data) @@ -49,10 +49,10 @@ def create_dashboard( def update_dashboard( self, dashboard_id: int, - data: Dict[str, Any], + data: dict[str, Any], team_id: Optional[int] = None, expected_status: int = status.HTTP_200_OK, - ) -> Tuple[int, Dict[str, Any]]: + ) -> tuple[int, dict[str, Any]]: if team_id is None: team_id = self.team.id response = self.client.patch(f"/api/projects/{team_id}/dashboards/{dashboard_id}", data) @@ -67,8 +67,8 @@ def get_dashboard( dashboard_id: int, team_id: Optional[int] = None, expected_status: int = status.HTTP_200_OK, - query_params: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: + query_params: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: if team_id is None: team_id = self.team.id @@ -82,8 +82,8 @@ def list_dashboards( self, team_id: Optional[int] = None, expected_status: int = status.HTTP_200_OK, - query_params: Optional[Dict] = None, - ) -> Dict: + query_params: Optional[dict] = None, + ) -> dict: if team_id is None: team_id = self.team.id @@ -100,8 +100,8 @@ def list_insights( self, team_id: Optional[int] = None, expected_status: int = status.HTTP_200_OK, - query_params: Optional[Dict] = None, - ) -> Dict: + query_params: Optional[dict] = None, + ) -> dict: if team_id is None: team_id = self.team.id @@ -122,8 +122,8 @@ def get_insight( insight_id: int, team_id: Optional[int] = None, expected_status: int = status.HTTP_200_OK, - query_params: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: + query_params: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: if team_id is None: team_id = self.team.id @@ -138,10 +138,10 @@ def get_insight( def create_insight( self, - data: Dict[str, Any], + data: dict[str, Any], team_id: Optional[int] = None, expected_status: int = status.HTTP_201_CREATED, - ) -> Tuple[int, Dict[str, Any]]: + ) -> tuple[int, dict[str, Any]]: if team_id is None: team_id = self.team.id @@ -160,10 +160,10 @@ def create_insight( def update_insight( self, insight_id: int, - data: Dict[str, Any], + data: dict[str, Any], team_id: Optional[int] = None, expected_status: int = status.HTTP_200_OK, - ) -> Tuple[int, Dict[str, Any]]: + ) -> tuple[int, dict[str, Any]]: if team_id is None: team_id = self.team.id @@ -177,10 +177,10 @@ def create_text_tile( self, dashboard_id: int, text: str = "I AM TEXT!", - extra_data: Optional[Dict] = None, + extra_data: Optional[dict] = None, team_id: Optional[int] = None, expected_status: int = status.HTTP_200_OK, - ) -> Tuple[int, Dict[str, Any]]: + ) -> tuple[int, dict[str, Any]]: if team_id is None: team_id = self.team.id @@ -218,10 +218,10 @@ def get_insight_activity( def update_text_tile( self, dashboard_id: int, - tile: Dict, + tile: dict, team_id: Optional[int] = None, expected_status: int = status.HTTP_200_OK, - ) -> Tuple[int, Dict[str, Any]]: + ) -> tuple[int, dict[str, Any]]: if team_id is None: team_id = self.team.id @@ -271,7 +271,7 @@ def set_tile_layout(self, dashboard_id: int, expected_tiles_to_update: int) -> N def add_insight_to_dashboard( self, - dashboard_ids: List[int], + dashboard_ids: list[int], insight_id: int, expected_status: int = status.HTTP_200_OK, ): diff --git a/posthog/api/test/dashboards/test_dashboard.py b/posthog/api/test/dashboards/test_dashboard.py index 234123dde16a6..7e8f3fafc87cf 100644 --- a/posthog/api/test/dashboards/test_dashboard.py +++ b/posthog/api/test/dashboards/test_dashboard.py @@ -1,5 +1,4 @@ import json -from typing import Dict from unittest import mock from unittest.mock import ANY, MagicMock, patch @@ -21,7 +20,7 @@ from posthog.test.base import APIBaseTest, QueryMatchingTest, snapshot_postgres_queries, FuzzyInt from posthog.utils import generate_cache_key -valid_template: Dict = { +valid_template: dict = { "template_name": "Sign up conversion template with variables", "dashboard_description": "Use this template to see how many users sign up after visiting your pricing page.", "dashboard_filters": {}, @@ -1186,7 +1185,7 @@ def test_create_from_template_json(self, mock_capture) -> None: ) def test_create_from_template_json_must_provide_at_least_one_tile(self) -> None: - template: Dict = {**valid_template, "tiles": []} + template: dict = {**valid_template, "tiles": []} response = self.client.post( f"/api/projects/{self.team.id}/dashboards/create_from_template_json", @@ -1195,7 +1194,7 @@ def test_create_from_template_json_must_provide_at_least_one_tile(self) -> None: assert response.status_code == 400, response.json() def test_create_from_template_json_can_provide_text_tile(self) -> None: - template: Dict = { + template: dict = { **valid_template, "tiles": [{"type": "TEXT", "body": "hello world", "layouts": {}}], } @@ -1226,7 +1225,7 @@ def test_create_from_template_json_can_provide_text_tile(self) -> None: ] def test_create_from_template_json_can_provide_query_tile(self) -> None: - template: Dict = { + template: dict = { **valid_template, # client provides an incorrect "empty" filter alongside a query "tiles": [ diff --git a/posthog/api/test/dashboards/test_dashboard_duplication.py b/posthog/api/test/dashboards/test_dashboard_duplication.py index dbfa572e9c014..f477f9f1e0598 100644 --- a/posthog/api/test/dashboards/test_dashboard_duplication.py +++ b/posthog/api/test/dashboards/test_dashboard_duplication.py @@ -1,5 +1,3 @@ -from typing import Dict, List - from posthog.api.test.dashboards import DashboardAPI from posthog.test.base import APIBaseTest, QueryMatchingTest @@ -85,7 +83,7 @@ def test_duplicating_dashboard_without_duplicating_tiles(self) -> None: ] @staticmethod - def _tile_child_ids_from(dashboard_json: Dict) -> List[int]: + def _tile_child_ids_from(dashboard_json: dict) -> list[int]: return [ (tile.get("insight", None) or {}).get("id", None) or (tile.get("text", None) or {}).get("id", None) for tile in dashboard_json["tiles"] diff --git a/posthog/api/test/dashboards/test_dashboard_text_tiles.py b/posthog/api/test/dashboards/test_dashboard_text_tiles.py index 34b9366da5aeb..d3f899d72d284 100644 --- a/posthog/api/test/dashboards/test_dashboard_text_tiles.py +++ b/posthog/api/test/dashboards/test_dashboard_text_tiles.py @@ -1,5 +1,5 @@ import datetime -from typing import Dict, Optional, Union +from typing import Optional, Union from unittest import mock from freezegun import freeze_time @@ -16,7 +16,7 @@ def setUp(self) -> None: self.dashboard_api = DashboardAPI(self.client, self.team, self.assertEqual) @staticmethod - def _serialised_user(user: Optional[User]) -> Optional[Dict[str, Optional[Union[int, str]]]]: + def _serialised_user(user: Optional[User]) -> Optional[dict[str, Optional[Union[int, str]]]]: if user is None: return None @@ -37,7 +37,7 @@ def _expected_text( last_modified_by: Optional[User] = None, text_id: Optional[int] = None, last_modified_at: str = "2022-04-01T12:45:00Z", - ) -> Dict: + ) -> dict: if not created_by: created_by = self.user @@ -62,7 +62,7 @@ def _expected_tile_with_text( text_id: Optional[int] = None, color: Optional[str] = None, last_modified_at: str = "2022-04-01T12:45:00Z", - ) -> Dict: + ) -> dict: if not tile_id: tile_id = mock.ANY return { @@ -82,7 +82,7 @@ def _expected_tile_with_text( } @staticmethod - def _tile_layout(lg: Optional[Dict] = None) -> Dict: + def _tile_layout(lg: Optional[dict] = None) -> dict: if lg is None: lg = {"x": "0", "y": "0", "w": "6", "h": "5"} diff --git a/posthog/api/test/notebooks/test_notebook.py b/posthog/api/test/notebooks/test_notebook.py index 2779f1a226c78..f01d8fd6bc694 100644 --- a/posthog/api/test/notebooks/test_notebook.py +++ b/posthog/api/test/notebooks/test_notebook.py @@ -1,4 +1,3 @@ -from typing import List, Dict from unittest import mock from freezegun import freeze_time @@ -11,7 +10,7 @@ class TestNotebooks(APIBaseTest, QueryMatchingTest): - def created_activity(self, item_id: str, short_id: str) -> Dict: + def created_activity(self, item_id: str, short_id: str) -> dict: return { "activity": "created", "created_at": mock.ANY, @@ -30,11 +29,11 @@ def created_activity(self, item_id: str, short_id: str) -> Dict: }, } - def assert_notebook_activity(self, expected: List[Dict]) -> None: + def assert_notebook_activity(self, expected: list[dict]) -> None: activity_response = self.client.get(f"/api/projects/{self.team.id}/notebooks/activity") assert activity_response.status_code == status.HTTP_200_OK - activity: List[Dict] = activity_response.json()["results"] + activity: list[dict] = activity_response.json()["results"] self.maxDiff = None assert activity == expected @@ -78,7 +77,7 @@ def test_cannot_list_deleted_notebook(self) -> None: ), ] ) - def test_create_a_notebook(self, _, content: Dict | None, text_content: str | None) -> None: + def test_create_a_notebook(self, _, content: dict | None, text_content: str | None) -> None: response = self.client.post( f"/api/projects/{self.team.id}/notebooks", data={"content": content, "text_content": text_content}, diff --git a/posthog/api/test/notebooks/test_notebook_filtering.py b/posthog/api/test/notebooks/test_notebook_filtering.py index bbe191892d8e8..06b543deca4bc 100644 --- a/posthog/api/test/notebooks/test_notebook_filtering.py +++ b/posthog/api/test/notebooks/test_notebook_filtering.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, List +from typing import Any from parameterized import parameterized from rest_framework import status @@ -59,7 +59,7 @@ class TestNotebooksFiltering(APIBaseTest, QueryMatchingTest): - def _create_notebook_with_content(self, inner_content: List[Dict[str, Any]], title: str = "the title") -> str: + def _create_notebook_with_content(self, inner_content: list[dict[str, Any]], title: str = "the title") -> str: response = self.client.post( f"/api/projects/{self.team.id}/notebooks", data={ @@ -83,7 +83,7 @@ def _create_notebook_with_content(self, inner_content: List[Dict[str, Any]], tit ["random", []], ] ) - def test_filters_based_on_title(self, search_text: str, expected_match_indexes: List[int]) -> None: + def test_filters_based_on_title(self, search_text: str, expected_match_indexes: list[int]) -> None: notebook_ids = [ self._create_notebook_with_content([BASIC_TEXT("my important notes")], title="i ride around on a pony"), self._create_notebook_with_content([BASIC_TEXT("my important notes")], title="my hobby is to fish around"), @@ -108,7 +108,7 @@ def test_filters_based_on_title(self, search_text: str, expected_match_indexes: ["neither", []], ] ) - def test_filters_based_on_text_content(self, search_text: str, expected_match_indexes: List[int]) -> None: + def test_filters_based_on_text_content(self, search_text: str, expected_match_indexes: list[int]) -> None: notebook_ids = [ # will match both pony and ponies self._create_notebook_with_content([BASIC_TEXT("you may ride a pony")], title="never matches"), diff --git a/posthog/api/test/openapi_validation.py b/posthog/api/test/openapi_validation.py index e86bf5198bb53..20d2fb1e1a603 100644 --- a/posthog/api/test/openapi_validation.py +++ b/posthog/api/test/openapi_validation.py @@ -1,7 +1,7 @@ import gzip import json from io import BytesIO -from typing import Any, Dict, Optional, cast +from typing import Any, Optional, cast from urllib.parse import parse_qs import lzstring @@ -11,7 +11,7 @@ from jsonschema import validate -def validate_response(openapi_spec: Dict[str, Any], response: Any, path_override: Optional[str] = None): +def validate_response(openapi_spec: dict[str, Any], response: Any, path_override: Optional[str] = None): # Validates are response against the OpenAPI spec. If `path_override` is # provided, the path in the response will be overridden with the provided # value. This is useful for validating responses from e.g. the /batch diff --git a/posthog/api/test/test_activity_log.py b/posthog/api/test/test_activity_log.py index a7573f10cabd3..c386d30de6cfd 100644 --- a/posthog/api/test/test_activity_log.py +++ b/posthog/api/test/test_activity_log.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional from freezegun import freeze_time from freezegun.api import FrozenDateTimeFactory, StepTickTimeFactory @@ -9,7 +9,7 @@ from posthog.test.base import APIBaseTest, QueryMatchingTest -def _feature_flag_json_payload(key: str) -> Dict: +def _feature_flag_json_payload(key: str) -> dict: return { "key": key, "name": "", @@ -103,7 +103,7 @@ def _create_and_edit_things(self): def _edit_them_all( self, - created_insights: List[int], + created_insights: list[int], flag_one: str, flag_two: str, notebook_short_id: str, @@ -269,10 +269,10 @@ def test_reading_notifications_marks_them_unread(self): def _create_insight( self, - data: Dict[str, Any], + data: dict[str, Any], team_id: Optional[int] = None, expected_status: int = status.HTTP_201_CREATED, - ) -> Tuple[int, Dict[str, Any]]: + ) -> tuple[int, dict[str, Any]]: if team_id is None: team_id = self.team.id diff --git a/posthog/api/test/test_capture.py b/posthog/api/test/test_capture.py index f771aca99b39d..1beb4e9724b39 100644 --- a/posthog/api/test/test_capture.py +++ b/posthog/api/test/test_capture.py @@ -25,7 +25,7 @@ from prance import ResolvingParser from rest_framework import status from token_bucket import Limiter, MemoryStorage -from typing import Any, Dict, List, Union, cast +from typing import Any, Union, cast from unittest.mock import ANY, MagicMock, call, patch from urllib.parse import quote @@ -60,7 +60,7 @@ def mocked_get_ingest_context_from_token(_: Any) -> None: url=str(pathlib.Path(__file__).parent / "../../../openapi/capture.yaml"), strict=True, ) -openapi_spec = cast(Dict[str, Any], parser.specification) +openapi_spec = cast(dict[str, Any], parser.specification) large_data_array = [ {"key": "".join(random.choice(string.ascii_letters) for _ in range(512 * 1024))} @@ -162,7 +162,7 @@ def setUp(self): # it is really important to know that /capture is CSRF exempt. Enforce checking in the client self.client = Client(enforce_csrf_checks=True) - def _to_json(self, data: Union[Dict, List]) -> str: + def _to_json(self, data: Union[dict, list]) -> str: return json.dumps(data) def _dict_to_b64(self, data: dict) -> str: @@ -188,7 +188,7 @@ def _to_arguments(self, patch_process_event_with_plugins: Any) -> dict: def _send_original_version_session_recording_event( self, number_of_events: int = 1, - event_data: Dict | None = None, + event_data: dict | None = None, snapshot_source=3, snapshot_type=1, session_id="abc123", @@ -229,7 +229,7 @@ def _send_original_version_session_recording_event( def _send_august_2023_version_session_recording_event( self, number_of_events: int = 1, - event_data: Dict | List[Dict] | None = None, + event_data: dict | list[dict] | None = None, session_id="abc123", window_id="def456", distinct_id="ghi789", @@ -241,7 +241,7 @@ def _send_august_2023_version_session_recording_event( # event_data is an array of RRWeb events event_data = [{"type": 3, "data": {"source": 1}}, {"type": 3, "data": {"source": 2}}] - if isinstance(event_data, Dict): + if isinstance(event_data, dict): event_data = [event_data] event = { @@ -260,7 +260,7 @@ def _send_august_2023_version_session_recording_event( "distinct_id": distinct_id, } - post_data: List[Dict[str, Any]] | Dict[str, Any] + post_data: list[dict[str, Any]] | dict[str, Any] if content_type == "application/json": post_data = [{**event, "api_key": self.team.api_token} for _ in range(number_of_events)] @@ -1254,7 +1254,7 @@ def test_js_library_underscore_sent_at(self, kafka_produce): } self.client.get( - "/e/?_=%s&data=%s" % (int(tomorrow_sent_at.timestamp()), quote(self._to_json(data))), + "/e/?_={}&data={}".format(int(tomorrow_sent_at.timestamp()), quote(self._to_json(data))), content_type="application/json", HTTP_ORIGIN="https://localhost", ) @@ -1283,7 +1283,7 @@ def test_long_distinct_id(self, kafka_produce): } self.client.get( - "/e/?_=%s&data=%s" % (int(tomorrow_sent_at.timestamp()), quote(self._to_json(data))), + "/e/?_={}&data={}".format(int(tomorrow_sent_at.timestamp()), quote(self._to_json(data))), content_type="application/json", HTTP_ORIGIN="https://localhost", ) @@ -1526,7 +1526,7 @@ def test_handle_invalid_snapshot(self): ), ] ) - def test_cors_allows_tracing_headers(self, _: str, path: str, headers: List[str]) -> None: + def test_cors_allows_tracing_headers(self, _: str, path: str, headers: list[str]) -> None: expected_headers = ",".join(["X-Requested-With", "Content-Type", *headers]) presented_headers = ",".join([*headers, "someotherrandomheader"]) response = self.client.options( diff --git a/posthog/api/test/test_cohort.py b/posthog/api/test/test_cohort.py index 0b1971f8f2cbb..4e1a3da2d526d 100644 --- a/posthog/api/test/test_cohort.py +++ b/posthog/api/test/test_cohort.py @@ -1,6 +1,6 @@ import json from datetime import datetime, timedelta -from typing import Any, Dict, List +from typing import Any from unittest.mock import patch from django.core.files.uploadedfile import SimpleUploadedFile @@ -1493,11 +1493,11 @@ def test_async_deletion_of_cohort_with_race_condition_multiple_updates(self, pat self.assertEqual(async_deletion.delete_verified_at is not None, True) -def create_cohort(client: Client, team_id: int, name: str, groups: List[Dict[str, Any]]): +def create_cohort(client: Client, team_id: int, name: str, groups: list[dict[str, Any]]): return client.post(f"/api/projects/{team_id}/cohorts", {"name": name, "groups": json.dumps(groups)}) -def create_cohort_ok(client: Client, team_id: int, name: str, groups: List[Dict[str, Any]]): +def create_cohort_ok(client: Client, team_id: int, name: str, groups: list[dict[str, Any]]): response = create_cohort(client=client, team_id=team_id, name=name, groups=groups) assert response.status_code == 201, response.content return response.json() diff --git a/posthog/api/test/test_decide.py b/posthog/api/test/test_decide.py index e89fb0b3c1270..af9b2db88e3e7 100644 --- a/posthog/api/test/test_decide.py +++ b/posthog/api/test/test_decide.py @@ -3457,9 +3457,11 @@ def test_decide_doesnt_error_out_when_database_is_down_and_database_check_isnt_c # remove database check cache values postgres_healthcheck.cache_clear() - with connection.execute_wrapper(QueryTimeoutWrapper()), snapshot_postgres_queries_context( - self - ), self.assertNumQueries(1): + with ( + connection.execute_wrapper(QueryTimeoutWrapper()), + snapshot_postgres_queries_context(self), + self.assertNumQueries(1), + ): response = self._post_decide(api_version=3, origin="https://random.example.com").json() response = self._post_decide(api_version=3, origin="https://random.example.com").json() response = self._post_decide(api_version=3, origin="https://random.example.com").json() @@ -3607,8 +3609,10 @@ def test_healthcheck_uses_read_replica(self): self.organization, self.team, self.user = org, team, user # this create fills up team cache^ - with freeze_time("2021-01-01T00:00:00Z"), self.assertNumQueries(1, using="replica"), self.assertNumQueries( - 1, using="default" + with ( + freeze_time("2021-01-01T00:00:00Z"), + self.assertNumQueries(1, using="replica"), + self.assertNumQueries(1, using="default"), ): response = self._post_decide() # Replica queries: @@ -4031,9 +4035,11 @@ def test_feature_flags_v3_consistent_flags(self, mock_is_connected): # now main database is down, but does not affect replica - with connections["default"].execute_wrapper(QueryTimeoutWrapper()), self.assertNumQueries( - 13, using="replica" - ), self.assertNumQueries(0, using="default"): + with ( + connections["default"].execute_wrapper(QueryTimeoutWrapper()), + self.assertNumQueries(13, using="replica"), + self.assertNumQueries(0, using="default"), + ): # Replica queries: # E 1. SET LOCAL statement_timeout = 300 # E 2. WITH some CTEs, diff --git a/posthog/api/test/test_element.py b/posthog/api/test/test_element.py index 72a97ea2b9b43..25cd01df35398 100644 --- a/posthog/api/test/test_element.py +++ b/posthog/api/test/test_element.py @@ -1,6 +1,5 @@ import json from datetime import timedelta -from typing import Dict, List from django.test import override_settings from freezegun import freeze_time @@ -17,7 +16,7 @@ snapshot_postgres_queries, ) -expected_autocapture_data_response_results: List[Dict] = [ +expected_autocapture_data_response_results: list[dict] = [ { "count": 3, "hash": None, @@ -78,7 +77,7 @@ }, ] -expected_rage_click_data_response_results: List[Dict] = [ +expected_rage_click_data_response_results: list[dict] = [ { "count": 1, "hash": None, diff --git a/posthog/api/test/test_event_definition.py b/posthog/api/test/test_event_definition.py index aa2a2c05a2428..c530708886b70 100644 --- a/posthog/api/test/test_event_definition.py +++ b/posthog/api/test/test_event_definition.py @@ -1,6 +1,6 @@ import dataclasses from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Optional from unittest.mock import ANY, patch from uuid import uuid4 @@ -20,7 +20,7 @@ class TestEventDefinitionAPI(APIBaseTest): demo_team: Team = None # type: ignore - EXPECTED_EVENT_DEFINITIONS: List[Dict[str, Any]] = [ + EXPECTED_EVENT_DEFINITIONS: list[dict[str, Any]] = [ {"name": "installed_app"}, {"name": "rated_app"}, {"name": "purchase"}, @@ -54,7 +54,7 @@ def test_list_event_definitions(self): self.assertEqual(len(response.json()["results"]), len(self.EXPECTED_EVENT_DEFINITIONS)) for item in self.EXPECTED_EVENT_DEFINITIONS: - response_item: Dict[str, Any] = next( + response_item: dict[str, Any] = next( (_i for _i in response.json()["results"] if _i["name"] == item["name"]), {}, ) @@ -199,7 +199,7 @@ class EventData: team_id: int distinct_id: str timestamp: datetime - properties: Dict[str, Any] + properties: dict[str, Any] def capture_event(event: EventData): @@ -222,7 +222,7 @@ def capture_event(event: EventData): ) -def create_event_definitions(event_definition: Dict, team_id: int) -> EventDefinition: +def create_event_definitions(event_definition: dict, team_id: int) -> EventDefinition: """ Create event definition for a team. """ diff --git a/posthog/api/test/test_exports.py b/posthog/api/test/test_exports.py index eead54e9055de..5e80486620693 100644 --- a/posthog/api/test/test_exports.py +++ b/posthog/api/test/test_exports.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Optional from unittest.mock import patch from datetime import datetime, timedelta import celery @@ -435,10 +435,10 @@ def _get_insight_activity(self, insight_id: int, expected_status: int = status.H self.assertEqual(activity.status_code, expected_status) return activity.json() - def _assert_logs_the_activity(self, insight_id: int, expected: List[Dict]) -> None: + def _assert_logs_the_activity(self, insight_id: int, expected: list[dict]) -> None: activity_response = self._get_insight_activity(insight_id) - activity: List[Dict] = activity_response["results"] + activity: list[dict] = activity_response["results"] self.maxDiff = None self.assertEqual(activity, expected) @@ -463,7 +463,7 @@ def test_can_list_exports(self) -> None: class TestExportMixin(APIBaseTest): - def _get_export_output(self, path: str) -> List[str]: + def _get_export_output(self, path: str) -> list[str]: """ Use this function to test the CSV output of exports in other tests """ diff --git a/posthog/api/test/test_feature_flag.py b/posthog/api/test/test_feature_flag.py index 18236c8332f00..4c353b98124df 100644 --- a/posthog/api/test/test_feature_flag.py +++ b/posthog/api/test/test_feature_flag.py @@ -1,6 +1,6 @@ import datetime import json -from typing import Dict, List, Optional +from typing import Optional from unittest.mock import call, patch from django.core.cache import cache @@ -3657,10 +3657,10 @@ def _get_feature_flag_activity( self.assertEqual(activity.status_code, expected_status) return activity.json() - def assert_feature_flag_activity(self, flag_id: Optional[int], expected: List[Dict]): + def assert_feature_flag_activity(self, flag_id: Optional[int], expected: list[dict]): activity_response = self._get_feature_flag_activity(flag_id) - activity: List[Dict] = activity_response["results"] + activity: list[dict] = activity_response["results"] self.maxDiff = None assert activity == expected @@ -3898,7 +3898,7 @@ def test_feature_flag_dashboard(self): self.assertEqual(response.status_code, status.HTTP_200_OK) response_json = response.json() - self.assertEquals(len(response_json["analytics_dashboards"]), 1) + self.assertEqual(len(response_json["analytics_dashboards"]), 1) # check deleting the dashboard doesn't delete flag, but deletes the relationship dashboard.delete() @@ -3928,7 +3928,7 @@ def test_feature_flag_dashboard_patch(self): self.assertEqual(response.status_code, status.HTTP_200_OK) response_json = response.json() - self.assertEquals(len(response_json["analytics_dashboards"]), 1) + self.assertEqual(len(response_json["analytics_dashboards"]), 1) def test_feature_flag_dashboard_already_exists(self): another_feature_flag = FeatureFlag.objects.create( @@ -3954,7 +3954,7 @@ def test_feature_flag_dashboard_already_exists(self): self.assertEqual(response.status_code, status.HTTP_200_OK) response_json = response.json() - self.assertEquals(len(response_json["analytics_dashboards"]), 1) + self.assertEqual(len(response_json["analytics_dashboards"]), 1) @freeze_time("2021-01-01") @snapshot_clickhouse_queries @@ -3988,8 +3988,11 @@ def test_creating_static_cohort(self): ) flush_persons_and_events() - with snapshot_postgres_queries_context(self), self.settings( - CELERY_TASK_ALWAYS_EAGER=True, PERSON_ON_EVENTS_OVERRIDE=False, PERSON_ON_EVENTS_V2_OVERRIDE=False + with ( + snapshot_postgres_queries_context(self), + self.settings( + CELERY_TASK_ALWAYS_EAGER=True, PERSON_ON_EVENTS_OVERRIDE=False, PERSON_ON_EVENTS_V2_OVERRIDE=False + ), ): response = self.client.post( f"/api/projects/{self.team.id}/feature_flags/{flag.id}/create_static_cohort_for_flag", @@ -5328,9 +5331,13 @@ def test_feature_flags_v3_with_a_working_slow_db(self, mock_postgres_check): self.assertFalse(errors) # now db is slow and times out - with snapshot_postgres_queries_context(self), connection.execute_wrapper(slow_query), patch( - "posthog.models.feature_flag.flag_matching.FLAG_MATCHING_QUERY_TIMEOUT_MS", - 500, + with ( + snapshot_postgres_queries_context(self), + connection.execute_wrapper(slow_query), + patch( + "posthog.models.feature_flag.flag_matching.FLAG_MATCHING_QUERY_TIMEOUT_MS", + 500, + ), ): mock_postgres_check.return_value = False all_flags, _, _, errors = get_all_feature_flags(team_id, "example_id") @@ -5423,10 +5430,15 @@ def test_feature_flags_v3_with_skip_database_setting(self, mock_postgres_check): self.assertTrue(errors) # db is slow and times out, but shouldn't matter to us - with self.assertNumQueries(0), connection.execute_wrapper(slow_query), patch( - "posthog.models.feature_flag.flag_matching.FLAG_MATCHING_QUERY_TIMEOUT_MS", - 500, - ), self.settings(DECIDE_SKIP_POSTGRES_FLAGS=True): + with ( + self.assertNumQueries(0), + connection.execute_wrapper(slow_query), + patch( + "posthog.models.feature_flag.flag_matching.FLAG_MATCHING_QUERY_TIMEOUT_MS", + 500, + ), + self.settings(DECIDE_SKIP_POSTGRES_FLAGS=True), + ): mock_postgres_check.return_value = False all_flags, _, _, errors = get_all_feature_flags(team_id, "example_id") @@ -5536,10 +5548,15 @@ def test_feature_flags_v3_with_slow_db_doesnt_try_to_compute_conditions_again(se self.assertFalse(errors) # now db is slow and times out - with snapshot_postgres_queries_context(self), connection.execute_wrapper(slow_query), patch( - "posthog.models.feature_flag.flag_matching.FLAG_MATCHING_QUERY_TIMEOUT_MS", - 500, - ), self.assertNumQueries(4): + with ( + snapshot_postgres_queries_context(self), + connection.execute_wrapper(slow_query), + patch( + "posthog.models.feature_flag.flag_matching.FLAG_MATCHING_QUERY_TIMEOUT_MS", + 500, + ), + self.assertNumQueries(4), + ): # no extra queries to get person properties for the second flag after first one failed all_flags, _, _, errors = get_all_feature_flags(team_id, "example_id") @@ -5627,9 +5644,13 @@ def test_feature_flags_v3_with_group_properties_and_slow_db(self, mock_counter, self.assertFalse(errors) # now db is slow - with snapshot_postgres_queries_context(self), connection.execute_wrapper(slow_query), patch( - "posthog.models.feature_flag.flag_matching.FLAG_MATCHING_QUERY_TIMEOUT_MS", - 500, + with ( + snapshot_postgres_queries_context(self), + connection.execute_wrapper(slow_query), + patch( + "posthog.models.feature_flag.flag_matching.FLAG_MATCHING_QUERY_TIMEOUT_MS", + 500, + ), ): with self.assertNumQueries(4): all_flags, _, _, errors = get_all_feature_flags(team_id, "example_id", groups={"organization": "org:1"}) @@ -5737,9 +5758,13 @@ def test_feature_flags_v3_with_experience_continuity_working_slow_db(self, mock_ self.assertFalse(errors) # db is slow and times out - with snapshot_postgres_queries_context(self), connection.execute_wrapper(slow_query), patch( - "posthog.models.feature_flag.flag_matching.FLAG_MATCHING_QUERY_TIMEOUT_MS", - 500, + with ( + snapshot_postgres_queries_context(self), + connection.execute_wrapper(slow_query), + patch( + "posthog.models.feature_flag.flag_matching.FLAG_MATCHING_QUERY_TIMEOUT_MS", + 500, + ), ): all_flags, _, _, errors = get_all_feature_flags(team_id, "example_id", hash_key_override="random") diff --git a/posthog/api/test/test_feature_flag_utils.py b/posthog/api/test/test_feature_flag_utils.py index 53369794dfe4f..c13bf04b6708a 100644 --- a/posthog/api/test/test_feature_flag_utils.py +++ b/posthog/api/test/test_feature_flag_utils.py @@ -1,4 +1,3 @@ -from typing import Set from posthog.models.cohort.cohort import CohortOrEmpty from posthog.test.base import ( APIBaseTest, @@ -68,7 +67,7 @@ def create_cohort(name): self.assertEqual(topologically_sorted_cohort_ids, destination_creation_order) def test_empty_cohorts_set(self): - cohort_ids: Set[int] = set() + cohort_ids: set[int] = set() seen_cohorts_cache: dict[int, CohortOrEmpty] = {} topologically_sorted_cohort_ids = sort_cohorts_topologically(cohort_ids, seen_cohorts_cache) self.assertEqual(topologically_sorted_cohort_ids, []) diff --git a/posthog/api/test/test_ingestion_warnings.py b/posthog/api/test/test_ingestion_warnings.py index bdf3996955909..05e893babfa3e 100644 --- a/posthog/api/test/test_ingestion_warnings.py +++ b/posthog/api/test/test_ingestion_warnings.py @@ -1,5 +1,4 @@ import json -from typing import Dict from freezegun.api import freeze_time from rest_framework import status @@ -13,7 +12,7 @@ from posthog.utils import cast_timestamp_or_now -def create_ingestion_warning(team_id: int, type: str, details: Dict, timestamp: str, source=""): +def create_ingestion_warning(team_id: int, type: str, details: dict, timestamp: str, source=""): timestamp = cast_timestamp_or_now(timestamp) data = { "team_id": team_id, diff --git a/posthog/api/test/test_insight.py b/posthog/api/test/test_insight.py index b427ce13a12a1..2184261eccd59 100644 --- a/posthog/api/test/test_insight.py +++ b/posthog/api/test/test_insight.py @@ -1,6 +1,6 @@ import json from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional +from typing import Any, Optional from unittest import mock from unittest.case import skip from unittest.mock import patch @@ -343,7 +343,7 @@ def test_basic_results(self) -> None: @override_settings(PERSON_ON_EVENTS_OVERRIDE=False, PERSON_ON_EVENTS_V2_OVERRIDE=False) @snapshot_postgres_queries def test_listing_insights_does_not_nplus1(self) -> None: - query_counts: List[int] = [] + query_counts: list[int] = [] queries = [] for i in range(5): @@ -2059,7 +2059,7 @@ def test_insight_trends_allowed_if_project_open_and_org_member(self) -> None: ) self.assertEqual(response.status_code, status.HTTP_200_OK) - def _create_one_person_cohort(self, properties: List[Dict[str, Any]]) -> int: + def _create_one_person_cohort(self, properties: list[dict[str, Any]]) -> int: Person.objects.create(team=self.team, properties=properties) cohort_one_id = self.client.post( f"/api/projects/{self.team.id}/cohorts", @@ -2426,7 +2426,7 @@ def test_soft_delete_can_be_reversed_by_patch(self) -> None: # assert that undeletes end up in the activity log activity_response = self.dashboard_api.get_insight_activity(insight_id) - activity: List[Dict] = activity_response["results"] + activity: list[dict] = activity_response["results"] # we will have three logged activities (in reverse order) undelete, delete, create assert [a["activity"] for a in activity] == ["updated", "updated", "created"] undelete_change_log = activity[0]["detail"]["changes"][0] @@ -2478,10 +2478,10 @@ def _get_insight_with_client_query_id(self, client_query_id: str) -> None: query_params = f"?events={json.dumps([{'id': '$pageview', }])}&client_query_id={client_query_id}" self.client.get(f"/api/projects/{self.team.id}/insights/trend/{query_params}").json() - def assert_insight_activity(self, insight_id: Optional[int], expected: List[Dict]): + def assert_insight_activity(self, insight_id: Optional[int], expected: list[dict]): activity_response = self.dashboard_api.get_insight_activity(insight_id) - activity: List[Dict] = activity_response["results"] + activity: list[dict] = activity_response["results"] self.maxDiff = None assert activity == expected diff --git a/posthog/api/test/test_insight_funnels.py b/posthog/api/test/test_insight_funnels.py index b02ebfec558da..3b4c1403a58ac 100644 --- a/posthog/api/test/test_insight_funnels.py +++ b/posthog/api/test/test_insight_funnels.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Dict, List, Union +from typing import Any, Union from django.test.client import Client from rest_framework import status @@ -1004,7 +1004,7 @@ def test_multi_property_breakdown(self): self.assertEqual(["Chrome", "95"], result[1][1]["breakdown_value"]) @staticmethod - def as_result(breakdown_properties: Union[str, List[str]]) -> Dict[str, Any]: + def as_result(breakdown_properties: Union[str, list[str]]) -> dict[str, Any]: return { "action_id": "$pageview", "name": "$pageview", diff --git a/posthog/api/test/test_insight_query.py b/posthog/api/test/test_insight_query.py index 6279999bbefcb..19044cd937bfd 100644 --- a/posthog/api/test/test_insight_query.py +++ b/posthog/api/test/test_insight_query.py @@ -1,5 +1,3 @@ -from typing import List - from rest_framework import status from ee.api.test.base import LicensedTestMixin @@ -213,7 +211,7 @@ def test_listing_insights_by_default_does_not_include_those_with_only_queries(se }, ) - created_insights: List[Insight] = list(Insight.objects.all()) + created_insights: list[Insight] = list(Insight.objects.all()) assert len(created_insights) == 2 listed_insights = self.dashboard_api.list_insights(query_params={"include_query_insights": False}) @@ -236,7 +234,7 @@ def test_can_list_insights_including_those_with_only_queries(self) -> None: }, ) - created_insights: List[Insight] = list(Insight.objects.all()) + created_insights: list[Insight] = list(Insight.objects.all()) assert len(created_insights) == 2 listed_insights = self.dashboard_api.list_insights(query_params={"include_query_insights": True}) diff --git a/posthog/api/test/test_kafka_inspector.py b/posthog/api/test/test_kafka_inspector.py index 6a42741a47ff1..b9a02d0464e14 100644 --- a/posthog/api/test/test_kafka_inspector.py +++ b/posthog/api/test/test_kafka_inspector.py @@ -1,5 +1,5 @@ import json -from typing import Dict, List, Union +from typing import Union from unittest.mock import patch from rest_framework import status @@ -14,7 +14,7 @@ def setUp(self): self.user.is_staff = True self.user.save() - def _to_json(self, data: Union[Dict, List]) -> str: + def _to_json(self, data: Union[dict, list]) -> str: return json.dumps(data) @patch( diff --git a/posthog/api/test/test_organization_feature_flag.py b/posthog/api/test/test_organization_feature_flag.py index 41960032ca8b7..f1ad4ba26fb06 100644 --- a/posthog/api/test/test_organization_feature_flag.py +++ b/posthog/api/test/test_organization_feature_flag.py @@ -11,7 +11,7 @@ from posthog.models.early_access_feature import EarlyAccessFeature from posthog.api.dashboards.dashboard import Dashboard from posthog.test.base import APIBaseTest, QueryMatchingTest, snapshot_postgres_queries -from typing import Any, Dict +from typing import Any class TestOrganizationFeatureFlagGet(APIBaseTest, QueryMatchingTest): @@ -382,7 +382,7 @@ def test_copy_feature_flag_update_override_deleted(self): def test_copy_feature_flag_missing_fields(self): url = f"/api/organizations/{self.organization.id}/feature_flags/copy_flags" - data: Dict[str, Any] = {} + data: dict[str, Any] = {} response = self.client.post(url, data) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) diff --git a/posthog/api/test/test_person.py b/posthog/api/test/test_person.py index 815f38c472978..a97e9d25de095 100644 --- a/posthog/api/test/test_person.py +++ b/posthog/api/test/test_person.py @@ -1,5 +1,5 @@ import json -from typing import Dict, List, Optional, cast +from typing import Optional, cast from unittest import mock from unittest.mock import patch, Mock @@ -982,10 +982,10 @@ def _get_person_activity( self.assertEqual(activity.status_code, expected_status) return activity.json() - def _assert_person_activity(self, person_id: Optional[str], expected: List[Dict]): + def _assert_person_activity(self, person_id: Optional[str], expected: list[dict]): activity_response = self._get_person_activity(person_id) - activity: List[Dict] = activity_response["results"] + activity: list[dict] = activity_response["results"] self.maxDiff = None self.assertCountEqual(activity, expected) diff --git a/posthog/api/test/test_plugin.py b/posthog/api/test/test_plugin.py index 16e0fc4c0d1d0..06642b460980f 100644 --- a/posthog/api/test/test_plugin.py +++ b/posthog/api/test/test_plugin.py @@ -1,7 +1,7 @@ import base64 import json from datetime import datetime -from typing import Dict, List, cast +from typing import cast from unittest import mock from unittest.mock import ANY, patch @@ -52,10 +52,10 @@ def _get_plugin_activity(self, expected_status: int = status.HTTP_200_OK): self.assertEqual(activity.status_code, expected_status) return activity.json() - def assert_plugin_activity(self, expected: List[Dict]): + def assert_plugin_activity(self, expected: list[dict]): activity_response = self._get_plugin_activity() - activity: List[Dict] = activity_response["results"] + activity: list[dict] = activity_response["results"] self.maxDiff = None self.assertEqual(activity, expected) @@ -586,7 +586,7 @@ def test_create_plugin_version_range_eq_next_minor(self, mock_get, mock_reload): ) self.assertEqual(response.status_code, 400) self.assertEqual( - cast(Dict[str, str], response.json())["detail"], + cast(dict[str, str], response.json())["detail"], f'Currently running PostHog version {FROZEN_POSTHOG_VERSION} does not match this plugin\'s semantic version requirement "{FROZEN_POSTHOG_VERSION.next_minor()}".', ) @@ -608,7 +608,7 @@ def test_create_plugin_version_range_gt_next_major(self, mock_get, mock_reload): ) self.assertEqual(response.status_code, 400) self.assertEqual( - cast(Dict[str, str], response.json())["detail"], + cast(dict[str, str], response.json())["detail"], f'Currently running PostHog version {FROZEN_POSTHOG_VERSION} does not match this plugin\'s semantic version requirement ">= {FROZEN_POSTHOG_VERSION.next_major()}".', ) @@ -620,7 +620,7 @@ def test_create_plugin_version_range_lt_current(self, mock_get, mock_reload): ) self.assertEqual(response.status_code, 400) self.assertEqual( - cast(Dict[str, str], response.json())["detail"], + cast(dict[str, str], response.json())["detail"], f'Currently running PostHog version {FROZEN_POSTHOG_VERSION} does not match this plugin\'s semantic version requirement "< {FROZEN_POSTHOG_VERSION}".', ) @@ -642,7 +642,7 @@ def test_create_plugin_version_range_lt_invalid(self, mock_get, mock_reload): ) self.assertEqual(response.status_code, 400) self.assertEqual( - cast(Dict[str, str], response.json())["detail"], + cast(dict[str, str], response.json())["detail"], 'Invalid PostHog semantic version requirement "< ..."!', ) diff --git a/posthog/api/test/test_properties_timeline.py b/posthog/api/test/test_properties_timeline.py index 5243151c27e09..d8b8a11e9099a 100644 --- a/posthog/api/test/test_properties_timeline.py +++ b/posthog/api/test/test_properties_timeline.py @@ -1,7 +1,7 @@ import json import random import uuid -from typing import Any, Dict, Literal, Optional +from typing import Any, Literal, Optional from freezegun.api import freeze_time from rest_framework import status @@ -52,7 +52,7 @@ def _create_actor(self, properties: dict) -> str: return group.group_key def _create_event(self, event: str, timestamp: str, actor_properties: dict): - create_event_kwargs: Dict[str, Any] = {} + create_event_kwargs: dict[str, Any] = {} if actor_type == "person": create_event_kwargs["person_id"] = main_actor_id create_event_kwargs["person_properties"] = actor_properties diff --git a/posthog/api/test/test_property_definition.py b/posthog/api/test/test_property_definition.py index 77dca5e833076..378f66d7884a5 100644 --- a/posthog/api/test/test_property_definition.py +++ b/posthog/api/test/test_property_definition.py @@ -1,5 +1,5 @@ import json -from typing import Dict, List, Optional, Union +from typing import Optional, Union from unittest.mock import ANY, patch from rest_framework import status @@ -17,7 +17,7 @@ class TestPropertyDefinitionAPI(APIBaseTest): - EXPECTED_PROPERTY_DEFINITIONS: List[Dict[str, Union[str, Optional[int], bool]]] = [ + EXPECTED_PROPERTY_DEFINITIONS: list[dict[str, Union[str, Optional[int], bool]]] = [ {"name": "$browser", "is_numerical": False}, {"name": "$current_url", "is_numerical": False}, {"name": "$lib", "is_numerical": False}, @@ -69,7 +69,7 @@ def test_list_property_definitions(self): self.assertEqual(len(response.json()["results"]), len(self.EXPECTED_PROPERTY_DEFINITIONS)) for item in self.EXPECTED_PROPERTY_DEFINITIONS: - response_item: Dict = next( + response_item: dict = next( (_i for _i in response.json()["results"] if _i["name"] == item["name"]), {}, ) diff --git a/posthog/api/test/test_signup.py b/posthog/api/test/test_signup.py index 1587c0b365e9e..532f7b945e7ca 100644 --- a/posthog/api/test/test_signup.py +++ b/posthog/api/test/test_signup.py @@ -1,6 +1,6 @@ import datetime import uuid -from typing import Dict, Optional, cast +from typing import Optional, cast from unittest import mock from unittest.mock import ANY, patch from zoneinfo import ZoneInfo @@ -294,7 +294,7 @@ def test_cant_sign_up_with_required_attributes_null(self): required_attributes = ["first_name", "email"] for attribute in required_attributes: - body: Dict[str, Optional[str]] = { + body: dict[str, Optional[str]] = { "first_name": "Jane", "email": "invalid@posthog.com", "password": "notsecure", diff --git a/posthog/api/test/test_site_app.py b/posthog/api/test/test_site_app.py index 82823ac4cf4ed..9a428774c6ea7 100644 --- a/posthog/api/test/test_site_app.py +++ b/posthog/api/test/test_site_app.py @@ -1,5 +1,3 @@ -from typing import List - from django.test.client import Client from rest_framework import status @@ -44,7 +42,7 @@ def test_site_app(self): ) def test_get_site_config_from_schema(self): - schema: List[dict] = [{"key": "in_site", "site": True}, {"key": "not_in_site"}] + schema: list[dict] = [{"key": "in_site", "site": True}, {"key": "not_in_site"}] config = {"in_site": "123", "not_in_site": "12345"} self.assertEqual(get_site_config_from_schema(schema, config), {"in_site": "123"}) self.assertEqual(get_site_config_from_schema(None, None), {}) diff --git a/posthog/api/test/test_stickiness.py b/posthog/api/test/test_stickiness.py index 56d610c205e65..b3942414d5459 100644 --- a/posthog/api/test/test_stickiness.py +++ b/posthog/api/test/test_stickiness.py @@ -1,7 +1,7 @@ import uuid from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from dateutil.relativedelta import relativedelta from django.test import override_settings @@ -20,26 +20,26 @@ from posthog.utils import encode_get_request_params -def get_stickiness(client: Client, team: Team, request: Dict[str, Any]): +def get_stickiness(client: Client, team: Team, request: dict[str, Any]): return client.get(f"/api/projects/{team.pk}/insights/trend/", data=request) -def get_stickiness_ok(client: Client, team: Team, request: Dict[str, Any]): +def get_stickiness_ok(client: Client, team: Team, request: dict[str, Any]): response = get_stickiness(client=client, team=team, request=encode_get_request_params(data=request)) assert response.status_code == 200, response.content return response.json() -def get_stickiness_time_series_ok(client: Client, team: Team, request: Dict[str, Any]): +def get_stickiness_time_series_ok(client: Client, team: Team, request: dict[str, Any]): data = get_stickiness_ok(client=client, request=request, team=team) return get_time_series_ok(data) -def get_stickiness_people(client: Client, team_id: int, request: Dict[str, Any]): +def get_stickiness_people(client: Client, team_id: int, request: dict[str, Any]): return client.get("/api/person/stickiness/", data=request) -def get_stickiness_people_ok(client: Client, team_id: int, request: Dict[str, Any]): +def get_stickiness_people_ok(client: Client, team_id: int, request: dict[str, Any]): response = get_stickiness_people(client=client, team_id=team_id, request=encode_get_request_params(data=request)) assert response.status_code == 200 return response.json() diff --git a/posthog/api/test/test_team.py b/posthog/api/test/test_team.py index d23efe81cf7d8..0cae63e3b60b2 100644 --- a/posthog/api/test/test_team.py +++ b/posthog/api/test/test_team.py @@ -1,6 +1,6 @@ import json import uuid -from typing import List, cast, Dict, Optional, Any +from typing import cast, Optional, Any from unittest import mock from unittest.mock import MagicMock, call, patch, ANY @@ -27,7 +27,7 @@ class TestTeamAPI(APIBaseTest): - def _assert_activity_log(self, expected: List[Dict], team_id: Optional[int] = None) -> None: + def _assert_activity_log(self, expected: list[dict], team_id: Optional[int] = None) -> None: if not team_id: team_id = self.team.pk @@ -35,7 +35,7 @@ def _assert_activity_log(self, expected: List[Dict], team_id: Optional[int] = No assert starting_log_response.status_code == 200 assert starting_log_response.json()["results"] == expected - def _assert_organization_activity_log(self, expected: List[Dict]) -> None: + def _assert_organization_activity_log(self, expected: list[dict]) -> None: starting_log_response = self.client.get(f"/api/organizations/{self.organization.pk}/activity") assert starting_log_response.status_code == 200 assert starting_log_response.json()["results"] == expected @@ -95,7 +95,7 @@ def test_cant_retrieve_project_from_another_org(self): @patch("posthog.api.team.get_geoip_properties") def test_ip_location_is_used_for_new_project_week_day_start(self, get_geoip_properties_mock: MagicMock): - self.organization.available_features = cast(List[str], [AvailableFeature.ORGANIZATIONS_PROJECTS]) + self.organization.available_features = cast(list[str], [AvailableFeature.ORGANIZATIONS_PROJECTS]) self.organization.save() self.organization_membership.level = OrganizationMembership.Level.ADMIN self.organization_membership.save() @@ -1039,7 +1039,7 @@ def test_can_set_replay_configs_patch_session_replay_config_one_level_deep(self) # and the existing second level nesting is not preserved self._assert_replay_config_is({"ai_config": {"opt_in": None, "included_event_properties": ["and another"]}}) - def _assert_replay_config_is(self, expected: Dict[str, Any] | None) -> HttpResponse: + def _assert_replay_config_is(self, expected: dict[str, Any] | None) -> HttpResponse: get_response = self.client.get("/api/projects/@current/") assert get_response.status_code == status.HTTP_200_OK, get_response.json() assert get_response.json()["session_replay_config"] == expected @@ -1047,7 +1047,7 @@ def _assert_replay_config_is(self, expected: Dict[str, Any] | None) -> HttpRespo return get_response def _patch_session_replay_config( - self, config: Dict[str, Any] | None, expected_status: int = status.HTTP_200_OK + self, config: dict[str, Any] | None, expected_status: int = status.HTTP_200_OK ) -> HttpResponse: patch_response = self.client.patch( "/api/projects/@current/", @@ -1057,13 +1057,13 @@ def _patch_session_replay_config( return patch_response - def _assert_linked_flag_config(self, expected_config: Dict | None) -> HttpResponse: + def _assert_linked_flag_config(self, expected_config: dict | None) -> HttpResponse: response = self.client.get("/api/projects/@current/") assert response.status_code == status.HTTP_200_OK assert response.json()["session_recording_linked_flag"] == expected_config return response - def _patch_linked_flag_config(self, config: Dict | None, expected_status: int = status.HTTP_200_OK) -> HttpResponse: + def _patch_linked_flag_config(self, config: dict | None, expected_status: int = status.HTTP_200_OK) -> HttpResponse: response = self.client.patch("/api/projects/@current/", {"session_recording_linked_flag": config}) assert response.status_code == expected_status, response.json() return response diff --git a/posthog/api/test/test_user.py b/posthog/api/test/test_user.py index 7113d50e5f7b5..4b682b4095e7f 100644 --- a/posthog/api/test/test_user.py +++ b/posthog/api/test/test_user.py @@ -1,6 +1,6 @@ import datetime import uuid -from typing import Dict, List, cast +from typing import cast from unittest import mock from unittest.mock import ANY, Mock, patch from urllib.parse import quote @@ -326,7 +326,7 @@ def test_set_scene_personalisation_for_user(self, _mock_capture, _mock_identify_ ) def _assert_set_scene_choice( - self, scene: str, dashboard: Dashboard, user: User, expected_choices: List[Dict] + self, scene: str, dashboard: Dashboard, user: User, expected_choices: list[dict] ) -> None: response = self.client.post( "/api/users/@me/scene_personalisation", diff --git a/posthog/api/uploaded_media.py b/posthog/api/uploaded_media.py index d4cea157c69b0..aba0384caf861 100644 --- a/posthog/api/uploaded_media.py +++ b/posthog/api/uploaded_media.py @@ -1,5 +1,5 @@ from io import BytesIO -from typing import Dict, Optional +from typing import Optional import structlog from django.http import HttpResponse @@ -149,7 +149,7 @@ def create(self, request, *args, **kwargs) -> Response: detail="Object storage must be available to allow media uploads.", ) - def get_success_headers(self, location: str) -> Dict: + def get_success_headers(self, location: str) -> dict: try: return {"Location": location} except (TypeError, KeyError): diff --git a/posthog/api/utils.py b/posthog/api/utils.py index d34530cda14cc..ed1a571e6e446 100644 --- a/posthog/api/utils.py +++ b/posthog/api/utils.py @@ -4,7 +4,7 @@ import urllib.parse from enum import Enum, auto from ipaddress import ip_address -from typing import List, Literal, Optional, Union, Tuple +from typing import Literal, Optional, Union from uuid import UUID import structlog @@ -64,7 +64,7 @@ def get_target_entity(filter: Union[Filter, StickinessFilter]) -> Entity: raise ValidationError("An entity must be provided for target entity to be determined") -def entity_from_order(order: Optional[str], entities: List[Entity]) -> Optional[Entity]: +def entity_from_order(order: Optional[str], entities: list[Entity]) -> Optional[Entity]: if not order: return None @@ -78,8 +78,8 @@ def retrieve_entity_from( entity_id: Optional[str], entity_type: Optional[str], entity_math: MathType, - events: List[Entity], - actions: List[Entity], + events: list[Entity], + actions: list[Entity], ) -> Optional[Entity]: """ Retrieves the entity from the events and actions. @@ -251,7 +251,7 @@ def create_event_definitions_sql( event_type: EventDefinitionType, is_enterprise: bool = False, conditions: str = "", - order_expressions: Optional[List[Tuple[str, Literal["ASC", "DESC"]]]] = None, + order_expressions: Optional[list[tuple[str, Literal["ASC", "DESC"]]]] = None, ) -> str: if order_expressions is None: order_expressions = [] @@ -305,7 +305,7 @@ def get_pk_or_uuid(queryset: QuerySet, key: Union[int, str]) -> QuerySet: return queryset.filter(pk=key) -def parse_bool(value: Union[str, List[str]]) -> bool: +def parse_bool(value: Union[str, list[str]]) -> bool: if value == "true": return True return False diff --git a/posthog/async_migrations/definition.py b/posthog/async_migrations/definition.py index 859b8af08819d..52a53164bc770 100644 --- a/posthog/async_migrations/definition.py +++ b/posthog/async_migrations/definition.py @@ -1,13 +1,10 @@ from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, - List, Optional, - Tuple, Union, ) +from collections.abc import Callable from posthog.constants import AnalyticsDBMS from posthog.models.utils import sane_repr @@ -36,9 +33,9 @@ def __init__( self, *, sql: str, - sql_settings: Optional[Dict] = None, + sql_settings: Optional[dict] = None, rollback: Optional[str], - rollback_settings: Optional[Dict] = None, + rollback_settings: Optional[dict] = None, database: AnalyticsDBMS = AnalyticsDBMS.CLICKHOUSE, timeout_seconds: int = ASYNC_MIGRATIONS_DEFAULT_TIMEOUT_SECONDS, per_shard: bool = False, @@ -58,7 +55,7 @@ def rollback_fn(self, query_id: str): if self.rollback is not None: self._execute_op(query_id, self.rollback, self.rollback_settings) - def _execute_op(self, query_id: str, sql: str, settings: Optional[Dict]): + def _execute_op(self, query_id: str, sql: str, settings: Optional[dict]): from posthog.async_migrations.utils import ( execute_op_clickhouse, execute_op_postgres, @@ -91,16 +88,16 @@ class AsyncMigrationDefinition: description = "" # list of versions accepted for the services the migration relies on e.g. ClickHouse, Postgres - service_version_requirements: List[ServiceVersionRequirement] = [] + service_version_requirements: list[ServiceVersionRequirement] = [] # list of operations the migration will perform _in order_ - operations: List[AsyncMigrationOperation] = [] + operations: list[AsyncMigrationOperation] = [] # name of async migration this migration depends on depends_on: Optional[str] = None # optional parameters for this async migration. Shown in the UI when starting the migration - parameters: Dict[str, Tuple[(Optional[Union[int, str]], str, Callable[[Any], Any])]] = {} + parameters: dict[str, tuple[(Optional[Union[int, str]], str, Callable[[Any], Any])]] = {} def __init__(self, name: str): self.name = name @@ -111,11 +108,11 @@ def is_required(self) -> bool: return True # run before starting the migration - def precheck(self) -> Tuple[bool, Optional[str]]: + def precheck(self) -> tuple[bool, Optional[str]]: return (True, None) # run at a regular interval while the migration is being executed - def healthcheck(self) -> Tuple[bool, Optional[str]]: + def healthcheck(self) -> tuple[bool, Optional[str]]: return (True, None) # return an int between 0-100 to specify how far along this migration is diff --git a/posthog/async_migrations/migrations/0001_events_sample_by.py b/posthog/async_migrations/migrations/0001_events_sample_by.py index 4098fd38f32a1..1d8fced273c1b 100644 --- a/posthog/async_migrations/migrations/0001_events_sample_by.py +++ b/posthog/async_migrations/migrations/0001_events_sample_by.py @@ -1,5 +1,3 @@ -from typing import List - from posthog.async_migrations.definition import ( AsyncMigrationDefinition, AsyncMigrationOperation, @@ -17,7 +15,7 @@ class Migration(AsyncMigrationDefinition): posthog_max_version = "1.33.9" - operations: List[AsyncMigrationOperation] = [] + operations: list[AsyncMigrationOperation] = [] def is_required(self): return False diff --git a/posthog/async_migrations/migrations/0002_events_sample_by.py b/posthog/async_migrations/migrations/0002_events_sample_by.py index 7038975b2afbb..2157c380f2ddd 100644 --- a/posthog/async_migrations/migrations/0002_events_sample_by.py +++ b/posthog/async_migrations/migrations/0002_events_sample_by.py @@ -1,5 +1,4 @@ from functools import cached_property -from typing import List from django.conf import settings @@ -76,7 +75,7 @@ def operations(self): # Note: This _should_ be impossible but hard to ensure. raise RuntimeError("Cannot run the migration as `events` table is already Distributed engine.") - create_table_op: List[AsyncMigrationOperation] = [ + create_table_op: list[AsyncMigrationOperation] = [ AsyncMigrationOperationSQL( database=AnalyticsDBMS.CLICKHOUSE, sql=f""" diff --git a/posthog/async_migrations/migrations/0005_person_replacing_by_version.py b/posthog/async_migrations/migrations/0005_person_replacing_by_version.py index 276d6c54abed3..8740456c5e1f7 100644 --- a/posthog/async_migrations/migrations/0005_person_replacing_by_version.py +++ b/posthog/async_migrations/migrations/0005_person_replacing_by_version.py @@ -1,6 +1,5 @@ import json from functools import cached_property -from typing import Dict, List, Tuple import structlog from django.conf import settings @@ -238,9 +237,9 @@ def _copy_batch_from_postgres(self, query_id: str) -> bool: ) return True - def _persons_insert_query(self, persons: List[Person]) -> Tuple[str, Dict]: + def _persons_insert_query(self, persons: list[Person]) -> tuple[str, dict]: values = [] - params: Dict = {} + params: dict = {} for i, person in enumerate(persons): created_at = person.created_at.strftime("%Y-%m-%d %H:%M:%S") # :TRICKY: We use a custom _timestamp to identify rows migrated during this migration diff --git a/posthog/async_migrations/migrations/0006_persons_and_groups_on_events_backfill.py b/posthog/async_migrations/migrations/0006_persons_and_groups_on_events_backfill.py index 62f539f333481..75c5510c9ef49 100644 --- a/posthog/async_migrations/migrations/0006_persons_and_groups_on_events_backfill.py +++ b/posthog/async_migrations/migrations/0006_persons_and_groups_on_events_backfill.py @@ -1,5 +1,3 @@ -from typing import List - from posthog.async_migrations.definition import ( AsyncMigrationDefinition, AsyncMigrationOperation, @@ -19,7 +17,7 @@ class Migration(AsyncMigrationDefinition): depends_on = "0005_person_replacing_by_version" - operations: List[AsyncMigrationOperation] = [] + operations: list[AsyncMigrationOperation] = [] def is_required(self): return False diff --git a/posthog/async_migrations/migrations/0007_persons_and_groups_on_events_backfill.py b/posthog/async_migrations/migrations/0007_persons_and_groups_on_events_backfill.py index 99216ee936b12..f51d171dfe855 100644 --- a/posthog/async_migrations/migrations/0007_persons_and_groups_on_events_backfill.py +++ b/posthog/async_migrations/migrations/0007_persons_and_groups_on_events_backfill.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Dict, Tuple, Union +from typing import Union import structlog from django.conf import settings @@ -289,7 +289,7 @@ def _postcheck(self, _: str): self._check_person_data() self._check_groups_data() - def _where_clause(self) -> Tuple[str, Dict[str, Union[str, int]]]: + def _where_clause(self) -> tuple[str, dict[str, Union[str, int]]]: team_id = self.get_parameter("TEAM_ID") team_id_filter = f" AND team_id = %(team_id)s" if team_id else "" where_clause = f"WHERE timestamp > toDateTime(%(timestamp_lower_bound)s) AND timestamp < toDateTime(%(timestamp_upper_bound)s) {team_id_filter}" diff --git a/posthog/async_migrations/migrations/0009_minmax_indexes_for_materialized_columns.py b/posthog/async_migrations/migrations/0009_minmax_indexes_for_materialized_columns.py index 9b4c64c9af869..d679643b8a538 100644 --- a/posthog/async_migrations/migrations/0009_minmax_indexes_for_materialized_columns.py +++ b/posthog/async_migrations/migrations/0009_minmax_indexes_for_materialized_columns.py @@ -1,5 +1,3 @@ -from typing import List - from posthog.async_migrations.definition import ( AsyncMigrationDefinition, AsyncMigrationOperation, @@ -16,4 +14,4 @@ class Migration(AsyncMigrationDefinition): def is_required(self): return False - operations: List[AsyncMigrationOperation] = [] + operations: list[AsyncMigrationOperation] = [] diff --git a/posthog/async_migrations/runner.py b/posthog/async_migrations/runner.py index 78f2afcf21201..05946cfd3c98a 100644 --- a/posthog/async_migrations/runner.py +++ b/posthog/async_migrations/runner.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import Optional import structlog from semantic_version.base import SimpleSpec @@ -281,7 +281,7 @@ def run_next_migration(candidate: str): trigger_migration(migration_instance) -def is_migration_dependency_fulfilled(migration_name: str) -> Tuple[bool, str]: +def is_migration_dependency_fulfilled(migration_name: str) -> tuple[bool, str]: dependency = get_async_migration_dependency(migration_name) dependency_ok: bool = ( @@ -292,8 +292,8 @@ def is_migration_dependency_fulfilled(migration_name: str) -> Tuple[bool, str]: def check_service_version_requirements( - service_version_requirements: List[ServiceVersionRequirement], -) -> Tuple[bool, str]: + service_version_requirements: list[ServiceVersionRequirement], +) -> tuple[bool, str]: for service_version_requirement in service_version_requirements: in_range, version = service_version_requirement.is_service_in_accepted_version() if not in_range: diff --git a/posthog/async_migrations/setup.py b/posthog/async_migrations/setup.py index 4493f137bd2a2..acc27b495431d 100644 --- a/posthog/async_migrations/setup.py +++ b/posthog/async_migrations/setup.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Optional from django.core.exceptions import ImproperlyConfigured from infi.clickhouse_orm.utils import import_submodules @@ -19,12 +19,12 @@ def reload_migration_definitions(): ALL_ASYNC_MIGRATIONS[name] = module.Migration(name) -ALL_ASYNC_MIGRATIONS: Dict[str, AsyncMigrationDefinition] = {} +ALL_ASYNC_MIGRATIONS: dict[str, AsyncMigrationDefinition] = {} -ASYNC_MIGRATION_TO_DEPENDENCY: Dict[str, Optional[str]] = {} +ASYNC_MIGRATION_TO_DEPENDENCY: dict[str, Optional[str]] = {} # inverted mapping of ASYNC_MIGRATION_TO_DEPENDENCY -DEPENDENCY_TO_ASYNC_MIGRATION: Dict[Optional[str], str] = {} +DEPENDENCY_TO_ASYNC_MIGRATION: dict[Optional[str], str] = {} ASYNC_MIGRATIONS_MODULE_PATH = "posthog.async_migrations.migrations" ASYNC_MIGRATIONS_EXAMPLE_MODULE_PATH = "posthog.async_migrations.examples" diff --git a/posthog/async_migrations/test/test_0007_persons_and_groups_on_events_backfill.py b/posthog/async_migrations/test/test_0007_persons_and_groups_on_events_backfill.py index 4e6588ad45920..9a35ed05c827f 100644 --- a/posthog/async_migrations/test/test_0007_persons_and_groups_on_events_backfill.py +++ b/posthog/async_migrations/test/test_0007_persons_and_groups_on_events_backfill.py @@ -1,5 +1,4 @@ import json -from typing import Dict, List from uuid import uuid4 import pytest @@ -31,7 +30,7 @@ MIGRATION_NAME = "0007_persons_and_groups_on_events_backfill" -uuid1, uuid2, uuid3 = [UUIDT() for _ in range(3)] +uuid1, uuid2, uuid3 = (UUIDT() for _ in range(3)) # Clickhouse leaves behind blank/zero values for non-filled columns, these are checked against these constants ZERO_UUID = UUIDT(uuid_str="00000000-0000-0000-0000-000000000000") ZERO_DATE = "1970-01-01T00:00:00Z" @@ -44,7 +43,7 @@ def run_migration(): return start_async_migration(MIGRATION_NAME, ignore_posthog_version=True) -def query_events() -> List[Dict]: +def query_events() -> list[dict]: return query_with_columns( """ SELECT @@ -351,7 +350,7 @@ def test_rollback(self): MIGRATION_DEFINITION.operations[-1].fn = old_fn def test_timestamp_boundaries(self): - _uuid1, _uuid2, _uuid3 = [UUIDT() for _ in range(3)] + _uuid1, _uuid2, _uuid3 = (UUIDT() for _ in range(3)) create_event( event_uuid=_uuid1, team=self.team, diff --git a/posthog/async_migrations/test/test_0010_move_old_partitions.py b/posthog/async_migrations/test/test_0010_move_old_partitions.py index d316f5f50e625..e249f17a43412 100644 --- a/posthog/async_migrations/test/test_0010_move_old_partitions.py +++ b/posthog/async_migrations/test/test_0010_move_old_partitions.py @@ -14,7 +14,7 @@ MIGRATION_NAME = "0010_move_old_partitions" -uuid1, uuid2, uuid3 = [UUIDT() for _ in range(3)] +uuid1, uuid2, uuid3 = (UUIDT() for _ in range(3)) MIGRATION_DEFINITION = get_async_migration_definition(MIGRATION_NAME) diff --git a/posthog/async_migrations/utils.py b/posthog/async_migrations/utils.py index 20ad64cf7d75b..ee7ecdbe4d2ed 100644 --- a/posthog/async_migrations/utils.py +++ b/posthog/async_migrations/utils.py @@ -1,6 +1,7 @@ import asyncio from datetime import datetime -from typing import Callable, Optional +from typing import Optional +from collections.abc import Callable import posthoganalytics import structlog diff --git a/posthog/auth.py b/posthog/auth.py index 6154ecb1ca0ba..f536ff30c200e 100644 --- a/posthog/auth.py +++ b/posthog/auth.py @@ -1,7 +1,7 @@ import functools import re from datetime import timedelta -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Union from urllib.parse import urlsplit import jwt @@ -57,9 +57,9 @@ class PersonalAPIKeyAuthentication(authentication.BaseAuthentication): def find_key_with_source( cls, request: Union[HttpRequest, Request], - request_data: Optional[Dict[str, Any]] = None, - extra_data: Optional[Dict[str, Any]] = None, - ) -> Optional[Tuple[str, str]]: + request_data: Optional[dict[str, Any]] = None, + extra_data: Optional[dict[str, Any]] = None, + ) -> Optional[tuple[str, str]]: """Try to find personal API key in request and return it along with where it was found.""" if "HTTP_AUTHORIZATION" in request.META: authorization_match = re.match(rf"^{cls.keyword}\s+(\S.+)$", request.META["HTTP_AUTHORIZATION"]) @@ -80,8 +80,8 @@ def find_key_with_source( def find_key( cls, request: Union[HttpRequest, Request], - request_data: Optional[Dict[str, Any]] = None, - extra_data: Optional[Dict[str, Any]] = None, + request_data: Optional[dict[str, Any]] = None, + extra_data: Optional[dict[str, Any]] = None, ) -> Optional[str]: """Try to find personal API key in request and return it.""" key_with_source = cls.find_key_with_source(request, request_data, extra_data) @@ -121,7 +121,7 @@ def validate_key(cls, personal_api_key_with_source): return personal_api_key_object - def authenticate(self, request: Union[HttpRequest, Request]) -> Optional[Tuple[Any, None]]: + def authenticate(self, request: Union[HttpRequest, Request]) -> Optional[tuple[Any, None]]: personal_api_key_with_source = self.find_key_with_source(request) if not personal_api_key_with_source: return None @@ -190,7 +190,7 @@ class JwtAuthentication(authentication.BaseAuthentication): keyword = "Bearer" @classmethod - def authenticate(cls, request: Union[HttpRequest, Request]) -> Optional[Tuple[Any, None]]: + def authenticate(cls, request: Union[HttpRequest, Request]) -> Optional[tuple[Any, None]]: if "HTTP_AUTHORIZATION" in request.META: authorization_match = re.match(rf"^Bearer\s+(\S.+)$", request.META["HTTP_AUTHORIZATION"]) if authorization_match: @@ -222,7 +222,7 @@ class SharingAccessTokenAuthentication(authentication.BaseAuthentication): sharing_configuration: SharingConfiguration - def authenticate(self, request: Union[HttpRequest, Request]) -> Optional[Tuple[Any, Any]]: + def authenticate(self, request: Union[HttpRequest, Request]) -> Optional[tuple[Any, Any]]: if sharing_access_token := request.GET.get("sharing_access_token"): if request.method not in ["GET", "HEAD"]: raise AuthenticationFailed(detail="Sharing access token can only be used for GET requests.") diff --git a/posthog/caching/calculate_results.py b/posthog/caching/calculate_results.py index 2fcf0ff04ccdd..4089323202e50 100644 --- a/posthog/caching/calculate_results.py +++ b/posthog/caching/calculate_results.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Optional, Union import structlog from sentry_sdk import capture_exception @@ -77,7 +77,7 @@ def get_cache_type_for_filter(cacheable: FilterType) -> CacheType: return CacheType.TRENDS -def get_cache_type_for_query(cacheable: Dict) -> CacheType: +def get_cache_type_for_query(cacheable: dict) -> CacheType: cache_type = None if cacheable.get("source"): @@ -92,7 +92,7 @@ def get_cache_type_for_query(cacheable: Dict) -> CacheType: return cache_type -def get_cache_type(cacheable: Optional[FilterType] | Optional[Dict]) -> CacheType: +def get_cache_type(cacheable: Optional[FilterType] | Optional[dict]) -> CacheType: if isinstance(cacheable, dict): return get_cache_type_for_query(cacheable) elif cacheable is not None: @@ -146,7 +146,7 @@ def calculate_for_query_based_insight( def calculate_for_filter_based_insight( insight: Insight, dashboard: Optional[Dashboard] -) -> Tuple[str, str, List | Dict]: +) -> tuple[str, str, list | dict]: filter = get_filter(data=insight.dashboard_filters(dashboard), team=insight.team) cache_key = generate_insight_cache_key(insight, dashboard) cache_type = get_cache_type(filter) @@ -161,7 +161,7 @@ def calculate_for_filter_based_insight( return cache_key, cache_type, calculate_result_by_cache_type(cache_type, filter, insight.team) -def calculate_result_by_cache_type(cache_type: CacheType, filter: Filter, team: Team) -> List[Dict[str, Any]]: +def calculate_result_by_cache_type(cache_type: CacheType, filter: Filter, team: Team) -> list[dict[str, Any]]: if cache_type == CacheType.FUNNEL: return _calculate_funnel(filter, team) else: @@ -169,7 +169,7 @@ def calculate_result_by_cache_type(cache_type: CacheType, filter: Filter, team: @timed("update_cache_item_timer.calculate_by_filter") -def _calculate_by_filter(filter: FilterType, team: Team, cache_type: CacheType) -> List[Dict[str, Any]]: +def _calculate_by_filter(filter: FilterType, team: Team, cache_type: CacheType) -> list[dict[str, Any]]: insight_class = CACHE_TYPE_TO_INSIGHT_CLASS[cache_type] if cache_type == CacheType.PATHS: @@ -180,7 +180,7 @@ def _calculate_by_filter(filter: FilterType, team: Team, cache_type: CacheType) @timed("update_cache_item_timer.calculate_funnel") -def _calculate_funnel(filter: Filter, team: Team) -> List[Dict[str, Any]]: +def _calculate_funnel(filter: Filter, team: Team) -> list[dict[str, Any]]: if filter.funnel_viz_type == FunnelVizType.TRENDS: result = ClickhouseFunnelTrends(team=team, filter=filter).run() elif filter.funnel_viz_type == FunnelVizType.TIME_TO_CONVERT: @@ -193,7 +193,7 @@ def _calculate_funnel(filter: Filter, team: Team) -> List[Dict[str, Any]]: def cache_includes_latest_events( - payload: Dict, filter: Union[RetentionFilter, StickinessFilter, PathFilter, Filter] + payload: dict, filter: Union[RetentionFilter, StickinessFilter, PathFilter, Filter] ) -> bool: """ event_definition has last_seen_at timestamp @@ -218,7 +218,7 @@ def cache_includes_latest_events( return False -def _events_from_filter(filter: Union[RetentionFilter, StickinessFilter, PathFilter, Filter]) -> List[str]: +def _events_from_filter(filter: Union[RetentionFilter, StickinessFilter, PathFilter, Filter]) -> list[str]: """ If a filter only represents a set of events then we can use their last_seen_at to determine if the cache is up-to-date diff --git a/posthog/caching/fetch_from_cache.py b/posthog/caching/fetch_from_cache.py index fcbeb0b72e341..fe5d46ace3d51 100644 --- a/posthog/caching/fetch_from_cache.py +++ b/posthog/caching/fetch_from_cache.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from django.utils.timezone import now from prometheus_client import Counter @@ -27,7 +27,7 @@ class InsightResult: is_cached: bool timezone: Optional[str] next_allowed_client_refresh: Optional[datetime] = None - timings: Optional[List[QueryTiming]] = None + timings: Optional[list[QueryTiming]] = None @dataclass(frozen=True) diff --git a/posthog/caching/insight_cache.py b/posthog/caching/insight_cache.py index d73486234dfb1..97b5c691e4643 100644 --- a/posthog/caching/insight_cache.py +++ b/posthog/caching/insight_cache.py @@ -1,6 +1,6 @@ from datetime import datetime, timedelta from time import perf_counter -from typing import Any, List, Optional, Tuple, cast +from typing import Any, Optional, cast from uuid import UUID import structlog @@ -49,7 +49,7 @@ def schedule_cache_updates(): logger.warn("No caches were found to be updated") -def fetch_states_in_need_of_updating(limit: int) -> List[Tuple[int, str, UUID]]: +def fetch_states_in_need_of_updating(limit: int) -> list[tuple[int, str, UUID]]: current_time = now() with connection.cursor() as cursor: cursor.execute( @@ -162,7 +162,7 @@ def update_cached_state( ) -def _extract_insight_dashboard(caching_state: InsightCachingState) -> Tuple[Insight, Optional[Dashboard]]: +def _extract_insight_dashboard(caching_state: InsightCachingState) -> tuple[Insight, Optional[Dashboard]]: if caching_state.dashboard_tile is not None: assert caching_state.dashboard_tile.insight is not None diff --git a/posthog/caching/insight_caching_state.py b/posthog/caching/insight_caching_state.py index a8ae36c14f05a..ae3eb269425f0 100644 --- a/posthog/caching/insight_caching_state.py +++ b/posthog/caching/insight_caching_state.py @@ -1,7 +1,7 @@ from datetime import timedelta from enum import Enum from functools import cached_property -from typing import List, Optional, Union +from typing import Optional, Union import structlog from django.core.paginator import Paginator @@ -232,10 +232,10 @@ def _iterate_large_queryset(queryset, page_size): yield page.object_list -def _execute_insert(states: List[Optional[InsightCachingState]]): +def _execute_insert(states: list[Optional[InsightCachingState]]): from django.db import connection - models: List[InsightCachingState] = list(filter(None, states)) + models: list[InsightCachingState] = list(filter(None, states)) if len(models) == 0: return diff --git a/posthog/caching/insights_api.py b/posthog/caching/insights_api.py index 35a75cdf8a0b1..11760e2dc4108 100644 --- a/posthog/caching/insights_api.py +++ b/posthog/caching/insights_api.py @@ -1,7 +1,7 @@ from datetime import datetime, timedelta from math import ceil from time import sleep -from typing import Optional, Tuple, Union +from typing import Optional, Union import zoneinfo from rest_framework import request @@ -37,7 +37,7 @@ def should_refresh_insight( *, request: request.Request, is_shared=False, -) -> Tuple[bool, timedelta]: +) -> tuple[bool, timedelta]: """Return whether the insight should be refreshed now, and what's the minimum wait time between refreshes. If a refresh already is being processed somewhere else, this function will wait for that to finish (or time out). diff --git a/posthog/caching/test/test_insight_cache.py b/posthog/caching/test/test_insight_cache.py index 9de2053f6c2f1..b86ac56a3de99 100644 --- a/posthog/caching/test/test_insight_cache.py +++ b/posthog/caching/test/test_insight_cache.py @@ -1,5 +1,6 @@ from datetime import timedelta -from typing import Callable, Optional +from typing import Optional +from collections.abc import Callable from unittest.mock import call, patch import pytest diff --git a/posthog/caching/test/test_insight_caching_state.py b/posthog/caching/test/test_insight_caching_state.py index 03a3652555202..47465786fb17b 100644 --- a/posthog/caching/test/test_insight_caching_state.py +++ b/posthog/caching/test/test_insight_caching_state.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Any, Dict, Optional, Union, cast +from typing import Any, Optional, Union, cast from unittest.mock import patch import pytest @@ -42,7 +42,7 @@ def create_insight( is_shared=True, filters=filter_dict, deleted=False, - query: Optional[Dict] = None, + query: Optional[dict] = None, ) -> Insight: if mock_active_teams: mock_active_teams.return_value = {team.pk} if team_should_be_active else set() @@ -77,7 +77,7 @@ def create_tile( dashboard_tile_deleted=False, is_dashboard_shared=True, text_tile=False, - query: Optional[Dict] = None, + query: Optional[dict] = None, ) -> DashboardTile: if mock_active_teams: mock_active_teams.return_value = {team.pk} if team_should_be_active else set() @@ -295,7 +295,7 @@ def test_calculate_target_age( team: Team, user: User, create_item, - create_item_kw: Dict, + create_item_kw: dict, expected_target_age: TargetCacheAge, ): item = cast( diff --git a/posthog/caching/utils.py b/posthog/caching/utils.py index d0c6450cc7dba..c56d0f33571d5 100644 --- a/posthog/caching/utils.py +++ b/posthog/caching/utils.py @@ -1,6 +1,6 @@ from datetime import datetime from dateutil.parser import isoparse -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Optional, Union from zoneinfo import ZoneInfo from dateutil.parser import parser @@ -32,7 +32,7 @@ def ensure_is_date(candidate: Optional[Union[str, datetime]]) -> Optional[dateti return parser().parse(candidate) -def active_teams() -> Set[int]: +def active_teams() -> set[int]: """ Teams are stored in a sorted set. [{team_id: score}, {team_id: score}]. Their "score" is the number of seconds since last event. @@ -43,7 +43,7 @@ def active_teams() -> Set[int]: This assumes that the list of active teams is small enough to reasonably load in one go. """ redis = get_client() - all_teams: List[Tuple[bytes, float]] = redis.zrange(RECENTLY_ACCESSED_TEAMS_REDIS_KEY, 0, -1, withscores=True) + all_teams: list[tuple[bytes, float]] = redis.zrange(RECENTLY_ACCESSED_TEAMS_REDIS_KEY, 0, -1, withscores=True) if not all_teams: teams_by_recency = sync_execute( """ @@ -106,7 +106,7 @@ def is_stale(team: Team, date_to: datetime, interval: str, cached_result: Any) - return False last_refresh = ( - cached_result.get("last_refresh", None) if isinstance(cached_result, Dict) else cached_result.last_refresh + cached_result.get("last_refresh", None) if isinstance(cached_result, dict) else cached_result.last_refresh ) date_to = min([date_to, datetime.now(tz=ZoneInfo("UTC"))]) # can't be later than now diff --git a/posthog/celery.py b/posthog/celery.py index a78a7c94ad844..29c45c9b60729 100644 --- a/posthog/celery.py +++ b/posthog/celery.py @@ -1,6 +1,5 @@ import os import time -from typing import Dict from celery import Celery from celery.signals import ( @@ -71,7 +70,7 @@ app.steps["worker"].add(DjangoStructLogInitStep) -task_timings: Dict[str, float] = {} +task_timings: dict[str, float] = {} @setup_logging.connect diff --git a/posthog/clickhouse/client/connection.py b/posthog/clickhouse/client/connection.py index 31ae6cd291de0..35c72a305faea 100644 --- a/posthog/clickhouse/client/connection.py +++ b/posthog/clickhouse/client/connection.py @@ -1,6 +1,6 @@ from contextlib import contextmanager from enum import Enum -from functools import lru_cache +from functools import cache from clickhouse_driver import Client as SyncClient from clickhouse_pool import ChPool @@ -65,7 +65,7 @@ def default_client(): ) -@lru_cache(maxsize=None) +@cache def make_ch_pool(**overrides) -> ChPool: kwargs = { "host": settings.CLICKHOUSE_HOST, diff --git a/posthog/clickhouse/client/execute.py b/posthog/clickhouse/client/execute.py index b588badfc07ea..17af5683a6f19 100644 --- a/posthog/clickhouse/client/execute.py +++ b/posthog/clickhouse/client/execute.py @@ -4,7 +4,8 @@ from contextlib import contextmanager from functools import lru_cache from time import perf_counter -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Optional, Union +from collections.abc import Sequence import sqlparse from clickhouse_driver import Client as SyncClient @@ -19,7 +20,7 @@ from posthog.utils import generate_short_id, patchable InsertParams = Union[list, tuple, types.GeneratorType] -NonInsertParams = Dict[str, Any] +NonInsertParams = dict[str, Any] QueryArgs = Optional[Union[InsertParams, NonInsertParams]] thread_local_storage = threading.local() @@ -39,7 +40,7 @@ @lru_cache(maxsize=1) -def default_settings() -> Dict: +def default_settings() -> dict: return { "join_algorithm": "direct,parallel_hash", "distributed_replica_max_ignored_errors": 1000, @@ -131,11 +132,11 @@ def query_with_columns( query: str, args: Optional[QueryArgs] = None, columns_to_remove: Optional[Sequence[str]] = None, - columns_to_rename: Optional[Dict[str, str]] = None, + columns_to_rename: Optional[dict[str, str]] = None, *, workload: Workload = Workload.DEFAULT, team_id: Optional[int] = None, -) -> List[Dict]: +) -> list[dict]: if columns_to_remove is None: columns_to_remove = [] if columns_to_rename is None: @@ -184,7 +185,7 @@ def _prepare_query( below predicate. """ prepared_args: Any = QueryArgs - if isinstance(args, (list, tuple, types.GeneratorType)): + if isinstance(args, list | tuple | types.GeneratorType): # If we get one of these it means we have an insert, let the clickhouse # client handle substitution here. rendered_sql = query diff --git a/posthog/clickhouse/client/migration_tools.py b/posthog/clickhouse/client/migration_tools.py index f71abd489fd64..aa3100b548bc0 100644 --- a/posthog/clickhouse/client/migration_tools.py +++ b/posthog/clickhouse/client/migration_tools.py @@ -1,4 +1,5 @@ -from typing import Callable, Union +from typing import Union +from collections.abc import Callable from infi.clickhouse_orm import migrations diff --git a/posthog/clickhouse/materialized_columns/column.py b/posthog/clickhouse/materialized_columns/column.py index 70aca94511ac9..a206c051395cc 100644 --- a/posthog/clickhouse/materialized_columns/column.py +++ b/posthog/clickhouse/materialized_columns/column.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Dict, List, Literal, Tuple, Union +from typing import Literal, Union from posthog.cache_utils import cache_for from posthog.models.property import PropertyName, TableColumn, TableWithProperties @@ -12,7 +12,7 @@ @cache_for(timedelta(minutes=15)) def get_materialized_columns( table: TablesWithMaterializedColumns, -) -> Dict[Tuple[PropertyName, TableColumn], ColumnName]: +) -> dict[tuple[PropertyName, TableColumn], ColumnName]: return {} @@ -28,7 +28,7 @@ def materialize( def backfill_materialized_columns( table: TableWithProperties, - properties: List[Tuple[PropertyName, TableColumn]], + properties: list[tuple[PropertyName, TableColumn]], backfill_period: timedelta, test_settings=None, ) -> None: diff --git a/posthog/clickhouse/migrations/0046_ensure_kafa_session_replay_table_exists.py b/posthog/clickhouse/migrations/0046_ensure_kafa_session_replay_table_exists.py index 877139c155ee0..85d6664e475be 100644 --- a/posthog/clickhouse/migrations/0046_ensure_kafa_session_replay_table_exists.py +++ b/posthog/clickhouse/migrations/0046_ensure_kafa_session_replay_table_exists.py @@ -1,6 +1,4 @@ -from typing import List - -operations: List = [ +operations: list = [ # this migration has been amended to be entirely No-op # it has applied successfully in Prod US where it was a no-op # as all tables/columns it affected already existed diff --git a/posthog/clickhouse/system_status.py b/posthog/clickhouse/system_status.py index e04c6bf7597f1..eec283f3b5ab2 100644 --- a/posthog/clickhouse/system_status.py +++ b/posthog/clickhouse/system_status.py @@ -1,6 +1,6 @@ from datetime import timedelta from os.path import abspath, dirname, join -from typing import Dict, Generator, List, Tuple +from collections.abc import Generator from zoneinfo import ZoneInfo from dateutil.relativedelta import relativedelta @@ -27,7 +27,7 @@ CLICKHOUSE_FLAMEGRAPH_EXECUTABLE = abspath(join(dirname(__file__), "bin", "clickhouse-flamegraph")) FLAMEGRAPH_PL = abspath(join(dirname(__file__), "bin", "flamegraph.pl")) -SystemStatusRow = Dict +SystemStatusRow = dict def system_status() -> Generator[SystemStatusRow, None, None]: @@ -179,7 +179,7 @@ def is_alive() -> bool: return False -def dead_letter_queue_ratio() -> Tuple[bool, int]: +def dead_letter_queue_ratio() -> tuple[bool, int]: dead_letter_queue_events_last_day = get_dead_letter_queue_events_last_24h() total_events_ingested_last_day = sync_execute( @@ -199,14 +199,14 @@ def dead_letter_queue_ratio_ok_cached() -> bool: return dead_letter_queue_ratio()[0] -def get_clickhouse_running_queries() -> List[Dict]: +def get_clickhouse_running_queries() -> list[dict]: return query_with_columns( "SELECT elapsed as duration, query, * FROM system.processes ORDER BY duration DESC", columns_to_remove=["address", "initial_address", "elapsed"], ) -def get_clickhouse_slow_log() -> List[Dict]: +def get_clickhouse_slow_log() -> list[dict]: return query_with_columns( f""" SELECT query_duration_ms as duration, query, * diff --git a/posthog/conftest.py b/posthog/conftest.py index 7d2895eb38015..8f3f233358ca8 100644 --- a/posthog/conftest.py +++ b/posthog/conftest.py @@ -1,4 +1,4 @@ -from typing import Any, Tuple +from typing import Any import pytest from django.conf import settings @@ -22,7 +22,7 @@ def create_clickhouse_tables(num_tables: int): ) # REMEMBER TO ADD ANY NEW CLICKHOUSE TABLES TO THIS ARRAY! - CREATE_TABLE_QUERIES: Tuple[Any, ...] = CREATE_MERGETREE_TABLE_QUERIES + CREATE_DISTRIBUTED_TABLE_QUERIES + CREATE_TABLE_QUERIES: tuple[Any, ...] = CREATE_MERGETREE_TABLE_QUERIES + CREATE_DISTRIBUTED_TABLE_QUERIES # Check if all the tables have already been created if num_tables == len(CREATE_TABLE_QUERIES): diff --git a/posthog/decorators.py b/posthog/decorators.py index 955bb9d085195..bb012701033e2 100644 --- a/posthog/decorators.py +++ b/posthog/decorators.py @@ -1,6 +1,7 @@ from enum import Enum from functools import wraps -from typing import Any, Callable, Dict, List, TypeVar, Union, cast +from typing import Any, TypeVar, Union, cast +from collections.abc import Callable from django.urls import resolve from django.utils.timezone import now @@ -25,7 +26,7 @@ class CacheType(str, Enum): PATHS = "Path" -ResultPackage = Union[Dict[str, Any], List[Dict[str, Any]]] +ResultPackage = Union[dict[str, Any], list[dict[str, Any]]] T = TypeVar("T", bound=ResultPackage) U = TypeVar("U", bound=GenericViewSet) diff --git a/posthog/demo/legacy/data_generator.py b/posthog/demo/legacy/data_generator.py index ccc9f163e6c3c..d507e65c31c67 100644 --- a/posthog/demo/legacy/data_generator.py +++ b/posthog/demo/legacy/data_generator.py @@ -1,4 +1,3 @@ -from typing import Dict, List from uuid import uuid4 from posthog.models import Person, PersonDistinctId, Team @@ -13,9 +12,9 @@ def __init__(self, team: Team, n_days=14, n_people=100): self.team = team self.n_days = n_days self.n_people = n_people - self.events: List[Dict] = [] - self.snapshots: List[Dict] = [] - self.distinct_ids: List[str] = [] + self.events: list[dict] = [] + self.snapshots: list[dict] = [] + self.distinct_ids: list[str] = [] def create(self, dashboards=True): self.create_missing_events_and_properties() diff --git a/posthog/demo/legacy/web_data_generator.py b/posthog/demo/legacy/web_data_generator.py index aa0836d3db732..811270092250f 100644 --- a/posthog/demo/legacy/web_data_generator.py +++ b/posthog/demo/legacy/web_data_generator.py @@ -1,7 +1,7 @@ import json import random from datetime import timedelta -from typing import Any, Dict, List +from typing import Any from dateutil.relativedelta import relativedelta from django.utils.timezone import now @@ -199,11 +199,11 @@ def make_person(self, index): return super().make_person(index) @cached_property - def demo_data(self) -> List[Dict[str, Any]]: - with open(get_absolute_path("demo/legacy/demo_people.json"), "r") as demo_data_file: + def demo_data(self) -> list[dict[str, Any]]: + with open(get_absolute_path("demo/legacy/demo_people.json")) as demo_data_file: return json.load(demo_data_file) @cached_property - def demo_recording(self) -> Dict[str, Any]: - with open(get_absolute_path("demo/legacy/hogflix_session_recording.json"), "r") as demo_session_file: + def demo_recording(self) -> dict[str, Any]: + with open(get_absolute_path("demo/legacy/hogflix_session_recording.json")) as demo_session_file: return json.load(demo_session_file) diff --git a/posthog/demo/matrix/manager.py b/posthog/demo/matrix/manager.py index 507ea09581d51..ce073a6126f8a 100644 --- a/posthog/demo/matrix/manager.py +++ b/posthog/demo/matrix/manager.py @@ -1,7 +1,7 @@ import datetime as dt import json from time import sleep -from typing import Any, Dict, List, Literal, Optional, Tuple, cast +from typing import Any, Literal, Optional, cast from django.conf import settings from django.core import exceptions @@ -55,13 +55,13 @@ def ensure_account_and_save( password: Optional[str] = None, is_staff: bool = False, disallow_collision: bool = False, - ) -> Tuple[Organization, Team, User]: + ) -> tuple[Organization, Team, User]: """If there's an email collision in signup in the demo environment, we treat it as a login.""" existing_user: Optional[User] = User.objects.filter(email=email).first() if existing_user is None: if self.print_steps: print(f"Creating demo organization, project, and user...") - organization_kwargs: Dict[str, Any] = {"name": organization_name} + organization_kwargs: dict[str, Any] = {"name": organization_name} if settings.DEMO: organization_kwargs["plugins_access_level"] = Organization.PluginsAccessLevel.INSTALL with transaction.atomic(): @@ -241,7 +241,7 @@ def _sync_postgres_with_clickhouse_data(cls, source_team_id: int, target_team_id ["team_id", "is_deleted", "_timestamp", "_offset", "_partition"], {"id": "uuid"}, ) - bulk_persons: Dict[str, Person] = {} + bulk_persons: dict[str, Person] = {} for row in clickhouse_persons: properties = json.loads(row.pop("properties", "{}")) bulk_persons[row["uuid"]] = Person(team_id=target_team_id, properties=properties, **row) @@ -317,7 +317,7 @@ def _save_sim_person(self, team: Team, subject: SimPerson): self._save_future_sim_events(team, subject.future_events) @staticmethod - def _save_past_sim_events(team: Team, events: List[SimEvent]): + def _save_past_sim_events(team: Team, events: list[SimEvent]): """Past events are saved into ClickHouse right away (via Kafka of course).""" from posthog.models.event.util import create_event @@ -346,7 +346,7 @@ def _save_past_sim_events(team: Team, events: List[SimEvent]): ) @staticmethod - def _save_future_sim_events(team: Team, events: List[SimEvent]): + def _save_future_sim_events(team: Team, events: list[SimEvent]): """Future events are not saved immediately, instead they're scheduled for ingestion via event buffer.""" # TODO: This used the plugin server's Graphile Worker-based event buffer, but the event buffer is no more @@ -356,7 +356,7 @@ def _save_sim_group( team: Team, type_index: Literal[0, 1, 2, 3, 4], key: str, - properties: Dict[str, Any], + properties: dict[str, Any], timestamp: dt.datetime, ): from posthog.models.group.util import raw_create_group_ch diff --git a/posthog/demo/matrix/matrix.py b/posthog/demo/matrix/matrix.py index c2d3a5f2eb4f4..382e70d85b78d 100644 --- a/posthog/demo/matrix/matrix.py +++ b/posthog/demo/matrix/matrix.py @@ -3,13 +3,7 @@ from collections import defaultdict, deque from typing import ( Any, - DefaultDict, - Deque, - Dict, - List, Optional, - Set, - Type, ) import mimesis @@ -38,7 +32,7 @@ class Cluster(ABC): end: timezone.datetime # End of the simulation (might be same as now or later) radius: int - people_matrix: List[List[SimPerson]] # Grid containing all people in the cluster + people_matrix: list[list[SimPerson]] # Grid containing all people in the cluster random: mimesis.random.Random properties_provider: PropertiesProvider @@ -52,7 +46,7 @@ class Cluster(ABC): _simulation_time: dt.datetime _reached_now: bool - _scheduled_effects: Deque[Effect] + _scheduled_effects: deque[Effect] def __init__(self, *, index: int, matrix: "Matrix") -> None: self.index = index @@ -98,7 +92,7 @@ def initiation_distribution(self) -> float: """Return a value between 0 and 1 determining how far into the overall simulation should this cluster be initiated.""" return self.random.random() - def list_neighbors(self, person: SimPerson) -> List[SimPerson]: + def list_neighbors(self, person: SimPerson) -> list[SimPerson]: """Return a list of neighbors of a person at (x, y).""" x, y = person.x, person.y neighbors = [] @@ -141,7 +135,7 @@ def _apply_due_effects(self, until: dt.datetime): while self._scheduled_effects and self._scheduled_effects[0].timestamp <= until: effect = self._scheduled_effects.popleft() self.simulation_time = effect.timestamp - resolved_targets: List[SimPerson] + resolved_targets: list[SimPerson] if effect.target == Effect.Target.SELF: resolved_targets = [effect.source] elif effect.target == Effect.Target.ALL_NEIGHBORS: @@ -155,7 +149,7 @@ def _apply_due_effects(self, until: dt.datetime): effect.callback(target) @property - def people(self) -> Set[SimPerson]: + def people(self) -> set[SimPerson]: return {person for row in self.people_matrix for person in row} @property @@ -198,17 +192,17 @@ class Matrix(ABC): """ PRODUCT_NAME: str - CLUSTER_CLASS: Type[Cluster] - PERSON_CLASS: Type[SimPerson] + CLUSTER_CLASS: type[Cluster] + PERSON_CLASS: type[SimPerson] start: dt.datetime now: dt.datetime end: dt.datetime group_type_index_offset: int # A mapping of groups. The first key is the group type, the second key is the group key. - groups: DefaultDict[str, DefaultDict[str, Dict[str, Any]]] - distinct_id_to_person: Dict[str, SimPerson] - clusters: List[Cluster] + groups: defaultdict[str, defaultdict[str, dict[str, Any]]] + distinct_id_to_person: dict[str, SimPerson] + clusters: list[Cluster] is_complete: Optional[bool] server_client: SimServerClient @@ -257,7 +251,7 @@ def __init__( self.is_complete = None @property - def people(self) -> List[SimPerson]: + def people(self) -> list[SimPerson]: return [person for cluster in self.clusters for person in cluster.people] @abstractmethod @@ -273,7 +267,7 @@ def simulate(self): cluster.simulate() self.is_complete = True - def _update_group(self, group_type: str, group_key: str, set_properties: Dict[str, Any]): + def _update_group(self, group_type: str, group_key: str, set_properties: dict[str, Any]): if len(self.groups) == GROUP_TYPES_LIMIT and group_type not in self.groups: raise Exception(f"Cannot add group type {group_type} to simulation, limit of {GROUP_TYPES_LIMIT} reached!") self.groups[group_type][group_key].update(set_properties) diff --git a/posthog/demo/matrix/models.py b/posthog/demo/matrix/models.py index e1698d7dd7b3b..c09fcae8cbb03 100644 --- a/posthog/demo/matrix/models.py +++ b/posthog/demo/matrix/models.py @@ -8,17 +8,12 @@ from typing import ( TYPE_CHECKING, Any, - Callable, - DefaultDict, - Dict, Generic, - Iterable, - List, Literal, Optional, - Set, TypeVar, ) +from collections.abc import Callable, Iterable from urllib.parse import urlparse, parse_qs from uuid import UUID @@ -77,7 +72,7 @@ class Target(Enum): "$referrer", } -Properties = Dict[str, Any] +Properties = dict[str, Any] class SimSessionIntent(Enum): @@ -330,23 +325,23 @@ class SimPerson(ABC): timezone: str # Exposed state - present - past_events: List[SimEvent] - future_events: List[SimEvent] + past_events: list[SimEvent] + future_events: list[SimEvent] # Exposed state - at `now` - distinct_ids_at_now: Set[str] + distinct_ids_at_now: set[str] properties_at_now: Properties first_seen_at: Optional[dt.datetime] last_seen_at: Optional[dt.datetime] # Internal state active_client: SimBrowserClient # Client being used by person - all_time_pageview_counts: DefaultDict[str, int] # Pageview count per URL across all time - session_pageview_counts: DefaultDict[str, int] # Pageview count per URL across the ongoing session + all_time_pageview_counts: defaultdict[str, int] # Pageview count per URL across all time + session_pageview_counts: defaultdict[str, int] # Pageview count per URL across the ongoing session active_session_intent: Optional[SimSessionIntent] wake_up_by: dt.datetime - _groups: Dict[str, str] - _distinct_ids: Set[str] + _groups: dict[str, str] + _distinct_ids: set[str] _properties: Properties def __init__(self, *, kernel: bool, cluster: "Cluster", x: int, y: int): @@ -397,7 +392,7 @@ def attempt_session(self): # Abstract methods - def decide_feature_flags(self) -> Dict[str, Any]: + def decide_feature_flags(self) -> dict[str, Any]: """Determine feature flags in force at present.""" return {} diff --git a/posthog/demo/matrix/randomization.py b/posthog/demo/matrix/randomization.py index ca6bcfd588640..d017c295321dc 100644 --- a/posthog/demo/matrix/randomization.py +++ b/posthog/demo/matrix/randomization.py @@ -1,10 +1,9 @@ from enum import Enum -from typing import Dict, List, Tuple import mimesis import mimesis.random -WeightedPool = Tuple[List[str], List[int]] +WeightedPool = tuple[list[str], list[int]] class Industry(str, Enum): @@ -27,12 +26,12 @@ class PropertiesProvider(mimesis.BaseProvider): ["Desktop", "Mobile", "Tablet"], [8, 1, 1], ) - OS_WEIGHTED_POOLS: Dict[str, WeightedPool] = { + OS_WEIGHTED_POOLS: dict[str, WeightedPool] = { "Desktop": (["Windows", "Mac OS X", "Linux", "Chrome OS"], [18, 16, 7, 1]), "Mobile": (["iOS", "Android"], [1, 1]), "Tablet": (["iOS", "Android"], [1, 1]), } - BROWSER_WEIGHTED_POOLS: Dict[str, WeightedPool] = { + BROWSER_WEIGHTED_POOLS: dict[str, WeightedPool] = { "Windows": ( ["Chrome", "Firefox", "Opera", "Microsoft Edge", "Internet Explorer"], [12, 4, 2, 1, 1], @@ -65,7 +64,7 @@ class PropertiesProvider(mimesis.BaseProvider): random: mimesis.random.Random - def device_type_os_browser(self) -> Tuple[str, str, str]: + def device_type_os_browser(self) -> tuple[str, str, str]: device_type_pool, device_type_weights = self.DEVICE_TYPE_WEIGHTED_POOL device_type = self.random.choices(device_type_pool, device_type_weights)[0] os_pool, os_weights = self.OS_WEIGHTED_POOLS[device_type] diff --git a/posthog/demo/matrix/taxonomy_inference.py b/posthog/demo/matrix/taxonomy_inference.py index cc5686de96b0b..e05dc67f33368 100644 --- a/posthog/demo/matrix/taxonomy_inference.py +++ b/posthog/demo/matrix/taxonomy_inference.py @@ -1,5 +1,5 @@ import json -from typing import Dict, List, Optional, Tuple +from typing import Optional from django.utils import timezone @@ -9,7 +9,7 @@ from posthog.models.property_definition import PropertyType -def infer_taxonomy_for_team(team_id: int) -> Tuple[int, int, int]: +def infer_taxonomy_for_team(team_id: int) -> tuple[int, int, int]: """Infer event and property definitions based on ClickHouse data. In production, the plugin server is responsible for this - but in demo data we insert directly to ClickHouse. @@ -55,13 +55,13 @@ def infer_taxonomy_for_team(team_id: int) -> Tuple[int, int, int]: return len(event_definitions), len(property_definitions), len(event_properties) -def _get_events_last_seen_at(team_id: int) -> Dict[str, timezone.datetime]: +def _get_events_last_seen_at(team_id: int) -> dict[str, timezone.datetime]: from posthog.client import sync_execute return dict(sync_execute(_GET_EVENTS_LAST_SEEN_AT, {"team_id": team_id})) -def _get_property_types(team_id: int) -> Dict[str, Optional[PropertyType]]: +def _get_property_types(team_id: int) -> dict[str, Optional[PropertyType]]: """Determine property types based on ClickHouse data.""" from posthog.client import sync_execute @@ -87,14 +87,14 @@ def _infer_property_type(sample_json_value: str) -> Optional[PropertyType]: parsed_value = json.loads(sample_json_value) if isinstance(parsed_value, bool): return PropertyType.Boolean - if isinstance(parsed_value, (float, int)): + if isinstance(parsed_value, float | int): return PropertyType.Numeric if isinstance(parsed_value, str): return PropertyType.String return None -def _get_event_property_pairs(team_id: int) -> List[Tuple[str, str]]: +def _get_event_property_pairs(team_id: int) -> list[tuple[str, str]]: """Determine which properties have been since with which events based on ClickHouse data.""" from posthog.client import sync_execute diff --git a/posthog/demo/products/hedgebox/models.py b/posthog/demo/products/hedgebox/models.py index 324dc6b473762..dd694f64aac41 100644 --- a/posthog/demo/products/hedgebox/models.py +++ b/posthog/demo/products/hedgebox/models.py @@ -5,11 +5,7 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - List, Optional, - Set, - Tuple, cast, ) from urllib.parse import urlencode, urlparse, urlunparse @@ -114,9 +110,9 @@ def __hash__(self) -> int: class HedgeboxAccount: id: str created_at: dt.datetime - team_members: Set["HedgeboxPerson"] + team_members: set["HedgeboxPerson"] plan: HedgeboxPlan - files: Set[HedgeboxFile] = field(default_factory=set) + files: set[HedgeboxFile] = field(default_factory=set) was_billing_scheduled: bool = field(default=False) @property @@ -247,7 +243,7 @@ def has_signed_up(self) -> bool: # Abstract methods - def decide_feature_flags(self) -> Dict[str, Any]: + def decide_feature_flags(self) -> dict[str, Any]: if ( self.cluster.simulation_time >= self.cluster.matrix.new_signup_page_experiment_start and self.cluster.simulation_time < self.cluster.matrix.new_signup_page_experiment_end @@ -292,7 +288,7 @@ def determine_session_intent(self) -> Optional[HedgeboxSessionIntent]: # Very low affinity users aren't interested # Non-kernel business users can't log in or sign up return None - possible_intents_with_weights: List[Tuple[HedgeboxSessionIntent, float]] = [] + possible_intents_with_weights: list[tuple[HedgeboxSessionIntent, float]] = [] if self.invite_to_use_id: possible_intents_with_weights.append((HedgeboxSessionIntent.JOIN_TEAM, 1)) elif self.file_to_view: @@ -342,8 +338,8 @@ def determine_session_intent(self) -> Optional[HedgeboxSessionIntent]: if possible_intents_with_weights: possible_intents, weights = zip(*possible_intents_with_weights) return self.cluster.random.choices( - cast(Tuple[HedgeboxSessionIntent], possible_intents), - cast(Tuple[float], weights), + cast(tuple[HedgeboxSessionIntent], possible_intents), + cast(tuple[float], weights), )[0] else: return None @@ -807,10 +803,10 @@ def log_out(self): self.advance_timer(self.cluster.random.uniform(0.1, 0.2)) @property - def invitable_neighbors(self) -> List["HedgeboxPerson"]: + def invitable_neighbors(self) -> list["HedgeboxPerson"]: return [ neighbor - for neighbor in cast(List[HedgeboxPerson], self.cluster.list_neighbors(self)) + for neighbor in cast(list[HedgeboxPerson], self.cluster.list_neighbors(self)) if neighbor.is_invitable ] diff --git a/posthog/email.py b/posthog/email.py index 61edb7ae593d2..3590723f7084b 100644 --- a/posthog/email.py +++ b/posthog/email.py @@ -1,5 +1,5 @@ import sys -from typing import Dict, List, Optional +from typing import Optional import lxml import toronado @@ -54,9 +54,9 @@ def is_email_available(with_absolute_urls: bool = False) -> bool: @shared_task(**EMAIL_TASK_KWARGS) def _send_email( campaign_key: str, - to: List[Dict[str, str]], + to: list[dict[str, str]], subject: str, - headers: Dict, + headers: dict, txt_body: str = "", html_body: str = "", reply_to: Optional[str] = None, @@ -65,8 +65,8 @@ def _send_email( Sends built email message asynchronously. """ - messages: List = [] - records: List = [] + messages: list = [] + records: list = [] with transaction.atomic(): for dest in to: @@ -135,8 +135,8 @@ def __init__( campaign_key: str, subject: str, template_name: str, - template_context: Optional[Dict] = None, - headers: Optional[Dict] = None, + template_context: Optional[dict] = None, + headers: Optional[dict] = None, reply_to: Optional[str] = None, ): if template_context is None: @@ -153,7 +153,7 @@ def __init__( self.html_body = inline_css(template.render(template_context)) self.txt_body = "" self.headers = headers if headers else {} - self.to: List[Dict[str, str]] = [] + self.to: list[dict[str, str]] = [] self.reply_to = reply_to def add_recipient(self, email: str, name: Optional[str] = None) -> None: diff --git a/posthog/errors.py b/posthog/errors.py index d028522a599c0..70b3d46dd3c31 100644 --- a/posthog/errors.py +++ b/posthog/errors.py @@ -1,6 +1,6 @@ from dataclasses import dataclass import re -from typing import Dict, Optional +from typing import Optional from clickhouse_driver.errors import ServerException @@ -91,7 +91,7 @@ def look_up_error_code_meta(error: ServerException) -> ErrorCodeMeta: # # Remember to add back the `user_safe` args though! CLICKHOUSE_UNKNOWN_EXCEPTION = ErrorCodeMeta("UNKNOWN_EXCEPTION") -CLICKHOUSE_ERROR_CODE_LOOKUP: Dict[int, ErrorCodeMeta] = { +CLICKHOUSE_ERROR_CODE_LOOKUP: dict[int, ErrorCodeMeta] = { 0: ErrorCodeMeta("OK"), 1: ErrorCodeMeta("UNSUPPORTED_METHOD"), 2: ErrorCodeMeta("UNSUPPORTED_PARAMETER"), diff --git a/posthog/event_usage.py b/posthog/event_usage.py index ae8432c6b2731..cf74b59936365 100644 --- a/posthog/event_usage.py +++ b/posthog/event_usage.py @@ -2,7 +2,7 @@ Module to centralize event reporting on the server-side. """ -from typing import Dict, List, Optional +from typing import Optional import posthoganalytics @@ -107,7 +107,7 @@ def report_user_logged_in( ) -def report_user_updated(user: User, updated_attrs: List[str]) -> None: +def report_user_updated(user: User, updated_attrs: list[str]) -> None: """ Reports a user has been updated. This includes current_team, current_organization & password. """ @@ -217,7 +217,7 @@ def report_user_organization_membership_level_changed( ) -def report_user_action(user: User, event: str, properties: Optional[Dict] = None, team: Optional[Team] = None): +def report_user_action(user: User, event: str, properties: Optional[dict] = None, team: Optional[Team] = None): if properties is None: properties = {} posthoganalytics.capture( @@ -254,8 +254,8 @@ def groups(organization: Optional[Organization] = None, team: Optional[Team] = N def report_team_action( team: Team, event: str, - properties: Optional[Dict] = None, - group_properties: Optional[Dict] = None, + properties: Optional[dict] = None, + group_properties: Optional[dict] = None, ): """ For capturing events where it is unclear which user was the core actor we can use the team instead @@ -271,8 +271,8 @@ def report_team_action( def report_organization_action( organization: Organization, event: str, - properties: Optional[Dict] = None, - group_properties: Optional[Dict] = None, + properties: Optional[dict] = None, + group_properties: Optional[dict] = None, ): """ For capturing events where it is unclear which user was the core actor we can use the organization instead diff --git a/posthog/filters.py b/posthog/filters.py index ac098dea92c68..911edcf4596d7 100644 --- a/posthog/filters.py +++ b/posthog/filters.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, TypeVar, Union +from typing import Optional, TypeVar, Union from django.db import models from django.db.models import Q @@ -19,7 +19,7 @@ class TermSearchFilterBackend(filters.BaseFilterBackend): # The URL query parameter used for the search. search_param = settings.api_settings.SEARCH_PARAM - def get_search_fields(self, view: APIView) -> Optional[List[str]]: + def get_search_fields(self, view: APIView) -> Optional[list[str]]: """ Search fields are obtained from the view. """ @@ -59,10 +59,10 @@ def filter_queryset( def term_search_filter_sql( - search_fields: List[str], + search_fields: list[str], search_terms: Optional[str] = "", search_extra: Optional[str] = "", -) -> Tuple[str, dict]: +) -> tuple[str, dict]: if not search_fields or not search_terms: return "", {} diff --git a/posthog/gzip_middleware.py b/posthog/gzip_middleware.py index 701f31b5dbe3d..cfd57eea0050b 100644 --- a/posthog/gzip_middleware.py +++ b/posthog/gzip_middleware.py @@ -1,5 +1,4 @@ import re -from typing import List from django.conf import settings from django.middleware.gzip import GZipMiddleware @@ -9,7 +8,7 @@ class InvalidGZipAllowList(Exception): pass -def allowed_path(path: str, allowed_paths: List) -> bool: +def allowed_path(path: str, allowed_paths: list) -> bool: return any(pattern.search(path) for pattern in allowed_paths) diff --git a/posthog/health.py b/posthog/health.py index 1ca35d6fe7308..72012928feb4e 100644 --- a/posthog/health.py +++ b/posthog/health.py @@ -17,7 +17,8 @@ # changes to them are deliberate, as otherwise we could introduce unexpected # behaviour in deployments. -from typing import Callable, Dict, List, Literal, cast, get_args +from typing import Literal, cast, get_args +from collections.abc import Callable from django.core.cache import cache from django.db import DEFAULT_DB_ALIAS @@ -35,7 +36,7 @@ ServiceRole = Literal["events", "web", "worker", "decide"] -service_dependencies: Dict[ServiceRole, List[str]] = { +service_dependencies: dict[ServiceRole, list[str]] = { "events": ["http", "kafka_connected"], "web": [ "http", @@ -66,7 +67,7 @@ # if atleast one of the checks is True, then the service is considered healthy # for the given role -service_conditional_dependencies: Dict[ServiceRole, List[str]] = { +service_conditional_dependencies: dict[ServiceRole, list[str]] = { "decide": ["cache", "postgres_flags"], } @@ -110,7 +111,7 @@ def readyz(request: HttpRequest): if role and role not in get_args(ServiceRole): return JsonResponse({"error": "InvalidRole"}, status=400) - available_checks: Dict[str, Callable] = { + available_checks: dict[str, Callable] = { "clickhouse": is_clickhouse_connected, "postgres": is_postgres_connected, "postgres_flags": lambda: is_postgres_connected(DATABASE_FOR_FLAG_MATCHING), diff --git a/posthog/heatmaps/heatmaps_api.py b/posthog/heatmaps/heatmaps_api.py index 35a424d4f1517..f06899e3c4178 100644 --- a/posthog/heatmaps/heatmaps_api.py +++ b/posthog/heatmaps/heatmaps_api.py @@ -1,5 +1,5 @@ from datetime import datetime, date -from typing import Any, Dict, List +from typing import Any, List # noqa: UP035 from rest_framework import viewsets, request, response, serializers, status @@ -80,7 +80,7 @@ def validate_date_from(self, value) -> date: except Exception: raise serializers.ValidationError("Error parsing provided date_from: {}".format(value)) - def validate(self, values) -> Dict: + def validate(self, values) -> dict: url_exact = values.get("url_exact", None) url_pattern = values.get("url_pattern", None) if isinstance(url_exact, str) and isinstance(url_pattern, str): @@ -154,10 +154,10 @@ def _choose_aggregation(self, aggregation, is_scrolldepth_query): return aggregation_count @staticmethod - def _predicate_expressions(placeholders: Dict[str, Expr]) -> List[ast.Expr]: - predicate_expressions: List[ast.Expr] = [] + def _predicate_expressions(placeholders: dict[str, Expr]) -> List[ast.Expr]: # noqa: UP006 + predicate_expressions: list[ast.Expr] = [] - predicate_mapping: Dict[str, str] = { + predicate_mapping: dict[str, str] = { # should always have values "date_from": "timestamp >= {date_from}", "type": "`type` = {type}", diff --git a/posthog/heatmaps/test/test_heatmaps_api.py b/posthog/heatmaps/test/test_heatmaps_api.py index 4f0896c5ef141..e07343a2760b4 100644 --- a/posthog/heatmaps/test/test_heatmaps_api.py +++ b/posthog/heatmaps/test/test_heatmaps_api.py @@ -1,5 +1,4 @@ import math -from typing import Dict import freezegun from django.http import HttpResponse @@ -48,21 +47,21 @@ class TestSessionRecordings(APIBaseTest, ClickhouseTestMixin, QueryMatchingTest) CLASS_DATA_LEVEL_SETUP = False def _assert_heatmap_no_result_count( - self, params: Dict[str, str | int | None] | None, expected_status_code: int = status.HTTP_200_OK + self, params: dict[str, str | int | None] | None, expected_status_code: int = status.HTTP_200_OK ) -> None: response = self._get_heatmap(params, expected_status_code) if response.status_code == status.HTTP_200_OK: assert len(response.json()["results"]) == 0 def _assert_heatmap_single_result_count( - self, params: Dict[str, str | int | None] | None, expected_grouped_count: int + self, params: dict[str, str | int | None] | None, expected_grouped_count: int ) -> None: response = self._get_heatmap(params) assert len(response.json()["results"]) == 1 assert response.json()["results"][0]["count"] == expected_grouped_count def _get_heatmap( - self, params: Dict[str, str | int | None] | None, expected_status_code: int = status.HTTP_200_OK + self, params: dict[str, str | int | None] | None, expected_status_code: int = status.HTTP_200_OK ) -> HttpResponse: if params is None: params = {} diff --git a/posthog/helpers/dashboard_templates.py b/posthog/helpers/dashboard_templates.py index cfaa2bac5e1d1..0e3f8a81f9536 100644 --- a/posthog/helpers/dashboard_templates.py +++ b/posthog/helpers/dashboard_templates.py @@ -1,4 +1,5 @@ -from typing import Callable, Dict, List, Optional +from typing import Optional +from collections.abc import Callable import structlog @@ -28,7 +29,7 @@ from posthog.models.insight import Insight from posthog.models.tag import Tag -DASHBOARD_COLORS: List[str] = ["white", "blue", "green", "purple", "black"] +DASHBOARD_COLORS: list[str] = ["white", "blue", "green", "purple", "black"] logger = structlog.get_logger(__name__) @@ -444,7 +445,7 @@ def _create_default_app_items(dashboard: Dashboard) -> None: create_from_template(dashboard, template) -DASHBOARD_TEMPLATES: Dict[str, Callable] = { +DASHBOARD_TEMPLATES: dict[str, Callable] = { "DEFAULT_APP": _create_default_app_items, "WEBSITE_TRAFFIC": _create_website_dashboard, } @@ -491,7 +492,7 @@ def create_from_template(dashboard: Dashboard, template: DashboardTemplate) -> N logger.error("dashboard_templates.creation.unknown_type", template=template) -def _create_tile_for_text(dashboard: Dashboard, body: str, layouts: Dict, color: Optional[str]) -> None: +def _create_tile_for_text(dashboard: Dashboard, body: str, layouts: dict, color: Optional[str]) -> None: text = Text.objects.create( team=dashboard.team, body=body, @@ -507,11 +508,11 @@ def _create_tile_for_text(dashboard: Dashboard, body: str, layouts: Dict, color: def _create_tile_for_insight( dashboard: Dashboard, name: str, - filters: Dict, + filters: dict, description: str, - layouts: Dict, + layouts: dict, color: Optional[str], - query: Optional[Dict] = None, + query: Optional[dict] = None, ) -> None: filter_test_accounts = filters.get("filter_test_accounts", True) insight = Insight.objects.create( diff --git a/posthog/helpers/multi_property_breakdown.py b/posthog/helpers/multi_property_breakdown.py index edc5fe68f1bfb..94fc538b2957d 100644 --- a/posthog/helpers/multi_property_breakdown.py +++ b/posthog/helpers/multi_property_breakdown.py @@ -1,12 +1,12 @@ import copy -from typing import Any, Dict, List, Union +from typing import Any, Union -funnel_with_breakdown_type = List[List[Dict[str, Any]]] -possible_funnel_results_types = Union[funnel_with_breakdown_type, List[Dict[str, Any]], Dict[str, Any]] +funnel_with_breakdown_type = list[list[dict[str, Any]]] +possible_funnel_results_types = Union[funnel_with_breakdown_type, list[dict[str, Any]], dict[str, Any]] def protect_old_clients_from_multi_property_default( - request_filter: Dict[str, Any], result: possible_funnel_results_types + request_filter: dict[str, Any], result: possible_funnel_results_types ) -> possible_funnel_results_types: """ Implementing multi property breakdown will default breakdown to a list even if it is received as a string. @@ -25,7 +25,7 @@ def protect_old_clients_from_multi_property_default( :return: """ - if isinstance(result, Dict) or (len(result) > 1) and isinstance(result[0], Dict): + if isinstance(result, dict) or (len(result) > 1) and isinstance(result[0], dict): return result is_breakdown_request = ( @@ -34,7 +34,7 @@ def protect_old_clients_from_multi_property_default( and "breakdown_type" in request_filter and request_filter["breakdown_type"] in ["person", "event"] ) - is_breakdown_result = isinstance(result, List) and len(result) > 0 and isinstance(result[0], List) + is_breakdown_result = isinstance(result, list) and len(result) > 0 and isinstance(result[0], list) is_single_property_breakdown = ( is_breakdown_request @@ -49,14 +49,14 @@ def protect_old_clients_from_multi_property_default( for series_index in range(len(result)): copied_series = copied_result[series_index] - if isinstance(copied_series, List): + if isinstance(copied_series, list): for data_index in range(len(copied_series)): copied_item = copied_series[data_index] if is_single_property_breakdown: - if copied_item.get("breakdown") and isinstance(copied_item["breakdown"], List): + if copied_item.get("breakdown") and isinstance(copied_item["breakdown"], list): copied_item["breakdown"] = copied_item["breakdown"][0] - if copied_item.get("breakdown_value") and isinstance(copied_item["breakdown_value"], List): + if copied_item.get("breakdown_value") and isinstance(copied_item["breakdown_value"], list): copied_item["breakdown_value"] = copied_item["breakdown_value"][0] if is_multi_property_breakdown: diff --git a/posthog/helpers/tests/test_multi_property_breakdown.py b/posthog/helpers/tests/test_multi_property_breakdown.py index d22675adf84e9..417583ae00965 100644 --- a/posthog/helpers/tests/test_multi_property_breakdown.py +++ b/posthog/helpers/tests/test_multi_property_breakdown.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any from unittest import TestCase from posthog.helpers.multi_property_breakdown import ( @@ -8,8 +8,8 @@ class TestMultiPropertyBreakdown(TestCase): def test_handles_empty_inputs(self): - data: Dict[str, Any] = {} - result: List = [] + data: dict[str, Any] = {} + result: list = [] try: protect_old_clients_from_multi_property_default(data, result) @@ -17,12 +17,12 @@ def test_handles_empty_inputs(self): raise AssertionError("should not raise any KeyError") def test_handles_empty_breakdowns_array(self): - data: Dict[str, Any] = { + data: dict[str, Any] = { "breakdowns": [], "insight": "FUNNELS", "breakdown_type": "event", } - result: List = [] + result: list = [] try: protect_old_clients_from_multi_property_default(data, result) @@ -30,37 +30,37 @@ def test_handles_empty_breakdowns_array(self): raise AssertionError("should not raise any KeyError") def test_keeps_multi_property_breakdown_for_multi_property_requests(self): - data: Dict[str, Any] = { + data: dict[str, Any] = { "breakdowns": ["a", "b"], "insight": "FUNNELS", "breakdown_type": "event", } - result: List[List[Dict[str, Any]]] = [[{"breakdown": ["a1", "b1"], "breakdown_value": ["a1", "b1"]}]] + result: list[list[dict[str, Any]]] = [[{"breakdown": ["a1", "b1"], "breakdown_value": ["a1", "b1"]}]] actual = protect_old_clients_from_multi_property_default(data, result) # to satisfy mypy - assert isinstance(actual, List) + assert isinstance(actual, list) series = actual[0] - assert isinstance(series, List) + assert isinstance(series, list) data = series[0] assert data["breakdowns"] == ["a1", "b1"] assert "breakdown" not in data def test_flattens_multi_property_breakdown_for_single_property_requests(self): - data: Dict[str, Any] = { + data: dict[str, Any] = { "breakdown": "a", "insight": "FUNNELS", "breakdown_type": "event", } - result: List[List[Dict[str, Any]]] = [[{"breakdown": ["a1"], "breakdown_value": ["a1", "b1"]}]] + result: list[list[dict[str, Any]]] = [[{"breakdown": ["a1"], "breakdown_value": ["a1", "b1"]}]] actual = protect_old_clients_from_multi_property_default(data, result) # to satisfy mypy - assert isinstance(actual, List) + assert isinstance(actual, list) series = actual[0] - assert isinstance(series, List) + assert isinstance(series, list) data = series[0] assert data["breakdown"] == "a1" assert "breakdowns" not in data diff --git a/posthog/hogql/ai.py b/posthog/hogql/ai.py index 15b03e82e5030..71a565ec77773 100644 --- a/posthog/hogql/ai.py +++ b/posthog/hogql/ai.py @@ -63,7 +63,7 @@ def write_sql_from_prompt(prompt: str, *, current_query: Optional[str] = None, t schema_description = "\n\n".join( ( f"Table {table_name} with fields:\n" - + "\n".join((f'- {field["key"]} ({field["type"]})' for field in table_fields)) + + "\n".join(f'- {field["key"]} ({field["type"]})' for field in table_fields) for table_name, table_fields in serialized_database.items() ) ) diff --git a/posthog/hogql/ast.py b/posthog/hogql/ast.py index ccb3f9f34576d..e3fa80b3f3ee8 100644 --- a/posthog/hogql/ast.py +++ b/posthog/hogql/ast.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Literal, Optional, Union from dataclasses import dataclass, field from posthog.hogql.base import Type, Expr, CTE, ConstantType, UnknownType, AST @@ -143,14 +143,14 @@ class SelectQueryType(Type): """Type and new enclosed scope for a select query. Contains information about all tables and columns in the query.""" # all aliases a select query has access to in its scope - aliases: Dict[str, FieldAliasType] = field(default_factory=dict) + aliases: dict[str, FieldAliasType] = field(default_factory=dict) # all types a select query exports - columns: Dict[str, Type] = field(default_factory=dict) + columns: dict[str, Type] = field(default_factory=dict) # all from and join, tables and subqueries with aliases - tables: Dict[str, TableOrSelectType] = field(default_factory=dict) - ctes: Dict[str, CTE] = field(default_factory=dict) + tables: dict[str, TableOrSelectType] = field(default_factory=dict) + ctes: dict[str, CTE] = field(default_factory=dict) # all from and join subqueries without aliases - anonymous_tables: List[Union["SelectQueryType", "SelectUnionQueryType"]] = field(default_factory=list) + anonymous_tables: list[Union["SelectQueryType", "SelectUnionQueryType"]] = field(default_factory=list) # the parent select query, if this is a lambda parent: Optional[Union["SelectQueryType", "SelectUnionQueryType"]] = None @@ -173,7 +173,7 @@ def has_child(self, name: str, context: HogQLContext) -> bool: @dataclass(kw_only=True) class SelectUnionQueryType(Type): - types: List[SelectQueryType] + types: list[SelectQueryType] def get_alias_for_table_type(self, table_type: TableOrSelectType) -> Optional[str]: return self.types[0].get_alias_for_table_type(table_type) @@ -313,7 +313,7 @@ def print_type(self) -> str: @dataclass(kw_only=True) class TupleType(ConstantType): data_type: ConstantDataType = field(default="tuple", init=False) - item_types: List[ConstantType] + item_types: list[ConstantType] def print_type(self) -> str: return "Tuple" @@ -322,8 +322,8 @@ def print_type(self) -> str: @dataclass(kw_only=True) class CallType(Type): name: str - arg_types: List[ConstantType] - param_types: Optional[List[ConstantType]] = None + arg_types: list[ConstantType] + param_types: Optional[list[ConstantType]] = None return_type: ConstantType def resolve_constant_type(self, context: HogQLContext) -> ConstantType: @@ -337,7 +337,7 @@ class AsteriskType(Type): @dataclass(kw_only=True) class FieldTraverserType(Type): - chain: List[str | int] + chain: list[str | int] table_type: TableOrSelectType @@ -400,7 +400,7 @@ def resolve_table_type(self, context: HogQLContext): @dataclass(kw_only=True) class PropertyType(Type): - chain: List[str | int] + chain: list[str | int] field_type: FieldType # The property has been moved into a field we query from a joined subquery @@ -449,12 +449,12 @@ class ArithmeticOperation(Expr): @dataclass(kw_only=True) class And(Expr): type: Optional[ConstantType] = None - exprs: List[Expr] + exprs: list[Expr] @dataclass(kw_only=True) class Or(Expr): - exprs: List[Expr] + exprs: list[Expr] type: Optional[ConstantType] = None @@ -509,7 +509,7 @@ class ArrayAccess(Expr): @dataclass(kw_only=True) class Array(Expr): - exprs: List[Expr] + exprs: list[Expr] @dataclass(kw_only=True) @@ -520,12 +520,12 @@ class TupleAccess(Expr): @dataclass(kw_only=True) class Tuple(Expr): - exprs: List[Expr] + exprs: list[Expr] @dataclass(kw_only=True) class Lambda(Expr): - args: List[str] + args: list[str] expr: Expr @@ -536,7 +536,7 @@ class Constant(Expr): @dataclass(kw_only=True) class Field(Expr): - chain: List[str | int] + chain: list[str | int] @dataclass(kw_only=True) @@ -548,8 +548,8 @@ class Placeholder(Expr): class Call(Expr): name: str """Function name""" - args: List[Expr] - params: Optional[List[Expr]] = None + args: list[Expr] + params: Optional[list[Expr]] = None """ Parameters apply to some aggregate functions, see ClickHouse docs: https://clickhouse.com/docs/en/sql-reference/aggregate-functions/parametric-functions @@ -569,7 +569,7 @@ class JoinExpr(Expr): join_type: Optional[str] = None table: Optional[Union["SelectQuery", "SelectUnionQuery", Field]] = None - table_args: Optional[List[Expr]] = None + table_args: Optional[list[Expr]] = None alias: Optional[str] = None table_final: Optional[bool] = None constraint: Optional["JoinConstraint"] = None @@ -585,8 +585,8 @@ class WindowFrameExpr(Expr): @dataclass(kw_only=True) class WindowExpr(Expr): - partition_by: Optional[List[Expr]] = None - order_by: Optional[List[OrderExpr]] = None + partition_by: Optional[list[Expr]] = None + order_by: Optional[list[OrderExpr]] = None frame_method: Optional[Literal["ROWS", "RANGE"]] = None frame_start: Optional[WindowFrameExpr] = None frame_end: Optional[WindowFrameExpr] = None @@ -595,7 +595,7 @@ class WindowExpr(Expr): @dataclass(kw_only=True) class WindowFunction(Expr): name: str - args: Optional[List[Expr]] = None + args: Optional[list[Expr]] = None over_expr: Optional[WindowExpr] = None over_identifier: Optional[str] = None @@ -604,20 +604,20 @@ class WindowFunction(Expr): class SelectQuery(Expr): # :TRICKY: When adding new fields, make sure they're handled in visitor.py and resolver.py type: Optional[SelectQueryType] = None - ctes: Optional[Dict[str, CTE]] = None - select: List[Expr] + ctes: Optional[dict[str, CTE]] = None + select: list[Expr] distinct: Optional[bool] = None select_from: Optional[JoinExpr] = None array_join_op: Optional[str] = None - array_join_list: Optional[List[Expr]] = None - window_exprs: Optional[Dict[str, WindowExpr]] = None + array_join_list: Optional[list[Expr]] = None + window_exprs: Optional[dict[str, WindowExpr]] = None where: Optional[Expr] = None prewhere: Optional[Expr] = None having: Optional[Expr] = None - group_by: Optional[List[Expr]] = None - order_by: Optional[List[OrderExpr]] = None + group_by: Optional[list[Expr]] = None + order_by: Optional[list[OrderExpr]] = None limit: Optional[Expr] = None - limit_by: Optional[List[Expr]] = None + limit_by: Optional[list[Expr]] = None limit_with_ties: Optional[bool] = None offset: Optional[Expr] = None settings: Optional[HogQLQuerySettings] = None @@ -627,7 +627,7 @@ class SelectQuery(Expr): @dataclass(kw_only=True) class SelectUnionQuery(Expr): type: Optional[SelectUnionQueryType] = None - select_queries: List[SelectQuery] + select_queries: list[SelectQuery] @dataclass(kw_only=True) @@ -652,7 +652,7 @@ class HogQLXAttribute(AST): @dataclass(kw_only=True) class HogQLXTag(AST): kind: str - attributes: List[HogQLXAttribute] + attributes: list[HogQLXAttribute] def to_dict(self): return { diff --git a/posthog/hogql/autocomplete.py b/posthog/hogql/autocomplete.py index b6d003c1ac88d..c0d4cd8b84f9d 100644 --- a/posthog/hogql/autocomplete.py +++ b/posthog/hogql/autocomplete.py @@ -1,5 +1,6 @@ from copy import copy, deepcopy -from typing import Callable, Dict, List, Optional, cast +from typing import Optional, cast +from collections.abc import Callable from posthog.hogql.context import HogQLContext from posthog.hogql.database.database import Database, create_hogql_database from posthog.hogql.database.models import ( @@ -38,7 +39,7 @@ class GetNodeAtPositionTraverser(TraversingVisitor): start: int end: int - selects: List[ast.SelectQuery] = [] + selects: list[ast.SelectQuery] = [] node: Optional[AST] = None parent_node: Optional[AST] = None last_node: Optional[AST] = None @@ -100,13 +101,13 @@ def convert_field_or_table_to_type_string(field_or_table: FieldOrTable) -> str | return "Object" if isinstance(field_or_table, ast.ExpressionField): return "Expression" - if isinstance(field_or_table, (ast.Table, ast.LazyJoin)): + if isinstance(field_or_table, ast.Table | ast.LazyJoin): return "Table" return None -def get_table(context: HogQLContext, join_expr: ast.JoinExpr, ctes: Optional[Dict[str, CTE]]) -> None | Table: +def get_table(context: HogQLContext, join_expr: ast.JoinExpr, ctes: Optional[dict[str, CTE]]) -> None | Table: assert context.database is not None def resolve_fields_on_table(table: Table | None, table_query: ast.SelectQuery) -> Table | None: @@ -120,7 +121,7 @@ def resolve_fields_on_table(table: Table | None, table_query: ast.SelectQuery) - return None selected_columns = node.type.columns - new_fields: Dict[str, FieldOrTable] = {} + new_fields: dict[str, FieldOrTable] = {} for name, field in selected_columns.items(): if isinstance(field, ast.FieldAliasType): underlying_field_name = field.alias @@ -145,7 +146,7 @@ def resolve_fields_on_table(table: Table | None, table_query: ast.SelectQuery) - # Return a new table with a reduced field set class AnonTable(Table): - fields: Dict[str, FieldOrTable] = new_fields + fields: dict[str, FieldOrTable] = new_fields def to_printed_hogql(self): # Use the base table name for resolving property definitions later @@ -184,8 +185,8 @@ def to_printed_hogql(self): return None -def get_tables_aliases(query: ast.SelectQuery, context: HogQLContext) -> Dict[str, ast.Table]: - tables: Dict[str, ast.Table] = {} +def get_tables_aliases(query: ast.SelectQuery, context: HogQLContext) -> dict[str, ast.Table]: + tables: dict[str, ast.Table] = {} if query.select_from is not None and query.select_from.alias is not None: table = get_table(context, query.select_from, query.ctes) @@ -207,7 +208,7 @@ def get_tables_aliases(query: ast.SelectQuery, context: HogQLContext) -> Dict[st # Replaces all ast.FieldTraverser with the underlying node def resolve_table_field_traversers(table: Table, context: HogQLContext) -> Table: new_table = deepcopy(table) - new_fields: Dict[str, FieldOrTable] = {} + new_fields: dict[str, FieldOrTable] = {} for key, field in list(new_table.fields.items()): if not isinstance(field, ast.FieldTraverser): new_fields[key] = field @@ -234,9 +235,9 @@ def resolve_table_field_traversers(table: Table, context: HogQLContext) -> Table return new_table -def append_table_field_to_response(table: Table, suggestions: List[AutocompleteCompletionItem]) -> None: - keys: List[str] = [] - details: List[str | None] = [] +def append_table_field_to_response(table: Table, suggestions: list[AutocompleteCompletionItem]) -> None: + keys: list[str] = [] + details: list[str | None] = [] table_fields = list(table.fields.items()) for field_name, field_or_table in table_fields: # Skip over hidden fields @@ -258,11 +259,11 @@ def append_table_field_to_response(table: Table, suggestions: List[AutocompleteC def extend_responses( - keys: List[str], - suggestions: List[AutocompleteCompletionItem], + keys: list[str], + suggestions: list[AutocompleteCompletionItem], kind: Kind = Kind.Variable, insert_text: Optional[Callable[[str], str]] = None, - details: Optional[List[str | None]] = None, + details: Optional[list[str | None]] = None, ) -> None: suggestions.extend( [ diff --git a/posthog/hogql/bytecode.py b/posthog/hogql/bytecode.py index 2be5c206cf327..f1abb9c4be0ed 100644 --- a/posthog/hogql/bytecode.py +++ b/posthog/hogql/bytecode.py @@ -1,4 +1,4 @@ -from typing import List, Any +from typing import Any from posthog.hogql import ast from posthog.hogql.errors import NotImplementedError @@ -39,13 +39,13 @@ } -def to_bytecode(expr: str) -> List[Any]: +def to_bytecode(expr: str) -> list[Any]: from posthog.hogql.parser import parse_expr return create_bytecode(parse_expr(expr)) -def create_bytecode(expr: ast.Expr) -> List[Any]: +def create_bytecode(expr: ast.Expr) -> list[Any]: bytecode = [HOGQL_BYTECODE_IDENTIFIER] bytecode.extend(BytecodeBuilder().visit(expr)) return bytecode diff --git a/posthog/hogql/constants.py b/posthog/hogql/constants.py index 3d933bca47eea..45c5b1e034c49 100644 --- a/posthog/hogql/constants.py +++ b/posthog/hogql/constants.py @@ -1,6 +1,6 @@ from datetime import date, datetime from enum import Enum -from typing import Optional, Literal, TypeAlias, Tuple, List +from typing import Optional, Literal, TypeAlias from uuid import UUID from pydantic import ConfigDict, BaseModel @@ -18,7 +18,7 @@ ] ConstantSupportedPrimitive: TypeAlias = int | float | str | bool | date | datetime | UUID | None ConstantSupportedData: TypeAlias = ( - ConstantSupportedPrimitive | List[ConstantSupportedPrimitive] | Tuple[ConstantSupportedPrimitive, ...] + ConstantSupportedPrimitive | list[ConstantSupportedPrimitive] | tuple[ConstantSupportedPrimitive, ...] ) # Keywords passed to ClickHouse without transformation diff --git a/posthog/hogql/context.py b/posthog/hogql/context.py index 68692323e059d..9b5b6092a6911 100644 --- a/posthog/hogql/context.py +++ b/posthog/hogql/context.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Any +from typing import TYPE_CHECKING, Literal, Optional, Any from posthog.hogql.timings import HogQLTimings from posthog.schema import HogQLNotice, HogQLQueryModifiers @@ -11,7 +11,7 @@ @dataclass class HogQLFieldAccess: - input: List[str] + input: list[str] type: Optional[Literal["event", "event.properties", "person", "person.properties"]] field: Optional[str] sql: str @@ -28,7 +28,7 @@ class HogQLContext: # Virtual database we're querying, will be populated from team_id if not present database: Optional["Database"] = None # If set, will save string constants to this dict. Inlines strings into the query if None. - values: Dict = field(default_factory=dict) + values: dict = field(default_factory=dict) # Are we small part of a non-HogQL query? If so, use custom syntax for accessed person properties. within_non_hogql_query: bool = False # Enable full SELECT queries and subqueries in ClickHouse @@ -39,9 +39,9 @@ class HogQLContext: max_view_depth: int = 1 # Warnings returned with the metadata query - warnings: List["HogQLNotice"] = field(default_factory=list) + warnings: list["HogQLNotice"] = field(default_factory=list) # Notices returned with the metadata query - notices: List["HogQLNotice"] = field(default_factory=list) + notices: list["HogQLNotice"] = field(default_factory=list) # Timings in seconds for different parts of the HogQL query timings: HogQLTimings = field(default_factory=HogQLTimings) # Modifications requested by the HogQL client diff --git a/posthog/hogql/database/argmax.py b/posthog/hogql/database/argmax.py index 5872dc77d8b44..b6c8e3d853bf8 100644 --- a/posthog/hogql/database/argmax.py +++ b/posthog/hogql/database/argmax.py @@ -1,10 +1,11 @@ -from typing import Callable, List, Optional, Dict +from typing import Optional +from collections.abc import Callable def argmax_select( table_name: str, - select_fields: Dict[str, List[str | int]], - group_fields: List[str], + select_fields: dict[str, list[str | int]], + group_fields: list[str], argmax_field: str, deleted_field: Optional[str] = None, ): @@ -14,8 +15,8 @@ def argmax_select( name="argMax", args=[field, ast.Field(chain=[table_name, argmax_field])] ) - fields_to_group: List[ast.Expr] = [] - fields_to_select: List[ast.Expr] = [] + fields_to_group: list[ast.Expr] = [] + fields_to_select: list[ast.Expr] = [] for name, chain in select_fields.items(): if name not in group_fields: fields_to_select.append( diff --git a/posthog/hogql/database/database.py b/posthog/hogql/database/database.py index 74837d21f49b3..fc1f665cf3978 100644 --- a/posthog/hogql/database/database.py +++ b/posthog/hogql/database/database.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional, TypedDict +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, TypedDict from zoneinfo import ZoneInfo, ZoneInfoNotFoundError from pydantic import ConfigDict, BaseModel from sentry_sdk import capture_exception @@ -96,7 +96,7 @@ class Database(BaseModel): numbers: NumbersTable = NumbersTable() # clunky: keep table names in sync with above - _table_names: ClassVar[List[str]] = [ + _table_names: ClassVar[list[str]] = [ "events", "groups", "persons", @@ -109,7 +109,7 @@ class Database(BaseModel): "sessions", ] - _warehouse_table_names: List[str] = [] + _warehouse_table_names: list[str] = [] _timezone: Optional[str] _week_start_day: Optional[WeekStartDay] @@ -136,7 +136,7 @@ def get_table(self, table_name: str) -> Table: return getattr(self, table_name) raise QueryError(f'Unknown table "{table_name}".') - def get_all_tables(self) -> List[str]: + def get_all_tables(self) -> list[str]: return self._table_names + self._warehouse_table_names def add_warehouse_tables(self, **field_definitions: Any): @@ -226,7 +226,7 @@ def create_hogql_database( if database.events.fields.get(mapping.group_type) is None: database.events.fields[mapping.group_type] = FieldTraverser(chain=[f"group_{mapping.group_type_index}"]) - tables: Dict[str, Table] = {} + tables: dict[str, Table] = {} for table in DataWarehouseTable.objects.filter(team_id=team.pk).exclude(deleted=True): tables[table.name] = table.hogql_definition() @@ -362,35 +362,35 @@ class _SerializedFieldBase(TypedDict): class SerializedField(_SerializedFieldBase, total=False): - fields: List[str] + fields: list[str] table: str - chain: List[str | int] + chain: list[str | int] -def serialize_database(context: HogQLContext) -> Dict[str, List[SerializedField]]: - tables: Dict[str, List[SerializedField]] = {} +def serialize_database(context: HogQLContext) -> dict[str, list[SerializedField]]: + tables: dict[str, list[SerializedField]] = {} if context.database is None: raise ResolutionError("Must provide database to serialize_database") for table_key in context.database.model_fields.keys(): - field_input: Dict[str, Any] = {} + field_input: dict[str, Any] = {} table = getattr(context.database, table_key, None) if isinstance(table, FunctionCallTable): field_input = table.get_asterisk() elif isinstance(table, Table): field_input = table.fields - field_output: List[SerializedField] = serialize_fields(field_input, context) + field_output: list[SerializedField] = serialize_fields(field_input, context) tables[table_key] = field_output return tables -def serialize_fields(field_input, context: HogQLContext) -> List[SerializedField]: +def serialize_fields(field_input, context: HogQLContext) -> list[SerializedField]: from posthog.hogql.database.models import SavedQuery - field_output: List[SerializedField] = [] + field_output: list[SerializedField] = [] for field_key, field in field_input.items(): if field_key == "team_id": pass diff --git a/posthog/hogql/database/models.py b/posthog/hogql/database/models.py index f6e985d92b4d7..34bec54eca32b 100644 --- a/posthog/hogql/database/models.py +++ b/posthog/hogql/database/models.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING +from collections.abc import Callable from pydantic import ConfigDict, BaseModel from posthog.hogql.base import Expr @@ -65,11 +66,11 @@ class ExpressionField(DatabaseField): class FieldTraverser(FieldOrTable): model_config = ConfigDict(extra="forbid") - chain: List[str | int] + chain: list[str | int] class Table(FieldOrTable): - fields: Dict[str, FieldOrTable] + fields: dict[str, FieldOrTable] model_config = ConfigDict(extra="forbid") def has_field(self, name: str | int) -> bool: @@ -87,12 +88,12 @@ def to_printed_clickhouse(self, context: "HogQLContext") -> str: def to_printed_hogql(self) -> str: raise NotImplementedError("Table.to_printed_hogql not overridden") - def avoid_asterisk_fields(self) -> List[str]: + def avoid_asterisk_fields(self) -> list[str]: return [] def get_asterisk(self): fields_to_avoid = [*self.avoid_asterisk_fields(), "team_id"] - asterisk: Dict[str, FieldOrTable] = {} + asterisk: dict[str, FieldOrTable] = {} for key, field in self.fields.items(): if key in fields_to_avoid: continue @@ -109,10 +110,10 @@ def get_asterisk(self): class LazyJoin(FieldOrTable): model_config = ConfigDict(extra="forbid") - join_function: Callable[[str, str, Dict[str, Any], "HogQLContext", "SelectQuery"], Any] + join_function: Callable[[str, str, dict[str, Any], "HogQLContext", "SelectQuery"], Any] join_table: Table | str - from_field: List[str | int] - to_field: Optional[List[str | int]] = None + from_field: list[str | int] + to_field: Optional[list[str | int]] = None def resolve_table(self, context: "HogQLContext") -> Table: if isinstance(self.join_table, Table): @@ -132,7 +133,7 @@ class LazyTable(Table): model_config = ConfigDict(extra="forbid") def lazy_select( - self, requested_fields: Dict[str, List[str | int]], context: "HogQLContext", node: "SelectQuery" + self, requested_fields: dict[str, list[str | int]], context: "HogQLContext", node: "SelectQuery" ) -> Any: raise NotImplementedError("LazyTable.lazy_select not overridden") diff --git a/posthog/hogql/database/schema/cohort_people.py b/posthog/hogql/database/schema/cohort_people.py index c556903d40cdf..255779aef5902 100644 --- a/posthog/hogql/database/schema/cohort_people.py +++ b/posthog/hogql/database/schema/cohort_people.py @@ -1,5 +1,3 @@ -from typing import Dict, List - from posthog.hogql.database.models import ( StringDatabaseField, IntegerDatabaseField, @@ -22,7 +20,7 @@ } -def select_from_cohort_people_table(requested_fields: Dict[str, List[str | int]], team_id: int): +def select_from_cohort_people_table(requested_fields: dict[str, list[str | int]], team_id: int): from posthog.hogql import ast from posthog.models import Cohort @@ -39,7 +37,7 @@ def select_from_cohort_people_table(requested_fields: Dict[str, List[str | int]] if "cohort_id" not in requested_fields: requested_fields = {**requested_fields, "cohort_id": ["cohort_id"]} - fields: List[ast.Expr] = [ + fields: list[ast.Expr] = [ ast.Alias(alias=name, expr=ast.Field(chain=[table_name, *chain])) for name, chain in requested_fields.items() ] @@ -60,7 +58,7 @@ def select_from_cohort_people_table(requested_fields: Dict[str, List[str | int]] class RawCohortPeople(Table): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { **COHORT_PEOPLE_FIELDS, "sign": IntegerDatabaseField(name="sign"), "version": IntegerDatabaseField(name="version"), @@ -74,9 +72,9 @@ def to_printed_hogql(self): class CohortPeople(LazyTable): - fields: Dict[str, FieldOrTable] = COHORT_PEOPLE_FIELDS + fields: dict[str, FieldOrTable] = COHORT_PEOPLE_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): + def lazy_select(self, requested_fields: dict[str, list[str | int]], context, node): return select_from_cohort_people_table(requested_fields, context.team_id) def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/event_sessions.py b/posthog/hogql/database/schema/event_sessions.py index 31682981ea3ea..fc03357884a6d 100644 --- a/posthog/hogql/database/schema/event_sessions.py +++ b/posthog/hogql/database/schema/event_sessions.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Any, Dict, List, Optional +from typing import Any, Optional from posthog.hogql import ast from posthog.hogql.context import HogQLContext from posthog.hogql.database.models import ( @@ -14,7 +14,7 @@ class EventsSessionSubTable(VirtualTable): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { "id": StringDatabaseField(name="$session_id"), "duration": IntegerDatabaseField(name="session_duration"), } @@ -27,7 +27,7 @@ def to_printed_hogql(self): class GetFieldsTraverser(TraversingVisitor): - fields: List[ast.Field] + fields: list[ast.Field] def __init__(self, expr: ast.Expr): super().__init__() @@ -71,7 +71,7 @@ def visit_field_type(self, node: ast.FieldType): class WhereClauseExtractor: - compare_operators: List[ast.Expr] + compare_operators: list[ast.Expr] def __init__( self, @@ -123,10 +123,10 @@ def _is_field_on_table(self, field: ast.Field) -> bool: return True - def run(self, expr: ast.Expr) -> List[ast.Expr]: - exprs_to_apply: List[ast.Expr] = [] + def run(self, expr: ast.Expr) -> list[ast.Expr]: + exprs_to_apply: list[ast.Expr] = [] - def should_add(expression: ast.Expr, fields: List[ast.Field]) -> bool: + def should_add(expression: ast.Expr, fields: list[ast.Field]) -> bool: for field in fields: on_table = self._is_field_on_table(field) if not on_table: @@ -168,7 +168,7 @@ def should_add(expression: ast.Expr, fields: List[ast.Field]) -> bool: def join_with_events_table_session_duration( from_table: str, to_table: str, - requested_fields: Dict[str, Any], + requested_fields: dict[str, Any], context: HogQLContext, node: ast.SelectQuery, ): diff --git a/posthog/hogql/database/schema/events.py b/posthog/hogql/database/schema/events.py index 88f59a11fd7ef..34941de0ec92a 100644 --- a/posthog/hogql/database/schema/events.py +++ b/posthog/hogql/database/schema/events.py @@ -1,5 +1,3 @@ -from typing import Dict - from posthog.hogql.database.models import ( VirtualTable, StringDatabaseField, @@ -20,7 +18,7 @@ class EventsPersonSubTable(VirtualTable): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { "id": StringDatabaseField(name="person_id"), "created_at": DateTimeDatabaseField(name="person_created_at"), "properties": StringJSONDatabaseField(name="person_properties"), @@ -54,7 +52,7 @@ def to_printed_hogql(self): class EventsTable(Table): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { "uuid": StringDatabaseField(name="uuid"), "event": StringDatabaseField(name="event"), "properties": StringJSONDatabaseField(name="properties"), diff --git a/posthog/hogql/database/schema/groups.py b/posthog/hogql/database/schema/groups.py index ad97ff7eb0878..06fc40560b7db 100644 --- a/posthog/hogql/database/schema/groups.py +++ b/posthog/hogql/database/schema/groups.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any from posthog.hogql.ast import SelectQuery from posthog.hogql.context import HogQLContext @@ -24,7 +24,7 @@ } -def select_from_groups_table(requested_fields: Dict[str, List[str | int]]): +def select_from_groups_table(requested_fields: dict[str, list[str | int]]): return argmax_select( table_name="raw_groups", select_fields=requested_fields, @@ -37,7 +37,7 @@ def join_with_group_n_table(group_index: int): def join_with_group_table( from_table: str, to_table: str, - requested_fields: Dict[str, Any], + requested_fields: dict[str, Any], context: HogQLContext, node: SelectQuery, ): @@ -70,7 +70,7 @@ def join_with_group_table( class RawGroupsTable(Table): - fields: Dict[str, FieldOrTable] = GROUPS_TABLE_FIELDS + fields: dict[str, FieldOrTable] = GROUPS_TABLE_FIELDS def to_printed_clickhouse(self, context): return "groups" @@ -80,9 +80,9 @@ def to_printed_hogql(self): class GroupsTable(LazyTable): - fields: Dict[str, FieldOrTable] = GROUPS_TABLE_FIELDS + fields: dict[str, FieldOrTable] = GROUPS_TABLE_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): + def lazy_select(self, requested_fields: dict[str, list[str | int]], context, node): return select_from_groups_table(requested_fields) def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/heatmaps.py b/posthog/hogql/database/schema/heatmaps.py index 6041926f5366f..959117baef874 100644 --- a/posthog/hogql/database/schema/heatmaps.py +++ b/posthog/hogql/database/schema/heatmaps.py @@ -1,5 +1,3 @@ -from typing import Dict - from posthog.hogql.database.models import ( StringDatabaseField, DateTimeDatabaseField, @@ -11,7 +9,7 @@ class HeatmapsTable(Table): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { "session_id": StringDatabaseField(name="session_id"), "team_id": IntegerDatabaseField(name="team_id"), "distinct_id": StringDatabaseField(name="distinct_id"), diff --git a/posthog/hogql/database/schema/log_entries.py b/posthog/hogql/database/schema/log_entries.py index 14efaff09ce1f..edd2f761981c3 100644 --- a/posthog/hogql/database/schema/log_entries.py +++ b/posthog/hogql/database/schema/log_entries.py @@ -1,5 +1,3 @@ -from typing import Dict, List - from posthog.hogql import ast from posthog.hogql.database.models import ( Table, @@ -10,7 +8,7 @@ FieldOrTable, ) -LOG_ENTRIES_FIELDS: Dict[str, FieldOrTable] = { +LOG_ENTRIES_FIELDS: dict[str, FieldOrTable] = { "team_id": IntegerDatabaseField(name="team_id"), "log_source": StringDatabaseField(name="log_source"), "log_source_id": StringDatabaseField(name="log_source_id"), @@ -22,7 +20,7 @@ class LogEntriesTable(Table): - fields: Dict[str, FieldOrTable] = LOG_ENTRIES_FIELDS + fields: dict[str, FieldOrTable] = LOG_ENTRIES_FIELDS def to_printed_clickhouse(self, context): return "log_entries" @@ -32,10 +30,10 @@ def to_printed_hogql(self): class ReplayConsoleLogsLogEntriesTable(LazyTable): - fields: Dict[str, FieldOrTable] = LOG_ENTRIES_FIELDS + fields: dict[str, FieldOrTable] = LOG_ENTRIES_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): - fields: List[ast.Expr] = [ast.Field(chain=["log_entries", *chain]) for name, chain in requested_fields.items()] + def lazy_select(self, requested_fields: dict[str, list[str | int]], context, node): + fields: list[ast.Expr] = [ast.Field(chain=["log_entries", *chain]) for name, chain in requested_fields.items()] return ast.SelectQuery( select=fields, @@ -55,10 +53,10 @@ def to_printed_hogql(self): class BatchExportLogEntriesTable(LazyTable): - fields: Dict[str, FieldOrTable] = LOG_ENTRIES_FIELDS + fields: dict[str, FieldOrTable] = LOG_ENTRIES_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): - fields: List[ast.Expr] = [ast.Field(chain=["log_entries", *chain]) for name, chain in requested_fields.items()] + def lazy_select(self, requested_fields: dict[str, list[str | int]], context, node): + fields: list[ast.Expr] = [ast.Field(chain=["log_entries", *chain]) for name, chain in requested_fields.items()] return ast.SelectQuery( select=fields, diff --git a/posthog/hogql/database/schema/numbers.py b/posthog/hogql/database/schema/numbers.py index 01c09ac66d797..7590e4041c1d5 100644 --- a/posthog/hogql/database/schema/numbers.py +++ b/posthog/hogql/database/schema/numbers.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Optional from posthog.hogql.database.models import ( IntegerDatabaseField, @@ -12,7 +12,7 @@ class NumbersTable(FunctionCallTable): - fields: Dict[str, FieldOrTable] = NUMBERS_TABLE_FIELDS + fields: dict[str, FieldOrTable] = NUMBERS_TABLE_FIELDS name: str = "numbers" min_args: Optional[int] = 1 diff --git a/posthog/hogql/database/schema/person_distinct_id_overrides.py b/posthog/hogql/database/schema/person_distinct_id_overrides.py index 6045e74ff7679..209c73c346e40 100644 --- a/posthog/hogql/database/schema/person_distinct_id_overrides.py +++ b/posthog/hogql/database/schema/person_distinct_id_overrides.py @@ -1,4 +1,3 @@ -from typing import Dict, List from posthog.hogql.ast import SelectQuery from posthog.hogql.context import HogQLContext @@ -27,7 +26,7 @@ } -def select_from_person_distinct_id_overrides_table(requested_fields: Dict[str, List[str | int]]): +def select_from_person_distinct_id_overrides_table(requested_fields: dict[str, list[str | int]]): # Always include "person_id", as it's the key we use to make further joins, and it'd be great if it's available if "person_id" not in requested_fields: requested_fields = {**requested_fields, "person_id": ["person_id"]} @@ -43,7 +42,7 @@ def select_from_person_distinct_id_overrides_table(requested_fields: Dict[str, L def join_with_person_distinct_id_overrides_table( from_table: str, to_table: str, - requested_fields: Dict[str, List[str]], + requested_fields: dict[str, list[str]], context: HogQLContext, node: SelectQuery, ): @@ -65,7 +64,7 @@ def join_with_person_distinct_id_overrides_table( class RawPersonDistinctIdOverridesTable(Table): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { **PERSON_DISTINCT_ID_OVERRIDES_FIELDS, "is_deleted": BooleanDatabaseField(name="is_deleted"), "version": IntegerDatabaseField(name="version"), @@ -79,9 +78,9 @@ def to_printed_hogql(self): class PersonDistinctIdOverridesTable(LazyTable): - fields: Dict[str, FieldOrTable] = PERSON_DISTINCT_ID_OVERRIDES_FIELDS + fields: dict[str, FieldOrTable] = PERSON_DISTINCT_ID_OVERRIDES_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], context: HogQLContext, node: SelectQuery): + def lazy_select(self, requested_fields: dict[str, list[str | int]], context: HogQLContext, node: SelectQuery): return select_from_person_distinct_id_overrides_table(requested_fields) def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/person_distinct_ids.py b/posthog/hogql/database/schema/person_distinct_ids.py index dde1f97c27922..9fa00c59c2985 100644 --- a/posthog/hogql/database/schema/person_distinct_ids.py +++ b/posthog/hogql/database/schema/person_distinct_ids.py @@ -1,4 +1,3 @@ -from typing import Dict, List from posthog.hogql.ast import SelectQuery from posthog.hogql.context import HogQLContext @@ -27,7 +26,7 @@ } -def select_from_person_distinct_ids_table(requested_fields: Dict[str, List[str | int]]): +def select_from_person_distinct_ids_table(requested_fields: dict[str, list[str | int]]): # Always include "person_id", as it's the key we use to make further joins, and it'd be great if it's available if "person_id" not in requested_fields: requested_fields = {**requested_fields, "person_id": ["person_id"]} @@ -43,7 +42,7 @@ def select_from_person_distinct_ids_table(requested_fields: Dict[str, List[str | def join_with_person_distinct_ids_table( from_table: str, to_table: str, - requested_fields: Dict[str, List[str]], + requested_fields: dict[str, list[str]], context: HogQLContext, node: SelectQuery, ): @@ -65,7 +64,7 @@ def join_with_person_distinct_ids_table( class RawPersonDistinctIdsTable(Table): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { **PERSON_DISTINCT_IDS_FIELDS, "is_deleted": BooleanDatabaseField(name="is_deleted"), "version": IntegerDatabaseField(name="version"), @@ -79,9 +78,9 @@ def to_printed_hogql(self): class PersonDistinctIdsTable(LazyTable): - fields: Dict[str, FieldOrTable] = PERSON_DISTINCT_IDS_FIELDS + fields: dict[str, FieldOrTable] = PERSON_DISTINCT_IDS_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): + def lazy_select(self, requested_fields: dict[str, list[str | int]], context, node): return select_from_person_distinct_ids_table(requested_fields) def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/person_overrides.py b/posthog/hogql/database/schema/person_overrides.py index 559ddd3a8013d..366321cf65e41 100644 --- a/posthog/hogql/database/schema/person_overrides.py +++ b/posthog/hogql/database/schema/person_overrides.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any from posthog.hogql.ast import SelectQuery from posthog.hogql.context import HogQLContext @@ -14,7 +14,7 @@ from posthog.hogql.errors import ResolutionError from posthog.schema import HogQLQueryModifiers -PERSON_OVERRIDES_FIELDS: Dict[str, FieldOrTable] = { +PERSON_OVERRIDES_FIELDS: dict[str, FieldOrTable] = { "team_id": IntegerDatabaseField(name="team_id"), "old_person_id": StringDatabaseField(name="old_person_id"), "override_person_id": StringDatabaseField(name="override_person_id"), @@ -24,7 +24,7 @@ } -def select_from_person_overrides_table(requested_fields: Dict[str, List[str | int]]): +def select_from_person_overrides_table(requested_fields: dict[str, list[str | int]]): return argmax_select( table_name="raw_person_overrides", select_fields=requested_fields, @@ -36,7 +36,7 @@ def select_from_person_overrides_table(requested_fields: Dict[str, List[str | in def join_with_person_overrides_table( from_table: str, to_table: str, - requested_fields: Dict[str, Any], + requested_fields: dict[str, Any], context: HogQLContext, node: SelectQuery, ): @@ -59,7 +59,7 @@ def join_with_person_overrides_table( class RawPersonOverridesTable(Table): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { **PERSON_OVERRIDES_FIELDS, "version": IntegerDatabaseField(name="version"), } @@ -72,9 +72,9 @@ def to_printed_hogql(self): class PersonOverridesTable(Table): - fields: Dict[str, FieldOrTable] = PERSON_OVERRIDES_FIELDS + fields: dict[str, FieldOrTable] = PERSON_OVERRIDES_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], modifiers: HogQLQueryModifiers): + def lazy_select(self, requested_fields: dict[str, list[str | int]], modifiers: HogQLQueryModifiers): return select_from_person_overrides_table(requested_fields) def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/persons.py b/posthog/hogql/database/schema/persons.py index 189da1faee068..14884a7008f60 100644 --- a/posthog/hogql/database/schema/persons.py +++ b/posthog/hogql/database/schema/persons.py @@ -1,4 +1,3 @@ -from typing import Dict, List from posthog.hogql.ast import SelectQuery from posthog.hogql.constants import HogQLQuerySettings @@ -19,7 +18,7 @@ from posthog.hogql.database.schema.persons_pdi import PersonsPDITable, persons_pdi_join from posthog.schema import HogQLQueryModifiers, PersonsArgMaxVersion -PERSONS_FIELDS: Dict[str, FieldOrTable] = { +PERSONS_FIELDS: dict[str, FieldOrTable] = { "id": StringDatabaseField(name="id"), "created_at": DateTimeDatabaseField(name="created_at"), "team_id": IntegerDatabaseField(name="team_id"), @@ -33,7 +32,7 @@ } -def select_from_persons_table(requested_fields: Dict[str, List[str | int]], modifiers: HogQLQueryModifiers): +def select_from_persons_table(requested_fields: dict[str, list[str | int]], modifiers: HogQLQueryModifiers): version = modifiers.personsArgMaxVersion if version == PersonsArgMaxVersion.auto: version = PersonsArgMaxVersion.v1 @@ -85,7 +84,7 @@ def select_from_persons_table(requested_fields: Dict[str, List[str | int]], modi def join_with_persons_table( from_table: str, to_table: str, - requested_fields: Dict[str, List[str | int]], + requested_fields: dict[str, list[str | int]], context: HogQLContext, node: SelectQuery, ): @@ -107,7 +106,7 @@ def join_with_persons_table( class RawPersonsTable(Table): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { **PERSONS_FIELDS, "is_deleted": BooleanDatabaseField(name="is_deleted"), "version": IntegerDatabaseField(name="version"), @@ -121,9 +120,9 @@ def to_printed_hogql(self): class PersonsTable(LazyTable): - fields: Dict[str, FieldOrTable] = PERSONS_FIELDS + fields: dict[str, FieldOrTable] = PERSONS_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): + def lazy_select(self, requested_fields: dict[str, list[str | int]], context, node): return select_from_persons_table(requested_fields, context.modifiers) def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/persons_pdi.py b/posthog/hogql/database/schema/persons_pdi.py index 30fdadee67795..0e30b4e62d275 100644 --- a/posthog/hogql/database/schema/persons_pdi.py +++ b/posthog/hogql/database/schema/persons_pdi.py @@ -1,4 +1,3 @@ -from typing import Dict, List from posthog.hogql.ast import SelectQuery from posthog.hogql.context import HogQLContext @@ -14,7 +13,7 @@ # :NOTE: We already have person_distinct_ids.py, which most tables link to. This persons_pdi.py is a hack to # make "select persons.pdi.distinct_id from persons" work while avoiding circular imports. Don't use directly. -def persons_pdi_select(requested_fields: Dict[str, List[str | int]]): +def persons_pdi_select(requested_fields: dict[str, list[str | int]]): # Always include "person_id", as it's the key we use to make further joins, and it'd be great if it's available if "person_id" not in requested_fields: requested_fields = {**requested_fields, "person_id": ["person_id"]} @@ -32,7 +31,7 @@ def persons_pdi_select(requested_fields: Dict[str, List[str | int]]): def persons_pdi_join( from_table: str, to_table: str, - requested_fields: Dict[str, List[str | int]], + requested_fields: dict[str, list[str | int]], context: HogQLContext, node: SelectQuery, ): @@ -56,13 +55,13 @@ def persons_pdi_join( # :NOTE: We already have person_distinct_ids.py, which most tables link to. This persons_pdi.py is a hack to # make "select persons.pdi.distinct_id from persons" work while avoiding circular imports. Don't use directly. class PersonsPDITable(LazyTable): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { "team_id": IntegerDatabaseField(name="team_id"), "distinct_id": StringDatabaseField(name="distinct_id"), "person_id": StringDatabaseField(name="person_id"), } - def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): + def lazy_select(self, requested_fields: dict[str, list[str | int]], context, node): return persons_pdi_select(requested_fields) def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/session_replay_events.py b/posthog/hogql/database/schema/session_replay_events.py index a6f0fbed3bcf5..81f705af378d9 100644 --- a/posthog/hogql/database/schema/session_replay_events.py +++ b/posthog/hogql/database/schema/session_replay_events.py @@ -1,5 +1,3 @@ -from typing import Dict, List - from posthog.hogql.database.models import ( Table, StringDatabaseField, @@ -18,7 +16,7 @@ RAW_ONLY_FIELDS = ["min_first_timestamp", "max_last_timestamp"] -SESSION_REPLAY_EVENTS_COMMON_FIELDS: Dict[str, FieldOrTable] = { +SESSION_REPLAY_EVENTS_COMMON_FIELDS: dict[str, FieldOrTable] = { "session_id": StringDatabaseField(name="session_id"), "team_id": IntegerDatabaseField(name="team_id"), "distinct_id": StringDatabaseField(name="distinct_id"), @@ -46,14 +44,14 @@ class RawSessionReplayEventsTable(Table): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { **SESSION_REPLAY_EVENTS_COMMON_FIELDS, "min_first_timestamp": DateTimeDatabaseField(name="min_first_timestamp"), "max_last_timestamp": DateTimeDatabaseField(name="max_last_timestamp"), "first_url": DatabaseField(name="first_url"), } - def avoid_asterisk_fields(self) -> List[str]: + def avoid_asterisk_fields(self) -> list[str]: return ["first_url"] def to_printed_clickhouse(self, context): @@ -63,7 +61,7 @@ def to_printed_hogql(self): return "raw_session_replay_events" -def select_from_session_replay_events_table(requested_fields: Dict[str, List[str | int]]): +def select_from_session_replay_events_table(requested_fields: dict[str, list[str | int]]): from posthog.hogql import ast table_name = "raw_session_replay_events" @@ -85,8 +83,8 @@ def select_from_session_replay_events_table(requested_fields: Dict[str, List[str "message_count": ast.Call(name="sum", args=[ast.Field(chain=[table_name, "message_count"])]), } - select_fields: List[ast.Expr] = [] - group_by_fields: List[ast.Expr] = [] + select_fields: list[ast.Expr] = [] + group_by_fields: list[ast.Expr] = [] for name, chain in requested_fields.items(): if name in RAW_ONLY_FIELDS: @@ -107,14 +105,14 @@ def select_from_session_replay_events_table(requested_fields: Dict[str, List[str class SessionReplayEventsTable(LazyTable): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { **{k: v for k, v in SESSION_REPLAY_EVENTS_COMMON_FIELDS.items() if k not in RAW_ONLY_FIELDS}, "start_time": DateTimeDatabaseField(name="start_time"), "end_time": DateTimeDatabaseField(name="end_time"), "first_url": StringDatabaseField(name="first_url"), } - def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node): + def lazy_select(self, requested_fields: dict[str, list[str | int]], context, node): return select_from_session_replay_events_table(requested_fields) def to_printed_clickhouse(self, context): diff --git a/posthog/hogql/database/schema/sessions.py b/posthog/hogql/database/schema/sessions.py index e1fcaf1a75f06..0bd6bfef09caf 100644 --- a/posthog/hogql/database/schema/sessions.py +++ b/posthog/hogql/database/schema/sessions.py @@ -1,4 +1,4 @@ -from typing import Dict, List, cast, Any, TYPE_CHECKING +from typing import cast, Any, TYPE_CHECKING from posthog.hogql import ast from posthog.hogql.context import HogQLContext @@ -19,7 +19,7 @@ if TYPE_CHECKING: pass -RAW_SESSIONS_FIELDS: Dict[str, FieldOrTable] = { +RAW_SESSIONS_FIELDS: dict[str, FieldOrTable] = { "id": StringDatabaseField(name="session_id"), # TODO remove this, it's a duplicate of the correct session_id field below to get some trends working on a deadline "session_id": StringDatabaseField(name="session_id"), @@ -44,7 +44,7 @@ "autocapture_count": IntegerDatabaseField(name="autocapture_count"), } -LAZY_SESSIONS_FIELDS: Dict[str, FieldOrTable] = { +LAZY_SESSIONS_FIELDS: dict[str, FieldOrTable] = { "id": StringDatabaseField(name="session_id"), # TODO remove this, it's a duplicate of the correct session_id field below to get some trends working on a deadline "session_id": StringDatabaseField(name="session_id"), @@ -75,7 +75,7 @@ class RawSessionsTable(Table): - fields: Dict[str, FieldOrTable] = RAW_SESSIONS_FIELDS + fields: dict[str, FieldOrTable] = RAW_SESSIONS_FIELDS def to_printed_clickhouse(self, context): return "sessions" @@ -83,7 +83,7 @@ def to_printed_clickhouse(self, context): def to_printed_hogql(self): return "raw_sessions" - def avoid_asterisk_fields(self) -> List[str]: + def avoid_asterisk_fields(self) -> list[str]: # our clickhouse driver can't return aggregate states return [ "entry_url", @@ -100,7 +100,7 @@ def avoid_asterisk_fields(self) -> List[str]: def select_from_sessions_table( - requested_fields: Dict[str, List[str | int]], node: ast.SelectQuery, context: HogQLContext + requested_fields: dict[str, list[str | int]], node: ast.SelectQuery, context: HogQLContext ): from posthog.hogql import ast @@ -166,8 +166,8 @@ def select_from_sessions_table( } aggregate_fields["duration"] = aggregate_fields["$session_duration"] - select_fields: List[ast.Expr] = [] - group_by_fields: List[ast.Expr] = [ast.Field(chain=[table_name, "session_id"])] + select_fields: list[ast.Expr] = [] + group_by_fields: list[ast.Expr] = [ast.Field(chain=[table_name, "session_id"])] for name, chain in requested_fields.items(): if name in aggregate_fields: @@ -189,9 +189,9 @@ def select_from_sessions_table( class SessionsTable(LazyTable): - fields: Dict[str, FieldOrTable] = LAZY_SESSIONS_FIELDS + fields: dict[str, FieldOrTable] = LAZY_SESSIONS_FIELDS - def lazy_select(self, requested_fields: Dict[str, List[str | int]], context, node: ast.SelectQuery): + def lazy_select(self, requested_fields: dict[str, list[str | int]], context, node: ast.SelectQuery): return select_from_sessions_table(requested_fields, node, context) def to_printed_clickhouse(self, context): @@ -202,7 +202,7 @@ def to_printed_hogql(self): def join_events_table_to_sessions_table( - from_table: str, to_table: str, requested_fields: Dict[str, Any], context: HogQLContext, node: ast.SelectQuery + from_table: str, to_table: str, requested_fields: dict[str, Any], context: HogQLContext, node: ast.SelectQuery ) -> ast.JoinExpr: from posthog.hogql import ast diff --git a/posthog/hogql/database/schema/static_cohort_people.py b/posthog/hogql/database/schema/static_cohort_people.py index 97d90cbd6dcac..fafbe9459eb99 100644 --- a/posthog/hogql/database/schema/static_cohort_people.py +++ b/posthog/hogql/database/schema/static_cohort_people.py @@ -1,5 +1,3 @@ -from typing import Dict - from posthog.hogql.database.models import ( Table, StringDatabaseField, @@ -11,7 +9,7 @@ class StaticCohortPeople(Table): - fields: Dict[str, FieldOrTable] = { + fields: dict[str, FieldOrTable] = { "person_id": StringDatabaseField(name="person_id"), "cohort_id": IntegerDatabaseField(name="cohort_id"), "team_id": IntegerDatabaseField(name="team_id"), diff --git a/posthog/hogql/database/schema/test/test_event_sessions.py b/posthog/hogql/database/schema/test/test_event_sessions.py index 1a31bc3f4720d..914ac471236d6 100644 --- a/posthog/hogql/database/schema/test/test_event_sessions.py +++ b/posthog/hogql/database/schema/test/test_event_sessions.py @@ -1,4 +1,4 @@ -from typing import List, cast +from typing import cast from posthog.hogql import ast from posthog.hogql.context import HogQLContext from posthog.hogql.database.database import create_hogql_database @@ -21,7 +21,7 @@ def _select(self, query: str) -> ast.SelectQuery: select_query = cast(ast.SelectQuery, clone_expr(parse_select(query), clear_locations=True)) return cast(ast.SelectQuery, resolve_types(select_query, self.context, dialect="clickhouse")) - def _compare_operators(self, query: ast.SelectQuery, table_name: str) -> List[ast.Expr]: + def _compare_operators(self, query: ast.SelectQuery, table_name: str) -> list[ast.Expr]: assert query.where is not None and query.type is not None return WhereClauseExtractor(query.where, table_name, query.type, self.context).compare_operators diff --git a/posthog/hogql/database/schema/util/test/test_session_where_clause_extractor.py b/posthog/hogql/database/schema/util/test/test_session_where_clause_extractor.py index 1e3464c1b9bd6..ea8c55d054cad 100644 --- a/posthog/hogql/database/schema/util/test/test_session_where_clause_extractor.py +++ b/posthog/hogql/database/schema/util/test/test_session_where_clause_extractor.py @@ -1,4 +1,4 @@ -from typing import Union, Optional, Dict +from typing import Union, Optional from posthog.hogql import ast from posthog.hogql.context import HogQLContext @@ -22,7 +22,7 @@ def f(s: Union[str, ast.Expr, None], placeholders: Optional[dict[str, ast.Expr]] def parse( s: str, - placeholders: Optional[Dict[str, ast.Expr]] = None, + placeholders: Optional[dict[str, ast.Expr]] = None, ) -> ast.SelectQuery | ast.SelectUnionQuery: parsed = parse_select(s, placeholders=placeholders) return parsed diff --git a/posthog/hogql/escape_sql.py b/posthog/hogql/escape_sql.py index 35a563061e169..10f4a413fa60d 100644 --- a/posthog/hogql/escape_sql.py +++ b/posthog/hogql/escape_sql.py @@ -1,6 +1,6 @@ import re from datetime import datetime, date -from typing import Optional, Any, Literal, List, Tuple +from typing import Optional, Any, Literal from uuid import UUID from zoneinfo import ZoneInfo @@ -129,8 +129,8 @@ def visit_fakedate(self, value: date): def visit_date(self, value: date): return f"toDate({self.visit(value.strftime('%Y-%m-%d'))})" - def visit_list(self, value: List): + def visit_list(self, value: list): return f"[{', '.join(str(self.visit(x)) for x in value)}]" - def visit_tuple(self, value: Tuple): + def visit_tuple(self, value: tuple): return f"({', '.join(str(self.visit(x)) for x in value)})" diff --git a/posthog/hogql/filters.py b/posthog/hogql/filters.py index 496cadf8da417..06ea36c1cdd10 100644 --- a/posthog/hogql/filters.py +++ b/posthog/hogql/filters.py @@ -1,4 +1,4 @@ -from typing import List, Optional, TypeVar +from typing import Optional, TypeVar from dateutil.parser import isoparse @@ -23,7 +23,7 @@ def __init__(self, filters: Optional[HogQLFilters], team: Team = None): super().__init__() self.filters = filters self.team = team - self.selects: List[ast.SelectQuery] = [] + self.selects: list[ast.SelectQuery] = [] def visit_select_query(self, node): self.selects.append(node) @@ -51,7 +51,7 @@ def visit_placeholder(self, node): "Cannot use 'filters' placeholder in a SELECT clause that does not select from the events table." ) - exprs: List[ast.Expr] = [] + exprs: list[ast.Expr] = [] if self.filters.properties is not None: exprs.append(property_to_expr(self.filters.properties, self.team)) diff --git a/posthog/hogql/functions/action.py b/posthog/hogql/functions/action.py index 02888081632f3..5ed8a156e393b 100644 --- a/posthog/hogql/functions/action.py +++ b/posthog/hogql/functions/action.py @@ -1,12 +1,10 @@ -from typing import List - from posthog.hogql import ast from posthog.hogql.context import HogQLContext from posthog.hogql.errors import QueryError from posthog.hogql.escape_sql import escape_clickhouse_string -def matches_action(node: ast.Expr, args: List[ast.Expr], context: HogQLContext) -> ast.Expr: +def matches_action(node: ast.Expr, args: list[ast.Expr], context: HogQLContext) -> ast.Expr: arg = args[0] if not isinstance(arg, ast.Constant): raise QueryError("action() takes only constant arguments", node=arg) diff --git a/posthog/hogql/functions/cohort.py b/posthog/hogql/functions/cohort.py index fc5077f610a4f..2b0992c6e7ef9 100644 --- a/posthog/hogql/functions/cohort.py +++ b/posthog/hogql/functions/cohort.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from posthog.hogql import ast from posthog.hogql.context import HogQLContext @@ -23,7 +23,7 @@ def cohort_query_node(node: ast.Expr, context: HogQLContext) -> ast.Expr: return cohort(node, [node], context) -def cohort(node: ast.Expr, args: List[ast.Expr], context: HogQLContext) -> ast.Expr: +def cohort(node: ast.Expr, args: list[ast.Expr], context: HogQLContext) -> ast.Expr: arg = args[0] if not isinstance(arg, ast.Constant): raise QueryError("cohort() takes only constant arguments", node=arg) diff --git a/posthog/hogql/functions/mapping.py b/posthog/hogql/functions/mapping.py index 6080face6f675..c4087013c85c4 100644 --- a/posthog/hogql/functions/mapping.py +++ b/posthog/hogql/functions/mapping.py @@ -1,13 +1,13 @@ from dataclasses import dataclass from itertools import chain -from typing import List, Optional, Dict, Tuple, Type +from typing import Optional from posthog.hogql import ast from posthog.hogql.base import ConstantType from posthog.hogql.errors import QueryError def validate_function_args( - args: List[ast.Expr], + args: list[ast.Expr], min_args: int, max_args: Optional[int], function_name: str, @@ -31,7 +31,7 @@ def validate_function_args( ) -Overload = Tuple[Tuple[Type[ConstantType], ...] | Type[ConstantType], str] +Overload = tuple[tuple[type[ConstantType], ...] | type[ConstantType], str] @dataclass() @@ -42,7 +42,7 @@ class HogQLFunctionMeta: min_params: int = 0 max_params: Optional[int] = 0 aggregate: bool = False - overloads: Optional[List[Overload]] = None + overloads: Optional[list[Overload]] = None """Overloads allow for using a different ClickHouse function depending on the type of the first arg.""" tz_aware: bool = False """Whether the function is timezone-aware. This means the project timezone will be appended as the last arg.""" @@ -50,7 +50,7 @@ class HogQLFunctionMeta: """Not all ClickHouse functions are case-insensitive. See https://clickhouse.com/docs/en/sql-reference/syntax#keywords.""" -HOGQL_COMPARISON_MAPPING: Dict[str, ast.CompareOperationOp] = { +HOGQL_COMPARISON_MAPPING: dict[str, ast.CompareOperationOp] = { "equals": ast.CompareOperationOp.Eq, "notEquals": ast.CompareOperationOp.NotEq, "less": ast.CompareOperationOp.Lt, @@ -65,7 +65,7 @@ class HogQLFunctionMeta: "notIn": ast.CompareOperationOp.NotIn, } -HOGQL_CLICKHOUSE_FUNCTIONS: Dict[str, HogQLFunctionMeta] = { +HOGQL_CLICKHOUSE_FUNCTIONS: dict[str, HogQLFunctionMeta] = { # arithmetic "plus": HogQLFunctionMeta("plus", 2, 2), "minus": HogQLFunctionMeta("minus", 2, 2), @@ -575,7 +575,7 @@ class HogQLFunctionMeta: "leadInFrame": HogQLFunctionMeta("leadInFrame", 1, 1), } # Permitted HogQL aggregations -HOGQL_AGGREGATIONS: Dict[str, HogQLFunctionMeta] = { +HOGQL_AGGREGATIONS: dict[str, HogQLFunctionMeta] = { # Standard aggregate functions "count": HogQLFunctionMeta("count", 0, 1, aggregate=True, case_sensitive=False), "countIf": HogQLFunctionMeta("countIf", 1, 2, aggregate=True), @@ -747,7 +747,7 @@ class HogQLFunctionMeta: "maxIntersectionsPosition": HogQLFunctionMeta("maxIntersectionsPosition", 2, 2, aggregate=True), "maxIntersectionsPositionIf": HogQLFunctionMeta("maxIntersectionsPositionIf", 3, 3, aggregate=True), } -HOGQL_POSTHOG_FUNCTIONS: Dict[str, HogQLFunctionMeta] = { +HOGQL_POSTHOG_FUNCTIONS: dict[str, HogQLFunctionMeta] = { "matchesAction": HogQLFunctionMeta("matchesAction", 1, 1), "sparkline": HogQLFunctionMeta("sparkline", 1, 1), "hogql_lookupDomainType": HogQLFunctionMeta("hogql_lookupDomainType", 1, 1), @@ -781,7 +781,7 @@ class HogQLFunctionMeta: ) -def _find_function(name: str, functions: Dict[str, HogQLFunctionMeta]) -> Optional[HogQLFunctionMeta]: +def _find_function(name: str, functions: dict[str, HogQLFunctionMeta]) -> Optional[HogQLFunctionMeta]: func = functions.get(name) if func is not None: return func diff --git a/posthog/hogql/functions/sparkline.py b/posthog/hogql/functions/sparkline.py index ddd6c02a7b20e..5bbf9004f4425 100644 --- a/posthog/hogql/functions/sparkline.py +++ b/posthog/hogql/functions/sparkline.py @@ -1,9 +1,7 @@ -from typing import List - from posthog.hogql import ast -def sparkline(node: ast.Expr, args: List[ast.Expr]) -> ast.Expr: +def sparkline(node: ast.Expr, args: list[ast.Expr]) -> ast.Expr: return ast.Tuple( exprs=[ ast.Constant(value="__hogql_chart_type"), diff --git a/posthog/hogql/hogql.py b/posthog/hogql/hogql.py index d3052f58b01a1..2a537bfd7a8d6 100644 --- a/posthog/hogql/hogql.py +++ b/posthog/hogql/hogql.py @@ -1,4 +1,4 @@ -from typing import Dict, Literal, cast, Optional +from typing import Literal, cast, Optional from posthog.hogql import ast from posthog.hogql.context import HogQLContext @@ -18,7 +18,7 @@ def translate_hogql( metadata_source: Optional[ast.SelectQuery] = None, *, events_table_alias: Optional[str] = None, - placeholders: Optional[Dict[str, ast.Expr]] = None, + placeholders: Optional[dict[str, ast.Expr]] = None, ) -> str: """Translate a HogQL expression into a ClickHouse expression.""" if query == "": diff --git a/posthog/hogql/parser.py b/posthog/hogql/parser.py index 0ec619f338909..68637a30a208c 100644 --- a/posthog/hogql/parser.py +++ b/posthog/hogql/parser.py @@ -1,4 +1,5 @@ -from typing import Dict, List, Literal, Optional, cast, Callable +from typing import Literal, Optional, cast +from collections.abc import Callable from antlr4 import CommonTokenStream, InputStream, ParseTreeVisitor, ParserRuleContext from antlr4.error.ErrorListener import ErrorListener @@ -19,7 +20,7 @@ parse_select as _parse_select_cpp, ) -RULE_TO_PARSE_FUNCTION: Dict[Literal["python", "cpp"], Dict[Literal["expr", "order_expr", "select"], Callable]] = { +RULE_TO_PARSE_FUNCTION: dict[Literal["python", "cpp"], dict[Literal["expr", "order_expr", "select"], Callable]] = { "python": { "expr": lambda string, start: HogQLParseTreeConverter(start=start).visit(get_parser(string).expr()), "order_expr": lambda string: HogQLParseTreeConverter().visit(get_parser(string).orderExpr()), @@ -32,7 +33,7 @@ }, } -RULE_TO_HISTOGRAM: Dict[Literal["expr", "order_expr", "select"], Histogram] = { +RULE_TO_HISTOGRAM: dict[Literal["expr", "order_expr", "select"], Histogram] = { rule: Histogram( f"parse_{rule}_seconds", f"Time to parse {rule} expression", @@ -44,7 +45,7 @@ def parse_expr( expr: str, - placeholders: Optional[Dict[str, ast.Expr]] = None, + placeholders: Optional[dict[str, ast.Expr]] = None, start: Optional[int] = 0, timings: Optional[HogQLTimings] = None, *, @@ -65,7 +66,7 @@ def parse_expr( def parse_order_expr( order_expr: str, - placeholders: Optional[Dict[str, ast.Expr]] = None, + placeholders: Optional[dict[str, ast.Expr]] = None, timings: Optional[HogQLTimings] = None, *, backend: Optional[Literal["python", "cpp"]] = None, @@ -85,7 +86,7 @@ def parse_order_expr( def parse_select( statement: str, - placeholders: Optional[Dict[str, ast.Expr]] = None, + placeholders: Optional[dict[str, ast.Expr]] = None, timings: Optional[HogQLTimings] = None, *, backend: Optional[Literal["python", "cpp"]] = None, @@ -159,10 +160,10 @@ def visitSelect(self, ctx: HogQLParser.SelectContext): return self.visit(ctx.selectUnionStmt() or ctx.selectStmt() or ctx.hogqlxTagElement()) def visitSelectUnionStmt(self, ctx: HogQLParser.SelectUnionStmtContext): - select_queries: List[ast.SelectQuery | ast.SelectUnionQuery] = [ + select_queries: list[ast.SelectQuery | ast.SelectUnionQuery] = [ self.visit(select) for select in ctx.selectStmtWithParens() ] - flattened_queries: List[ast.SelectQuery] = [] + flattened_queries: list[ast.SelectQuery] = [] for query in select_queries: if isinstance(query, ast.SelectQuery): flattened_queries.append(query) @@ -771,7 +772,7 @@ def visitColumnLambdaExpr(self, ctx: HogQLParser.ColumnLambdaExprContext): ) def visitWithExprList(self, ctx: HogQLParser.WithExprListContext): - ctes: Dict[str, ast.CTE] = {} + ctes: dict[str, ast.CTE] = {} for expr in ctx.withExpr(): cte = self.visit(expr) ctes[cte.name] = cte diff --git a/posthog/hogql/placeholders.py b/posthog/hogql/placeholders.py index a09e39fd65680..d0e835fb0d853 100644 --- a/posthog/hogql/placeholders.py +++ b/posthog/hogql/placeholders.py @@ -1,15 +1,15 @@ -from typing import Dict, Optional, List +from typing import Optional from posthog.hogql import ast from posthog.hogql.errors import QueryError from posthog.hogql.visitor import CloningVisitor, TraversingVisitor -def replace_placeholders(node: ast.Expr, placeholders: Optional[Dict[str, ast.Expr]]) -> ast.Expr: +def replace_placeholders(node: ast.Expr, placeholders: Optional[dict[str, ast.Expr]]) -> ast.Expr: return ReplacePlaceholders(placeholders).visit(node) -def find_placeholders(node: ast.Expr) -> List[str]: +def find_placeholders(node: ast.Expr) -> list[str]: finder = FindPlaceholders() finder.visit(node) return list(finder.found) @@ -28,7 +28,7 @@ def visit_placeholder(self, node: ast.Placeholder): class ReplacePlaceholders(CloningVisitor): - def __init__(self, placeholders: Optional[Dict[str, ast.Expr]]): + def __init__(self, placeholders: Optional[dict[str, ast.Expr]]): super().__init__() self.placeholders = placeholders @@ -42,5 +42,5 @@ def visit_placeholder(self, node): return new_node raise QueryError( f"Placeholder {{{node.field}}} is not available in this context. You can use the following: " - + ", ".join((f"{placeholder}" for placeholder in self.placeholders)) + + ", ".join(f"{placeholder}" for placeholder in self.placeholders) ) diff --git a/posthog/hogql/printer.py b/posthog/hogql/printer.py index ff4766f86074a..a829697e9007a 100644 --- a/posthog/hogql/printer.py +++ b/posthog/hogql/printer.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from datetime import datetime, date from difflib import get_close_matches -from typing import List, Literal, Optional, Union, cast +from typing import Literal, Optional, Union, cast from uuid import UUID from posthog.hogql import ast @@ -73,7 +73,7 @@ def print_ast( node: ast.Expr, context: HogQLContext, dialect: Literal["hogql", "clickhouse"], - stack: Optional[List[ast.SelectQuery]] = None, + stack: Optional[list[ast.SelectQuery]] = None, settings: Optional[HogQLGlobalSettings] = None, pretty: bool = False, ) -> str: @@ -92,7 +92,7 @@ def prepare_ast_for_printing( node: ast.Expr, context: HogQLContext, dialect: Literal["hogql", "clickhouse"], - stack: Optional[List[ast.SelectQuery]] = None, + stack: Optional[list[ast.SelectQuery]] = None, settings: Optional[HogQLGlobalSettings] = None, ) -> ast.Expr: with context.timings.measure("create_hogql_database"): @@ -130,7 +130,7 @@ def print_prepared_ast( node: ast.Expr, context: HogQLContext, dialect: Literal["hogql", "clickhouse"], - stack: Optional[List[ast.SelectQuery]] = None, + stack: Optional[list[ast.SelectQuery]] = None, settings: Optional[HogQLGlobalSettings] = None, pretty: bool = False, ) -> str: @@ -158,13 +158,13 @@ def __init__( self, context: HogQLContext, dialect: Literal["hogql", "clickhouse"], - stack: Optional[List[AST]] = None, + stack: Optional[list[AST]] = None, settings: Optional[HogQLGlobalSettings] = None, pretty: bool = False, ): self.context = context self.dialect = dialect - self.stack: List[AST] = stack or [] # Keep track of all traversed nodes. + self.stack: list[AST] = stack or [] # Keep track of all traversed nodes. self.settings = settings self.pretty = pretty self._indent = -1 @@ -773,7 +773,7 @@ def visit_call(self, node: ast.Call): if self.dialect == "clickhouse": if node.name in FIRST_ARG_DATETIME_FUNCTIONS: - args: List[str] = [] + args: list[str] = [] for idx, arg in enumerate(node.args): if idx == 0: if isinstance(arg, ast.Call) and arg.name in ADD_OR_NULL_DATETIME_FUNCTIONS: @@ -783,7 +783,7 @@ def visit_call(self, node: ast.Call): else: args.append(self.visit(arg)) elif node.name == "concat": - args: List[str] = [] + args: list[str] = [] for arg in node.args: if isinstance(arg, ast.Constant): if arg.value is None: @@ -1002,7 +1002,7 @@ def visit_property_type(self, type: ast.PropertyType): while isinstance(table, ast.TableAliasType): table = table.table_type - args: List[str] = [] + args: list[str] = [] if self.context.modifiers.materializationMode != "disabled": # find a materialized property for the first part of the chain @@ -1094,7 +1094,7 @@ def visit_unknown(self, node: AST): raise ImpossibleASTError(f"Unknown AST node {type(node).__name__}") def visit_window_expr(self, node: ast.WindowExpr): - strings: List[str] = [] + strings: list[str] = [] if node.partition_by is not None: if len(node.partition_by) == 0: raise ImpossibleASTError("PARTITION BY must have at least one argument") @@ -1168,7 +1168,7 @@ def _print_escaped_string(self, name: float | int | str | list | tuple | datetim return escape_clickhouse_string(name, timezone=self._get_timezone()) return escape_hogql_string(name, timezone=self._get_timezone()) - def _unsafe_json_extract_trim_quotes(self, unsafe_field: str, unsafe_args: List[str]) -> str: + def _unsafe_json_extract_trim_quotes(self, unsafe_field: str, unsafe_args: list[str]) -> str: return f"replaceRegexpAll(nullIf(nullIf(JSONExtractRaw({', '.join([unsafe_field, *unsafe_args])}), ''), 'null'), '^\"|\"$', '')" def _get_materialized_column( @@ -1209,7 +1209,7 @@ def _print_settings(self, settings): for key, value in settings: if value is None: continue - if not isinstance(value, (int, float, str)): + if not isinstance(value, int | float | str): raise QueryError(f"Setting {key} must be a string, int, or float") if not re.match(r"^[a-zA-Z0-9_]+$", key): raise QueryError(f"Setting {key} is not supported") diff --git a/posthog/hogql/property.py b/posthog/hogql/property.py index fb5e6d90a459f..824a11bdae94d 100644 --- a/posthog/hogql/property.py +++ b/posthog/hogql/property.py @@ -1,5 +1,5 @@ import re -from typing import List, Optional, Union, cast, Literal +from typing import Optional, Union, cast, Literal from pydantic import BaseModel @@ -382,7 +382,7 @@ def action_to_expr(action: Action) -> ast.Expr: or_queries = [] for step in steps: - exprs: List[ast.Expr] = [] + exprs: list[ast.Expr] = [] if step.event: exprs.append(parse_expr("event = {event}", {"event": ast.Constant(value=step.event)})) diff --git a/posthog/hogql/query.py b/posthog/hogql/query.py index 65c0c9d71356f..b42a61b785541 100644 --- a/posthog/hogql/query.py +++ b/posthog/hogql/query.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Dict, Optional, Union, cast +from typing import Optional, Union, cast from posthog.clickhouse.client.connection import Workload from posthog.errors import ExposedCHQueryError @@ -32,7 +32,7 @@ def execute_hogql_query( *, query_type: str = "hogql_query", filters: Optional[HogQLFilters] = None, - placeholders: Optional[Dict[str, ast.Expr]] = None, + placeholders: Optional[dict[str, ast.Expr]] = None, workload: Workload = Workload.ONLINE, settings: Optional[HogQLGlobalSettings] = None, modifiers: Optional[HogQLQueryModifiers] = None, @@ -175,7 +175,7 @@ def execute_hogql_query( except Exception as e: if explain: results, types = None, None - if isinstance(e, (ExposedCHQueryError, ExposedHogQLError)): + if isinstance(e, ExposedCHQueryError | ExposedHogQLError): error = str(e) else: error = "Unknown error" diff --git a/posthog/hogql/resolver.py b/posthog/hogql/resolver.py index fce251dc8a08d..5921e5a6f2d94 100644 --- a/posthog/hogql/resolver.py +++ b/posthog/hogql/resolver.py @@ -1,5 +1,5 @@ from datetime import date, datetime -from typing import List, Optional, Any, cast, Literal +from typing import Optional, Any, cast, Literal from uuid import UUID from posthog.hogql import ast @@ -58,7 +58,7 @@ def resolve_types( node: ast.Expr, context: HogQLContext, dialect: Literal["hogql", "clickhouse"], - scopes: Optional[List[ast.SelectQueryType]] = None, + scopes: Optional[list[ast.SelectQueryType]] = None, ) -> ast.Expr: return Resolver(scopes=scopes, context=context, dialect=dialect).visit(node) @@ -66,7 +66,7 @@ def resolve_types( class AliasCollector(TraversingVisitor): def __init__(self): super().__init__() - self.aliases: List[str] = [] + self.aliases: list[str] = [] def visit_alias(self, node: ast.Alias): self.aliases.append(node.alias) @@ -80,11 +80,11 @@ def __init__( self, context: HogQLContext, dialect: Literal["hogql", "clickhouse"] = "clickhouse", - scopes: Optional[List[ast.SelectQueryType]] = None, + scopes: Optional[list[ast.SelectQueryType]] = None, ): super().__init__() # Each SELECT query creates a new scope (type). Store all of them in a list as we traverse the tree. - self.scopes: List[ast.SelectQueryType] = scopes or [] + self.scopes: list[ast.SelectQueryType] = scopes or [] self.current_view_depth: int = 0 self.context = context self.dialect = dialect @@ -214,7 +214,7 @@ def visit_select_query(self, node: ast.SelectQuery): return new_node - def _asterisk_columns(self, asterisk: ast.AsteriskType) -> List[ast.Expr]: + def _asterisk_columns(self, asterisk: ast.AsteriskType) -> list[ast.Expr]: """Expand an asterisk. Mutates `select_query.select` and `select_query.type.columns` with the new fields""" if isinstance(asterisk.table_type, ast.BaseTableType): table = asterisk.table_type.resolve_database_table(self.context) @@ -393,13 +393,13 @@ def visit_call(self, node: ast.Call): return self.visit(matches_action(node=node, args=node.args, context=self.context)) node = super().visit_call(node) - arg_types: List[ast.ConstantType] = [] + arg_types: list[ast.ConstantType] = [] for arg in node.args: if arg.type: arg_types.append(arg.type.resolve_constant_type(self.context) or ast.UnknownType()) else: arg_types.append(ast.UnknownType()) - param_types: Optional[List[ast.ConstantType]] = None + param_types: Optional[list[ast.ConstantType]] = None if node.params is not None: param_types = [] for param in node.params: diff --git a/posthog/hogql/resolver_utils.py b/posthog/hogql/resolver_utils.py index 7910a17fdb92e..bfede9538ab64 100644 --- a/posthog/hogql/resolver_utils.py +++ b/posthog/hogql/resolver_utils.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from posthog import schema from posthog.hogql import ast @@ -27,7 +27,7 @@ def lookup_field_by_name(scope: ast.SelectQueryType, name: str, context: HogQLCo return None -def lookup_cte_by_name(scopes: List[ast.SelectQueryType], name: str) -> Optional[ast.CTE]: +def lookup_cte_by_name(scopes: list[ast.SelectQueryType], name: str) -> Optional[ast.CTE]: for scope in reversed(scopes): if scope and scope.ctes and name in scope.ctes: return scope.ctes[name] diff --git a/posthog/hogql/test/_test_parser.py b/posthog/hogql/test/_test_parser.py index 478958746d601..514914906d015 100644 --- a/posthog/hogql/test/_test_parser.py +++ b/posthog/hogql/test/_test_parser.py @@ -1,4 +1,4 @@ -from typing import Literal, cast, Optional, Dict +from typing import Literal, cast, Optional import math @@ -20,10 +20,10 @@ class TestParser(*base_classes): maxDiff = None - def _expr(self, expr: str, placeholders: Optional[Dict[str, ast.Expr]] = None) -> ast.Expr: + def _expr(self, expr: str, placeholders: Optional[dict[str, ast.Expr]] = None) -> ast.Expr: return clear_locations(parse_expr(expr, placeholders=placeholders, backend=backend)) - def _select(self, query: str, placeholders: Optional[Dict[str, ast.Expr]] = None) -> ast.Expr: + def _select(self, query: str, placeholders: Optional[dict[str, ast.Expr]] = None) -> ast.Expr: return clear_locations(parse_select(query, placeholders=placeholders, backend=backend)) def test_numbers(self): diff --git a/posthog/hogql/test/test_filters.py b/posthog/hogql/test/test_filters.py index 5aba11a3b28c6..05ac11667ae63 100644 --- a/posthog/hogql/test/test_filters.py +++ b/posthog/hogql/test/test_filters.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, Optional +from typing import Any, Optional from posthog.hogql import ast from posthog.hogql.context import HogQLContext @@ -18,10 +18,10 @@ class TestFilters(BaseTest): maxDiff = None - def _parse_expr(self, expr: str, placeholders: Optional[Dict[str, Any]] = None): + def _parse_expr(self, expr: str, placeholders: Optional[dict[str, Any]] = None): return clear_locations(parse_expr(expr, placeholders=placeholders)) - def _parse_select(self, select: str, placeholders: Optional[Dict[str, Any]] = None): + def _parse_select(self, select: str, placeholders: Optional[dict[str, Any]] = None): return clear_locations(parse_select(select, placeholders=placeholders)) def _print_ast(self, node: ast.Expr): diff --git a/posthog/hogql/test/test_mapping.py b/posthog/hogql/test/test_mapping.py index b13b2d1c744b2..9af0d9a60e44e 100644 --- a/posthog/hogql/test/test_mapping.py +++ b/posthog/hogql/test/test_mapping.py @@ -23,22 +23,22 @@ def _get_hogql_posthog_function(self, name: str) -> HogQLFunctionMeta: return self._return_present_function(find_hogql_posthog_function(name)) def test_find_case_sensitive_function(self): - self.assertEquals(self._get_hogql_function("toString").clickhouse_name, "toString") - self.assertEquals(find_hogql_function("TOString"), None) - self.assertEquals(find_hogql_function("PlUs"), None) + self.assertEqual(self._get_hogql_function("toString").clickhouse_name, "toString") + self.assertEqual(find_hogql_function("TOString"), None) + self.assertEqual(find_hogql_function("PlUs"), None) - self.assertEquals(self._get_hogql_aggregation("countIf").clickhouse_name, "countIf") - self.assertEquals(find_hogql_aggregation("COUNTIF"), None) + self.assertEqual(self._get_hogql_aggregation("countIf").clickhouse_name, "countIf") + self.assertEqual(find_hogql_aggregation("COUNTIF"), None) - self.assertEquals(self._get_hogql_posthog_function("sparkline").clickhouse_name, "sparkline") - self.assertEquals(find_hogql_posthog_function("SPARKLINE"), None) + self.assertEqual(self._get_hogql_posthog_function("sparkline").clickhouse_name, "sparkline") + self.assertEqual(find_hogql_posthog_function("SPARKLINE"), None) def test_find_case_insensitive_function(self): - self.assertEquals(self._get_hogql_function("CoAlesce").clickhouse_name, "coalesce") + self.assertEqual(self._get_hogql_function("CoAlesce").clickhouse_name, "coalesce") - self.assertEquals(self._get_hogql_aggregation("SuM").clickhouse_name, "sum") + self.assertEqual(self._get_hogql_aggregation("SuM").clickhouse_name, "sum") def test_find_non_existent_function(self): - self.assertEquals(find_hogql_function("functionThatDoesntExist"), None) - self.assertEquals(find_hogql_aggregation("functionThatDoesntExist"), None) - self.assertEquals(find_hogql_posthog_function("functionThatDoesntExist"), None) + self.assertEqual(find_hogql_function("functionThatDoesntExist"), None) + self.assertEqual(find_hogql_aggregation("functionThatDoesntExist"), None) + self.assertEqual(find_hogql_posthog_function("functionThatDoesntExist"), None) diff --git a/posthog/hogql/test/test_printer.py b/posthog/hogql/test/test_printer.py index 1a8a2130c5245..9c7a1fda936f5 100644 --- a/posthog/hogql/test/test_printer.py +++ b/posthog/hogql/test/test_printer.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional, Dict +from typing import Literal, Optional import pytest from django.test import override_settings @@ -35,7 +35,7 @@ def _select( self, query: str, context: Optional[HogQLContext] = None, - placeholders: Optional[Dict[str, ast.Expr]] = None, + placeholders: Optional[dict[str, ast.Expr]] = None, ) -> str: return print_ast( parse_select(query, placeholders=placeholders), diff --git a/posthog/hogql/test/test_property.py b/posthog/hogql/test/test_property.py index 9b07a362bdd3a..4f6ed2e115066 100644 --- a/posthog/hogql/test/test_property.py +++ b/posthog/hogql/test/test_property.py @@ -1,4 +1,4 @@ -from typing import List, Union, cast, Optional, Dict, Any, Literal +from typing import Union, cast, Optional, Any, Literal from unittest.mock import MagicMock, patch from posthog.constants import PropertyOperatorType, TREND_FILTER_TYPE_ACTIONS, TREND_FILTER_TYPE_EVENTS @@ -46,7 +46,7 @@ def _property_to_expr( def _selector_to_expr(self, selector: str): return clear_locations(selector_to_expr(selector)) - def _parse_expr(self, expr: str, placeholders: Optional[Dict[str, Any]] = None): + def _parse_expr(self, expr: str, placeholders: Optional[dict[str, Any]] = None): return clear_locations(parse_expr(expr, placeholders=placeholders)) def test_has_aggregation(self): @@ -416,7 +416,7 @@ def test_property_groups_combined(self): PropertyGroup( type=PropertyOperatorType.AND, values=cast( - Union[List[Property], List[PropertyGroup]], + Union[list[Property], list[PropertyGroup]], [ Property(type="person", key="a", value="b", operator="exact"), PropertyGroup( diff --git a/posthog/hogql/test/test_query.py b/posthog/hogql/test/test_query.py index 7dc2954380700..b7f13f9c07082 100644 --- a/posthog/hogql/test/test_query.py +++ b/posthog/hogql/test/test_query.py @@ -1014,7 +1014,7 @@ def test_property_access_with_arrays(self): f"LIMIT 100 " f"SETTINGS readonly=2, max_execution_time=60, allow_experimental_object_type=1", ) - self.assertEqual(response.results[0], tuple((random_uuid for x in alternatives))) + self.assertEqual(response.results[0], tuple(random_uuid for x in alternatives)) def test_property_access_with_arrays_zero_index_error(self): query = f"SELECT properties.something[0] FROM events" diff --git a/posthog/hogql/test/test_resolver.py b/posthog/hogql/test/test_resolver.py index a5f3b838c39be..7cbd5a60a3245 100644 --- a/posthog/hogql/test/test_resolver.py +++ b/posthog/hogql/test/test_resolver.py @@ -1,5 +1,5 @@ from datetime import timezone, datetime, date -from typing import Optional, Dict, cast +from typing import Optional, cast import pytest from django.test import override_settings from uuid import UUID @@ -28,7 +28,7 @@ class TestResolver(BaseTest): maxDiff = None - def _select(self, query: str, placeholders: Optional[Dict[str, ast.Expr]] = None) -> ast.SelectQuery: + def _select(self, query: str, placeholders: Optional[dict[str, ast.Expr]] = None) -> ast.SelectQuery: return cast( ast.SelectQuery, clone_expr(parse_select(query, placeholders=placeholders), clear_locations=True), diff --git a/posthog/hogql/test/test_timings.py b/posthog/hogql/test/test_timings.py index 02f8392da09ca..cfb2259157afa 100644 --- a/posthog/hogql/test/test_timings.py +++ b/posthog/hogql/test/test_timings.py @@ -26,8 +26,8 @@ def test_basic_timing(self): pass results = timings.to_dict() - self.assertAlmostEquals(results["./test"], 0.05) - self.assertAlmostEquals(results["."], 0.15) + self.assertAlmostEqual(results["./test"], 0.05) + self.assertAlmostEqual(results["."], 0.15) def test_no_timing(self): with patch("posthog.hogql.timings.perf_counter", fake_perf_counter): @@ -45,9 +45,9 @@ def test_nested_timing(self): pass results = timings.to_dict() - self.assertAlmostEquals(results["./outer/inner"], 0.05) - self.assertAlmostEquals(results["./outer"], 0.15) - self.assertAlmostEquals(results["."], 0.25) + self.assertAlmostEqual(results["./outer/inner"], 0.05) + self.assertAlmostEqual(results["./outer"], 0.15) + self.assertAlmostEqual(results["."], 0.25) def test_multiple_top_level_timings(self): with patch("posthog.hogql.timings.perf_counter", fake_perf_counter): @@ -59,9 +59,9 @@ def test_multiple_top_level_timings(self): pass results = timings.to_dict() - self.assertAlmostEquals(results["./first"], 0.05) - self.assertAlmostEquals(results["./second"], 0.05) - self.assertAlmostEquals(results["."], 0.25) + self.assertAlmostEqual(results["./first"], 0.05) + self.assertAlmostEqual(results["./second"], 0.05) + self.assertAlmostEqual(results["."], 0.25) def test_deeply_nested_timing(self): with patch("posthog.hogql.timings.perf_counter", fake_perf_counter): @@ -73,10 +73,10 @@ def test_deeply_nested_timing(self): pass results = timings.to_dict() - self.assertAlmostEquals(results["./a/b/c"], 0.05) - self.assertAlmostEquals(results["./a/b"], 0.15) - self.assertAlmostEquals(results["./a"], 0.25) - self.assertAlmostEquals(results["."], 0.35) + self.assertAlmostEqual(results["./a/b/c"], 0.05) + self.assertAlmostEqual(results["./a/b"], 0.15) + self.assertAlmostEqual(results["./a"], 0.25) + self.assertAlmostEqual(results["."], 0.35) def test_overlapping_keys(self): with patch("posthog.hogql.timings.perf_counter", fake_perf_counter): @@ -88,5 +88,5 @@ def test_overlapping_keys(self): pass results = timings.to_dict() - self.assertAlmostEquals(results["./a"], 0.1) - self.assertAlmostEquals(results["."], 0.25) + self.assertAlmostEqual(results["./a"], 0.1) + self.assertAlmostEqual(results["."], 0.25) diff --git a/posthog/hogql/timings.py b/posthog/hogql/timings.py index fca643d640b32..950d0f5bf23ae 100644 --- a/posthog/hogql/timings.py +++ b/posthog/hogql/timings.py @@ -1,6 +1,5 @@ from dataclasses import dataclass, field from time import perf_counter -from typing import Dict, List from contextlib import contextmanager from sentry_sdk import start_span @@ -11,10 +10,10 @@ @dataclass class HogQLTimings: # Completed time in seconds for different parts of the HogQL query - timings: Dict[str, float] = field(default_factory=dict) + timings: dict[str, float] = field(default_factory=dict) # Used for housekeeping - _timing_starts: Dict[str, float] = field(default_factory=dict) + _timing_starts: dict[str, float] = field(default_factory=dict) _timing_pointer: str = "." def __post_init__(self): @@ -37,11 +36,11 @@ def measure(self, key: str): if span: span.set_tag("duration_seconds", duration) - def to_dict(self) -> Dict[str, float]: + def to_dict(self) -> dict[str, float]: timings = {**self.timings} for key, start in reversed(self._timing_starts.items()): timings[key] = timings.get(key, 0.0) + (perf_counter() - start) return timings - def to_list(self) -> List[QueryTiming]: + def to_list(self) -> list[QueryTiming]: return [QueryTiming(k=key, t=time) for key, time in self.to_dict().items()] diff --git a/posthog/hogql/transforms/in_cohort.py b/posthog/hogql/transforms/in_cohort.py index d10e393f539e3..67fdd57a7df15 100644 --- a/posthog/hogql/transforms/in_cohort.py +++ b/posthog/hogql/transforms/in_cohort.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, cast, Literal +from typing import Optional, cast, Literal from posthog.hogql import ast @@ -13,7 +13,7 @@ def resolve_in_cohorts( node: ast.Expr, dialect: Literal["hogql", "clickhouse"], - stack: Optional[List[ast.SelectQuery]] = None, + stack: Optional[list[ast.SelectQuery]] = None, context: HogQLContext = None, ): InCohortResolver(stack=stack, dialect=dialect, context=context).visit(node) @@ -23,13 +23,13 @@ def resolve_in_cohorts_conjoined( node: ast.Expr, dialect: Literal["hogql", "clickhouse"], context: HogQLContext, - stack: Optional[List[ast.SelectQuery]] = None, + stack: Optional[list[ast.SelectQuery]] = None, ): MultipleInCohortResolver(stack=stack, dialect=dialect, context=context).visit(node) class CohortCompareOperationTraverser(TraversingVisitor): - ops: List[ast.CompareOperation] = [] + ops: list[ast.CompareOperation] = [] def __init__(self, expr: ast.Expr): self.ops = [] @@ -50,10 +50,10 @@ def __init__( self, dialect: Literal["hogql", "clickhouse"], context: HogQLContext, - stack: Optional[List[ast.SelectQuery]] = None, + stack: Optional[list[ast.SelectQuery]] = None, ): super().__init__() - self.stack: List[ast.SelectQuery] = stack or [] + self.stack: list[ast.SelectQuery] = stack or [] self.context = context self.dialect = dialect @@ -68,7 +68,7 @@ def visit_select_query(self, node: ast.SelectQuery): self.stack.pop() - def _execute(self, node: ast.SelectQuery, compare_operations: List[ast.CompareOperation]): + def _execute(self, node: ast.SelectQuery, compare_operations: list[ast.CompareOperation]): if len(compare_operations) == 0: return @@ -81,11 +81,11 @@ def _execute(self, node: ast.SelectQuery, compare_operations: List[ast.CompareOp compare_node.right = ast.Constant(value=1) def _resolve_cohorts( - self, compare_operations: List[ast.CompareOperation] - ) -> List[Tuple[int, StaticOrDynamic, int]]: + self, compare_operations: list[ast.CompareOperation] + ) -> list[tuple[int, StaticOrDynamic, int]]: from posthog.models import Cohort - cohorts: List[Tuple[int, StaticOrDynamic, int]] = [] + cohorts: list[tuple[int, StaticOrDynamic, int]] = [] for node in compare_operations: arg = node.right @@ -132,9 +132,9 @@ def _resolve_cohorts( def _add_join( self, - cohorts: List[Tuple[int, StaticOrDynamic, int]], + cohorts: list[tuple[int, StaticOrDynamic, int]], select: ast.SelectQuery, - compare_operations: List[ast.CompareOperation], + compare_operations: list[ast.CompareOperation], ): must_add_join = True last_join = select.select_from @@ -264,11 +264,11 @@ class InCohortResolver(TraversingVisitor): def __init__( self, dialect: Literal["hogql", "clickhouse"], - stack: Optional[List[ast.SelectQuery]] = None, + stack: Optional[list[ast.SelectQuery]] = None, context: HogQLContext = None, ): super().__init__() - self.stack: List[ast.SelectQuery] = stack or [] + self.stack: list[ast.SelectQuery] = stack or [] self.context = context self.dialect = dialect diff --git a/posthog/hogql/transforms/lazy_tables.py b/posthog/hogql/transforms/lazy_tables.py index bd3a3550034cd..c010fc13ce408 100644 --- a/posthog/hogql/transforms/lazy_tables.py +++ b/posthog/hogql/transforms/lazy_tables.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Dict, List, Optional, cast, Literal +from typing import Optional, cast, Literal from posthog.hogql import ast from posthog.hogql.context import HogQLContext @@ -13,7 +13,7 @@ def resolve_lazy_tables( node: ast.Expr, dialect: Literal["hogql", "clickhouse"], - stack: Optional[List[ast.SelectQuery]] = None, + stack: Optional[list[ast.SelectQuery]] = None, context: HogQLContext = None, ): LazyTableResolver(stack=stack, context=context, dialect=dialect).visit(node) @@ -21,7 +21,7 @@ def resolve_lazy_tables( @dataclasses.dataclass class JoinToAdd: - fields_accessed: Dict[str, List[str | int]] + fields_accessed: dict[str, list[str | int]] lazy_join: LazyJoin from_table: str to_table: str @@ -29,7 +29,7 @@ class JoinToAdd: @dataclasses.dataclass class TableToAdd: - fields_accessed: Dict[str, List[str | int]] + fields_accessed: dict[str, list[str | int]] lazy_table: LazyTable @@ -37,13 +37,13 @@ class TableToAdd: class ConstraintOverride: alias: str table_name: str - chain_to_replace: List[str | int] + chain_to_replace: list[str | int] class FieldChainReplacer(TraversingVisitor): - overrides: List[ConstraintOverride] = {} + overrides: list[ConstraintOverride] = {} - def __init__(self, overrides: List[ConstraintOverride]) -> None: + def __init__(self, overrides: list[ConstraintOverride]) -> None: super().__init__() self.overrides = overrides @@ -58,7 +58,7 @@ class LazyFinder(TraversingVisitor): max_type_visits: int = 3 def __init__(self) -> None: - self.visited_field_type_counts: Dict[int, int] = {} + self.visited_field_type_counts: dict[int, int] = {} def visit_lazy_join_type(self, node: ast.LazyJoinType): self.found_lazy = True @@ -80,11 +80,11 @@ class LazyTableResolver(TraversingVisitor): def __init__( self, dialect: Literal["hogql", "clickhouse"], - stack: Optional[List[ast.SelectQuery]] = None, + stack: Optional[list[ast.SelectQuery]] = None, context: HogQLContext = None, ): super().__init__() - self.stack_of_fields: List[List[ast.FieldType | ast.PropertyType]] = [[]] if stack else [] + self.stack_of_fields: list[list[ast.FieldType | ast.PropertyType]] = [[]] if stack else [] self.context = context self.dialect: Literal["hogql", "clickhouse"] = dialect @@ -129,30 +129,30 @@ def visit_select_query(self, node: ast.SelectQuery): assert select_type is not None # Collect each `ast.Field` with `ast.LazyJoinType` - field_collector: List[ast.FieldType | ast.PropertyType] = [] + field_collector: list[ast.FieldType | ast.PropertyType] = [] self.stack_of_fields.append(field_collector) # Collect all visited fields on lazy tables into field_collector super().visit_select_query(node) # Collect all the joins we need to add to the select query - joins_to_add: Dict[str, JoinToAdd] = {} - tables_to_add: Dict[str, TableToAdd] = {} + joins_to_add: dict[str, JoinToAdd] = {} + tables_to_add: dict[str, TableToAdd] = {} # First properties, then fields. This way we always get the smallest units to query first. - matched_properties: List[ast.PropertyType | ast.FieldType] = [ + matched_properties: list[ast.PropertyType | ast.FieldType] = [ property for property in field_collector if isinstance(property, ast.PropertyType) ] - matched_fields: List[ast.PropertyType | ast.FieldType] = [ + matched_fields: list[ast.PropertyType | ast.FieldType] = [ field for field in field_collector if isinstance(field, ast.FieldType) ] - sorted_properties: List[ast.PropertyType | ast.FieldType] = matched_properties + matched_fields + sorted_properties: list[ast.PropertyType | ast.FieldType] = matched_properties + matched_fields # Look for tables without requested fields to support cases like `select count() from table` join = node.select_from while join: if join.table is not None and isinstance(join.table.type, ast.LazyTableType): - fields: List[ast.FieldType | ast.PropertyType] = [] + fields: list[ast.FieldType | ast.PropertyType] = [] for field_or_property in field_collector: if isinstance(field_or_property, ast.FieldType): if isinstance(field_or_property.table_type, ast.TableAliasType): @@ -186,7 +186,7 @@ def visit_select_query(self, node: ast.SelectQuery): # Traverse the lazy tables until we reach a real table, collecting them in a list. # Usually there's just one or two. - table_types: List[ast.LazyJoinType | ast.LazyTableType | ast.TableAliasType] = [] + table_types: list[ast.LazyJoinType | ast.LazyTableType | ast.TableAliasType] = [] while ( isinstance(table_type, ast.TableAliasType) or isinstance(table_type, ast.LazyJoinType) @@ -217,12 +217,12 @@ def visit_select_query(self, node: ast.SelectQuery): ) new_join = joins_to_add[to_table] if table_type == field.table_type: - chain: List[str | int] = [] + chain: list[str | int] = [] chain.append(field.name) if property is not None: chain.extend(property.chain) property.joined_subquery_field_name = ( - f"{field.name}___{'___'.join((str(x) for x in property.chain))}" + f"{field.name}___{'___'.join(str(x) for x in property.chain)}" ) new_join.fields_accessed[property.joined_subquery_field_name] = chain else: @@ -241,7 +241,7 @@ def visit_select_query(self, node: ast.SelectQuery): if property is not None: chain.extend(property.chain) property.joined_subquery_field_name = ( - f"{field.name}___{'___'.join((str(x) for x in property.chain))}" + f"{field.name}___{'___'.join(str(x) for x in property.chain)}" ) new_table.fields_accessed[property.joined_subquery_field_name] = chain else: @@ -259,12 +259,12 @@ def visit_select_query(self, node: ast.SelectQuery): ) new_join = joins_to_add[to_table] if table_type == field.table_type: - chain: List[str | int] = [] + chain: list[str | int] = [] chain.append(field.name) if property is not None: chain.extend(property.chain) property.joined_subquery_field_name = ( - f"{field.name}___{'___'.join((str(x) for x in property.chain))}" + f"{field.name}___{'___'.join(str(x) for x in property.chain)}" ) new_join.fields_accessed[property.joined_subquery_field_name] = chain else: @@ -283,7 +283,7 @@ def visit_select_query(self, node: ast.SelectQuery): if property is not None: chain.extend(property.chain) property.joined_subquery_field_name = ( - f"{field.name}___{'___'.join((str(x) for x in property.chain))}" + f"{field.name}___{'___'.join(str(x) for x in property.chain)}" ) new_table.fields_accessed[property.joined_subquery_field_name] = chain else: @@ -291,10 +291,10 @@ def visit_select_query(self, node: ast.SelectQuery): # Make sure we also add fields we will use for the join's "ON" condition into the list of fields accessed. # Without this "pdi.person.id" won't work if you did not ALSO select "pdi.person_id" explicitly for the join. - join_constraint_overrides: Dict[str, List[ConstraintOverride]] = {} + join_constraint_overrides: dict[str, list[ConstraintOverride]] = {} - def create_override(table_name: str, field_chain: List[str | int]) -> None: - alias = f"{table_name}___{'___'.join((str(x) for x in field_chain))}" + def create_override(table_name: str, field_chain: list[str | int]) -> None: + alias = f"{table_name}___{'___'.join(str(x) for x in field_chain)}" if table_name in tables_to_add: tables_to_add[table_name].fields_accessed[alias] = field_chain @@ -387,7 +387,7 @@ def create_override(table_name: str, field_chain: List[str | int]) -> None: node.select_from = join_to_add # Collect any fields or properties that may have been added from the join_function with the LazyJoinType - join_field_collector: List[ast.FieldType | ast.PropertyType] = [] + join_field_collector: list[ast.FieldType | ast.PropertyType] = [] self.stack_of_fields.append(join_field_collector) super().visit(join_to_add) self.stack_of_fields.pop() diff --git a/posthog/hogql/transforms/property_types.py b/posthog/hogql/transforms/property_types.py index cc5451bf6bc3a..5627980fa0dfc 100644 --- a/posthog/hogql/transforms/property_types.py +++ b/posthog/hogql/transforms/property_types.py @@ -1,4 +1,4 @@ -from typing import Dict, Set, Literal, Optional, cast +from typing import Literal, Optional, cast from posthog.hogql import ast from posthog.hogql.context import HogQLContext @@ -81,9 +81,9 @@ class PropertyFinder(TraversingVisitor): def __init__(self, context: HogQLContext): super().__init__() - self.person_properties: Set[str] = set() - self.event_properties: Set[str] = set() - self.group_properties: Dict[int, Set[str]] = {} + self.person_properties: set[str] = set() + self.event_properties: set[str] = set() + self.group_properties: dict[int, set[str]] = {} self.found_timestamps = False self.context = context @@ -123,9 +123,9 @@ class PropertySwapper(CloningVisitor): def __init__( self, timezone: str, - event_properties: Dict[str, str], - person_properties: Dict[str, str], - group_properties: Dict[str, str], + event_properties: dict[str, str], + person_properties: dict[str, str], + group_properties: dict[str, str], context: HogQLContext, ): super().__init__(clear_types=False) diff --git a/posthog/hogql_queries/actor_strategies.py b/posthog/hogql_queries/actor_strategies.py index d05661d4eddb6..41cd8d5a1bf3a 100644 --- a/posthog/hogql_queries/actor_strategies.py +++ b/posthog/hogql_queries/actor_strategies.py @@ -1,4 +1,4 @@ -from typing import Dict, List, cast, Literal, Optional +from typing import cast, Literal, Optional from django.db.models import Prefetch @@ -21,19 +21,19 @@ def __init__(self, team: Team, query: ActorsQuery, paginator: HogQLHasMorePagina self.paginator = paginator self.query = query - def get_actors(self, actor_ids) -> Dict[str, Dict]: + def get_actors(self, actor_ids) -> dict[str, dict]: raise NotImplementedError() def get_recordings(self, matching_events) -> dict[str, list[dict]]: return {} - def input_columns(self) -> List[str]: + def input_columns(self) -> list[str]: raise NotImplementedError() - def filter_conditions(self) -> List[ast.Expr]: + def filter_conditions(self) -> list[ast.Expr]: return [] - def order_by(self) -> Optional[List[ast.OrderExpr]]: + def order_by(self) -> Optional[list[ast.OrderExpr]]: return None @@ -42,7 +42,7 @@ class PersonStrategy(ActorStrategy): origin = "persons" origin_id = "id" - def get_actors(self, actor_ids) -> Dict[str, Dict]: + def get_actors(self, actor_ids) -> dict[str, dict]: return { str(p.uuid): { "id": p.uuid, @@ -58,11 +58,11 @@ def get_actors(self, actor_ids) -> Dict[str, Dict]: def get_recordings(self, matching_events) -> dict[str, list[dict]]: return RecordingsHelper(self.team).get_recordings(matching_events) - def input_columns(self) -> List[str]: + def input_columns(self) -> list[str]: return ["person", "id", "created_at", "person.$delete"] - def filter_conditions(self) -> List[ast.Expr]: - where_exprs: List[ast.Expr] = [] + def filter_conditions(self) -> list[ast.Expr]: + where_exprs: list[ast.Expr] = [] if self.query.properties: where_exprs.append(property_to_expr(self.query.properties, self.team, scope="person")) @@ -98,7 +98,7 @@ def filter_conditions(self) -> List[ast.Expr]: ) return where_exprs - def order_by(self) -> Optional[List[ast.OrderExpr]]: + def order_by(self) -> Optional[list[ast.OrderExpr]]: if self.query.orderBy not in [["person"], ["person DESC"], ["person ASC"]]: return None @@ -125,7 +125,7 @@ def __init__(self, group_type_index: int, **kwargs): self.group_type_index = group_type_index super().__init__(**kwargs) - def get_actors(self, actor_ids) -> Dict[str, Dict]: + def get_actors(self, actor_ids) -> dict[str, dict]: return { str(p["group_key"]): { "id": p["group_key"], @@ -140,11 +140,11 @@ def get_actors(self, actor_ids) -> Dict[str, Dict]: .iterator(chunk_size=self.paginator.limit) } - def input_columns(self) -> List[str]: + def input_columns(self) -> list[str]: return ["group"] - def filter_conditions(self) -> List[ast.Expr]: - where_exprs: List[ast.Expr] = [] + def filter_conditions(self) -> list[ast.Expr]: + where_exprs: list[ast.Expr] = [] if self.query.search is not None and self.query.search != "": where_exprs.append( @@ -166,7 +166,7 @@ def filter_conditions(self) -> List[ast.Expr]: return where_exprs - def order_by(self) -> Optional[List[ast.OrderExpr]]: + def order_by(self) -> Optional[list[ast.OrderExpr]]: if self.query.orderBy not in [["group"], ["group DESC"], ["group ASC"]]: return None diff --git a/posthog/hogql_queries/actors_query_runner.py b/posthog/hogql_queries/actors_query_runner.py index da2e142bf6636..8224067c24d36 100644 --- a/posthog/hogql_queries/actors_query_runner.py +++ b/posthog/hogql_queries/actors_query_runner.py @@ -1,6 +1,7 @@ import itertools from datetime import timedelta -from typing import List, Generator, Sequence, Iterator, Optional +from typing import Optional +from collections.abc import Generator, Sequence, Iterator from posthog.hogql import ast from posthog.hogql.parser import parse_expr, parse_order_expr from posthog.hogql.property import has_aggregation @@ -53,7 +54,7 @@ def enrich_with_actors( actors_lookup, recordings_column_index: Optional[int], recordings_lookup: Optional[dict[str, list[dict]]], - ) -> Generator[List, None, None]: + ) -> Generator[list, None, None]: for result in results: new_row = list(result) actor_id = str(result[actor_column_index]) @@ -70,9 +71,7 @@ def prepare_recordings(self, column_name, input_columns): return None, None column_index_events = input_columns.index("matched_recordings") - matching_events_list = itertools.chain.from_iterable( - (row[column_index_events] for row in self.paginator.results) - ) + matching_events_list = itertools.chain.from_iterable(row[column_index_events] for row in self.paginator.results) return column_index_events, self.strategy.get_recordings(matching_events_list) def calculate(self) -> ActorsQueryResponse: @@ -85,7 +84,7 @@ def calculate(self) -> ActorsQueryResponse: ) input_columns = self.input_columns() missing_actors_count = None - results: Sequence[List] | Iterator[List] = self.paginator.results + results: Sequence[list] | Iterator[list] = self.paginator.results enrich_columns = filter(lambda column: column in ("person", "group", "actor"), input_columns) for column_name in enrich_columns: @@ -110,14 +109,14 @@ def calculate(self) -> ActorsQueryResponse: **self.paginator.response_params(), ) - def input_columns(self) -> List[str]: + def input_columns(self) -> list[str]: if self.query.select: return self.query.select return self.strategy.input_columns() # TODO: Figure out a more sure way of getting the actor id than using the alias or chain name - def source_id_column(self, source_query: ast.SelectQuery | ast.SelectUnionQuery) -> List[str]: + def source_id_column(self, source_query: ast.SelectQuery | ast.SelectUnionQuery) -> list[str]: # Figure out the id column of the source query, first column that has id in the name if isinstance(source_query, ast.SelectQuery): select = source_query.select diff --git a/posthog/hogql_queries/events_query_runner.py b/posthog/hogql_queries/events_query_runner.py index fe04ed8aa8563..9dc329e9e464d 100644 --- a/posthog/hogql_queries/events_query_runner.py +++ b/posthog/hogql_queries/events_query_runner.py @@ -1,6 +1,6 @@ import json from datetime import timedelta -from typing import Dict, List, Optional +from typing import Optional from dateutil.parser import isoparse from django.db.models import Prefetch @@ -53,8 +53,8 @@ def to_query(self) -> ast.SelectQuery: with self.timings.measure("build_ast"): # columns & group_by with self.timings.measure("columns"): - select_input: List[str] = [] - person_indices: List[int] = [] + select_input: list[str] = [] + person_indices: list[int] = [] for index, col in enumerate(self.select_input_raw()): # Selecting a "*" expands the list of columns, resulting in a table that's not what we asked for. # Instead, ask for a tuple with all the columns we want. Later transform this back into a dict. @@ -66,11 +66,11 @@ def to_query(self) -> ast.SelectQuery: person_indices.append(index) else: select_input.append(col) - select: List[ast.Expr] = [parse_expr(column, timings=self.timings) for column in select_input] + select: list[ast.Expr] = [parse_expr(column, timings=self.timings) for column in select_input] with self.timings.measure("aggregations"): - group_by: List[ast.Expr] = [column for column in select if not has_aggregation(column)] - aggregations: List[ast.Expr] = [column for column in select if has_aggregation(column)] + group_by: list[ast.Expr] = [column for column in select if not has_aggregation(column)] + aggregations: list[ast.Expr] = [column for column in select if has_aggregation(column)] has_any_aggregation = len(aggregations) > 0 # filters @@ -210,7 +210,7 @@ def calculate(self) -> EventsQueryResponse: ).data self.paginator.results[index][star_idx] = new_result - person_indices: List[int] = [] + person_indices: list[int] = [] for index, col in enumerate(self.select_input_raw()): if col.split("--")[0].strip() == "person": person_indices.append(index) @@ -222,7 +222,7 @@ def calculate(self) -> EventsQueryResponse: distinct_ids = list({event[person_idx] for event in self.paginator.results}) persons = get_persons_by_distinct_ids(self.team.pk, distinct_ids) persons = persons.prefetch_related(Prefetch("persondistinctid_set", to_attr="distinct_ids_cache")) - distinct_to_person: Dict[str, Person] = {} + distinct_to_person: dict[str, Person] = {} for person in persons: if person: for person_distinct_id in person.distinct_ids: @@ -268,7 +268,7 @@ def apply_dashboard_filters(self, dashboard_filter: DashboardFilter): return new_query - def select_input_raw(self) -> List[str]: + def select_input_raw(self) -> list[str]: return ["*"] if len(self.query.select) == 0 else self.query.select def _is_stale(self, cached_result_package): diff --git a/posthog/hogql_queries/hogql_query_runner.py b/posthog/hogql_queries/hogql_query_runner.py index 46b4c105a4336..3a9a0b62efd98 100644 --- a/posthog/hogql_queries/hogql_query_runner.py +++ b/posthog/hogql_queries/hogql_query_runner.py @@ -1,5 +1,6 @@ from datetime import timedelta -from typing import Callable, Dict, Optional, cast +from typing import Optional, cast +from collections.abc import Callable from posthog.clickhouse.client.connection import Workload from posthog.hogql import ast @@ -26,7 +27,7 @@ class HogQLQueryRunner(QueryRunner): def to_query(self) -> ast.SelectQuery: if self.timings is None: self.timings = HogQLTimings() - values: Optional[Dict[str, ast.Expr]] = ( + values: Optional[dict[str, ast.Expr]] = ( {key: ast.Constant(value=value) for key, value in self.query.values.items()} if self.query.values else None ) with self.timings.measure("parse_select"): diff --git a/posthog/hogql_queries/insights/funnels/base.py b/posthog/hogql_queries/insights/funnels/base.py index 1dade0de4b052..40614464f1361 100644 --- a/posthog/hogql_queries/insights/funnels/base.py +++ b/posthog/hogql_queries/insights/funnels/base.py @@ -1,6 +1,6 @@ from abc import ABC from functools import cached_property -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Optional, Union, cast import uuid from posthog.clickhouse.materialized_columns.column import ColumnName from posthog.hogql import ast @@ -37,14 +37,14 @@ class FunnelBase(ABC): context: FunnelQueryContext - _extra_event_fields: List[ColumnName] - _extra_event_properties: List[PropertyName] + _extra_event_fields: list[ColumnName] + _extra_event_properties: list[PropertyName] def __init__(self, context: FunnelQueryContext): self.context = context - self._extra_event_fields: List[ColumnName] = [] - self._extra_event_properties: List[PropertyName] = [] + self._extra_event_fields: list[ColumnName] = [] + self._extra_event_properties: list[PropertyName] = [] if ( hasattr(self.context, "actorsQuery") @@ -86,7 +86,7 @@ def get_step_counts_without_aggregation_query(self) -> ast.SelectQuery: raise NotImplementedError() @cached_property - def breakdown_cohorts(self) -> List[Cohort]: + def breakdown_cohorts(self) -> list[Cohort]: team, breakdown = self.context.team, self.context.breakdown if isinstance(breakdown, list): @@ -97,7 +97,7 @@ def breakdown_cohorts(self) -> List[Cohort]: return list(cohorts) @cached_property - def breakdown_cohorts_ids(self) -> List[int]: + def breakdown_cohorts_ids(self) -> list[int]: breakdown = self.context.breakdown ids = [int(cohort.pk) for cohort in self.breakdown_cohorts] @@ -108,7 +108,7 @@ def breakdown_cohorts_ids(self) -> List[int]: return ids @cached_property - def breakdown_values(self) -> List[int] | List[str] | List[List[str]]: + def breakdown_values(self) -> list[int] | list[str] | list[list[str]]: # """ # Returns the top N breakdown prop values for event/person breakdown @@ -169,7 +169,7 @@ def breakdown_values(self) -> List[int] | List[str] | List[List[str]]: else: prop_exprs = [] - where_exprs: List[ast.Expr | None] = [ + where_exprs: list[ast.Expr | None] = [ # entity filter entity_expr, # prop filter @@ -209,7 +209,7 @@ def breakdown_values(self) -> List[int] | List[str] | List[List[str]]: raise ValidationError("Apologies, there has been an error computing breakdown values.") return [row[0] for row in results[0:breakdown_limit_or_default]] - def _get_breakdown_select_prop(self) -> List[ast.Expr]: + def _get_breakdown_select_prop(self) -> list[ast.Expr]: breakdown, breakdownAttributionType, funnelsFilter = ( self.context.breakdown, self.context.breakdownAttributionType, @@ -296,7 +296,7 @@ def _get_breakdown_expr(self) -> ast.Expr: def _format_results( self, results - ) -> Union[FunnelTimeToConvertResults, List[Dict[str, Any]], List[List[Dict[str, Any]]]]: + ) -> Union[FunnelTimeToConvertResults, list[dict[str, Any]], list[list[dict[str, Any]]]]: breakdown = self.context.breakdown if not results or len(results) == 0: @@ -387,9 +387,9 @@ def _serialize_step( step: ActionsNode | EventsNode | DataWarehouseNode, count: int, index: int, - people: Optional[List[uuid.UUID]] = None, + people: Optional[list[uuid.UUID]] = None, sampling_factor: Optional[float] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: action_id: Optional[str | int] if isinstance(step, EventsNode): name = step.event @@ -419,7 +419,7 @@ def extra_event_fields_and_properties(self): def _get_inner_event_query( self, - entities: List[EntityNode] | None = None, + entities: list[EntityNode] | None = None, entity_name="events", skip_entity_filter=False, skip_step_filter=False, @@ -433,7 +433,7 @@ def _get_inner_event_query( ) entities_to_use = entities or query.series - extra_fields: List[str] = [] + extra_fields: list[str] = [] for prop in self.context.includeProperties: extra_fields.append(prop) @@ -450,7 +450,7 @@ def _get_inner_event_query( # extra_event_properties=self._extra_event_properties, # ).get_query(entities_to_use, entity_name, skip_entity_filter=skip_entity_filter) - all_step_cols: List[ast.Expr] = [] + all_step_cols: list[ast.Expr] = [] for index, entity in enumerate(entities_to_use): step_cols = self._get_step_col(entity, index, entity_name) all_step_cols.extend(step_cols) @@ -489,7 +489,7 @@ def _get_inner_event_query( def _get_cohort_breakdown_join(self) -> ast.JoinExpr: breakdown = self.context.breakdown - cohort_queries: List[ast.SelectQuery] = [] + cohort_queries: list[ast.SelectQuery] = [] for cohort in self.breakdown_cohorts: query = parse_select( @@ -564,7 +564,7 @@ def _add_breakdown_attribution_subquery(self, inner_query: ast.SelectQuery) -> a return query def _get_steps_conditions(self, length: int) -> ast.Expr: - step_conditions: List[ast.Expr] = [] + step_conditions: list[ast.Expr] = [] for index in range(length): step_conditions.append(parse_expr(f"step_{index} = 1")) @@ -580,10 +580,10 @@ def _get_step_col( index: int, entity_name: str, step_prefix: str = "", - ) -> List[ast.Expr]: + ) -> list[ast.Expr]: # step prefix is used to distinguish actual steps, and exclusion steps # without the prefix, we get the same parameter binding for both, which borks things up - step_cols: List[ast.Expr] = [] + step_cols: list[ast.Expr] = [] condition = self._build_step_query(entity, index, entity_name, step_prefix) step_cols.append( parse_expr(f"if({{condition}}, 1, 0) as {step_prefix}step_{index}", placeholders={"condition": condition}) @@ -626,7 +626,7 @@ def _build_step_query( else: return event_expr - def _get_timestamp_outer_select(self) -> List[ast.Expr]: + def _get_timestamp_outer_select(self) -> list[ast.Expr]: if self.context.includePrecedingTimestamp: return [ast.Field(chain=["max_timestamp"]), ast.Field(chain=["min_timestamp"])] elif self.context.includeTimestamp: @@ -646,7 +646,7 @@ def _get_funnel_person_step_condition(self) -> ast.Expr: funnelCustomSteps = actorsQuery.funnelCustomSteps funnelStepBreakdown = actorsQuery.funnelStepBreakdown - conditions: List[ast.Expr] = [] + conditions: list[ast.Expr] = [] if funnelCustomSteps: conditions.append(parse_expr(f"steps IN {funnelCustomSteps}")) @@ -673,7 +673,7 @@ def _get_funnel_person_step_condition(self) -> ast.Expr: return ast.And(exprs=conditions) - def _get_funnel_person_step_events(self) -> List[ast.Expr]: + def _get_funnel_person_step_events(self) -> list[ast.Expr]: if ( hasattr(self.context, "actorsQuery") and self.context.actorsQuery is not None @@ -694,23 +694,23 @@ def _get_funnel_person_step_events(self) -> List[ast.Expr]: return [parse_expr(f"step_{matching_events_step_num}_matching_events as matching_events")] return [] - def _get_count_columns(self, max_steps: int) -> List[ast.Expr]: - exprs: List[ast.Expr] = [] + def _get_count_columns(self, max_steps: int) -> list[ast.Expr]: + exprs: list[ast.Expr] = [] for i in range(max_steps): exprs.append(parse_expr(f"countIf(steps = {i + 1}) step_{i + 1}")) return exprs - def _get_step_time_names(self, max_steps: int) -> List[ast.Expr]: - exprs: List[ast.Expr] = [] + def _get_step_time_names(self, max_steps: int) -> list[ast.Expr]: + exprs: list[ast.Expr] = [] for i in range(1, max_steps): exprs.append(parse_expr(f"step_{i}_conversion_time")) return exprs - def _get_final_matching_event(self, max_steps: int) -> List[ast.Expr]: + def _get_final_matching_event(self, max_steps: int) -> list[ast.Expr]: statement = None for i in range(max_steps - 1, -1, -1): if i == max_steps - 1: @@ -721,7 +721,7 @@ def _get_final_matching_event(self, max_steps: int) -> List[ast.Expr]: statement = f"if(isNull(latest_{i}),step_{i-1}_matching_event,{statement})" return [parse_expr(f"{statement} as final_matching_event")] if statement else [] - def _get_matching_events(self, max_steps: int) -> List[ast.Expr]: + def _get_matching_events(self, max_steps: int) -> list[ast.Expr]: if ( hasattr(self.context, "actorsQuery") and self.context.actorsQuery is not None @@ -737,8 +737,8 @@ def _get_matching_events(self, max_steps: int) -> List[ast.Expr]: return [*events, *self._get_final_matching_event(max_steps)] return [] - def _get_matching_event_arrays(self, max_steps: int) -> List[ast.Expr]: - exprs: List[ast.Expr] = [] + def _get_matching_event_arrays(self, max_steps: int) -> list[ast.Expr]: + exprs: list[ast.Expr] = [] if ( hasattr(self.context, "actorsQuery") and self.context.actorsQuery is not None @@ -749,8 +749,8 @@ def _get_matching_event_arrays(self, max_steps: int) -> List[ast.Expr]: exprs.append(parse_expr(f"groupArray(10)(final_matching_event) as final_matching_events")) return exprs - def _get_step_time_avgs(self, max_steps: int, inner_query: bool = False) -> List[ast.Expr]: - exprs: List[ast.Expr] = [] + def _get_step_time_avgs(self, max_steps: int, inner_query: bool = False) -> list[ast.Expr]: + exprs: list[ast.Expr] = [] for i in range(1, max_steps): exprs.append( @@ -761,8 +761,8 @@ def _get_step_time_avgs(self, max_steps: int, inner_query: bool = False) -> List return exprs - def _get_step_time_median(self, max_steps: int, inner_query: bool = False) -> List[ast.Expr]: - exprs: List[ast.Expr] = [] + def _get_step_time_median(self, max_steps: int, inner_query: bool = False) -> list[ast.Expr]: + exprs: list[ast.Expr] = [] for i in range(1, max_steps): exprs.append( @@ -773,7 +773,7 @@ def _get_step_time_median(self, max_steps: int, inner_query: bool = False) -> Li return exprs - def _get_timestamp_selects(self) -> Tuple[List[ast.Expr], List[ast.Expr]]: + def _get_timestamp_selects(self) -> tuple[list[ast.Expr], list[ast.Expr]]: """ Returns timestamp selectors for the target step and optionally the preceding step. In the former case, always returns the timestamp for the first and last step as well. @@ -829,11 +829,11 @@ def _get_timestamp_selects(self) -> Tuple[List[ast.Expr], List[ast.Expr]]: else: return [], [] - def _get_step_times(self, max_steps: int) -> List[ast.Expr]: + def _get_step_times(self, max_steps: int) -> list[ast.Expr]: windowInterval = self.context.funnelWindowInterval windowIntervalUnit = funnel_window_interval_unit_to_sql(self.context.funnelWindowIntervalUnit) - exprs: List[ast.Expr] = [] + exprs: list[ast.Expr] = [] for i in range(1, max_steps): exprs.append( @@ -844,12 +844,12 @@ def _get_step_times(self, max_steps: int) -> List[ast.Expr]: return exprs - def _get_partition_cols(self, level_index: int, max_steps: int) -> List[ast.Expr]: + def _get_partition_cols(self, level_index: int, max_steps: int) -> list[ast.Expr]: query, funnelsFilter = self.context.query, self.context.funnelsFilter exclusions = funnelsFilter.exclusions series = query.series - exprs: List[ast.Expr] = [] + exprs: list[ast.Expr] = [] for i in range(0, max_steps): exprs.append(ast.Field(chain=[f"step_{i}"])) @@ -894,7 +894,7 @@ def _get_partition_cols(self, level_index: int, max_steps: int) -> List[ast.Expr return exprs - def _get_breakdown_prop_expr(self, group_remaining=False) -> List[ast.Expr]: + def _get_breakdown_prop_expr(self, group_remaining=False) -> list[ast.Expr]: # SEE BELOW for a string implementation of the following breakdown, breakdownType = self.context.breakdown, self.context.breakdownType @@ -938,7 +938,7 @@ def _get_breakdown_prop(self, group_remaining=False) -> str: else: return "" - def _get_breakdown_conditions(self) -> Optional[List[int] | List[str] | List[List[str]]]: + def _get_breakdown_conditions(self) -> Optional[list[int] | list[str] | list[list[str]]]: """ For people, pagination sets the offset param, which is common across filters and gives us the wrong breakdown values here, so we override it. @@ -957,7 +957,7 @@ def _query_has_array_breakdown(self) -> bool: breakdown, breakdownType = self.context.breakdown, self.context.breakdownType return not isinstance(breakdown, str) and breakdownType != "cohort" - def _get_exclusion_condition(self) -> List[ast.Expr]: + def _get_exclusion_condition(self) -> list[ast.Expr]: funnelsFilter = self.context.funnelsFilter windowInterval = self.context.funnelWindowInterval windowIntervalUnit = funnel_window_interval_unit_to_sql(self.context.funnelWindowIntervalUnit) @@ -965,7 +965,7 @@ def _get_exclusion_condition(self) -> List[ast.Expr]: if not funnelsFilter.exclusions: return [] - conditions: List[ast.Expr] = [] + conditions: list[ast.Expr] = [] for exclusion_id, exclusion in enumerate(funnelsFilter.exclusions): from_time = f"latest_{exclusion.funnelFromStep}" @@ -995,7 +995,7 @@ def _get_sorting_condition(self, curr_index: int, max_steps: int) -> ast.Expr: if curr_index == 1: return ast.Constant(value=1) - conditions: List[ast.Expr] = [] + conditions: list[ast.Expr] = [] for i in range(1, curr_index): duplicate_event = is_equal(series[i], series[i - 1]) or is_superset(series[i], series[i - 1]) @@ -1016,8 +1016,8 @@ def _get_sorting_condition(self, curr_index: int, max_steps: int) -> ast.Expr: ], ) - def _get_person_and_group_properties(self, aggregate: bool = False) -> List[ast.Expr]: - exprs: List[ast.Expr] = [] + def _get_person_and_group_properties(self, aggregate: bool = False) -> list[ast.Expr]: + exprs: list[ast.Expr] = [] for prop in self.context.includeProperties: exprs.append(parse_expr(f"any({prop}) as {prop}") if aggregate else parse_expr(prop)) diff --git a/posthog/hogql_queries/insights/funnels/funnel.py b/posthog/hogql_queries/insights/funnels/funnel.py index b5ce2bb7faf53..1975645d753e1 100644 --- a/posthog/hogql_queries/insights/funnels/funnel.py +++ b/posthog/hogql_queries/insights/funnels/funnel.py @@ -1,5 +1,3 @@ -from typing import List - from posthog.hogql import ast from posthog.hogql.parser import parse_expr from posthog.hogql_queries.insights.funnels.base import FunnelBase @@ -35,7 +33,7 @@ def get_query(self): breakdown_exprs = self._get_breakdown_prop_expr() - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ *self._get_count_columns(max_steps), *self._get_step_time_avgs(max_steps), *self._get_step_time_median(max_steps), @@ -54,13 +52,13 @@ def get_step_counts_query(self): inner_timestamps, outer_timestamps = self._get_timestamp_selects() person_and_group_properties = self._get_person_and_group_properties(aggregate=True) - group_by_columns: List[ast.Expr] = [ + group_by_columns: list[ast.Expr] = [ ast.Field(chain=["aggregation_target"]), ast.Field(chain=["steps"]), *breakdown_exprs, ] - outer_select: List[ast.Expr] = [ + outer_select: list[ast.Expr] = [ *group_by_columns, *self._get_step_time_avgs(max_steps, inner_query=True), *self._get_step_time_median(max_steps, inner_query=True), @@ -74,7 +72,7 @@ def get_step_counts_query(self): f"max(steps) over (PARTITION BY aggregation_target {self._get_breakdown_prop()}) as max_steps" ) - inner_select: List[ast.Expr] = [ + inner_select: list[ast.Expr] = [ *group_by_columns, max_steps_expr, *self._get_step_time_names(max_steps), @@ -106,7 +104,7 @@ def get_step_counts_without_aggregation_query(self): formatted_query = self._build_step_subquery(2, max_steps) breakdown_exprs = self._get_breakdown_prop_expr() - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ ast.Field(chain=["*"]), ast.Alias(alias="steps", expr=self._get_sorting_condition(max_steps, max_steps)), *self._get_exclusion_condition(), @@ -135,7 +133,7 @@ def get_step_counts_without_aggregation_query(self): def _build_step_subquery( self, level_index: int, max_steps: int, event_names_alias: str = "events" ) -> ast.SelectQuery: - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ ast.Field(chain=["aggregation_target"]), ast.Field(chain=["timestamp"]), ] @@ -175,12 +173,12 @@ def _build_step_subquery( ), ) - def _get_comparison_cols(self, level_index: int, max_steps: int) -> List[ast.Expr]: + def _get_comparison_cols(self, level_index: int, max_steps: int) -> list[ast.Expr]: """ level_index: The current smallest comparison step. Everything before level index is already at the minimum ordered timestamps. """ - exprs: List[ast.Expr] = [] + exprs: list[ast.Expr] = [] funnelsFilter = self.context.funnelsFilter exclusions = funnelsFilter.exclusions @@ -225,7 +223,7 @@ def _get_comparison_cols(self, level_index: int, max_steps: int) -> List[ast.Exp return exprs def _get_comparison_at_step(self, index: int, level_index: int) -> ast.Or: - exprs: List[ast.Expr] = [] + exprs: list[ast.Expr] = [] for i in range(level_index, index + 1): exprs.append(parse_expr(f"latest_{i} < latest_{level_index - 1}")) diff --git a/posthog/hogql_queries/insights/funnels/funnel_correlation_query_runner.py b/posthog/hogql_queries/insights/funnels/funnel_correlation_query_runner.py index 04b1115fd38d2..035339c8e02ad 100644 --- a/posthog/hogql_queries/insights/funnels/funnel_correlation_query_runner.py +++ b/posthog/hogql_queries/insights/funnels/funnel_correlation_query_runner.py @@ -1,6 +1,6 @@ import dataclasses from datetime import timedelta -from typing import List, Literal, Optional, Any, Dict, Set, TypedDict, cast +from typing import Literal, Optional, Any, TypedDict, cast from posthog.constants import AUTOCAPTURE_EVENT from posthog.hogql.parser import parse_select @@ -95,7 +95,7 @@ class FunnelCorrelationQueryRunner(QueryRunner): def __init__( self, - query: FunnelCorrelationQuery | Dict[str, Any], + query: FunnelCorrelationQuery | dict[str, Any], team: Team, timings: Optional[HogQLTimings] = None, modifiers: Optional[HogQLQueryModifiers] = None, @@ -132,7 +132,7 @@ def __init__( # Used for generating the funnel persons cte funnel_order_actor_class = get_funnel_actor_class(self.context.funnelsFilter)(context=self.context) assert isinstance( - funnel_order_actor_class, (FunnelActors, FunnelStrictActors, FunnelUnorderedActors) + funnel_order_actor_class, FunnelActors | FunnelStrictActors | FunnelUnorderedActors ) # for typings self._funnel_actors_generator = funnel_order_actor_class @@ -228,7 +228,7 @@ def calculate(self) -> FunnelCorrelationResponse: modifiers=self.modifiers, ) - def _calculate(self) -> tuple[List[EventOddsRatio], bool, str, HogQLQueryResponse]: + def _calculate(self) -> tuple[list[EventOddsRatio], bool, str, HogQLQueryResponse]: query = self.to_query() hogql = to_printed_hogql(query, self.team) @@ -823,8 +823,8 @@ def _get_properties_prop_clause(self): props_str = ", ".join(props) return f"arrayJoin(arrayZip({self.query.funnelCorrelationNames}, [{props_str}])) as prop" - def _get_funnel_step_names(self) -> List[str]: - events: Set[str] = set() + def _get_funnel_step_names(self) -> list[str]: + events: set[str] = set() for entity in self.funnels_query.series: if isinstance(entity, ActionsNode): action = Action.objects.get(pk=int(entity.id), team=self.context.team) @@ -838,8 +838,8 @@ def _get_funnel_step_names(self) -> List[str]: return sorted(events) @property - def properties_to_include(self) -> List[str]: - props_to_include: List[str] = [] + def properties_to_include(self) -> list[str]: + props_to_include: list[str] = [] # TODO: implement or remove # if self.query.funnelCorrelationType == FunnelCorrelationResultsType.properties: # assert self.query.funnelCorrelationNames is not None diff --git a/posthog/hogql_queries/insights/funnels/funnel_event_query.py b/posthog/hogql_queries/insights/funnels/funnel_event_query.py index b2fd19083ed75..8acb0f7dea87b 100644 --- a/posthog/hogql_queries/insights/funnels/funnel_event_query.py +++ b/posthog/hogql_queries/insights/funnels/funnel_event_query.py @@ -1,4 +1,4 @@ -from typing import List, Set, Union, Optional +from typing import Union, Optional from posthog.clickhouse.materialized_columns.column import ColumnName from posthog.hogql import ast from posthog.hogql.parser import parse_expr @@ -13,16 +13,16 @@ class FunnelEventQuery: context: FunnelQueryContext - _extra_fields: List[ColumnName] - _extra_event_properties: List[PropertyName] + _extra_fields: list[ColumnName] + _extra_event_properties: list[PropertyName] EVENT_TABLE_ALIAS = "e" def __init__( self, context: FunnelQueryContext, - extra_fields: Optional[List[ColumnName]] = None, - extra_event_properties: Optional[List[PropertyName]] = None, + extra_fields: Optional[list[ColumnName]] = None, + extra_event_properties: Optional[list[PropertyName]] = None, ): if extra_event_properties is None: extra_event_properties = [] @@ -38,12 +38,12 @@ def to_query( # entities=None, # TODO: implement passed in entities when needed skip_entity_filter=False, ) -> ast.SelectQuery: - _extra_fields: List[ast.Expr] = [ + _extra_fields: list[ast.Expr] = [ ast.Alias(alias=field, expr=ast.Field(chain=[self.EVENT_TABLE_ALIAS, field])) for field in self._extra_fields ] - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ ast.Alias(alias="timestamp", expr=ast.Field(chain=[self.EVENT_TABLE_ALIAS, "timestamp"])), ast.Alias(alias="aggregation_target", expr=self._aggregation_target_expr()), *_extra_fields, @@ -132,7 +132,7 @@ def _entity_expr(self, skip_entity_filter: bool) -> ast.Expr | None: if skip_entity_filter is True: return None - events: Set[Union[int, str, None]] = set() + events: set[Union[int, str, None]] = set() for node in [*query.series, *exclusions]: if isinstance(node, EventsNode) or isinstance(node, FunnelExclusionEventsNode): @@ -157,5 +157,5 @@ def _entity_expr(self, skip_entity_filter: bool) -> ast.Expr | None: op=ast.CompareOperationOp.In, ) - def _properties_expr(self) -> List[ast.Expr]: + def _properties_expr(self) -> list[ast.Expr]: return Properties(context=self.context).to_exprs() diff --git a/posthog/hogql_queries/insights/funnels/funnel_persons.py b/posthog/hogql_queries/insights/funnels/funnel_persons.py index 68781c6bbd0c8..5fc06a07a7d4d 100644 --- a/posthog/hogql_queries/insights/funnels/funnel_persons.py +++ b/posthog/hogql_queries/insights/funnels/funnel_persons.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from posthog.hogql import ast from posthog.hogql_queries.insights.funnels.funnel import Funnel @@ -7,9 +7,9 @@ class FunnelActors(Funnel): def actor_query( self, - extra_fields: Optional[List[str]] = None, + extra_fields: Optional[list[str]] = None, ) -> ast.SelectQuery: - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ ast.Alias(alias="actor_id", expr=ast.Field(chain=["aggregation_target"])), *self._get_funnel_person_step_events(), *self._get_timestamp_outer_select(), diff --git a/posthog/hogql_queries/insights/funnels/funnel_query_context.py b/posthog/hogql_queries/insights/funnels/funnel_query_context.py index 3b777e3ff8026..499dc3eb9ed4c 100644 --- a/posthog/hogql_queries/insights/funnels/funnel_query_context.py +++ b/posthog/hogql_queries/insights/funnels/funnel_query_context.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import Optional, Union from posthog.hogql.constants import LimitContext from posthog.hogql.timings import HogQLTimings from posthog.hogql_queries.insights.query_context import QueryContext @@ -25,7 +25,7 @@ class FunnelQueryContext(QueryContext): interval: IntervalType - breakdown: List[Union[str, int]] | str | int | None + breakdown: list[Union[str, int]] | str | int | None breakdownType: BreakdownType breakdownAttributionType: BreakdownAttributionType @@ -36,7 +36,7 @@ class FunnelQueryContext(QueryContext): includeTimestamp: Optional[bool] includePrecedingTimestamp: Optional[bool] - includeProperties: List[str] + includeProperties: list[str] includeFinalMatchingEvents: Optional[bool] def __init__( @@ -48,7 +48,7 @@ def __init__( limit_context: Optional[LimitContext] = None, include_timestamp: Optional[bool] = None, include_preceding_timestamp: Optional[bool] = None, - include_properties: Optional[List[str]] = None, + include_properties: Optional[list[str]] = None, include_final_matching_events: Optional[bool] = None, ): super().__init__(query=query, team=team, timings=timings, modifiers=modifiers, limit_context=limit_context) @@ -98,7 +98,7 @@ def __init__( "hogql", None, ]: - boxed_breakdown: List[Union[str, int]] = box_value(self.breakdownFilter.breakdown) + boxed_breakdown: list[Union[str, int]] = box_value(self.breakdownFilter.breakdown) self.breakdown = boxed_breakdown else: self.breakdown = self.breakdownFilter.breakdown # type: ignore diff --git a/posthog/hogql_queries/insights/funnels/funnel_strict.py b/posthog/hogql_queries/insights/funnels/funnel_strict.py index 1bea66772a6f5..1b5bf73ad5033 100644 --- a/posthog/hogql_queries/insights/funnels/funnel_strict.py +++ b/posthog/hogql_queries/insights/funnels/funnel_strict.py @@ -1,5 +1,3 @@ -from typing import List - from posthog.hogql import ast from posthog.hogql.parser import parse_expr from posthog.hogql_queries.insights.funnels.base import FunnelBase @@ -11,7 +9,7 @@ def get_query(self): breakdown_exprs = self._get_breakdown_prop_expr() - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ *self._get_count_columns(max_steps), *self._get_step_time_avgs(max_steps), *self._get_step_time_median(max_steps), @@ -30,13 +28,13 @@ def get_step_counts_query(self): inner_timestamps, outer_timestamps = self._get_timestamp_selects() person_and_group_properties = self._get_person_and_group_properties(aggregate=True) - group_by_columns: List[ast.Expr] = [ + group_by_columns: list[ast.Expr] = [ ast.Field(chain=["aggregation_target"]), ast.Field(chain=["steps"]), *breakdown_exprs, ] - outer_select: List[ast.Expr] = [ + outer_select: list[ast.Expr] = [ *group_by_columns, *self._get_step_time_avgs(max_steps, inner_query=True), *self._get_step_time_median(max_steps, inner_query=True), @@ -50,7 +48,7 @@ def get_step_counts_query(self): f"max(steps) over (PARTITION BY aggregation_target {self._get_breakdown_prop()}) as max_steps" ) - inner_select: List[ast.Expr] = [ + inner_select: list[ast.Expr] = [ *group_by_columns, max_steps_expr, *self._get_step_time_names(max_steps), @@ -77,7 +75,7 @@ def get_step_counts_query(self): def get_step_counts_without_aggregation_query(self): max_steps = self.context.max_steps - select_inner: List[ast.Expr] = [ + select_inner: list[ast.Expr] = [ ast.Field(chain=["aggregation_target"]), ast.Field(chain=["timestamp"]), *self._get_partition_cols(1, max_steps), @@ -87,7 +85,7 @@ def get_step_counts_without_aggregation_query(self): select_from_inner = self._get_inner_event_query(skip_entity_filter=True, skip_step_filter=True) inner_query = ast.SelectQuery(select=select_inner, select_from=ast.JoinExpr(table=select_from_inner)) - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ ast.Field(chain=["*"]), ast.Alias(alias="steps", expr=self._get_sorting_condition(max_steps, max_steps)), *self._get_step_times(max_steps), @@ -101,7 +99,7 @@ def get_step_counts_without_aggregation_query(self): return ast.SelectQuery(select=select, select_from=select_from, where=where) def _get_partition_cols(self, level_index: int, max_steps: int): - exprs: List[ast.Expr] = [] + exprs: list[ast.Expr] = [] for i in range(0, max_steps): exprs.append(ast.Field(chain=[f"step_{i}"])) diff --git a/posthog/hogql_queries/insights/funnels/funnel_strict_persons.py b/posthog/hogql_queries/insights/funnels/funnel_strict_persons.py index f55afbd218266..299bd982b972b 100644 --- a/posthog/hogql_queries/insights/funnels/funnel_strict_persons.py +++ b/posthog/hogql_queries/insights/funnels/funnel_strict_persons.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from posthog.hogql import ast from posthog.hogql_queries.insights.funnels.funnel_strict import FunnelStrict @@ -7,9 +7,9 @@ class FunnelStrictActors(FunnelStrict): def actor_query( self, - extra_fields: Optional[List[str]] = None, + extra_fields: Optional[list[str]] = None, ) -> ast.SelectQuery: - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ ast.Alias(alias="actor_id", expr=ast.Field(chain=["aggregation_target"])), *self._get_funnel_person_step_events(), *self._get_timestamp_outer_select(), diff --git a/posthog/hogql_queries/insights/funnels/funnel_trends.py b/posthog/hogql_queries/insights/funnels/funnel_trends.py index 9d486f1b06196..964f5d05cc6d0 100644 --- a/posthog/hogql_queries/insights/funnels/funnel_trends.py +++ b/posthog/hogql_queries/insights/funnels/funnel_trends.py @@ -1,6 +1,6 @@ from datetime import datetime from itertools import groupby -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional from posthog.hogql import ast from posthog.hogql.parser import parse_expr from posthog.hogql_queries.insights.funnels.base import FunnelBase @@ -58,7 +58,7 @@ def __init__(self, context: FunnelQueryContext, just_summarize=False): self.just_summarize = just_summarize self.funnel_order = get_funnel_order_class(self.context.funnelsFilter)(context=self.context) - def _format_results(self, results) -> List[Dict[str, Any]]: + def _format_results(self, results) -> list[dict[str, Any]]: query = self.context.query breakdown_clause = self._get_breakdown_prop() @@ -75,7 +75,7 @@ def _format_results(self, results) -> List[Dict[str, Any]]: if breakdown_clause: if isinstance(period_row[-1], str) or ( - isinstance(period_row[-1], List) and all(isinstance(item, str) for item in period_row[-1]) + isinstance(period_row[-1], list) and all(isinstance(item, str) for item in period_row[-1]) ): serialized_result.update({"breakdown_value": (period_row[-1])}) else: @@ -145,7 +145,7 @@ def get_query(self) -> ast.SelectQuery: breakdown_clause = self._get_breakdown_prop_expr() - data_select: List[ast.Expr] = [ + data_select: list[ast.Expr] = [ ast.Field(chain=["entrance_period_start"]), parse_expr(f"countIf({reached_from_step_count_condition}) AS reached_from_step_count"), parse_expr(f"countIf({reached_to_step_count_condition}) AS reached_to_step_count"), @@ -163,10 +163,10 @@ def get_query(self) -> ast.SelectQuery: args=[ast.Call(name="toDateTime", args=[(ast.Constant(value=formatted_date_to))])], ) data_select_from = ast.JoinExpr(table=step_counts) - data_group_by: List[ast.Expr] = [ast.Field(chain=["entrance_period_start"]), *breakdown_clause] + data_group_by: list[ast.Expr] = [ast.Field(chain=["entrance_period_start"]), *breakdown_clause] data_query = ast.SelectQuery(select=data_select, select_from=data_select_from, group_by=data_group_by) - fill_select: List[ast.Expr] = [ + fill_select: list[ast.Expr] = [ ast.Alias( alias="entrance_period_start", expr=ast.ArithmeticOperation( @@ -249,7 +249,7 @@ def get_query(self) -> ast.SelectQuery: ), ) - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ ast.Field(chain=["fill", "entrance_period_start"]), ast.Field(chain=["reached_from_step_count"]), ast.Field(chain=["reached_to_step_count"]), @@ -263,7 +263,7 @@ def get_query(self) -> ast.SelectQuery: alias="data", next_join=fill_join, ) - order_by: List[ast.OrderExpr] = [ + order_by: list[ast.OrderExpr] = [ ast.OrderExpr(expr=ast.Field(chain=["fill", "entrance_period_start"]), order="ASC") ] @@ -281,7 +281,7 @@ def get_step_counts_without_aggregation_query( steps_per_person_query = self.funnel_order.get_step_counts_without_aggregation_query() - event_select_clause: List[ast.Expr] = [] + event_select_clause: list[ast.Expr] = [] if ( hasattr(self.context, "actorsQuery") and self.context.actorsQuery is not None @@ -291,7 +291,7 @@ def get_step_counts_without_aggregation_query( breakdown_clause = self._get_breakdown_prop_expr() - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ ast.Field(chain=["aggregation_target"]), ast.Alias(alias="entrance_period_start", expr=get_start_of_interval_hogql(interval.value, team=team)), parse_expr("max(steps) AS steps_completed"), @@ -309,7 +309,7 @@ def get_step_counts_without_aggregation_query( if specific_entrance_period_start else None ) - group_by: List[ast.Expr] = [ + group_by: list[ast.Expr] = [ ast.Field(chain=["aggregation_target"]), ast.Field(chain=["entrance_period_start"]), *breakdown_clause, @@ -317,7 +317,7 @@ def get_step_counts_without_aggregation_query( return ast.SelectQuery(select=select, select_from=select_from, where=where, group_by=group_by) - def get_steps_reached_conditions(self) -> Tuple[str, str, str]: + def get_steps_reached_conditions(self) -> tuple[str, str, str]: funnelsFilter, max_steps = self.context.funnelsFilter, self.context.max_steps # How many steps must have been done to count for the denominator of a funnel trends data point diff --git a/posthog/hogql_queries/insights/funnels/funnel_trends_persons.py b/posthog/hogql_queries/insights/funnels/funnel_trends_persons.py index c90a9ed576270..c124265ba653e 100644 --- a/posthog/hogql_queries/insights/funnels/funnel_trends_persons.py +++ b/posthog/hogql_queries/insights/funnels/funnel_trends_persons.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import List from rest_framework.exceptions import ValidationError @@ -39,7 +38,7 @@ def __init__(self, context: FunnelQueryContext, just_summarize=False): self.dropOff = actorsQuery.funnelTrendsDropOff self.entrancePeriodStart = entrancePeriodStart - def _get_funnel_person_step_events(self) -> List[ast.Expr]: + def _get_funnel_person_step_events(self) -> list[ast.Expr]: if ( hasattr(self.context, "actorsQuery") and self.context.actorsQuery is not None @@ -71,7 +70,7 @@ def actor_query(self) -> ast.SelectQuery: did_not_reach_to_step_count_condition, ) = self.get_steps_reached_conditions() - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ ast.Alias(alias="actor_id", expr=ast.Field(chain=["aggregation_target"])), *self._get_funnel_person_step_events(), ] diff --git a/posthog/hogql_queries/insights/funnels/funnel_unordered.py b/posthog/hogql_queries/insights/funnels/funnel_unordered.py index af3ed18d4f82e..4ac87866d7fcc 100644 --- a/posthog/hogql_queries/insights/funnels/funnel_unordered.py +++ b/posthog/hogql_queries/insights/funnels/funnel_unordered.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Optional import uuid from rest_framework.exceptions import ValidationError @@ -45,7 +45,7 @@ def get_query(self): breakdown_exprs = self._get_breakdown_prop_expr() - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ *self._get_count_columns(max_steps), *self._get_step_time_avgs(max_steps), *self._get_step_time_median(max_steps), @@ -64,13 +64,13 @@ def get_step_counts_query(self): inner_timestamps, outer_timestamps = self._get_timestamp_selects() person_and_group_properties = self._get_person_and_group_properties(aggregate=True) - group_by_columns: List[ast.Expr] = [ + group_by_columns: list[ast.Expr] = [ ast.Field(chain=["aggregation_target"]), ast.Field(chain=["steps"]), *breakdown_exprs, ] - outer_select: List[ast.Expr] = [ + outer_select: list[ast.Expr] = [ *group_by_columns, *self._get_step_time_avgs(max_steps, inner_query=True), *self._get_step_time_median(max_steps, inner_query=True), @@ -82,7 +82,7 @@ def get_step_counts_query(self): f"max(steps) over (PARTITION BY aggregation_target {self._get_breakdown_prop()}) as max_steps" ) - inner_select: List[ast.Expr] = [ + inner_select: list[ast.Expr] = [ *group_by_columns, max_steps_expr, *self._get_step_time_names(max_steps), @@ -106,7 +106,7 @@ def get_step_counts_query(self): def get_step_counts_without_aggregation_query(self): max_steps = self.context.max_steps - union_queries: List[ast.SelectQuery] = [] + union_queries: list[ast.SelectQuery] = [] entities_to_use = list(self.context.query.series) for i in range(max_steps): @@ -153,11 +153,11 @@ def get_step_counts_without_aggregation_query(self): return ast.SelectUnionQuery(select_queries=union_queries) - def _get_step_times(self, max_steps: int) -> List[ast.Expr]: + def _get_step_times(self, max_steps: int) -> list[ast.Expr]: windowInterval = self.context.funnelWindowInterval windowIntervalUnit = funnel_window_interval_unit_to_sql(self.context.funnelWindowIntervalUnit) - exprs: List[ast.Expr] = [] + exprs: list[ast.Expr] = [] conversion_times_elements = [] for i in range(max_steps): @@ -175,7 +175,7 @@ def _get_step_times(self, max_steps: int) -> List[ast.Expr]: return exprs - def get_sorting_condition(self, max_steps: int) -> List[ast.Expr]: + def get_sorting_condition(self, max_steps: int) -> list[ast.Expr]: windowInterval = self.context.funnelWindowInterval windowIntervalUnit = funnel_window_interval_unit_to_sql(self.context.funnelWindowIntervalUnit) @@ -187,7 +187,7 @@ def get_sorting_condition(self, max_steps: int) -> List[ast.Expr]: conditions.append(parse_expr(f"arraySort([{','.join(event_times_elements)}]) as event_times")) # replacement of latest_i for whatever query part requires it, just like conversion_times - basic_conditions: List[str] = [] + basic_conditions: list[str] = [] for i in range(1, max_steps): basic_conditions.append( f"if(latest_0 < latest_{i} AND latest_{i} <= toTimeZone(latest_0, 'UTC') + INTERVAL {windowInterval} {windowIntervalUnit}, 1, 0)" @@ -199,7 +199,7 @@ def get_sorting_condition(self, max_steps: int) -> List[ast.Expr]: else: return [ast.Alias(alias="steps", expr=ast.Constant(value=1))] - def _get_exclusion_condition(self) -> List[ast.Expr]: + def _get_exclusion_condition(self) -> list[ast.Expr]: funnelsFilter = self.context.funnelsFilter windowInterval = self.context.funnelWindowInterval windowIntervalUnit = funnel_window_interval_unit_to_sql(self.context.funnelWindowIntervalUnit) @@ -207,7 +207,7 @@ def _get_exclusion_condition(self) -> List[ast.Expr]: if not funnelsFilter.exclusions: return [] - conditions: List[ast.Expr] = [] + conditions: list[ast.Expr] = [] for exclusion_id, exclusion in enumerate(funnelsFilter.exclusions): from_time = f"latest_{exclusion.funnelFromStep}" @@ -233,9 +233,9 @@ def _serialize_step( step: ActionsNode | EventsNode | DataWarehouseNode, count: int, index: int, - people: Optional[List[uuid.UUID]] = None, + people: Optional[list[uuid.UUID]] = None, sampling_factor: Optional[float] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: if isinstance(step, DataWarehouseNode): raise NotImplementedError("Data Warehouse queries are not supported in funnels") diff --git a/posthog/hogql_queries/insights/funnels/funnel_unordered_persons.py b/posthog/hogql_queries/insights/funnels/funnel_unordered_persons.py index a378f044b5d56..ad1086bdc3324 100644 --- a/posthog/hogql_queries/insights/funnels/funnel_unordered_persons.py +++ b/posthog/hogql_queries/insights/funnels/funnel_unordered_persons.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from posthog.hogql import ast from posthog.hogql.parser import parse_expr @@ -6,7 +6,7 @@ class FunnelUnorderedActors(FunnelUnordered): - def _get_funnel_person_step_events(self) -> List[ast.Expr]: + def _get_funnel_person_step_events(self) -> list[ast.Expr]: # Unordered funnels does not support matching events (and thereby recordings), # but it simplifies the logic if we return an empty array for matching events if ( @@ -19,9 +19,9 @@ def _get_funnel_person_step_events(self) -> List[ast.Expr]: def actor_query( self, - extra_fields: Optional[List[str]] = None, + extra_fields: Optional[list[str]] = None, ) -> ast.SelectQuery: - select: List[ast.Expr] = [ + select: list[ast.Expr] = [ ast.Alias(alias="actor_id", expr=ast.Field(chain=["aggregation_target"])), *self._get_funnel_person_step_events(), *self._get_timestamp_outer_select(), diff --git a/posthog/hogql_queries/insights/funnels/funnels_query_runner.py b/posthog/hogql_queries/insights/funnels/funnels_query_runner.py index d2ec04e3e8489..3e1173b276091 100644 --- a/posthog/hogql_queries/insights/funnels/funnels_query_runner.py +++ b/posthog/hogql_queries/insights/funnels/funnels_query_runner.py @@ -1,6 +1,6 @@ from datetime import timedelta from math import ceil -from typing import Optional, Any, Dict +from typing import Optional, Any from django.utils.timezone import datetime from posthog.caching.insights_api import ( @@ -37,7 +37,7 @@ class FunnelsQueryRunner(QueryRunner): def __init__( self, - query: FunnelsQuery | Dict[str, Any], + query: FunnelsQuery | dict[str, Any], team: Team, timings: Optional[HogQLTimings] = None, modifiers: Optional[HogQLQueryModifiers] = None, diff --git a/posthog/hogql_queries/insights/funnels/test/breakdown_cases.py b/posthog/hogql_queries/insights/funnels/test/breakdown_cases.py index db5e882963e9f..2b1b08f444553 100644 --- a/posthog/hogql_queries/insights/funnels/test/breakdown_cases.py +++ b/posthog/hogql_queries/insights/funnels/test/breakdown_cases.py @@ -2,7 +2,8 @@ from datetime import datetime from string import ascii_lowercase -from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast +from typing import Any, Literal, Optional, Union, cast +from collections.abc import Callable from posthog.constants import INSIGHT_FUNNELS, FunnelOrderType from posthog.hogql_queries.insights.funnels.funnels_query_runner import FunnelsQueryRunner @@ -30,7 +31,7 @@ class FunnelStepResult: name: str count: int - breakdown: Union[List[str], str] + breakdown: Union[list[str], str] average_conversion_time: Optional[float] = None median_conversion_time: Optional[float] = None type: Literal["events", "actions"] = "events" @@ -51,8 +52,8 @@ def _get_actor_ids_at_step(self, filter, funnel_step, breakdown_value=None): return [val["id"] for val in serialized_result] - def _assert_funnel_breakdown_result_is_correct(self, result, steps: List[FunnelStepResult]): - def funnel_result(step: FunnelStepResult, order: int) -> Dict[str, Any]: + def _assert_funnel_breakdown_result_is_correct(self, result, steps: list[FunnelStepResult]): + def funnel_result(step: FunnelStepResult, order: int) -> dict[str, Any]: return { "action_id": step.name if step.type == "events" else step.action_id, "name": step.name, @@ -2695,8 +2696,8 @@ def _create_groups(self): properties={"industry": "random"}, ) - def _assert_funnel_breakdown_result_is_correct(self, result, steps: List[FunnelStepResult]): - def funnel_result(step: FunnelStepResult, order: int) -> Dict[str, Any]: + def _assert_funnel_breakdown_result_is_correct(self, result, steps: list[FunnelStepResult]): + def funnel_result(step: FunnelStepResult, order: int) -> dict[str, Any]: return { "action_id": step.name if step.type == "events" else step.action_id, "name": step.name, @@ -3067,11 +3068,11 @@ def test_funnel_aggregate_by_groups_breakdown_group_person_on_events(self): return TestFunnelBreakdownGroup -def sort_breakdown_funnel_results(results: List[Dict[int, Any]]): +def sort_breakdown_funnel_results(results: list[dict[int, Any]]): return sorted(results, key=lambda r: r[0]["breakdown_value"]) -def assert_funnel_results_equal(left: List[Dict[str, Any]], right: List[Dict[str, Any]]): +def assert_funnel_results_equal(left: list[dict[str, Any]], right: list[dict[str, Any]]): """ Helper to be able to compare two funnel results, but exclude people urls from the comparison, as these include: @@ -3081,7 +3082,7 @@ def assert_funnel_results_equal(left: List[Dict[str, Any]], right: List[Dict[str 2. contain timestamps which are not stable across runs """ - def _filter(steps: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _filter(steps: list[dict[str, Any]]) -> list[dict[str, Any]]: return [{**step, "converted_people_url": None, "dropped_people_url": None} for step in steps] assert len(left) == len(right) diff --git a/posthog/hogql_queries/insights/funnels/test/test_funnel_breakdowns_by_current_url.py b/posthog/hogql_queries/insights/funnels/test/test_funnel_breakdowns_by_current_url.py index 859f3e627aab7..aef262ba22edb 100644 --- a/posthog/hogql_queries/insights/funnels/test/test_funnel_breakdowns_by_current_url.py +++ b/posthog/hogql_queries/insights/funnels/test/test_funnel_breakdowns_by_current_url.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Dict, cast, Optional +from typing import cast, Optional from posthog.hogql_queries.insights.funnels.funnels_query_runner import FunnelsQueryRunner from posthog.hogql_queries.legacy_compatibility.filter_to_query import filter_to_query @@ -116,7 +116,7 @@ def setUp(self): journeys_for(journey, team=self.team, create_people=True) - def _run(self, extra: Optional[Dict] = None, events_extra: Optional[Dict] = None): + def _run(self, extra: Optional[dict] = None, events_extra: Optional[dict] = None): if events_extra is None: events_extra = {} if extra is None: diff --git a/posthog/hogql_queries/insights/funnels/test/test_funnel_correlation.py b/posthog/hogql_queries/insights/funnels/test/test_funnel_correlation.py index f69eb3c6977b6..4db744a6d9280 100644 --- a/posthog/hogql_queries/insights/funnels/test/test_funnel_correlation.py +++ b/posthog/hogql_queries/insights/funnels/test/test_funnel_correlation.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, cast +from typing import Any, cast import unittest from rest_framework.exceptions import ValidationError @@ -77,7 +77,7 @@ def _get_events_for_filters( result, skewed_totals, _, _ = FunnelCorrelationQueryRunner(query=correlation_query, team=self.team)._calculate() return result, skewed_totals - def _get_actors_for_event(self, filters: Dict[str, Any], event_name: str, properties=None, success=True): + def _get_actors_for_event(self, filters: dict[str, Any], event_name: str, properties=None, success=True): serialized_actors = get_actors( filters, self.team, @@ -87,7 +87,7 @@ def _get_actors_for_event(self, filters: Dict[str, Any], event_name: str, proper return [str(row[0]) for row in serialized_actors] def _get_actors_for_property( - self, filters: Dict[str, Any], property_values: list, success=True, funnelCorrelationNames=None + self, filters: dict[str, Any], property_values: list, success=True, funnelCorrelationNames=None ): funnelCorrelationPropertyValues = [ ( diff --git a/posthog/hogql_queries/insights/funnels/test/test_funnel_correlations_persons.py b/posthog/hogql_queries/insights/funnels/test/test_funnel_correlations_persons.py index f324dcfcf7c3a..223b24a949b3e 100644 --- a/posthog/hogql_queries/insights/funnels/test/test_funnel_correlations_persons.py +++ b/posthog/hogql_queries/insights/funnels/test/test_funnel_correlations_persons.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, cast +from typing import Any, Optional, cast from datetime import datetime, timedelta from uuid import UUID @@ -37,7 +37,7 @@ def get_actors( - filters: Dict[str, Any], + filters: dict[str, Any], team: Team, funnelCorrelationType: Optional[FunnelCorrelationResultsType] = FunnelCorrelationResultsType.events, funnelCorrelationNames=None, diff --git a/posthog/hogql_queries/insights/funnels/test/test_funnel_persons.py b/posthog/hogql_queries/insights/funnels/test/test_funnel_persons.py index dec7bdd933b3e..37d9b853404b7 100644 --- a/posthog/hogql_queries/insights/funnels/test/test_funnel_persons.py +++ b/posthog/hogql_queries/insights/funnels/test/test_funnel_persons.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from typing import Dict, List, Optional, cast, Any +from typing import Optional, cast, Any from uuid import UUID from django.utils import timezone @@ -32,11 +32,11 @@ def get_actors( - filters: Dict[str, Any], + filters: dict[str, Any], team: Team, funnelStep: Optional[int] = None, - funnelCustomSteps: Optional[List[int]] = None, - funnelStepBreakdown: Optional[str | float | List[str | float]] = None, + funnelCustomSteps: Optional[list[int]] = None, + funnelStepBreakdown: Optional[str | float | list[str | float]] = None, funnelTrendsDropOff: Optional[bool] = None, funnelTrendsEntrancePeriodStart: Optional[str] = None, offset: Optional[int] = None, diff --git a/posthog/hogql_queries/insights/funnels/utils.py b/posthog/hogql_queries/insights/funnels/utils.py index 95374f179e1af..7aea066883eda 100644 --- a/posthog/hogql_queries/insights/funnels/utils.py +++ b/posthog/hogql_queries/insights/funnels/utils.py @@ -1,4 +1,3 @@ -from typing import List from posthog.constants import FUNNEL_WINDOW_INTERVAL_TYPES from posthog.hogql import ast from posthog.hogql.parser import parse_expr @@ -61,7 +60,7 @@ def funnel_window_interval_unit_to_sql( def get_breakdown_expr( - breakdowns: List[str | int] | str | int, properties_column: str, normalize_url: bool | None = False + breakdowns: list[str | int] | str | int, properties_column: str, normalize_url: bool | None = False ) -> ast.Expr: if isinstance(breakdowns, str) or isinstance(breakdowns, int) or breakdowns is None: return ast.Call( diff --git a/posthog/hogql_queries/insights/lifecycle_query_runner.py b/posthog/hogql_queries/insights/lifecycle_query_runner.py index 42b35d6b4df51..5e11dcdcae0ec 100644 --- a/posthog/hogql_queries/insights/lifecycle_query_runner.py +++ b/posthog/hogql_queries/insights/lifecycle_query_runner.py @@ -1,6 +1,6 @@ from datetime import timedelta from math import ceil -from typing import Optional, List +from typing import Optional from django.utils.timezone import datetime from posthog.caching.insights_api import ( @@ -225,7 +225,7 @@ def query_date_range(self): @cached_property def event_filter(self) -> ast.Expr: - event_filters: List[ast.Expr] = [] + event_filters: list[ast.Expr] = [] with self.timings.measure("date_range"): event_filters.append( parse_expr( diff --git a/posthog/hogql_queries/insights/paths_query_runner.py b/posthog/hogql_queries/insights/paths_query_runner.py index ca7890735f814..8c2bc84d821ad 100644 --- a/posthog/hogql_queries/insights/paths_query_runner.py +++ b/posthog/hogql_queries/insights/paths_query_runner.py @@ -3,7 +3,7 @@ from datetime import datetime, timedelta from math import ceil from re import escape -from typing import Any, Dict, Literal, cast +from typing import Any, Literal, cast from typing import Optional from posthog.caching.insights_api import BASE_MINIMUM_INSIGHT_REFRESH_INTERVAL, REDUCED_MINIMUM_INSIGHT_REFRESH_INTERVAL @@ -47,7 +47,7 @@ class PathsQueryRunner(QueryRunner): def __init__( self, - query: PathsQuery | Dict[str, Any], + query: PathsQuery | dict[str, Any], team: Team, timings: Optional[HogQLTimings] = None, modifiers: Optional[HogQLQueryModifiers] = None, diff --git a/posthog/hogql_queries/insights/retention_query_runner.py b/posthog/hogql_queries/insights/retention_query_runner.py index ac15ded6728b1..f79af288ca665 100644 --- a/posthog/hogql_queries/insights/retention_query_runner.py +++ b/posthog/hogql_queries/insights/retention_query_runner.py @@ -1,6 +1,6 @@ from datetime import datetime, timedelta from math import ceil -from typing import Any, Dict +from typing import Any from typing import Optional from posthog.caching.insights_api import BASE_MINIMUM_INSIGHT_REFRESH_INTERVAL, REDUCED_MINIMUM_INSIGHT_REFRESH_INTERVAL @@ -39,7 +39,7 @@ class RetentionQueryRunner(QueryRunner): def __init__( self, - query: RetentionQuery | Dict[str, Any], + query: RetentionQuery | dict[str, Any], team: Team, timings: Optional[HogQLTimings] = None, modifiers: Optional[HogQLQueryModifiers] = None, diff --git a/posthog/hogql_queries/insights/stickiness_query_runner.py b/posthog/hogql_queries/insights/stickiness_query_runner.py index d9096f05853b6..24bb2504de6f2 100644 --- a/posthog/hogql_queries/insights/stickiness_query_runner.py +++ b/posthog/hogql_queries/insights/stickiness_query_runner.py @@ -1,6 +1,6 @@ from datetime import timedelta from math import ceil -from typing import List, Optional, Any, Dict, cast +from typing import Optional, Any, cast from django.utils.timezone import datetime from posthog.caching.insights_api import ( @@ -47,11 +47,11 @@ def __init__( class StickinessQueryRunner(QueryRunner): query: StickinessQuery query_type = StickinessQuery - series: List[SeriesWithExtras] + series: list[SeriesWithExtras] def __init__( self, - query: StickinessQuery | Dict[str, Any], + query: StickinessQuery | dict[str, Any], team: Team, timings: Optional[HogQLTimings] = None, modifiers: Optional[HogQLQueryModifiers] = None, @@ -134,7 +134,7 @@ def _events_query(self, series_with_extra: SeriesWithExtras) -> ast.SelectQuery: def to_query(self) -> ast.SelectUnionQuery: return ast.SelectUnionQuery(select_queries=self.to_queries()) - def to_queries(self) -> List[ast.SelectQuery]: + def to_queries(self) -> list[ast.SelectQuery]: queries = [] for series in self.series: @@ -174,7 +174,7 @@ def to_queries(self) -> List[ast.SelectQuery]: return queries def to_actors_query(self, interval_num: Optional[int] = None) -> ast.SelectQuery | ast.SelectUnionQuery: - queries: List[ast.SelectQuery] = [] + queries: list[ast.SelectQuery] = [] for series in self.series: events_query = self._events_query(series) @@ -253,7 +253,7 @@ def calculate(self): def where_clause(self, series_with_extra: SeriesWithExtras) -> ast.Expr: date_range = self.date_range(series_with_extra) series = series_with_extra.series - filters: List[ast.Expr] = [] + filters: list[ast.Expr] = [] # Dates filters.extend( @@ -344,7 +344,7 @@ def intervals_num(self): else: return delta.days - def setup_series(self) -> List[SeriesWithExtras]: + def setup_series(self) -> list[SeriesWithExtras]: series_with_extras = [ SeriesWithExtras( series, diff --git a/posthog/hogql_queries/insights/test/test_insight_actors_query_runner.py b/posthog/hogql_queries/insights/test/test_insight_actors_query_runner.py index bb963cf1f8b62..830ecc3982b6c 100644 --- a/posthog/hogql_queries/insights/test/test_insight_actors_query_runner.py +++ b/posthog/hogql_queries/insights/test/test_insight_actors_query_runner.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, Optional +from typing import Any, Optional from freezegun import freeze_time @@ -69,7 +69,7 @@ def _create_test_events(self): ] ) - def select(self, query: str, placeholders: Optional[Dict[str, Any]] = None): + def select(self, query: str, placeholders: Optional[dict[str, Any]] = None): if placeholders is None: placeholders = {} return execute_hogql_query( diff --git a/posthog/hogql_queries/insights/test/test_paths_query_runner.py b/posthog/hogql_queries/insights/test/test_paths_query_runner.py index b74102ba70510..0b82f33ca7e52 100644 --- a/posthog/hogql_queries/insights/test/test_paths_query_runner.py +++ b/posthog/hogql_queries/insights/test/test_paths_query_runner.py @@ -1,5 +1,4 @@ import dataclasses -from typing import Dict from dateutil.relativedelta import relativedelta from django.utils.timezone import now @@ -25,7 +24,7 @@ class MockEvent: distinct_id: str team: Team timestamp: str - properties: Dict + properties: dict class TestPaths(ClickhouseTestMixin, APIBaseTest): diff --git a/posthog/hogql_queries/insights/test/test_stickiness_query_runner.py b/posthog/hogql_queries/insights/test/test_stickiness_query_runner.py index 6e25827e6ecba..e61f4160276ab 100644 --- a/posthog/hogql_queries/insights/test/test_stickiness_query_runner.py +++ b/posthog/hogql_queries/insights/test/test_stickiness_query_runner.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, List, Optional, Union +from typing import Optional, Union from unittest.mock import MagicMock, patch from django.test import override_settings @@ -41,18 +41,18 @@ @dataclass class Series: event: str - timestamps: List[str] + timestamps: list[str] @dataclass class SeriesTestData: distinct_id: str - events: List[Series] - properties: Dict[str, str | int] + events: list[Series] + properties: dict[str, str | int] StickinessProperties = Union[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -74,9 +74,9 @@ class TestStickinessQueryRunner(APIBaseTest): default_date_from = "2020-01-11" default_date_to = "2020-01-20" - def _create_events(self, data: List[SeriesTestData]): + def _create_events(self, data: list[SeriesTestData]): person_result = [] - properties_to_create: Dict[str, str] = {} + properties_to_create: dict[str, str] = {} for person in data: first_timestamp = person.events[0].timestamps[0] @@ -194,7 +194,7 @@ def _create_test_events(self): def _run_query( self, - series: Optional[List[EventsNode | ActionsNode]] = None, + series: Optional[list[EventsNode | ActionsNode]] = None, date_from: Optional[str] = None, date_to: Optional[str] = None, interval: Optional[IntervalType] = None, @@ -203,7 +203,7 @@ def _run_query( filter_test_accounts: Optional[bool] = False, limit_context: Optional[LimitContext] = None, ): - query_series: List[EventsNode | ActionsNode] = [EventsNode(event="$pageview")] if series is None else series + query_series: list[EventsNode | ActionsNode] = [EventsNode(event="$pageview")] if series is None else series query_date_from = date_from or self.default_date_from query_date_to = None if date_to == "now" else date_to or self.default_date_to query_interval = interval or IntervalType.day @@ -223,8 +223,8 @@ def test_stickiness_runs(self): response = self._run_query() assert isinstance(response, StickinessQueryResponse) - assert isinstance(response.results, List) - assert isinstance(response.results[0], Dict) + assert isinstance(response.results, list) + assert isinstance(response.results[0], dict) @override_settings(PERSON_ON_EVENTS_V2_OVERRIDE=True) def test_stickiness_runs_with_poe(self): @@ -232,8 +232,8 @@ def test_stickiness_runs_with_poe(self): response = self._run_query() assert isinstance(response, StickinessQueryResponse) - assert isinstance(response.results, List) - assert isinstance(response.results[0], Dict) + assert isinstance(response.results, list) + assert isinstance(response.results[0], dict) def test_days(self): self._create_test_events() @@ -423,7 +423,7 @@ def test_property_filtering_hogql(self): def test_event_filtering(self): self._create_test_events() - series: List[EventsNode | ActionsNode] = [ + series: list[EventsNode | ActionsNode] = [ EventsNode( event="$pageview", properties=[EventPropertyFilter(key="$browser", operator=PropertyOperator.exact, value="Chrome")], @@ -450,7 +450,7 @@ def test_event_filtering(self): def test_any_event(self): self._create_test_events() - series: List[EventsNode | ActionsNode] = [ + series: list[EventsNode | ActionsNode] = [ EventsNode( event=None, ) @@ -484,7 +484,7 @@ def test_actions(self): properties=[{"key": "$browser", "type": "event", "value": "Chrome", "operator": "exact"}], ) - series: List[EventsNode | ActionsNode] = [ActionsNode(id=action.pk)] + series: list[EventsNode | ActionsNode] = [ActionsNode(id=action.pk)] response = self._run_query(series=series) @@ -541,7 +541,7 @@ def test_group_aggregations(self): self._create_test_groups() self._create_test_events() - series: List[EventsNode | ActionsNode] = [ + series: list[EventsNode | ActionsNode] = [ EventsNode(event="$pageview", math="unique_group", math_group_type_index=MathGroupTypeIndex.number_0) ] @@ -565,7 +565,7 @@ def test_group_aggregations(self): def test_hogql_aggregations(self): self._create_test_events() - series: List[EventsNode | ActionsNode] = [ + series: list[EventsNode | ActionsNode] = [ EventsNode(event="$pageview", math="hogql", math_hogql="e.properties.prop") ] diff --git a/posthog/hogql_queries/insights/trends/aggregation_operations.py b/posthog/hogql_queries/insights/trends/aggregation_operations.py index 1c356277548d0..2e716b2b1caea 100644 --- a/posthog/hogql_queries/insights/trends/aggregation_operations.py +++ b/posthog/hogql_queries/insights/trends/aggregation_operations.py @@ -1,4 +1,4 @@ -from typing import List, Optional, cast, Union +from typing import Optional, cast, Union from posthog.constants import NON_TIME_SERIES_DISPLAY_TYPES from posthog.hogql import ast from posthog.hogql.parser import parse_expr, parse_select @@ -13,8 +13,8 @@ class QueryAlternator: """Allows query_builder to modify the query without having to expost the whole AST interface""" _query: ast.SelectQuery - _selects: List[ast.Expr] - _group_bys: List[ast.Expr] + _selects: list[ast.Expr] + _group_bys: list[ast.Expr] _select_from: ast.JoinExpr | None def __init__(self, query: ast.SelectQuery | ast.SelectUnionQuery): @@ -143,7 +143,7 @@ def is_count_per_actor_variant(self): "p99_count_per_actor", ] - def _math_func(self, method: str, override_chain: Optional[List[str | int]]) -> ast.Call: + def _math_func(self, method: str, override_chain: Optional[list[str | int]]) -> ast.Call: if override_chain is not None: return ast.Call(name=method, args=[ast.Field(chain=override_chain)]) @@ -167,7 +167,7 @@ def _math_func(self, method: str, override_chain: Optional[List[str | int]]) -> return ast.Call(name=method, args=[ast.Field(chain=chain)]) - def _math_quantile(self, percentile: float, override_chain: Optional[List[str | int]]) -> ast.Call: + def _math_quantile(self, percentile: float, override_chain: Optional[list[str | int]]) -> ast.Call: if self.series.math_property == "$session_duration": chain = ["session_duration"] else: diff --git a/posthog/hogql_queries/insights/trends/breakdown.py b/posthog/hogql_queries/insights/trends/breakdown.py index e588ca30353f2..025d181bf81ef 100644 --- a/posthog/hogql_queries/insights/trends/breakdown.py +++ b/posthog/hogql_queries/insights/trends/breakdown.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple, Union, cast +from typing import Optional, Union, cast from posthog.hogql import ast from posthog.hogql.constants import LimitContext from posthog.hogql.parser import parse_expr @@ -30,7 +30,7 @@ class Breakdown: timings: HogQLTimings modifiers: HogQLQueryModifiers events_filter: ast.Expr - breakdown_values_override: Optional[List[str]] + breakdown_values_override: Optional[list[str]] limit_context: LimitContext def __init__( @@ -42,7 +42,7 @@ def __init__( timings: HogQLTimings, modifiers: HogQLQueryModifiers, events_filter: ast.Expr, - breakdown_values_override: Optional[List[str]] = None, + breakdown_values_override: Optional[list[str]] = None, limit_context: LimitContext = LimitContext.QUERY, ): self.team = team @@ -71,7 +71,7 @@ def is_session_type(self) -> bool: def is_histogram_breakdown(self) -> bool: return self.enabled and self.query.breakdownFilter.breakdown_histogram_bin_count is not None - def placeholders(self) -> Dict[str, ast.Expr]: + def placeholders(self) -> dict[str, ast.Expr]: values = self._breakdown_buckets_ast if self.is_histogram_breakdown else self._breakdown_values_ast return {"cross_join_breakdown_values": ast.Alias(alias="breakdown_value", expr=values)} @@ -106,7 +106,7 @@ def events_where_filter(self) -> ast.Expr | None: if self.query.breakdownFilter.breakdown == "all": return None - if isinstance(self.query.breakdownFilter.breakdown, List): + if isinstance(self.query.breakdownFilter.breakdown, list): or_clause = ast.Or( exprs=[ ast.CompareOperation( @@ -226,10 +226,10 @@ def _breakdown_values_ast(self) -> ast.Array: return ast.Array(exprs=exprs) @cached_property - def _all_breakdown_values(self) -> List[str | int | None]: + def _all_breakdown_values(self) -> list[str | int | None]: # Used in the actors query if self.breakdown_values_override is not None: - return cast(List[str | int | None], self.breakdown_values_override) + return cast(list[str | int | None], self.breakdown_values_override) if self.query.breakdownFilter is None: return [] @@ -245,18 +245,18 @@ def _all_breakdown_values(self) -> List[str | int | None]: modifiers=self.modifiers, limit_context=self.limit_context, ) - return cast(List[str | int | None], breakdown.get_breakdown_values()) + return cast(list[str | int | None], breakdown.get_breakdown_values()) @cached_property - def _breakdown_values(self) -> List[str | int]: + def _breakdown_values(self) -> list[str | int]: values = [BREAKDOWN_NULL_STRING_LABEL if v is None else v for v in self._all_breakdown_values] - return cast(List[str | int], values) + return cast(list[str | int], values) @cached_property def has_breakdown_values(self) -> bool: return len(self._breakdown_values) > 0 - def _get_breakdown_histogram_buckets(self) -> List[Tuple[float, float]]: + def _get_breakdown_histogram_buckets(self) -> list[tuple[float, float]]: buckets = [] values = self._breakdown_values @@ -275,7 +275,7 @@ def _get_breakdown_histogram_buckets(self) -> List[Tuple[float, float]]: return buckets def _get_breakdown_histogram_multi_if(self) -> ast.Expr: - multi_if_exprs: List[ast.Expr] = [] + multi_if_exprs: list[ast.Expr] = [] buckets = self._get_breakdown_histogram_buckets() diff --git a/posthog/hogql_queries/insights/trends/breakdown_values.py b/posthog/hogql_queries/insights/trends/breakdown_values.py index 6a9b9a24a22f0..b15897b360fde 100644 --- a/posthog/hogql_queries/insights/trends/breakdown_values.py +++ b/posthog/hogql_queries/insights/trends/breakdown_values.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union, Any +from typing import Optional, Union, Any from posthog.hogql import ast from posthog.hogql.constants import LimitContext, get_breakdown_limit_for_context, BREAKDOWN_VALUES_LIMIT_FOR_COUNTRIES from posthog.hogql.parser import parse_expr, parse_select @@ -30,7 +30,7 @@ class BreakdownValues: team: Team series: Union[EventsNode, ActionsNode, DataWarehouseNode] - breakdown_field: Union[str, float, List[Union[str, float]]] + breakdown_field: Union[str, float, list[Union[str, float]]] breakdown_type: BreakdownType events_filter: ast.Expr chart_display_type: ChartDisplayType @@ -76,12 +76,12 @@ def __init__( self.query_date_range = query_date_range self.modifiers = modifiers - def get_breakdown_values(self) -> List[str | int]: + def get_breakdown_values(self) -> list[str | int]: if self.breakdown_type == "cohort": if self.breakdown_field == "all": return [0] - if isinstance(self.breakdown_field, List): + if isinstance(self.breakdown_field, list): return [value if isinstance(value, str) else int(value) for value in self.breakdown_field] return [self.breakdown_field if isinstance(self.breakdown_field, str) else int(self.breakdown_field)] @@ -186,7 +186,7 @@ def get_breakdown_values(self) -> List[str | int]: ): inner_events_query.order_by[0].order = "ASC" - values: List[Any] + values: list[Any] if self.histogram_bin_count is not None: query = parse_select( """ diff --git a/posthog/hogql_queries/insights/trends/test/test_trends.py b/posthog/hogql_queries/insights/trends/test/test_trends.py index 8ba4aea1b3459..f34229e99ded7 100644 --- a/posthog/hogql_queries/insights/trends/test/test_trends.py +++ b/posthog/hogql_queries/insights/trends/test/test_trends.py @@ -1,7 +1,7 @@ import json import uuid from datetime import datetime -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union from unittest.mock import patch from zoneinfo import ZoneInfo @@ -68,8 +68,8 @@ from posthog.test.test_journeys import journeys_for -def breakdown_label(entity: Entity, value: Union[str, int]) -> Dict[str, Optional[Union[str, int]]]: - ret_dict: Dict[str, Optional[Union[str, int]]] = {} +def breakdown_label(entity: Entity, value: Union[str, int]) -> dict[str, Optional[Union[str, int]]]: + ret_dict: dict[str, Optional[Union[str, int]]] = {} if not value or not isinstance(value, str) or "cohort_" not in value: label = value if (value or isinstance(value, bool)) and value != "None" and value != "nan" else "Other" ret_dict["label"] = f"{entity.name} - {label}" @@ -103,7 +103,7 @@ def _create_cohort(**kwargs): return cohort -def _props(dict: Dict): +def _props(dict: dict): props = dict.get("properties", None) if not props: return None @@ -125,11 +125,11 @@ def _props(dict: Dict): def convert_filter_to_trends_query(filter: Filter) -> TrendsQuery: filter_as_dict = filter.to_dict() - events: List[EventsNode] = [] - actions: List[ActionsNode] = [] + events: list[EventsNode] = [] + actions: list[ActionsNode] = [] for event in filter.events: - if isinstance(event._data.get("properties", None), List): + if isinstance(event._data.get("properties", None), list): properties = clean_entity_properties(event._data.get("properties", None)) elif event._data.get("properties", None) is not None: values = event._data.get("properties", None).get("values", None) @@ -151,7 +151,7 @@ def convert_filter_to_trends_query(filter: Filter) -> TrendsQuery: ) for action in filter.actions: - if isinstance(action._data.get("properties", None), List): + if isinstance(action._data.get("properties", None), list): properties = clean_entity_properties(action._data.get("properties", None)) elif action._data.get("properties", None) is not None: values = action._data.get("properties", None).get("values", None) @@ -172,7 +172,7 @@ def convert_filter_to_trends_query(filter: Filter) -> TrendsQuery: ) ) - series: List[Union[EventsNode, ActionsNode, DataWarehouseNode]] = [*events, *actions] + series: list[Union[EventsNode, ActionsNode, DataWarehouseNode]] = [*events, *actions] tq = TrendsQuery( series=series, @@ -304,7 +304,7 @@ def _create_group(self, **kwargs): type=PropertyDefinition.Type.GROUP, ) - def _create_events(self, use_time=False) -> Tuple[Action, Person]: + def _create_events(self, use_time=False) -> tuple[Action, Person]: person = self._create_person( team_id=self.team.pk, distinct_ids=["blabla", "anonymous_id"], @@ -2080,7 +2080,7 @@ def test_trends_compare_hour_interval_relative_range(self): ], ) - def _test_events_with_dates(self, dates: List[str], result, query_time=None, **filter_params): + def _test_events_with_dates(self, dates: list[str], result, query_time=None, **filter_params): self._create_person(team_id=self.team.pk, distinct_ids=["person_1"], properties={"name": "John"}) for time in dates: with freeze_time(time): diff --git a/posthog/hogql_queries/insights/trends/test/test_trends_query_runner.py b/posthog/hogql_queries/insights/trends/test/test_trends_query_runner.py index 573bbf2c12e13..772d71922727b 100644 --- a/posthog/hogql_queries/insights/trends/test/test_trends_query_runner.py +++ b/posthog/hogql_queries/insights/trends/test/test_trends_query_runner.py @@ -1,7 +1,7 @@ import zoneinfo from dataclasses import dataclass from datetime import datetime -from typing import Dict, List, Optional +from typing import Optional from unittest.mock import MagicMock, patch from django.test import override_settings from freezegun import freeze_time @@ -49,14 +49,14 @@ @dataclass class Series: event: str - timestamps: List[str] + timestamps: list[str] @dataclass class SeriesTestData: distinct_id: str - events: List[Series] - properties: Dict[str, str | int] + events: list[Series] + properties: dict[str, str | int] @override_settings(IN_UNIT_TESTING=True) @@ -64,9 +64,9 @@ class TestTrendsQueryRunner(ClickhouseTestMixin, APIBaseTest): default_date_from = "2020-01-09" default_date_to = "2020-01-19" - def _create_events(self, data: List[SeriesTestData]): + def _create_events(self, data: list[SeriesTestData]): person_result = [] - properties_to_create: Dict[str, str] = {} + properties_to_create: dict[str, str] = {} for person in data: first_timestamp = person.events[0].timestamps[0] @@ -174,7 +174,7 @@ def _create_query_runner( date_from: str, date_to: Optional[str], interval: IntervalType, - series: Optional[List[EventsNode | ActionsNode]], + series: Optional[list[EventsNode | ActionsNode]], trends_filters: Optional[TrendsFilter] = None, breakdown: Optional[BreakdownFilter] = None, filter_test_accounts: Optional[bool] = None, @@ -182,7 +182,7 @@ def _create_query_runner( limit_context: Optional[LimitContext] = None, explicit_date: Optional[bool] = None, ) -> TrendsQueryRunner: - query_series: List[EventsNode | ActionsNode] = [EventsNode(event="$pageview")] if series is None else series + query_series: list[EventsNode | ActionsNode] = [EventsNode(event="$pageview")] if series is None else series query = TrendsQuery( dateRange=DateRange(date_from=date_from, date_to=date_to, explicitDate=explicit_date), interval=interval, @@ -198,7 +198,7 @@ def _run_trends_query( date_from: str, date_to: Optional[str], interval: IntervalType, - series: Optional[List[EventsNode | ActionsNode]], + series: Optional[list[EventsNode | ActionsNode]], trends_filters: Optional[TrendsFilter] = None, breakdown: Optional[BreakdownFilter] = None, *, diff --git a/posthog/hogql_queries/insights/trends/trends_query_builder.py b/posthog/hogql_queries/insights/trends/trends_query_builder.py index 82fbb849ef5d9..072f371e4c058 100644 --- a/posthog/hogql_queries/insights/trends/trends_query_builder.py +++ b/posthog/hogql_queries/insights/trends/trends_query_builder.py @@ -1,4 +1,4 @@ -from typing import List, Optional, cast +from typing import Optional, cast from posthog.hogql import ast from posthog.hogql.constants import LimitContext from posthog.hogql.parser import parse_expr, parse_select @@ -98,7 +98,7 @@ def build_actors_query( }, ) - def _get_date_subqueries(self, breakdown: Breakdown, ignore_breakdowns: bool = False) -> List[ast.SelectQuery]: + def _get_date_subqueries(self, breakdown: Breakdown, ignore_breakdowns: bool = False) -> list[ast.SelectQuery]: if not breakdown.enabled or ignore_breakdowns: return [ cast( @@ -473,7 +473,7 @@ def _events_filter( actors_query_time_frame: Optional[str] = None, ) -> ast.Expr: series = self.series - filters: List[ast.Expr] = [] + filters: list[ast.Expr] = [] # Dates if is_actors_query and actors_query_time_frame is not None: diff --git a/posthog/hogql_queries/insights/trends/trends_query_runner.py b/posthog/hogql_queries/insights/trends/trends_query_runner.py index 8629d17ec928a..6ceb2dd185739 100644 --- a/posthog/hogql_queries/insights/trends/trends_query_runner.py +++ b/posthog/hogql_queries/insights/trends/trends_query_runner.py @@ -7,7 +7,7 @@ from math import ceil from operator import itemgetter import threading -from typing import List, Optional, Any, Dict +from typing import Optional, Any from dateutil import parser from dateutil.relativedelta import relativedelta from django.conf import settings @@ -70,11 +70,11 @@ class TrendsQueryRunner(QueryRunner): query: TrendsQuery query_type = TrendsQuery - series: List[SeriesWithExtras] + series: list[SeriesWithExtras] def __init__( self, - query: TrendsQuery | Dict[str, Any], + query: TrendsQuery | dict[str, Any], team: Team, timings: Optional[HogQLTimings] = None, modifiers: Optional[HogQLQueryModifiers] = None, @@ -115,7 +115,7 @@ def to_query(self) -> ast.SelectUnionQuery: queries.extend(query.select_queries) return ast.SelectUnionQuery(select_queries=queries) - def to_queries(self) -> List[ast.SelectQuery | ast.SelectUnionQuery]: + def to_queries(self) -> list[ast.SelectQuery | ast.SelectUnionQuery]: queries = [] with self.timings.measure("trends_to_query"): for series in self.series: @@ -184,9 +184,9 @@ def to_actors_query( return query def to_actors_query_options(self) -> InsightActorsQueryOptionsResponse: - res_breakdown: List[BreakdownItem] | None = None - res_series: List[Series] = [] - res_compare: List[CompareItem] | None = None + res_breakdown: list[BreakdownItem] | None = None + res_series: list[Series] = [] + res_compare: list[CompareItem] | None = None # Days res_days: Optional[list[DayItem]] = ( @@ -239,7 +239,7 @@ def to_actors_query_options(self) -> InsightActorsQueryOptionsResponse: is_boolean_breakdown = self._is_breakdown_field_boolean() is_histogram_breakdown = breakdown.is_histogram_breakdown - breakdown_values: List[str | int] + breakdown_values: list[str | int] res_breakdown = [] if is_histogram_breakdown: @@ -289,9 +289,9 @@ def calculate(self): with self.timings.measure("printing_hogql_for_response"): response_hogql = to_printed_hogql(response_hogql_query, self.team, self.modifiers) - res_matrix: List[List[Any] | Any | None] = [None] * len(queries) - timings_matrix: List[List[QueryTiming] | None] = [None] * len(queries) - errors: List[Exception] = [] + res_matrix: list[list[Any] | Any | None] = [None] * len(queries) + timings_matrix: list[list[QueryTiming] | None] = [None] * len(queries) + errors: list[Exception] = [] def run(index: int, query: ast.SelectQuery | ast.SelectUnionQuery, is_parallel: bool): try: @@ -342,14 +342,14 @@ def run(index: int, query: ast.SelectQuery | ast.SelectUnionQuery, is_parallel: # Flatten res and timings res = [] for result in res_matrix: - if isinstance(result, List): + if isinstance(result, list): res.extend(result) else: res.append(result) timings = [] for result in timings_matrix: - if isinstance(result, List): + if isinstance(result, list): timings.extend(result) else: timings.append(result) @@ -555,7 +555,7 @@ def update_hogql_modifiers(self) -> None: self.modifiers.inCohortVia == InCohortVia.auto and self.query.breakdownFilter is not None and self.query.breakdownFilter.breakdown_type == "cohort" - and isinstance(self.query.breakdownFilter.breakdown, List) + and isinstance(self.query.breakdownFilter.breakdown, list) and len(self.query.breakdownFilter.breakdown) > 1 and not any(value == "all" for value in self.query.breakdownFilter.breakdown) ): @@ -575,7 +575,7 @@ def update_hogql_modifiers(self) -> None: self.modifiers.dataWarehouseEventsModifiers = datawarehouse_modifiers - def setup_series(self) -> List[SeriesWithExtras]: + def setup_series(self) -> list[SeriesWithExtras]: series_with_extras = [ SeriesWithExtras( series=series, @@ -593,7 +593,7 @@ def setup_series(self) -> List[SeriesWithExtras]: and self.query.breakdownFilter.breakdown_type == "cohort" ): updated_series = [] - if isinstance(self.query.breakdownFilter.breakdown, List): + if isinstance(self.query.breakdownFilter.breakdown, list): cohort_ids = self.query.breakdownFilter.breakdown elif self.query.breakdownFilter.breakdown is not None: cohort_ids = [self.query.breakdownFilter.breakdown] @@ -642,7 +642,7 @@ def setup_series(self) -> List[SeriesWithExtras]: return series_with_extras - def apply_formula(self, formula: str, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def apply_formula(self, formula: str, results: list[dict[str, Any]]) -> list[dict[str, Any]]: has_compare = bool(self.query.trendsFilter and self.query.trendsFilter.compare) has_breakdown = bool(self.query.breakdownFilter and self.query.breakdownFilter.breakdown) is_total_value = self._trends_display.should_aggregate_values() @@ -694,8 +694,8 @@ def apply_formula(self, formula: str, results: List[Dict[str, Any]]) -> List[Dic @staticmethod def apply_formula_to_results_group( - results_group: List[Dict[str, Any]], formula: str, aggregate_values: Optional[bool] = False - ) -> Dict[str, Any]: + results_group: list[dict[str, Any]], formula: str, aggregate_values: Optional[bool] = False + ) -> dict[str, Any]: """ Applies the formula to a list of results, resulting in a single, computed result. """ @@ -787,7 +787,7 @@ def _event_property( return "String" # TODO: Move this to posthog/hogql_queries/legacy_compatibility/query_to_filter.py - def _query_to_filter(self) -> Dict[str, Any]: + def _query_to_filter(self) -> dict[str, Any]: filter_dict = { "insight": "TRENDS", "properties": self.query.properties, diff --git a/posthog/hogql_queries/insights/trends/utils.py b/posthog/hogql_queries/insights/trends/utils.py index 61a4252d499f2..b8f6c3989f1fd 100644 --- a/posthog/hogql_queries/insights/trends/utils.py +++ b/posthog/hogql_queries/insights/trends/utils.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import Optional, Union from posthog.schema import ActionsNode, DataWarehouseNode, EventsNode, BreakdownType @@ -12,7 +12,7 @@ def get_properties_chain( breakdown_type: BreakdownType | None, breakdown_field: str, group_type_index: Optional[float | int], -) -> List[str | int]: +) -> list[str | int]: if breakdown_type == "person": return ["person", "properties", breakdown_field] diff --git a/posthog/hogql_queries/insights/utils/properties.py b/posthog/hogql_queries/insights/utils/properties.py index ea4770037b78d..41826b28535d8 100644 --- a/posthog/hogql_queries/insights/utils/properties.py +++ b/posthog/hogql_queries/insights/utils/properties.py @@ -1,11 +1,11 @@ -from typing import List, TypeAlias +from typing import TypeAlias from posthog.hogql import ast from posthog.hogql.property import property_to_expr from posthog.hogql_queries.insights.query_context import QueryContext from posthog.schema import PropertyGroupFilter from posthog.types import AnyPropertyFilter -PropertiesType: TypeAlias = List[AnyPropertyFilter] | PropertyGroupFilter | None +PropertiesType: TypeAlias = list[AnyPropertyFilter] | PropertyGroupFilter | None class Properties: @@ -17,8 +17,8 @@ def __init__( ) -> None: self.context = context - def to_exprs(self) -> List[ast.Expr]: - exprs: List[ast.Expr] = [] + def to_exprs(self) -> list[ast.Expr]: + exprs: list[ast.Expr] = [] team, query = self.context.team, self.context.query diff --git a/posthog/hogql_queries/insights/utils/utils.py b/posthog/hogql_queries/insights/utils/utils.py index c3b99c6a3b625..747d7e2b6ca5a 100644 --- a/posthog/hogql_queries/insights/utils/utils.py +++ b/posthog/hogql_queries/insights/utils/utils.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from posthog.hogql import ast from posthog.models.team.team import Team, WeekStartDay from posthog.queries.util import get_trunc_func_ch @@ -6,7 +6,7 @@ def get_start_of_interval_hogql(interval: str, *, team: Team, source: Optional[ast.Expr] = None) -> ast.Expr: trunc_func = get_trunc_func_ch(interval) - trunc_func_args: List[ast.Expr] = [source] if source else [ast.Field(chain=["timestamp"])] + trunc_func_args: list[ast.Expr] = [source] if source else [ast.Field(chain=["timestamp"])] if trunc_func == "toStartOfWeek": trunc_func_args.append(ast.Constant(value=int((WeekStartDay(team.week_start_day or 0)).clickhouse_mode))) return ast.Call(name=trunc_func, args=trunc_func_args) diff --git a/posthog/hogql_queries/legacy_compatibility/filter_to_query.py b/posthog/hogql_queries/legacy_compatibility/filter_to_query.py index 382b37fa56db0..fdeb74fc9076f 100644 --- a/posthog/hogql_queries/legacy_compatibility/filter_to_query.py +++ b/posthog/hogql_queries/legacy_compatibility/filter_to_query.py @@ -1,7 +1,7 @@ import copy from enum import Enum import json -from typing import Any, List, Dict, Literal +from typing import Any, Literal from posthog.hogql_queries.legacy_compatibility.clean_properties import clean_entity_properties, clean_global_properties from posthog.models.entity.entity import Entity as LegacyEntity from posthog.schema import ( @@ -118,7 +118,7 @@ def exlusion_entity_to_node(entity) -> FunnelExclusionEventsNode | FunnelExclusi # TODO: remove this method that returns legacy entities -def to_base_entity_dict(entity: Dict): +def to_base_entity_dict(entity: dict): return { "type": entity.get("type"), "id": entity.get("id"), @@ -140,7 +140,7 @@ def to_base_entity_dict(entity: Dict): INSIGHT_TYPE = Literal["TRENDS", "FUNNELS", "RETENTION", "PATHS", "LIFECYCLE", "STICKINESS"] -def _date_range(filter: Dict): +def _date_range(filter: dict): date_range = DateRange( date_from=filter.get("date_from"), date_to=filter.get("date_to"), @@ -153,7 +153,7 @@ def _date_range(filter: Dict): return {"dateRange": date_range} -def _interval(filter: Dict): +def _interval(filter: dict): if _insight_type(filter) == "RETENTION" or _insight_type(filter) == "PATHS": return {} @@ -163,7 +163,7 @@ def _interval(filter: Dict): return {"interval": filter.get("interval")} -def _series(filter: Dict): +def _series(filter: dict): if _insight_type(filter) == "RETENTION" or _insight_type(filter) == "PATHS": return {} @@ -188,8 +188,8 @@ def _series(filter: Dict): } -def _entities(filter: Dict): - processed_entities: List[LegacyEntity] = [] +def _entities(filter: dict): + processed_entities: list[LegacyEntity] = [] # add actions actions = filter.get("actions", []) @@ -213,7 +213,7 @@ def _entities(filter: Dict): return processed_entities -def _sampling_factor(filter: Dict): +def _sampling_factor(filter: dict): if isinstance(filter.get("sampling_factor"), str): try: return float(filter.get("sampling_factor")) @@ -223,16 +223,16 @@ def _sampling_factor(filter: Dict): return {"samplingFactor": filter.get("sampling_factor")} -def _properties(filter: Dict): +def _properties(filter: dict): raw_properties = filter.get("properties", None) return {"properties": clean_global_properties(raw_properties)} -def _filter_test_accounts(filter: Dict): +def _filter_test_accounts(filter: dict): return {"filterTestAccounts": filter.get("filter_test_accounts")} -def _breakdown_filter(_filter: Dict): +def _breakdown_filter(_filter: dict): if _insight_type(_filter) != "TRENDS" and _insight_type(_filter) != "FUNNELS": return {} @@ -275,13 +275,13 @@ def _breakdown_filter(_filter: Dict): return {"breakdownFilter": BreakdownFilter(**breakdownFilter)} -def _group_aggregation_filter(filter: Dict): +def _group_aggregation_filter(filter: dict): if _insight_type(filter) == "STICKINESS" or _insight_type(filter) == "LIFECYCLE": return {} return {"aggregation_group_type_index": filter.get("aggregation_group_type_index")} -def _insight_filter(filter: Dict): +def _insight_filter(filter: dict): if _insight_type(filter) == "TRENDS": insight_filter = { "trendsFilter": TrendsFilter( @@ -387,7 +387,7 @@ def _insight_filter(filter: Dict): return insight_filter -def filters_to_funnel_paths_query(filter: Dict[str, Any]) -> FunnelPathsFilter | None: +def filters_to_funnel_paths_query(filter: dict[str, Any]) -> FunnelPathsFilter | None: funnel_paths = filter.get("funnel_paths") funnel_filter = filter.get("funnel_filter") @@ -404,13 +404,13 @@ def filters_to_funnel_paths_query(filter: Dict[str, Any]) -> FunnelPathsFilter | ) -def _insight_type(filter: Dict) -> INSIGHT_TYPE: +def _insight_type(filter: dict) -> INSIGHT_TYPE: if filter.get("insight") == "SESSIONS": return "TRENDS" return filter.get("insight", "TRENDS") -def filter_to_query(filter: Dict) -> InsightQueryNode: +def filter_to_query(filter: dict) -> InsightQueryNode: filter = copy.deepcopy(filter) # duplicate to prevent accidental filter alterations Query = insight_to_query_type[_insight_type(filter)] diff --git a/posthog/hogql_queries/query_runner.py b/posthog/hogql_queries/query_runner.py index 1ddd336dee981..0e49ca849637b 100644 --- a/posthog/hogql_queries/query_runner.py +++ b/posthog/hogql_queries/query_runner.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from datetime import datetime from enum import IntEnum -from typing import Any, Generic, List, Optional, Type, Dict, TypeVar, Union, Tuple, cast, TypeGuard +from typing import Any, Generic, Optional, TypeVar, Union, cast, TypeGuard from django.conf import settings from django.core.cache import cache @@ -76,9 +76,9 @@ class QueryResponse(BaseModel, Generic[DataT]): extra="forbid", ) results: DataT - timings: Optional[List[QueryTiming]] = None - types: Optional[List[Union[Tuple[str, str], str]]] = None - columns: Optional[List[str]] = None + timings: Optional[list[QueryTiming]] = None + types: Optional[list[Union[tuple[str, str], str]]] = None + columns: Optional[list[str]] = None hogql: Optional[str] = None hasMore: Optional[bool] = None limit: Optional[int] = None @@ -128,7 +128,7 @@ class CacheMissResponse(BaseModel): def get_query_runner( - query: Dict[str, Any] | RunnableQueryNode | BaseModel, + query: dict[str, Any] | RunnableQueryNode | BaseModel, team: Team, timings: Optional[HogQLTimings] = None, limit_context: Optional[LimitContext] = None, @@ -146,7 +146,7 @@ def get_query_runner( from .insights.trends.trends_query_runner import TrendsQueryRunner return TrendsQueryRunner( - query=cast(TrendsQuery | Dict[str, Any], query), + query=cast(TrendsQuery | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -156,7 +156,7 @@ def get_query_runner( from .insights.funnels.funnels_query_runner import FunnelsQueryRunner return FunnelsQueryRunner( - query=cast(FunnelsQuery | Dict[str, Any], query), + query=cast(FunnelsQuery | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -166,7 +166,7 @@ def get_query_runner( from .insights.retention_query_runner import RetentionQueryRunner return RetentionQueryRunner( - query=cast(RetentionQuery | Dict[str, Any], query), + query=cast(RetentionQuery | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -176,7 +176,7 @@ def get_query_runner( from .insights.paths_query_runner import PathsQueryRunner return PathsQueryRunner( - query=cast(PathsQuery | Dict[str, Any], query), + query=cast(PathsQuery | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -186,7 +186,7 @@ def get_query_runner( from .insights.stickiness_query_runner import StickinessQueryRunner return StickinessQueryRunner( - query=cast(StickinessQuery | Dict[str, Any], query), + query=cast(StickinessQuery | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -196,7 +196,7 @@ def get_query_runner( from .insights.lifecycle_query_runner import LifecycleQueryRunner return LifecycleQueryRunner( - query=cast(LifecycleQuery | Dict[str, Any], query), + query=cast(LifecycleQuery | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -206,7 +206,7 @@ def get_query_runner( from .events_query_runner import EventsQueryRunner return EventsQueryRunner( - query=cast(EventsQuery | Dict[str, Any], query), + query=cast(EventsQuery | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -216,7 +216,7 @@ def get_query_runner( from .actors_query_runner import ActorsQueryRunner return ActorsQueryRunner( - query=cast(ActorsQuery | Dict[str, Any], query), + query=cast(ActorsQuery | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -226,7 +226,7 @@ def get_query_runner( from .insights.insight_actors_query_runner import InsightActorsQueryRunner return InsightActorsQueryRunner( - query=cast(InsightActorsQuery | Dict[str, Any], query), + query=cast(InsightActorsQuery | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -236,7 +236,7 @@ def get_query_runner( from .insights.insight_actors_query_options_runner import InsightActorsQueryOptionsRunner return InsightActorsQueryOptionsRunner( - query=cast(InsightActorsQueryOptions | Dict[str, Any], query), + query=cast(InsightActorsQueryOptions | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -246,7 +246,7 @@ def get_query_runner( from .insights.funnels.funnel_correlation_query_runner import FunnelCorrelationQueryRunner return FunnelCorrelationQueryRunner( - query=cast(FunnelCorrelationQuery | Dict[str, Any], query), + query=cast(FunnelCorrelationQuery | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -256,7 +256,7 @@ def get_query_runner( from .hogql_query_runner import HogQLQueryRunner return HogQLQueryRunner( - query=cast(HogQLQuery | Dict[str, Any], query), + query=cast(HogQLQuery | dict[str, Any], query), team=team, timings=timings, limit_context=limit_context, @@ -266,7 +266,7 @@ def get_query_runner( from .sessions_timeline_query_runner import SessionsTimelineQueryRunner return SessionsTimelineQueryRunner( - query=cast(SessionsTimelineQuery | Dict[str, Any], query), + query=cast(SessionsTimelineQuery | dict[str, Any], query), team=team, timings=timings, modifiers=modifiers, @@ -292,7 +292,7 @@ def get_query_runner( class QueryRunner(ABC, Generic[Q]): query: Q - query_type: Type[Q] + query_type: type[Q] team: Team timings: HogQLTimings modifiers: HogQLQueryModifiers @@ -300,7 +300,7 @@ class QueryRunner(ABC, Generic[Q]): def __init__( self, - query: Q | BaseModel | Dict[str, Any], + query: Q | BaseModel | dict[str, Any], team: Team, timings: Optional[HogQLTimings] = None, modifiers: Optional[HogQLQueryModifiers] = None, @@ -425,7 +425,7 @@ def apply_dashboard_filters(self, dashboard_filter: DashboardFilter) -> Q: # The default logic below applies to all insights and a lot of other queries # Notable exception: `HogQLQuery`, which has `properties` and `dateRange` within `HogQLFilters` if hasattr(self.query, "properties") and hasattr(self.query, "dateRange"): - query_update: Dict[str, Any] = {} + query_update: dict[str, Any] = {} if dashboard_filter.properties: if self.query.properties: query_update["properties"] = PropertyGroupFilter( diff --git a/posthog/hogql_queries/sessions_timeline_query_runner.py b/posthog/hogql_queries/sessions_timeline_query_runner.py index cda9433d63efa..306ec02c93448 100644 --- a/posthog/hogql_queries/sessions_timeline_query_runner.py +++ b/posthog/hogql_queries/sessions_timeline_query_runner.py @@ -1,6 +1,6 @@ from datetime import timedelta import json -from typing import Dict, cast +from typing import cast from posthog.api.element import ElementSerializer @@ -138,7 +138,7 @@ def calculate(self) -> SessionsTimelineQueryResponse: limit_context=self.limit_context, ) assert query_result.results is not None - timeline_entries_map: Dict[str, TimelineEntry] = {} + timeline_entries_map: dict[str, TimelineEntry] = {} for ( uuid, timestamp_parsed, diff --git a/posthog/hogql_queries/test/test_events_query_runner.py b/posthog/hogql_queries/test/test_events_query_runner.py index 7c8c62c5fb0fc..1617919f984ff 100644 --- a/posthog/hogql_queries/test/test_events_query_runner.py +++ b/posthog/hogql_queries/test/test_events_query_runner.py @@ -1,4 +1,4 @@ -from typing import Tuple, Any, cast +from typing import Any, cast from freezegun import freeze_time @@ -25,7 +25,7 @@ class TestEventsQueryRunner(ClickhouseTestMixin, APIBaseTest): maxDiff = None - def _create_events(self, data: list[Tuple[str, str, Any]], event="$pageview"): + def _create_events(self, data: list[tuple[str, str, Any]], event="$pageview"): person_result = [] for distinct_id, timestamp, event_properties in data: with freeze_time(timestamp): diff --git a/posthog/hogql_queries/test/test_query_runner.py b/posthog/hogql_queries/test/test_query_runner.py index a02cf4fb46cdd..88d6128b00544 100644 --- a/posthog/hogql_queries/test/test_query_runner.py +++ b/posthog/hogql_queries/test/test_query_runner.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from typing import Any, List, Literal, Optional +from typing import Any, Literal, Optional from zoneinfo import ZoneInfo from dateutil.parser import isoparse @@ -21,7 +21,7 @@ class TestQuery(BaseModel): kind: Literal["TestQuery"] = "TestQuery" some_attr: str - other_attr: Optional[List[Any]] = [] + other_attr: Optional[list[Any]] = [] class TestQueryRunner(BaseTest): diff --git a/posthog/hogql_queries/utils/formula_ast.py b/posthog/hogql_queries/utils/formula_ast.py index 28e705827b7f8..922e283362a49 100644 --- a/posthog/hogql_queries/utils/formula_ast.py +++ b/posthog/hogql_queries/utils/formula_ast.py @@ -1,6 +1,6 @@ import ast import operator -from typing import Any, Dict, List +from typing import Any class FormulaAST: @@ -12,9 +12,9 @@ class FormulaAST: ast.Mod: operator.mod, ast.Pow: operator.pow, } - zipped_data: List[tuple[float]] + zipped_data: list[tuple[float]] - def __init__(self, data: List[List[float]]): + def __init__(self, data: list[list[float]]): self.zipped_data = list(zip(*data)) def call(self, node: str): @@ -27,8 +27,8 @@ def call(self, node: str): res.append(result) return res - def _evaluate(self, node, const_map: Dict[str, Any]): - if isinstance(node, (list, tuple)): + def _evaluate(self, node, const_map: dict[str, Any]): + if isinstance(node, list | tuple): return [self._evaluate(sub_node, const_map) for sub_node in node] elif isinstance(node, str): diff --git a/posthog/hogql_queries/utils/query_date_range.py b/posthog/hogql_queries/utils/query_date_range.py index ab1f25fbb376c..ac9636c1e1ce5 100644 --- a/posthog/hogql_queries/utils/query_date_range.py +++ b/posthog/hogql_queries/utils/query_date_range.py @@ -1,7 +1,7 @@ import re from datetime import datetime, timedelta from functools import cached_property -from typing import Literal, Optional, Dict +from typing import Literal, Optional from zoneinfo import ZoneInfo from dateutil.parser import parse @@ -248,7 +248,7 @@ def date_to_with_extra_interval_hogql(self) -> ast.Call: args=[self.date_to_start_of_interval_hogql(self.date_to_as_hogql()), self.one_interval_period()], ) - def to_placeholders(self) -> Dict[str, ast.Expr]: + def to_placeholders(self) -> dict[str, ast.Expr]: return { "interval": self.interval_period_string_as_hogql_constant(), "one_interval_period": self.one_interval_period(), diff --git a/posthog/hogql_queries/utils/query_previous_period_date_range.py b/posthog/hogql_queries/utils/query_previous_period_date_range.py index 652a95c835eb7..c6dca63dc7d95 100644 --- a/posthog/hogql_queries/utils/query_previous_period_date_range.py +++ b/posthog/hogql_queries/utils/query_previous_period_date_range.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Optional, Dict, Tuple +from typing import Optional from posthog.hogql_queries.utils.query_date_range import QueryDateRange from posthog.models.team import Team @@ -28,7 +28,7 @@ def __init__( ) -> None: super().__init__(date_range, team, interval, now) - def date_from_delta_mappings(self) -> Dict[str, int] | None: + def date_from_delta_mappings(self) -> dict[str, int] | None: if self._date_range and isinstance(self._date_range.date_from, str) and self._date_range.date_from != "all": date_from = self._date_range.date_from else: @@ -41,7 +41,7 @@ def date_from_delta_mappings(self) -> Dict[str, int] | None: )[1] return delta_mapping - def date_to_delta_mappings(self) -> Dict[str, int] | None: + def date_to_delta_mappings(self) -> dict[str, int] | None: if self._date_range and self._date_range.date_to: delta_mapping = relative_date_parse_with_delta_mapping( self._date_range.date_to, @@ -52,7 +52,7 @@ def date_to_delta_mappings(self) -> Dict[str, int] | None: return delta_mapping return None - def dates(self) -> Tuple[datetime, datetime]: + def dates(self) -> tuple[datetime, datetime]: current_period_date_from = super().date_from() current_period_date_to = super().date_to() diff --git a/posthog/hogql_queries/web_analytics/test/test_web_analytics_query_runner.py b/posthog/hogql_queries/web_analytics/test/test_web_analytics_query_runner.py index 7ea8e864a3a65..3ea217606522c 100644 --- a/posthog/hogql_queries/web_analytics/test/test_web_analytics_query_runner.py +++ b/posthog/hogql_queries/web_analytics/test/test_web_analytics_query_runner.py @@ -1,4 +1,4 @@ -from typing import Union, List +from typing import Union from freezegun import freeze_time @@ -62,7 +62,7 @@ def _create__web_overview_query(self, date_from, date_to, properties): return WebOverviewQueryRunner(team=self.team, query=query) def test_sample_rate_cache_key_is_same_across_subclasses(self): - properties: List[Union[EventPropertyFilter, PersonPropertyFilter]] = [ + properties: list[Union[EventPropertyFilter, PersonPropertyFilter]] = [ EventPropertyFilter(key="$current_url", value="/a", operator=PropertyOperator.is_not), PersonPropertyFilter(key="$initial_utm_source", value="google", operator=PropertyOperator.is_not), ] @@ -75,10 +75,10 @@ def test_sample_rate_cache_key_is_same_across_subclasses(self): self.assertEqual(stats_key, overview_key) def test_sample_rate_cache_key_is_same_with_different_properties(self): - properties_a: List[Union[EventPropertyFilter, PersonPropertyFilter]] = [ + properties_a: list[Union[EventPropertyFilter, PersonPropertyFilter]] = [ EventPropertyFilter(key="$current_url", value="/a", operator=PropertyOperator.is_not), ] - properties_b: List[Union[EventPropertyFilter, PersonPropertyFilter]] = [ + properties_b: list[Union[EventPropertyFilter, PersonPropertyFilter]] = [ EventPropertyFilter(key="$current_url", value="/b", operator=PropertyOperator.is_not), ] date_from = "2023-12-08" @@ -90,7 +90,7 @@ def test_sample_rate_cache_key_is_same_with_different_properties(self): self.assertEqual(key_a, key_b) def test_sample_rate_cache_key_changes_with_date_range(self): - properties: List[Union[EventPropertyFilter, PersonPropertyFilter]] = [ + properties: list[Union[EventPropertyFilter, PersonPropertyFilter]] = [ EventPropertyFilter(key="$current_url", value="/a", operator=PropertyOperator.is_not), ] date_from_a = "2023-12-08" @@ -100,7 +100,7 @@ def test_sample_rate_cache_key_changes_with_date_range(self): key_a = self._create_web_stats_table_query(date_from_a, date_to, properties)._sample_rate_cache_key() key_b = self._create_web_stats_table_query(date_from_b, date_to, properties)._sample_rate_cache_key() - self.assertNotEquals(key_a, key_b) + self.assertNotEqual(key_a, key_b) def test_sample_rate_from_count(self): self.assertEqual(SamplingRate(numerator=1), _sample_rate_from_count(0)) diff --git a/posthog/hogql_queries/web_analytics/web_analytics_query_runner.py b/posthog/hogql_queries/web_analytics/web_analytics_query_runner.py index f91f3c1cff404..fb1288ac1bd26 100644 --- a/posthog/hogql_queries/web_analytics/web_analytics_query_runner.py +++ b/posthog/hogql_queries/web_analytics/web_analytics_query_runner.py @@ -2,7 +2,7 @@ from abc import ABC from datetime import timedelta from math import ceil -from typing import Optional, List, Union, Type +from typing import Optional, Union from django.conf import settings from django.core.cache import cache @@ -32,7 +32,7 @@ class WebAnalyticsQueryRunner(QueryRunner, ABC): query: WebQueryNode - query_type: Type[WebQueryNode] + query_type: type[WebQueryNode] @cached_property def query_date_range(self): @@ -51,7 +51,7 @@ def pathname_property_filter(self) -> Optional[EventPropertyFilter]: return None @cached_property - def property_filters_without_pathname(self) -> List[Union[EventPropertyFilter, PersonPropertyFilter]]: + def property_filters_without_pathname(self) -> list[Union[EventPropertyFilter, PersonPropertyFilter]]: return [p for p in self.query.properties if p.key != "$pathname"] def session_where(self, include_previous_period: Optional[bool] = None): diff --git a/posthog/jwt.py b/posthog/jwt.py index fa458ab2f5e3f..111a85d51df82 100644 --- a/posthog/jwt.py +++ b/posthog/jwt.py @@ -1,6 +1,6 @@ from datetime import datetime, timedelta, timezone from enum import Enum -from typing import Any, Dict +from typing import Any import jwt from django.conf import settings @@ -32,7 +32,7 @@ def encode_jwt(payload: dict, expiry_delta: timedelta, audience: PosthogJwtAudie return encoded_jwt -def decode_jwt(token: str, audience: PosthogJwtAudience) -> Dict[str, Any]: +def decode_jwt(token: str, audience: PosthogJwtAudience) -> dict[str, Any]: info = jwt.decode(token, settings.SECRET_KEY, audience=audience.value, algorithms=["HS256"]) return info diff --git a/posthog/kafka_client/client.py b/posthog/kafka_client/client.py index d29d9e9c0ae0d..3f58e572417b8 100644 --- a/posthog/kafka_client/client.py +++ b/posthog/kafka_client/client.py @@ -1,6 +1,7 @@ import json from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Optional +from collections.abc import Callable from django.conf import settings from kafka import KafkaConsumer as KC @@ -32,7 +33,7 @@ def send( topic: str, value: Any, key: Any = None, - headers: Optional[List[Tuple[str, bytes]]] = None, + headers: Optional[list[tuple[str, bytes]]] = None, ): produce_future = FutureProduceResult(topic_partition=TopicPartition(topic, 1)) future = FutureRecordMetadata( @@ -158,7 +159,7 @@ def produce( data: Any, key: Any = None, value_serializer: Optional[Callable[[Any], Any]] = None, - headers: Optional[List[Tuple[str, str]]] = None, + headers: Optional[list[tuple[str, str]]] = None, ): if not value_serializer: value_serializer = self.json_serializer @@ -258,7 +259,7 @@ class ClickhouseProducer: def __init__(self): self.producer = KafkaProducer() if not settings.TEST else None - def produce(self, sql: str, topic: str, data: Dict[str, Any], sync: bool = True): + def produce(self, sql: str, topic: str, data: dict[str, Any], sync: bool = True): if self.producer is not None: # TODO: this should be not sync and self.producer.produce(topic=topic, data=data) else: diff --git a/posthog/kafka_client/helper.py b/posthog/kafka_client/helper.py index 6084e991a100a..39cb9f038560f 100644 --- a/posthog/kafka_client/helper.py +++ b/posthog/kafka_client/helper.py @@ -39,9 +39,11 @@ def get_kafka_ssl_context(): # SSLContext inside the with so when it goes out of scope the files are removed which has them # existing for the shortest amount of time. As extra caution password # protect/encrypt the client key - with NamedTemporaryFile(suffix=".crt") as cert_file, NamedTemporaryFile( - suffix=".key" - ) as key_file, NamedTemporaryFile(suffix=".crt") as trust_file: + with ( + NamedTemporaryFile(suffix=".crt") as cert_file, + NamedTemporaryFile(suffix=".key") as key_file, + NamedTemporaryFile(suffix=".crt") as trust_file, + ): cert_file.write(base64.b64decode(os.environ["KAFKA_CLIENT_CERT_B64"].encode("utf-8"))) cert_file.flush() diff --git a/posthog/management/commands/backfill_distinct_id_overrides.py b/posthog/management/commands/backfill_distinct_id_overrides.py index 507e744a93d0e..4472ec6291658 100644 --- a/posthog/management/commands/backfill_distinct_id_overrides.py +++ b/posthog/management/commands/backfill_distinct_id_overrides.py @@ -2,7 +2,7 @@ import logging from dataclasses import dataclass -from typing import Sequence +from collections.abc import Sequence import structlog from django.core.management.base import BaseCommand, CommandError diff --git a/posthog/management/commands/create_channel_definitions_file.py b/posthog/management/commands/create_channel_definitions_file.py index 859bbe3c631ce..cab70bf31d360 100644 --- a/posthog/management/commands/create_channel_definitions_file.py +++ b/posthog/management/commands/create_channel_definitions_file.py @@ -4,7 +4,7 @@ from collections import OrderedDict from dataclasses import dataclass from enum import Enum -from typing import Optional, Tuple +from typing import Optional from django.core.management.base import BaseCommand @@ -40,7 +40,7 @@ def handle(self, *args, **options): input_arg = options.get("ga_sources") if not input_arg: raise ValueError("No input file specified") - with open(input_arg, "r", encoding="utf-8-sig") as input_file: + with open(input_arg, encoding="utf-8-sig") as input_file: input_str = input_file.read() split_items = re.findall(r"\S+\s+SOURCE_CATEGORY_\S+", input_str) @@ -59,7 +59,7 @@ def handle_entry(entry): base_type, type_if_paid, type_if_organic = types[raw_type] return (domain, EntryKind.source), SourceEntry(base_type, type_if_paid, type_if_organic) - entries: OrderedDict[Tuple[str, str], SourceEntry] = OrderedDict(map(handle_entry, split_items)) + entries: OrderedDict[tuple[str, str], SourceEntry] = OrderedDict(map(handle_entry, split_items)) # add google domains to this, from https://www.google.com/supported_domains for google_domain in [ diff --git a/posthog/management/commands/fix_person_distinct_ids_after_delete.py b/posthog/management/commands/fix_person_distinct_ids_after_delete.py index 842a4e5353ec8..4f0853dd001ba 100644 --- a/posthog/management/commands/fix_person_distinct_ids_after_delete.py +++ b/posthog/management/commands/fix_person_distinct_ids_after_delete.py @@ -1,5 +1,5 @@ import logging -from typing import List, Optional +from typing import Optional import structlog from django.core.management.base import BaseCommand @@ -50,7 +50,7 @@ def run(options, sync: bool = False): logger.info("Kafka producer queue flushed.") -def get_distinct_ids_tied_to_deleted_persons(team_id: int) -> List[str]: +def get_distinct_ids_tied_to_deleted_persons(team_id: int) -> list[str]: # find distinct_ids where the person is set to be deleted rows = sync_execute( """ diff --git a/posthog/management/commands/makemigrations.py b/posthog/management/commands/makemigrations.py index 8ff0a37bfaa34..a9e0ea4f98e5a 100644 --- a/posthog/management/commands/makemigrations.py +++ b/posthog/management/commands/makemigrations.py @@ -9,7 +9,7 @@ class Command(MakeMigrationsCommand): def handle(self, *app_labels, **options): # Generate a migrations manifest with latest migration on each app - super(Command, self).handle(*app_labels, **options) + super().handle(*app_labels, **options) loader = MigrationLoader(None, ignore_no_migrations=True) apps = sorted(loader.migrated_apps) diff --git a/posthog/management/commands/partition.py b/posthog/management/commands/partition.py index b17e958b0c1e1..4bb17e68b7851 100644 --- a/posthog/management/commands/partition.py +++ b/posthog/management/commands/partition.py @@ -6,7 +6,7 @@ def load_sql(filename): path = os.path.join(os.path.dirname(__file__), "../sql/", filename) - with open(path, "r", encoding="utf_8") as f: + with open(path, encoding="utf_8") as f: return f.read() diff --git a/posthog/management/commands/run_async_migrations.py b/posthog/management/commands/run_async_migrations.py index 611c6038fd43b..c8ee72ea352c6 100644 --- a/posthog/management/commands/run_async_migrations.py +++ b/posthog/management/commands/run_async_migrations.py @@ -1,5 +1,5 @@ import logging -from typing import List, Sequence +from collections.abc import Sequence import structlog from django.core.exceptions import ImproperlyConfigured @@ -31,7 +31,7 @@ def get_necessary_migrations() -> Sequence[AsyncMigration]: - necessary_migrations: List[AsyncMigration] = [] + necessary_migrations: list[AsyncMigration] = [] for migration_name, definition in sorted(ALL_ASYNC_MIGRATIONS.items()): if is_async_migration_complete(migration_name): continue @@ -144,10 +144,8 @@ def handle_plan(necessary_migrations: Sequence[AsyncMigration]): logger.info("Async migrations up to date!") else: logger.warning( - ( - f"Required async migration{' is' if len(necessary_migrations) == 1 else 's are'} not completed:\n" - "\n".join((f"- {migration.get_name_with_requirements()}" for migration in necessary_migrations)) - ) + f"Required async migration{' is' if len(necessary_migrations) == 1 else 's are'} not completed:\n" + "\n".join(f"- {migration.get_name_with_requirements()}" for migration in necessary_migrations) ) diff --git a/posthog/management/commands/sync_feature_flags.py b/posthog/management/commands/sync_feature_flags.py index df2e8d3257645..4e26061603691 100644 --- a/posthog/management/commands/sync_feature_flags.py +++ b/posthog/management/commands/sync_feature_flags.py @@ -1,4 +1,4 @@ -from typing import Dict, cast +from typing import cast from django.core.management.base import BaseCommand @@ -15,8 +15,8 @@ class Command(BaseCommand): help = "Add and enable all feature flags in frontend/src/lib/constants.tsx for all teams" def handle(self, *args, **options): - flags: Dict[str, str] = {} - with open("frontend/src/lib/constants.tsx", "r", encoding="utf_8") as f: + flags: dict[str, str] = {} + with open("frontend/src/lib/constants.tsx", encoding="utf_8") as f: lines = f.readlines() parsing_flags = False for line in lines: diff --git a/posthog/management/commands/sync_replicated_schema.py b/posthog/management/commands/sync_replicated_schema.py index e2c280bd41b39..642eae80d9bbf 100644 --- a/posthog/management/commands/sync_replicated_schema.py +++ b/posthog/management/commands/sync_replicated_schema.py @@ -1,7 +1,6 @@ import logging import re from collections import defaultdict -from typing import Dict, Set import structlog from django.conf import settings @@ -65,8 +64,8 @@ def analyze_cluster_tables(self): }, ) - host_tables: Dict[HostName, Set[TableName]] = defaultdict(set) - create_table_queries: Dict[TableName, Query] = {} + host_tables: dict[HostName, set[TableName]] = defaultdict(set) + create_table_queries: dict[TableName, Query] = {} for host, table_name, create_table_query in rows: host_tables[host].add(table_name) @@ -74,7 +73,7 @@ def analyze_cluster_tables(self): return host_tables, create_table_queries, self.get_out_of_sync_hosts(host_tables) - def get_out_of_sync_hosts(self, host_tables: Dict[HostName, Set[TableName]]) -> Dict[HostName, Set[TableName]]: + def get_out_of_sync_hosts(self, host_tables: dict[HostName, set[TableName]]) -> dict[HostName, set[TableName]]: table_names = list(map(get_table_name, CREATE_TABLE_QUERIES)) out_of_sync = {} @@ -87,8 +86,8 @@ def get_out_of_sync_hosts(self, host_tables: Dict[HostName, Set[TableName]]) -> def create_missing_tables( self, - out_of_sync_hosts: Dict[HostName, Set[TableName]], - create_table_queries: Dict[TableName, Query], + out_of_sync_hosts: dict[HostName, set[TableName]], + create_table_queries: dict[TableName, Query], ): missing_tables = {table for tables in out_of_sync_hosts.values() for table in tables} diff --git a/posthog/management/commands/test_migrations_are_safe.py b/posthog/management/commands/test_migrations_are_safe.py index 566533fd9fe69..a576b982b5089 100644 --- a/posthog/management/commands/test_migrations_are_safe.py +++ b/posthog/management/commands/test_migrations_are_safe.py @@ -1,6 +1,6 @@ import re import sys -from typing import List, Optional +from typing import Optional from django.core.management import call_command from django.core.management.base import BaseCommand, CommandError @@ -20,7 +20,7 @@ def _get_table(search_string: str, operation_sql: str) -> Optional[str]: def validate_migration_sql(sql) -> bool: new_tables = _get_new_tables(sql) operations = sql.split("\n") - tables_created_so_far: List[str] = [] + tables_created_so_far: list[str] = [] for operation_sql in operations: # Extract table name from queries of this format: ALTER TABLE TABLE "posthog_feature" table_being_altered: Optional[str] = ( diff --git a/posthog/middleware.py b/posthog/middleware.py index e43ef3a620f18..87ee128a7268e 100644 --- a/posthog/middleware.py +++ b/posthog/middleware.py @@ -1,6 +1,7 @@ import time from ipaddress import ip_address, ip_network -from typing import Any, Callable, List, Optional, cast +from typing import Any, Optional, cast +from collections.abc import Callable from django.shortcuts import redirect import structlog @@ -66,7 +67,7 @@ class AllowIPMiddleware: - trusted_proxies: List[str] = [] + trusted_proxies: list[str] = [] def __init__(self, get_response): if not settings.ALLOWED_IP_BLOCKS: @@ -411,7 +412,7 @@ class CaptureMiddleware: def __init__(self, get_response): self.get_response = get_response - middlewares: List[Any] = [] + middlewares: list[Any] = [] # based on how we're using these middlewares, only middlewares that # have a process_request and process_response attribute can be valid here. # Or, middlewares that inherit from `middleware.util.deprecation.MiddlewareMixin` which diff --git a/posthog/migrations/0027_move_elements_to_group.py b/posthog/migrations/0027_move_elements_to_group.py index 51a65b1f5da39..1bc55cd985388 100644 --- a/posthog/migrations/0027_move_elements_to_group.py +++ b/posthog/migrations/0027_move_elements_to_group.py @@ -1,7 +1,6 @@ # Generated by Django 3.0.3 on 2020-02-27 18:13 import hashlib import json -from typing import List from django.db import migrations, models, transaction from django.forms.models import model_to_dict @@ -21,7 +20,7 @@ def forwards(apps, schema_editor): ElementGroup = apps.get_model("posthog", "ElementGroup") Element = apps.get_model("posthog", "Element") - hashes_seen: List[str] = [] + hashes_seen: list[str] = [] while Event.objects.filter(element__isnull=False, elements_hash__isnull=True, event="$autocapture").exists(): with transaction.atomic(): events = ( diff --git a/posthog/migrations/0132_team_test_account_filters.py b/posthog/migrations/0132_team_test_account_filters.py index 313de9f3355e4..a1aba896aa287 100644 --- a/posthog/migrations/0132_team_test_account_filters.py +++ b/posthog/migrations/0132_team_test_account_filters.py @@ -22,7 +22,7 @@ class GenericEmails: """ def __init__(self): - with open(get_absolute_path("../helpers/generic_emails.txt"), "r") as f: + with open(get_absolute_path("../helpers/generic_emails.txt")) as f: self.emails = {x.rstrip(): True for x in f} def is_generic(self, email: str) -> bool: diff --git a/posthog/migrations/0219_migrate_tags_v2.py b/posthog/migrations/0219_migrate_tags_v2.py index fef394a5cc0ea..dcd7375511e4f 100644 --- a/posthog/migrations/0219_migrate_tags_v2.py +++ b/posthog/migrations/0219_migrate_tags_v2.py @@ -1,5 +1,5 @@ # Generated by Django 3.2.5 on 2022-03-01 23:41 -from typing import Any, List, Tuple +from typing import Any from django.core.paginator import Paginator from django.db import migrations @@ -19,7 +19,7 @@ def forwards(apps, schema_editor): Insight = apps.get_model("posthog", "Insight") Dashboard = apps.get_model("posthog", "Dashboard") - createables: List[Tuple[Any, Any]] = [] + createables: list[tuple[Any, Any]] = [] batch_size = 1_000 # Collect insight tags and taggeditems diff --git a/posthog/migrations/0259_backfill_team_recording_domains.py b/posthog/migrations/0259_backfill_team_recording_domains.py index 1f0dcba4f08f8..12304cc70fd83 100644 --- a/posthog/migrations/0259_backfill_team_recording_domains.py +++ b/posthog/migrations/0259_backfill_team_recording_domains.py @@ -1,4 +1,3 @@ -from typing import Set from urllib.parse import urlparse import structlog @@ -20,7 +19,7 @@ def backfill_recording_domains(apps, _): teams_in_batch = all_teams[i : i + batch_size] for team in teams_in_batch: - recording_domains: Set[str] = set() + recording_domains: set[str] = set() for app_url in team.app_urls: # Extract just the domain from the URL parsed_url = urlparse(app_url) diff --git a/posthog/models/action/action.py b/posthog/models/action/action.py index bd016535a88e0..49aefe15440f8 100644 --- a/posthog/models/action/action.py +++ b/posthog/models/action/action.py @@ -1,5 +1,5 @@ import json -from typing import List, Any +from typing import Any from django.db import models from django.db.models import Q @@ -51,10 +51,10 @@ def get_analytics_metadata(self): "deleted": self.deleted, } - def get_step_events(self) -> List[str]: + def get_step_events(self) -> list[str]: return [action_step.event for action_step in self.steps.all()] - def generate_bytecode(self) -> List[Any]: + def generate_bytecode(self) -> list[Any]: from posthog.hogql.property import action_to_expr from posthog.hogql.bytecode import create_bytecode diff --git a/posthog/models/action/util.py b/posthog/models/action/util.py index 54fda6ef5b95f..95cdca9721ceb 100644 --- a/posthog/models/action/util.py +++ b/posthog/models/action/util.py @@ -1,6 +1,6 @@ from collections import Counter -from typing import Counter as TCounter, Literal, Optional -from typing import Dict, List, Tuple +from typing import Literal, Optional +from collections import Counter as TCounter from posthog.constants import AUTOCAPTURE_EVENT, TREND_FILTER_TYPE_ACTIONS from posthog.hogql.hogql import HogQLContext @@ -15,7 +15,7 @@ def format_action_filter_event_only( action: Action, prepend: str = "action", -) -> Tuple[str, Dict]: +) -> tuple[str, dict]: """Return SQL for prefiltering events by action, i.e. down to only the events and without any other filters.""" events = action.get_step_events() if not events: @@ -37,7 +37,7 @@ def format_action_filter( table_name: str = "", person_properties_mode: PersonPropertiesMode = PersonPropertiesMode.USING_SUBQUERY, person_id_joined_alias: str = "person_id", -) -> Tuple[str, Dict]: +) -> tuple[str, dict]: """Return SQL for filtering events by action.""" # get action steps params = {"team_id": action.team.pk} if filter_by_team else {} @@ -48,7 +48,7 @@ def format_action_filter( or_queries = [] for index, step in enumerate(steps): - conditions: List[str] = [] + conditions: list[str] = [] # filter element if step.event == AUTOCAPTURE_EVENT: from posthog.models.property.util import ( @@ -118,7 +118,7 @@ def format_action_filter( def filter_event( step: ActionStep, prepend: str = "event", index: int = 0, table_name: str = "" -) -> Tuple[List[str], Dict]: +) -> tuple[list[str], dict]: from posthog.models.property.util import get_property_string_expr params = {} @@ -156,7 +156,7 @@ def format_entity_filter( person_id_joined_alias: str, prepend: str = "action", filter_by_team=True, -) -> Tuple[str, Dict]: +) -> tuple[str, dict]: if entity.type == TREND_FILTER_TYPE_ACTIONS: action = entity.get_action() entity_filter, params = format_action_filter( diff --git a/posthog/models/activity_logging/activity_log.py b/posthog/models/activity_logging/activity_log.py index 074b53b2dd55b..141130ea4f80e 100644 --- a/posthog/models/activity_logging/activity_log.py +++ b/posthog/models/activity_logging/activity_log.py @@ -2,7 +2,7 @@ import json from datetime import datetime from decimal import Decimal -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Literal, Optional, Union import structlog from django.core.paginator import Paginator @@ -52,7 +52,7 @@ class Change: class Trigger: job_type: str job_id: str - payload: Dict + payload: dict @dataclasses.dataclass(frozen=True) @@ -62,13 +62,13 @@ class Detail: # The short_id if it has one short_id: Optional[str] = None type: Optional[str] = None - changes: Optional[List[Change]] = None + changes: Optional[list[Change]] = None trigger: Optional[Trigger] = None class ActivityDetailEncoder(json.JSONEncoder): def default(self, obj): - if isinstance(obj, (Detail, Change, Trigger)): + if isinstance(obj, Detail | Change | Trigger): return obj.__dict__ if isinstance(obj, datetime): return obj.isoformat() @@ -132,7 +132,7 @@ class Meta: ] -field_exclusions: Dict[ActivityScope, List[str]] = { +field_exclusions: dict[ActivityScope, list[str]] = { "Notebook": [ "text_content", ], @@ -199,7 +199,7 @@ class Meta: } -def describe_change(m: Any) -> Union[str, Dict]: +def describe_change(m: Any) -> Union[str, dict]: if isinstance(m, Dashboard): return {"id": m.id, "name": m.name} if isinstance(m, DashboardTile): @@ -213,7 +213,7 @@ def describe_change(m: Any) -> Union[str, Dict]: return str(m) -def _read_through_relation(relation: models.Manager) -> List[Union[Dict, str]]: +def _read_through_relation(relation: models.Manager) -> list[Union[dict, str]]: described_models = [describe_change(r) for r in relation.all()] if all(isinstance(elem, str) for elem in described_models): @@ -227,11 +227,11 @@ def changes_between( model_type: ActivityScope, previous: Optional[models.Model], current: Optional[models.Model], -) -> List[Change]: +) -> list[Change]: """ Identifies changes between two models by comparing fields """ - changes: List[Change] = [] + changes: list[Change] = [] if previous is None and current is None: # there are no changes between two things that don't exist @@ -282,14 +282,14 @@ def changes_between( def dict_changes_between( model_type: ActivityScope, - previous: Dict[Any, Any], - new: Dict[Any, Any], + previous: dict[Any, Any], + new: dict[Any, Any], use_field_exclusions: bool = False, -) -> List[Change]: +) -> list[Change]: """ Identifies changes between two dictionaries by comparing fields """ - changes: List[Change] = [] + changes: list[Change] = [] if previous == new: return changes @@ -395,7 +395,7 @@ class ActivityPage: limit: int has_next: bool has_previous: bool - results: List[ActivityLog] + results: list[ActivityLog] def get_activity_page(activity_query: models.QuerySet, limit: int = 10, page: int = 1) -> ActivityPage: @@ -430,7 +430,7 @@ def load_activity( return get_activity_page(activity_query, limit, page) -def load_all_activity(scope_list: List[ActivityScope], team_id: int, limit: int = 10, page: int = 1): +def load_all_activity(scope_list: list[ActivityScope], team_id: int, limit: int = 10, page: int = 1): activity_query = ( ActivityLog.objects.select_related("user").filter(team_id=team_id, scope__in=scope_list).order_by("-created_at") ) diff --git a/posthog/models/async_deletion/delete.py b/posthog/models/async_deletion/delete.py index 9846842b8e0d5..1ab75b353e898 100644 --- a/posthog/models/async_deletion/delete.py +++ b/posthog/models/async_deletion/delete.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod from collections import defaultdict -from typing import Dict, List, Tuple import structlog from django.utils import timezone @@ -13,7 +12,7 @@ class AsyncDeletionProcess(ABC): CLICKHOUSE_MUTATION_CHUNK_SIZE = 1_000_000 CLICKHOUSE_VERIFY_CHUNK_SIZE = 1_000 - DELETION_TYPES: List[DeletionType] = [] + DELETION_TYPES: list[DeletionType] = [] def __init__(self) -> None: super().__init__() @@ -60,14 +59,14 @@ def _fetch_unverified_deletions_grouped(self): return result @abstractmethod - def process(self, deletions: List[AsyncDeletion]): + def process(self, deletions: list[AsyncDeletion]): raise NotImplementedError() @abstractmethod - def _verify_by_group(self, deletion_type: int, async_deletions: List[AsyncDeletion]) -> List[AsyncDeletion]: + def _verify_by_group(self, deletion_type: int, async_deletions: list[AsyncDeletion]) -> list[AsyncDeletion]: raise NotImplementedError() - def _conditions(self, async_deletions: List[AsyncDeletion]) -> Tuple[List[str], Dict]: + def _conditions(self, async_deletions: list[AsyncDeletion]) -> tuple[list[str], dict]: conditions, args = [], {} for i, row in enumerate(async_deletions): condition, arg = self._condition(row, str(i)) @@ -76,5 +75,5 @@ def _conditions(self, async_deletions: List[AsyncDeletion]) -> Tuple[List[str], return conditions, args @abstractmethod - def _condition(self, async_deletion: AsyncDeletion, suffix: str) -> Tuple[str, Dict]: + def _condition(self, async_deletion: AsyncDeletion, suffix: str) -> tuple[str, dict]: raise NotImplementedError() diff --git a/posthog/models/async_deletion/delete_cohorts.py b/posthog/models/async_deletion/delete_cohorts.py index c2d452628ceb2..00f10aac6b82e 100644 --- a/posthog/models/async_deletion/delete_cohorts.py +++ b/posthog/models/async_deletion/delete_cohorts.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Set, Tuple +from typing import Any from posthog.client import sync_execute from posthog.models.async_deletion import AsyncDeletion, DeletionType @@ -9,7 +9,7 @@ class AsyncCohortDeletion(AsyncDeletionProcess): DELETION_TYPES = [DeletionType.Cohort_full, DeletionType.Cohort_stale] - def process(self, deletions: List[AsyncDeletion]): + def process(self, deletions: list[AsyncDeletion]): if len(deletions) == 0: logger.warn("No AsyncDeletion for cohorts to perform") return @@ -33,7 +33,7 @@ def process(self, deletions: List[AsyncDeletion]): workload=Workload.OFFLINE, ) - def _verify_by_group(self, deletion_type: int, async_deletions: List[AsyncDeletion]) -> List[AsyncDeletion]: + def _verify_by_group(self, deletion_type: int, async_deletions: list[AsyncDeletion]) -> list[AsyncDeletion]: if deletion_type == DeletionType.Cohort_stale or deletion_type == DeletionType.Cohort_full: cohort_ids_with_data = self._verify_by_column("team_id, cohort_id", async_deletions) return [ @@ -42,7 +42,7 @@ def _verify_by_group(self, deletion_type: int, async_deletions: List[AsyncDeleti else: return [] - def _verify_by_column(self, distinct_columns: str, async_deletions: List[AsyncDeletion]) -> Set[Tuple[Any, ...]]: + def _verify_by_column(self, distinct_columns: str, async_deletions: list[AsyncDeletion]) -> set[tuple[Any, ...]]: conditions, args = self._conditions(async_deletions) clickhouse_result = sync_execute( f""" @@ -62,7 +62,7 @@ def _column_name(self, async_deletion: AsyncDeletion): ) return "cohort_id" - def _condition(self, async_deletion: AsyncDeletion, suffix: str) -> Tuple[str, Dict]: + def _condition(self, async_deletion: AsyncDeletion, suffix: str) -> tuple[str, dict]: team_id_param = f"team_id{suffix}" key_param = f"key{suffix}" version_param = f"version{suffix}" diff --git a/posthog/models/async_deletion/delete_events.py b/posthog/models/async_deletion/delete_events.py index 2486043a5b871..988161336cc56 100644 --- a/posthog/models/async_deletion/delete_events.py +++ b/posthog/models/async_deletion/delete_events.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Set, Tuple +from typing import Any from posthog.client import sync_execute from posthog.models.async_deletion import AsyncDeletion, DeletionType, CLICKHOUSE_ASYNC_DELETION_TABLE @@ -22,7 +22,7 @@ class AsyncEventDeletion(AsyncDeletionProcess): DELETION_TYPES = [DeletionType.Team, DeletionType.Person] - def process(self, deletions: List[AsyncDeletion]): + def process(self, deletions: list[AsyncDeletion]): if len(deletions) == 0: logger.debug("No AsyncDeletion to perform") return @@ -87,7 +87,7 @@ def process(self, deletions: List[AsyncDeletion]): workload=Workload.OFFLINE, ) - def _fill_table(self, deletions: List[AsyncDeletion], temp_table_name: str): + def _fill_table(self, deletions: list[AsyncDeletion], temp_table_name: str): sync_execute(f"DROP TABLE IF EXISTS {temp_table_name}", workload=Workload.OFFLINE) sync_execute( CLICKHOUSE_ASYNC_DELETION_TABLE.format(table_name=temp_table_name, cluster=CLICKHOUSE_CLUSTER), @@ -111,7 +111,7 @@ def _fill_table(self, deletions: List[AsyncDeletion], temp_table_name: str): workload=Workload.OFFLINE, ) - def _verify_by_group(self, deletion_type: int, async_deletions: List[AsyncDeletion]) -> List[AsyncDeletion]: + def _verify_by_group(self, deletion_type: int, async_deletions: list[AsyncDeletion]) -> list[AsyncDeletion]: if deletion_type == DeletionType.Team: team_ids_with_data = self._verify_by_column("team_id", async_deletions) return [row for row in async_deletions if (row.team_id,) not in team_ids_with_data] @@ -122,7 +122,7 @@ def _verify_by_group(self, deletion_type: int, async_deletions: List[AsyncDeleti else: return [] - def _verify_by_column(self, distinct_columns: str, async_deletions: List[AsyncDeletion]) -> Set[Tuple[Any, ...]]: + def _verify_by_column(self, distinct_columns: str, async_deletions: list[AsyncDeletion]) -> set[tuple[Any, ...]]: conditions, args = self._conditions(async_deletions) clickhouse_result = sync_execute( f""" @@ -142,7 +142,7 @@ def _column_name(self, async_deletion: AsyncDeletion): else: return f"$group_{async_deletion.group_type_index}" - def _condition(self, async_deletion: AsyncDeletion, suffix: str) -> Tuple[str, Dict]: + def _condition(self, async_deletion: AsyncDeletion, suffix: str) -> tuple[str, dict]: if async_deletion.deletion_type == DeletionType.Team: return f"team_id = %(team_id{suffix})s", {f"team_id{suffix}": async_deletion.team_id} else: diff --git a/posthog/models/async_migration.py b/posthog/models/async_migration.py index 885f7ce397931..92d61fb5e3f33 100644 --- a/posthog/models/async_migration.py +++ b/posthog/models/async_migration.py @@ -1,5 +1,3 @@ -from typing import List - from django.db import models @@ -63,7 +61,7 @@ def get_all_running_async_migrations(): return AsyncMigration.objects.filter(status=MigrationStatus.Running) -def get_async_migrations_by_status(target_statuses: List[int]): +def get_async_migrations_by_status(target_statuses: list[int]): return AsyncMigration.objects.filter(status__in=target_statuses) diff --git a/posthog/models/channel_type/sql.py b/posthog/models/channel_type/sql.py index 15470601c2dfb..d631c276e55dd 100644 --- a/posthog/models/channel_type/sql.py +++ b/posthog/models/channel_type/sql.py @@ -37,7 +37,7 @@ f"TRUNCATE TABLE IF EXISTS {CHANNEL_DEFINITION_TABLE_NAME} ON CLUSTER '{CLICKHOUSE_CLUSTER}'" ) -with open(os.path.join(os.path.dirname(__file__), "channel_definitions.json"), "r") as f: +with open(os.path.join(os.path.dirname(__file__), "channel_definitions.json")) as f: CHANNEL_DEFINITIONS = json.loads(f.read()) @@ -54,7 +54,7 @@ def format_value(value): INSERT INTO channel_definition (domain, kind, domain_type, type_if_paid, type_if_organic) VALUES { ''', -'''.join((f'({" ,".join(map(format_value, x))})' for x in CHANNEL_DEFINITIONS))}, +'''.join(f'({" ,".join(map(format_value, x))})' for x in CHANNEL_DEFINITIONS)}, ; """ diff --git a/posthog/models/cohort/cohort.py b/posthog/models/cohort/cohort.py index a10be159d5702..8f7867127a1fa 100644 --- a/posthog/models/cohort/cohort.py +++ b/posthog/models/cohort/cohort.py @@ -1,6 +1,6 @@ import time from datetime import datetime -from typing import Any, Dict, List, Literal, Optional, Union, cast +from typing import Any, Literal, Optional, Union, cast import structlog from django.conf import settings @@ -37,7 +37,7 @@ class Group: def __init__( self, - properties: Optional[Dict[str, Any]] = None, + properties: Optional[dict[str, Any]] = None, action_id: Optional[int] = None, event_id: Optional[str] = None, days: Optional[int] = None, @@ -59,7 +59,7 @@ def __init__( self.start_date = start_date self.end_date = end_date - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: dup = self.__dict__.copy() dup["start_date"] = self.start_date.isoformat() if self.start_date else self.start_date dup["end_date"] = self.end_date.isoformat() if self.end_date else self.end_date @@ -159,11 +159,11 @@ def properties(self) -> PropertyGroup: ) else: # invalid state - return PropertyGroup(PropertyOperatorType.AND, cast(List[Property], [])) + return PropertyGroup(PropertyOperatorType.AND, cast(list[Property], [])) return PropertyGroup(PropertyOperatorType.OR, property_groups) - return PropertyGroup(PropertyOperatorType.AND, cast(List[Property], [])) + return PropertyGroup(PropertyOperatorType.AND, cast(list[Property], [])) @property def has_complex_behavioral_filter(self) -> bool: @@ -241,7 +241,7 @@ def calculate_people_ch(self, pending_version: int, *, initiating_user_id: Optio clear_stale_cohort.delay(self.pk, before_version=pending_version) - def insert_users_by_list(self, items: List[str]) -> None: + def insert_users_by_list(self, items: list[str]) -> None: """ Items is a list of distinct_ids """ @@ -303,7 +303,7 @@ def insert_users_by_list(self, items: List[str]) -> None: self.save() capture_exception(err) - def insert_users_list_by_uuid(self, items: List[str], insert_in_clickhouse: bool = False, batchsize=1000) -> None: + def insert_users_list_by_uuid(self, items: list[str], insert_in_clickhouse: bool = False, batchsize=1000) -> None: from posthog.models.cohort.util import get_static_cohort_size, insert_static_cohort try: diff --git a/posthog/models/cohort/util.py b/posthog/models/cohort/util.py index 059bfb3813b8d..2af2c0c66c1ae 100644 --- a/posthog/models/cohort/util.py +++ b/posthog/models/cohort/util.py @@ -1,6 +1,6 @@ import uuid from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast +from typing import Any, Optional, Union, cast import structlog from dateutil import parser @@ -44,7 +44,7 @@ logger = structlog.get_logger(__name__) -def format_person_query(cohort: Cohort, index: int, hogql_context: HogQLContext) -> Tuple[str, Dict[str, Any]]: +def format_person_query(cohort: Cohort, index: int, hogql_context: HogQLContext) -> tuple[str, dict[str, Any]]: if cohort.is_static: return format_static_cohort_query(cohort, index, prepend="") @@ -72,7 +72,7 @@ def format_person_query(cohort: Cohort, index: int, hogql_context: HogQLContext) def print_cohort_hogql_query(cohort: Cohort, hogql_context: HogQLContext) -> str: from posthog.hogql_queries.query_runner import get_query_runner - persons_query = cast(Dict, cohort.query) + persons_query = cast(dict, cohort.query) persons_query["select"] = ["id as actor_id"] query = get_query_runner( persons_query, team=cast(Team, cohort.team), limit_context=LimitContext.COHORT_CALCULATION @@ -81,7 +81,7 @@ def print_cohort_hogql_query(cohort: Cohort, hogql_context: HogQLContext) -> str return print_ast(query, context=hogql_context, dialect="clickhouse") -def format_static_cohort_query(cohort: Cohort, index: int, prepend: str) -> Tuple[str, Dict[str, Any]]: +def format_static_cohort_query(cohort: Cohort, index: int, prepend: str) -> tuple[str, dict[str, Any]]: cohort_id = cohort.pk return ( f"SELECT person_id as id FROM {PERSON_STATIC_COHORT_TABLE} WHERE cohort_id = %({prepend}_cohort_id_{index})s AND team_id = %(team_id)s", @@ -89,7 +89,7 @@ def format_static_cohort_query(cohort: Cohort, index: int, prepend: str) -> Tupl ) -def format_precalculated_cohort_query(cohort: Cohort, index: int, prepend: str = "") -> Tuple[str, Dict[str, Any]]: +def format_precalculated_cohort_query(cohort: Cohort, index: int, prepend: str = "") -> tuple[str, dict[str, Any]]: filter_query = GET_PERSON_ID_BY_PRECALCULATED_COHORT_ID.format(index=index, prepend=prepend) return ( filter_query, @@ -121,7 +121,7 @@ def get_entity_query( team_id: int, group_idx: Union[int, str], hogql_context: HogQLContext, -) -> Tuple[str, Dict[str, str]]: +) -> tuple[str, dict[str, str]]: if event_id: return f"event = %({f'event_{group_idx}'})s", {f"event_{group_idx}": event_id} elif action_id: @@ -139,9 +139,9 @@ def get_entity_query( def get_date_query( days: Optional[str], start_time: Optional[str], end_time: Optional[str] -) -> Tuple[str, Dict[str, str]]: +) -> tuple[str, dict[str, str]]: date_query: str = "" - date_params: Dict[str, str] = {} + date_params: dict[str, str] = {} if days: date_query, date_params = parse_entity_timestamps_in_days(int(days)) elif start_time or end_time: @@ -150,7 +150,7 @@ def get_date_query( return date_query, date_params -def parse_entity_timestamps_in_days(days: int) -> Tuple[str, Dict[str, str]]: +def parse_entity_timestamps_in_days(days: int) -> tuple[str, dict[str, str]]: curr_time = timezone.now() start_time = curr_time - timedelta(days=days) @@ -163,9 +163,9 @@ def parse_entity_timestamps_in_days(days: int) -> Tuple[str, Dict[str, str]]: ) -def parse_cohort_timestamps(start_time: Optional[str], end_time: Optional[str]) -> Tuple[str, Dict[str, str]]: +def parse_cohort_timestamps(start_time: Optional[str], end_time: Optional[str]) -> tuple[str, dict[str, str]]: clause = "AND " - params: Dict[str, str] = {} + params: dict[str, str] = {} if start_time: clause += "timestamp >= %(date_from)s" @@ -199,7 +199,7 @@ def format_filter_query( hogql_context: HogQLContext, id_column: str = "distinct_id", custom_match_field="person_id", -) -> Tuple[str, Dict[str, Any]]: +) -> tuple[str, dict[str, Any]]: person_query, params = format_cohort_subquery(cohort, index, hogql_context, custom_match_field=custom_match_field) person_id_query = CALCULATE_COHORT_PEOPLE_SQL.format( @@ -215,7 +215,7 @@ def format_cohort_subquery( index: int, hogql_context: HogQLContext, custom_match_field="person_id", -) -> Tuple[str, Dict[str, Any]]: +) -> tuple[str, dict[str, Any]]: is_precalculated = is_precalculated_query(cohort) if is_precalculated: query, params = format_precalculated_cohort_query(cohort, index) @@ -259,7 +259,7 @@ def get_person_ids_by_cohort_id( return [str(row[0]) for row in results] -def insert_static_cohort(person_uuids: List[Optional[uuid.UUID]], cohort_id: int, team: Team): +def insert_static_cohort(person_uuids: list[Optional[uuid.UUID]], cohort_id: int, team: Team): persons = ( { "id": str(uuid.uuid4()), @@ -442,17 +442,17 @@ def simplified_cohort_filter_properties(cohort: Cohort, team: Team, is_negated=F return cohort.properties -def _get_cohort_ids_by_person_uuid(uuid: str, team_id: int) -> List[int]: +def _get_cohort_ids_by_person_uuid(uuid: str, team_id: int) -> list[int]: res = sync_execute(GET_COHORTS_BY_PERSON_UUID, {"person_id": uuid, "team_id": team_id}) return [row[0] for row in res] -def _get_static_cohort_ids_by_person_uuid(uuid: str, team_id: int) -> List[int]: +def _get_static_cohort_ids_by_person_uuid(uuid: str, team_id: int) -> list[int]: res = sync_execute(GET_STATIC_COHORTPEOPLE_BY_PERSON_UUID, {"person_id": uuid, "team_id": team_id}) return [row[0] for row in res] -def get_all_cohort_ids_by_person_uuid(uuid: str, team_id: int) -> List[int]: +def get_all_cohort_ids_by_person_uuid(uuid: str, team_id: int) -> list[int]: cohort_ids = _get_cohort_ids_by_person_uuid(uuid, team_id) static_cohort_ids = _get_static_cohort_ids_by_person_uuid(uuid, team_id) return [*cohort_ids, *static_cohort_ids] @@ -461,8 +461,8 @@ def get_all_cohort_ids_by_person_uuid(uuid: str, team_id: int) -> List[int]: def get_dependent_cohorts( cohort: Cohort, using_database: str = "default", - seen_cohorts_cache: Optional[Dict[int, CohortOrEmpty]] = None, -) -> List[Cohort]: + seen_cohorts_cache: Optional[dict[int, CohortOrEmpty]] = None, +) -> list[Cohort]: if seen_cohorts_cache is None: seen_cohorts_cache = {} @@ -508,7 +508,7 @@ def get_dependent_cohorts( return cohorts -def sort_cohorts_topologically(cohort_ids: Set[int], seen_cohorts_cache: Dict[int, CohortOrEmpty]) -> List[int]: +def sort_cohorts_topologically(cohort_ids: set[int], seen_cohorts_cache: dict[int, CohortOrEmpty]) -> list[int]: """ Sorts the given cohorts in an order where cohorts with no dependencies are placed first, followed by cohorts that depend on the preceding ones. It ensures that each cohort in the sorted list @@ -518,7 +518,7 @@ def sort_cohorts_topologically(cohort_ids: Set[int], seen_cohorts_cache: Dict[in if not cohort_ids: return [] - dependency_graph: Dict[int, List[int]] = {} + dependency_graph: dict[int, list[int]] = {} seen = set() # build graph (adjacency list) @@ -553,7 +553,7 @@ def dfs(node, seen, sorted_arr): sorted_arr.append(int(node)) seen.add(node) - sorted_cohort_ids: List[int] = [] + sorted_cohort_ids: list[int] = [] seen = set() for cohort_id in cohort_ids: if cohort_id not in seen: diff --git a/posthog/models/dashboard.py b/posthog/models/dashboard.py index 9be7e0de14e93..003201722a5c4 100644 --- a/posthog/models/dashboard.py +++ b/posthog/models/dashboard.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any from django.contrib.postgres.fields import ArrayField from django.db import models @@ -93,7 +93,7 @@ def is_sharing_enabled(self): def url(self): return absolute_uri(f"/dashboard/{self.id}") - def get_analytics_metadata(self) -> Dict[str, Any]: + def get_analytics_metadata(self) -> dict[str, Any]: """ Returns serialized information about the object for analytics reporting. """ diff --git a/posthog/models/dashboard_tile.py b/posthog/models/dashboard_tile.py index 50af2868abf5b..9d39028e49bf0 100644 --- a/posthog/models/dashboard_tile.py +++ b/posthog/models/dashboard_tile.py @@ -1,5 +1,3 @@ -from typing import List - from django.core.exceptions import ValidationError from django.db import models from django.db.models import Q, QuerySet, UniqueConstraint @@ -112,7 +110,7 @@ def save(self, *args, **kwargs) -> None: if "update_fields" in kwargs: kwargs["update_fields"].append("filters_hash") - super(DashboardTile, self).save(*args, **kwargs) + super().save(*args, **kwargs) def copy_to_dashboard(self, dashboard: Dashboard) -> None: DashboardTile.objects.create( @@ -139,7 +137,7 @@ def dashboard_queryset(queryset: QuerySet) -> QuerySet: ) -def get_tiles_ordered_by_position(dashboard: Dashboard, size: str = "xs") -> List[DashboardTile]: +def get_tiles_ordered_by_position(dashboard: Dashboard, size: str = "xs") -> list[DashboardTile]: tiles = list( dashboard.tiles.select_related("insight", "text") .exclude(insight__deleted=True) diff --git a/posthog/models/element/element.py b/posthog/models/element/element.py index c1091932cd4c8..4beeb5400851b 100644 --- a/posthog/models/element/element.py +++ b/posthog/models/element/element.py @@ -1,5 +1,4 @@ import re -from typing import List from django.contrib.postgres.fields import ArrayField from django.db import models @@ -34,7 +33,7 @@ def _escape(input: str) -> str: return input.replace('"', r"\"") -def elements_to_string(elements: List[Element]) -> str: +def elements_to_string(elements: list[Element]) -> str: ret = [] for element in elements: el_string = "" @@ -58,7 +57,7 @@ def elements_to_string(elements: List[Element]) -> str: return ";".join(ret) -def chain_to_elements(chain: str) -> List[Element]: +def chain_to_elements(chain: str) -> list[Element]: elements = [] for idx, el_string in enumerate(re.findall(split_chain_regex, chain)): el_string_split = re.findall(split_class_attributes, el_string)[0] diff --git a/posthog/models/element_group.py b/posthog/models/element_group.py index 3d399f2559844..0a6a2545da0e5 100644 --- a/posthog/models/element_group.py +++ b/posthog/models/element_group.py @@ -1,6 +1,6 @@ import hashlib import json -from typing import Any, Dict, List +from typing import Any from django.db import models, transaction from django.forms.models import model_to_dict @@ -9,8 +9,8 @@ from posthog.models.team import Team -def hash_elements(elements: List) -> str: - elements_list: List[Dict] = [] +def hash_elements(elements: list) -> str: + elements_list: list[dict] = [] for element in elements: el_dict = model_to_dict(element) [el_dict.pop(key) for key in ["event", "id", "group"]] diff --git a/posthog/models/entity/entity.py b/posthog/models/entity/entity.py index 91865f9fa50f9..255edb0db4f3a 100644 --- a/posthog/models/entity/entity.py +++ b/posthog/models/entity/entity.py @@ -1,6 +1,6 @@ import inspect from collections import Counter -from typing import Any, Dict, Literal, Optional +from typing import Any, Literal, Optional from django.conf import settings from rest_framework.exceptions import ValidationError @@ -67,7 +67,7 @@ class Entity(PropertyMixin): id_field: Optional[str] timestamp_field: Optional[str] - def __init__(self, data: Dict[str, Any]) -> None: + def __init__(self, data: dict[str, Any]) -> None: self.id = data.get("id") if data.get("type") not in [ TREND_FILTER_TYPE_ACTIONS, @@ -102,7 +102,7 @@ def __init__(self, data: Dict[str, Any]) -> None: if self.type == TREND_FILTER_TYPE_EVENTS and not self.name: self.name = "All events" if self.id is None else str(self.id) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "id": self.id, "type": self.type, @@ -180,10 +180,10 @@ class ExclusionEntity(Entity, FunnelFromToStepsMixin): with extra parameters for exclusion semantics. """ - def __init__(self, data: Dict[str, Any]) -> None: + def __init__(self, data: dict[str, Any]) -> None: super().__init__(data) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: ret = super().to_dict() for _, func in inspect.getmembers(self, inspect.ismethod): diff --git a/posthog/models/entity/util.py b/posthog/models/entity/util.py index 06abcda5d0167..ffcd8cda671a7 100644 --- a/posthog/models/entity/util.py +++ b/posthog/models/entity/util.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, List, Sequence, Set, Tuple +from typing import Any +from collections.abc import Sequence from posthog.constants import TREND_FILTER_TYPE_ACTIONS from posthog.hogql.hogql import HogQLContext @@ -16,17 +17,17 @@ def get_entity_filtering_params( person_properties_mode: PersonPropertiesMode = PersonPropertiesMode.USING_PERSON_PROPERTIES_COLUMN, person_id_joined_alias: str = "person_id", deep_filtering: bool = False, -) -> Tuple[Dict, Dict]: +) -> tuple[dict, dict]: """Return SQL condition for filtering events by allowed entities (events/actions). Events matching _at least one_ entity are included. If no entities are provided, _all_ events are included.""" if not allowed_entities: return {}, {} - params: Dict[str, Any] = {} - entity_clauses: List[str] = [] - action_ids_already_included: Set[int] = set() # Avoid duplicating action conditions - events_already_included: Set[str] = set() # Avoid duplicating event conditions + params: dict[str, Any] = {} + entity_clauses: list[str] = [] + action_ids_already_included: set[int] = set() # Avoid duplicating action conditions + events_already_included: set[str] = set() # Avoid duplicating event conditions for entity in allowed_entities: if entity.type == TREND_FILTER_TYPE_ACTIONS: if entity.id in action_ids_already_included or entity.id is None: diff --git a/posthog/models/event/event.py b/posthog/models/event/event.py index 59b2f3c0a032b..184fffb18afa6 100644 --- a/posthog/models/event/event.py +++ b/posthog/models/event/event.py @@ -2,7 +2,7 @@ import datetime import re from collections import defaultdict -from typing import Dict, List, Optional, Union +from typing import Optional, Union from dateutil.relativedelta import relativedelta from django.db import models @@ -13,10 +13,10 @@ SELECTOR_ATTRIBUTE_REGEX = r"([a-zA-Z]*)\[(.*)=[\'|\"](.*)[\'|\"]\]" -LAST_UPDATED_TEAM_ACTION: Dict[int, datetime.datetime] = {} -TEAM_EVENT_ACTION_QUERY_CACHE: Dict[int, Dict[str, tuple]] = defaultdict(dict) +LAST_UPDATED_TEAM_ACTION: dict[int, datetime.datetime] = {} +TEAM_EVENT_ACTION_QUERY_CACHE: dict[int, dict[str, tuple]] = defaultdict(dict) # TEAM_EVENT_ACTION_QUERY_CACHE looks like team_id -> event ex('$pageview') -> query -TEAM_ACTION_QUERY_CACHE: Dict[int, str] = {} +TEAM_ACTION_QUERY_CACHE: dict[int, str] = {} DEFAULT_EARLIEST_TIME_DELTA = relativedelta(weeks=1) @@ -26,8 +26,8 @@ class SelectorPart: def __init__(self, tag: str, direct_descendant: bool, escape_slashes: bool): self.direct_descendant = direct_descendant - self.data: Dict[str, Union[str, List]] = {} - self.ch_attributes: Dict[str, Union[str, List]] = {} # attributes for CH + self.data: dict[str, Union[str, list]] = {} + self.ch_attributes: dict[str, Union[str, list]] = {} # attributes for CH result = re.search(SELECTOR_ATTRIBUTE_REGEX, tag) if result and "[id=" in tag: @@ -58,9 +58,9 @@ def __init__(self, tag: str, direct_descendant: bool, escape_slashes: bool): self.data["tag_name"] = tag @property - def extra_query(self) -> Dict[str, List[Union[str, List[str]]]]: - where: List[Union[str, List[str]]] = [] - params: List[Union[str, List[str]]] = [] + def extra_query(self) -> dict[str, list[Union[str, list[str]]]]: + where: list[Union[str, list[str]]] = [] + params: list[Union[str, list[str]]] = [] for key, value in self.data.items(): if "attr__" in key: where.append(f"(attributes ->> 'attr__{key.split('attr__')[1]}') = %s") @@ -78,7 +78,7 @@ def _unescape_class(self, class_name): class Selector: - parts: List[SelectorPart] = [] + parts: list[SelectorPart] = [] def __init__(self, selector: str, escape_slashes=True): self.parts = [] @@ -98,7 +98,7 @@ def __init__(self, selector: str, escape_slashes=True): def _split(self, selector): in_attribute_selector = False in_quotes: Optional[str] = None - part: List[str] = [] + part: list[str] = [] for char in selector: if char == "[" and in_quotes is None: in_attribute_selector = True diff --git a/posthog/models/event/query_event_list.py b/posthog/models/event/query_event_list.py index ded739c9a81c0..1ecdbee021a7d 100644 --- a/posthog/models/event/query_event_list.py +++ b/posthog/models/event/query_event_list.py @@ -1,5 +1,5 @@ from datetime import timedelta, datetime, time -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union from zoneinfo import ZoneInfo from dateutil.parser import isoparse @@ -29,10 +29,10 @@ def parse_timestamp(timestamp: str, tzinfo: ZoneInfo) -> datetime: def parse_request_params( - conditions: Dict[str, Union[None, str, List[str]]], team: Team, tzinfo: ZoneInfo -) -> Tuple[str, Dict]: + conditions: dict[str, Union[None, str, list[str]]], team: Team, tzinfo: ZoneInfo +) -> tuple[str, dict]: result = "" - params: Dict[str, Union[str, List[str]]] = {} + params: dict[str, Union[str, list[str]]] = {} for k, v in conditions.items(): if not isinstance(v, str): continue @@ -58,13 +58,13 @@ def parse_request_params( def query_events_list( filter: Filter, team: Team, - request_get_query_dict: Dict, - order_by: List[str], + request_get_query_dict: dict, + order_by: list[str], action_id: Optional[str], unbounded_date_from: bool = False, limit: int = DEFAULT_RETURNED_ROWS, offset: int = 0, -) -> List: +) -> list: # Note: This code is inefficient and problematic, see https://github.com/PostHog/posthog/issues/13485 for details. # To isolate its impact from rest of the queries its queries are run on different nodes as part of "offline" workloads. hogql_context = HogQLContext(within_non_hogql_query=True, team_id=team.pk, enable_select_queries=True) diff --git a/posthog/models/event/util.py b/posthog/models/event/util.py index c55094898016d..065d47da33161 100644 --- a/posthog/models/event/util.py +++ b/posthog/models/event/util.py @@ -1,7 +1,7 @@ import datetime as dt import json import uuid -from typing import Any, Dict, List, Literal, Optional, Set, Union +from typing import Any, Literal, Optional, Union from zoneinfo import ZoneInfo from dateutil.parser import isoparse @@ -31,16 +31,16 @@ def create_event( team: Team, distinct_id: str, timestamp: Optional[Union[timezone.datetime, str]] = None, - properties: Optional[Dict] = None, - elements: Optional[List[Element]] = None, + properties: Optional[dict] = None, + elements: Optional[list[Element]] = None, person_id: Optional[uuid.UUID] = None, - person_properties: Optional[Dict] = None, + person_properties: Optional[dict] = None, person_created_at: Optional[Union[timezone.datetime, str]] = None, - group0_properties: Optional[Dict] = None, - group1_properties: Optional[Dict] = None, - group2_properties: Optional[Dict] = None, - group3_properties: Optional[Dict] = None, - group4_properties: Optional[Dict] = None, + group0_properties: Optional[dict] = None, + group1_properties: Optional[dict] = None, + group2_properties: Optional[dict] = None, + group3_properties: Optional[dict] = None, + group4_properties: Optional[dict] = None, group0_created_at: Optional[Union[timezone.datetime, str]] = None, group1_created_at: Optional[Union[timezone.datetime, str]] = None, group2_created_at: Optional[Union[timezone.datetime, str]] = None, @@ -105,8 +105,8 @@ def format_clickhouse_timestamp( def bulk_create_events( - events: List[Dict[str, Any]], - person_mapping: Optional[Dict[str, Person]] = None, + events: list[dict[str, Any]], + person_mapping: Optional[dict[str, Person]] = None, ) -> None: """ TEST ONLY @@ -121,7 +121,7 @@ def bulk_create_events( if not TEST: raise Exception("This function is only meant for setting up tests") inserts = [] - params: Dict[str, Any] = {} + params: dict[str, Any] = {} for index, event in enumerate(events): datetime64_default_timestamp = timezone.now().astimezone(ZoneInfo("UTC")).strftime("%Y-%m-%d %H:%M:%S") timestamp = event.get("timestamp") or dt.datetime.now() @@ -287,7 +287,7 @@ class Meta: ] -def parse_properties(properties: str, allow_list: Optional[Set[str]] = None) -> Dict: +def parse_properties(properties: str, allow_list: Optional[set[str]] = None) -> dict: # parse_constants gets called for any NaN, Infinity etc values # we just want those to be returned as None if allow_list is None: @@ -349,7 +349,7 @@ def get_elements_chain(self, event): return event["elements_chain"] -def get_agg_event_count_for_teams(team_ids: List[Union[str, int]]) -> int: +def get_agg_event_count_for_teams(team_ids: list[Union[str, int]]) -> int: result = sync_execute( """ SELECT count(1) as count @@ -362,7 +362,7 @@ def get_agg_event_count_for_teams(team_ids: List[Union[str, int]]) -> int: def get_agg_events_with_groups_count_for_teams_and_period( - team_ids: List[Union[str, int]], begin: timezone.datetime, end: timezone.datetime + team_ids: list[Union[str, int]], begin: timezone.datetime, end: timezone.datetime ) -> int: result = sync_execute( """ diff --git a/posthog/models/exported_asset.py b/posthog/models/exported_asset.py index ceebb2bc3db03..d07009be45b4e 100644 --- a/posthog/models/exported_asset.py +++ b/posthog/models/exported_asset.py @@ -1,6 +1,6 @@ import secrets from datetime import timedelta -from typing import List, Optional +from typing import Optional import structlog from django.conf import settings @@ -178,7 +178,7 @@ def save_content_to_exported_asset(exported_asset: ExportedAsset, content: bytes def save_content_to_object_storage(exported_asset: ExportedAsset, content: bytes) -> None: - path_parts: List[str] = [ + path_parts: list[str] = [ settings.OBJECT_STORAGE_EXPORTS_FOLDER, exported_asset.export_format.split("/")[1], f"team-{exported_asset.team.id}", diff --git a/posthog/models/feature_flag/feature_flag.py b/posthog/models/feature_flag/feature_flag.py index 67432e0b643eb..0a46a44d53b98 100644 --- a/posthog/models/feature_flag/feature_flag.py +++ b/posthog/models/feature_flag/feature_flag.py @@ -1,7 +1,7 @@ import json from django.http import HttpRequest import structlog -from typing import Dict, List, Optional, cast +from typing import Optional, cast from django.core.cache import cache from django.db import models @@ -59,7 +59,7 @@ class Meta: # whether a feature is sending us rich analytics, like views & interactions. has_enriched_analytics: models.BooleanField = models.BooleanField(default=False, null=True, blank=True) - def get_analytics_metadata(self) -> Dict: + def get_analytics_metadata(self) -> dict: filter_count = sum(len(condition.get("properties", [])) for condition in self.conditions) variants_count = len(self.variants) payload_count = len(self._payloads) @@ -135,7 +135,7 @@ def get_filters(self): def transform_cohort_filters_for_easy_evaluation( self, using_database: str = "default", - seen_cohorts_cache: Optional[Dict[int, CohortOrEmpty]] = None, + seen_cohorts_cache: Optional[dict[int, CohortOrEmpty]] = None, ): """ Expands cohort filters into person property filters when possible. @@ -243,7 +243,7 @@ def transform_cohort_filters_for_easy_evaluation( if target_properties.type == PropertyOperatorType.AND: return self.conditions - for prop_group in cast(List[PropertyGroup], target_properties.values): + for prop_group in cast(list[PropertyGroup], target_properties.values): if ( len(prop_group.values) == 0 or not isinstance(prop_group.values[0], Property) @@ -264,9 +264,9 @@ def transform_cohort_filters_for_easy_evaluation( def get_cohort_ids( self, using_database: str = "default", - seen_cohorts_cache: Optional[Dict[int, CohortOrEmpty]] = None, + seen_cohorts_cache: Optional[dict[int, CohortOrEmpty]] = None, sort_by_topological_order=False, - ) -> List[int]: + ) -> list[int]: from posthog.models.cohort.util import get_dependent_cohorts, sort_cohorts_topologically if seen_cohorts_cache is None: @@ -398,9 +398,9 @@ class Meta: def set_feature_flags_for_team_in_cache( team_id: int, - feature_flags: Optional[List[FeatureFlag]] = None, + feature_flags: Optional[list[FeatureFlag]] = None, using_database: str = "default", -) -> List[FeatureFlag]: +) -> list[FeatureFlag]: from posthog.api.feature_flag import MinimalFeatureFlagSerializer if feature_flags is not None: @@ -422,7 +422,7 @@ def set_feature_flags_for_team_in_cache( return all_feature_flags -def get_feature_flags_for_team_in_cache(team_id: int) -> Optional[List[FeatureFlag]]: +def get_feature_flags_for_team_in_cache(team_id: int) -> Optional[list[FeatureFlag]]: try: flag_data = cache.get(f"team_feature_flags_{team_id}") except Exception: diff --git a/posthog/models/feature_flag/flag_analytics.py b/posthog/models/feature_flag/flag_analytics.py index d5f27d804ac48..f62ed1934eca8 100644 --- a/posthog/models/feature_flag/flag_analytics.py +++ b/posthog/models/feature_flag/flag_analytics.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING from posthog.constants import FlagRequestType from posthog.helpers.dashboard_templates import ( add_enriched_insights_to_feature_flag_dashboard, @@ -45,7 +45,7 @@ def increment_request_count( capture_exception(error) -def _extract_total_count_for_key_from_redis_hash(client: redis.Redis, key: str) -> Tuple[int, int, int]: +def _extract_total_count_for_key_from_redis_hash(client: redis.Redis, key: str) -> tuple[int, int, int]: total_count = 0 existing_values = client.hgetall(key) time_buckets = existing_values.keys() diff --git a/posthog/models/feature_flag/flag_matching.py b/posthog/models/feature_flag/flag_matching.py index 134af65dfdad7..0b4a6befebc94 100644 --- a/posthog/models/feature_flag/flag_matching.py +++ b/posthog/models/feature_flag/flag_matching.py @@ -3,7 +3,7 @@ from enum import Enum import time import structlog -from typing import Dict, List, Literal, Optional, Tuple, Union, cast +from typing import Literal, Optional, Union, cast from prometheus_client import Counter from django.conf import settings @@ -110,7 +110,7 @@ def __init__(self, team_id: int): self.failed_to_fetch_flags = False @cached_property - def group_types_to_indexes(self) -> Dict[GroupTypeName, GroupTypeIndex]: + def group_types_to_indexes(self) -> dict[GroupTypeName, GroupTypeIndex]: if self.failed_to_fetch_flags: raise DatabaseError("Failed to fetch group type mapping previously, not trying again.") try: @@ -124,7 +124,7 @@ def group_types_to_indexes(self) -> Dict[GroupTypeName, GroupTypeIndex]: raise err @cached_property - def group_type_index_to_name(self) -> Dict[GroupTypeIndex, GroupTypeName]: + def group_type_index_to_name(self) -> dict[GroupTypeIndex, GroupTypeName]: return {value: key for key, value in self.group_types_to_indexes.items()} @@ -133,15 +133,15 @@ class FeatureFlagMatcher: def __init__( self, - feature_flags: List[FeatureFlag], + feature_flags: list[FeatureFlag], distinct_id: str, - groups: Optional[Dict[GroupTypeName, str]] = None, + groups: Optional[dict[GroupTypeName, str]] = None, cache: Optional[FlagsMatcherCache] = None, - hash_key_overrides: Optional[Dict[str, str]] = None, - property_value_overrides: Optional[Dict[str, Union[str, int]]] = None, - group_property_value_overrides: Optional[Dict[str, Dict[str, Union[str, int]]]] = None, + hash_key_overrides: Optional[dict[str, str]] = None, + property_value_overrides: Optional[dict[str, Union[str, int]]] = None, + group_property_value_overrides: Optional[dict[str, dict[str, Union[str, int]]]] = None, skip_database_flags: bool = False, - cohorts_cache: Optional[Dict[int, CohortOrEmpty]] = None, + cohorts_cache: Optional[dict[int, CohortOrEmpty]] = None, ): if group_property_value_overrides is None: group_property_value_overrides = {} @@ -231,7 +231,7 @@ def get_match(self, feature_flag: FeatureFlag) -> FeatureFlagMatch: payload=None, ) - def get_matches(self) -> Tuple[Dict[str, Union[str, bool]], Dict[str, dict], Dict[str, object], bool]: + def get_matches(self) -> tuple[dict[str, Union[str, bool]], dict[str, dict], dict[str, object], bool]: flag_values = {} flag_evaluation_reasons = {} faced_error_computing_flags = False @@ -287,7 +287,7 @@ def get_matching_payload( else: return None - def is_super_condition_match(self, feature_flag: FeatureFlag) -> Tuple[bool, bool, FeatureFlagMatchReason]: + def is_super_condition_match(self, feature_flag: FeatureFlag) -> tuple[bool, bool, FeatureFlagMatchReason]: # TODO: Right now super conditions with property overrides bork when the database is down, # because we're still going to the database in the line below. Ideally, we should not go to the database. # Don't skip test: test_super_condition_with_override_properties_doesnt_make_database_requests when this is fixed. @@ -320,8 +320,8 @@ def is_super_condition_match(self, feature_flag: FeatureFlag) -> Tuple[bool, boo return False, False, FeatureFlagMatchReason.NO_CONDITION_MATCH def is_condition_match( - self, feature_flag: FeatureFlag, condition: Dict, condition_index: int - ) -> Tuple[bool, FeatureFlagMatchReason]: + self, feature_flag: FeatureFlag, condition: dict, condition_index: int + ) -> tuple[bool, FeatureFlagMatchReason]: rollout_percentage = condition.get("rollout_percentage") if len(condition.get("properties", [])) > 0: properties = Filter(data=condition).property_groups.flat @@ -405,12 +405,12 @@ def variant_lookup_table(self, feature_flag: FeatureFlag): return lookup_table @cached_property - def query_conditions(self) -> Dict[str, bool]: + def query_conditions(self) -> dict[str, bool]: try: # Some extra wiggle room here for timeouts because this depends on the number of flags as well, # and not just the database query. with execute_with_timeout(FLAG_MATCHING_QUERY_TIMEOUT_MS * 2, DATABASE_FOR_FLAG_MATCHING): - all_conditions: Dict = {} + all_conditions: dict = {} team_id = self.feature_flags[0].team_id person_query: QuerySet = Person.objects.using(DATABASE_FOR_FLAG_MATCHING).filter( team_id=team_id, @@ -418,7 +418,7 @@ def query_conditions(self) -> Dict[str, bool]: persondistinctid__team_id=team_id, ) basic_group_query: QuerySet = Group.objects.using(DATABASE_FOR_FLAG_MATCHING).filter(team_id=team_id) - group_query_per_group_type_mapping: Dict[GroupTypeIndex, Tuple[QuerySet, List[str]]] = {} + group_query_per_group_type_mapping: dict[GroupTypeIndex, tuple[QuerySet, list[str]]] = {} # :TRICKY: Create a queryset for each group type that uniquely identifies a group, based on the groups passed in. # If no groups for a group type are passed in, we can skip querying for that group type, # since the result will always be `false`. @@ -431,7 +431,7 @@ def query_conditions(self) -> Dict[str, bool]: [], ) - person_fields: List[str] = [] + person_fields: list[str] = [] for existence_condition_key in self.has_pure_is_not_conditions: if existence_condition_key == PERSON_KEY: @@ -637,7 +637,7 @@ def get_hash(self, feature_flag: FeatureFlag, salt="") -> float: def can_compute_locally( self, - properties: List[Property], + properties: list[Property], group_type_index: Optional[GroupTypeIndex] = None, ) -> bool: target_properties = self.property_value_overrides @@ -682,10 +682,10 @@ def has_pure_is_not_conditions(self) -> set[Literal["person"] | GroupTypeIndex]: def get_feature_flag_hash_key_overrides( team_id: int, - distinct_ids: List[str], + distinct_ids: list[str], using_database: str = "default", - person_id_to_distinct_id_mapping: Optional[Dict[int, str]] = None, -) -> Dict[str, str]: + person_id_to_distinct_id_mapping: Optional[dict[int, str]] = None, +) -> dict[str, str]: feature_flag_to_key_overrides = {} # Priority to the first distinctID's values, to keep this function deterministic @@ -716,15 +716,15 @@ def get_feature_flag_hash_key_overrides( # Return a Dict with all flags and their values def _get_all_feature_flags( - feature_flags: List[FeatureFlag], + feature_flags: list[FeatureFlag], team_id: int, distinct_id: str, - person_overrides: Optional[Dict[str, str]] = None, - groups: Optional[Dict[GroupTypeName, str]] = None, - property_value_overrides: Optional[Dict[str, Union[str, int]]] = None, - group_property_value_overrides: Optional[Dict[str, Dict[str, Union[str, int]]]] = None, + person_overrides: Optional[dict[str, str]] = None, + groups: Optional[dict[GroupTypeName, str]] = None, + property_value_overrides: Optional[dict[str, Union[str, int]]] = None, + group_property_value_overrides: Optional[dict[str, dict[str, Union[str, int]]]] = None, skip_database_flags: bool = False, -) -> Tuple[Dict[str, Union[str, bool]], Dict[str, dict], Dict[str, object], bool]: +) -> tuple[dict[str, Union[str, bool]], dict[str, dict], dict[str, object], bool]: if group_property_value_overrides is None: group_property_value_overrides = {} if property_value_overrides is None: @@ -752,11 +752,11 @@ def _get_all_feature_flags( def get_all_feature_flags( team_id: int, distinct_id: str, - groups: Optional[Dict[GroupTypeName, str]] = None, + groups: Optional[dict[GroupTypeName, str]] = None, hash_key_override: Optional[str] = None, - property_value_overrides: Optional[Dict[str, Union[str, int]]] = None, - group_property_value_overrides: Optional[Dict[str, Dict[str, Union[str, int]]]] = None, -) -> Tuple[Dict[str, Union[str, bool]], Dict[str, dict], Dict[str, object], bool]: + property_value_overrides: Optional[dict[str, Union[str, int]]] = None, + group_property_value_overrides: Optional[dict[str, dict[str, Union[str, int]]]] = None, +) -> tuple[dict[str, Union[str, bool]], dict[str, dict], dict[str, object], bool]: if group_property_value_overrides is None: group_property_value_overrides = {} if property_value_overrides is None: @@ -907,7 +907,7 @@ def get_all_feature_flags( ) -def set_feature_flag_hash_key_overrides(team_id: int, distinct_ids: List[str], hash_key_override: str) -> bool: +def set_feature_flag_hash_key_overrides(team_id: int, distinct_ids: list[str], hash_key_override: str) -> bool: # As a product decision, the first override wins, i.e consistency matters for the first walkthrough. # Thus, we don't need to do upserts here. @@ -1004,7 +1004,7 @@ def parse_exception_for_error_message(err: Exception): return reason -def key_and_field_for_property(property: Property) -> Tuple[str, str]: +def key_and_field_for_property(property: Property) -> tuple[str, str]: column = "group_properties" if property.type == "group" else "properties" key = property.key sanitized_key = sanitize_property_key(key) @@ -1016,8 +1016,8 @@ def key_and_field_for_property(property: Property) -> Tuple[str, str]: def get_all_properties_with_math_operators( - properties: List[Property], cohorts_cache: Dict[int, CohortOrEmpty], team_id: int -) -> List[Tuple[str, str]]: + properties: list[Property], cohorts_cache: dict[int, CohortOrEmpty], team_id: int +) -> list[tuple[str, str]]: all_keys_and_fields = [] for prop in properties: diff --git a/posthog/models/filters/base_filter.py b/posthog/models/filters/base_filter.py index ca2ef9e4c575f..f4d46c9acaf4b 100644 --- a/posthog/models/filters/base_filter.py +++ b/posthog/models/filters/base_filter.py @@ -1,6 +1,6 @@ import inspect import json -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Optional from rest_framework import request @@ -17,14 +17,14 @@ class BaseFilter(BaseParamMixin): - _data: Dict + _data: dict team: Optional["Team"] - kwargs: Dict + kwargs: dict hogql_context: HogQLContext def __init__( self, - data: Optional[Dict[str, Any]] = None, + data: Optional[dict[str, Any]] = None, request: Optional[request.Request] = None, *, team: Optional["Team"] = None, @@ -69,7 +69,7 @@ def __init__( simplified_filter = self.simplify(self.team) self._data = simplified_filter._data - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: ret = {} for _, func in inspect.getmembers(self, inspect.ismethod): @@ -78,20 +78,20 @@ def to_dict(self) -> Dict[str, Any]: return ret - def to_params(self) -> Dict[str, str]: + def to_params(self) -> dict[str, str]: return encode_get_request_params(data=self.to_dict()) def toJSON(self): return json.dumps(self.to_dict(), default=lambda o: o.__dict__, sort_keys=True, indent=4) - def shallow_clone(self, overrides: Dict[str, Any]): + def shallow_clone(self, overrides: dict[str, Any]): "Clone the filter's data while sharing the HogQL context" return type(self)( data={**self._data, **overrides}, **{**self.kwargs, "team": self.team, "hogql_context": self.hogql_context}, ) - def query_tags(self) -> Dict[str, Any]: + def query_tags(self) -> dict[str, Any]: ret = {} for _, func in inspect.getmembers(self, inspect.ismethod): diff --git a/posthog/models/filters/lifecycle_filter.py b/posthog/models/filters/lifecycle_filter.py index 576cf499f30d9..34775ac98f883 100644 --- a/posthog/models/filters/lifecycle_filter.py +++ b/posthog/models/filters/lifecycle_filter.py @@ -1,5 +1,5 @@ import datetime -from typing import Any, Dict, Optional +from typing import Any, Optional from posthog.models import Filter from posthog.utils import relative_date_parse from rest_framework.request import Request @@ -12,7 +12,7 @@ class LifecycleFilter(Filter): def __init__( self, - data: Optional[Dict[str, Any]] = None, + data: Optional[dict[str, Any]] = None, request: Optional[Request] = None, **kwargs, ) -> None: diff --git a/posthog/models/filters/mixins/base.py b/posthog/models/filters/mixins/base.py index b0c79566f72d5..a4640f0aae129 100644 --- a/posthog/models/filters/mixins/base.py +++ b/posthog/models/filters/mixins/base.py @@ -1,4 +1,4 @@ -from typing import Dict, Literal +from typing import Literal BreakdownType = Literal["event", "person", "cohort", "group", "session", "hogql"] IntervalType = Literal["hour", "day", "week", "month"] @@ -6,4 +6,4 @@ class BaseParamMixin: - _data: Dict + _data: dict diff --git a/posthog/models/filters/mixins/common.py b/posthog/models/filters/mixins/common.py index 8ab2c1ac7fcdf..65be03514030c 100644 --- a/posthog/models/filters/mixins/common.py +++ b/posthog/models/filters/mixins/common.py @@ -2,7 +2,7 @@ import json import re from math import ceil -from typing import Any, Dict, List, Literal, Optional, Union, cast +from typing import Any, Literal, Optional, Union, cast from zoneinfo import ZoneInfo from dateutil.relativedelta import relativedelta @@ -142,7 +142,7 @@ def formula_to_dict(self): class BreakdownMixin(BaseParamMixin): @cached_property - def breakdown(self) -> Optional[Union[str, List[Union[str, int]]]]: + def breakdown(self) -> Optional[Union[str, list[Union[str, int]]]]: breakdown = self._data.get(BREAKDOWN) if not isinstance(breakdown, str): @@ -171,11 +171,11 @@ def breakdown_attribution_value(self) -> Optional[int]: return int(attribution_value) if attribution_value is not None else None @cached_property - def breakdowns(self) -> Optional[List[Dict[str, Any]]]: + def breakdowns(self) -> Optional[list[dict[str, Any]]]: breakdowns = self._data.get(BREAKDOWNS) try: - if isinstance(breakdowns, List): + if isinstance(breakdowns, list): return breakdowns elif isinstance(breakdowns, str): return json.loads(breakdowns) @@ -226,7 +226,7 @@ def breakdown_hide_other_aggregation(self) -> Optional[bool]: @include_dict def breakdown_to_dict(self): - result: Dict = {} + result: dict = {} if self.breakdown: result[BREAKDOWN] = self.breakdown if self.breakdowns: @@ -346,8 +346,8 @@ def compare_to_dict(self): class DateMixin(BaseParamMixin): - date_from_delta_mapping: Optional[Dict[str, int]] - date_to_delta_mapping: Optional[Dict[str, int]] + date_from_delta_mapping: Optional[dict[str, int]] + date_to_delta_mapping: Optional[dict[str, int]] @cached_property def _date_from(self) -> Optional[Union[str, datetime.datetime]]: @@ -417,7 +417,7 @@ def use_explicit_dates(self) -> bool: return process_bool(self._data.get(EXPLICIT_DATE)) @include_dict - def date_to_dict(self) -> Dict: + def date_to_dict(self) -> dict: result_dict = {} if self._date_from: result_dict.update( @@ -455,8 +455,8 @@ def query_tags_dates(self): class EntitiesMixin(BaseParamMixin): @cached_property - def entities(self) -> List[Entity]: - processed_entities: List[Entity] = [] + def entities(self) -> list[Entity]: + processed_entities: list[Entity] = [] if self._data.get(ACTIONS): actions = self._data.get(ACTIONS, []) if isinstance(actions, str): @@ -487,20 +487,20 @@ def query_tags_entities(self): return {"number_of_entities": len(self.entities)} @cached_property - def actions(self) -> List[Entity]: + def actions(self) -> list[Entity]: return [entity for entity in self.entities if entity.type == TREND_FILTER_TYPE_ACTIONS] @cached_property - def events(self) -> List[Entity]: + def events(self) -> list[Entity]: return [entity for entity in self.entities if entity.type == TREND_FILTER_TYPE_EVENTS] @cached_property - def data_warehouse_entities(self) -> List[Entity]: + def data_warehouse_entities(self) -> list[Entity]: return [entity for entity in self.entities if entity.type == TREND_FILTER_TYPE_DATA_WAREHOUSE] @cached_property - def exclusions(self) -> List[ExclusionEntity]: - _exclusions: List[ExclusionEntity] = [] + def exclusions(self) -> list[ExclusionEntity]: + _exclusions: list[ExclusionEntity] = [] if self._data.get(EXCLUSIONS): exclusion_list = self._data.get(EXCLUSIONS, []) if isinstance(exclusion_list, str): diff --git a/posthog/models/filters/mixins/funnel.py b/posthog/models/filters/mixins/funnel.py index 91312a5030478..3baf5f15b50da 100644 --- a/posthog/models/filters/mixins/funnel.py +++ b/posthog/models/filters/mixins/funnel.py @@ -1,6 +1,6 @@ import datetime import json -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Literal, Optional, Union from posthog.models.property import Property @@ -111,7 +111,7 @@ def funnel_window_interval_unit(self) -> Optional[FunnelWindowIntervalType]: @include_dict def funnel_window_to_dict(self): - dict_part: Dict = {} + dict_part: dict = {} if self.funnel_window_interval is not None: dict_part[FUNNEL_WINDOW_INTERVAL] = self.funnel_window_interval if self.funnel_window_interval_unit is not None: @@ -154,7 +154,7 @@ def funnel_step(self) -> Optional[int]: return int(_step_as_string) @cached_property - def funnel_custom_steps(self) -> List[int]: + def funnel_custom_steps(self) -> list[int]: """ Custom step numbers to get persons for. This overrides FunnelPersonsStepMixin::funnel_step """ @@ -176,7 +176,7 @@ def funnel_step_to_dict(self): class FunnelPersonsStepBreakdownMixin(BaseParamMixin): @cached_property - def funnel_step_breakdown(self) -> Optional[Union[List[str], int, str]]: + def funnel_step_breakdown(self) -> Optional[Union[list[str], int, str]]: """ The breakdown value for which to get persons for. @@ -241,7 +241,7 @@ def funnel_viz_type(self) -> Optional[FunnelVizType]: @include_dict def funnel_type_to_dict(self): - result: Dict[str, str] = {} + result: dict[str, str] = {} if self.funnel_order_type: result[FUNNEL_ORDER_TYPE] = self.funnel_order_type if self.funnel_viz_type: @@ -277,7 +277,7 @@ def drop_off(self) -> Optional[bool]: @include_dict def funnel_trends_persons_to_dict(self): - result_dict: Dict = {} + result_dict: dict = {} if self.entrance_period_start: result_dict[ENTRANCE_PERIOD_START] = self.entrance_period_start.isoformat() if self.drop_off is not None: @@ -298,7 +298,7 @@ def correlation_type(self) -> Optional[FunnelCorrelationType]: return None @cached_property - def correlation_property_names(self) -> List[str]: + def correlation_property_names(self) -> list[str]: # Person Property names for which to run Person Properties correlation property_names = self._data.get(FUNNEL_CORRELATION_NAMES, []) if isinstance(property_names, str): @@ -306,7 +306,7 @@ def correlation_property_names(self) -> List[str]: return property_names @cached_property - def correlation_property_exclude_names(self) -> List[str]: + def correlation_property_exclude_names(self) -> list[str]: # Person Property names to exclude from Person Properties correlation property_names = self._data.get(FUNNEL_CORRELATION_EXCLUDE_NAMES, []) if isinstance(property_names, str): @@ -314,7 +314,7 @@ def correlation_property_exclude_names(self) -> List[str]: return property_names @cached_property - def correlation_event_names(self) -> List[str]: + def correlation_event_names(self) -> list[str]: # Event names for which to run EventWithProperties correlation event_names = self._data.get(FUNNEL_CORRELATION_EVENT_NAMES, []) if isinstance(event_names, str): @@ -322,7 +322,7 @@ def correlation_event_names(self) -> List[str]: return event_names @cached_property - def correlation_event_exclude_names(self) -> List[str]: + def correlation_event_exclude_names(self) -> list[str]: # Exclude event names from Event correlation property_names = self._data.get(FUNNEL_CORRELATION_EXCLUDE_EVENT_NAMES, []) if isinstance(property_names, str): @@ -330,7 +330,7 @@ def correlation_event_exclude_names(self) -> List[str]: return property_names @cached_property - def correlation_event_exclude_property_names(self) -> List[str]: + def correlation_event_exclude_property_names(self) -> list[str]: # Event Property names to exclude from EventWithProperties correlation property_names = self._data.get(FUNNEL_CORRELATION_EVENT_EXCLUDE_PROPERTY_NAMES, []) if isinstance(property_names, str): @@ -339,7 +339,7 @@ def correlation_event_exclude_property_names(self) -> List[str]: @include_dict def funnel_correlation_to_dict(self): - result_dict: Dict = {} + result_dict: dict = {} if self.correlation_type: result_dict[FUNNEL_CORRELATION_TYPE] = self.correlation_type if self.correlation_property_names: @@ -370,7 +370,7 @@ def correlation_person_entity(self) -> Optional["Entity"]: return Entity(event) if event else None @cached_property - def correlation_property_values(self) -> Optional[List[Property]]: + def correlation_property_values(self) -> Optional[list[Property]]: # Used for property correlations persons _props = self._data.get(FUNNEL_CORRELATION_PROPERTY_VALUES) @@ -421,7 +421,7 @@ def correlation_persons_converted(self) -> Optional[bool]: @include_dict def funnel_correlation_persons_to_dict(self): - result_dict: Dict = {} + result_dict: dict = {} if self.correlation_person_entity: result_dict[FUNNEL_CORRELATION_PERSON_ENTITY] = self.correlation_person_entity.to_dict() if self.correlation_property_values: diff --git a/posthog/models/filters/mixins/paths.py b/posthog/models/filters/mixins/paths.py index 393b1f7140a70..8249a7015d1bc 100644 --- a/posthog/models/filters/mixins/paths.py +++ b/posthog/models/filters/mixins/paths.py @@ -1,5 +1,5 @@ import json -from typing import Dict, List, Literal, Optional +from typing import Literal, Optional from posthog.constants import ( CUSTOM_EVENT, @@ -84,21 +84,21 @@ def paths_hogql_expression_to_dict(self): class TargetEventsMixin(BaseParamMixin): @cached_property - def target_events(self) -> List[str]: + def target_events(self) -> list[str]: target_events = self._data.get(PATHS_INCLUDE_EVENT_TYPES, []) if isinstance(target_events, str): return json.loads(target_events) return target_events @cached_property - def custom_events(self) -> List[str]: + def custom_events(self) -> list[str]: custom_events = self._data.get(PATHS_INCLUDE_CUSTOM_EVENTS, []) if isinstance(custom_events, str): return json.loads(custom_events) return custom_events @cached_property - def exclude_events(self) -> List[str]: + def exclude_events(self) -> list[str]: _exclude_events = self._data.get(PATHS_EXCLUDE_EVENTS, []) if isinstance(_exclude_events, str): return json.loads(_exclude_events) @@ -160,7 +160,7 @@ def funnel_paths_to_dict(self): class PathGroupingMixin(BaseParamMixin): @cached_property - def path_groupings(self) -> Optional[List[str]]: + def path_groupings(self) -> Optional[list[str]]: path_groupings = self._data.get(PATH_GROUPINGS, None) if isinstance(path_groupings, str): return json.loads(path_groupings) @@ -193,7 +193,7 @@ def path_replacements_to_dict(self): class LocalPathCleaningFiltersMixin(BaseParamMixin): @cached_property - def local_path_cleaning_filters(self) -> Optional[List[Dict[str, str]]]: + def local_path_cleaning_filters(self) -> Optional[list[dict[str, str]]]: local_path_cleaning_filters = self._data.get(LOCAL_PATH_CLEANING_FILTERS, None) if isinstance(local_path_cleaning_filters, str): return json.loads(local_path_cleaning_filters) diff --git a/posthog/models/filters/mixins/property.py b/posthog/models/filters/mixins/property.py index ff4cb56fee91a..2ffc984754b35 100644 --- a/posthog/models/filters/mixins/property.py +++ b/posthog/models/filters/mixins/property.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Optional, Union, cast from rest_framework.exceptions import ValidationError @@ -15,7 +15,7 @@ class PropertyMixin(BaseParamMixin): @cached_property - def old_properties(self) -> List[Property]: + def old_properties(self) -> list[Property]: _props = self._data.get(PROPERTIES) if isinstance(_props, str): @@ -64,7 +64,7 @@ def property_groups(self) -> PropertyGroup: # old properties return PropertyGroup(type=PropertyOperatorType.AND, values=self.old_properties) - def _parse_properties(self, properties: Optional[Any]) -> List[Property]: + def _parse_properties(self, properties: Optional[Any]) -> list[Property]: if isinstance(properties, list): _properties = [] for prop_params in properties: @@ -94,19 +94,19 @@ def _parse_properties(self, properties: Optional[Any]) -> List[Property]: ) return ret - def _parse_property_group(self, group: Optional[Dict]) -> PropertyGroup: + def _parse_property_group(self, group: Optional[dict]) -> PropertyGroup: if group and "type" in group and "values" in group: return PropertyGroup( PropertyOperatorType(group["type"].upper()), self._parse_property_group_list(group["values"]), ) - return PropertyGroup(PropertyOperatorType.AND, cast(List[Property], [])) + return PropertyGroup(PropertyOperatorType.AND, cast(list[Property], [])) - def _parse_property_group_list(self, prop_list: Optional[List]) -> Union[List[Property], List[PropertyGroup]]: + def _parse_property_group_list(self, prop_list: Optional[list]) -> Union[list[Property], list[PropertyGroup]]: if not prop_list: # empty prop list - return cast(List[Property], []) + return cast(list[Property], []) has_property_groups = False has_simple_properties = False diff --git a/posthog/models/filters/mixins/retention.py b/posthog/models/filters/mixins/retention.py index eeec027c4f817..044278f014275 100644 --- a/posthog/models/filters/mixins/retention.py +++ b/posthog/models/filters/mixins/retention.py @@ -1,6 +1,6 @@ import json from datetime import datetime, timedelta -from typing import Literal, Optional, Tuple, Union +from typing import Literal, Optional, Union from dateutil.relativedelta import relativedelta from django.utils import timezone @@ -112,7 +112,7 @@ def period_increment(self) -> Union[timedelta, relativedelta]: @staticmethod def determine_time_delta( total_intervals: int, period: str - ) -> Tuple[Union[timedelta, relativedelta], Union[timedelta, relativedelta]]: + ) -> tuple[Union[timedelta, relativedelta], Union[timedelta, relativedelta]]: if period == "Hour": return timedelta(hours=total_intervals), timedelta(hours=1) elif period == "Week": diff --git a/posthog/models/filters/mixins/session_recordings.py b/posthog/models/filters/mixins/session_recordings.py index 8779ea92e6bec..83d9bb40245a6 100644 --- a/posthog/models/filters/mixins/session_recordings.py +++ b/posthog/models/filters/mixins/session_recordings.py @@ -1,5 +1,5 @@ import json -from typing import List, Optional, Literal +from typing import Optional, Literal from posthog.constants import PERSON_UUID_FILTER, SESSION_RECORDINGS_FILTER_IDS from posthog.models.filters.mixins.common import BaseParamMixin @@ -19,7 +19,7 @@ def console_search_query(self) -> str | None: return self._data.get("console_search_query", None) @cached_property - def console_logs_filter(self) -> List[Literal["error", "warn", "info"]]: + def console_logs_filter(self) -> list[Literal["error", "warn", "info"]]: user_value = self._data.get("console_logs", None) or [] if isinstance(user_value, str): user_value = json.loads(user_value) @@ -43,7 +43,7 @@ def recording_duration_filter(self) -> Optional[Property]: return None @cached_property - def session_ids(self) -> Optional[List[str]]: + def session_ids(self) -> Optional[list[str]]: # Can be ['a', 'b'] or "['a', 'b']" or "a,b" session_ids_str = self._data.get(SESSION_RECORDINGS_FILTER_IDS, None) diff --git a/posthog/models/filters/mixins/simplify.py b/posthog/models/filters/mixins/simplify.py index 3b1e0eb426ba1..72d8d184539ef 100644 --- a/posthog/models/filters/mixins/simplify.py +++ b/posthog/models/filters/mixins/simplify.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Literal, TypeVar, cast +from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast from posthog.constants import PropertyOperatorType from posthog.models.property import GroupTypeIndex, PropertyGroup @@ -67,9 +67,9 @@ def _simplify_entity( self, team: "Team", entity_type: Literal["events", "actions", "exclusions"], - entity_params: Dict, + entity_params: dict, **kwargs, - ) -> Dict: + ) -> dict: from posthog.models.entity import Entity, ExclusionEntity EntityClass = ExclusionEntity if entity_type == "exclusions" else Entity @@ -82,7 +82,7 @@ def _simplify_entity( return EntityClass({**entity_params, "properties": properties}).to_dict() - def _simplify_properties(self, team: "Team", properties: List["Property"], **kwargs) -> "PropertyGroup": + def _simplify_properties(self, team: "Team", properties: list["Property"], **kwargs) -> "PropertyGroup": simplified_properties_values = [] for prop in properties: simplified_properties_values.append(self._simplify_property(team, prop, **kwargs)) diff --git a/posthog/models/filters/mixins/stickiness.py b/posthog/models/filters/mixins/stickiness.py index 0dfca1d834c83..1b659481b98ea 100644 --- a/posthog/models/filters/mixins/stickiness.py +++ b/posthog/models/filters/mixins/stickiness.py @@ -1,5 +1,6 @@ from datetime import datetime -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Optional, Union +from collections.abc import Callable from rest_framework.exceptions import ValidationError diff --git a/posthog/models/filters/mixins/utils.py b/posthog/models/filters/mixins/utils.py index a297cdcfa6320..5b5fe6d422d92 100644 --- a/posthog/models/filters/mixins/utils.py +++ b/posthog/models/filters/mixins/utils.py @@ -1,5 +1,6 @@ from functools import lru_cache -from typing import Callable, Optional, TypeVar, Union +from typing import Optional, TypeVar, Union +from collections.abc import Callable from posthog.utils import str_to_bool diff --git a/posthog/models/filters/path_filter.py b/posthog/models/filters/path_filter.py index 5ef9395d82da4..df7a0ca928581 100644 --- a/posthog/models/filters/path_filter.py +++ b/posthog/models/filters/path_filter.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Optional from rest_framework.request import Request @@ -76,7 +76,7 @@ class PathFilter( ): def __init__( self, - data: Optional[Dict[str, Any]] = None, + data: Optional[dict[str, Any]] = None, request: Optional[Request] = None, **kwargs, ) -> None: diff --git a/posthog/models/filters/retention_filter.py b/posthog/models/filters/retention_filter.py index 338d3d87e3e64..6f73aeb69d3f5 100644 --- a/posthog/models/filters/retention_filter.py +++ b/posthog/models/filters/retention_filter.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Union from rest_framework.request import Request @@ -48,7 +48,7 @@ class RetentionFilter( SampleMixin, BaseFilter, ): - def __init__(self, data: Optional[Dict[str, Any]] = None, request: Optional[Request] = None, **kwargs) -> None: + def __init__(self, data: Optional[dict[str, Any]] = None, request: Optional[Request] = None, **kwargs) -> None: if data is None: data = {} if data: @@ -58,7 +58,7 @@ def __init__(self, data: Optional[Dict[str, Any]] = None, request: Optional[Requ super().__init__(data, request, **kwargs) @cached_property - def breakdown_values(self) -> Optional[Tuple[Union[str, int], ...]]: + def breakdown_values(self) -> Optional[tuple[Union[str, int], ...]]: raw_value = self._data.get("breakdown_values", None) if raw_value is None: return None diff --git a/posthog/models/filters/stickiness_filter.py b/posthog/models/filters/stickiness_filter.py index 4674c4ceeb3d9..cde6d8020928f 100644 --- a/posthog/models/filters/stickiness_filter.py +++ b/posthog/models/filters/stickiness_filter.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union +from collections.abc import Callable from django.db.models.functions.datetime import ( TruncDay, @@ -62,7 +63,7 @@ class StickinessFilter( def __init__( self, - data: Optional[Dict[str, Any]] = None, + data: Optional[dict[str, Any]] = None, request: Optional[Request] = None, **kwargs, ) -> None: diff --git a/posthog/models/filters/test/test_filter.py b/posthog/models/filters/test/test_filter.py index 63a947bca6770..eb99a3ac42941 100644 --- a/posthog/models/filters/test/test_filter.py +++ b/posthog/models/filters/test/test_filter.py @@ -1,6 +1,7 @@ import datetime import json -from typing import Any, Callable, Dict, List, Optional, cast +from typing import Any, Optional, cast +from collections.abc import Callable from django.db.models import Q, Func, F, CharField from freezegun import freeze_time @@ -993,8 +994,8 @@ def filter_persons_with_annotation(filter: Filter, team: Team): def filter_persons_with_property_group( - filter: Filter, team: Team, property_overrides: Optional[Dict[str, Any]] = None -) -> List[str]: + filter: Filter, team: Team, property_overrides: Optional[dict[str, Any]] = None +) -> list[str]: if property_overrides is None: property_overrides = {} flush_persons_and_events() diff --git a/posthog/models/filters/test/test_path_filter.py b/posthog/models/filters/test/test_path_filter.py index df8ffac45aaec..3f66e0b9b7392 100644 --- a/posthog/models/filters/test/test_path_filter.py +++ b/posthog/models/filters/test/test_path_filter.py @@ -18,7 +18,7 @@ def test_to_dict(self): } ) - self.assertEquals( + self.assertEqual( filter.to_dict(), filter.to_dict() | { @@ -51,7 +51,7 @@ def test_to_dict_hogql(self): } ) - self.assertEquals( + self.assertEqual( filter.to_dict(), filter.to_dict() | { diff --git a/posthog/models/group/util.py b/posthog/models/group/util.py index 427c883a2e920..0b9c0fb9724c3 100644 --- a/posthog/models/group/util.py +++ b/posthog/models/group/util.py @@ -1,6 +1,6 @@ import datetime import json -from typing import Dict, Optional, Union +from typing import Optional, Union from zoneinfo import ZoneInfo from dateutil.parser import isoparse @@ -17,7 +17,7 @@ def raw_create_group_ch( team_id: int, group_type_index: GroupTypeIndex, group_key: str, - properties: Dict, + properties: dict, created_at: datetime.datetime, timestamp: Optional[datetime.datetime] = None, sync: bool = False, @@ -44,7 +44,7 @@ def create_group( team_id: int, group_type_index: GroupTypeIndex, group_key: str, - properties: Optional[Dict] = None, + properties: Optional[dict] = None, timestamp: Optional[Union[datetime.datetime, str]] = None, sync: bool = False, ) -> Group: diff --git a/posthog/models/instance_setting.py b/posthog/models/instance_setting.py index 749975e5d5e72..0ad0ca5bde0bf 100644 --- a/posthog/models/instance_setting.py +++ b/posthog/models/instance_setting.py @@ -1,6 +1,6 @@ import json from contextlib import contextmanager -from typing import Any, List +from typing import Any from django.db import models @@ -29,7 +29,7 @@ def get_instance_setting(key: str) -> Any: return CONSTANCE_CONFIG[key][0] # Get the default value -def get_instance_settings(keys: List[str]) -> Any: +def get_instance_settings(keys: list[str]) -> Any: for key in keys: assert key in CONSTANCE_CONFIG, f"Unknown dynamic setting: {repr(key)}" diff --git a/posthog/models/integration.py b/posthog/models/integration.py index 8ce1c9d6ef7c7..6e313ea179ff6 100644 --- a/posthog/models/integration.py +++ b/posthog/models/integration.py @@ -2,7 +2,7 @@ import hmac import time from datetime import timedelta -from typing import Dict, List, Literal +from typing import Literal from django.db import models from rest_framework.request import Request @@ -50,7 +50,7 @@ def __init__(self, integration: Integration) -> None: def client(self) -> WebClient: return WebClient(self.integration.sensitive_config["access_token"]) - def list_channels(self) -> List[Dict]: + def list_channels(self) -> list[dict]: # NOTE: Annoyingly the Slack API has no search so we have to load all channels... # We load public and private channels separately as when mixed, the Slack API pagination is buggy public_channels = self._list_channels_by_type("public_channel") @@ -59,7 +59,7 @@ def list_channels(self) -> List[Dict]: return sorted(channels, key=lambda x: x["name"]) - def _list_channels_by_type(self, type: Literal["public_channel", "private_channel"]) -> List[Dict]: + def _list_channels_by_type(self, type: Literal["public_channel", "private_channel"]) -> list[dict]: max_page = 10 channels = [] cursor = None @@ -76,7 +76,7 @@ def _list_channels_by_type(self, type: Literal["public_channel", "private_channe return channels @classmethod - def integration_from_slack_response(cls, team_id: str, created_by: User, params: Dict[str, str]) -> Integration: + def integration_from_slack_response(cls, team_id: str, created_by: User, params: dict[str, str]) -> Integration: client = WebClient() slack_config = cls.slack_config() diff --git a/posthog/models/organization.py b/posthog/models/organization.py index 8740a0f34c453..cdb4ee7ccd926 100644 --- a/posthog/models/organization.py +++ b/posthog/models/organization.py @@ -1,6 +1,6 @@ import json import sys -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, TypedDict, Union +from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union import structlog from django.conf import settings @@ -45,7 +45,7 @@ class OrganizationUsageInfo(TypedDict): events: Optional[OrganizationUsageResource] recordings: Optional[OrganizationUsageResource] rows_synced: Optional[OrganizationUsageResource] - period: Optional[List[str]] + period: Optional[list[str]] class OrganizationManager(models.Manager): @@ -56,9 +56,9 @@ def bootstrap( self, user: Optional["User"], *, - team_fields: Optional[Dict[str, Any]] = None, + team_fields: Optional[dict[str, Any]] = None, **kwargs, - ) -> Tuple["Organization", Optional["OrganizationMembership"], "Team"]: + ) -> tuple["Organization", Optional["OrganizationMembership"], "Team"]: """Instead of doing the legwork of creating an organization yourself, delegate the details with bootstrap.""" from .project import Project # Avoiding circular import @@ -157,7 +157,7 @@ def __str__(self): __repr__ = sane_repr("name") @property - def _billing_plan_details(self) -> Tuple[Optional[str], Optional[str]]: + def _billing_plan_details(self) -> tuple[Optional[str], Optional[str]]: """ Obtains details on the billing plan for the organization. Returns a tuple with (billing_plan_key, billing_realm) @@ -176,7 +176,7 @@ def _billing_plan_details(self) -> Tuple[Optional[str], Optional[str]]: return (license.plan, "ee") return (None, None) - def update_available_features(self) -> List[Union[AvailableFeature, str]]: + def update_available_features(self) -> list[Union[AvailableFeature, str]]: """Updates field `available_features`. Does not `save()`.""" if is_cloud() or self.usage: # Since billing V2 we just use the available features which are updated when the billing service is called diff --git a/posthog/models/organization_domain.py b/posthog/models/organization_domain.py index 416b2d560f310..5d49d8a64ac91 100644 --- a/posthog/models/organization_domain.py +++ b/posthog/models/organization_domain.py @@ -1,5 +1,5 @@ import secrets -from typing import Optional, Tuple +from typing import Optional import dns.resolver import structlog @@ -151,13 +151,13 @@ def has_saml(self) -> bool: """ return bool(self.saml_entity_id) and bool(self.saml_acs_url) and bool(self.saml_x509_cert) - def _complete_verification(self) -> Tuple["OrganizationDomain", bool]: + def _complete_verification(self) -> tuple["OrganizationDomain", bool]: self.last_verification_retry = None self.verified_at = timezone.now() self.save() return (self, True) - def attempt_verification(self) -> Tuple["OrganizationDomain", bool]: + def attempt_verification(self) -> tuple["OrganizationDomain", bool]: """ Performs a DNS verification for a specific domain. """ diff --git a/posthog/models/person/person.py b/posthog/models/person/person.py index a04565423335b..20f9dd7675487 100644 --- a/posthog/models/person/person.py +++ b/posthog/models/person/person.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional +from typing import Any, Optional from django.db import models, transaction from django.db.models import F, Q @@ -21,15 +21,15 @@ def create(self, *args: Any, **kwargs: Any): return person @staticmethod - def distinct_ids_exist(team_id: int, distinct_ids: List[str]) -> bool: + def distinct_ids_exist(team_id: int, distinct_ids: list[str]) -> bool: return PersonDistinctId.objects.filter(team_id=team_id, distinct_id__in=distinct_ids).exists() class Person(models.Model): - _distinct_ids: Optional[List[str]] + _distinct_ids: Optional[list[str]] @property - def distinct_ids(self) -> List[str]: + def distinct_ids(self) -> list[str]: if hasattr(self, "distinct_ids_cache"): return [id.distinct_id for id in self.distinct_ids_cache] if hasattr(self, "_distinct_ids") and self._distinct_ids: @@ -46,7 +46,7 @@ def add_distinct_id(self, distinct_id: str) -> None: PersonDistinctId.objects.create(person=self, distinct_id=distinct_id, team_id=self.team_id) # :DEPRECATED: This should happen through the plugin server - def _add_distinct_ids(self, distinct_ids: List[str]) -> None: + def _add_distinct_ids(self, distinct_ids: list[str]) -> None: for distinct_id in distinct_ids: self.add_distinct_id(distinct_id) @@ -274,7 +274,7 @@ class Meta: ] -def get_distinct_ids_for_subquery(person: Person | None, team: Team) -> List[str]: +def get_distinct_ids_for_subquery(person: Person | None, team: Team) -> list[str]: """_summary_ Fetching distinct_ids for a person from CH is slow, so we fetch them from PG for certain queries. Therfore we need diff --git a/posthog/models/person/util.py b/posthog/models/person/util.py index f6bcc60ebc333..0e1efa7bdb2c9 100644 --- a/posthog/models/person/util.py +++ b/posthog/models/person/util.py @@ -1,7 +1,7 @@ import datetime import json from contextlib import ExitStack -from typing import Dict, List, Optional, Union +from typing import Optional, Union from uuid import UUID from zoneinfo import ZoneInfo @@ -80,7 +80,7 @@ def person_distinct_id_deleted(sender, instance: PersonDistinctId, **kwargs): except: pass - def bulk_create_persons(persons_list: List[Dict]): + def bulk_create_persons(persons_list: list[dict]): persons = [] person_mapping = {} for _person in persons_list: @@ -127,7 +127,7 @@ def create_person( team_id: int, version: int, uuid: Optional[str] = None, - properties: Optional[Dict] = None, + properties: Optional[dict] = None, sync: bool = False, is_identified: bool = False, is_deleted: bool = False, @@ -217,7 +217,7 @@ def create_person_override( ) -def get_persons_by_distinct_ids(team_id: int, distinct_ids: List[str]) -> QuerySet: +def get_persons_by_distinct_ids(team_id: int, distinct_ids: list[str]) -> QuerySet: return Person.objects.filter( team_id=team_id, persondistinctid__team_id=team_id, @@ -225,7 +225,7 @@ def get_persons_by_distinct_ids(team_id: int, distinct_ids: List[str]) -> QueryS ) -def get_persons_by_uuids(team: Team, uuids: List[str]) -> QuerySet: +def get_persons_by_uuids(team: Team, uuids: list[str]) -> QuerySet: return Person.objects.filter(team_id=team.pk, uuid__in=uuids) @@ -254,7 +254,7 @@ def _delete_person( ) -def _get_distinct_ids_with_version(person: Person) -> Dict[str, int]: +def _get_distinct_ids_with_version(person: Person) -> dict[str, int]: return { distinct_id: int(version or 0) for distinct_id, version in PersonDistinctId.objects.filter(person=person, team_id=person.team_id) diff --git a/posthog/models/personal_api_key.py b/posthog/models/personal_api_key.py index 047471f4fe8a8..23bb04e0b4242 100644 --- a/posthog/models/personal_api_key.py +++ b/posthog/models/personal_api_key.py @@ -1,4 +1,4 @@ -from typing import Optional, Literal, Tuple, get_args +from typing import Optional, Literal, get_args import hashlib from django.contrib.auth.hashers import PBKDF2PasswordHasher @@ -111,5 +111,5 @@ class PersonalAPIKey(models.Model): ] -API_SCOPE_OBJECTS: Tuple[APIScopeObject, ...] = get_args(APIScopeObject) -API_SCOPE_ACTIONS: Tuple[APIScopeActions, ...] = get_args(APIScopeActions) +API_SCOPE_OBJECTS: tuple[APIScopeObject, ...] = get_args(APIScopeObject) +API_SCOPE_ACTIONS: tuple[APIScopeActions, ...] = get_args(APIScopeActions) diff --git a/posthog/models/plugin.py b/posthog/models/plugin.py index 900b1abec7741..06971c1ce7cca 100644 --- a/posthog/models/plugin.py +++ b/posthog/models/plugin.py @@ -3,7 +3,7 @@ import os from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, cast +from typing import Any, Optional, cast from uuid import UUID from django.conf import settings @@ -52,13 +52,13 @@ def raise_if_plugin_installed(url: str, organization_id: str): raise ValidationError(f'Plugin from URL "{url_without_private_key}" already installed!') -def update_validated_data_from_url(validated_data: Dict[str, Any], url: str) -> Dict[str, Any]: +def update_validated_data_from_url(validated_data: dict[str, Any], url: str) -> dict[str, Any]: """If remote plugin, download the archive and get up-to-date validated_data from there. Returns plugin.json.""" - plugin_json: Optional[Dict[str, Any]] + plugin_json: Optional[dict[str, Any]] if url.startswith("file:"): plugin_path = url[5:] plugin_json_path = os.path.join(plugin_path, "plugin.json") - plugin_json = cast(Optional[Dict[str, Any]], load_json_file(plugin_json_path)) + plugin_json = cast(Optional[dict[str, Any]], load_json_file(plugin_json_path)) if not plugin_json: raise ValidationError(f"Could not load plugin.json from: {plugin_json_path}") validated_data["plugin_type"] = "local" @@ -81,7 +81,7 @@ def update_validated_data_from_url(validated_data: Dict[str, Any], url: str) -> validated_data["latest_tag"] = parsed_url.get("tag", None) validated_data["archive"] = download_plugin_archive(validated_data["url"], validated_data["tag"]) plugin_json = cast( - Optional[Dict[str, Any]], + Optional[dict[str, Any]], get_file_from_archive(validated_data["archive"], "plugin.json"), ) if not plugin_json: @@ -124,7 +124,7 @@ class PluginManager(models.Manager): def install(self, **kwargs) -> "Plugin": if "organization_id" not in kwargs and "organization" in kwargs: kwargs["organization_id"] = kwargs["organization"].id - plugin_json: Optional[Dict[str, Any]] = None + plugin_json: Optional[dict[str, Any]] = None if kwargs.get("plugin_type", None) != Plugin.PluginType.SOURCE: plugin_json = update_validated_data_from_url(kwargs, kwargs["url"]) raise_if_plugin_installed(kwargs["url"], kwargs["organization_id"]) @@ -204,8 +204,8 @@ class PluginType(models.TextChoices): objects: PluginManager = PluginManager() - def get_default_config(self) -> Dict[str, Any]: - config: Dict[str, Any] = {} + def get_default_config(self) -> dict[str, Any]: + config: dict[str, Any] = {} config_schema = self.config_schema if isinstance(config_schema, dict): for key, config_entry in config_schema.items(): @@ -296,8 +296,8 @@ class PluginLogEntryType(str, Enum): class PluginSourceFileManager(models.Manager): def sync_from_plugin_archive( - self, plugin: Plugin, plugin_json_parsed: Optional[Dict[str, Any]] = None - ) -> Tuple[ + self, plugin: Plugin, plugin_json_parsed: Optional[dict[str, Any]] = None + ) -> tuple[ "PluginSourceFile", Optional["PluginSourceFile"], Optional["PluginSourceFile"], @@ -426,12 +426,12 @@ def fetch_plugin_log_entries( before: Optional[timezone.datetime] = None, search: Optional[str] = None, limit: Optional[int] = None, - type_filter: Optional[List[PluginLogEntryType]] = None, -) -> List[PluginLogEntry]: + type_filter: Optional[list[PluginLogEntryType]] = None, +) -> list[PluginLogEntry]: if type_filter is None: type_filter = [] - clickhouse_where_parts: List[str] = [] - clickhouse_kwargs: Dict[str, Any] = {} + clickhouse_where_parts: list[str] = [] + clickhouse_kwargs: dict[str, Any] = {} if team_id is not None: clickhouse_where_parts.append("team_id = %(team_id)s") clickhouse_kwargs["team_id"] = team_id @@ -457,7 +457,7 @@ def fetch_plugin_log_entries( return [PluginLogEntry(*result) for result in cast(list, sync_execute(clickhouse_query, clickhouse_kwargs))] -def validate_plugin_job_payload(plugin: Plugin, job_type: str, payload: Dict[str, Any], *, is_staff: bool): +def validate_plugin_job_payload(plugin: Plugin, job_type: str, payload: dict[str, Any], *, is_staff: bool): if not plugin.public_jobs: raise ValidationError("Plugin has no public jobs") if job_type not in plugin.public_jobs: diff --git a/posthog/models/project.py b/posthog/models/project.py index c4ead260fb780..030bd4669a6c8 100644 --- a/posthog/models/project.py +++ b/posthog/models/project.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Optional from django.db import models from django.db import transaction from django.core.validators import MinLengthValidator @@ -8,7 +8,7 @@ class ProjectManager(models.Manager): - def create_with_team(self, team_fields: Optional[dict] = None, **kwargs) -> Tuple["Project", "Team"]: + def create_with_team(self, team_fields: Optional[dict] = None, **kwargs) -> tuple["Project", "Team"]: from .team import Team with transaction.atomic(): diff --git a/posthog/models/property/property.py b/posthog/models/property/property.py index defd098cd7ef7..74ef611e257a3 100644 --- a/posthog/models/property/property.py +++ b/posthog/models/property/property.py @@ -2,11 +2,8 @@ from enum import Enum from typing import ( Any, - Dict, - List, Literal, Optional, - Tuple, Union, cast, ) @@ -27,7 +24,7 @@ class BehavioralPropertyType(str, Enum): RESTARTED_PERFORMING_EVENT = "restarted_performing_event" -ValueT = Union[str, int, List[str]] +ValueT = Union[str, int, list[str]] PropertyType = Literal[ "event", "feature", @@ -78,7 +75,7 @@ class BehavioralPropertyType(str, Enum): OperatorInterval = Literal["day", "week", "month", "year"] GroupTypeName = str -PropertyIdentifier = Tuple[PropertyName, PropertyType, Optional[GroupTypeIndex]] +PropertyIdentifier = tuple[PropertyName, PropertyType, Optional[GroupTypeIndex]] NEGATED_OPERATORS = ["is_not", "not_icontains", "not_regex", "is_not_set"] CLICKHOUSE_ONLY_PROPERTY_TYPES = [ @@ -187,7 +184,7 @@ class Property: # Type of `key` event_type: Optional[Literal["events", "actions"]] # Any extra filters on the event - event_filters: Optional[List["Property"]] + event_filters: Optional[list["Property"]] # Query people who did event '$pageview' 20 times in the last 30 days # translates into: # key = '$pageview', value = 'performed_event_multiple' @@ -216,7 +213,7 @@ class Property: total_periods: Optional[int] min_periods: Optional[int] negation: Optional[bool] = False - _data: Dict + _data: dict def __init__( self, @@ -239,7 +236,7 @@ def __init__( seq_time_value: Optional[int] = None, seq_time_interval: Optional[OperatorInterval] = None, negation: Optional[bool] = None, - event_filters: Optional[List["Property"]] = None, + event_filters: Optional[list["Property"]] = None, **kwargs, ) -> None: self.key = key @@ -298,7 +295,7 @@ def __repr__(self): params_repr = ", ".join(f"{key}={repr(value)}" for key, value in self.to_dict().items()) return f"Property({params_repr})" - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return {key: value for key, value in vars(self).items() if value is not None} @staticmethod @@ -331,17 +328,17 @@ def _parse_value(value: ValueT, convert_to_number: bool = False) -> Any: class PropertyGroup: type: PropertyOperatorType - values: Union[List[Property], List["PropertyGroup"]] + values: Union[list[Property], list["PropertyGroup"]] def __init__( self, type: PropertyOperatorType, - values: Union[List[Property], List["PropertyGroup"]], + values: Union[list[Property], list["PropertyGroup"]], ) -> None: self.type = type self.values = values - def combine_properties(self, operator: PropertyOperatorType, properties: List[Property]) -> "PropertyGroup": + def combine_properties(self, operator: PropertyOperatorType, properties: list[Property]) -> "PropertyGroup": if not properties: return self @@ -375,7 +372,7 @@ def __repr__(self): return f"PropertyGroup(type={self.type}-{params_repr})" @cached_property - def flat(self) -> List[Property]: + def flat(self) -> list[Property]: return list(self._property_groups_flat(self)) def _property_groups_flat(self, prop_group: "PropertyGroup"): diff --git a/posthog/models/property/util.py b/posthog/models/property/util.py index cae1be3340eac..b1ce9b6087aaa 100644 --- a/posthog/models/property/util.py +++ b/posthog/models/property/util.py @@ -1,17 +1,15 @@ import re from collections import Counter -from typing import Any, Callable -from typing import Counter as TCounter +from typing import Any +from collections.abc import Callable +from collections import Counter as TCounter from typing import ( - Dict, - Iterable, - List, Literal, Optional, - Tuple, Union, cast, ) +from collections.abc import Iterable from rest_framework import exceptions @@ -88,7 +86,7 @@ def parse_prop_grouped_clauses( person_id_joined_alias: str = "person_id", group_properties_joined: bool = True, _top_level: bool = True, -) -> Tuple[str, Dict]: +) -> tuple[str, dict]: """Translate the given property filter group into an SQL condition clause (+ SQL params).""" if not property_group or len(property_group.values) == 0: return "", {} @@ -119,7 +117,7 @@ def parse_prop_grouped_clauses( _final = f"{property_group.type} ".join(group_clauses) else: _final, final_params = parse_prop_clauses( - filters=cast(List[Property], property_group.values), + filters=cast(list[Property], property_group.values), prepend=f"{prepend}", table_name=table_name, allow_denormalized_props=allow_denormalized_props, @@ -151,7 +149,7 @@ def is_property_group(group: Union[Property, "PropertyGroup"]): def parse_prop_clauses( team_id: int, - filters: List[Property], + filters: list[Property], *, hogql_context: Optional[HogQLContext], prepend: str = "global", @@ -162,10 +160,10 @@ def parse_prop_clauses( person_id_joined_alias: str = "person_id", group_properties_joined: bool = True, property_operator: PropertyOperatorType = PropertyOperatorType.AND, -) -> Tuple[str, Dict]: +) -> tuple[str, dict]: """Translate the given property filter into an SQL condition clause (+ SQL params).""" final = [] - params: Dict[str, Any] = {} + params: dict[str, Any] = {} table_formatted = table_name if table_formatted != "": @@ -411,7 +409,7 @@ def prop_filter_json_extract( property_operator: str = PropertyOperatorType.AND, table_name: Optional[str] = None, use_event_column: Optional[str] = None, -) -> Tuple[str, Dict[str, Any]]: +) -> tuple[str, dict[str, Any]]: # TODO: Once all queries are migrated over we can get rid of allow_denormalized_props if transform_expression is not None: prop_var = transform_expression(prop_var) @@ -433,7 +431,7 @@ def prop_filter_json_extract( if prop.negation: operator = negate_operator(operator or "exact") - params: Dict[str, Any] = {} + params: dict[str, Any] = {} if operator == "is_not": params = { @@ -649,7 +647,7 @@ def get_single_or_multi_property_string_expr( allow_denormalized_props=True, materialised_table_column: str = "properties", normalize_url: bool = False, -) -> Tuple[str, Dict[str, Any]]: +) -> tuple[str, dict[str, Any]]: """ When querying for breakdown properties: * If the breakdown provided is a string, we extract the JSON from the properties object stored in the DB @@ -663,7 +661,7 @@ def get_single_or_multi_property_string_expr( no alias will be appended. """ - breakdown_params: Dict[str, Any] = {} + breakdown_params: dict[str, Any] = {} if isinstance(breakdown, str) or isinstance(breakdown, int): breakdown_key = f"breakdown_param_{len(breakdown_params) + 1}" breakdown_key = f"breakdown_param_{len(breakdown_params) + 1}" @@ -719,7 +717,7 @@ def get_property_string_expr( allow_denormalized_props: bool = True, table_alias: Optional[str] = None, materialised_table_column: str = "properties", -) -> Tuple[str, bool]: +) -> tuple[str, bool]: """ :param table: @@ -752,8 +750,8 @@ def get_property_string_expr( return trim_quotes_expr(f"JSONExtractRaw({table_string}{column}, {var})"), False -def box_value(value: Any, remove_spaces=False) -> List[Any]: - if not isinstance(value, List): +def box_value(value: Any, remove_spaces=False) -> list[Any]: + if not isinstance(value, list): value = [value] return [str(value).replace(" ", "") if remove_spaces else str(value) for value in value] @@ -764,19 +762,19 @@ def filter_element( *, operator: Optional[OperatorType] = None, prepend: str = "", -) -> Tuple[str, Dict]: +) -> tuple[str, dict]: if operator is None: operator = "exact" params = {} - combination_conditions: List[str] = [] + combination_conditions: list[str] = [] if key == "selector": if operator not in ("exact", "is_not"): raise exceptions.ValidationError( 'Filtering by element selector only supports operators "equals" and "doesn\'t equal" currently.' ) - selectors = cast(List[str | int], value) if isinstance(value, list) else [value] + selectors = cast(list[str | int], value) if isinstance(value, list) else [value] for idx, query in enumerate(selectors): if not query: # Skip empty selectors continue @@ -792,7 +790,7 @@ def filter_element( raise exceptions.ValidationError( 'Filtering by element tag only supports operators "equals" and "doesn\'t equal" currently.' ) - tag_names = cast(List[str | int], value) if isinstance(value, list) else [value] + tag_names = cast(list[str | int], value) if isinstance(value, list) else [value] for idx, tag_name in enumerate(tag_names): if not tag_name: # Skip empty tags continue @@ -824,12 +822,12 @@ def filter_element( return "0 = 191" if operator not in NEGATED_OPERATORS else "", {} -def process_ok_values(ok_values: Any, operator: OperatorType) -> List[str]: +def process_ok_values(ok_values: Any, operator: OperatorType) -> list[str]: if operator.endswith("_set"): return [r'[^"]+'] else: # Make sure ok_values is a list - ok_values = cast(List[str], [str(val) for val in ok_values]) if isinstance(ok_values, list) else [ok_values] + ok_values = cast(list[str], [str(val) for val in ok_values]) if isinstance(ok_values, list) else [ok_values] # Escape double quote characters, since e.g. text 'foo="bar"' is represented as text="foo=\"bar\"" # in the elements chain ok_values = [text.replace('"', r"\"") for text in ok_values] @@ -869,8 +867,8 @@ def build_selector_regex(selector: Selector) -> str: class HogQLPropertyChecker(TraversingVisitor): def __init__(self): - self.event_properties: List[str] = [] - self.person_properties: List[str] = [] + self.event_properties: list[str] = [] + self.person_properties: list[str] = [] def visit_field(self, node: ast.Field): if len(node.chain) > 1 and node.chain[0] == "properties": @@ -888,8 +886,8 @@ def visit_field(self, node: ast.Field): self.person_properties.append(node.chain[3]) -def extract_tables_and_properties(props: List[Property]) -> TCounter[PropertyIdentifier]: - counters: List[tuple] = [] +def extract_tables_and_properties(props: list[Property]) -> TCounter[PropertyIdentifier]: + counters: list[tuple] = [] for prop in props: if prop.type == "hogql": counters.extend(count_hogql_properties(prop.key)) @@ -917,7 +915,7 @@ def count_hogql_properties( return counter -def get_session_property_filter_statement(prop: Property, idx: int, prepend: str = "") -> Tuple[str, Dict[str, Any]]: +def get_session_property_filter_statement(prop: Property, idx: int, prepend: str = "") -> tuple[str, dict[str, Any]]: if prop.key == "$session_duration": try: duration = float(prop.value) # type: ignore diff --git a/posthog/models/sharing_configuration.py b/posthog/models/sharing_configuration.py index 48ea711f02a1f..7bbacc453559d 100644 --- a/posthog/models/sharing_configuration.py +++ b/posthog/models/sharing_configuration.py @@ -1,5 +1,5 @@ import secrets -from typing import List, cast +from typing import cast from django.db import models @@ -48,7 +48,7 @@ def can_access_object(self, obj: models.Model): return False - def get_connected_insight_ids(self) -> List[int]: + def get_connected_insight_ids(self) -> list[int]: if self.insight: if self.insight.deleted: return [] diff --git a/posthog/models/subscription.py b/posthog/models/subscription.py index f7b8a90a7e492..a0aa65ed9f668 100644 --- a/posthog/models/subscription.py +++ b/posthog/models/subscription.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from datetime import timedelta -from typing import Any, Dict, Optional +from typing import Any, Optional from dateutil.rrule import ( FR, @@ -134,7 +134,7 @@ def save(self, *args, **kwargs) -> None: self.set_next_delivery_date() if "update_fields" in kwargs: kwargs["update_fields"].append("next_delivery_date") - super(Subscription, self).save(*args, **kwargs) + super().save(*args, **kwargs) @property def url(self): @@ -187,7 +187,7 @@ def summary(self): capture_exception(e) return "sent on a schedule" - def get_analytics_metadata(self) -> Dict[str, Any]: + def get_analytics_metadata(self) -> dict[str, Any]: """ Returns serialized information about the object for analytics reporting. """ diff --git a/posthog/models/tagged_item.py b/posthog/models/tagged_item.py index 612f2f39399c3..302adcdb24f23 100644 --- a/posthog/models/tagged_item.py +++ b/posthog/models/tagged_item.py @@ -1,4 +1,5 @@ -from typing import Iterable, List, Union +from typing import Union +from collections.abc import Iterable from django.core.exceptions import ValidationError from django.db import models @@ -18,7 +19,7 @@ # Checks that exactly one object field is populated def build_check(related_objects: Iterable[str]): - built_check_list: List[Union[Q, Q]] = [] + built_check_list: list[Union[Q, Q]] = [] for field in related_objects: built_check_list.append( Q( @@ -117,7 +118,7 @@ def clean(self): def save(self, *args, **kwargs): self.full_clean() - return super(TaggedItem, self).save(*args, **kwargs) + return super().save(*args, **kwargs) def __str__(self) -> str: return str(self.tag) diff --git a/posthog/models/team/team.py b/posthog/models/team/team.py index 6f5f927fe000a..868b855596469 100644 --- a/posthog/models/team/team.py +++ b/posthog/models/team/team.py @@ -1,7 +1,7 @@ import re from decimal import Decimal from functools import lru_cache -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Optional import posthoganalytics import pydantic @@ -64,7 +64,7 @@ class TeamManager(models.Manager): def get_queryset(self): return super().get_queryset().defer(*DEPRECATED_ATTRS) - def set_test_account_filters(self, organization: Optional[Any]) -> List: + def set_test_account_filters(self, organization: Optional[Any]) -> list: filters = [ { "key": "$host", @@ -150,7 +150,7 @@ def increment_id_sequence(self) -> int: return result[0] -def get_default_data_attributes() -> List[str]: +def get_default_data_attributes() -> list[str]: return ["data-attr"] @@ -477,7 +477,7 @@ def groups_on_events_querying_enabled(): def check_is_feature_available_for_team(team_id: int, feature_key: str, current_usage: Optional[int] = None): - available_product_features: Optional[List[Dict[str, str]]] = ( + available_product_features: Optional[list[dict[str, str]]] = ( Team.objects.select_related("organization") .values_list("organization__available_product_features", flat=True) .get(id=team_id) diff --git a/posthog/models/team/util.py b/posthog/models/team/util.py index a21b75ab80384..5756c8da211aa 100644 --- a/posthog/models/team/util.py +++ b/posthog/models/team/util.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Any, List +from typing import Any from posthog.temporal.common.client import sync_connect from posthog.batch_exports.service import batch_export_delete_schedule @@ -7,7 +7,7 @@ from posthog.models.async_migration import is_async_migration_complete -def delete_bulky_postgres_data(team_ids: List[int]): +def delete_bulky_postgres_data(team_ids: list[int]): "Efficiently delete large tables for teams from postgres. Using normal CASCADE delete here can time out" from posthog.models.cohort import CohortPeople @@ -29,7 +29,7 @@ def _raw_delete(queryset: Any): queryset._raw_delete(queryset.db) -def delete_batch_exports(team_ids: List[int]): +def delete_batch_exports(team_ids: list[int]): """Delete BatchExports for deleted teams. Using normal CASCADE doesn't trigger a delete from Temporal. diff --git a/posthog/models/test/test_dashboard_tile_model.py b/posthog/models/test/test_dashboard_tile_model.py index be13ba06975c3..79f4a085a24c7 100644 --- a/posthog/models/test/test_dashboard_tile_model.py +++ b/posthog/models/test/test_dashboard_tile_model.py @@ -1,5 +1,4 @@ import datetime -from typing import Dict, List from django.core.exceptions import ValidationError from django.db.utils import IntegrityError @@ -19,7 +18,7 @@ class TestDashboardTileModel(APIBaseTest): dashboard: Dashboard asset: ExportedAsset - tiles: List[DashboardTile] + tiles: list[DashboardTile] def setUp(self) -> None: self.dashboard = Dashboard.objects.create(team=self.team, name="private dashboard", created_by=self.user) @@ -64,7 +63,7 @@ def test_cannot_add_a_tile_with_insight_and_text_on_validation(self) -> None: DashboardTile.objects.create(dashboard=self.dashboard, insight=insight, text=text) def test_cannot_set_caching_data_for_text_tiles(self) -> None: - tile_fields: List[Dict] = [ + tile_fields: list[dict] = [ {"filters_hash": "123"}, {"refreshing": True}, {"refresh_attempt": 2}, diff --git a/posthog/models/uploaded_media.py b/posthog/models/uploaded_media.py index 0161b71beb4f6..2b31f348263cb 100644 --- a/posthog/models/uploaded_media.py +++ b/posthog/models/uploaded_media.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional import structlog from django.conf import settings @@ -72,7 +72,7 @@ def save_content( def save_content_to_object_storage(uploaded_media: UploadedMedia, content: bytes) -> None: - path_parts: List[str] = [ + path_parts: list[str] = [ settings.OBJECT_STORAGE_MEDIA_UPLOADS_FOLDER, f"team-{uploaded_media.team.pk}", f"media-{uploaded_media.pk}", diff --git a/posthog/models/user.py b/posthog/models/user.py index cb4b1063cc961..c2d5b0f8d5551 100644 --- a/posthog/models/user.py +++ b/posthog/models/user.py @@ -1,5 +1,6 @@ from functools import cached_property -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypedDict +from typing import Any, Optional, TypedDict +from collections.abc import Callable from django.contrib.auth.models import AbstractUser, BaseUserManager from django.db import models, transaction @@ -36,7 +37,7 @@ class UserManager(BaseUserManager): def get_queryset(self): return super().get_queryset().defer(*DEFERED_ATTRS) - model: Type["User"] + model: type["User"] use_in_migrations = True @@ -58,12 +59,12 @@ def bootstrap( email: str, password: Optional[str], first_name: str = "", - organization_fields: Optional[Dict[str, Any]] = None, - team_fields: Optional[Dict[str, Any]] = None, + organization_fields: Optional[dict[str, Any]] = None, + team_fields: Optional[dict[str, Any]] = None, create_team: Optional[Callable[["Organization", "User"], "Team"]] = None, is_staff: bool = False, **user_fields, - ) -> Tuple["Organization", "Team", "User"]: + ) -> tuple["Organization", "Team", "User"]: """Instead of doing the legwork of creating a user from scratch, delegate the details with bootstrap.""" with transaction.atomic(): organization_fields = organization_fields or {} @@ -112,7 +113,7 @@ def get_from_personal_api_key(self, key_value: str) -> Optional["User"]: return personal_api_key.user -def events_column_config_default() -> Dict[str, Any]: +def events_column_config_default() -> dict[str, Any]: return {"active": "DEFAULT"} @@ -124,7 +125,7 @@ class ThemeMode(models.TextChoices): class User(AbstractUser, UUIDClassicModel): USERNAME_FIELD = "email" - REQUIRED_FIELDS: List[str] = [] + REQUIRED_FIELDS: list[str] = [] DISABLED = "disabled" TOOLBAR = "toolbar" diff --git a/posthog/models/utils.py b/posthog/models/utils.py index a093cf1e4ebde..c832cc8f044eb 100644 --- a/posthog/models/utils.py +++ b/posthog/models/utils.py @@ -5,7 +5,8 @@ from contextlib import contextmanager from random import Random, choice from time import time -from typing import Any, Callable, Dict, Iterator, Optional, Set, Type, TypeVar +from typing import Any, Optional, TypeVar +from collections.abc import Callable, Iterator from django.db import IntegrityError, connections, models, transaction from django.db.backends.utils import CursorWrapper @@ -40,7 +41,7 @@ class UUIDT(uuid.UUID): (https://blog.twitter.com/engineering/en_us/a/2010/announcing-snowflake.html). """ - current_series_per_ms: Dict[int, int] = defaultdict(int) + current_series_per_ms: dict[int, int] = defaultdict(int) def __init__( self, @@ -205,10 +206,10 @@ def create_with_slug(create_func: Callable[..., T], default_slug: str = "", *arg def get_deferred_field_set_for_model( - model: Type[models.Model], - fields_not_deferred: Optional[Set[str]] = None, + model: type[models.Model], + fields_not_deferred: Optional[set[str]] = None, field_prefix: str = "", -) -> Set[str]: +) -> set[str]: """Return a set of field names to be deferred for a given model. Used with `.defer()` after `select_related` Why? `select_related` fetches the entire related objects - not allowing you to specify which fields diff --git a/posthog/plugins/site.py b/posthog/plugins/site.py index 9cb2b3023f80e..0f5feda2df2c7 100644 --- a/posthog/plugins/site.py +++ b/posthog/plugins/site.py @@ -1,6 +1,6 @@ from dataclasses import asdict, dataclass from hashlib import md5 -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from posthog.models import Team @@ -11,7 +11,7 @@ class WebJsSource: id: int source: str token: str - config_schema: List[dict] + config_schema: list[dict] config: dict @@ -48,7 +48,7 @@ def get_transpiled_site_source(id: int, token: str) -> Optional[WebJsSource]: return WebJsSource(*(list(response))) # type: ignore -def get_decide_site_apps(team: "Team", using_database: str = "default") -> List[dict]: +def get_decide_site_apps(team: "Team", using_database: str = "default") -> list[dict]: from posthog.models import PluginConfig, PluginSourceFile sources = ( @@ -70,13 +70,13 @@ def get_decide_site_apps(team: "Team", using_database: str = "default") -> List[ ) def site_app_url(source: tuple) -> str: - hash = md5(f"{source[2]}-{source[3]}-{source[4]}".encode("utf-8")).hexdigest() + hash = md5(f"{source[2]}-{source[3]}-{source[4]}".encode()).hexdigest() return f"/site_app/{source[0]}/{source[1]}/{hash}/" return [asdict(WebJsUrl(source[0], site_app_url(source))) for source in sources] -def get_site_config_from_schema(config_schema: Optional[List[dict]], config: Optional[dict]): +def get_site_config_from_schema(config_schema: Optional[list[dict]], config: Optional[dict]): if not config or not config_schema: return {} return { diff --git a/posthog/plugins/utils.py b/posthog/plugins/utils.py index 2610d8b2eb17d..602f775447bfa 100644 --- a/posthog/plugins/utils.py +++ b/posthog/plugins/utils.py @@ -4,7 +4,7 @@ import re import tarfile from tarfile import ReadError -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional from urllib.parse import parse_qs, quote from zipfile import ZIP_DEFLATED, BadZipFile, Path, ZipFile @@ -12,7 +12,7 @@ from django.conf import settings -def parse_github_url(url: str, get_latest_if_none=False) -> Optional[Dict[str, Optional[str]]]: +def parse_github_url(url: str, get_latest_if_none=False) -> Optional[dict[str, Optional[str]]]: url, private_token = split_url_and_private_token(url) match = re.search( r"^https?://(?:www\.)?github\.com/([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+)(/(commit|tree|releases/tag)/([A-Za-z0-9_.\-]+)/?([A-Za-z0-9_.\-/]+)?)?$", @@ -27,7 +27,7 @@ def parse_github_url(url: str, get_latest_if_none=False) -> Optional[Dict[str, O if not match: return None - parsed: Dict[str, Optional[str]] = { + parsed: dict[str, Optional[str]] = { "type": "github", "root_url": f"https://github.com/{match.group(1)}/{match.group(2)}", "user": match.group(1), @@ -76,13 +76,13 @@ def parse_github_url(url: str, get_latest_if_none=False) -> Optional[Dict[str, O return parsed -def parse_gitlab_url(url: str, get_latest_if_none=False) -> Optional[Dict[str, Optional[str]]]: +def parse_gitlab_url(url: str, get_latest_if_none=False) -> Optional[dict[str, Optional[str]]]: url, private_token = split_url_and_private_token(url) match = re.search(r"^https?://(?:www\.)?gitlab\.com/([A-Za-z0-9_.\-/]+)$", url) if not match: return None - parsed: Dict[str, Optional[str]] = { + parsed: dict[str, Optional[str]] = { "type": "gitlab", "project": match.group(1), "tag": None, @@ -127,7 +127,7 @@ def parse_gitlab_url(url: str, get_latest_if_none=False) -> Optional[Dict[str, O return parsed -def parse_npm_url(url: str, get_latest_if_none=False) -> Optional[Dict[str, Optional[str]]]: +def parse_npm_url(url: str, get_latest_if_none=False) -> Optional[dict[str, Optional[str]]]: url, private_token = split_url_and_private_token(url) match = re.search( r"^https?://(?:www\.)?npmjs\.com/package/([@a-z0-9_-]+(/[a-z0-9_-]+)?)?/?(v/([A-Za-z0-9_.-]+)/?|)$", @@ -135,7 +135,7 @@ def parse_npm_url(url: str, get_latest_if_none=False) -> Optional[Dict[str, Opti ) if not match: return None - parsed: Dict[str, Optional[str]] = { + parsed: dict[str, Optional[str]] = { "type": "npm", "pkg": match.group(1), "tag": match.group(4), @@ -166,7 +166,7 @@ def parse_npm_url(url: str, get_latest_if_none=False) -> Optional[Dict[str, Opti return parsed -def parse_url(url: str, get_latest_if_none=False) -> Dict[str, Optional[str]]: +def parse_url(url: str, get_latest_if_none=False) -> dict[str, Optional[str]]: parsed_url = parse_github_url(url, get_latest_if_none) if parsed_url: return parsed_url @@ -179,7 +179,7 @@ def parse_url(url: str, get_latest_if_none=False) -> Dict[str, Optional[str]]: raise Exception("Must be a GitHub/GitLab repository or npm package URL!") -def split_url_and_private_token(url: str) -> Tuple[str, Optional[str]]: +def split_url_and_private_token(url: str) -> tuple[str, Optional[str]]: private_token = None if "?" in url: url, query = url.split("?") @@ -242,7 +242,7 @@ def download_plugin_archive(url: str, tag: Optional[str] = None) -> bytes: def load_json_file(filename: str): try: - with open(filename, "r", encoding="utf_8") as reader: + with open(filename, encoding="utf_8") as reader: return json.loads(reader.read()) except FileNotFoundError: return None @@ -313,8 +313,8 @@ def find_index_ts_in_archive(archive: bytes, main_filename: Optional[str] = None def extract_plugin_code( - archive: bytes, plugin_json_parsed: Optional[Dict[str, Any]] = None -) -> Tuple[str, Optional[str], Optional[str], Optional[str]]: + archive: bytes, plugin_json_parsed: Optional[dict[str, Any]] = None +) -> tuple[str, Optional[str], Optional[str], Optional[str]]: """Extract plugin.json, index.ts (which can be aliased) and frontend.tsx out of an archive. If plugin.json has already been parsed before this is called, its value can be passed in as an optimization.""" diff --git a/posthog/queries/actor_base_query.py b/posthog/queries/actor_base_query.py index 66c476cd814cd..f23b4c4ff05da 100644 --- a/posthog/queries/actor_base_query.py +++ b/posthog/queries/actor_base_query.py @@ -2,12 +2,8 @@ from datetime import datetime, timedelta from typing import ( Any, - Dict, - List, Literal, Optional, - Set, - Tuple, TypedDict, Union, cast, @@ -34,14 +30,14 @@ class EventInfoForRecording(TypedDict): class MatchedRecording(TypedDict): session_id: str - events: List[EventInfoForRecording] + events: list[EventInfoForRecording] class CommonActor(TypedDict): id: Union[uuid.UUID, str] created_at: Optional[str] - properties: Dict[str, Any] - matched_recordings: List[MatchedRecording] + properties: dict[str, Any] + matched_recordings: list[MatchedRecording] value_at_data_point: Optional[float] @@ -50,7 +46,7 @@ class SerializedPerson(CommonActor): uuid: Union[uuid.UUID, str] is_identified: Optional[bool] name: str - distinct_ids: List[str] + distinct_ids: list[str] class SerializedGroup(CommonActor): @@ -81,7 +77,7 @@ def __init__( self.entity = entity self._filter = filter - def actor_query(self, limit_actors: Optional[bool] = True) -> Tuple[str, Dict]: + def actor_query(self, limit_actors: Optional[bool] = True) -> tuple[str, dict]: """Implemented by subclasses. Must provide query and params. The query must return list of uuids. Can be group uuids (group_key) or person uuids""" raise NotImplementedError() @@ -96,9 +92,9 @@ def is_aggregating_by_groups(self) -> bool: def get_actors( self, - ) -> Tuple[ + ) -> tuple[ Union[QuerySet[Person], QuerySet[Group]], - Union[List[SerializedGroup], List[SerializedPerson]], + Union[list[SerializedGroup], list[SerializedPerson]], int, ]: """Get actors in data model and dict formats. Builds query and executes""" @@ -124,10 +120,10 @@ def get_actors( def query_for_session_ids_with_recordings( self, - session_ids: Set[str], + session_ids: set[str], date_from: datetime | None, date_to: datetime | None, - ) -> Set[str]: + ) -> set[str]: """Filters a list of session_ids to those that actually have recordings""" query = """ SELECT DISTINCT session_id @@ -166,9 +162,9 @@ def query_for_session_ids_with_recordings( def add_matched_recordings_to_serialized_actors( self, - serialized_actors: Union[List[SerializedGroup], List[SerializedPerson]], + serialized_actors: Union[list[SerializedGroup], list[SerializedPerson]], raw_result, - ) -> Union[List[SerializedGroup], List[SerializedPerson]]: + ) -> Union[list[SerializedGroup], list[SerializedPerson]]: all_session_ids = set() session_events_column_index = 2 if self.ACTOR_VALUES_INCLUDED else 1 @@ -192,9 +188,9 @@ def add_matched_recordings_to_serialized_actors( ) session_ids_with_recordings = session_ids_with_all_recordings.difference(session_ids_with_deleted_recordings) - matched_recordings_by_actor_id: Dict[Union[uuid.UUID, str], List[MatchedRecording]] = {} + matched_recordings_by_actor_id: dict[Union[uuid.UUID, str], list[MatchedRecording]] = {} for row in raw_result: - recording_events_by_session_id: Dict[str, List[EventInfoForRecording]] = {} + recording_events_by_session_id: dict[str, list[EventInfoForRecording]] = {} if len(row) > session_events_column_index - 1: for event in row[session_events_column_index]: event_session_id = event[2] @@ -211,7 +207,7 @@ def add_matched_recordings_to_serialized_actors( # Casting Union[SerializedActor, SerializedGroup] as SerializedPerson because mypy yells # when you do an indexed assignment on a Union even if all items in the Union support it - serialized_actors = cast(List[SerializedPerson], serialized_actors) + serialized_actors = cast(list[SerializedPerson], serialized_actors) serialized_actors_with_recordings = [] for actor in serialized_actors: actor["matched_recordings"] = matched_recordings_by_actor_id[actor["id"]] @@ -221,12 +217,12 @@ def add_matched_recordings_to_serialized_actors( def get_actors_from_result( self, raw_result - ) -> Tuple[ + ) -> tuple[ Union[QuerySet[Person], QuerySet[Group]], - Union[List[SerializedGroup], List[SerializedPerson]], + Union[list[SerializedGroup], list[SerializedPerson]], ]: actors: Union[QuerySet[Person], QuerySet[Group]] - serialized_actors: Union[List[SerializedGroup], List[SerializedPerson]] + serialized_actors: Union[list[SerializedGroup], list[SerializedPerson]] actor_ids = [row[0] for row in raw_result] value_per_actor_id = {str(row[0]): row[1] for row in raw_result} if self.ACTOR_VALUES_INCLUDED else None @@ -255,9 +251,9 @@ def get_actors_from_result( def get_groups( team_id: int, group_type_index: int, - group_ids: List[Any], - value_per_actor_id: Optional[Dict[str, float]] = None, -) -> Tuple[QuerySet[Group], List[SerializedGroup]]: + group_ids: list[Any], + value_per_actor_id: Optional[dict[str, float]] = None, +) -> tuple[QuerySet[Group], list[SerializedGroup]]: """Get groups from raw SQL results in data model and dict formats""" groups: QuerySet[Group] = Group.objects.filter( team_id=team_id, group_type_index=group_type_index, group_key__in=group_ids @@ -267,10 +263,10 @@ def get_groups( def get_people( team: Team, - people_ids: List[Any], - value_per_actor_id: Optional[Dict[str, float]] = None, + people_ids: list[Any], + value_per_actor_id: Optional[dict[str, float]] = None, distinct_id_limit=1000, -) -> Tuple[QuerySet[Person], List[SerializedPerson]]: +) -> tuple[QuerySet[Person], list[SerializedPerson]]: """Get people from raw SQL results in data model and dict formats""" distinct_id_subquery = Subquery( PersonDistinctId.objects.filter(person_id=OuterRef("person_id")).values_list("id", flat=True)[ @@ -294,9 +290,9 @@ def get_people( def serialize_people( team: Team, - data: Union[QuerySet[Person], List[Person]], - value_per_actor_id: Optional[Dict[str, float]] = None, -) -> List[SerializedPerson]: + data: Union[QuerySet[Person], list[Person]], + value_per_actor_id: Optional[dict[str, float]] = None, +) -> list[SerializedPerson]: from posthog.api.person import get_person_name return [ @@ -316,7 +312,7 @@ def serialize_people( ] -def serialize_groups(data: QuerySet[Group], value_per_actor_id: Optional[Dict[str, float]]) -> List[SerializedGroup]: +def serialize_groups(data: QuerySet[Group], value_per_actor_id: Optional[dict[str, float]]) -> list[SerializedGroup]: return [ SerializedGroup( id=group.group_key, diff --git a/posthog/queries/app_metrics/historical_exports.py b/posthog/queries/app_metrics/historical_exports.py index cbf22d480156b..5fd32a06ec2da 100644 --- a/posthog/queries/app_metrics/historical_exports.py +++ b/posthog/queries/app_metrics/historical_exports.py @@ -1,6 +1,6 @@ import json from datetime import timedelta -from typing import Dict, Optional +from typing import Optional from zoneinfo import ZoneInfo @@ -26,7 +26,7 @@ def historical_exports_activity(team_id: int, plugin_config_id: int, job_id: Opt **({"detail__trigger__job_id": job_id} if job_id is not None else {}), ) - by_category: Dict = {"job_triggered": {}, "export_success": {}, "export_fail": {}} + by_category: dict = {"job_triggered": {}, "export_success": {}, "export_fail": {}} for entry in entries: by_category[entry.activity][entry.detail["trigger"]["job_id"]] = entry diff --git a/posthog/queries/app_metrics/test/test_app_metrics.py b/posthog/queries/app_metrics/test/test_app_metrics.py index e6c50b08ae525..2368961b507cb 100644 --- a/posthog/queries/app_metrics/test/test_app_metrics.py +++ b/posthog/queries/app_metrics/test/test_app_metrics.py @@ -1,6 +1,6 @@ import json from datetime import datetime -from typing import Dict, Optional +from typing import Optional from freezegun.api import freeze_time @@ -34,7 +34,7 @@ def create_app_metric( failures=0, error_uuid: Optional[str] = None, error_type: Optional[str] = None, - error_details: Optional[Dict] = None, + error_details: Optional[dict] = None, ): timestamp = cast_timestamp_or_now(timestamp) data = { diff --git a/posthog/queries/base.py b/posthog/queries/base.py index 7dff88f602099..e5cf6e717444b 100644 --- a/posthog/queries/base.py +++ b/posthog/queries/base.py @@ -3,14 +3,12 @@ import re from typing import ( Any, - Callable, - Dict, - List, Optional, TypeVar, Union, cast, ) +from collections.abc import Callable from zoneinfo import ZoneInfo from dateutil.relativedelta import relativedelta from dateutil import parser @@ -47,7 +45,7 @@ def determine_compared_filter(filter: FilterType) -> FilterType: return filter.shallow_clone({"date_from": date_from.isoformat(), "date_to": date_to.isoformat()}) -def convert_to_comparison(trend_entities: List[Dict[str, Any]], filter, label: str) -> List[Dict[str, Any]]: +def convert_to_comparison(trend_entities: list[dict[str, Any]], filter, label: str) -> list[dict[str, Any]]: for entity in trend_entities: labels = [ "{} {}".format(filter.interval if filter.interval is not None else "day", i) @@ -72,7 +70,7 @@ def convert_to_comparison(trend_entities: List[Dict[str, Any]], filter, label: s """ -def handle_compare(filter, func: Callable, team: Team, **kwargs) -> List: +def handle_compare(filter, func: Callable, team: Team, **kwargs) -> list: all_entities = [] base_entitites = func(filter=filter, team=team, **kwargs) if filter.compare: @@ -88,7 +86,7 @@ def handle_compare(filter, func: Callable, team: Team, **kwargs) -> List: return all_entities -def match_property(property: Property, override_property_values: Dict[str, Any]) -> bool: +def match_property(property: Property, override_property_values: dict[str, Any]) -> bool: # only looks for matches where key exists in override_property_values # doesn't support operator is_not_set @@ -276,8 +274,8 @@ def lookup_q(key: str, value: Any) -> Q: def property_to_Q( team_id: int, property: Property, - override_property_values: Optional[Dict[str, Any]] = None, - cohorts_cache: Optional[Dict[int, CohortOrEmpty]] = None, + override_property_values: Optional[dict[str, Any]] = None, + cohorts_cache: Optional[dict[int, CohortOrEmpty]] = None, using_database: str = "default", ) -> Q: if override_property_values is None: @@ -382,8 +380,8 @@ def property_to_Q( def property_group_to_Q( team_id: int, property_group: PropertyGroup, - override_property_values: Optional[Dict[str, Any]] = None, - cohorts_cache: Optional[Dict[int, CohortOrEmpty]] = None, + override_property_values: Optional[dict[str, Any]] = None, + cohorts_cache: Optional[dict[int, CohortOrEmpty]] = None, using_database: str = "default", ) -> Q: if override_property_values is None: @@ -426,9 +424,9 @@ def property_group_to_Q( def properties_to_Q( team_id: int, - properties: List[Property], - override_property_values: Optional[Dict[str, Any]] = None, - cohorts_cache: Optional[Dict[int, CohortOrEmpty]] = None, + properties: list[Property], + override_property_values: Optional[dict[str, Any]] = None, + cohorts_cache: Optional[dict[int, CohortOrEmpty]] = None, using_database: str = "default", ) -> Q: """ diff --git a/posthog/queries/breakdown_props.py b/posthog/queries/breakdown_props.py index fffb0aef0f2f0..96cf6afa9596c 100644 --- a/posthog/queries/breakdown_props.py +++ b/posthog/queries/breakdown_props.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Optional, Union, cast from django.forms import ValidationError @@ -50,7 +50,7 @@ def get_breakdown_prop_values( column_optimizer: Optional[ColumnOptimizer] = None, person_properties_mode: PersonPropertiesMode = PersonPropertiesMode.USING_PERSON_PROPERTIES_COLUMN, use_all_funnel_entities: bool = False, -) -> Tuple[List[Any], bool]: +) -> tuple[list[Any], bool]: """ Returns the top N breakdown prop values for event/person breakdown @@ -77,13 +77,13 @@ def get_breakdown_prop_values( props_to_filter = filter.property_groups person_join_clauses = "" - person_join_params: Dict = {} + person_join_params: dict = {} groups_join_clause = "" - groups_join_params: Dict = {} + groups_join_params: dict = {} sessions_join_clause = "" - sessions_join_params: Dict = {} + sessions_join_params: dict = {} null_person_filter = ( f"AND notEmpty(e.person_id)" if team.person_on_events_mode != PersonsOnEventsMode.disabled else "" @@ -248,14 +248,14 @@ def get_breakdown_prop_values( def _to_value_expression( breakdown_type: Optional[BREAKDOWN_TYPES], - breakdown: Union[str, List[Union[str, int]], None], + breakdown: Union[str, list[Union[str, int]], None], breakdown_group_type_index: Optional[GroupTypeIndex], hogql_context: HogQLContext, breakdown_normalize_url: bool = False, direct_on_events: bool = False, cast_as_float: bool = False, -) -> Tuple[str, Dict]: - params: Dict[str, Any] = {} +) -> tuple[str, dict]: + params: dict[str, Any] = {} if breakdown_type == "session": if breakdown == "$session_duration": # Return the session duration expression right away because it's already an number, @@ -321,7 +321,7 @@ def _to_bucketing_expression(bin_count: int) -> str: return f"arrayCompact(arrayMap(x -> floor(x, 2), {qunatile_expression}))" -def _format_all_query(team: Team, filter: Filter, **kwargs) -> Tuple[str, Dict]: +def _format_all_query(team: Team, filter: Filter, **kwargs) -> tuple[str, dict]: entity = kwargs.pop("entity", None) date_params = {} @@ -354,7 +354,7 @@ def _format_all_query(team: Team, filter: Filter, **kwargs) -> Tuple[str, Dict]: return query, {**date_params, **prop_filter_params} -def format_breakdown_cohort_join_query(team: Team, filter: Filter, **kwargs) -> Tuple[str, List, Dict]: +def format_breakdown_cohort_join_query(team: Team, filter: Filter, **kwargs) -> tuple[str, list, dict]: entity = kwargs.pop("entity", None) cohorts = ( Cohort.objects.filter(team_id=team.pk, pk__in=[b for b in filter.breakdown if b != "all"]) @@ -371,9 +371,9 @@ def format_breakdown_cohort_join_query(team: Team, filter: Filter, **kwargs) -> return " UNION ALL ".join(cohort_queries), ids, params -def _parse_breakdown_cohorts(cohorts: List[Cohort], hogql_context: HogQLContext) -> Tuple[List[str], Dict]: +def _parse_breakdown_cohorts(cohorts: list[Cohort], hogql_context: HogQLContext) -> tuple[list[str], dict]: queries = [] - params: Dict[str, Any] = {} + params: dict[str, Any] = {} for idx, cohort in enumerate(cohorts): person_id_query, cohort_filter_params = format_filter_query(cohort, idx, hogql_context) diff --git a/posthog/queries/column_optimizer/foss_column_optimizer.py b/posthog/queries/column_optimizer/foss_column_optimizer.py index 98dfb1b54c418..b3e73d3178c5e 100644 --- a/posthog/queries/column_optimizer/foss_column_optimizer.py +++ b/posthog/queries/column_optimizer/foss_column_optimizer.py @@ -1,6 +1,7 @@ from collections import Counter -from typing import Counter as TCounter -from typing import Generator, List, Set, Union, cast +from collections import Counter as TCounter +from typing import Union, cast +from collections.abc import Generator from posthog.clickhouse.materialized_columns import ColumnName, get_materialized_columns from posthog.constants import TREND_FILTER_TYPE_ACTIONS, FunnelCorrelationType @@ -48,19 +49,19 @@ def __init__( self.property_optimizer = PropertyOptimizer() @cached_property - def event_columns_to_query(self) -> Set[ColumnName]: + def event_columns_to_query(self) -> set[ColumnName]: "Returns a list of event table columns containing materialized properties that this query needs" return self.columns_to_query("events", set(self.used_properties_with_type("event"))) @cached_property - def person_on_event_columns_to_query(self) -> Set[ColumnName]: + def person_on_event_columns_to_query(self) -> set[ColumnName]: "Returns a list of event table person columns containing materialized properties that this query needs" return self.columns_to_query("events", set(self.used_properties_with_type("person")), "person_properties") @cached_property - def person_columns_to_query(self) -> Set[ColumnName]: + def person_columns_to_query(self) -> set[ColumnName]: "Returns a list of person table columns containing materialized properties that this query needs" return self.columns_to_query("person", set(self.used_properties_with_type("person"))) @@ -68,9 +69,9 @@ def person_columns_to_query(self) -> Set[ColumnName]: def columns_to_query( self, table: TableWithProperties, - used_properties: Set[PropertyIdentifier], + used_properties: set[PropertyIdentifier], table_column: str = "properties", - ) -> Set[ColumnName]: + ) -> set[ColumnName]: "Transforms a list of property names to what columns are needed for that query" materialized_columns = get_materialized_columns(table) @@ -92,11 +93,11 @@ def is_using_cohort_propertes(self) -> bool: ) @cached_property - def group_types_to_query(self) -> Set[GroupTypeIndex]: + def group_types_to_query(self) -> set[GroupTypeIndex]: return set() @cached_property - def group_on_event_columns_to_query(self) -> Set[ColumnName]: + def group_on_event_columns_to_query(self) -> set[ColumnName]: return set() @cached_property @@ -171,7 +172,7 @@ def properties_used_in_filter(self) -> TCounter[PropertyIdentifier]: counter += get_action_tables_and_properties(entity.get_action()) if ( - not isinstance(self.filter, (StickinessFilter, PropertiesTimelineFilter)) + not isinstance(self.filter, StickinessFilter | PropertiesTimelineFilter) and self.filter.correlation_type == FunnelCorrelationType.PROPERTIES and self.filter.correlation_property_names ): @@ -195,7 +196,7 @@ def used_properties_with_type(self, property_type: PropertyType) -> TCounter[Pro def entities_used_in_filter(self) -> Generator[Entity, None, None]: yield from self.filter.entities - yield from cast(List[Entity], self.filter.exclusions) + yield from cast(list[Entity], self.filter.exclusions) if isinstance(self.filter, RetentionFilter): yield self.filter.target_entity diff --git a/posthog/queries/event_query/event_query.py b/posthog/queries/event_query/event_query.py index 8737876d00116..49d4565a943ec 100644 --- a/posthog/queries/event_query/event_query.py +++ b/posthog/queries/event_query/event_query.py @@ -1,5 +1,5 @@ from abc import ABCMeta, abstractmethod -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union from posthog.clickhouse.materialized_columns import ColumnName from posthog.models import Cohort, Filter, Property @@ -38,9 +38,9 @@ class EventQuery(metaclass=ABCMeta): _should_join_persons = False _should_join_sessions = False _should_round_interval = False - _extra_fields: List[ColumnName] - _extra_event_properties: List[PropertyName] - _extra_person_fields: List[ColumnName] + _extra_fields: list[ColumnName] + _extra_event_properties: list[PropertyName] + _extra_person_fields: list[ColumnName] _person_id_alias: str _session_id_alias: Optional[str] @@ -60,9 +60,9 @@ def __init__( should_join_persons=False, should_join_sessions=False, # Extra events/person table columns to fetch since parent query needs them - extra_fields: Optional[List[ColumnName]] = None, - extra_event_properties: Optional[List[PropertyName]] = None, - extra_person_fields: Optional[List[ColumnName]] = None, + extra_fields: Optional[list[ColumnName]] = None, + extra_event_properties: Optional[list[PropertyName]] = None, + extra_person_fields: Optional[list[ColumnName]] = None, override_aggregate_users_by_distinct_id: Optional[bool] = None, person_on_events_mode: PersonsOnEventsMode = PersonsOnEventsMode.disabled, **kwargs, @@ -79,7 +79,7 @@ def __init__( self._extra_event_properties = extra_event_properties self._column_optimizer = ColumnOptimizer(self._filter, self._team_id) self._extra_person_fields = extra_person_fields - self.params: Dict[str, Any] = { + self.params: dict[str, Any] = { "team_id": self._team_id, "timezone": team.timezone, } @@ -118,7 +118,7 @@ def __init__( self._person_id_alias = self._get_person_id_alias(person_on_events_mode) @abstractmethod - def get_query(self) -> Tuple[str, Dict[str, Any]]: + def get_query(self) -> tuple[str, dict[str, Any]]: pass @abstractmethod @@ -206,7 +206,7 @@ def _person_query(self) -> PersonQuery: extra_fields=self._extra_person_fields, ) - def _get_person_query(self) -> Tuple[str, Dict]: + def _get_person_query(self) -> tuple[str, dict]: if self._should_join_persons: person_query, params = self._person_query.get_query() return ( @@ -219,7 +219,7 @@ def _get_person_query(self) -> Tuple[str, Dict]: else: return "", {} - def _get_groups_query(self) -> Tuple[str, Dict]: + def _get_groups_query(self) -> tuple[str, dict]: return "", {} @cached_property @@ -232,7 +232,7 @@ def _sessions_query(self) -> SessionQuery: session_id_alias=self._session_id_alias, ) - def _get_sessions_query(self) -> Tuple[str, Dict]: + def _get_sessions_query(self) -> tuple[str, dict]: if self._should_join_sessions: session_query, session_params = self._sessions_query.get_query() @@ -246,7 +246,7 @@ def _get_sessions_query(self) -> Tuple[str, Dict]: ) return "", {} - def _get_date_filter(self) -> Tuple[str, Dict]: + def _get_date_filter(self) -> tuple[str, dict]: date_params = {} query_date_range = QueryDateRange( filter=self._filter, team=self._team, should_round=self._should_round_interval @@ -270,7 +270,7 @@ def _get_prop_groups( person_id_joined_alias="person_id", prepend="global", allow_denormalized_props=True, - ) -> Tuple[str, Dict]: + ) -> tuple[str, dict]: if not prop_group: return "", {} diff --git a/posthog/queries/foss_cohort_query.py b/posthog/queries/foss_cohort_query.py index 352fc19ee13cf..847f6737c9f3f 100644 --- a/posthog/queries/foss_cohort_query.py +++ b/posthog/queries/foss_cohort_query.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Optional, Union, cast from zoneinfo import ZoneInfo from posthog.clickhouse.materialized_columns import ColumnName @@ -26,8 +26,8 @@ from posthog.schema import PersonsOnEventsMode from posthog.utils import relative_date_parse -Relative_Date = Tuple[int, OperatorInterval] -Event = Tuple[str, Union[str, int]] +Relative_Date = tuple[int, OperatorInterval] +Event = tuple[str, Union[str, int]] INTERVAL_TO_SECONDS = { @@ -40,7 +40,7 @@ } -def relative_date_to_seconds(date: Tuple[Optional[int], Union[OperatorInterval, None]]): +def relative_date_to_seconds(date: tuple[Optional[int], Union[OperatorInterval, None]]): if date[0] is None or date[1] is None: raise ValueError("Time value and time interval must be specified") @@ -66,7 +66,7 @@ def parse_and_validate_positive_integer(value: Optional[int], value_name: str) - return parsed_value -def validate_entity(possible_event: Tuple[Optional[str], Optional[Union[int, str]]]) -> Event: +def validate_entity(possible_event: tuple[Optional[str], Optional[Union[int, str]]]) -> Event: event_type = possible_event[0] event_val = possible_event[1] if event_type is None or event_val is None: @@ -83,7 +83,7 @@ def relative_date_is_greater(date_1: Relative_Date, date_2: Relative_Date) -> bo return relative_date_to_seconds(date_1) > relative_date_to_seconds(date_2) -def convert_to_entity_params(events: List[Event]) -> Tuple[List, List]: +def convert_to_entity_params(events: list[Event]) -> tuple[list, list]: res_events = [] res_actions = [] @@ -124,8 +124,8 @@ class FOSSCohortQuery(EventQuery): BEHAVIOR_QUERY_ALIAS = "behavior_query" FUNNEL_QUERY_ALIAS = "funnel_query" SEQUENCE_FIELD_ALIAS = "steps" - _fields: List[str] - _events: List[str] + _fields: list[str] + _events: list[str] _earliest_time_for_event_query: Optional[Relative_Date] _restrict_event_query_by_time: bool @@ -139,9 +139,9 @@ def __init__( should_join_distinct_ids=False, should_join_persons=False, # Extra events/person table columns to fetch since parent query needs them - extra_fields: Optional[List[ColumnName]] = None, - extra_event_properties: Optional[List[PropertyName]] = None, - extra_person_fields: Optional[List[ColumnName]] = None, + extra_fields: Optional[list[ColumnName]] = None, + extra_event_properties: Optional[list[PropertyName]] = None, + extra_person_fields: Optional[list[ColumnName]] = None, override_aggregate_users_by_distinct_id: Optional[bool] = None, **kwargs, ) -> None: @@ -187,14 +187,14 @@ def _unwrap(property_group: PropertyGroup, negate_group: bool = False) -> Proper if not negate_group: return PropertyGroup( type=property_group.type, - values=[_unwrap(v) for v in cast(List[PropertyGroup], property_group.values)], + values=[_unwrap(v) for v in cast(list[PropertyGroup], property_group.values)], ) else: return PropertyGroup( type=PropertyOperatorType.AND if property_group.type == PropertyOperatorType.OR else PropertyOperatorType.OR, - values=[_unwrap(v, True) for v in cast(List[PropertyGroup], property_group.values)], + values=[_unwrap(v, True) for v in cast(list[PropertyGroup], property_group.values)], ) elif isinstance(property_group.values[0], Property): @@ -202,7 +202,7 @@ def _unwrap(property_group: PropertyGroup, negate_group: bool = False) -> Proper # if any single one is a cohort property, unwrap it into a property group # which implies converting everything else in the list into a property group too - new_property_group_list: List[PropertyGroup] = [] + new_property_group_list: list[PropertyGroup] = [] for prop in property_group.values: prop = cast(Property, prop) current_negation = prop.negation or False @@ -258,7 +258,7 @@ def _unwrap(property_group: PropertyGroup, negate_group: bool = False) -> Proper return filter.shallow_clone({"properties": new_props.to_dict()}) # Implemented in /ee - def get_query(self) -> Tuple[str, Dict[str, Any]]: + def get_query(self) -> tuple[str, dict[str, Any]]: if not self._outer_property_groups: # everything is pushed down, no behavioral stuff to do # thus, use personQuery directly @@ -294,7 +294,7 @@ def get_query(self) -> Tuple[str, Dict[str, Any]]: return final_query, self.params - def _build_sources(self, subq: List[Tuple[str, str]]) -> Tuple[str, str]: + def _build_sources(self, subq: list[tuple[str, str]]) -> tuple[str, str]: q = "" filtered_queries = [(q, alias) for (q, alias) in subq if q and len(q)] @@ -325,7 +325,7 @@ def _build_sources(self, subq: List[Tuple[str, str]]) -> Tuple[str, str]: return q, fields - def _get_behavior_subquery(self) -> Tuple[str, Dict[str, Any], str]: + def _get_behavior_subquery(self) -> tuple[str, dict[str, Any], str]: # # Get the subquery for the cohort query. # @@ -371,7 +371,7 @@ def _get_behavior_subquery(self) -> Tuple[str, Dict[str, Any], str]: return query, params, self.BEHAVIOR_QUERY_ALIAS - def _get_persons_query(self, prepend: str = "") -> Tuple[str, Dict[str, Any], str]: + def _get_persons_query(self, prepend: str = "") -> tuple[str, dict[str, Any], str]: query, params = "", {} if self._should_join_persons: person_query, person_params = self._person_query.get_query(prepend=prepend) @@ -387,9 +387,9 @@ def should_pushdown_persons(self) -> bool: prop.type for prop in getattr(self._outer_property_groups, "flat", []) ] and "static-cohort" not in [prop.type for prop in getattr(self._outer_property_groups, "flat", [])] - def _get_date_condition(self) -> Tuple[str, Dict[str, Any]]: + def _get_date_condition(self) -> tuple[str, dict[str, Any]]: date_query = "" - date_params: Dict[str, Any] = {} + date_params: dict[str, Any] = {} earliest_time_param = f"earliest_time_{self._cohort_pk}" if self._earliest_time_for_event_query and self._restrict_event_query_by_time: @@ -404,7 +404,7 @@ def _check_earliest_date(self, relative_date: Relative_Date) -> None: elif relative_date_is_greater(relative_date, self._earliest_time_for_event_query): self._earliest_time_for_event_query = relative_date - def _get_conditions(self) -> Tuple[str, Dict[str, Any]]: + def _get_conditions(self) -> tuple[str, dict[str, Any]]: def build_conditions(prop: Optional[Union[PropertyGroup, Property]], prepend="level", num=0): if not prop: return "", {} @@ -426,9 +426,9 @@ def build_conditions(prop: Optional[Union[PropertyGroup, Property]], prepend="le return f"AND ({conditions})" if conditions else "", params # Implemented in /ee - def _get_condition_for_property(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def _get_condition_for_property(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: res: str = "" - params: Dict[str, Any] = {} + params: dict[str, Any] = {} if prop.type == "behavioral": if prop.value == "performed_event": @@ -446,7 +446,7 @@ def _get_condition_for_property(self, prop: Property, prepend: str, idx: int) -> return res, params - def get_person_condition(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def get_person_condition(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: if self._outer_property_groups and len(self._outer_property_groups.flat): return prop_filter_json_extract( prop, @@ -459,7 +459,7 @@ def get_person_condition(self, prop: Property, prepend: str, idx: int) -> Tuple[ else: return "", {} - def get_static_cohort_condition(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def get_static_cohort_condition(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: # If we reach this stage, it means there are no cyclic dependencies # They should've been caught by API update validation # and if not there, `simplifyFilter` would've failed @@ -467,8 +467,8 @@ def get_static_cohort_condition(self, prop: Property, prepend: str, idx: int) -> query, params = format_static_cohort_query(cohort, idx, prepend) return f"id {'NOT' if prop.negation else ''} IN ({query})", params - def _get_entity_event_filters(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: - params: Dict[str, Any] = {} + def _get_entity_event_filters(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: + params: dict[str, Any] = {} if prop.event_filters: prop_query, prop_params = parse_prop_grouped_clauses( @@ -491,7 +491,7 @@ def _get_relative_interval_from_explicit_date(self, datetime: datetime, timezone # one extra day for any partial days return (delta.days + 1, "day") - def _get_entity_datetime_filters(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def _get_entity_datetime_filters(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: if prop.explicit_datetime: # Explicit datetime filter, can be a relative or absolute date, follows same convention # as all analytics datetime filters @@ -512,7 +512,7 @@ def _get_entity_datetime_filters(self, prop: Property, prepend: str, idx: int) - return f"timestamp > now() - INTERVAL %({date_param})s {date_interval}", {f"{date_param}": date_value} - def get_performed_event_condition(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def get_performed_event_condition(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: event = (prop.event_type, prop.key) column_name = f"performed_event_condition_{prepend}_{idx}" @@ -530,7 +530,7 @@ def get_performed_event_condition(self, prop: Property, prepend: str, idx: int) **entity_filters_params, } - def get_performed_event_multiple(self, prop: Property, prepend: str, idx: int) -> Tuple[str, Dict[str, Any]]: + def get_performed_event_multiple(self, prop: Property, prepend: str, idx: int) -> tuple[str, dict[str, Any]]: event = (prop.event_type, prop.key) column_name = f"performed_event_multiple_condition_{prepend}_{idx}" @@ -591,12 +591,12 @@ def _validate_negations(self) -> None: def _get_entity( self, - event: Tuple[Optional[str], Optional[Union[int, str]]], + event: tuple[Optional[str], Optional[Union[int, str]]], prepend: str, idx: int, - ) -> Tuple[str, Dict[str, Any]]: + ) -> tuple[str, dict[str, Any]]: res: str = "" - params: Dict[str, Any] = {} + params: dict[str, Any] = {} if event[0] is None or event[1] is None: raise ValueError("Event type and key must be specified") diff --git a/posthog/queries/funnels/base.py b/posthog/queries/funnels/base.py index c4258c6f6eb9f..a6de14b050cfa 100644 --- a/posthog/queries/funnels/base.py +++ b/posthog/queries/funnels/base.py @@ -1,7 +1,7 @@ import urllib.parse import uuid from abc import ABC -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Optional, Union, cast from rest_framework.exceptions import ValidationError @@ -44,9 +44,9 @@ class ClickhouseFunnelBase(ABC): _team: Team _include_timestamp: Optional[bool] _include_preceding_timestamp: Optional[bool] - _extra_event_fields: List[ColumnName] - _extra_event_properties: List[PropertyName] - _include_properties: List[str] + _extra_event_fields: list[ColumnName] + _extra_event_properties: list[PropertyName] + _include_properties: list[str] def __init__( self, @@ -55,7 +55,7 @@ def __init__( include_timestamp: Optional[bool] = None, include_preceding_timestamp: Optional[bool] = None, base_uri: str = "/", - include_properties: Optional[List[str]] = None, + include_properties: Optional[list[str]] = None, ) -> None: self._filter = filter self._team = team @@ -92,8 +92,8 @@ def __init__( self.params.update({OFFSET: self._filter.offset}) - self._extra_event_fields: List[ColumnName] = [] - self._extra_event_properties: List[PropertyName] = [] + self._extra_event_fields: list[ColumnName] = [] + self._extra_event_properties: list[PropertyName] = [] if self._filter.include_recordings: self._extra_event_fields = ["uuid"] self._extra_event_properties = ["$session_id", "$window_id"] @@ -111,9 +111,9 @@ def _serialize_step( self, step: Entity, count: int, - people: Optional[List[uuid.UUID]] = None, + people: Optional[list[uuid.UUID]] = None, sampling_factor: Optional[float] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: if step.type == TREND_FILTER_TYPE_ACTIONS: name = step.get_action().name else: @@ -135,7 +135,7 @@ def extra_event_fields_and_properties(self): def _update_filters(self): # format default dates - data: Dict[str, Any] = {} + data: dict[str, Any] = {} if not self._filter._date_from: data.update({"date_from": relative_date_parse("-7d", self._team.timezone_info)}) @@ -153,7 +153,7 @@ def _update_filters(self): # # Once multi property breakdown is implemented in Trends this becomes unnecessary - if isinstance(self._filter.breakdowns, List) and self._filter.breakdown_type in [ + if isinstance(self._filter.breakdowns, list) and self._filter.breakdown_type in [ "person", "event", "hogql", @@ -167,7 +167,7 @@ def _update_filters(self): "hogql", None, ]: - boxed_breakdown: List[Union[str, int]] = box_value(self._filter.breakdown) + boxed_breakdown: list[Union[str, int]] = box_value(self._filter.breakdown) data.update({"breakdown": boxed_breakdown}) for exclusion in self._filter.exclusions: @@ -270,7 +270,7 @@ def _format_results(self, results): else: return self._format_single_funnel(results[0]) - def _exec_query(self) -> List[Tuple]: + def _exec_query(self) -> list[tuple]: self._filter.team = self._team query = self.get_query() return insight_sync_execute( @@ -289,7 +289,7 @@ def _get_timestamp_outer_select(self) -> str: else: return "" - def _get_timestamp_selects(self) -> Tuple[str, str]: + def _get_timestamp_selects(self) -> tuple[str, str]: """ Returns timestamp selectors for the target step and optionally the preceding step. In the former case, always returns the timestamp for the first and last step as well. @@ -328,7 +328,7 @@ def _get_timestamp_selects(self) -> Tuple[str, str]: return "", "" def _get_step_times(self, max_steps: int): - conditions: List[str] = [] + conditions: list[str] = [] for i in range(1, max_steps): conditions.append( f"if(isNotNull(latest_{i}) AND latest_{i} <= latest_{i-1} + INTERVAL {self._filter.funnel_window_interval} {self._filter.funnel_window_interval_unit_ch()}, " @@ -339,7 +339,7 @@ def _get_step_times(self, max_steps: int): return f", {formatted}" if formatted else "" def _get_partition_cols(self, level_index: int, max_steps: int): - cols: List[str] = [] + cols: list[str] = [] for i in range(0, max_steps): cols.append(f"step_{i}") if i < level_index: @@ -397,7 +397,7 @@ def _get_sorting_condition(self, curr_index: int, max_steps: int): if curr_index == 1: return "1" - conditions: List[str] = [] + conditions: list[str] = [] for i in range(1, curr_index): duplicate_event = ( True @@ -444,7 +444,7 @@ def _get_inner_event_query( else: steps_conditions = self._get_steps_conditions(length=len(entities_to_use)) - all_step_cols: List[str] = [] + all_step_cols: list[str] = [] for index, entity in enumerate(entities_to_use): step_cols = self._get_step_col(entity, index, entity_name) all_step_cols.extend(step_cols) @@ -521,7 +521,7 @@ def _add_breakdown_attribution_subquery(self, inner_query: str) -> str: """ def _get_steps_conditions(self, length: int) -> str: - step_conditions: List[str] = [] + step_conditions: list[str] = [] for index in range(length): step_conditions.append(f"step_{index} = 1") @@ -531,10 +531,10 @@ def _get_steps_conditions(self, length: int) -> str: return " OR ".join(step_conditions) - def _get_step_col(self, entity: Entity, index: int, entity_name: str, step_prefix: str = "") -> List[str]: + def _get_step_col(self, entity: Entity, index: int, entity_name: str, step_prefix: str = "") -> list[str]: # step prefix is used to distinguish actual steps, and exclusion steps # without the prefix, we get the same parameter binding for both, which borks things up - step_cols: List[str] = [] + step_cols: list[str] = [] condition = self._build_step_query(entity, index, entity_name, step_prefix) step_cols.append(f"if({condition}, 1, 0) as {step_prefix}step_{index}") step_cols.append(f"if({step_prefix}step_{index} = 1, timestamp, null) as {step_prefix}latest_{index}") @@ -637,7 +637,7 @@ def _get_funnel_person_step_events(self): return "" def _get_count_columns(self, max_steps: int): - cols: List[str] = [] + cols: list[str] = [] for i in range(max_steps): cols.append(f"countIf(steps = {i + 1}) step_{i + 1}") @@ -680,7 +680,7 @@ def _get_matching_events(self, max_steps: int): return "" def _get_step_time_avgs(self, max_steps: int, inner_query: bool = False): - conditions: List[str] = [] + conditions: list[str] = [] for i in range(1, max_steps): conditions.append( f"avg(step_{i}_conversion_time) step_{i}_average_conversion_time_inner" @@ -692,7 +692,7 @@ def _get_step_time_avgs(self, max_steps: int, inner_query: bool = False): return f", {formatted}" if formatted else "" def _get_step_time_median(self, max_steps: int, inner_query: bool = False): - conditions: List[str] = [] + conditions: list[str] = [] for i in range(1, max_steps): conditions.append( f"median(step_{i}_conversion_time) step_{i}_median_conversion_time_inner" @@ -720,9 +720,9 @@ def get_step_counts_query(self) -> str: def get_step_counts_without_aggregation_query(self) -> str: raise NotImplementedError() - def _get_breakdown_select_prop(self) -> Tuple[str, Dict[str, Any]]: + def _get_breakdown_select_prop(self) -> tuple[str, dict[str, Any]]: basic_prop_selector = "" - basic_prop_params: Dict[str, Any] = {} + basic_prop_params: dict[str, Any] = {} if not self._filter.breakdown: return basic_prop_selector, basic_prop_params @@ -837,7 +837,7 @@ def _get_cohort_breakdown_join(self) -> str: ON events.distinct_id = cohort_join.distinct_id """ - def _get_breakdown_conditions(self) -> Optional[List[str]]: + def _get_breakdown_conditions(self) -> Optional[list[str]]: """ For people, pagination sets the offset param, which is common across filters and gives us the wrong breakdown values here, so we override it. diff --git a/posthog/queries/funnels/funnel.py b/posthog/queries/funnels/funnel.py index e1ac23f00d637..c72a7f1608ea6 100644 --- a/posthog/queries/funnels/funnel.py +++ b/posthog/queries/funnels/funnel.py @@ -1,4 +1,4 @@ -from typing import List, cast +from typing import cast from posthog.queries.funnels.base import ClickhouseFunnelBase @@ -74,7 +74,7 @@ def get_step_counts_without_aggregation_query(self): """ def _get_comparison_at_step(self, index: int, level_index: int): - or_statements: List[str] = [] + or_statements: list[str] = [] for i in range(level_index, index + 1): or_statements.append(f"latest_{i} < latest_{level_index - 1}") @@ -86,7 +86,7 @@ def get_comparison_cols(self, level_index: int, max_steps: int): level_index: The current smallest comparison step. Everything before level index is already at the minimum ordered timestamps. """ - cols: List[str] = [] + cols: list[str] = [] for i in range(0, max_steps): cols.append(f"step_{i}") if i < level_index: diff --git a/posthog/queries/funnels/funnel_event_query.py b/posthog/queries/funnels/funnel_event_query.py index 2c8ad72524f70..9f0ad134257e8 100644 --- a/posthog/queries/funnels/funnel_event_query.py +++ b/posthog/queries/funnels/funnel_event_query.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Set, Tuple, Union +from typing import Any, Union from posthog.constants import TREND_FILTER_TYPE_ACTIONS from posthog.hogql.hogql import translate_hogql @@ -17,7 +17,7 @@ def get_query( entities=None, entity_name="events", skip_entity_filter=False, - ) -> Tuple[str, Dict[str, Any]]: + ) -> tuple[str, dict[str, Any]]: # Aggregating by group if self._filter.aggregation_group_type_index is not None: aggregation_target = get_aggregation_target_field( @@ -81,7 +81,7 @@ def get_query( if skip_entity_filter: entity_query = "" - entity_params: Dict[str, Any] = {} + entity_params: dict[str, Any] = {} else: entity_query, entity_params = self._get_entity_query(entities, entity_name) @@ -145,8 +145,8 @@ def _determine_should_join_persons(self) -> None: if self._person_on_events_mode != PersonsOnEventsMode.disabled: self._should_join_persons = False - def _get_entity_query(self, entities=None, entity_name="events") -> Tuple[str, Dict[str, Any]]: - events: Set[Union[int, str, None]] = set() + def _get_entity_query(self, entities=None, entity_name="events") -> tuple[str, dict[str, Any]]: + events: set[Union[int, str, None]] = set() entities_to_use = entities or self._filter.entities for entity in entities_to_use: diff --git a/posthog/queries/funnels/funnel_persons.py b/posthog/queries/funnels/funnel_persons.py index 5cebef5fb7dcd..c221727866e3a 100644 --- a/posthog/queries/funnels/funnel_persons.py +++ b/posthog/queries/funnels/funnel_persons.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from posthog.models.filters.filter import Filter from posthog.models.filters.mixins.utils import cached_property @@ -18,7 +18,7 @@ def aggregation_group_type_index(self): def actor_query( self, limit_actors: Optional[bool] = True, - extra_fields: Optional[List[str]] = None, + extra_fields: Optional[list[str]] = None, ): extra_fields_string = ", ".join([self._get_timestamp_outer_select()] + (extra_fields or [])) return ( diff --git a/posthog/queries/funnels/funnel_strict.py b/posthog/queries/funnels/funnel_strict.py index 38b5d3a4c6a09..cb9f97d191870 100644 --- a/posthog/queries/funnels/funnel_strict.py +++ b/posthog/queries/funnels/funnel_strict.py @@ -1,5 +1,3 @@ -from typing import List - from posthog.queries.funnels.base import ClickhouseFunnelBase @@ -57,7 +55,7 @@ def get_step_counts_without_aggregation_query(self): return formatted_query def _get_partition_cols(self, level_index: int, max_steps: int): - cols: List[str] = [] + cols: list[str] = [] for i in range(0, max_steps): cols.append(f"step_{i}") if i < level_index: diff --git a/posthog/queries/funnels/funnel_strict_persons.py b/posthog/queries/funnels/funnel_strict_persons.py index cca6f8e598dc8..2ad13822f5464 100644 --- a/posthog/queries/funnels/funnel_strict_persons.py +++ b/posthog/queries/funnels/funnel_strict_persons.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from posthog.models.filters.filter import Filter from posthog.models.filters.mixins.utils import cached_property @@ -18,7 +18,7 @@ def aggregation_group_type_index(self): def actor_query( self, limit_actors: Optional[bool] = True, - extra_fields: Optional[List[str]] = None, + extra_fields: Optional[list[str]] = None, ): extra_fields_string = ", ".join([self._get_timestamp_outer_select()] + (extra_fields or [])) return ( diff --git a/posthog/queries/funnels/funnel_trends.py b/posthog/queries/funnels/funnel_trends.py index d67b24ae78bbc..cb8ecbe7c8227 100644 --- a/posthog/queries/funnels/funnel_trends.py +++ b/posthog/queries/funnels/funnel_trends.py @@ -1,6 +1,6 @@ from datetime import datetime from itertools import groupby -from typing import List, Optional, Tuple +from typing import Optional from posthog.models.cohort import Cohort from posthog.models.filters.filter import Filter @@ -147,7 +147,7 @@ def get_query(self) -> str: return query - def get_steps_reached_conditions(self) -> Tuple[str, str, str]: + def get_steps_reached_conditions(self) -> tuple[str, str, str]: # How many steps must have been done to count for the denominator of a funnel trends data point from_step = self._filter.funnel_from_step or 0 # How many steps must have been done to count for the numerator of a funnel trends data point @@ -180,7 +180,7 @@ def _summarize_data(self, results): if breakdown_clause: if isinstance(period_row[-1], str) or ( - isinstance(period_row[-1], List) and all(isinstance(item, str) for item in period_row[-1]) + isinstance(period_row[-1], list) and all(isinstance(item, str) for item in period_row[-1]) ): serialized_result.update({"breakdown_value": (period_row[-1])}) else: diff --git a/posthog/queries/funnels/funnel_unordered.py b/posthog/queries/funnels/funnel_unordered.py index ac3a6d939b09f..ee984b9462a75 100644 --- a/posthog/queries/funnels/funnel_unordered.py +++ b/posthog/queries/funnels/funnel_unordered.py @@ -1,5 +1,5 @@ import uuid -from typing import Any, Dict, List, Optional, cast +from typing import Any, Optional, cast from rest_framework.exceptions import ValidationError @@ -40,9 +40,9 @@ def _serialize_step( self, step: Entity, count: int, - people: Optional[List[uuid.UUID]] = None, + people: Optional[list[uuid.UUID]] = None, sampling_factor: Optional[float] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: return { "action_id": None, "name": f"Completed {step.index+1} step{'s' if step.index != 0 else ''}", @@ -119,7 +119,7 @@ def get_step_counts_without_aggregation_query(self): return " UNION ALL ".join(union_queries) def _get_step_times(self, max_steps: int): - conditions: List[str] = [] + conditions: list[str] = [] conversion_times_elements = [] for i in range(max_steps): @@ -146,7 +146,7 @@ def get_sorting_condition(self, max_steps: int): conditions.append(f"arraySort([{','.join(event_times_elements)}]) as event_times") # replacement of latest_i for whatever query part requires it, just like conversion_times - basic_conditions: List[str] = [] + basic_conditions: list[str] = [] for i in range(1, max_steps): basic_conditions.append( f"if(latest_0 < latest_{i} AND latest_{i} <= latest_0 + INTERVAL {self._filter.funnel_window_interval} {self._filter.funnel_window_interval_unit_ch()}, 1, 0)" diff --git a/posthog/queries/funnels/funnel_unordered_persons.py b/posthog/queries/funnels/funnel_unordered_persons.py index 334798c990208..fc1e953bfb58e 100644 --- a/posthog/queries/funnels/funnel_unordered_persons.py +++ b/posthog/queries/funnels/funnel_unordered_persons.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from posthog.models.filters.filter import Filter from posthog.models.filters.mixins.utils import cached_property @@ -25,7 +25,7 @@ def _get_funnel_person_step_events(self): def actor_query( self, limit_actors: Optional[bool] = True, - extra_fields: Optional[List[str]] = None, + extra_fields: Optional[list[str]] = None, ): extra_fields_string = ", ".join([self._get_timestamp_outer_select()] + (extra_fields or [])) return ( diff --git a/posthog/queries/funnels/test/breakdown_cases.py b/posthog/queries/funnels/test/breakdown_cases.py index b38384c745e90..2bc977c974afc 100644 --- a/posthog/queries/funnels/test/breakdown_cases.py +++ b/posthog/queries/funnels/test/breakdown_cases.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from datetime import datetime from string import ascii_lowercase -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Literal, Optional, Union from posthog.constants import INSIGHT_FUNNELS from posthog.models.cohort import Cohort @@ -20,7 +20,7 @@ class FunnelStepResult: name: str count: int - breakdown: Union[List[str], str] + breakdown: Union[list[str], str] average_conversion_time: Optional[float] = None median_conversion_time: Optional[float] = None type: Literal["events", "actions"] = "events" @@ -35,8 +35,8 @@ def _get_actor_ids_at_step(self, filter, funnel_step, breakdown_value=None): return [val["id"] for val in serialized_result] - def _assert_funnel_breakdown_result_is_correct(self, result, steps: List[FunnelStepResult]): - def funnel_result(step: FunnelStepResult, order: int) -> Dict[str, Any]: + def _assert_funnel_breakdown_result_is_correct(self, result, steps: list[FunnelStepResult]): + def funnel_result(step: FunnelStepResult, order: int) -> dict[str, Any]: return { "action_id": step.name if step.type == "events" else step.action_id, "name": step.name, @@ -2646,11 +2646,11 @@ def test_funnel_breakdown_correct_breakdown_props_are_chosen_for_step(self): return TestFunnelBreakdown -def sort_breakdown_funnel_results(results: List[Dict[int, Any]]): +def sort_breakdown_funnel_results(results: list[dict[int, Any]]): return sorted(results, key=lambda r: r[0]["breakdown_value"]) -def assert_funnel_results_equal(left: List[Dict[str, Any]], right: List[Dict[str, Any]]): +def assert_funnel_results_equal(left: list[dict[str, Any]], right: list[dict[str, Any]]): """ Helper to be able to compare two funnel results, but exclude people urls from the comparison, as these include: @@ -2660,7 +2660,7 @@ def assert_funnel_results_equal(left: List[Dict[str, Any]], right: List[Dict[str 2. contain timestamps which are not stable across runs """ - def _filter(steps: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _filter(steps: list[dict[str, Any]]) -> list[dict[str, Any]]: return [{**step, "converted_people_url": None, "dropped_people_url": None} for step in steps] assert len(left) == len(right) diff --git a/posthog/queries/funnels/test/test_breakdowns_by_current_url.py b/posthog/queries/funnels/test/test_breakdowns_by_current_url.py index 7994b195fca94..800cd9f46dca0 100644 --- a/posthog/queries/funnels/test/test_breakdowns_by_current_url.py +++ b/posthog/queries/funnels/test/test_breakdowns_by_current_url.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Dict, Optional +from typing import Optional from posthog.models import Filter from posthog.queries.funnels import ClickhouseFunnel @@ -115,7 +115,7 @@ def setUp(self): journeys_for(journey, team=self.team, create_people=True) - def _run(self, extra: Optional[Dict] = None, events_extra: Optional[Dict] = None): + def _run(self, extra: Optional[dict] = None, events_extra: Optional[dict] = None): if events_extra is None: events_extra = {} if extra is None: diff --git a/posthog/queries/funnels/utils.py b/posthog/queries/funnels/utils.py index 68f93c2d4542e..b2c0df300ce8c 100644 --- a/posthog/queries/funnels/utils.py +++ b/posthog/queries/funnels/utils.py @@ -1,11 +1,9 @@ -from typing import Type - from posthog.constants import FunnelOrderType from posthog.models.filters import Filter from posthog.queries.funnels import ClickhouseFunnelBase -def get_funnel_order_class(filter: Filter) -> Type[ClickhouseFunnelBase]: +def get_funnel_order_class(filter: Filter) -> type[ClickhouseFunnelBase]: from posthog.queries.funnels import ( ClickhouseFunnel, ClickhouseFunnelStrict, diff --git a/posthog/queries/groups_join_query/groups_join_query.py b/posthog/queries/groups_join_query/groups_join_query.py index 2cc62849cacc3..6499d39ce1e94 100644 --- a/posthog/queries/groups_join_query/groups_join_query.py +++ b/posthog/queries/groups_join_query/groups_join_query.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Union from posthog.models import Filter from posthog.models.filters.path_filter import PathFilter @@ -31,5 +31,5 @@ def __init__( self._join_key = join_key self._person_on_events_mode = person_on_events_mode - def get_join_query(self) -> Tuple[str, Dict]: + def get_join_query(self) -> tuple[str, dict]: return "", {} diff --git a/posthog/queries/paths/paths.py b/posthog/queries/paths/paths.py index 6a98857e3927d..21438ee6ea79f 100644 --- a/posthog/queries/paths/paths.py +++ b/posthog/queries/paths/paths.py @@ -1,6 +1,6 @@ import dataclasses from collections import defaultdict -from typing import Dict, List, Literal, Optional, Tuple, Union, cast +from typing import Literal, Optional, Union, cast from rest_framework.exceptions import ValidationError @@ -35,8 +35,8 @@ class Paths: _filter: PathFilter _funnel_filter: Optional[Filter] _team: Team - _extra_event_fields: List[ColumnName] - _extra_event_properties: List[PropertyName] + _extra_event_fields: list[ColumnName] + _extra_event_properties: list[PropertyName] def __init__(self, filter: PathFilter, team: Team, funnel_filter: Optional[Filter] = None) -> None: self._filter = filter @@ -50,8 +50,8 @@ def __init__(self, filter: PathFilter, team: Team, funnel_filter: Optional[Filte } self._funnel_filter = funnel_filter - self._extra_event_fields: List[ColumnName] = [] - self._extra_event_properties: List[PropertyName] = [] + self._extra_event_fields: list[ColumnName] = [] + self._extra_event_properties: list[PropertyName] = [] if self._filter.include_recordings: self._extra_event_fields = ["uuid", "timestamp"] self._extra_event_properties = ["$session_id", "$window_id"] @@ -93,7 +93,7 @@ def _format_results(self, results): ) return resp - def _exec_query(self) -> List[Tuple]: + def _exec_query(self) -> list[tuple]: query = self.get_query() return insight_sync_execute( query, @@ -225,7 +225,7 @@ def get_path_query_funnel_cte(self, funnel_filter: Filter): return "", {} # Implemented in /ee - def get_edge_weight_clause(self) -> Tuple[str, Dict]: + def get_edge_weight_clause(self) -> tuple[str, dict]: return "", {} # Implemented in /ee @@ -240,8 +240,8 @@ def get_session_threshold_clause(self) -> str: return "arraySplit(x -> if(x.3 < %(session_time_threshold)s, 0, 1), paths_tuple)" # Implemented in /ee - def get_target_clause(self) -> Tuple[str, Dict]: - params: Dict[str, Union[str, None]] = { + def get_target_clause(self) -> tuple[str, dict]: + params: dict[str, Union[str, None]] = { "target_point": None, "secondary_target_point": None, } @@ -276,7 +276,7 @@ def get_array_compacting_function(self) -> Literal["arrayResize", "arraySlice"]: return "arraySlice" # Implemented in /ee - def get_filtered_path_ordering(self) -> Tuple[str, ...]: + def get_filtered_path_ordering(self) -> tuple[str, ...]: fields_to_include = ["filtered_path", "filtered_timings"] + [ f"filtered_{field}s" for field in self.extra_event_fields_and_properties ] diff --git a/posthog/queries/paths/paths_actors.py b/posthog/queries/paths/paths_actors.py index e39a01dfee34c..ec739271795e0 100644 --- a/posthog/queries/paths/paths_actors.py +++ b/posthog/queries/paths/paths_actors.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple, cast +from typing import Optional, cast from posthog.models.filters.filter import Filter from posthog.queries.actor_base_query import ActorBaseQuery @@ -27,7 +27,7 @@ class PathsActors(Paths, ActorBaseQuery): # type: ignore QUERY_TYPE = "paths" - def actor_query(self, limit_actors: Optional[bool] = True) -> Tuple[str, Dict]: + def actor_query(self, limit_actors: Optional[bool] = True) -> tuple[str, dict]: paths_per_person_query = self.get_paths_per_person_query() person_path_filter = self.get_person_path_filter() paths_funnel_cte = "" diff --git a/posthog/queries/paths/paths_event_query.py b/posthog/queries/paths/paths_event_query.py index 61b032aa663ec..31241cea64919 100644 --- a/posthog/queries/paths/paths_event_query.py +++ b/posthog/queries/paths/paths_event_query.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple +from typing import Any from posthog.constants import ( FUNNEL_PATH_AFTER_STEP, @@ -21,7 +21,7 @@ class PathEventQuery(EventQuery): FUNNEL_PERSONS_ALIAS = "funnel_actors" _filter: PathFilter - def get_query(self) -> Tuple[str, Dict[str, Any]]: + def get_query(self) -> tuple[str, dict[str, Any]]: funnel_paths_timestamp = "" funnel_paths_join = "" funnel_paths_filter = "" @@ -151,7 +151,7 @@ def _determine_should_join_persons(self) -> None: if self._person_on_events_mode != PersonsOnEventsMode.disabled: self._should_join_persons = False - def _get_grouping_fields(self) -> Tuple[List[str], Dict[str, Any]]: + def _get_grouping_fields(self) -> tuple[list[str], dict[str, Any]]: _fields = [] params = {} @@ -188,8 +188,8 @@ def _get_grouping_fields(self) -> Tuple[List[str], Dict[str, Any]]: return _fields, params - def _get_event_query(self, deep_filtering: bool) -> Tuple[str, Dict[str, Any]]: - params: Dict[str, Any] = {} + def _get_event_query(self, deep_filtering: bool) -> tuple[str, dict[str, Any]]: + params: dict[str, Any] = {} conditions = [] or_conditions = [] diff --git a/posthog/queries/person_query.py b/posthog/queries/person_query.py index cffcce890c80b..b785fcb7442e0 100644 --- a/posthog/queries/person_query.py +++ b/posthog/queries/person_query.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Optional, Union from uuid import UUID from posthog.clickhouse.materialized_columns import ColumnName @@ -45,7 +45,7 @@ class PersonQuery: _filter: Union[Filter, PathFilter, RetentionFilter, StickinessFilter] _team_id: int _column_optimizer: ColumnOptimizer - _extra_fields: Set[ColumnName] + _extra_fields: set[ColumnName] _inner_person_properties: Optional[PropertyGroup] _cohort: Optional[Cohort] _include_distinct_ids: Optional[bool] = False @@ -58,10 +58,10 @@ def __init__( cohort: Optional[Cohort] = None, *, entity: Optional[Entity] = None, - extra_fields: Optional[List[ColumnName]] = None, + extra_fields: Optional[list[ColumnName]] = None, # A sub-optimal version of the `cohort` parameter above, the difference being that # this supports multiple cohort filters, but is not as performant as the above. - cohort_filters: Optional[List[Property]] = None, + cohort_filters: Optional[list[Property]] = None, include_distinct_ids: Optional[bool] = False, ) -> None: self._filter = filter @@ -90,7 +90,7 @@ def get_query( prepend: Optional[Union[str, int]] = None, paginate: bool = False, filter_future_persons: bool = False, - ) -> Tuple[str, Dict]: + ) -> tuple[str, dict]: prepend = str(prepend) if prepend is not None else "" fields = "id" + " ".join( @@ -175,7 +175,7 @@ def get_query( ) @property - def fields(self) -> List[ColumnName]: + def fields(self) -> list[ColumnName]: "Returns person table fields this query exposes" return [alias for column_name, alias in self._get_fields()] @@ -194,7 +194,7 @@ def is_used(self): def _uses_person_id(self, prop: Property) -> bool: return prop.type in ("person", "static-cohort", "precalculated-cohort") - def _get_fields(self) -> List[Tuple[str, str]]: + def _get_fields(self) -> list[tuple[str, str]]: # :TRICKY: Figure out what fields we want to expose - minimizing this set is good for performance. # We use the result from column_optimizer to figure out counts of all properties to be filtered and queried. # Here, we remove the ones only to be used for filtering. @@ -207,7 +207,7 @@ def _get_fields(self) -> List[Tuple[str, str]]: return [(column_name, self.ALIASES.get(column_name, column_name)) for column_name in sorted(columns)] - def _get_person_filter_clauses(self, prepend: str = "") -> Tuple[str, str, Dict]: + def _get_person_filter_clauses(self, prepend: str = "") -> tuple[str, str, dict]: finalization_conditions, params = parse_prop_grouped_clauses( self._team_id, self._inner_person_properties, @@ -231,7 +231,7 @@ def _get_person_filter_clauses(self, prepend: str = "") -> Tuple[str, str, Dict] params.update(prefiltering_params) return prefiltering_conditions, finalization_conditions, params - def _get_fast_single_cohort_clause(self) -> Tuple[str, Dict]: + def _get_fast_single_cohort_clause(self) -> tuple[str, dict]: if self._cohort: cohort_table = ( GET_STATIC_COHORTPEOPLE_BY_COHORT_ID if self._cohort.is_static else GET_COHORTPEOPLE_BY_COHORT_ID @@ -252,10 +252,10 @@ def _get_fast_single_cohort_clause(self) -> Tuple[str, Dict]: else: return "", {} - def _get_multiple_cohorts_clause(self, prepend: str = "") -> Tuple[str, Dict]: + def _get_multiple_cohorts_clause(self, prepend: str = "") -> tuple[str, dict]: if self._cohort_filters: query = [] - params: Dict[str, Any] = {} + params: dict[str, Any] = {} # TODO: doesn't support non-caclculated cohorts for index, property in enumerate(self._cohort_filters): @@ -274,7 +274,7 @@ def _get_multiple_cohorts_clause(self, prepend: str = "") -> Tuple[str, Dict]: else: return "", {} - def _get_limit_offset_clause(self) -> Tuple[str, Dict]: + def _get_limit_offset_clause(self) -> tuple[str, dict]: if not isinstance(self._filter, Filter): return "", {} @@ -295,7 +295,7 @@ def _get_limit_offset_clause(self) -> Tuple[str, Dict]: return clause, params - def _get_search_clauses(self, prepend: str = "") -> Tuple[str, str, Dict]: + def _get_search_clauses(self, prepend: str = "") -> tuple[str, str, dict]: """ Return - respectively - the prefiltering search clause (not aggregated by is_deleted or version, which is great for memory usage), the final search clause (aggregated for true results, more expensive), and new params. @@ -365,7 +365,7 @@ def _get_search_clauses(self, prepend: str = "") -> Tuple[str, str, Dict]: return "", "", {} - def _get_distinct_id_clause(self) -> Tuple[str, Dict]: + def _get_distinct_id_clause(self) -> tuple[str, dict]: if not isinstance(self._filter, Filter): return "", {} @@ -378,7 +378,7 @@ def _get_distinct_id_clause(self) -> Tuple[str, Dict]: return distinct_id_clause, {"distinct_id_filter": self._filter.distinct_id} return "", {} - def _add_distinct_id_join_if_needed(self, query: str, params: Dict[Any, Any]) -> Tuple[str, Dict[Any, Any]]: + def _add_distinct_id_join_if_needed(self, query: str, params: dict[Any, Any]) -> tuple[str, dict[Any, Any]]: if not self._include_distinct_ids: return query, params return ( @@ -395,7 +395,7 @@ def _add_distinct_id_join_if_needed(self, query: str, params: Dict[Any, Any]) -> params, ) - def _get_email_clause(self) -> Tuple[str, Dict]: + def _get_email_clause(self) -> tuple[str, dict]: if not isinstance(self._filter, Filter): return "", {} @@ -407,7 +407,7 @@ def _get_email_clause(self) -> Tuple[str, Dict]: ) return "", {} - def _get_updated_after_clause(self) -> Tuple[str, Dict]: + def _get_updated_after_clause(self) -> tuple[str, dict]: if not isinstance(self._filter, Filter): return "", {} diff --git a/posthog/queries/properties_timeline/properties_timeline.py b/posthog/queries/properties_timeline/properties_timeline.py index 34c392353098a..fe9191e5e1c15 100644 --- a/posthog/queries/properties_timeline/properties_timeline.py +++ b/posthog/queries/properties_timeline/properties_timeline.py @@ -1,6 +1,6 @@ import datetime import json -from typing import Any, Dict, List, Set, TypedDict, Union, cast +from typing import Any, TypedDict, Union, cast from posthog.models.filters.properties_timeline_filter import PropertiesTimelineFilter from posthog.models.group.group import Group @@ -18,13 +18,13 @@ class PropertiesTimelinePoint(TypedDict): timestamp: str - properties: Dict[str, Any] + properties: dict[str, Any] relevant_event_count: int class PropertiesTimelineResult(TypedDict): - points: List[PropertiesTimelinePoint] - crucial_property_keys: List[str] + points: list[PropertiesTimelinePoint] + crucial_property_keys: list[str] effective_date_from: str effective_date_to: str @@ -56,7 +56,7 @@ class PropertiesTimelineResult(TypedDict): class PropertiesTimeline: - def extract_crucial_property_keys(self, filter: PropertiesTimelineFilter) -> Set[str]: + def extract_crucial_property_keys(self, filter: PropertiesTimelineFilter) -> set[str]: is_filter_relevant = lambda property_type, property_group_type_index: ( (property_type == "person") if filter.aggregation_group_type_index is None @@ -76,7 +76,7 @@ def extract_crucial_property_keys(self, filter: PropertiesTimelineFilter) -> Set if filter.breakdown and filter.breakdown_type == "person": if isinstance(filter.breakdown, list): - crucial_property_keys.update(cast(List[str], filter.breakdown)) + crucial_property_keys.update(cast(list[str], filter.breakdown)) else: crucial_property_keys.add(filter.breakdown) diff --git a/posthog/queries/properties_timeline/properties_timeline_event_query.py b/posthog/queries/properties_timeline/properties_timeline_event_query.py index d3ca17eb70091..b5e9a87d07c82 100644 --- a/posthog/queries/properties_timeline/properties_timeline_event_query.py +++ b/posthog/queries/properties_timeline/properties_timeline_event_query.py @@ -1,5 +1,5 @@ import datetime as dt -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional from zoneinfo import ZoneInfo from posthog.models.entity.util import get_entity_filtering_params @@ -20,7 +20,7 @@ def __init__(self, filter: PropertiesTimelineFilter, *args, **kwargs): super().__init__(filter, *args, **kwargs) self._group_type_index = filter.aggregation_group_type_index - def get_query(self) -> Tuple[str, Dict[str, Any]]: + def get_query(self) -> tuple[str, dict[str, Any]]: real_fields = [f"{self.EVENT_TABLE_ALIAS}.timestamp AS timestamp"] sentinel_fields = ["NULL AS timestamp"] @@ -72,8 +72,8 @@ def _determine_should_join_persons(self) -> None: def _determine_should_join_sessions(self) -> None: self._should_join_sessions = False - def _get_date_filter(self) -> Tuple[str, Dict]: - query_params: Dict[str, Any] = {} + def _get_date_filter(self) -> tuple[str, dict]: + query_params: dict[str, Any] = {} query_date_range = QueryDateRange(self._filter, self._team) effective_timezone = ZoneInfo(self._team.timezone) # Get effective date range from QueryDateRange @@ -92,7 +92,7 @@ def _get_date_filter(self) -> Tuple[str, Dict]: return date_filter, query_params - def _get_entity_query(self) -> Tuple[str, Dict]: + def _get_entity_query(self) -> tuple[str, dict]: entity_params, entity_format_params = get_entity_filtering_params( allowed_entities=self._filter.entities, team_id=self._team_id, diff --git a/posthog/queries/property_optimizer.py b/posthog/queries/property_optimizer.py index d69cadfe5e82b..b11be666fd684 100644 --- a/posthog/queries/property_optimizer.py +++ b/posthog/queries/property_optimizer.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Optional, cast +from typing import Optional, cast from rest_framework.exceptions import ValidationError @@ -94,7 +94,7 @@ def using_only_person_properties(property_group: PropertyGroup) -> bool: elif isinstance(property_group.values[0], PropertyGroup): return all( PropertyOptimizer.using_only_person_properties(group) - for group in cast(List[PropertyGroup], property_group.values) + for group in cast(list[PropertyGroup], property_group.values) ) else: diff --git a/posthog/queries/property_values.py b/posthog/queries/property_values.py index a8b943f25d1d2..0e79d15146af6 100644 --- a/posthog/queries/property_values.py +++ b/posthog/queries/property_values.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from django.utils import timezone @@ -16,7 +16,7 @@ def get_property_values_for_key( key: str, team: Team, - event_names: Optional[List[str]] = None, + event_names: Optional[list[str]] = None, value: Optional[str] = None, ): property_field, mat_column_exists = get_property_string_expr("events", key, "%(key)s", "properties") diff --git a/posthog/queries/query_date_range.py b/posthog/queries/query_date_range.py index 2825e4e0360b0..578f2ccf041c3 100644 --- a/posthog/queries/query_date_range.py +++ b/posthog/queries/query_date_range.py @@ -1,6 +1,6 @@ from datetime import datetime, timedelta from functools import cached_property -from typing import Dict, Literal, Optional, Tuple +from typing import Literal, Optional from zoneinfo import ZoneInfo from dateutil.relativedelta import relativedelta @@ -117,7 +117,7 @@ def date_from_clause(self): return self._get_timezone_aware_date_condition("date_from") @cached_property - def date_to(self) -> Tuple[str, Dict]: + def date_to(self) -> tuple[str, dict]: date_to_query = self.date_to_clause date_to = self.date_to_param @@ -129,7 +129,7 @@ def date_to(self) -> Tuple[str, Dict]: return date_to_query, date_to_param @cached_property - def date_from(self) -> Tuple[str, Dict]: + def date_from(self) -> tuple[str, dict]: date_from_query = self.date_from_clause date_from = self.date_from_param diff --git a/posthog/queries/retention/actors_query.py b/posthog/queries/retention/actors_query.py index 5a49c510a3240..e087b88e44fc8 100644 --- a/posthog/queries/retention/actors_query.py +++ b/posthog/queries/retention/actors_query.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional from posthog.models.filters.retention_filter import RetentionFilter from posthog.models.team import Team @@ -19,7 +19,7 @@ class AppearanceRow: actor_id: str appearance_count: int # This is actually the number of days from first event to the current event. - appearances: List[float] + appearances: list[float] # Note: This class does not respect the entire flor from ActorBaseQuery because the result shape differs from other actor queries @@ -98,7 +98,7 @@ def build_actor_activity_query( selected_interval: Optional[int] = None, aggregate_users_by_distinct_id: Optional[bool] = None, retention_events_query=RetentionEventsQuery, -) -> Tuple[str, Dict[str, Any]]: +) -> tuple[str, dict[str, Any]]: from posthog.queries.retention import ( build_returning_event_query, build_target_event_query, @@ -150,7 +150,7 @@ def _build_actor_query( filter_by_breakdown: Optional[BreakdownValues] = None, selected_interval: Optional[int] = None, retention_events_query=RetentionEventsQuery, -) -> Tuple[str, Dict[str, Any]]: +) -> tuple[str, dict[str, Any]]: actor_activity_query, actor_activity_query_params = build_actor_activity_query( filter=filter, team=team, diff --git a/posthog/queries/retention/retention.py b/posthog/queries/retention/retention.py index 8f8b0d89254bf..d3b9f43ca5c60 100644 --- a/posthog/queries/retention/retention.py +++ b/posthog/queries/retention/retention.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional from urllib.parse import urlencode from zoneinfo import ZoneInfo @@ -24,7 +24,7 @@ class Retention: def __init__(self, base_uri="/"): self._base_uri = base_uri - def run(self, filter: RetentionFilter, team: Team, *args, **kwargs) -> List[Dict[str, Any]]: + def run(self, filter: RetentionFilter, team: Team, *args, **kwargs) -> list[dict[str, Any]]: filter.team = team retention_by_breakdown = self._get_retention_by_breakdown_values(filter, team) if filter.breakdowns: @@ -34,7 +34,7 @@ def run(self, filter: RetentionFilter, team: Team, *args, **kwargs) -> List[Dict def _get_retention_by_breakdown_values( self, filter: RetentionFilter, team: Team - ) -> Dict[CohortKey, Dict[str, Any]]: + ) -> dict[CohortKey, dict[str, Any]]: actor_query, actor_query_params = build_actor_activity_query( filter=filter, team=team, retention_events_query=self.event_query ) @@ -77,7 +77,7 @@ def _construct_people_url_for_trend_breakdown_interval( ).to_params() return f"{self._base_uri}api/person/retention/?{urlencode(params)}" - def process_breakdown_table_result(self, resultset: Dict[CohortKey, Dict[str, Any]], filter: RetentionFilter): + def process_breakdown_table_result(self, resultset: dict[CohortKey, dict[str, Any]], filter: RetentionFilter): result = [ { "values": [ @@ -101,7 +101,7 @@ def process_breakdown_table_result(self, resultset: Dict[CohortKey, Dict[str, An def process_table_result( self, - resultset: Dict[CohortKey, Dict[str, Any]], + resultset: dict[CohortKey, dict[str, Any]], filter: RetentionFilter, team: Team, ): @@ -140,7 +140,7 @@ def construct_url(first_day): return result - def actors_in_period(self, filter: RetentionFilter, team: Team) -> Tuple[list, int]: + def actors_in_period(self, filter: RetentionFilter, team: Team) -> tuple[list, int]: """ Creates a response of the form @@ -168,7 +168,7 @@ def build_returning_event_query( aggregate_users_by_distinct_id: Optional[bool] = None, person_on_events_mode: PersonsOnEventsMode = PersonsOnEventsMode.disabled, retention_events_query=RetentionEventsQuery, -) -> Tuple[str, Dict[str, Any]]: +) -> tuple[str, dict[str, Any]]: returning_event_query_templated, returning_event_params = retention_events_query( filter=filter.shallow_clone({"breakdowns": []}), # Avoid pulling in breakdown values from returning event query team=team, @@ -186,7 +186,7 @@ def build_target_event_query( aggregate_users_by_distinct_id: Optional[bool] = None, person_on_events_mode: PersonsOnEventsMode = PersonsOnEventsMode.disabled, retention_events_query=RetentionEventsQuery, -) -> Tuple[str, Dict[str, Any]]: +) -> tuple[str, dict[str, Any]]: target_event_query_templated, target_event_params = retention_events_query( filter=filter, team=team, diff --git a/posthog/queries/retention/retention_events_query.py b/posthog/queries/retention/retention_events_query.py index e84e4bc1e91cc..9e64b758be6e8 100644 --- a/posthog/queries/retention/retention_events_query.py +++ b/posthog/queries/retention/retention_events_query.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Literal, Optional, Tuple, Union, cast +from typing import Any, Literal, Optional, Union, cast from posthog.constants import ( PAGEVIEW_EVENT, @@ -37,7 +37,7 @@ def __init__( person_on_events_mode=person_on_events_mode, ) - def get_query(self) -> Tuple[str, Dict[str, Any]]: + def get_query(self) -> tuple[str, dict[str, Any]]: _fields = [ self.get_timestamp_field(), self.target_field(), diff --git a/posthog/queries/retention/types.py b/posthog/queries/retention/types.py index d3f77fab7f51e..0a9e630da6a85 100644 --- a/posthog/queries/retention/types.py +++ b/posthog/queries/retention/types.py @@ -1,4 +1,4 @@ -from typing import NamedTuple, Tuple, Union +from typing import NamedTuple, Union -BreakdownValues = Tuple[Union[str, int], ...] +BreakdownValues = tuple[Union[str, int], ...] CohortKey = NamedTuple("CohortKey", (("breakdown_values", BreakdownValues), ("period", int))) diff --git a/posthog/queries/stickiness/stickiness.py b/posthog/queries/stickiness/stickiness.py index 50c2ff81ad987..26204b8e9964f 100644 --- a/posthog/queries/stickiness/stickiness.py +++ b/posthog/queries/stickiness/stickiness.py @@ -1,6 +1,6 @@ import copy import urllib.parse -from typing import Any, Dict, List +from typing import Any from posthog.constants import TREND_FILTER_TYPE_ACTIONS from posthog.models.action import Action @@ -19,7 +19,7 @@ class Stickiness: event_query_class = StickinessEventsQuery actor_query_class = StickinessActors - def run(self, filter: StickinessFilter, team: Team, *args, **kwargs) -> List[Dict[str, Any]]: + def run(self, filter: StickinessFilter, team: Team, *args, **kwargs) -> list[dict[str, Any]]: response = [] for entity in filter.entities: if entity.type == TREND_FILTER_TYPE_ACTIONS and entity.id is not None: @@ -29,7 +29,7 @@ def run(self, filter: StickinessFilter, team: Team, *args, **kwargs) -> List[Dic response.extend(entity_resp) return response - def stickiness(self, entity: Entity, filter: StickinessFilter, team: Team) -> Dict[str, Any]: + def stickiness(self, entity: Entity, filter: StickinessFilter, team: Team) -> dict[str, Any]: events_query, event_params = self.event_query_class( entity, filter, team, person_on_events_mode=team.person_on_events_mode ).get_query() @@ -66,8 +66,8 @@ def people( _, serialized_actors, _ = self.actor_query_class(entity=target_entity, filter=filter, team=team).get_actors() return serialized_actors - def process_result(self, counts: List, filter: StickinessFilter, entity: Entity) -> Dict[str, Any]: - response: Dict[int, int] = {} + def process_result(self, counts: list, filter: StickinessFilter, entity: Entity) -> dict[str, Any]: + response: dict[int, int] = {} for result in counts: response[result[1]] = result[0] @@ -92,8 +92,8 @@ def process_result(self, counts: List, filter: StickinessFilter, entity: Entity) "persons_urls": self._get_persons_url(filter, entity), } - def _serialize_entity(self, entity: Entity, filter: StickinessFilter, team: Team) -> List[Dict[str, Any]]: - serialized: Dict[str, Any] = { + def _serialize_entity(self, entity: Entity, filter: StickinessFilter, team: Team) -> list[dict[str, Any]]: + serialized: dict[str, Any] = { "action": entity.to_dict(), "label": entity.name, "count": 0, @@ -107,7 +107,7 @@ def _serialize_entity(self, entity: Entity, filter: StickinessFilter, team: Team response.append(new_dict) return response - def _get_persons_url(self, filter: StickinessFilter, entity: Entity) -> List[Dict[str, Any]]: + def _get_persons_url(self, filter: StickinessFilter, entity: Entity) -> list[dict[str, Any]]: persons_url = [] cache_invalidation_key = generate_short_id() for interval_idx in range(1, filter.total_intervals): @@ -119,7 +119,7 @@ def _get_persons_url(self, filter: StickinessFilter, entity: Entity) -> List[Dic "entity_math": entity.math, "entity_order": entity.order, } - parsed_params: Dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) + parsed_params: dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) persons_url.append( { "filter": extra_params, diff --git a/posthog/queries/stickiness/stickiness_actors.py b/posthog/queries/stickiness/stickiness_actors.py index 625d3852ce536..c6c20301f2bf7 100644 --- a/posthog/queries/stickiness/stickiness_actors.py +++ b/posthog/queries/stickiness/stickiness_actors.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple +from typing import Optional from posthog.models.entity import Entity from posthog.models.filters.mixins.utils import cached_property @@ -22,7 +22,7 @@ def __init__(self, team: Team, entity: Entity, filter: StickinessFilter, **kwarg def aggregation_group_type_index(self): return None - def actor_query(self, limit_actors: Optional[bool] = True) -> Tuple[str, Dict]: + def actor_query(self, limit_actors: Optional[bool] = True) -> tuple[str, dict]: events_query, event_params = self.event_query_class( entity=self.entity, filter=self._filter, diff --git a/posthog/queries/stickiness/stickiness_event_query.py b/posthog/queries/stickiness/stickiness_event_query.py index 25d68b1d6bfdf..7c8c92222ef95 100644 --- a/posthog/queries/stickiness/stickiness_event_query.py +++ b/posthog/queries/stickiness/stickiness_event_query.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Tuple +from typing import Any from posthog.constants import TREND_FILTER_TYPE_ACTIONS, PropertyOperatorType from posthog.models import Entity @@ -20,7 +20,7 @@ def __init__(self, entity: Entity, *args, **kwargs): super().__init__(*args, **kwargs) self._should_round_interval = True - def get_query(self) -> Tuple[str, Dict[str, Any]]: + def get_query(self) -> tuple[str, dict[str, Any]]: prop_query, prop_params = self._get_prop_groups( self._filter.property_groups.combine_property_group(PropertyOperatorType.AND, self._entity.property_groups), person_properties_mode=get_person_properties_mode(self._team), @@ -95,7 +95,7 @@ def _determine_should_join_persons(self) -> None: def aggregation_target(self): return self._person_id_alias - def get_entity_query(self) -> Tuple[str, Dict[str, Any]]: + def get_entity_query(self) -> tuple[str, dict[str, Any]]: if self._entity.type == TREND_FILTER_TYPE_ACTIONS: condition, params = format_action_filter( team_id=self._team_id, diff --git a/posthog/queries/test/test_paths.py b/posthog/queries/test/test_paths.py index 45f09a9ca5787..4be8e9789810c 100644 --- a/posthog/queries/test/test_paths.py +++ b/posthog/queries/test/test_paths.py @@ -1,5 +1,4 @@ import dataclasses -from typing import Dict from dateutil.relativedelta import relativedelta from django.utils.timezone import now @@ -26,7 +25,7 @@ class MockEvent: distinct_id: str team: Team timestamp: str - properties: Dict + properties: dict class TestPaths(ClickhouseTestMixin, APIBaseTest): diff --git a/posthog/queries/test/test_trends.py b/posthog/queries/test/test_trends.py index abb32426dd68d..333babb6ccfbf 100644 --- a/posthog/queries/test/test_trends.py +++ b/posthog/queries/test/test_trends.py @@ -1,7 +1,7 @@ import json import uuid from datetime import datetime -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union from unittest.mock import patch, ANY from urllib.parse import parse_qsl, urlparse @@ -56,8 +56,8 @@ from posthog.utils import generate_cache_key -def breakdown_label(entity: Entity, value: Union[str, int]) -> Dict[str, Optional[Union[str, int]]]: - ret_dict: Dict[str, Optional[Union[str, int]]] = {} +def breakdown_label(entity: Entity, value: Union[str, int]) -> dict[str, Optional[Union[str, int]]]: + ret_dict: dict[str, Optional[Union[str, int]]] = {} if not value or not isinstance(value, str) or "cohort_" not in value: label = ( value @@ -112,7 +112,7 @@ def _get_trend_people(self, filter: Filter, entity: Entity): ).json() return response["results"][0]["people"] - def _create_events(self, use_time=False) -> Tuple[Action, Person]: + def _create_events(self, use_time=False) -> tuple[Action, Person]: person = _create_person( team_id=self.team.pk, distinct_ids=["blabla", "anonymous_id"], @@ -1788,7 +1788,7 @@ def test_trends_compare_hour_interval_relative_range(self): ], ) - def _test_events_with_dates(self, dates: List[str], result, query_time=None, **filter_params): + def _test_events_with_dates(self, dates: list[str], result, query_time=None, **filter_params): _create_person(team_id=self.team.pk, distinct_ids=["person_1"], properties={"name": "John"}) for time in dates: with freeze_time(time): diff --git a/posthog/queries/time_to_see_data/hierarchy.py b/posthog/queries/time_to_see_data/hierarchy.py index b4b686b612405..260a1fad0efbb 100644 --- a/posthog/queries/time_to_see_data/hierarchy.py +++ b/posthog/queries/time_to_see_data/hierarchy.py @@ -1,6 +1,5 @@ from dataclasses import dataclass, field from enum import Enum -from typing import List class NodeType(Enum): @@ -24,7 +23,7 @@ class NodeType(Enum): class Node: type: NodeType data: dict - children: List["Node"] = field(default_factory=list) + children: list["Node"] = field(default_factory=list) def to_dict(self): return { @@ -39,7 +38,7 @@ def construct_hierarchy(session, interactions_and_events, queries) -> dict: Constructs a tree-like hierarchy for session based on interactions and queries, to expose triggered-by relationships. """ - nodes: List[Node] = [] + nodes: list[Node] = [] nodes.extend(make_empty_node(interaction_type, data) for data in interactions_and_events) nodes.extend(make_empty_node(query_type, data) for data in queries) diff --git a/posthog/queries/time_to_see_data/sessions.py b/posthog/queries/time_to_see_data/sessions.py index 8ebeeb8db36a6..709d253d5b78a 100644 --- a/posthog/queries/time_to_see_data/sessions.py +++ b/posthog/queries/time_to_see_data/sessions.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple +from typing import Optional from posthog.client import query_with_columns from posthog.queries.time_to_see_data.hierarchy import construct_hierarchy @@ -58,7 +58,7 @@ def get_sessions(query: SessionsQuerySerializer) -> SessionResponseSerializer: return response_serializer -def get_session_events(query: SessionEventsQuerySerializer) -> Optional[Dict]: +def get_session_events(query: SessionEventsQuerySerializer) -> Optional[dict]: params = { "team_id": query.validated_data["team_id"], "session_id": query.validated_data["session_id"], @@ -82,12 +82,12 @@ def get_session_events(query: SessionEventsQuerySerializer) -> Optional[Dict]: return construct_hierarchy(sessions[0], events, queries) -def _fetch_sessions(query: SessionsQuerySerializer) -> List[Dict]: +def _fetch_sessions(query: SessionsQuerySerializer) -> list[dict]: condition, params = _sessions_condition(query) return query_with_columns(GET_SESSIONS.format(condition=condition), params) -def _sessions_condition(query: SessionsQuerySerializer) -> Tuple[str, Dict]: +def _sessions_condition(query: SessionsQuerySerializer) -> tuple[str, dict]: conditions = [] if "team_id" in query.validated_data: diff --git a/posthog/queries/trends/breakdown.py b/posthog/queries/trends/breakdown.py index 444f045384a14..0f06984bac083 100644 --- a/posthog/queries/trends/breakdown.py +++ b/posthog/queries/trends/breakdown.py @@ -2,7 +2,8 @@ import re import urllib.parse from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union +from collections.abc import Callable from zoneinfo import ZoneInfo from django.forms import ValidationError @@ -104,7 +105,7 @@ def __init__( self.filter = filter self.team = team self.team_id = team.pk - self.params: Dict[str, Any] = {"team_id": team.pk} + self.params: dict[str, Any] = {"team_id": team.pk} self.column_optimizer = column_optimizer or ColumnOptimizer(self.filter, self.team_id) self.add_person_urls = add_person_urls self.person_on_events_mode = person_on_events_mode @@ -122,7 +123,7 @@ def actor_aggregator(self) -> str: return self._person_id_alias @cached_property - def _props_to_filter(self) -> Tuple[str, Dict]: + def _props_to_filter(self) -> tuple[str, dict]: props_to_filter = self.filter.property_groups.combine_property_group( PropertyOperatorType.AND, self.entity.property_groups ) @@ -140,7 +141,7 @@ def _props_to_filter(self) -> Tuple[str, Dict]: hogql_context=self.filter.hogql_context, ) - def get_query(self) -> Tuple[str, Dict, Callable]: + def get_query(self) -> tuple[str, dict, Callable]: date_params = {} query_date_range = QueryDateRange(filter=self.filter, team=self.team) @@ -165,7 +166,7 @@ def get_query(self) -> Tuple[str, Dict, Callable]: ) action_query = "" - action_params: Dict = {} + action_params: dict = {} if self.entity.type == TREND_FILTER_TYPE_ACTIONS: action = self.entity.get_action() action_query, action_params = format_action_filter( @@ -439,7 +440,7 @@ def _breakdown_cohort_params(self): return params, breakdown_filter, breakdown_filter_params, "value" - def _breakdown_prop_params(self, aggregate_operation: str, math_params: Dict): + def _breakdown_prop_params(self, aggregate_operation: str, math_params: dict): values_arr, has_more_values = get_breakdown_prop_values( self.filter, self.entity, @@ -564,7 +565,7 @@ def _get_breakdown_value(self, breakdown: str) -> str: return breakdown_value - def _get_histogram_breakdown_values(self, raw_breakdown_value: str, buckets: List[int]): + def _get_histogram_breakdown_values(self, raw_breakdown_value: str, buckets: list[int]): multi_if_conditionals = [] values_arr = [] @@ -607,9 +608,9 @@ def breakdown_sort_function(self, value): return count_or_aggregated_value * -1, value.get("label") # reverse it def _parse_single_aggregate_result( - self, filter: Filter, entity: Entity, additional_values: Dict[str, Any] + self, filter: Filter, entity: Entity, additional_values: dict[str, Any] ) -> Callable: - def _parse(result: List) -> List: + def _parse(result: list) -> list: parsed_results = [] cache_invalidation_key = generate_short_id() for stats in result: @@ -623,7 +624,7 @@ def _parse(result: List) -> List: "breakdown_value": result_descriptors["breakdown_value"], "breakdown_type": filter.breakdown_type or "event", } - parsed_params: Dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) + parsed_params: dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) parsed_result = { "aggregated_value": float( correct_result_for_sampling(aggregated_value, filter.sampling_factor, entity.math) @@ -647,7 +648,7 @@ def _parse(result: List) -> List: return _parse def _parse_trend_result(self, filter: Filter, entity: Entity) -> Callable: - def _parse(result: List) -> List: + def _parse(result: list) -> list: parsed_results = [] for stats in result: result_descriptors = self._breakdown_result_descriptors(stats[2], filter, entity) @@ -679,9 +680,9 @@ def _get_persons_url( filter: Filter, entity: Entity, team: Team, - point_dates: List[datetime], + point_dates: list[datetime], breakdown_value: Union[str, int], - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: persons_url = [] cache_invalidation_key = generate_short_id() for point_date in point_dates: @@ -705,7 +706,7 @@ def _get_persons_url( "breakdown_value": breakdown_value, "breakdown_type": filter.breakdown_type or "event", } - parsed_params: Dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) + parsed_params: dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) persons_url.append( { "filter": extra_params, @@ -744,7 +745,7 @@ def _determine_breakdown_label( else: return str(value) or BREAKDOWN_NULL_DISPLAY - def _person_join_condition(self) -> Tuple[str, Dict]: + def _person_join_condition(self) -> tuple[str, dict]: if self.person_on_events_mode == PersonsOnEventsMode.person_id_no_override_properties_on_events: return "", {} @@ -780,7 +781,7 @@ def _person_join_condition(self) -> Tuple[str, Dict]: else: return "", {} - def _groups_join_condition(self) -> Tuple[str, Dict]: + def _groups_join_condition(self) -> tuple[str, dict]: return GroupsJoinQuery( self.filter, self.team_id, @@ -788,7 +789,7 @@ def _groups_join_condition(self) -> Tuple[str, Dict]: person_on_events_mode=self.person_on_events_mode, ).get_join_query() - def _sessions_join_condition(self) -> Tuple[str, Dict]: + def _sessions_join_condition(self) -> tuple[str, dict]: session_query = SessionQuery(filter=self.filter, team=self.team) if session_query.is_used: query, session_params = session_query.get_query() diff --git a/posthog/queries/trends/formula.py b/posthog/queries/trends/formula.py index 4f59e5b0cd794..b2fd1bcd8062c 100644 --- a/posthog/queries/trends/formula.py +++ b/posthog/queries/trends/formula.py @@ -2,7 +2,7 @@ from itertools import accumulate import re from string import ascii_uppercase -from typing import Any, Dict, List +from typing import Any from sentry_sdk import push_scope @@ -22,7 +22,7 @@ class TrendsFormula: def _run_formula_query(self, filter: Filter, team: Team): letters = [ascii_uppercase[i] for i in range(0, len(filter.entities))] queries = [] - params: Dict[str, Any] = {} + params: dict[str, Any] = {} for idx, entity in enumerate(filter.entities): _, sql, entity_params, _ = self._get_sql_for_entity(filter, team, entity) # type: ignore sql = PARAM_DISAMBIGUATION_REGEX.sub(f"%({idx}_", sql) @@ -96,7 +96,7 @@ def _run_formula_query(self, filter: Filter, team: Team): ) response = [] for item in result: - additional_values: Dict[str, Any] = {"label": self._label(filter, item)} + additional_values: dict[str, Any] = {"label": self._label(filter, item)} if filter.breakdown: additional_values["breakdown_value"] = additional_values["label"] @@ -113,7 +113,7 @@ def _run_formula_query(self, filter: Filter, team: Team): response.append(parse_response(item, filter, additional_values=additional_values)) return response - def _label(self, filter: Filter, item: List) -> str: + def _label(self, filter: Filter, item: list) -> str: if filter.breakdown: if filter.breakdown_type == "cohort": return get_breakdown_cohort_name(item[2]) diff --git a/posthog/queries/trends/lifecycle.py b/posthog/queries/trends/lifecycle.py index 2629672879e7a..199e3c57973b6 100644 --- a/posthog/queries/trends/lifecycle.py +++ b/posthog/queries/trends/lifecycle.py @@ -1,5 +1,6 @@ import urllib -from typing import Any, Callable, Dict, List, Tuple +from typing import Any +from collections.abc import Callable from posthog.models.entity import Entity from posthog.models.entity.util import get_entity_filtering_params @@ -28,7 +29,7 @@ class Lifecycle: - def _format_lifecycle_query(self, entity: Entity, filter: Filter, team: Team) -> Tuple[str, Dict, Callable]: + def _format_lifecycle_query(self, entity: Entity, filter: Filter, team: Team) -> tuple[str, dict, Callable]: event_query, event_params = LifecycleEventQuery( team=team, filter=filter, person_on_events_mode=team.person_on_events_mode ).get_query() @@ -40,7 +41,7 @@ def _format_lifecycle_query(self, entity: Entity, filter: Filter, team: Team) -> ) def _parse_result(self, filter: Filter, entity: Entity, team: Team) -> Callable: - def _parse(result: List) -> List: + def _parse(result: list) -> list: res = [] for val in result: label = "{} - {}".format(entity.name, val[2]) @@ -61,7 +62,7 @@ def get_people(self, filter: LifecycleFilter, team: Team): _, serialized_actors, _ = LifecycleActors(filter=filter, team=team, limit_actors=True).get_actors() return serialized_actors - def _get_persons_urls(self, filter: Filter, entity: Entity, times: List[str], status) -> List[Dict[str, Any]]: + def _get_persons_urls(self, filter: Filter, entity: Entity, times: list[str], status) -> list[dict[str, Any]]: persons_url = [] cache_invalidation_key = generate_short_id() for target_date in times: @@ -75,7 +76,7 @@ def _get_persons_urls(self, filter: Filter, entity: Entity, times: List[str], st "lifecycle_type": status, } - parsed_params: Dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) + parsed_params: dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) persons_url.append( { "filter": extra_params, @@ -167,7 +168,7 @@ def _person_query(self): ) def _get_date_filter(self): - date_params: Dict[str, Any] = {} + date_params: dict[str, Any] = {} query_date_range = QueryDateRange(self._filter, self._team, should_round=False) _, date_from_params = query_date_range.date_from _, date_to_params = query_date_range.date_to diff --git a/posthog/queries/trends/lifecycle_actors.py b/posthog/queries/trends/lifecycle_actors.py index 2b83dbb364ddb..0e4b7446cda52 100644 --- a/posthog/queries/trends/lifecycle_actors.py +++ b/posthog/queries/trends/lifecycle_actors.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple +from typing import Optional from posthog.queries.actor_base_query import ActorBaseQuery from posthog.queries.trends.lifecycle import LifecycleEventQuery @@ -13,7 +13,7 @@ class LifecycleActors(ActorBaseQuery): QUERY_TYPE = "lifecycle" - def actor_query(self, limit_actors: Optional[bool] = True) -> Tuple[str, Dict]: + def actor_query(self, limit_actors: Optional[bool] = True) -> tuple[str, dict]: events_query, event_params = self.event_query_class( filter=self._filter, team=self._team, diff --git a/posthog/queries/trends/test/test_breakdowns.py b/posthog/queries/trends/test/test_breakdowns.py index 78b5a01e45aaa..3b8651d541512 100644 --- a/posthog/queries/trends/test/test_breakdowns.py +++ b/posthog/queries/trends/test/test_breakdowns.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Dict, Optional +from typing import Optional from posthog.constants import TRENDS_TABLE from posthog.models import Filter @@ -104,7 +104,7 @@ def setUp(self): journeys_for(journey, team=self.team, create_people=True) - def _run(self, extra: Optional[Dict] = None, events_extra: Optional[Dict] = None): + def _run(self, extra: Optional[dict] = None, events_extra: Optional[dict] = None): if events_extra is None: events_extra = {} if extra is None: diff --git a/posthog/queries/trends/test/test_breakdowns_by_current_url.py b/posthog/queries/trends/test/test_breakdowns_by_current_url.py index 26e0c40ae6404..8474d7a27bb23 100644 --- a/posthog/queries/trends/test/test_breakdowns_by_current_url.py +++ b/posthog/queries/trends/test/test_breakdowns_by_current_url.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Dict, Optional +from typing import Optional from posthog.models import Filter from posthog.queries.trends.trends import Trends @@ -99,7 +99,7 @@ def setUp(self): journeys_for(journey, team=self.team, create_people=True) - def _run(self, extra: Optional[Dict] = None, events_extra: Optional[Dict] = None): + def _run(self, extra: Optional[dict] = None, events_extra: Optional[dict] = None): if events_extra is None: events_extra = {} if extra is None: diff --git a/posthog/queries/trends/test/test_formula.py b/posthog/queries/trends/test/test_formula.py index 01e838336e5c8..d711bbff6f827 100644 --- a/posthog/queries/trends/test/test_formula.py +++ b/posthog/queries/trends/test/test_formula.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Optional from freezegun.api import freeze_time @@ -129,7 +129,7 @@ def setUp(self): }, ) - def _run(self, extra: Optional[Dict] = None, run_at: Optional[str] = None): + def _run(self, extra: Optional[dict] = None, run_at: Optional[str] = None): if extra is None: extra = {} with freeze_time(run_at or "2020-01-04T13:01:01Z"): diff --git a/posthog/queries/trends/test/test_paging_breakdowns.py b/posthog/queries/trends/test/test_paging_breakdowns.py index b4040fee61897..47ea447005c1a 100644 --- a/posthog/queries/trends/test/test_paging_breakdowns.py +++ b/posthog/queries/trends/test/test_paging_breakdowns.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Optional from freezegun import freeze_time @@ -38,7 +38,7 @@ def setUp(self): create_people=True, ) - def _run(self, extra: Optional[Dict] = None, run_at: Optional[str] = None): + def _run(self, extra: Optional[dict] = None, run_at: Optional[str] = None): if extra is None: extra = {} with freeze_time(run_at or "2020-01-04T13:01:01Z"): diff --git a/posthog/queries/trends/total_volume.py b/posthog/queries/trends/total_volume.py index e36f6d2de7313..5e91d9272cf18 100644 --- a/posthog/queries/trends/total_volume.py +++ b/posthog/queries/trends/total_volume.py @@ -1,6 +1,7 @@ import urllib.parse from datetime import date, datetime, timedelta -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Union +from collections.abc import Callable from posthog.clickhouse.query_tagging import tag_queries from posthog.constants import ( @@ -48,7 +49,7 @@ class TrendsTotalVolume: EVENT_TABLE_ALIAS = EventQuery.EVENT_TABLE_ALIAS PERSON_ID_OVERRIDES_TABLE_ALIAS = EventQuery.PERSON_ID_OVERRIDES_TABLE_ALIAS - def _total_volume_query(self, entity: Entity, filter: Filter, team: Team) -> Tuple[str, Dict, Callable]: + def _total_volume_query(self, entity: Entity, filter: Filter, team: Team) -> tuple[str, dict, Callable]: interval_func = get_interval_func_ch(filter.interval) person_id_alias = f"{self.DISTINCT_ID_TABLE_ALIAS}.person_id" @@ -82,7 +83,7 @@ def _total_volume_query(self, entity: Entity, filter: Filter, team: Team) -> Tup "timestamp": "e.timestamp", "interval_func": interval_func, } - params: Dict = {"team_id": team.id, "timezone": team.timezone} + params: dict = {"team_id": team.id, "timezone": team.timezone} params = {**params, **math_params, **event_query_params} if filter.display in NON_TIME_SERIES_DISPLAY_TYPES: @@ -219,14 +220,14 @@ def _total_volume_query(self, entity: Entity, filter: Filter, team: Team) -> Tup return final_query, params, self._parse_total_volume_result(filter, entity, team) def _parse_total_volume_result(self, filter: Filter, entity: Entity, team: Team) -> Callable: - def _parse(result: List) -> List: + def _parse(result: list) -> list: parsed_results = [] if result is not None: for stats in result: parsed_result = parse_response(stats, filter, entity=entity) - point_dates: List[Union[datetime, date]] = stats[0] + point_dates: list[Union[datetime, date]] = stats[0] # Ensure we have datetimes for all points - point_datetimes: List[datetime] = [ + point_datetimes: list[datetime] = [ datetime.combine(d, datetime.min.time()) if not isinstance(d, datetime) else d for d in point_dates ] @@ -238,7 +239,7 @@ def _parse(result: List) -> List: return _parse def _parse_aggregate_volume_result(self, filter: Filter, entity: Entity, team_id: int) -> Callable: - def _parse(result: List) -> List: + def _parse(result: list) -> list: aggregated_value = result[0][0] if result else 0 seconds_in_interval = TIME_IN_SECONDS[filter.interval] time_range = enumerate_time_range(filter, seconds_in_interval) @@ -249,7 +250,7 @@ def _parse(result: List) -> List: "entity_math": entity.math, "entity_order": entity.order, } - parsed_params: Dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) + parsed_params: dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) cache_invalidation_key = generate_short_id() return [ @@ -286,8 +287,8 @@ def _get_persons_url( filter: Filter, entity: Entity, team: Team, - point_datetimes: List[datetime], - ) -> List[Dict[str, Any]]: + point_datetimes: list[datetime], + ) -> list[dict[str, Any]]: persons_url = [] cache_invalidation_key = generate_short_id() for point_datetime in point_datetimes: @@ -301,7 +302,7 @@ def _get_persons_url( "entity_order": entity.order, } - parsed_params: Dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) + parsed_params: dict[str, str] = encode_get_request_params({**filter_params, **extra_params}) persons_url.append( { "filter": extra_params, diff --git a/posthog/queries/trends/trends.py b/posthog/queries/trends/trends.py index 81e35336138bf..da8e0ff80e1c7 100644 --- a/posthog/queries/trends/trends.py +++ b/posthog/queries/trends/trends.py @@ -2,7 +2,8 @@ import threading from datetime import datetime, timedelta from itertools import accumulate -from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from typing import Any, Optional, cast +from collections.abc import Callable from zoneinfo import ZoneInfo from dateutil import parser @@ -33,7 +34,7 @@ class Trends(TrendsTotalVolume, Lifecycle, TrendsFormula): - def _get_sql_for_entity(self, filter: Filter, team: Team, entity: Entity) -> Tuple[str, str, Dict, Callable]: + def _get_sql_for_entity(self, filter: Filter, team: Team, entity: Entity) -> tuple[str, str, dict, Callable]: if filter.breakdown and filter.display not in NON_BREAKDOWN_DISPLAY_TYPES: query_type = "trends_breakdown" sql, params, parse_function = TrendsBreakdown( @@ -53,7 +54,7 @@ def _get_sql_for_entity(self, filter: Filter, team: Team, entity: Entity) -> Tup return query_type, sql, params, parse_function # Use cached result even on refresh if team has strict caching enabled - def get_cached_result(self, filter: Filter, team: Team) -> Optional[List[Dict[str, Any]]]: + def get_cached_result(self, filter: Filter, team: Team) -> Optional[list[dict[str, Any]]]: if not team.strict_caching_enabled or filter.breakdown or filter.display != TRENDS_LINEAR: return None @@ -73,7 +74,7 @@ def get_cached_result(self, filter: Filter, team: Team) -> Optional[List[Dict[st return cached_result if _is_present else None # Determine if the current timerange is present in the cache - def is_present_timerange(self, cached_result: List[Dict[str, Any]], filter: Filter, team: Team) -> bool: + def is_present_timerange(self, cached_result: list[dict[str, Any]], filter: Filter, team: Team) -> bool: if ( len(cached_result) > 0 and cached_result[0].get("days") @@ -92,7 +93,7 @@ def is_present_timerange(self, cached_result: List[Dict[str, Any]], filter: Filt return _is_present # Use a condensed filter if a cached result exists in the current timerange - def adjusted_filter(self, filter: Filter, team: Team) -> Tuple[Filter, Optional[Dict[str, Any]]]: + def adjusted_filter(self, filter: Filter, team: Team) -> tuple[Filter, Optional[dict[str, Any]]]: cached_result = self.get_cached_result(filter, team) new_filter = filter.shallow_clone({"date_from": interval_unit(filter.interval)}) if cached_result else filter @@ -107,7 +108,7 @@ def adjusted_filter(self, filter: Filter, team: Team) -> Tuple[Filter, Optional[ def merge_results( self, result, - cached_result: Optional[Dict[str, Any]], + cached_result: Optional[dict[str, Any]], entity_order: int, filter: Filter, team: Team, @@ -129,7 +130,7 @@ def merge_results( else: return result, {} - def _run_query(self, filter: Filter, team: Team, entity: Entity) -> List[Dict[str, Any]]: + def _run_query(self, filter: Filter, team: Team, entity: Entity) -> list[dict[str, Any]]: adjusted_filter, cached_result = self.adjusted_filter(filter, team) with push_scope() as scope: query_type, sql, params, parse_function = self._get_sql_for_entity(adjusted_filter, team, entity) @@ -163,12 +164,12 @@ def _run_query(self, filter: Filter, team: Team, entity: Entity) -> List[Dict[st def _run_query_for_threading( self, - result: List, + result: list, index: int, query_type, sql, params, - query_tags: Dict, + query_tags: dict, filter: Filter, team_id: int, ): @@ -177,10 +178,10 @@ def _run_query_for_threading( scope.set_context("query", {"sql": sql, "params": params}) result[index] = insight_sync_execute(sql, params, query_type=query_type, filter=filter, team_id=team_id) - def _run_parallel(self, filter: Filter, team: Team) -> List[Dict[str, Any]]: - result: List[Optional[List[Dict[str, Any]]]] = [None] * len(filter.entities) - parse_functions: List[Optional[Callable]] = [None] * len(filter.entities) - sql_statements_with_params: List[Tuple[Optional[str], Dict]] = [(None, {})] * len(filter.entities) + def _run_parallel(self, filter: Filter, team: Team) -> list[dict[str, Any]]: + result: list[Optional[list[dict[str, Any]]]] = [None] * len(filter.entities) + parse_functions: list[Optional[Callable]] = [None] * len(filter.entities) + sql_statements_with_params: list[tuple[Optional[str], dict]] = [(None, {})] * len(filter.entities) cached_result = None jobs = [] @@ -225,7 +226,7 @@ def _run_parallel(self, filter: Filter, team: Team) -> List[Dict[str, Any]]: "params": sql_statements_with_params[i][1], }, ) - serialized_data = cast(List[Callable], parse_functions)[entity.index](result[entity.index]) + serialized_data = cast(list[Callable], parse_functions)[entity.index](result[entity.index]) serialized_data = self._format_serialized(entity, serialized_data) merged_results, cached_result = self.merge_results( serialized_data, @@ -237,9 +238,9 @@ def _run_parallel(self, filter: Filter, team: Team) -> List[Dict[str, Any]]: result[entity.index] = merged_results # flatten results - flat_results: List[Dict[str, Any]] = [] + flat_results: list[dict[str, Any]] = [] for item in result: - for flat in cast(List[Dict[str, Any]], item): + for flat in cast(list[dict[str, Any]], item): flat_results.append(flat) if cached_result: @@ -248,7 +249,7 @@ def _run_parallel(self, filter: Filter, team: Team) -> List[Dict[str, Any]]: return flat_results - def run(self, filter: Filter, team: Team, is_csv_export: bool = False, *args, **kwargs) -> List[Dict[str, Any]]: + def run(self, filter: Filter, team: Team, is_csv_export: bool = False, *args, **kwargs) -> list[dict[str, Any]]: self.is_csv_export = is_csv_export actions = Action.objects.filter(team_id=team.pk).order_by("-id") if len(filter.actions) > 0: @@ -274,10 +275,10 @@ def run(self, filter: Filter, team: Team, is_csv_export: bool = False, *args, ** return result - def _format_serialized(self, entity: Entity, result: List[Dict[str, Any]]): + def _format_serialized(self, entity: Entity, result: list[dict[str, Any]]): serialized_data = [] - serialized: Dict[str, Any] = { + serialized: dict[str, Any] = { "action": entity.to_dict(), "label": entity.name, "count": 0, @@ -293,7 +294,7 @@ def _format_serialized(self, entity: Entity, result: List[Dict[str, Any]]): return serialized_data - def _handle_cumulative(self, entity_metrics: List) -> List[Dict[str, Any]]: + def _handle_cumulative(self, entity_metrics: list) -> list[dict[str, Any]]: for metrics in entity_metrics: metrics.update(data=list(accumulate(metrics["data"]))) return entity_metrics diff --git a/posthog/queries/trends/trends_actors.py b/posthog/queries/trends/trends_actors.py index 9c4afa89c41a6..f7db8b36d8ac3 100644 --- a/posthog/queries/trends/trends_actors.py +++ b/posthog/queries/trends/trends_actors.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional from posthog.constants import PropertyOperatorType from posthog.models.cohort import Cohort @@ -37,7 +37,7 @@ def aggregation_group_type_index(self): return self.entity.math_group_type_index return None - def actor_query(self, limit_actors: Optional[bool] = True) -> Tuple[str, Dict]: + def actor_query(self, limit_actors: Optional[bool] = True) -> tuple[str, dict]: if self._filter.breakdown_type == "cohort" and self._filter.breakdown_value != "all": cohort = Cohort.objects.get(pk=self._filter.breakdown_value, team_id=self._team.pk) self._filter = self._filter.shallow_clone( @@ -95,7 +95,7 @@ def actor_query(self, limit_actors: Optional[bool] = True) -> Tuple[str, Dict]: } ) - extra_fields: List[str] = ["distinct_id", "team_id"] if not self.is_aggregating_by_groups else [] + extra_fields: list[str] = ["distinct_id", "team_id"] if not self.is_aggregating_by_groups else [] if self._filter.include_recordings: extra_fields += ["uuid"] @@ -147,7 +147,7 @@ def _aggregation_actor_field(self) -> str: return "person_id" @cached_property - def _aggregation_actor_value_expression_with_params(self) -> Tuple[str, Dict[str, Any]]: + def _aggregation_actor_value_expression_with_params(self) -> tuple[str, dict[str, Any]]: if self.entity.math in PROPERTY_MATH_FUNCTIONS: math_aggregate_operation, _, math_params = process_math( self.entity, self._team, filter=self._filter, event_table_alias="e" diff --git a/posthog/queries/trends/trends_event_query.py b/posthog/queries/trends/trends_event_query.py index bc9e9b979bd00..b856cb6a035e5 100644 --- a/posthog/queries/trends/trends_event_query.py +++ b/posthog/queries/trends/trends_event_query.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Tuple +from typing import Any from posthog.models.property.util import get_property_string_expr from posthog.queries.trends.trends_event_query_base import TrendsEventQueryBase @@ -6,7 +6,7 @@ class TrendsEventQuery(TrendsEventQueryBase): - def get_query(self) -> Tuple[str, Dict[str, Any]]: + def get_query(self) -> tuple[str, dict[str, Any]]: person_id_field = "" if self._should_join_distinct_ids: person_id_field = f", {self._person_id_alias} as person_id" diff --git a/posthog/queries/trends/trends_event_query_base.py b/posthog/queries/trends/trends_event_query_base.py index dbeb9f17cdc3d..8fb17d3579e8f 100644 --- a/posthog/queries/trends/trends_event_query_base.py +++ b/posthog/queries/trends/trends_event_query_base.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Tuple +from typing import Any from posthog.constants import ( MONTHLY_ACTIVE, @@ -29,7 +29,7 @@ def __init__(self, entity: Entity, *args, **kwargs): self._entity = entity super().__init__(*args, **kwargs) - def get_query_base(self) -> Tuple[str, Dict[str, Any]]: + def get_query_base(self) -> tuple[str, dict[str, Any]]: """ Returns part of the event query with only FROM, JOINs and WHERE clauses. """ @@ -114,9 +114,9 @@ def _get_not_null_actor_condition(self) -> str: # If aggregating by group, exclude events that aren't associated with a group return f"""AND "$group_{self._entity.math_group_type_index}" != ''""" - def _get_date_filter(self) -> Tuple[str, Dict]: + def _get_date_filter(self) -> tuple[str, dict]: date_query = "" - date_params: Dict[str, Any] = {} + date_params: dict[str, Any] = {} query_date_range = QueryDateRange(self._filter, self._team) parsed_date_from, date_from_params = query_date_range.date_from parsed_date_to, date_to_params = query_date_range.date_to @@ -145,7 +145,7 @@ def _get_date_filter(self) -> Tuple[str, Dict]: return date_query, date_params - def _get_entity_query(self, *, deep_filtering: bool) -> Tuple[str, Dict]: + def _get_entity_query(self, *, deep_filtering: bool) -> tuple[str, dict]: entity_params, entity_format_params = get_entity_filtering_params( allowed_entities=[self._entity], team_id=self._team_id, diff --git a/posthog/queries/trends/util.py b/posthog/queries/trends/util.py index e002145de9957..3558640602e48 100644 --- a/posthog/queries/trends/util.py +++ b/posthog/queries/trends/util.py @@ -1,6 +1,6 @@ import datetime from datetime import timedelta -from typing import Any, Dict, List, Optional, Tuple, TypeVar +from typing import Any, Optional, TypeVar from zoneinfo import ZoneInfo import structlog @@ -60,10 +60,10 @@ def process_math( filter: Filter, event_table_alias: Optional[str] = None, person_id_alias: str = "person_id", -) -> Tuple[str, str, Dict[str, Any]]: +) -> tuple[str, str, dict[str, Any]]: aggregate_operation = "count(*)" join_condition = "" - params: Dict[str, Any] = {} + params: dict[str, Any] = {} if entity.math in (UNIQUE_USERS, WEEKLY_ACTIVE, MONTHLY_ACTIVE): if team.aggregate_users_by_distinct_id: @@ -100,11 +100,11 @@ def process_math( def parse_response( - stats: Dict, + stats: dict, filter: Filter, - additional_values: Optional[Dict] = None, + additional_values: Optional[dict] = None, entity: Optional[Entity] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: if additional_values is None: additional_values = {} counts = stats[1] @@ -122,7 +122,7 @@ def parse_response( } -def get_active_user_params(filter: Filter, entity: Entity, team_id: int) -> Tuple[Dict[str, Any], Dict[str, Any]]: +def get_active_user_params(filter: Filter, entity: Entity, team_id: int) -> tuple[dict[str, Any], dict[str, Any]]: diff = timedelta(days=7 if entity.math == WEEKLY_ACTIVE else 30) date_from: datetime.datetime @@ -155,11 +155,11 @@ def get_active_user_params(filter: Filter, entity: Entity, team_id: int) -> Tupl return format_params, query_params -def enumerate_time_range(filter: Filter, seconds_in_interval: int) -> List[str]: +def enumerate_time_range(filter: Filter, seconds_in_interval: int) -> list[str]: date_from = filter.date_from date_to = filter.date_to delta = timedelta(seconds=seconds_in_interval) - time_range: List[str] = [] + time_range: list[str] = [] if not date_from or not date_to: return time_range diff --git a/posthog/queries/util.py b/posthog/queries/util.py index e366fb1cc7833..e0d2cb9896f02 100644 --- a/posthog/queries/util.py +++ b/posthog/queries/util.py @@ -1,7 +1,7 @@ import json from datetime import datetime, timedelta from enum import Enum, auto -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from zoneinfo import ZoneInfo from django.utils import timezone @@ -46,21 +46,21 @@ class PersonPropertiesMode(Enum): SELECT timestamp from events WHERE team_id = %(team_id)s AND timestamp > %(earliest_timestamp)s order by timestamp limit 1 """ -TIME_IN_SECONDS: Dict[str, Any] = { +TIME_IN_SECONDS: dict[str, Any] = { "hour": 3600, "day": 3600 * 24, "week": 3600 * 24 * 7, "month": 3600 * 24 * 30, # TODO: Let's get rid of this lie! Months are not all 30 days long } -PERIOD_TO_TRUNC_FUNC: Dict[str, str] = { +PERIOD_TO_TRUNC_FUNC: dict[str, str] = { "hour": "toStartOfHour", "week": "toStartOfWeek", "day": "toStartOfDay", "month": "toStartOfMonth", } -PERIOD_TO_INTERVAL_FUNC: Dict[str, str] = { +PERIOD_TO_INTERVAL_FUNC: dict[str, str] = { "hour": "toIntervalHour", "week": "toIntervalWeek", "day": "toIntervalDay", @@ -141,7 +141,7 @@ def get_time_in_seconds_for_period(period: Optional[str]) -> str: return seconds_in_period -def deep_dump_object(params: Dict[str, Any]) -> Dict[str, Any]: +def deep_dump_object(params: dict[str, Any]) -> dict[str, Any]: for key in params: if isinstance(params[key], dict) or isinstance(params[key], list): params[key] = json.dumps(params[key]) diff --git a/posthog/rate_limit.py b/posthog/rate_limit.py index 856d1b6cceb32..d85238c3d491a 100644 --- a/posthog/rate_limit.py +++ b/posthog/rate_limit.py @@ -2,7 +2,7 @@ import re import time from functools import lru_cache -from typing import List, Optional +from typing import Optional from prometheus_client import Counter from rest_framework.throttling import SimpleRateThrottle, BaseThrottle, UserRateThrottle @@ -36,7 +36,7 @@ @lru_cache(maxsize=1) -def get_team_allow_list(_ttl: int) -> List[str]: +def get_team_allow_list(_ttl: int) -> list[str]: """ The "allow list" will change way less frequently than it will be called _ttl is passed an infrequently changing value to ensure the cache is invalidated after some delay diff --git a/posthog/renderers.py b/posthog/renderers.py index fa2d532fdce70..2c7853497ea57 100644 --- a/posthog/renderers.py +++ b/posthog/renderers.py @@ -1,10 +1,8 @@ -from typing import Dict - import orjson from rest_framework.renderers import JSONRenderer from rest_framework.utils.encoders import JSONEncoder -CleaningMarker = bool | Dict[int, "CleaningMarker"] +CleaningMarker = bool | dict[int, "CleaningMarker"] class SafeJSONRenderer(JSONRenderer): diff --git a/posthog/schema.py b/posthog/schema.py index 5673db2a3bf54..46ad0beb9a11a 100644 --- a/posthog/schema.py +++ b/posthog/schema.py @@ -4,7 +4,7 @@ from __future__ import annotations from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Literal, Optional, Union from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, RootModel @@ -165,8 +165,8 @@ class DatabaseSchemaQueryResponseField(BaseModel): model_config = ConfigDict( extra="forbid", ) - chain: Optional[List[str]] = None - fields: Optional[List[str]] = None + chain: Optional[list[str]] = None + fields: Optional[list[str]] = None key: str table: Optional[str] = None type: str @@ -203,9 +203,9 @@ class ElementType(BaseModel): model_config = ConfigDict( extra="forbid", ) - attr_class: Optional[List[str]] = None + attr_class: Optional[list[str]] = None attr_id: Optional[str] = None - attributes: Dict[str, str] + attributes: dict[str, str] href: Optional[str] = None nth_child: Optional[float] = None nth_of_type: Optional[float] = None @@ -232,9 +232,9 @@ class EventDefinition(BaseModel): model_config = ConfigDict( extra="forbid", ) - elements: List + elements: list event: str - properties: Dict[str, Any] + properties: dict[str, Any] class CorrelationType(str, Enum): @@ -257,9 +257,9 @@ class Person(BaseModel): model_config = ConfigDict( extra="forbid", ) - distinct_ids: List[str] + distinct_ids: list[str] is_identified: Optional[bool] = None - properties: Dict[str, Any] + properties: dict[str, Any] class EventType(BaseModel): @@ -267,12 +267,12 @@ class EventType(BaseModel): extra="forbid", ) distinct_id: str - elements: List[ElementType] + elements: list[ElementType] elements_chain: Optional[str] = None event: str id: str person: Optional[Person] = None - properties: Dict[str, Any] + properties: dict[str, Any] timestamp: str uuid: Optional[str] = None @@ -282,7 +282,7 @@ class Response(BaseModel): extra="forbid", ) next: Optional[str] = None - results: List[EventType] + results: list[EventType] class Properties(BaseModel): @@ -321,7 +321,7 @@ class FunnelCorrelationResult(BaseModel): model_config = ConfigDict( extra="forbid", ) - events: List[EventOddsRatioSerialized] + events: list[EventOddsRatioSerialized] skewed: bool @@ -374,7 +374,7 @@ class FunnelTimeToConvertResults(BaseModel): extra="forbid", ) average_conversion_time: Optional[float] = None - bins: List[List[int]] + bins: list[list[int]] class FunnelVizType(str, Enum): @@ -432,7 +432,7 @@ class HogQLQueryModifiers(BaseModel): model_config = ConfigDict( extra="forbid", ) - dataWarehouseEventsModifiers: Optional[List[DataWarehouseEventsModifier]] = None + dataWarehouseEventsModifiers: Optional[list[DataWarehouseEventsModifier]] = None inCohortVia: Optional[InCohortVia] = None materializationMode: Optional[MaterializationMode] = None personsArgMaxVersion: Optional[PersonsArgMaxVersion] = None @@ -496,12 +496,12 @@ class InsightActorsQueryOptionsResponse(BaseModel): model_config = ConfigDict( extra="forbid", ) - breakdown: Optional[List[BreakdownItem]] = None - compare: Optional[List[CompareItem]] = None - day: Optional[List[DayItem]] = None - interval: Optional[List[IntervalItem]] = None - series: Optional[List[Series]] = None - status: Optional[List[StatusItem]] = None + breakdown: Optional[list[BreakdownItem]] = None + compare: Optional[list[CompareItem]] = None + day: Optional[list[DayItem]] = None + interval: Optional[list[IntervalItem]] = None + series: Optional[list[Series]] = None + status: Optional[list[StatusItem]] = None class InsightFilterProperty(str, Enum): @@ -604,14 +604,14 @@ class PathsFilter(BaseModel): ) edgeLimit: Optional[int] = None endPoint: Optional[str] = None - excludeEvents: Optional[List[str]] = None - includeEventTypes: Optional[List[PathType]] = None - localPathCleaningFilters: Optional[List[PathCleaningFilter]] = None + excludeEvents: Optional[list[str]] = None + includeEventTypes: Optional[list[PathType]] = None + localPathCleaningFilters: Optional[list[PathCleaningFilter]] = None maxEdgeWeight: Optional[int] = None minEdgeWeight: Optional[int] = None pathDropoffKey: Optional[str] = Field(default=None, description="Relevant only within actors query") pathEndKey: Optional[str] = Field(default=None, description="Relevant only within actors query") - pathGroupings: Optional[List[str]] = None + pathGroupings: Optional[list[str]] = None pathReplacements: Optional[bool] = None pathStartKey: Optional[str] = Field(default=None, description="Relevant only within actors query") pathsHogQLExpression: Optional[str] = None @@ -625,14 +625,14 @@ class PathsFilterLegacy(BaseModel): ) edge_limit: Optional[int] = None end_point: Optional[str] = None - exclude_events: Optional[List[str]] = None - funnel_filter: Optional[Dict[str, Any]] = None + exclude_events: Optional[list[str]] = None + funnel_filter: Optional[dict[str, Any]] = None funnel_paths: Optional[FunnelPathType] = None - include_event_types: Optional[List[PathType]] = None - local_path_cleaning_filters: Optional[List[PathCleaningFilter]] = None + include_event_types: Optional[list[PathType]] = None + local_path_cleaning_filters: Optional[list[PathCleaningFilter]] = None max_edge_weight: Optional[int] = None min_edge_weight: Optional[int] = None - path_groupings: Optional[List[str]] = None + path_groupings: Optional[list[str]] = None path_replacements: Optional[bool] = None path_type: Optional[PathType] = None paths_hogql_expression: Optional[str] = None @@ -693,39 +693,39 @@ class QueryResponseAlternative1(BaseModel): extra="forbid", ) next: Optional[str] = None - results: List[EventType] + results: list[EventType] class QueryResponseAlternative2(BaseModel): model_config = ConfigDict( extra="forbid", ) - results: List[Dict[str, Any]] + results: list[dict[str, Any]] class QueryResponseAlternative5(BaseModel): model_config = ConfigDict( extra="forbid", ) - breakdown: Optional[List[BreakdownItem]] = None - compare: Optional[List[CompareItem]] = None - day: Optional[List[DayItem]] = None - interval: Optional[List[IntervalItem]] = None - series: Optional[List[Series]] = None - status: Optional[List[StatusItem]] = None + breakdown: Optional[list[BreakdownItem]] = None + compare: Optional[list[CompareItem]] = None + day: Optional[list[DayItem]] = None + interval: Optional[list[IntervalItem]] = None + series: Optional[list[Series]] = None + status: Optional[list[StatusItem]] = None class QueryResponseAlternative8(BaseModel): model_config = ConfigDict( extra="forbid", ) - errors: List[HogQLNotice] + errors: list[HogQLNotice] inputExpr: Optional[str] = None inputSelect: Optional[str] = None isValid: Optional[bool] = None isValidView: Optional[bool] = None - notices: List[HogQLNotice] - warnings: List[HogQLNotice] + notices: list[HogQLNotice] + warnings: list[HogQLNotice] class QueryStatus(BaseModel): @@ -822,7 +822,7 @@ class SessionPropertyFilter(BaseModel): label: Optional[str] = None operator: PropertyOperator type: Literal["session"] = "session" - value: Optional[Union[str, float, List[Union[str, float]]]] = None + value: Optional[Union[str, float, list[Union[str, float]]]] = None class StepOrderValue(str, Enum): @@ -837,7 +837,7 @@ class StickinessFilter(BaseModel): ) compare: Optional[bool] = None display: Optional[ChartDisplayType] = None - hidden_legend_indexes: Optional[List[float]] = None + hidden_legend_indexes: Optional[list[float]] = None showLegend: Optional[bool] = None showValuesOnSeries: Optional[bool] = None @@ -848,7 +848,7 @@ class StickinessFilterLegacy(BaseModel): ) compare: Optional[bool] = None display: Optional[ChartDisplayType] = None - hidden_legend_indexes: Optional[List[float]] = None + hidden_legend_indexes: Optional[list[float]] = None show_legend: Optional[bool] = None show_values_on_series: Optional[bool] = None @@ -862,8 +862,8 @@ class StickinessQueryResponse(BaseModel): last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: List[Dict[str, Any]] - timings: Optional[List[QueryTiming]] = None + results: list[dict[str, Any]] + timings: Optional[list[QueryTiming]] = None class TimeToSeeDataQuery(BaseModel): @@ -874,7 +874,7 @@ class TimeToSeeDataQuery(BaseModel): modifiers: Optional[HogQLQueryModifiers] = Field( default=None, description="Modifiers used when performing the query" ) - response: Optional[Dict[str, Any]] = Field(default=None, description="Cached query response") + response: Optional[dict[str, Any]] = Field(default=None, description="Cached query response") sessionEnd: Optional[str] = None sessionId: Optional[str] = Field(default=None, description="Project to filter on. Defaults to current session") sessionStart: Optional[str] = Field( @@ -887,7 +887,7 @@ class TimeToSeeDataSessionsQueryResponse(BaseModel): model_config = ConfigDict( extra="forbid", ) - results: List[Dict[str, Any]] + results: list[dict[str, Any]] class TimeToSeeDataWaterfallNode(BaseModel): @@ -902,7 +902,7 @@ class TimelineEntry(BaseModel): model_config = ConfigDict( extra="forbid", ) - events: List[EventType] + events: list[EventType] recording_duration_s: Optional[float] = Field(default=None, description="Duration of the recording in seconds.") sessionId: Optional[str] = Field(default=None, description="Session ID. None means out-of-session events") @@ -919,7 +919,7 @@ class TrendsFilter(BaseModel): decimalPlaces: Optional[float] = None display: Optional[ChartDisplayType] = None formula: Optional[str] = None - hidden_legend_indexes: Optional[List[float]] = None + hidden_legend_indexes: Optional[list[float]] = None showLabelsOnSeries: Optional[bool] = None showLegend: Optional[bool] = None showPercentStackView: Optional[bool] = None @@ -939,7 +939,7 @@ class TrendsFilterLegacy(BaseModel): decimal_places: Optional[float] = None display: Optional[ChartDisplayType] = None formula: Optional[str] = None - hidden_legend_indexes: Optional[List[float]] = None + hidden_legend_indexes: Optional[list[float]] = None show_labels_on_series: Optional[bool] = None show_legend: Optional[bool] = None show_percent_stack_view: Optional[bool] = None @@ -956,8 +956,8 @@ class TrendsQueryResponse(BaseModel): last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: List[Dict[str, Any]] - timings: Optional[List[QueryTiming]] = None + results: list[dict[str, Any]] + timings: Optional[list[QueryTiming]] = None class ActionsPie(BaseModel): @@ -1020,9 +1020,9 @@ class WebOverviewQueryResponse(BaseModel): last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: List[WebOverviewItem] + results: list[WebOverviewItem] samplingRate: Optional[SamplingRate] = None - timings: Optional[List[QueryTiming]] = None + timings: Optional[list[QueryTiming]] = None class WebStatsBreakdown(str, Enum): @@ -1047,7 +1047,7 @@ class WebStatsTableQueryResponse(BaseModel): model_config = ConfigDict( extra="forbid", ) - columns: Optional[List] = None + columns: Optional[list] = None hasMore: Optional[bool] = None hogql: Optional[str] = None is_cached: Optional[bool] = None @@ -1056,42 +1056,42 @@ class WebStatsTableQueryResponse(BaseModel): modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None offset: Optional[int] = None - results: List + results: list samplingRate: Optional[SamplingRate] = None - timings: Optional[List[QueryTiming]] = None - types: Optional[List] = None + timings: Optional[list[QueryTiming]] = None + types: Optional[list] = None class WebTopClicksQueryResponse(BaseModel): model_config = ConfigDict( extra="forbid", ) - columns: Optional[List] = None + columns: Optional[list] = None hogql: Optional[str] = None is_cached: Optional[bool] = None last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: List + results: list samplingRate: Optional[SamplingRate] = None - timings: Optional[List[QueryTiming]] = None - types: Optional[List] = None + timings: Optional[list[QueryTiming]] = None + types: Optional[list] = None class ActorsQueryResponse(BaseModel): model_config = ConfigDict( extra="forbid", ) - columns: List + columns: list hasMore: Optional[bool] = None hogql: str limit: int missing_actors_count: Optional[int] = None modifiers: Optional[HogQLQueryModifiers] = None offset: int - results: List[List] - timings: Optional[List[QueryTiming]] = None - types: List[str] + results: list[list] + timings: Optional[list[QueryTiming]] = None + types: list[str] class AnyResponseType1(BaseModel): @@ -1099,7 +1099,7 @@ class AnyResponseType1(BaseModel): extra="forbid", ) next: Optional[str] = None - results: List[EventType] + results: list[EventType] class Breakdown(BaseModel): @@ -1115,14 +1115,14 @@ class BreakdownFilter(BaseModel): model_config = ConfigDict( extra="forbid", ) - breakdown: Optional[Union[str, float, List[Union[str, float]]]] = None + breakdown: Optional[Union[str, float, list[Union[str, float]]]] = None breakdown_group_type_index: Optional[int] = None breakdown_hide_other_aggregation: Optional[bool] = None breakdown_histogram_bin_count: Optional[int] = None breakdown_limit: Optional[int] = None breakdown_normalize_url: Optional[bool] = None breakdown_type: Optional[BreakdownType] = None - breakdowns: Optional[List[Breakdown]] = None + breakdowns: Optional[list[Breakdown]] = None class DataNode(BaseModel): @@ -1133,16 +1133,16 @@ class DataNode(BaseModel): modifiers: Optional[HogQLQueryModifiers] = Field( default=None, description="Modifiers used when performing the query" ) - response: Optional[Dict[str, Any]] = Field(default=None, description="Cached query response") + response: Optional[dict[str, Any]] = Field(default=None, description="Cached query response") class ChartSettings(BaseModel): model_config = ConfigDict( extra="forbid", ) - goalLines: Optional[List[GoalLine]] = None + goalLines: Optional[list[GoalLine]] = None xAxis: Optional[ChartAxis] = None - yAxis: Optional[List[ChartAxis]] = None + yAxis: Optional[list[ChartAxis]] = None class DataWarehousePersonPropertyFilter(BaseModel): @@ -1153,7 +1153,7 @@ class DataWarehousePersonPropertyFilter(BaseModel): label: Optional[str] = None operator: PropertyOperator type: Literal["data_warehouse_person_property"] = "data_warehouse_person_property" - value: Optional[Union[str, float, List[Union[str, float]]]] = None + value: Optional[Union[str, float, list[Union[str, float]]]] = None class DataWarehousePropertyFilter(BaseModel): @@ -1164,7 +1164,7 @@ class DataWarehousePropertyFilter(BaseModel): label: Optional[str] = None operator: PropertyOperator type: Literal["data_warehouse"] = "data_warehouse" - value: Optional[Union[str, float, List[Union[str, float]]]] = None + value: Optional[Union[str, float, list[Union[str, float]]]] = None class ElementPropertyFilter(BaseModel): @@ -1175,7 +1175,7 @@ class ElementPropertyFilter(BaseModel): label: Optional[str] = None operator: PropertyOperator type: Literal["element"] = "element" - value: Optional[Union[str, float, List[Union[str, float]]]] = None + value: Optional[Union[str, float, list[Union[str, float]]]] = None class EventPropertyFilter(BaseModel): @@ -1186,22 +1186,22 @@ class EventPropertyFilter(BaseModel): label: Optional[str] = None operator: Optional[PropertyOperator] = PropertyOperator("exact") type: Literal["event"] = Field(default="event", description="Event properties") - value: Optional[Union[str, float, List[Union[str, float]]]] = None + value: Optional[Union[str, float, list[Union[str, float]]]] = None class EventsQueryResponse(BaseModel): model_config = ConfigDict( extra="forbid", ) - columns: List + columns: list hasMore: Optional[bool] = None hogql: str limit: Optional[int] = None modifiers: Optional[HogQLQueryModifiers] = None offset: Optional[int] = None - results: List[List] - timings: Optional[List[QueryTiming]] = None - types: List[str] + results: list[list] + timings: Optional[list[QueryTiming]] = None + types: list[str] class FeaturePropertyFilter(BaseModel): @@ -1212,22 +1212,22 @@ class FeaturePropertyFilter(BaseModel): label: Optional[str] = None operator: PropertyOperator type: Literal["feature"] = Field(default="feature", description='Event property with "$feature/" prepended') - value: Optional[Union[str, float, List[Union[str, float]]]] = None + value: Optional[Union[str, float, list[Union[str, float]]]] = None class FunnelCorrelationResponse(BaseModel): model_config = ConfigDict( extra="forbid", ) - columns: Optional[List] = None + columns: Optional[list] = None hasMore: Optional[bool] = None hogql: Optional[str] = None limit: Optional[int] = None modifiers: Optional[HogQLQueryModifiers] = None offset: Optional[int] = None results: FunnelCorrelationResult - timings: Optional[List[QueryTiming]] = None - types: Optional[List] = None + timings: Optional[list[QueryTiming]] = None + types: Optional[list] = None class FunnelsFilterLegacy(BaseModel): @@ -1237,7 +1237,7 @@ class FunnelsFilterLegacy(BaseModel): bin_count: Optional[Union[float, str]] = None breakdown_attribution_type: Optional[BreakdownAttributionType] = None breakdown_attribution_value: Optional[float] = None - exclusions: Optional[List[FunnelExclusionLegacy]] = None + exclusions: Optional[list[FunnelExclusionLegacy]] = None funnel_aggregate_by_hogql: Optional[str] = None funnel_from_step: Optional[float] = None funnel_order_type: Optional[StepOrderValue] = None @@ -1246,7 +1246,7 @@ class FunnelsFilterLegacy(BaseModel): funnel_viz_type: Optional[FunnelVizType] = None funnel_window_interval: Optional[float] = None funnel_window_interval_unit: Optional[FunnelConversionWindowTimeUnit] = None - hidden_legend_breakdowns: Optional[List[str]] = None + hidden_legend_breakdowns: Optional[list[str]] = None layout: Optional[FunnelLayout] = None @@ -1259,8 +1259,8 @@ class FunnelsQueryResponse(BaseModel): last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: Union[FunnelTimeToConvertResults, List[Dict[str, Any]], List[List[Dict[str, Any]]]] - timings: Optional[List[QueryTiming]] = None + results: Union[FunnelTimeToConvertResults, list[dict[str, Any]], list[list[dict[str, Any]]]] + timings: Optional[list[QueryTiming]] = None class GroupPropertyFilter(BaseModel): @@ -1272,7 +1272,7 @@ class GroupPropertyFilter(BaseModel): label: Optional[str] = None operator: PropertyOperator type: Literal["group"] = "group" - value: Optional[Union[str, float, List[Union[str, float]]]] = None + value: Optional[Union[str, float, list[Union[str, float]]]] = None class HogQLAutocompleteResponse(BaseModel): @@ -1280,8 +1280,8 @@ class HogQLAutocompleteResponse(BaseModel): extra="forbid", ) incomplete_list: bool = Field(..., description="Whether or not the suggestions returned are complete") - suggestions: List[AutocompleteCompletionItem] - timings: Optional[List[QueryTiming]] = Field( + suggestions: list[AutocompleteCompletionItem] + timings: Optional[list[QueryTiming]] = Field( default=None, description="Measured timings for different parts of the query generation process" ) @@ -1290,13 +1290,13 @@ class HogQLMetadataResponse(BaseModel): model_config = ConfigDict( extra="forbid", ) - errors: List[HogQLNotice] + errors: list[HogQLNotice] inputExpr: Optional[str] = None inputSelect: Optional[str] = None isValid: Optional[bool] = None isValidView: Optional[bool] = None - notices: List[HogQLNotice] - warnings: List[HogQLNotice] + notices: list[HogQLNotice] + warnings: list[HogQLNotice] class HogQLPropertyFilter(BaseModel): @@ -1306,7 +1306,7 @@ class HogQLPropertyFilter(BaseModel): key: str label: Optional[str] = None type: Literal["hogql"] = "hogql" - value: Optional[Union[str, float, List[Union[str, float]]]] = None + value: Optional[Union[str, float, list[Union[str, float]]]] = None class HogQLQueryResponse(BaseModel): @@ -1314,11 +1314,11 @@ class HogQLQueryResponse(BaseModel): extra="forbid", ) clickhouse: Optional[str] = Field(default=None, description="Executed ClickHouse query") - columns: Optional[List] = Field(default=None, description="Returned columns") + columns: Optional[list] = Field(default=None, description="Returned columns") error: Optional[str] = Field( default=None, description="Query error. Returned only if 'explain' is true. Throws an error otherwise." ) - explain: Optional[List[str]] = Field(default=None, description="Query explanation output") + explain: Optional[list[str]] = Field(default=None, description="Query explanation output") hasMore: Optional[bool] = None hogql: Optional[str] = Field(default=None, description="Generated HogQL query") limit: Optional[int] = None @@ -1328,11 +1328,11 @@ class HogQLQueryResponse(BaseModel): ) offset: Optional[int] = None query: Optional[str] = Field(default=None, description="Input query string") - results: Optional[List] = Field(default=None, description="Query results") - timings: Optional[List[QueryTiming]] = Field( + results: Optional[list] = Field(default=None, description="Query results") + timings: Optional[list[QueryTiming]] = Field( default=None, description="Measured timings for different parts of the query generation process" ) - types: Optional[List] = Field(default=None, description="Types of returned columns") + types: Optional[list] = Field(default=None, description="Types of returned columns") class InsightActorsQueryBase(BaseModel): @@ -1349,7 +1349,7 @@ class LifecycleFilter(BaseModel): extra="forbid", ) showValuesOnSeries: Optional[bool] = None - toggledLifecycles: Optional[List[LifecycleToggle]] = None + toggledLifecycles: Optional[list[LifecycleToggle]] = None class LifecycleFilterLegacy(BaseModel): @@ -1357,7 +1357,7 @@ class LifecycleFilterLegacy(BaseModel): extra="forbid", ) show_values_on_series: Optional[bool] = None - toggledLifecycles: Optional[List[LifecycleToggle]] = None + toggledLifecycles: Optional[list[LifecycleToggle]] = None class LifecycleQueryResponse(BaseModel): @@ -1369,8 +1369,8 @@ class LifecycleQueryResponse(BaseModel): last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: List[Dict[str, Any]] - timings: Optional[List[QueryTiming]] = None + results: list[dict[str, Any]] + timings: Optional[list[QueryTiming]] = None class Node(BaseModel): @@ -1389,8 +1389,8 @@ class PathsQueryResponse(BaseModel): last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: List[Dict[str, Any]] - timings: Optional[List[QueryTiming]] = None + results: list[dict[str, Any]] + timings: Optional[list[QueryTiming]] = None class PersonPropertyFilter(BaseModel): @@ -1401,7 +1401,7 @@ class PersonPropertyFilter(BaseModel): label: Optional[str] = None operator: PropertyOperator type: Literal["person"] = Field(default="person", description="Person properties") - value: Optional[Union[str, float, List[Union[str, float]]]] = None + value: Optional[Union[str, float, list[Union[str, float]]]] = None class QueryResponse(BaseModel): @@ -1414,38 +1414,38 @@ class QueryResponse(BaseModel): modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None results: Any - timings: Optional[List[QueryTiming]] = None + timings: Optional[list[QueryTiming]] = None class QueryResponseAlternative3(BaseModel): model_config = ConfigDict( extra="forbid", ) - columns: List + columns: list hasMore: Optional[bool] = None hogql: str limit: Optional[int] = None modifiers: Optional[HogQLQueryModifiers] = None offset: Optional[int] = None - results: List[List] - timings: Optional[List[QueryTiming]] = None - types: List[str] + results: list[list] + timings: Optional[list[QueryTiming]] = None + types: list[str] class QueryResponseAlternative4(BaseModel): model_config = ConfigDict( extra="forbid", ) - columns: List + columns: list hasMore: Optional[bool] = None hogql: str limit: int missing_actors_count: Optional[int] = None modifiers: Optional[HogQLQueryModifiers] = None offset: int - results: List[List] - timings: Optional[List[QueryTiming]] = None - types: List[str] + results: list[list] + timings: Optional[list[QueryTiming]] = None + types: list[str] class QueryResponseAlternative6(BaseModel): @@ -1454,8 +1454,8 @@ class QueryResponseAlternative6(BaseModel): ) hasMore: Optional[bool] = None hogql: Optional[str] = None - results: List[TimelineEntry] - timings: Optional[List[QueryTiming]] = None + results: list[TimelineEntry] + timings: Optional[list[QueryTiming]] = None class QueryResponseAlternative7(BaseModel): @@ -1463,11 +1463,11 @@ class QueryResponseAlternative7(BaseModel): extra="forbid", ) clickhouse: Optional[str] = Field(default=None, description="Executed ClickHouse query") - columns: Optional[List] = Field(default=None, description="Returned columns") + columns: Optional[list] = Field(default=None, description="Returned columns") error: Optional[str] = Field( default=None, description="Query error. Returned only if 'explain' is true. Throws an error otherwise." ) - explain: Optional[List[str]] = Field(default=None, description="Query explanation output") + explain: Optional[list[str]] = Field(default=None, description="Query explanation output") hasMore: Optional[bool] = None hogql: Optional[str] = Field(default=None, description="Generated HogQL query") limit: Optional[int] = None @@ -1477,11 +1477,11 @@ class QueryResponseAlternative7(BaseModel): ) offset: Optional[int] = None query: Optional[str] = Field(default=None, description="Input query string") - results: Optional[List] = Field(default=None, description="Query results") - timings: Optional[List[QueryTiming]] = Field( + results: Optional[list] = Field(default=None, description="Query results") + timings: Optional[list[QueryTiming]] = Field( default=None, description="Measured timings for different parts of the query generation process" ) - types: Optional[List] = Field(default=None, description="Types of returned columns") + types: Optional[list] = Field(default=None, description="Types of returned columns") class QueryResponseAlternative9(BaseModel): @@ -1489,8 +1489,8 @@ class QueryResponseAlternative9(BaseModel): extra="forbid", ) incomplete_list: bool = Field(..., description="Whether or not the suggestions returned are complete") - suggestions: List[AutocompleteCompletionItem] - timings: Optional[List[QueryTiming]] = Field( + suggestions: list[AutocompleteCompletionItem] + timings: Optional[list[QueryTiming]] = Field( default=None, description="Measured timings for different parts of the query generation process" ) @@ -1504,16 +1504,16 @@ class QueryResponseAlternative10(BaseModel): last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: List[WebOverviewItem] + results: list[WebOverviewItem] samplingRate: Optional[SamplingRate] = None - timings: Optional[List[QueryTiming]] = None + timings: Optional[list[QueryTiming]] = None class QueryResponseAlternative11(BaseModel): model_config = ConfigDict( extra="forbid", ) - columns: Optional[List] = None + columns: Optional[list] = None hasMore: Optional[bool] = None hogql: Optional[str] = None is_cached: Optional[bool] = None @@ -1522,26 +1522,26 @@ class QueryResponseAlternative11(BaseModel): modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None offset: Optional[int] = None - results: List + results: list samplingRate: Optional[SamplingRate] = None - timings: Optional[List[QueryTiming]] = None - types: Optional[List] = None + timings: Optional[list[QueryTiming]] = None + types: Optional[list] = None class QueryResponseAlternative12(BaseModel): model_config = ConfigDict( extra="forbid", ) - columns: Optional[List] = None + columns: Optional[list] = None hogql: Optional[str] = None is_cached: Optional[bool] = None last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: List + results: list samplingRate: Optional[SamplingRate] = None - timings: Optional[List[QueryTiming]] = None - types: Optional[List] = None + timings: Optional[list[QueryTiming]] = None + types: Optional[list] = None class QueryResponseAlternative13(BaseModel): @@ -1553,23 +1553,23 @@ class QueryResponseAlternative13(BaseModel): last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: List[Dict[str, Any]] - timings: Optional[List[QueryTiming]] = None + results: list[dict[str, Any]] + timings: Optional[list[QueryTiming]] = None class QueryResponseAlternative17(BaseModel): model_config = ConfigDict( extra="forbid", ) - columns: Optional[List] = None + columns: Optional[list] = None hasMore: Optional[bool] = None hogql: Optional[str] = None limit: Optional[int] = None modifiers: Optional[HogQLQueryModifiers] = None offset: Optional[int] = None results: FunnelCorrelationResult - timings: Optional[List[QueryTiming]] = None - types: Optional[List] = None + timings: Optional[list[QueryTiming]] = None + types: Optional[list] = None class RetentionFilter(BaseModel): @@ -1602,7 +1602,7 @@ class RetentionResult(BaseModel): ) date: AwareDatetime label: str - values: List[RetentionValue] + values: list[RetentionValue] class SavedInsightNode(BaseModel): @@ -1664,8 +1664,8 @@ class SessionsTimelineQueryResponse(BaseModel): ) hasMore: Optional[bool] = None hogql: Optional[str] = None - results: List[TimelineEntry] - timings: Optional[List[QueryTiming]] = None + results: list[TimelineEntry] + timings: Optional[list[QueryTiming]] = None class TimeToSeeDataJSONNode(BaseModel): @@ -1699,7 +1699,7 @@ class WebAnalyticsQueryBase(BaseModel): ) dateRange: Optional[DateRange] = None modifiers: Optional[HogQLQueryModifiers] = None - properties: List[Union[EventPropertyFilter, PersonPropertyFilter]] + properties: list[Union[EventPropertyFilter, PersonPropertyFilter]] sampling: Optional[Sampling] = None useSessionsTable: Optional[bool] = None @@ -1712,7 +1712,7 @@ class WebOverviewQuery(BaseModel): dateRange: Optional[DateRange] = None kind: Literal["WebOverviewQuery"] = "WebOverviewQuery" modifiers: Optional[HogQLQueryModifiers] = None - properties: List[Union[EventPropertyFilter, PersonPropertyFilter]] + properties: list[Union[EventPropertyFilter, PersonPropertyFilter]] response: Optional[WebOverviewQueryResponse] = None sampling: Optional[Sampling] = None useSessionsTable: Optional[bool] = None @@ -1730,7 +1730,7 @@ class WebStatsTableQuery(BaseModel): kind: Literal["WebStatsTableQuery"] = "WebStatsTableQuery" limit: Optional[int] = None modifiers: Optional[HogQLQueryModifiers] = None - properties: List[Union[EventPropertyFilter, PersonPropertyFilter]] + properties: list[Union[EventPropertyFilter, PersonPropertyFilter]] response: Optional[WebStatsTableQueryResponse] = None sampling: Optional[Sampling] = None useSessionsTable: Optional[bool] = None @@ -1743,7 +1743,7 @@ class WebTopClicksQuery(BaseModel): dateRange: Optional[DateRange] = None kind: Literal["WebTopClicksQuery"] = "WebTopClicksQuery" modifiers: Optional[HogQLQueryModifiers] = None - properties: List[Union[EventPropertyFilter, PersonPropertyFilter]] + properties: list[Union[EventPropertyFilter, PersonPropertyFilter]] response: Optional[WebTopClicksQueryResponse] = None sampling: Optional[Sampling] = None useSessionsTable: Optional[bool] = None @@ -1752,7 +1752,7 @@ class WebTopClicksQuery(BaseModel): class AnyResponseType( RootModel[ Union[ - Dict[str, Any], + dict[str, Any], HogQLQueryResponse, HogQLMetadataResponse, HogQLAutocompleteResponse, @@ -1762,7 +1762,7 @@ class AnyResponseType( ] ): root: Union[ - Dict[str, Any], + dict[str, Any], HogQLQueryResponse, HogQLMetadataResponse, HogQLAutocompleteResponse, @@ -1778,7 +1778,7 @@ class DashboardFilter(BaseModel): date_from: Optional[str] = None date_to: Optional[str] = None properties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -1804,7 +1804,7 @@ class DataWarehouseNode(BaseModel): custom_name: Optional[str] = None distinct_id_field: str fixedProperties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -1838,7 +1838,7 @@ class DataWarehouseNode(BaseModel): ) name: Optional[str] = None properties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -1855,7 +1855,7 @@ class DataWarehouseNode(BaseModel): ] ] ] = Field(default=None, description="Properties configurable in the interface") - response: Optional[Dict[str, Any]] = Field(default=None, description="Cached query response") + response: Optional[dict[str, Any]] = Field(default=None, description="Cached query response") table_name: str timestamp_field: str @@ -1868,7 +1868,7 @@ class DatabaseSchemaQuery(BaseModel): modifiers: Optional[HogQLQueryModifiers] = Field( default=None, description="Modifiers used when performing the query" ) - response: Optional[Dict[str, List[DatabaseSchemaQueryResponseField]]] = Field( + response: Optional[dict[str, list[DatabaseSchemaQueryResponseField]]] = Field( default=None, description="Cached query response" ) @@ -1879,7 +1879,7 @@ class EntityNode(BaseModel): ) custom_name: Optional[str] = None fixedProperties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -1911,7 +1911,7 @@ class EntityNode(BaseModel): ) name: Optional[str] = None properties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -1928,7 +1928,7 @@ class EntityNode(BaseModel): ] ] ] = Field(default=None, description="Properties configurable in the interface") - response: Optional[Dict[str, Any]] = Field(default=None, description="Cached query response") + response: Optional[dict[str, Any]] = Field(default=None, description="Cached query response") class EventsNode(BaseModel): @@ -1938,7 +1938,7 @@ class EventsNode(BaseModel): custom_name: Optional[str] = None event: Optional[str] = Field(default=None, description="The event or `null` for all events.") fixedProperties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -1970,9 +1970,9 @@ class EventsNode(BaseModel): default=None, description="Modifiers used when performing the query" ) name: Optional[str] = None - orderBy: Optional[List[str]] = Field(default=None, description="Columns to order by") + orderBy: Optional[list[str]] = Field(default=None, description="Columns to order by") properties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2002,7 +2002,7 @@ class EventsQuery(BaseModel): event: Optional[str] = Field(default=None, description="Limit to events matching this string") filterTestAccounts: Optional[bool] = Field(default=None, description="Filter test accounts") fixedProperties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2028,10 +2028,10 @@ class EventsQuery(BaseModel): default=None, description="Modifiers used when performing the query" ) offset: Optional[int] = Field(default=None, description="Number of rows to skip before returning rows") - orderBy: Optional[List[str]] = Field(default=None, description="Columns to order by") + orderBy: Optional[list[str]] = Field(default=None, description="Columns to order by") personId: Optional[str] = Field(default=None, description="Show events for a given person") properties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2049,8 +2049,8 @@ class EventsQuery(BaseModel): ] ] = Field(default=None, description="Properties configurable in the interface") response: Optional[EventsQueryResponse] = Field(default=None, description="Cached query response") - select: List[str] = Field(..., description="Return a limited set of data. Required.") - where: Optional[List[str]] = Field(default=None, description="HogQL filters to apply on returned data") + select: list[str] = Field(..., description="Return a limited set of data. Required.") + where: Optional[list[str]] = Field(default=None, description="HogQL filters to apply on returned data") class FunnelExclusionActionsNode(BaseModel): @@ -2059,7 +2059,7 @@ class FunnelExclusionActionsNode(BaseModel): ) custom_name: Optional[str] = None fixedProperties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2094,7 +2094,7 @@ class FunnelExclusionActionsNode(BaseModel): ) name: Optional[str] = None properties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2111,7 +2111,7 @@ class FunnelExclusionActionsNode(BaseModel): ] ] ] = Field(default=None, description="Properties configurable in the interface") - response: Optional[Dict[str, Any]] = Field(default=None, description="Cached query response") + response: Optional[dict[str, Any]] = Field(default=None, description="Cached query response") class FunnelExclusionEventsNode(BaseModel): @@ -2121,7 +2121,7 @@ class FunnelExclusionEventsNode(BaseModel): custom_name: Optional[str] = None event: Optional[str] = Field(default=None, description="The event or `null` for all events.") fixedProperties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2155,9 +2155,9 @@ class FunnelExclusionEventsNode(BaseModel): default=None, description="Modifiers used when performing the query" ) name: Optional[str] = None - orderBy: Optional[List[str]] = Field(default=None, description="Columns to order by") + orderBy: Optional[list[str]] = Field(default=None, description="Columns to order by") properties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2184,7 +2184,7 @@ class HogQLFilters(BaseModel): dateRange: Optional[DateRange] = None filterTestAccounts: Optional[bool] = None properties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2215,7 +2215,7 @@ class HogQLQuery(BaseModel): ) query: str response: Optional[HogQLQueryResponse] = Field(default=None, description="Cached query response") - values: Optional[Dict[str, Any]] = Field( + values: Optional[dict[str, Any]] = Field( default=None, description="Constant values that can be referenced with the {placeholder} syntax in the query" ) @@ -2227,7 +2227,7 @@ class PersonsNode(BaseModel): cohort: Optional[int] = None distinctId: Optional[str] = None fixedProperties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2254,7 +2254,7 @@ class PersonsNode(BaseModel): ) offset: Optional[int] = None properties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2271,7 +2271,7 @@ class PersonsNode(BaseModel): ] ] ] = Field(default=None, description="Properties configurable in the interface") - response: Optional[Dict[str, Any]] = Field(default=None, description="Cached query response") + response: Optional[dict[str, Any]] = Field(default=None, description="Cached query response") search: Optional[str] = None @@ -2280,7 +2280,7 @@ class PropertyGroupFilterValue(BaseModel): extra="forbid", ) type: FilterLogicalOperator - values: List[ + values: list[ Union[ PropertyGroupFilterValue, Union[ @@ -2310,15 +2310,15 @@ class QueryResponseAlternative14(BaseModel): last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: List[RetentionResult] - timings: Optional[List[QueryTiming]] = None + results: list[RetentionResult] + timings: Optional[list[QueryTiming]] = None class QueryResponseAlternative( RootModel[ Union[ QueryResponseAlternative1, - Dict[str, Any], + dict[str, Any], QueryResponseAlternative2, QueryResponseAlternative3, QueryResponseAlternative4, @@ -2333,13 +2333,13 @@ class QueryResponseAlternative( QueryResponseAlternative13, QueryResponseAlternative14, QueryResponseAlternative17, - Dict[str, List[DatabaseSchemaQueryResponseField]], + dict[str, list[DatabaseSchemaQueryResponseField]], ] ] ): root: Union[ QueryResponseAlternative1, - Dict[str, Any], + dict[str, Any], QueryResponseAlternative2, QueryResponseAlternative3, QueryResponseAlternative4, @@ -2354,7 +2354,7 @@ class QueryResponseAlternative( QueryResponseAlternative13, QueryResponseAlternative14, QueryResponseAlternative17, - Dict[str, List[DatabaseSchemaQueryResponseField]], + dict[str, list[DatabaseSchemaQueryResponseField]], ] @@ -2367,8 +2367,8 @@ class RetentionQueryResponse(BaseModel): last_refresh: Optional[str] = None modifiers: Optional[HogQLQueryModifiers] = None next_allowed_client_refresh: Optional[str] = None - results: List[RetentionResult] - timings: Optional[List[QueryTiming]] = None + results: list[RetentionResult] + timings: Optional[list[QueryTiming]] = None class SessionsTimelineQuery(BaseModel): @@ -2395,7 +2395,7 @@ class ActionsNode(BaseModel): ) custom_name: Optional[str] = None fixedProperties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2428,7 +2428,7 @@ class ActionsNode(BaseModel): ) name: Optional[str] = None properties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2445,7 +2445,7 @@ class ActionsNode(BaseModel): ] ] ] = Field(default=None, description="Properties configurable in the interface") - response: Optional[Dict[str, Any]] = Field(default=None, description="Cached query response") + response: Optional[dict[str, Any]] = Field(default=None, description="Cached query response") class DataVisualizationNode(BaseModel): @@ -2465,7 +2465,7 @@ class FunnelsFilter(BaseModel): binCount: Optional[int] = None breakdownAttributionType: Optional[BreakdownAttributionType] = None breakdownAttributionValue: Optional[int] = None - exclusions: Optional[List[Union[FunnelExclusionEventsNode, FunnelExclusionActionsNode]]] = None + exclusions: Optional[list[Union[FunnelExclusionEventsNode, FunnelExclusionActionsNode]]] = None funnelAggregateByHogQL: Optional[str] = None funnelFromStep: Optional[int] = None funnelOrderType: Optional[StepOrderValue] = None @@ -2474,7 +2474,7 @@ class FunnelsFilter(BaseModel): funnelVizType: Optional[FunnelVizType] = None funnelWindowInterval: Optional[int] = None funnelWindowIntervalUnit: Optional[FunnelConversionWindowTimeUnit] = None - hidden_legend_breakdowns: Optional[List[str]] = None + hidden_legend_breakdowns: Optional[list[str]] = None layout: Optional[FunnelLayout] = None @@ -2508,7 +2508,7 @@ class PropertyGroupFilter(BaseModel): extra="forbid", ) type: FilterLogicalOperator - values: List[PropertyGroupFilterValue] + values: list[PropertyGroupFilterValue] class RetentionQuery(BaseModel): @@ -2526,7 +2526,7 @@ class RetentionQuery(BaseModel): ) properties: Optional[ Union[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2567,7 +2567,7 @@ class StickinessQuery(BaseModel): ) properties: Optional[ Union[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2587,7 +2587,7 @@ class StickinessQuery(BaseModel): ] ] = Field(default=None, description="Property filters for all series") samplingFactor: Optional[float] = Field(default=None, description="Sampling rate") - series: List[Union[EventsNode, ActionsNode, DataWarehouseNode]] = Field( + series: list[Union[EventsNode, ActionsNode, DataWarehouseNode]] = Field( ..., description="Events and actions to include" ) stickinessFilter: Optional[StickinessFilter] = Field( @@ -2614,7 +2614,7 @@ class TrendsQuery(BaseModel): ) properties: Optional[ Union[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2635,7 +2635,7 @@ class TrendsQuery(BaseModel): ] = Field(default=None, description="Property filters for all series") response: Optional[TrendsQueryResponse] = None samplingFactor: Optional[float] = Field(default=None, description="Sampling rate") - series: List[Union[EventsNode, ActionsNode, DataWarehouseNode]] = Field( + series: list[Union[EventsNode, ActionsNode, DataWarehouseNode]] = Field( ..., description="Events and actions to include" ) trendsFilter: Optional[TrendsFilter] = Field(default=None, description="Properties specific to the trends insight") @@ -2645,22 +2645,22 @@ class FilterType(BaseModel): model_config = ConfigDict( extra="forbid", ) - actions: Optional[List[Dict[str, Any]]] = None + actions: Optional[list[dict[str, Any]]] = None aggregation_group_type_index: Optional[float] = None - breakdown: Optional[Union[str, float, List[Union[str, float]]]] = None + breakdown: Optional[Union[str, float, list[Union[str, float]]]] = None breakdown_group_type_index: Optional[float] = None breakdown_hide_other_aggregation: Optional[bool] = None breakdown_limit: Optional[int] = None breakdown_normalize_url: Optional[bool] = None breakdown_type: Optional[BreakdownType] = None - breakdowns: Optional[List[Breakdown]] = None - data_warehouse: Optional[List[Dict[str, Any]]] = None + breakdowns: Optional[list[Breakdown]] = None + data_warehouse: Optional[list[dict[str, Any]]] = None date_from: Optional[str] = None date_to: Optional[str] = None entity_id: Optional[Union[str, float]] = None entity_math: Optional[str] = None entity_type: Optional[EntityType] = None - events: Optional[List[Dict[str, Any]]] = None + events: Optional[list[dict[str, Any]]] = None explicit_date: Optional[Union[bool, str]] = Field( default=None, description='Whether the `date_from` and `date_to` should be used verbatim. Disables rounding to the start and end of period. Strings are cast to bools, e.g. "true" -> true.', @@ -2669,10 +2669,10 @@ class FilterType(BaseModel): from_dashboard: Optional[Union[bool, float]] = None insight: Optional[InsightType] = None interval: Optional[IntervalType] = None - new_entity: Optional[List[Dict[str, Any]]] = None + new_entity: Optional[list[dict[str, Any]]] = None properties: Optional[ Union[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2716,7 +2716,7 @@ class FunnelsQuery(BaseModel): ) properties: Optional[ Union[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2736,7 +2736,7 @@ class FunnelsQuery(BaseModel): ] ] = Field(default=None, description="Property filters for all series") samplingFactor: Optional[float] = Field(default=None, description="Sampling rate") - series: List[Union[EventsNode, ActionsNode, DataWarehouseNode]] = Field( + series: list[Union[EventsNode, ActionsNode, DataWarehouseNode]] = Field( ..., description="Events and actions to include" ) @@ -2756,7 +2756,7 @@ class InsightsQueryBase(BaseModel): ) properties: Optional[ Union[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2798,7 +2798,7 @@ class LifecycleQuery(BaseModel): ) properties: Optional[ Union[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2819,7 +2819,7 @@ class LifecycleQuery(BaseModel): ] = Field(default=None, description="Property filters for all series") response: Optional[LifecycleQueryResponse] = None samplingFactor: Optional[float] = Field(default=None, description="Sampling rate") - series: List[Union[EventsNode, ActionsNode, DataWarehouseNode]] = Field( + series: list[Union[EventsNode, ActionsNode, DataWarehouseNode]] = Field( ..., description="Events and actions to include" ) @@ -2844,7 +2844,7 @@ class FunnelsActorsQuery(BaseModel): model_config = ConfigDict( extra="forbid", ) - funnelCustomSteps: Optional[List[int]] = Field( + funnelCustomSteps: Optional[list[int]] = Field( default=None, description="Custom step numbers to get persons for. This overrides `funnelStep`. Primarily for correlation use.", ) @@ -2852,7 +2852,7 @@ class FunnelsActorsQuery(BaseModel): default=None, description="Index of the step for which we want to get the timestamp for, per person. Positive for converted persons, negative for dropped of persons.", ) - funnelStepBreakdown: Optional[Union[str, float, List[Union[str, float]]]] = Field( + funnelStepBreakdown: Optional[Union[str, float, list[Union[str, float]]]] = Field( default=None, description="The breakdown value for which to get persons for. This is an array for person and event properties, a string for groups and an integer for cohorts.", ) @@ -2887,7 +2887,7 @@ class PathsQuery(BaseModel): pathsFilter: PathsFilter = Field(..., description="Properties specific to the paths insight") properties: Optional[ Union[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -2914,11 +2914,11 @@ class FunnelCorrelationQuery(BaseModel): model_config = ConfigDict( extra="forbid", ) - funnelCorrelationEventExcludePropertyNames: Optional[List[str]] = None - funnelCorrelationEventNames: Optional[List[str]] = None - funnelCorrelationExcludeEventNames: Optional[List[str]] = None - funnelCorrelationExcludeNames: Optional[List[str]] = None - funnelCorrelationNames: Optional[List[str]] = None + funnelCorrelationEventExcludePropertyNames: Optional[list[str]] = None + funnelCorrelationEventNames: Optional[list[str]] = None + funnelCorrelationExcludeEventNames: Optional[list[str]] = None + funnelCorrelationExcludeNames: Optional[list[str]] = None + funnelCorrelationNames: Optional[list[str]] = None funnelCorrelationType: FunnelCorrelationResultsType kind: Literal["FunnelCorrelationQuery"] = "FunnelCorrelationQuery" response: Optional[FunnelCorrelationResponse] = None @@ -2956,7 +2956,7 @@ class FunnelCorrelationActorsQuery(BaseModel): funnelCorrelationPersonConverted: Optional[bool] = None funnelCorrelationPersonEntity: Optional[Union[EventsNode, ActionsNode, DataWarehouseNode]] = None funnelCorrelationPropertyValues: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -3015,7 +3015,7 @@ class ActorsQuery(BaseModel): extra="forbid", ) fixedProperties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -3038,9 +3038,9 @@ class ActorsQuery(BaseModel): default=None, description="Modifiers used when performing the query" ) offset: Optional[int] = None - orderBy: Optional[List[str]] = None + orderBy: Optional[list[str]] = None properties: Optional[ - List[ + list[ Union[ EventPropertyFilter, PersonPropertyFilter, @@ -3059,7 +3059,7 @@ class ActorsQuery(BaseModel): ] = None response: Optional[ActorsQueryResponse] = Field(default=None, description="Cached query response") search: Optional[str] = None - select: Optional[List[str]] = None + select: Optional[list[str]] = None source: Optional[Union[InsightActorsQuery, FunnelsActorsQuery, FunnelCorrelationActorsQuery, HogQLQuery]] = None @@ -3070,7 +3070,7 @@ class DataTableNode(BaseModel): allowSorting: Optional[bool] = Field( default=None, description="Can the user click on column headers to sort the table? (default: true)" ) - columns: Optional[List[str]] = Field( + columns: Optional[list[str]] = Field( default=None, description="Columns shown in the table, unless the `source` provides them." ) embedded: Optional[bool] = Field(default=None, description="Uses the embedded version of LemonTable") @@ -3078,7 +3078,7 @@ class DataTableNode(BaseModel): default=None, description="Can expand row to show raw event data (default: true)" ) full: Optional[bool] = Field(default=None, description="Show with most visual options enabled. Used in scenes.") - hiddenColumns: Optional[List[str]] = Field( + hiddenColumns: Optional[list[str]] = Field( default=None, description="Columns that aren't shown in the table, even if in columns or returned data" ) kind: Literal["DataTableNode"] = "DataTableNode" diff --git a/posthog/session_recordings/models/metadata.py b/posthog/session_recordings/models/metadata.py index 4d75c70dae4ed..dd26fde6a3b32 100644 --- a/posthog/session_recordings/models/metadata.py +++ b/posthog/session_recordings/models/metadata.py @@ -1,7 +1,7 @@ from datetime import datetime -from typing import Dict, List, Optional, TypedDict, Union, Literal +from typing import Optional, TypedDict, Union, Literal -SnapshotData = Dict +SnapshotData = dict WindowId = Optional[str] @@ -22,7 +22,7 @@ class SessionRecordingEventSummary(TypedDict): timestamp: int type: int # keys of this object should be any of EVENT_SUMMARY_DATA_INCLUSIONS - data: Dict[str, Union[int, str]] + data: dict[str, Union[int, str]] # NOTE: MatchingSessionRecordingEvent is a minimal version of full events that is used to display events matching a filter on the frontend @@ -35,7 +35,7 @@ class MatchingSessionRecordingEvent(TypedDict): class DecompressedRecordingData(TypedDict): has_next: bool - snapshot_data_by_window_id: Dict[WindowId, List[Union[SnapshotData, SessionRecordingEventSummary]]] + snapshot_data_by_window_id: dict[WindowId, list[Union[SnapshotData, SessionRecordingEventSummary]]] class RecordingMetadata(TypedDict): @@ -55,10 +55,10 @@ class RecordingMetadata(TypedDict): class RecordingMatchingEvents(TypedDict): - events: List[MatchingSessionRecordingEvent] + events: list[MatchingSessionRecordingEvent] class PersistedRecordingV1(TypedDict): version: str # "2022-12-22" - snapshot_data_by_window_id: Dict[WindowId, List[Union[SnapshotData, SessionRecordingEventSummary]]] + snapshot_data_by_window_id: dict[WindowId, list[Union[SnapshotData, SessionRecordingEventSummary]]] distinct_id: str diff --git a/posthog/session_recordings/models/session_recording.py b/posthog/session_recordings/models/session_recording.py index c217d41cef8e7..d5ac114f8a216 100644 --- a/posthog/session_recordings/models/session_recording.py +++ b/posthog/session_recordings/models/session_recording.py @@ -1,4 +1,4 @@ -from typing import Any, List, Literal, Optional +from typing import Any, Literal, Optional from django.conf import settings from django.db import models @@ -136,7 +136,7 @@ def check_viewed_for_user(self, user: Any, save_viewed=False) -> None: def build_object_storage_path(self, version: Literal["2023-08-01", "2022-12-22"]) -> str: if version == "2022-12-22": - path_parts: List[str] = [ + path_parts: list[str] = [ settings.OBJECT_STORAGE_SESSION_RECORDING_LTS_FOLDER, f"team-{self.team_id}", f"session-{self.session_id}", @@ -161,7 +161,7 @@ def get_or_build(session_id: str, team: Team) -> "SessionRecording": return SessionRecording(session_id=session_id, team=team) @staticmethod - def get_or_build_from_clickhouse(team: Team, ch_recordings: List[dict]) -> "List[SessionRecording]": + def get_or_build_from_clickhouse(team: Team, ch_recordings: list[dict]) -> "list[SessionRecording]": session_ids = sorted([recording["session_id"] for recording in ch_recordings]) recordings_by_id = { @@ -193,7 +193,7 @@ def get_or_build_from_clickhouse(team: Team, ch_recordings: List[dict]) -> "List return recordings - def set_start_url_from_urls(self, urls: Optional[List[str]] = None, first_url: Optional[str] = None): + def set_start_url_from_urls(self, urls: Optional[list[str]] = None, first_url: Optional[str] = None): if first_url: self.start_url = first_url[:512] return diff --git a/posthog/session_recordings/queries/session_query.py b/posthog/session_recordings/queries/session_query.py index d0ff7b32afb4e..eb856194806de 100644 --- a/posthog/session_recordings/queries/session_query.py +++ b/posthog/session_recordings/queries/session_query.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Union from posthog.models import Filter from posthog.models.filters.path_filter import PathFilter @@ -29,7 +29,7 @@ def __init__( self._team = team self._session_id_alias = session_id_alias - def get_query(self) -> Tuple[str, Dict]: + def get_query(self) -> tuple[str, dict]: params = {"team_id": self._team.pk} query_date_range = QueryDateRange(filter=self._filter, team=self._team, should_round=False) diff --git a/posthog/session_recordings/queries/session_recording_list_from_replay_summary.py b/posthog/session_recordings/queries/session_recording_list_from_replay_summary.py index 4f64fff7f8ab3..b9458c597c9fc 100644 --- a/posthog/session_recordings/queries/session_recording_list_from_replay_summary.py +++ b/posthog/session_recordings/queries/session_recording_list_from_replay_summary.py @@ -1,7 +1,7 @@ import dataclasses import re from datetime import datetime, timedelta -from typing import Any, Dict, List, Literal, NamedTuple, Tuple, Union +from typing import Any, Literal, NamedTuple, Union from django.conf import settings from sentry_sdk import capture_exception @@ -25,15 +25,15 @@ class SummaryEventFiltersSQL: having_conditions: str having_select: str where_conditions: str - params: Dict[str, Any] + params: dict[str, Any] class SessionRecordingQueryResult(NamedTuple): - results: List + results: list has_more_recording: bool -def _get_recording_start_time_clause(recording_filters: SessionRecordingsFilter) -> Tuple[str, Dict[str, Any]]: +def _get_recording_start_time_clause(recording_filters: SessionRecordingsFilter) -> tuple[str, dict[str, Any]]: start_time_clause = "" start_time_params = {} if recording_filters.date_from: @@ -52,7 +52,7 @@ def _get_order_by_clause(filter_order: str | None) -> str: def _get_filter_by_log_text_session_ids_clause( team: Team, recording_filters: SessionRecordingsFilter, column_name="session_id" -) -> Tuple[str, Dict[str, Any]]: +) -> tuple[str, dict[str, Any]]: if not recording_filters.console_search_query: return "", {} @@ -66,7 +66,7 @@ def _get_filter_by_log_text_session_ids_clause( def _get_filter_by_provided_session_ids_clause( recording_filters: SessionRecordingsFilter, column_name="session_id" -) -> Tuple[str, Dict[str, Any]]: +) -> tuple[str, dict[str, Any]]: if recording_filters.session_ids is None: return "", {} @@ -111,7 +111,7 @@ def ttl_days(self): # a recording spans the time boundaries # TODO This is just copied from below @cached_property - def _get_events_timestamp_clause(self) -> Tuple[str, Dict[str, Any]]: + def _get_events_timestamp_clause(self) -> tuple[str, dict[str, Any]]: timestamp_clause = "" timestamp_params = {} if self._filter.date_from: @@ -124,8 +124,8 @@ def _get_events_timestamp_clause(self) -> Tuple[str, Dict[str, Any]]: @staticmethod def _get_console_log_clause( - console_logs_filter: List[Literal["error", "warn", "info"]], - ) -> Tuple[str, Dict[str, Any]]: + console_logs_filter: list[Literal["error", "warn", "info"]], + ) -> tuple[str, dict[str, Any]]: return ( ( f"AND level in %(console_logs_levels)s", @@ -135,7 +135,7 @@ def _get_console_log_clause( else ("", {}) ) - def get_query(self) -> Tuple[str, Dict]: + def get_query(self) -> tuple[str, dict]: if not self._filter.console_search_query: return "", {} @@ -177,7 +177,7 @@ def _determine_should_join_distinct_ids(self) -> None: pass # we have to implement this from EventQuery but don't need it - def _data_to_return(self, results: List[Any]) -> List[Dict[str, Any]]: + def _data_to_return(self, results: list[Any]) -> list[dict[str, Any]]: pass _raw_persons_query = """ @@ -195,7 +195,7 @@ def _data_to_return(self, results: List[Any]) -> List[Dict[str, Any]]: {filter_by_person_uuid_condition} """ - def get_query(self) -> Tuple[str, Dict[str, Any]]: + def get_query(self) -> tuple[str, dict[str, Any]]: # we don't support PoE V1 - hopefully that's ok if self._person_on_events_mode == PersonsOnEventsMode.person_id_override_properties_on_events: return "", {} @@ -280,7 +280,7 @@ def _determine_should_join_distinct_ids(self) -> None: pass # we have to implement this from EventQuery but don't need it - def _data_to_return(self, results: List[Any]) -> List[Dict[str, Any]]: + def _data_to_return(self, results: list[Any]) -> list[dict[str, Any]]: pass def _determine_should_join_events(self): @@ -354,7 +354,7 @@ def ttl_days(self): HAVING 1=1 {event_filter_having_events_condition} """ - def format_event_filter(self, entity: Entity, prepend: str, team_id: int) -> Tuple[str, Dict[str, Any]]: + def format_event_filter(self, entity: Entity, prepend: str, team_id: int) -> tuple[str, dict[str, Any]]: filter_sql, params = format_entity_filter( team_id=team_id, entity=entity, @@ -382,8 +382,8 @@ def format_event_filter(self, entity: Entity, prepend: str, team_id: int) -> Tup @cached_property def build_event_filters(self) -> SummaryEventFiltersSQL: - event_names_to_filter: List[Union[int, str]] = [] - params: Dict = {} + event_names_to_filter: list[Union[int, str]] = [] + params: dict = {} condition_sql = "" for index, entity in enumerate(self._filter.entities): @@ -432,7 +432,7 @@ def build_event_filters(self) -> SummaryEventFiltersSQL: params=params, ) - def _get_groups_query(self) -> Tuple[str, Dict]: + def _get_groups_query(self) -> tuple[str, dict]: try: from ee.clickhouse.queries.groups_join_query import GroupsJoinQuery except ImportError: @@ -449,7 +449,7 @@ def _get_groups_query(self) -> Tuple[str, Dict]: # We want to select events beyond the range of the recording to handle the case where # a recording spans the time boundaries @cached_property - def _get_events_timestamp_clause(self) -> Tuple[str, Dict[str, Any]]: + def _get_events_timestamp_clause(self) -> tuple[str, dict[str, Any]]: timestamp_clause = "" timestamp_params = {} if self._filter.date_from: @@ -460,7 +460,7 @@ def _get_events_timestamp_clause(self) -> Tuple[str, Dict[str, Any]]: timestamp_params["event_end_time"] = self._filter.date_to + timedelta(hours=12) return timestamp_clause, timestamp_params - def get_query(self, select_event_ids: bool = False) -> Tuple[str, Dict[str, Any]]: + def get_query(self, select_event_ids: bool = False) -> tuple[str, dict[str, Any]]: if not self._determine_should_join_events(): return "", {} @@ -564,7 +564,7 @@ def _persons_join_or_subquery(self, event_filters, prop_query): return persons_join, persons_select_params, persons_sub_query @cached_property - def _get_person_id_clause(self) -> Tuple[str, Dict[str, Any]]: + def _get_person_id_clause(self) -> tuple[str, dict[str, Any]]: person_id_clause = "" person_id_params = {} if self._filter.person_uuid: @@ -572,7 +572,7 @@ def _get_person_id_clause(self) -> Tuple[str, Dict[str, Any]]: person_id_params = {"person_uuid": self._filter.person_uuid} return person_id_clause, person_id_params - def matching_events(self) -> List[str]: + def matching_events(self) -> list[str]: self._filter.hogql_context.modifiers.personsOnEventsMode = self._person_on_events_mode query, query_params = self.get_query(select_event_ids=True) query_results = sync_execute(query, {**query_params, **self._filter.hogql_context.values}) @@ -644,7 +644,7 @@ def ttl_days(self): """ @staticmethod - def _data_to_return(results: List[Any]) -> List[Dict[str, Any]]: + def _data_to_return(results: list[Any]) -> list[dict[str, Any]]: default_columns = [ "session_id", "team_id", @@ -694,7 +694,7 @@ def run(self) -> SessionRecordingQueryResult: def limit(self): return self._filter.limit or self.SESSION_RECORDINGS_DEFAULT_LIMIT - def get_query(self) -> Tuple[str, Dict[str, Any]]: + def get_query(self) -> tuple[str, dict[str, Any]]: offset = self._filter.offset or 0 base_params = { @@ -758,7 +758,7 @@ def get_query(self) -> Tuple[str, Dict[str, Any]]: def duration_clause( self, duration_filter_type: Literal["duration", "active_seconds", "inactive_seconds"], - ) -> Tuple[str, Dict[str, Any]]: + ) -> tuple[str, dict[str, Any]]: duration_clause = "" duration_params = {} if self._filter.recording_duration_filter: @@ -775,7 +775,7 @@ def duration_clause( return duration_clause, duration_params @staticmethod - def _get_console_log_clause(console_logs_filter: List[Literal["error", "warn", "info"]]) -> str: + def _get_console_log_clause(console_logs_filter: list[Literal["error", "warn", "info"]]) -> str: # to avoid a CH migration we map from info to log when constructing the query here filters = [f"console_{'log' if log == 'info' else log}_count > 0" for log in console_logs_filter] return f"AND ({' OR '.join(filters)})" if filters else "" diff --git a/posthog/session_recordings/queries/session_recording_properties.py b/posthog/session_recordings/queries/session_recording_properties.py index e7c5544f14fe7..2d2ef187c0407 100644 --- a/posthog/session_recordings/queries/session_recording_properties.py +++ b/posthog/session_recordings/queries/session_recording_properties.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Tuple +from typing import TYPE_CHECKING, Any, NamedTuple from posthog.client import sync_execute from posthog.models.event.util import parse_properties @@ -14,12 +14,12 @@ class EventFiltersSQL(NamedTuple): aggregate_select_clause: str aggregate_having_clause: str where_conditions: str - params: Dict[str, Any] + params: dict[str, Any] class SessionRecordingProperties(EventQuery): _filter: SessionRecordingsFilter - _session_ids: List[str] + _session_ids: list[str] SESSION_RECORDING_PROPERTIES_ALLOWLIST = { "$os", @@ -47,7 +47,7 @@ class SessionRecordingProperties(EventQuery): GROUP BY session_id """ - def __init__(self, team: "Team", session_ids: List[str], filter: SessionRecordingsFilter): + def __init__(self, team: "Team", session_ids: list[str], filter: SessionRecordingsFilter): super().__init__(team=team, filter=filter) self._session_ids = sorted(session_ids) # Sort for stable queries @@ -56,7 +56,7 @@ def _determine_should_join_distinct_ids(self) -> None: # We want to select events beyond the range of the recording to handle the case where # a recording spans the time boundaries - def _get_events_timestamp_clause(self) -> Tuple[str, Dict[str, Any]]: + def _get_events_timestamp_clause(self) -> tuple[str, dict[str, Any]]: timestamp_clause = "" timestamp_params = {} if self._filter.date_from: @@ -67,11 +67,11 @@ def _get_events_timestamp_clause(self) -> Tuple[str, Dict[str, Any]]: timestamp_params["event_end_time"] = self._filter.date_to + timedelta(hours=12) return timestamp_clause, timestamp_params - def format_session_recording_id_filters(self) -> Tuple[str, Dict]: + def format_session_recording_id_filters(self) -> tuple[str, dict]: where_conditions = "AND session_id IN %(session_ids)s" return where_conditions, {"session_ids": self._session_ids} - def get_query(self) -> Tuple[str, Dict[str, Any]]: + def get_query(self) -> tuple[str, dict[str, Any]]: base_params = {"team_id": self._team_id} ( events_timestamp_clause, @@ -90,7 +90,7 @@ def get_query(self) -> Tuple[str, Dict[str, Any]]: {**base_params, **events_timestamp_params, **session_ids_params}, ) - def _data_to_return(self, results: List[Any]) -> List[Dict[str, Any]]: + def _data_to_return(self, results: list[Any]) -> list[dict[str, Any]]: return [ { "session_id": row[0], @@ -99,7 +99,7 @@ def _data_to_return(self, results: List[Any]) -> List[Dict[str, Any]]: for row in results ] - def run(self) -> List: + def run(self) -> list: query, query_params = self.get_query() query_results = sync_execute(query, query_params) session_recording_properties = self._data_to_return(query_results) diff --git a/posthog/session_recordings/queries/session_replay_events.py b/posthog/session_recordings/queries/session_replay_events.py index fbb3577bf03d7..226d27154fd85 100644 --- a/posthog/session_recordings/queries/session_replay_events.py +++ b/posthog/session_recordings/queries/session_replay_events.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from typing import Optional, Tuple, List +from typing import Optional from django.conf import settings @@ -75,7 +75,7 @@ def get_metadata( ) ) - replay_response: List[Tuple] = sync_execute( + replay_response: list[tuple] = sync_execute( query, { "team_id": team.pk, @@ -107,8 +107,8 @@ def get_metadata( ) def get_events( - self, session_id: str, team: Team, metadata: RecordingMetadata, events_to_ignore: List[str] | None - ) -> Tuple[List | None, List | None]: + self, session_id: str, team: Team, metadata: RecordingMetadata, events_to_ignore: list[str] | None + ) -> tuple[list | None, list | None]: from posthog.schema import HogQLQuery, HogQLQueryResponse from posthog.hogql_queries.hogql_query_runner import HogQLQueryRunner diff --git a/posthog/session_recordings/queries/test/session_replay_sql.py b/posthog/session_recordings/queries/test/session_replay_sql.py index b72c64dbc0f68..fbec2ea065036 100644 --- a/posthog/session_recordings/queries/test/session_replay_sql.py +++ b/posthog/session_recordings/queries/test/session_replay_sql.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Optional, List, Dict +from typing import Optional from uuid import uuid4 from dateutil.parser import parse @@ -113,7 +113,7 @@ def produce_replay_summary( console_log_count: Optional[int] = None, console_warn_count: Optional[int] = None, console_error_count: Optional[int] = None, - log_messages: Dict[str, List[str]] | None = None, + log_messages: dict[str, list[str]] | None = None, snapshot_source: str | None = None, ): if log_messages is None: diff --git a/posthog/session_recordings/queries/test/test_session_recording_list_from_session_replay.py b/posthog/session_recordings/queries/test/test_session_recording_list_from_session_replay.py index 1af1554415de3..5abfc3727fe7f 100644 --- a/posthog/session_recordings/queries/test/test_session_recording_list_from_session_replay.py +++ b/posthog/session_recordings/queries/test/test_session_recording_list_from_session_replay.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Dict from uuid import uuid4 from dateutil.relativedelta import relativedelta @@ -76,7 +75,7 @@ def create_event( properties=properties, ) - def _filter_recordings_by(self, recordings_filter: Dict) -> SessionRecordingQueryResult: + def _filter_recordings_by(self, recordings_filter: dict) -> SessionRecordingQueryResult: the_filter = SessionRecordingsFilter(team=self.team, data=recordings_filter) session_recording_list_instance = SessionRecordingListFromReplaySummary(filter=the_filter, team=self.team) return session_recording_list_instance.run() diff --git a/posthog/session_recordings/realtime_snapshots.py b/posthog/session_recordings/realtime_snapshots.py index d6890c63517e1..8e943db34a600 100644 --- a/posthog/session_recordings/realtime_snapshots.py +++ b/posthog/session_recordings/realtime_snapshots.py @@ -1,6 +1,6 @@ import json from time import sleep -from typing import Dict, List, Optional +from typing import Optional import structlog from prometheus_client import Counter @@ -54,7 +54,7 @@ def publish_subscription(team_id: str, session_id: str) -> None: raise e -def get_realtime_snapshots(team_id: str, session_id: str, attempt_count=0) -> Optional[List[Dict]]: +def get_realtime_snapshots(team_id: str, session_id: str, attempt_count=0) -> Optional[list[dict]]: try: redis = get_client(settings.SESSION_RECORDING_REDIS_URL) key = get_key(team_id, session_id) diff --git a/posthog/session_recordings/session_recording_api.py b/posthog/session_recordings/session_recording_api.py index e7f4ac7769696..d9a9fc303d1bb 100644 --- a/posthog/session_recordings/session_recording_api.py +++ b/posthog/session_recordings/session_recording_api.py @@ -3,7 +3,7 @@ from datetime import datetime, timedelta, timezone import json -from typing import Any, List, Type, cast, Dict, Tuple +from typing import Any, cast from django.conf import settings @@ -191,7 +191,7 @@ class SessionRecordingSnapshotsSerializer(serializers.Serializer): def list_recordings_response( - filter: SessionRecordingsFilter, request: request.Request, serializer_context: Dict[str, Any] + filter: SessionRecordingsFilter, request: request.Request, serializer_context: dict[str, Any] ) -> Response: (recordings, timings) = list_recordings(filter, request, context=serializer_context) response = Response(recordings) @@ -211,7 +211,7 @@ class SessionRecordingViewSet(TeamAndOrgViewSetMixin, viewsets.GenericViewSet): sharing_enabled_actions = ["retrieve", "snapshots", "snapshot_file"] - def get_serializer_class(self) -> Type[serializers.Serializer]: + def get_serializer_class(self) -> type[serializers.Serializer]: if isinstance(self.request.successful_authenticator, SharingAccessTokenAuthentication): return SessionRecordingSharedSerializer else: @@ -252,7 +252,7 @@ def matching_events(self, request: request.Request, *args: Any, **kwargs: Any) - "Must specify at least one event or action filter", ) - matching_events: List[str] = SessionIdEventsQuery(filter=filter, team=self.team).matching_events() + matching_events: list[str] = SessionIdEventsQuery(filter=filter, team=self.team).matching_events() return JsonResponse(data={"results": matching_events}) # Returns metadata about the recording @@ -342,9 +342,9 @@ def snapshots(self, request: request.Request, **kwargs): SNAPSHOT_SOURCE_REQUESTED.labels(source=source).inc() if not source: - sources: List[dict] = [] + sources: list[dict] = [] - blob_keys: List[str] | None = None + blob_keys: list[str] | None = None if recording.object_storage_path: if recording.storage_version == "2023-08-01": blob_prefix = recording.object_storage_path @@ -603,8 +603,8 @@ def error_clusters(self, request: request.Request, **kwargs): def list_recordings( - filter: SessionRecordingsFilter, request: request.Request, context: Dict[str, Any] -) -> Tuple[Dict, Dict]: + filter: SessionRecordingsFilter, request: request.Request, context: dict[str, Any] +) -> tuple[dict, dict]: """ As we can store recordings in S3 or in Clickhouse we need to do a few things here @@ -617,7 +617,7 @@ def list_recordings( all_session_ids = filter.session_ids - recordings: List[SessionRecording] = [] + recordings: list[SessionRecording] = [] more_recordings_available = False team = context["get_team"]() @@ -655,7 +655,7 @@ def list_recordings( if all_session_ids: recordings = sorted( recordings, - key=lambda x: cast(List[str], all_session_ids).index(x.session_id), + key=lambda x: cast(list[str], all_session_ids).index(x.session_id), ) if not request.user.is_authenticated: # for mypy diff --git a/posthog/session_recordings/session_recording_helpers.py b/posthog/session_recordings/session_recording_helpers.py index 1eccc2be26e32..8dfb1c0ad2396 100644 --- a/posthog/session_recordings/session_recording_helpers.py +++ b/posthog/session_recordings/session_recording_helpers.py @@ -3,7 +3,8 @@ import json from collections import defaultdict from datetime import datetime, timezone -from typing import Any, Callable, Dict, Generator, List, Tuple +from typing import Any +from collections.abc import Callable, Generator from dateutil.parser import parse from prometheus_client import Counter @@ -89,10 +90,10 @@ class RRWEB_MAP_EVENT_DATA_TYPE: ] -Event = Dict[str, Any] +Event = dict[str, Any] -def split_replay_events(events: List[Event]) -> Tuple[List[Event], List[Event]]: +def split_replay_events(events: list[Event]) -> tuple[list[Event], list[Event]]: replay, other = [], [] for event in events: @@ -102,12 +103,12 @@ def split_replay_events(events: List[Event]) -> Tuple[List[Event], List[Event]]: # TODO is this covered by enough tests post-blob ingester rollout -def preprocess_replay_events_for_blob_ingestion(events: List[Event], max_size_bytes=1024 * 1024) -> List[Event]: +def preprocess_replay_events_for_blob_ingestion(events: list[Event], max_size_bytes=1024 * 1024) -> list[Event]: return _process_windowed_events(events, lambda x: preprocess_replay_events(x, max_size_bytes=max_size_bytes)) def preprocess_replay_events( - _events: List[Event] | Generator[Event, None, None], max_size_bytes=1024 * 1024 + _events: list[Event] | Generator[Event, None, None], max_size_bytes=1024 * 1024 ) -> Generator[Event, None, None]: """ The events going to blob ingestion are uncompressed (the compression happens in the Kafka producer) @@ -135,7 +136,7 @@ def preprocess_replay_events( window_id = events[0]["properties"].get("$window_id") snapshot_source = events[0]["properties"].get("$snapshot_source", "web") - def new_event(items: List[dict] | None = None) -> Event: + def new_event(items: list[dict] | None = None) -> Event: return { **events[0], "event": "$snapshot_items", # New event name to avoid confusion with the old $snapshot event @@ -151,7 +152,7 @@ def new_event(items: List[dict] | None = None) -> Event: # 1. Group by $snapshot_bytes if any of the events have it if events[0]["properties"].get("$snapshot_bytes"): - current_event: Dict | None = None + current_event: dict | None = None current_event_size = 0 for event in events: @@ -208,13 +209,13 @@ def new_event(items: List[dict] | None = None) -> Event: def _process_windowed_events( - events: List[Event], fn: Callable[[List[Any]], Generator[Event, None, None]] -) -> List[Event]: + events: list[Event], fn: Callable[[list[Any]], Generator[Event, None, None]] +) -> list[Event]: """ Helper method to simplify grouping events by window_id and session_id, processing them with the given function, and then returning the flattened list """ - result: List[Event] = [] + result: list[Event] = [] snapshots_by_session_and_window_id = defaultdict(list) for event in events: @@ -228,7 +229,7 @@ def _process_windowed_events( return result -def is_unprocessed_snapshot_event(event: Dict) -> bool: +def is_unprocessed_snapshot_event(event: dict) -> bool: try: is_snapshot = event["event"] == "$snapshot" except KeyError: @@ -274,5 +275,5 @@ def convert_to_timestamp(source: str) -> int: return int(parse(source).timestamp() * 1000) -def byte_size_dict(x: Dict | List) -> int: +def byte_size_dict(x: dict | list) -> int: return len(json.dumps(x)) diff --git a/posthog/session_recordings/snapshots/convert_legacy_snapshots.py b/posthog/session_recordings/snapshots/convert_legacy_snapshots.py index 963016d0e869a..d2d4ba2c4b4bd 100644 --- a/posthog/session_recordings/snapshots/convert_legacy_snapshots.py +++ b/posthog/session_recordings/snapshots/convert_legacy_snapshots.py @@ -1,5 +1,4 @@ import json -from typing import Dict import structlog from prometheus_client import Histogram @@ -67,7 +66,7 @@ def _prepare_legacy_content(content: str) -> str: return _convert_legacy_format_from_lts_storage(json_content) -def _convert_legacy_format_from_lts_storage(lts_formatted_data: Dict) -> str: +def _convert_legacy_format_from_lts_storage(lts_formatted_data: dict) -> str: """ The latest version is JSONL formatted data. Each line is json containing a window_id and a data array. diff --git a/posthog/session_recordings/test/test_lts_session_recordings.py b/posthog/session_recordings/test/test_lts_session_recordings.py index 7d60d07defb2c..bd6dfc39d246d 100644 --- a/posthog/session_recordings/test/test_lts_session_recordings.py +++ b/posthog/session_recordings/test/test_lts_session_recordings.py @@ -1,5 +1,4 @@ import uuid -from typing import List from unittest.mock import patch, MagicMock, call, Mock from rest_framework import status @@ -32,7 +31,7 @@ def test_2023_08_01_version_stored_snapshots_can_be_gathered( session_id = str(uuid.uuid4()) lts_storage_path = "purposefully/not/what/we/would/calculate/to/prove/this/is/used" - def list_objects_func(path: str) -> List[str]: + def list_objects_func(path: str) -> list[str]: # this mock simulates a recording whose blob storage has been deleted by TTL # but which has been stored in LTS blob storage if path == lts_storage_path: @@ -88,7 +87,7 @@ def test_original_version_stored_snapshots_can_be_gathered( session_id = str(uuid.uuid4()) lts_storage_path = "1234-5678" - def list_objects_func(_path: str) -> List[str]: + def list_objects_func(_path: str) -> list[str]: return [] mock_list_objects.side_effect = list_objects_func @@ -138,7 +137,7 @@ def test_2023_08_01_version_stored_snapshots_can_be_loaded( session_id = str(uuid.uuid4()) lts_storage_path = "purposefully/not/what/we/would/calculate/to/prove/this/is/used" - def list_objects_func(path: str) -> List[str]: + def list_objects_func(path: str) -> list[str]: # this mock simulates a recording whose blob storage has been deleted by TTL # but which has been stored in LTS blob storage if path == lts_storage_path: @@ -208,7 +207,7 @@ def test_original_version_stored_snapshots_can_be_loaded_without_upversion( session_id = str(uuid.uuid4()) lts_storage_path = "1234-5678" - def list_objects_func(path: str) -> List[str]: + def list_objects_func(path: str) -> list[str]: return [] mock_list_objects.side_effect = list_objects_func diff --git a/posthog/session_recordings/test/test_session_recording_helpers.py b/posthog/session_recordings/test/test_session_recording_helpers.py index b6b83e02c28d9..a13b131fb3160 100644 --- a/posthog/session_recordings/test/test_session_recording_helpers.py +++ b/posthog/session_recordings/test/test_session_recording_helpers.py @@ -3,7 +3,7 @@ import random import string from datetime import datetime -from typing import Any, List, Tuple +from typing import Any import pytest from pytest_mock import MockerFixture @@ -27,7 +27,7 @@ def create_activity_data(timestamp: datetime, is_active: bool): ) -def mock_capture_flow(events: List[dict], max_size_bytes=512 * 1024) -> Tuple[List[dict], List[dict]]: +def mock_capture_flow(events: list[dict], max_size_bytes=512 * 1024) -> tuple[list[dict], list[dict]]: """ Returns the legacy events and the new flow ones """ @@ -422,7 +422,7 @@ def test_new_ingestion_groups_using_snapshot_bytes_if_possible(raw_snapshot_even "something": "small", } - events: List[Any] = [ + events: list[Any] = [ { "event": "$snapshot", "properties": { diff --git a/posthog/session_recordings/test/test_session_recordings.py b/posthog/session_recordings/test/test_session_recordings.py index 78f92c24a0a30..6b0721d673083 100644 --- a/posthog/session_recordings/test/test_session_recordings.py +++ b/posthog/session_recordings/test/test_session_recordings.py @@ -1,7 +1,6 @@ import time import uuid from datetime import datetime, timedelta, timezone -from typing import List from unittest.mock import ANY, patch, MagicMock, call from urllib.parse import urlencode @@ -65,7 +64,7 @@ def create_snapshot( # because we use `now()` in the CH queries which don't know about any frozen time # @snapshot_clickhouse_queries def test_get_session_recordings(self): - twelve_distinct_ids: List[str] = [f"user_one_{i}" for i in range(12)] + twelve_distinct_ids: list[str] = [f"user_one_{i}" for i in range(12)] user = Person.objects.create( team=self.team, @@ -132,7 +131,7 @@ def test_can_list_recordings_even_when_the_person_has_multiple_distinct_ids(self # almost duplicate of test_get_session_recordings above # but if we have multiple distinct ids on a recording the snapshot # varies which makes the snapshot useless - twelve_distinct_ids: List[str] = [f"user_one_{i}" for i in range(12)] + twelve_distinct_ids: list[str] = [f"user_one_{i}" for i in range(12)] Person.objects.create( team=self.team, @@ -577,7 +576,7 @@ def test_get_snapshots_v2_from_lts(self, mock_list_objects: MagicMock, _mock_exi object_storage_path="an lts stored object path", ) - def list_objects_func(path: str) -> List[str]: + def list_objects_func(path: str) -> list[str]: # this mock simulates a recording whose blob storage has been deleted by TTL # but which has been stored in LTS blob storage if path == "an lts stored object path": diff --git a/posthog/settings/__init__.py b/posthog/settings/__init__.py index 455b7e8dc34a1..faf2e466764d2 100644 --- a/posthog/settings/__init__.py +++ b/posthog/settings/__init__.py @@ -13,7 +13,6 @@ # isort: skip_file import os -from typing import Dict, List # :TRICKY: Imported before anything else to support overloads from posthog.settings.overrides import * @@ -68,7 +67,7 @@ DISABLE_MMDB = get_from_env( "DISABLE_MMDB", TEST, type_cast=str_to_bool ) # plugin server setting disabling GeoIP feature -PLUGINS_PREINSTALLED_URLS: List[str] = ( +PLUGINS_PREINSTALLED_URLS: list[str] = ( os.getenv( "PLUGINS_PREINSTALLED_URLS", "https://www.npmjs.com/package/@posthog/geoip-plugin", @@ -100,7 +99,7 @@ # Wether to use insight queries converted to HogQL. HOGQL_INSIGHTS_OVERRIDE = get_from_env("HOGQL_INSIGHTS_OVERRIDE", optional=True, type_cast=str_to_bool) -HOOK_EVENTS: Dict[str, str] = {} +HOOK_EVENTS: dict[str, str] = {} # Support creating multiple organizations in a single instance. Requires a premium license. MULTI_ORG_ENABLED = get_from_env("MULTI_ORG_ENABLED", False, type_cast=str_to_bool) diff --git a/posthog/settings/data_stores.py b/posthog/settings/data_stores.py index f3402a748111f..d175f04f07c2a 100644 --- a/posthog/settings/data_stores.py +++ b/posthog/settings/data_stores.py @@ -1,6 +1,5 @@ import json import os -from typing import List from urllib.parse import urlparse import dj_database_url @@ -173,7 +172,7 @@ def postgres_config(host: str) -> dict: READONLY_CLICKHOUSE_PASSWORD = os.getenv("READONLY_CLICKHOUSE_PASSWORD", None) -def _parse_kafka_hosts(hosts_string: str) -> List[str]: +def _parse_kafka_hosts(hosts_string: str) -> list[str]: hosts = [] for host in hosts_string.split(","): if "://" in host: diff --git a/posthog/settings/logs.py b/posthog/settings/logs.py index 8f41f3e6c21e6..f8f21294e37a3 100644 --- a/posthog/settings/logs.py +++ b/posthog/settings/logs.py @@ -1,7 +1,6 @@ import logging import os import threading -from typing import List import structlog @@ -27,7 +26,7 @@ def add_pid_and_tid( # To enable standard library logs to be formatted via structlog, we add this # `foreign_pre_chain` to both formatters. -foreign_pre_chain: List[structlog.types.Processor] = [ +foreign_pre_chain: list[structlog.types.Processor] = [ structlog.contextvars.merge_contextvars, structlog.processors.TimeStamper(fmt="iso"), structlog.stdlib.add_logger_name, diff --git a/posthog/settings/session_replay.py b/posthog/settings/session_replay.py index 4cd8a429aa028..429f3207dccf7 100644 --- a/posthog/settings/session_replay.py +++ b/posthog/settings/session_replay.py @@ -1,5 +1,3 @@ -from typing import List - from posthog.settings import get_from_env, get_list from posthog.utils import str_to_bool @@ -18,7 +16,7 @@ "REALTIME_SNAPSHOTS_FROM_REDIS_ATTEMPT_TIMEOUT_SECONDS", 0.2, type_cast=float ) -REPLAY_EMBEDDINGS_ALLOWED_TEAMS: List[str] = get_list(get_from_env("REPLAY_EMBEDDINGS_ALLOWED_TEAM", "", type_cast=str)) +REPLAY_EMBEDDINGS_ALLOWED_TEAMS: list[str] = get_list(get_from_env("REPLAY_EMBEDDINGS_ALLOWED_TEAM", "", type_cast=str)) REPLAY_EMBEDDINGS_BATCH_SIZE = get_from_env("REPLAY_EMBEDDINGS_BATCH_SIZE", 10, type_cast=int) REPLAY_EMBEDDINGS_MIN_DURATION_SECONDS = get_from_env("REPLAY_EMBEDDINGS_MIN_DURATION_SECONDS", 30, type_cast=int) REPLAY_EMBEDDINGS_CALCULATION_CELERY_INTERVAL_SECONDS = get_from_env( diff --git a/posthog/settings/temporal.py b/posthog/settings/temporal.py index ccb5fbfb0db3f..ce0e72172eabb 100644 --- a/posthog/settings/temporal.py +++ b/posthog/settings/temporal.py @@ -1,5 +1,4 @@ import os -from typing import Dict from posthog.settings.utils import get_list, get_from_env @@ -24,6 +23,6 @@ CLICKHOUSE_MAX_EXECUTION_TIME = get_from_env("CLICKHOUSE_MAX_EXECUTION_TIME", 0, type_cast=int) CLICKHOUSE_MAX_BLOCK_SIZE_DEFAULT = get_from_env("CLICKHOUSE_MAX_BLOCK_SIZE_DEFAULT", 10000, type_cast=int) # Comma separated list of overrides in the format "team_id:block_size" -CLICKHOUSE_MAX_BLOCK_SIZE_OVERRIDES: Dict[int, int] = dict( +CLICKHOUSE_MAX_BLOCK_SIZE_OVERRIDES: dict[int, int] = dict( [map(int, o.split(":")) for o in os.getenv("CLICKHOUSE_MAX_BLOCK_SIZE_OVERRIDES", "").split(",") if o] # type: ignore ) diff --git a/posthog/settings/utils.py b/posthog/settings/utils.py index 6dd22dbf97cf8..eead270c7bd7d 100644 --- a/posthog/settings/utils.py +++ b/posthog/settings/utils.py @@ -1,5 +1,6 @@ import os -from typing import Any, Callable, List, Optional, Set +from typing import Any, Optional +from collections.abc import Callable from django.core.exceptions import ImproperlyConfigured @@ -28,13 +29,13 @@ def get_from_env( return value -def get_list(text: str) -> List[str]: +def get_list(text: str) -> list[str]: if not text: return [] return [item.strip() for item in text.split(",")] -def get_set(text: str) -> Set[str]: +def get_set(text: str) -> set[str]: if not text: return set() return {item.strip() for item in text.split(",")} diff --git a/posthog/settings/web.py b/posthog/settings/web.py index ee6961de70e79..b80c1baab02d6 100644 --- a/posthog/settings/web.py +++ b/posthog/settings/web.py @@ -1,7 +1,6 @@ # Web app specific settings/middleware/apps setup import os from datetime import timedelta -from typing import List from corsheaders.defaults import default_headers @@ -160,7 +159,7 @@ SOCIAL_AUTH_USER_MODEL = "posthog.User" SOCIAL_AUTH_REDIRECT_IS_HTTPS = get_from_env("SOCIAL_AUTH_REDIRECT_IS_HTTPS", not DEBUG, type_cast=str_to_bool) -AUTHENTICATION_BACKENDS: List[str] = [ +AUTHENTICATION_BACKENDS: list[str] = [ "axes.backends.AxesBackend", "social_core.backends.github.GithubOAuth2", "social_core.backends.gitlab.GitLabOAuth2", diff --git a/posthog/storage/object_storage.py b/posthog/storage/object_storage.py index a1ff639b1c293..147b02436fa6e 100644 --- a/posthog/storage/object_storage.py +++ b/posthog/storage/object_storage.py @@ -1,5 +1,5 @@ import abc -from typing import Optional, Union, List, Dict +from typing import Optional, Union import structlog from boto3 import client @@ -26,7 +26,7 @@ def get_presigned_url(self, bucket: str, file_key: str, expiration: int = 3600) pass @abc.abstractmethod - def list_objects(self, bucket: str, prefix: str) -> Optional[List[str]]: + def list_objects(self, bucket: str, prefix: str) -> Optional[list[str]]: pass @abc.abstractmethod @@ -38,11 +38,11 @@ def read_bytes(self, bucket: str, key: str) -> Optional[bytes]: pass @abc.abstractmethod - def tag(self, bucket: str, key: str, tags: Dict[str, str]) -> None: + def tag(self, bucket: str, key: str, tags: dict[str, str]) -> None: pass @abc.abstractmethod - def write(self, bucket: str, key: str, content: Union[str, bytes], extras: Dict | None) -> None: + def write(self, bucket: str, key: str, content: Union[str, bytes], extras: dict | None) -> None: pass @abc.abstractmethod @@ -60,7 +60,7 @@ def head_bucket(self, bucket: str): def get_presigned_url(self, bucket: str, file_key: str, expiration: int = 3600) -> Optional[str]: pass - def list_objects(self, bucket: str, prefix: str) -> Optional[List[str]]: + def list_objects(self, bucket: str, prefix: str) -> Optional[list[str]]: pass def read(self, bucket: str, key: str) -> Optional[str]: @@ -69,10 +69,10 @@ def read(self, bucket: str, key: str) -> Optional[str]: def read_bytes(self, bucket: str, key: str) -> Optional[bytes]: pass - def tag(self, bucket: str, key: str, tags: Dict[str, str]) -> None: + def tag(self, bucket: str, key: str, tags: dict[str, str]) -> None: pass - def write(self, bucket: str, key: str, content: Union[str, bytes], extras: Dict | None) -> None: + def write(self, bucket: str, key: str, content: Union[str, bytes], extras: dict | None) -> None: pass def copy_objects(self, bucket: str, source_prefix: str, target_prefix: str) -> int | None: @@ -103,7 +103,7 @@ def get_presigned_url(self, bucket: str, file_key: str, expiration: int = 3600) capture_exception(e) return None - def list_objects(self, bucket: str, prefix: str) -> Optional[List[str]]: + def list_objects(self, bucket: str, prefix: str) -> Optional[list[str]]: try: s3_response = self.aws_client.list_objects_v2(Bucket=bucket, Prefix=prefix) if s3_response.get("Contents"): @@ -143,7 +143,7 @@ def read_bytes(self, bucket: str, key: str) -> Optional[bytes]: capture_exception(e) raise ObjectStorageError("read failed") from e - def tag(self, bucket: str, key: str, tags: Dict[str, str]) -> None: + def tag(self, bucket: str, key: str, tags: dict[str, str]) -> None: try: self.aws_client.put_object_tagging( Bucket=bucket, @@ -155,7 +155,7 @@ def tag(self, bucket: str, key: str, tags: Dict[str, str]) -> None: capture_exception(e) raise ObjectStorageError("tag failed") from e - def write(self, bucket: str, key: str, content: Union[str, bytes], extras: Dict | None) -> None: + def write(self, bucket: str, key: str, content: Union[str, bytes], extras: dict | None) -> None: s3_response = {} try: s3_response = self.aws_client.put_object(Bucket=bucket, Body=content, Key=key, **(extras or {})) @@ -218,7 +218,7 @@ def object_storage_client() -> ObjectStorageClient: return _client -def write(file_name: str, content: Union[str, bytes], extras: Dict | None = None) -> None: +def write(file_name: str, content: Union[str, bytes], extras: dict | None = None) -> None: return object_storage_client().write( bucket=settings.OBJECT_STORAGE_BUCKET, key=file_name, @@ -227,7 +227,7 @@ def write(file_name: str, content: Union[str, bytes], extras: Dict | None = None ) -def tag(file_name: str, tags: Dict[str, str]) -> None: +def tag(file_name: str, tags: dict[str, str]) -> None: return object_storage_client().tag(bucket=settings.OBJECT_STORAGE_BUCKET, key=file_name, tags=tags) @@ -239,7 +239,7 @@ def read_bytes(file_name: str) -> Optional[bytes]: return object_storage_client().read_bytes(bucket=settings.OBJECT_STORAGE_BUCKET, key=file_name) -def list_objects(prefix: str) -> Optional[List[str]]: +def list_objects(prefix: str) -> Optional[list[str]]: return object_storage_client().list_objects(bucket=settings.OBJECT_STORAGE_BUCKET, prefix=prefix) diff --git a/posthog/storage/test/test_object_storage.py b/posthog/storage/test/test_object_storage.py index f24114911ba9e..3737ca155ee6f 100644 --- a/posthog/storage/test/test_object_storage.py +++ b/posthog/storage/test/test_object_storage.py @@ -57,7 +57,7 @@ def test_write_and_read_works_with_known_byte_content(self) -> None: chunk_id = uuid.uuid4() name = f"{session_id}/{0}-{chunk_id}" file_name = f"{TEST_BUCKET}/test_write_and_read_works_with_known_content/{name}" - write(file_name, "my content".encode("utf-8")) + write(file_name, b"my content") self.assertEqual(read(file_name), "my content") def test_can_generate_presigned_url_for_existing_file(self) -> None: @@ -66,7 +66,7 @@ def test_can_generate_presigned_url_for_existing_file(self) -> None: chunk_id = uuid.uuid4() name = f"{session_id}/{0}-{chunk_id}" file_name = f"{TEST_BUCKET}/test_can_generate_presigned_url_for_existing_file/{name}" - write(file_name, "my content".encode("utf-8")) + write(file_name, b"my content") presigned_url = get_presigned_url(file_name) assert presigned_url is not None @@ -93,7 +93,7 @@ def test_can_list_objects_with_prefix(self) -> None: for file in ["a", "b", "c"]: file_name = f"{TEST_BUCKET}/{shared_prefix}/{file}" - write(file_name, "my content".encode("utf-8")) + write(file_name, b"my content") listing = list_objects(prefix=f"{TEST_BUCKET}/{shared_prefix}") @@ -117,7 +117,7 @@ def test_can_copy_objects_between_prefixes(self) -> None: for file in ["a", "b", "c"]: file_name = f"{TEST_BUCKET}/{shared_prefix}/{file}" - write(file_name, "my content".encode("utf-8")) + write(file_name, b"my content") copied_count = copy_objects( source_prefix=f"{TEST_BUCKET}/{shared_prefix}", @@ -142,7 +142,7 @@ def test_can_safely_copy_objects_from_unknown_prefix(self) -> None: for file in ["a", "b", "c"]: file_name = f"{TEST_BUCKET}/{shared_prefix}/{file}" - write(file_name, "my content".encode("utf-8")) + write(file_name, b"my content") copied_count = copy_objects( source_prefix=f"nothing_here", diff --git a/posthog/tasks/calculate_cohort.py b/posthog/tasks/calculate_cohort.py index 7dba512d6c86c..35ccc8fe9ece6 100644 --- a/posthog/tasks/calculate_cohort.py +++ b/posthog/tasks/calculate_cohort.py @@ -1,5 +1,5 @@ import time -from typing import Any, Dict, List, Optional +from typing import Any, Optional import structlog from celery import shared_task @@ -53,7 +53,7 @@ def calculate_cohort_ch(cohort_id: int, pending_version: int, initiating_user_id @shared_task(ignore_result=True, max_retries=1) -def calculate_cohort_from_list(cohort_id: int, items: List[str]) -> None: +def calculate_cohort_from_list(cohort_id: int, items: list[str]) -> None: start_time = time.time() cohort = Cohort.objects.get(pk=cohort_id) @@ -62,7 +62,7 @@ def calculate_cohort_from_list(cohort_id: int, items: List[str]) -> None: @shared_task(ignore_result=True, max_retries=1) -def insert_cohort_from_insight_filter(cohort_id: int, filter_data: Dict[str, Any]) -> None: +def insert_cohort_from_insight_filter(cohort_id: int, filter_data: dict[str, Any]) -> None: from posthog.api.cohort import ( insert_cohort_actors_into_ch, insert_cohort_people_into_pg, diff --git a/posthog/tasks/email.py b/posthog/tasks/email.py index d06d15ee12ace..2d7198dc2d8ca 100644 --- a/posthog/tasks/email.py +++ b/posthog/tasks/email.py @@ -1,6 +1,6 @@ import uuid from datetime import datetime -from typing import List, Optional +from typing import Optional import posthoganalytics import structlog @@ -281,7 +281,7 @@ def send_async_migration_errored_email(migration_key: str, time: str, error: str send_message_to_all_staff_users(message) -def get_users_for_orgs_with_no_ingested_events(org_created_from: datetime, org_created_to: datetime) -> List[User]: +def get_users_for_orgs_with_no_ingested_events(org_created_from: datetime, org_created_to: datetime) -> list[User]: # Get all users for organization that haven't ingested any events users = [] recently_created_organizations = Organization.objects.filter( diff --git a/posthog/tasks/exports/csv_exporter.py b/posthog/tasks/exports/csv_exporter.py index 22cf1004ac07e..489bf64e74036 100644 --- a/posthog/tasks/exports/csv_exporter.py +++ b/posthog/tasks/exports/csv_exporter.py @@ -1,6 +1,7 @@ import datetime import io -from typing import Any, Dict, List, Optional, Tuple, Generator +from typing import Any, Optional +from collections.abc import Generator from urllib.parse import parse_qsl, quote, urlencode, urlparse, urlunparse import requests @@ -53,14 +54,14 @@ # 5. We save the final blob output and update the ExportedAsset -def add_query_params(url: str, params: Dict[str, str]) -> str: +def add_query_params(url: str, params: dict[str, str]) -> str: """ Uses parse_qsl because parse_qs turns all values into lists but doesn't unbox them when re-encoded """ parsed = urlparse(url) query_params = parse_qsl(parsed.query, keep_blank_values=True) - update_params: List[Tuple[str, Any]] = [] + update_params: list[tuple[str, Any]] = [] for param, value in query_params: if param in params: update_params.append((param, params.pop(param))) @@ -265,7 +266,7 @@ def get_from_hogql_query(exported_asset: ExportedAsset, limit: int, resource: di def _export_to_dict(exported_asset: ExportedAsset, limit: int) -> Any: resource = exported_asset.export_context - columns: List[str] = resource.get("columns", []) + columns: list[str] = resource.get("columns", []) returned_rows: Generator[Any, None, None] if resource.get("source"): @@ -310,7 +311,7 @@ def _export_to_excel(exported_asset: ExportedAsset, limit: int) -> None: for row_num, row_data in enumerate(renderer.tablize(all_csv_rows, header=render_context.get("header"))): for col_num, value in enumerate(row_data): - if value is not None and not isinstance(value, (str, int, float, bool)): + if value is not None and not isinstance(value, str | int | float | bool): value = str(value) worksheet.cell(row=row_num + 1, column=col_num + 1, value=value) diff --git a/posthog/tasks/exports/ordered_csv_renderer.py b/posthog/tasks/exports/ordered_csv_renderer.py index d183ee874b2bc..5b70e9bed911c 100644 --- a/posthog/tasks/exports/ordered_csv_renderer.py +++ b/posthog/tasks/exports/ordered_csv_renderer.py @@ -1,6 +1,7 @@ import itertools from collections import OrderedDict -from typing import Any, Dict, Generator +from typing import Any +from collections.abc import Generator from more_itertools import unique_everseen from rest_framework_csv.renderers import CSVRenderer @@ -28,7 +29,7 @@ def tablize(self, data: Any, header: Any = None, labels: Any = None) -> Generato # Get the set of all unique headers, and sort them. unique_fields = list(unique_everseen(itertools.chain(*(item.keys() for item in data)))) - ordered_fields: Dict[str, Any] = OrderedDict() + ordered_fields: dict[str, Any] = OrderedDict() for item in unique_fields: field = item.split(".") field = field[0] diff --git a/posthog/tasks/exports/test/test_csv_exporter.py b/posthog/tasks/exports/test/test_csv_exporter.py index 87d731dd6a192..d1c03ea5a3eeb 100644 --- a/posthog/tasks/exports/test/test_csv_exporter.py +++ b/posthog/tasks/exports/test/test_csv_exporter.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Optional from unittest import mock from unittest.mock import MagicMock, Mock, patch, ANY @@ -97,7 +97,7 @@ def patched_request(self): patched_request.return_value = mock_response yield patched_request - def _create_asset(self, extra_context: Optional[Dict] = None) -> ExportedAsset: + def _create_asset(self, extra_context: Optional[dict] = None) -> ExportedAsset: if extra_context is None: extra_context = {} @@ -588,7 +588,7 @@ def test_csv_exporter_empty_result(self, mocked_uuidt: Any) -> None: self.assertEqual(lines[0], "error") self.assertEqual(lines[1], "No data available or unable to format for export.") - def _split_to_dict(self, url: str) -> Dict[str, Any]: + def _split_to_dict(self, url: str) -> dict[str, Any]: first_split_parts = url.split("?") assert len(first_split_parts) == 2 return {bits[0]: bits[1] for bits in [param.split("=") for param in first_split_parts[1].split("&")]} diff --git a/posthog/tasks/sync_all_organization_available_features.py b/posthog/tasks/sync_all_organization_available_features.py index 87e425fa5ca81..ec16a0e0a5a91 100644 --- a/posthog/tasks/sync_all_organization_available_features.py +++ b/posthog/tasks/sync_all_organization_available_features.py @@ -1,4 +1,5 @@ -from typing import Sequence, cast +from typing import cast +from collections.abc import Sequence from posthog.models.organization import Organization diff --git a/posthog/tasks/test/test_calculate_cohort.py b/posthog/tasks/test/test_calculate_cohort.py index 0c81076c8fa81..ff2c534a91039 100644 --- a/posthog/tasks/test/test_calculate_cohort.py +++ b/posthog/tasks/test/test_calculate_cohort.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable from unittest.mock import MagicMock, patch from freezegun import freeze_time diff --git a/posthog/tasks/test/test_email.py b/posthog/tasks/test/test_email.py index 447d0d442bfc8..b89127b48b7a5 100644 --- a/posthog/tasks/test/test_email.py +++ b/posthog/tasks/test/test_email.py @@ -1,5 +1,4 @@ import datetime as dt -from typing import Tuple from unittest.mock import MagicMock, patch import pytest @@ -28,7 +27,7 @@ from posthog.test.base import APIBaseTest, ClickhouseTestMixin -def create_org_team_and_user(creation_date: str, email: str, ingested_event: bool = False) -> Tuple[Organization, User]: +def create_org_team_and_user(creation_date: str, email: str, ingested_event: bool = False) -> tuple[Organization, User]: with freeze_time(creation_date): org = Organization.objects.create(name="too_late_org") Team.objects.create(organization=org, name="Default Project", ingested_event=ingested_event) diff --git a/posthog/tasks/test/test_usage_report.py b/posthog/tasks/test/test_usage_report.py index d977f27560b51..286e1a623f834 100644 --- a/posthog/tasks/test/test_usage_report.py +++ b/posthog/tasks/test/test_usage_report.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Dict, List +from typing import Any from unittest.mock import ANY, MagicMock, Mock, call, patch from uuid import uuid4 @@ -324,14 +324,14 @@ def _create_sample_usage_data(self) -> None: flush_persons_and_events() - def _select_report_by_org_id(self, org_id: str, reports: List[Dict]) -> Dict: + def _select_report_by_org_id(self, org_id: str, reports: list[dict]) -> dict: return next(report for report in reports if report["organization_id"] == org_id) def _create_plugin(self, name: str, enabled: bool) -> None: plugin = Plugin.objects.create(organization_id=self.team.organization.pk, name=name) PluginConfig.objects.create(plugin=plugin, enabled=enabled, order=1) - def _test_usage_report(self) -> List[dict]: + def _test_usage_report(self) -> list[dict]: with self.settings(SITE_URL="http://test.posthog.com"): self._create_sample_usage_data() self._create_plugin("Installed but not enabled", False) diff --git a/posthog/tasks/test/utils_email_tests.py b/posthog/tasks/test/utils_email_tests.py index d9be8cdd3bc7e..ccb998b3dc38d 100644 --- a/posthog/tasks/test/utils_email_tests.py +++ b/posthog/tasks/test/utils_email_tests.py @@ -1,12 +1,12 @@ import os -from typing import Any, List +from typing import Any from unittest.mock import MagicMock from posthog.email import EmailMessage from posthog.utils import get_absolute_path -def mock_email_messages(MockEmailMessage: MagicMock, path: str = "tasks/test/__emails__/") -> List[Any]: +def mock_email_messages(MockEmailMessage: MagicMock, path: str = "tasks/test/__emails__/") -> list[Any]: """ Takes a mocked EmailMessage class and returns a list of all subsequently created EmailMessage instances The "send" method is spyed on to write the generated email to a file diff --git a/posthog/tasks/usage_report.py b/posthog/tasks/usage_report.py index 958601d1ec3ca..727a93f1e0cb3 100644 --- a/posthog/tasks/usage_report.py +++ b/posthog/tasks/usage_report.py @@ -4,16 +4,13 @@ from datetime import datetime from typing import ( Any, - Dict, - List, Literal, Optional, - Sequence, - Tuple, TypedDict, Union, cast, ) +from collections.abc import Sequence import requests import structlog @@ -52,8 +49,15 @@ logger = structlog.get_logger(__name__) -Period = TypedDict("Period", {"start_inclusive": str, "end_inclusive": str}) -TableSizes = TypedDict("TableSizes", {"posthog_event": int, "posthog_sessionrecordingevent": int}) + +class Period(TypedDict): + start_inclusive: str + end_inclusive: str + + +class TableSizes(TypedDict): + posthog_event: int + posthog_sessionrecordingevent: int CH_BILLING_SETTINGS = { @@ -133,13 +137,13 @@ class InstanceMetadata: product: str helm: Optional[dict] clickhouse_version: Optional[str] - users_who_logged_in: Optional[List[Dict[str, Union[str, int]]]] + users_who_logged_in: Optional[list[dict[str, Union[str, int]]]] users_who_logged_in_count: Optional[int] - users_who_signed_up: Optional[List[Dict[str, Union[str, int]]]] + users_who_signed_up: Optional[list[dict[str, Union[str, int]]]] users_who_signed_up_count: Optional[int] table_sizes: Optional[TableSizes] - plugins_installed: Optional[Dict] - plugins_enabled: Optional[Dict] + plugins_installed: Optional[dict] + plugins_enabled: Optional[dict] instance_tag: str @@ -151,7 +155,7 @@ class OrgReport(UsageReportCounters): organization_created_at: str organization_user_count: int team_count: int - teams: Dict[str, UsageReportCounters] + teams: dict[str, UsageReportCounters] @dataclasses.dataclass @@ -163,7 +167,7 @@ def fetch_table_size(table_name: str) -> int: return fetch_sql("SELECT pg_total_relation_size(%s) as size", (table_name,))[0].size -def fetch_sql(sql_: str, params: Tuple[Any, ...]) -> List[Any]: +def fetch_sql(sql_: str, params: tuple[Any, ...]) -> list[Any]: with connection.cursor() as cursor: cursor.execute(sql.SQL(sql_), params) return namedtuplefetchall(cursor) @@ -178,7 +182,7 @@ def get_product_name(realm: str, has_license: bool) -> str: return "unknown" -def get_instance_metadata(period: Tuple[datetime, datetime]) -> InstanceMetadata: +def get_instance_metadata(period: tuple[datetime, datetime]) -> InstanceMetadata: has_license = False if settings.EE_AVAILABLE: @@ -288,7 +292,7 @@ def get_org_owner_or_first_user(organization_id: str) -> Optional[User]: @shared_task(**USAGE_REPORT_TASK_KWARGS, max_retries=3) -def send_report_to_billing_service(org_id: str, report: Dict[str, Any]) -> None: +def send_report_to_billing_service(org_id: str, report: dict[str, Any]) -> None: if not settings.EE_AVAILABLE: return @@ -340,7 +344,7 @@ def capture_event( pha_client: Client, name: str, organization_id: str, - properties: Dict[str, Any], + properties: dict[str, Any], timestamp: Optional[Union[datetime, str]] = None, ) -> None: if timestamp and isinstance(timestamp, str): @@ -373,7 +377,7 @@ def capture_event( @timed_log() @retry(tries=QUERY_RETRIES, delay=QUERY_RETRY_DELAY, backoff=QUERY_RETRY_BACKOFF) -def get_teams_with_event_count_lifetime() -> List[Tuple[int, int]]: +def get_teams_with_event_count_lifetime() -> list[tuple[int, int]]: result = sync_execute( """ SELECT team_id, count(1) as count @@ -390,7 +394,7 @@ def get_teams_with_event_count_lifetime() -> List[Tuple[int, int]]: @retry(tries=QUERY_RETRIES, delay=QUERY_RETRY_DELAY, backoff=QUERY_RETRY_BACKOFF) def get_teams_with_billable_event_count_in_period( begin: datetime, end: datetime, count_distinct: bool = False -) -> List[Tuple[int, int]]: +) -> list[tuple[int, int]]: # count only unique events # Duplicate events will be eventually removed by ClickHouse and likely came from our library or pipeline. # We shouldn't bill for these. However counting unique events is more expensive, and likely to fail on longer time ranges. @@ -420,7 +424,7 @@ def get_teams_with_billable_event_count_in_period( @retry(tries=QUERY_RETRIES, delay=QUERY_RETRY_DELAY, backoff=QUERY_RETRY_BACKOFF) def get_teams_with_billable_enhanced_persons_event_count_in_period( begin: datetime, end: datetime, count_distinct: bool = False -) -> List[Tuple[int, int]]: +) -> list[tuple[int, int]]: # count only unique events # Duplicate events will be eventually removed by ClickHouse and likely came from our library or pipeline. # We shouldn't bill for these. However counting unique events is more expensive, and likely to fail on longer time ranges. @@ -448,7 +452,7 @@ def get_teams_with_billable_enhanced_persons_event_count_in_period( @timed_log() @retry(tries=QUERY_RETRIES, delay=QUERY_RETRY_DELAY, backoff=QUERY_RETRY_BACKOFF) -def get_teams_with_event_count_with_groups_in_period(begin: datetime, end: datetime) -> List[Tuple[int, int]]: +def get_teams_with_event_count_with_groups_in_period(begin: datetime, end: datetime) -> list[tuple[int, int]]: result = sync_execute( """ SELECT team_id, count(1) as count @@ -466,7 +470,7 @@ def get_teams_with_event_count_with_groups_in_period(begin: datetime, end: datet @timed_log() @retry(tries=QUERY_RETRIES, delay=QUERY_RETRY_DELAY, backoff=QUERY_RETRY_BACKOFF) -def get_teams_with_event_count_by_lib(begin: datetime, end: datetime) -> List[Tuple[int, str, int]]: +def get_teams_with_event_count_by_lib(begin: datetime, end: datetime) -> list[tuple[int, str, int]]: results = sync_execute( """ SELECT team_id, JSONExtractString(properties, '$lib') as lib, COUNT(1) as count @@ -483,7 +487,7 @@ def get_teams_with_event_count_by_lib(begin: datetime, end: datetime) -> List[Tu @timed_log() @retry(tries=QUERY_RETRIES, delay=QUERY_RETRY_DELAY, backoff=QUERY_RETRY_BACKOFF) -def get_teams_with_event_count_by_name(begin: datetime, end: datetime) -> List[Tuple[int, str, int]]: +def get_teams_with_event_count_by_name(begin: datetime, end: datetime) -> list[tuple[int, str, int]]: results = sync_execute( """ SELECT team_id, event, COUNT(1) as count @@ -500,7 +504,7 @@ def get_teams_with_event_count_by_name(begin: datetime, end: datetime) -> List[T @timed_log() @retry(tries=QUERY_RETRIES, delay=QUERY_RETRY_DELAY, backoff=QUERY_RETRY_BACKOFF) -def get_teams_with_recording_count_in_period(begin: datetime, end: datetime) -> List[Tuple[int, int]]: +def get_teams_with_recording_count_in_period(begin: datetime, end: datetime) -> list[tuple[int, int]]: previous_begin = begin - (end - begin) result = sync_execute( @@ -531,7 +535,7 @@ def get_teams_with_recording_count_in_period(begin: datetime, end: datetime) -> @timed_log() @retry(tries=QUERY_RETRIES, delay=QUERY_RETRY_DELAY, backoff=QUERY_RETRY_BACKOFF) -def get_teams_with_recording_count_total() -> List[Tuple[int, int]]: +def get_teams_with_recording_count_total() -> list[tuple[int, int]]: result = sync_execute( """ SELECT team_id, count(distinct session_id) as count @@ -549,10 +553,10 @@ def get_teams_with_recording_count_total() -> List[Tuple[int, int]]: def get_teams_with_hogql_metric( begin: datetime, end: datetime, - query_types: List[str], + query_types: list[str], access_method: str = "", metric: Literal["read_bytes", "read_rows", "query_duration_ms"] = "read_bytes", -) -> List[Tuple[int, int]]: +) -> list[tuple[int, int]]: if metric not in ["read_bytes", "read_rows", "query_duration_ms"]: # :TRICKY: Inlined into the query below. raise ValueError(f"Invalid metric {metric}") @@ -586,7 +590,7 @@ def get_teams_with_hogql_metric( @retry(tries=QUERY_RETRIES, delay=QUERY_RETRY_DELAY, backoff=QUERY_RETRY_BACKOFF) def get_teams_with_feature_flag_requests_count_in_period( begin: datetime, end: datetime, request_type: FlagRequestType -) -> List[Tuple[int, int]]: +) -> list[tuple[int, int]]: # depending on the region, events are stored in different teams team_to_query = 1 if get_instance_region() == "EU" else 2 validity_token = settings.DECIDE_BILLING_ANALYTICS_TOKEN @@ -620,7 +624,7 @@ def get_teams_with_feature_flag_requests_count_in_period( def get_teams_with_survey_responses_count_in_period( begin: datetime, end: datetime, -) -> List[Tuple[int, int]]: +) -> list[tuple[int, int]]: results = sync_execute( """ SELECT team_id, COUNT() as count @@ -638,7 +642,7 @@ def get_teams_with_survey_responses_count_in_period( @timed_log() @retry(tries=QUERY_RETRIES, delay=QUERY_RETRY_DELAY, backoff=QUERY_RETRY_BACKOFF) -def get_teams_with_rows_synced_in_period(begin: datetime, end: datetime) -> List[Tuple[int, int]]: +def get_teams_with_rows_synced_in_period(begin: datetime, end: datetime) -> list[tuple[int, int]]: team_to_query = 1 if get_instance_region() == "EU" else 2 # dedup by job id incase there were duplicates sent @@ -668,7 +672,7 @@ def get_teams_with_rows_synced_in_period(begin: datetime, end: datetime) -> List def capture_report( capture_event_name: str, org_id: str, - full_report_dict: Dict[str, Any], + full_report_dict: dict[str, Any], at_date: Optional[datetime] = None, ) -> None: pha_client = Client("sTMFPsFhdP1Ssg") @@ -695,7 +699,7 @@ def has_non_zero_usage(report: FullUsageReport) -> bool: ) -def convert_team_usage_rows_to_dict(rows: List[Union[dict, Tuple[int, int]]]) -> Dict[int, int]: +def convert_team_usage_rows_to_dict(rows: list[Union[dict, tuple[int, int]]]) -> dict[int, int]: team_id_map = {} for row in rows: if isinstance(row, dict) and "team_id" in row: @@ -708,7 +712,7 @@ def convert_team_usage_rows_to_dict(rows: List[Union[dict, Tuple[int, int]]]) -> return team_id_map -def _get_all_usage_data(period_start: datetime, period_end: datetime) -> Dict[str, Any]: +def _get_all_usage_data(period_start: datetime, period_end: datetime) -> dict[str, Any]: """ Gets all usage data for the specified period. Clickhouse is good at counting things so we count across all teams rather than doing it one by one @@ -867,7 +871,7 @@ def _get_all_usage_data(period_start: datetime, period_end: datetime) -> Dict[st } -def _get_all_usage_data_as_team_rows(period_start: datetime, period_end: datetime) -> Dict[str, Any]: +def _get_all_usage_data_as_team_rows(period_start: datetime, period_end: datetime) -> dict[str, Any]: """ Gets all usage data for the specified period as a map of team_id -> value. This makes it faster to access the data than looping over all_data to find what we want. @@ -887,7 +891,7 @@ def _get_teams_for_usage_reports() -> Sequence[Team]: ) -def _get_team_report(all_data: Dict[str, Any], team: Team) -> UsageReportCounters: +def _get_team_report(all_data: dict[str, Any], team: Team) -> UsageReportCounters: decide_requests_count_in_month = all_data["teams_with_decide_requests_count_in_month"].get(team.id, 0) decide_requests_count_in_period = all_data["teams_with_decide_requests_count_in_period"].get(team.id, 0) local_evaluation_requests_count_in_period = all_data["teams_with_local_evaluation_requests_count_in_period"].get( @@ -942,7 +946,7 @@ def _get_team_report(all_data: Dict[str, Any], team: Team) -> UsageReportCounter def _add_team_report_to_org_reports( - org_reports: Dict[str, OrgReport], + org_reports: dict[str, OrgReport], team: Team, team_report: UsageReportCounters, period_start: datetime, @@ -975,12 +979,12 @@ def _add_team_report_to_org_reports( ) -def _get_all_org_reports(period_start: datetime, period_end: datetime) -> Dict[str, OrgReport]: +def _get_all_org_reports(period_start: datetime, period_end: datetime) -> dict[str, OrgReport]: all_data = _get_all_usage_data_as_team_rows(period_start, period_end) teams = _get_teams_for_usage_reports() - org_reports: Dict[str, OrgReport] = {} + org_reports: dict[str, OrgReport] = {} print("Generating reports for teams...") # noqa T201 time_now = datetime.now() @@ -1000,7 +1004,7 @@ def _get_full_org_usage_report(org_report: OrgReport, instance_metadata: Instanc ) -def _get_full_org_usage_report_as_dict(full_report: FullUsageReport) -> Dict[str, Any]: +def _get_full_org_usage_report_as_dict(full_report: FullUsageReport) -> dict[str, Any]: return dataclasses.asdict(full_report) diff --git a/posthog/tasks/verify_persons_data_in_sync.py b/posthog/tasks/verify_persons_data_in_sync.py index 02a53b0176c7b..5ed2a3ec074db 100644 --- a/posthog/tasks/verify_persons_data_in_sync.py +++ b/posthog/tasks/verify_persons_data_in_sync.py @@ -1,7 +1,7 @@ import json from collections import Counter, defaultdict from datetime import timedelta -from typing import Any, Dict, List +from typing import Any import structlog from celery import shared_task @@ -80,7 +80,7 @@ def verify_persons_data_in_sync( return results -def _team_integrity_statistics(person_data: List[Any]) -> Counter: +def _team_integrity_statistics(person_data: list[Any]) -> Counter: person_ids = [id for id, _, _ in person_data] person_uuids = [uuid for _, uuid, _ in person_data] team_ids = list({team_id for _, _, team_id in person_data}) @@ -159,8 +159,8 @@ def _emit_metrics(integrity_results: Counter) -> None: statsd.gauge(f"posthog_person_integrity_{key}", value) -def _index_by(collection: List[Any], key_fn: Any, flat: bool = True) -> Dict: - result: Dict = {} if flat else defaultdict(list) +def _index_by(collection: list[Any], key_fn: Any, flat: bool = True) -> dict: + result: dict = {} if flat else defaultdict(list) for item in collection: if flat: result[key_fn(item)] = item diff --git a/posthog/templatetags/posthog_assets.py b/posthog/templatetags/posthog_assets.py index 422bd687d9a07..dd8a1c1bb1981 100644 --- a/posthog/templatetags/posthog_assets.py +++ b/posthog/templatetags/posthog_assets.py @@ -1,5 +1,4 @@ import re -from typing import List from django.conf import settings from django.template import Library @@ -26,7 +25,7 @@ def absolute_asset_url(path: str) -> str: @register.simple_tag -def human_social_providers(providers: List[str]) -> str: +def human_social_providers(providers: list[str]) -> str: """ Returns a human-friendly name for a social login provider. Example: diff --git a/posthog/temporal/batch_exports/batch_exports.py b/posthog/temporal/batch_exports/batch_exports.py index c522a75bce2c5..68ea47c19c6a8 100644 --- a/posthog/temporal/batch_exports/batch_exports.py +++ b/posthog/temporal/batch_exports/batch_exports.py @@ -188,14 +188,14 @@ def iter_records( timestamp_predicates = "" if fields is None: - query_fields = ",".join((f"{field['expression']} AS {field['alias']}" for field in default_fields())) + query_fields = ",".join(f"{field['expression']} AS {field['alias']}" for field in default_fields()) else: if "_inserted_at" not in [field["alias"] for field in fields]: control_fields = [BatchExportField(expression="COALESCE(inserted_at, _timestamp)", alias="_inserted_at")] else: control_fields = [] - query_fields = ",".join((f"{field['expression']} AS {field['alias']}" for field in fields + control_fields)) + query_fields = ",".join(f"{field['expression']} AS {field['alias']}" for field in fields + control_fields) query = SELECT_QUERY_TEMPLATE.substitute( fields=query_fields, @@ -219,8 +219,7 @@ def iter_records( else: query_parameters = base_query_parameters - for record_batch in client.stream_query_as_arrow(query, query_parameters=query_parameters): - yield record_batch + yield from client.stream_query_as_arrow(query, query_parameters=query_parameters) def get_data_interval(interval: str, data_interval_end: str | None) -> tuple[dt.datetime, dt.datetime]: diff --git a/posthog/temporal/batch_exports/postgres_batch_export.py b/posthog/temporal/batch_exports/postgres_batch_export.py index 6ebede565bc35..a4c1712a12e3c 100644 --- a/posthog/temporal/batch_exports/postgres_batch_export.py +++ b/posthog/temporal/batch_exports/postgres_batch_export.py @@ -98,7 +98,7 @@ async def copy_tsv_to_postgres( # TODO: Switch to binary encoding as CSV has a million edge cases. sql.SQL("COPY {table_name} ({fields}) FROM STDIN WITH (FORMAT CSV, DELIMITER '\t')").format( table_name=sql.Identifier(table_name), - fields=sql.SQL(",").join((sql.Identifier(column) for column in schema_columns)), + fields=sql.SQL(",").join(sql.Identifier(column) for column in schema_columns), ) ) as copy: while data := tsv_file.read(): diff --git a/posthog/temporal/batch_exports/snowflake_batch_export.py b/posthog/temporal/batch_exports/snowflake_batch_export.py index c769862af96f1..373312303be2e 100644 --- a/posthog/temporal/batch_exports/snowflake_batch_export.py +++ b/posthog/temporal/batch_exports/snowflake_batch_export.py @@ -283,7 +283,7 @@ async def create_table_in_snowflake( table_name: fields: An iterable of (name, type) tuples representing the fields of the table. """ - field_ddl = ", ".join((f'"{field[0]}" {field[1]}' for field in fields)) + field_ddl = ", ".join(f'"{field[0]}" {field[1]}' for field in fields) await execute_async_query( connection, diff --git a/posthog/temporal/batch_exports/utils.py b/posthog/temporal/batch_exports/utils.py index f165ae070a83f..c10ede32d778c 100644 --- a/posthog/temporal/batch_exports/utils.py +++ b/posthog/temporal/batch_exports/utils.py @@ -24,8 +24,7 @@ def peek_first_and_rewind( def rewind_gen() -> collections.abc.Generator[T, None, None]: """Yield the item we popped to rewind the generator.""" yield first - for i in gen: - yield i + yield from gen return (first, rewind_gen()) diff --git a/posthog/temporal/common/clickhouse.py b/posthog/temporal/common/clickhouse.py index d548d3871d805..2640bf95c1f97 100644 --- a/posthog/temporal/common/clickhouse.py +++ b/posthog/temporal/common/clickhouse.py @@ -24,7 +24,7 @@ def encode_clickhouse_data(data: typing.Any, quote_char="'") -> bytes: return b"NULL" case uuid.UUID(): - return f"{quote_char}{data}{quote_char}".encode("utf-8") + return f"{quote_char}{data}{quote_char}".encode() case int() | float(): return b"%d" % data @@ -35,8 +35,8 @@ def encode_clickhouse_data(data: typing.Any, quote_char="'") -> bytes: timezone_arg = f", '{data:%Z}'" if data.microsecond == 0: - return f"toDateTime('{data:%Y-%m-%d %H:%M:%S}'{timezone_arg})".encode("utf-8") - return f"toDateTime64('{data:%Y-%m-%d %H:%M:%S.%f}', 6{timezone_arg})".encode("utf-8") + return f"toDateTime('{data:%Y-%m-%d %H:%M:%S}'{timezone_arg})".encode() + return f"toDateTime64('{data:%Y-%m-%d %H:%M:%S.%f}', 6{timezone_arg})".encode() case list(): encoded_data = [encode_clickhouse_data(value) for value in data] @@ -62,7 +62,7 @@ def encode_clickhouse_data(data: typing.Any, quote_char="'") -> bytes: value = str(value) encoded_data.append( - f'"{str(key)}"'.encode("utf-8") + b":" + encode_clickhouse_data(value, quote_char=quote_char) + f'"{str(key)}"'.encode() + b":" + encode_clickhouse_data(value, quote_char=quote_char) ) result = b"{" + b",".join(encoded_data) + b"}" @@ -71,7 +71,7 @@ def encode_clickhouse_data(data: typing.Any, quote_char="'") -> bytes: case _: str_data = str(data) str_data = str_data.replace("\\", "\\\\").replace("'", "\\'") - return f"{quote_char}{str_data}{quote_char}".encode("utf-8") + return f"{quote_char}{str_data}{quote_char}".encode() class ClickHouseError(Exception): @@ -355,8 +355,7 @@ def stream_query_as_arrow( """ with self.post_query(query, *data, query_parameters=query_parameters, query_id=query_id) as response: with pa.ipc.open_stream(pa.PythonFile(response.raw)) as reader: - for batch in reader: - yield batch + yield from reader async def __aenter__(self): """Enter method part of the AsyncContextManager protocol.""" diff --git a/posthog/temporal/common/codec.py b/posthog/temporal/common/codec.py index faf91c31173cb..42e775a24ba3b 100644 --- a/posthog/temporal/common/codec.py +++ b/posthog/temporal/common/codec.py @@ -1,5 +1,5 @@ import base64 -from typing import Iterable +from collections.abc import Iterable from cryptography.fernet import Fernet from temporalio.api.common.v1 import Payload diff --git a/posthog/temporal/common/sentry.py b/posthog/temporal/common/sentry.py index 290cc0182d2d8..81af9367914cb 100644 --- a/posthog/temporal/common/sentry.py +++ b/posthog/temporal/common/sentry.py @@ -1,5 +1,5 @@ from dataclasses import is_dataclass -from typing import Any, Optional, Type, Union +from typing import Any, Optional, Union from temporalio import activity, workflow from temporalio.worker import ( @@ -83,5 +83,5 @@ def intercept_activity(self, next: ActivityInboundInterceptor) -> ActivityInboun def workflow_interceptor_class( self, input: WorkflowInterceptorClassInput - ) -> Optional[Type[WorkflowInboundInterceptor]]: + ) -> Optional[type[WorkflowInboundInterceptor]]: return _SentryWorkflowInterceptor diff --git a/posthog/temporal/common/utils.py b/posthog/temporal/common/utils.py index 022c8270d7748..e8e03332c1a98 100644 --- a/posthog/temporal/common/utils.py +++ b/posthog/temporal/common/utils.py @@ -103,7 +103,7 @@ def from_activity(cls, activity): async def should_resume_from_activity_heartbeat( - activity, heartbeat_type: typing.Type[HeartbeatType], logger + activity, heartbeat_type: type[HeartbeatType], logger ) -> tuple[bool, HeartbeatType | None]: """Check if a batch export should resume from an activity's heartbeat details. diff --git a/posthog/temporal/data_imports/external_data_job.py b/posthog/temporal/data_imports/external_data_job.py index 9c9245e003dbd..dc111fb4b834d 100644 --- a/posthog/temporal/data_imports/external_data_job.py +++ b/posthog/temporal/data_imports/external_data_job.py @@ -34,7 +34,6 @@ ExternalDataSource, ) from posthog.temporal.common.logger import bind_temporal_worker_logger -from typing import Dict @dataclasses.dataclass @@ -67,7 +66,7 @@ class ValidateSchemaInputs: team_id: int schema_id: uuid.UUID table_schema: TSchemaTables - table_row_counts: Dict[str, int] + table_row_counts: dict[str, int] @activity.defn diff --git a/posthog/temporal/data_imports/pipelines/hubspot/__init__.py b/posthog/temporal/data_imports/pipelines/hubspot/__init__.py index 3ffa3c8ffa161..49d84aa41f2d9 100644 --- a/posthog/temporal/data_imports/pipelines/hubspot/__init__.py +++ b/posthog/temporal/data_imports/pipelines/hubspot/__init__.py @@ -23,7 +23,8 @@ >>> resources = hubspot(api_key="hubspot_access_code") """ -from typing import Literal, Sequence, Iterator, Iterable +from typing import Literal +from collections.abc import Sequence, Iterator, Iterable import dlt from dlt.common.typing import TDataItems @@ -114,13 +115,11 @@ def crm_objects( if len(props) > 10000: raise ValueError( - ( - "Your request to Hubspot is too long to process. " - "Maximum allowed query length is 10000 symbols, while " - f"your list of properties `{props[:200]}`... is {len(props)} " - "symbols long. Use the `props` argument of the resource to " - "set the list of properties to extract from the endpoint." - ) + "Your request to Hubspot is too long to process. " + "Maximum allowed query length is 10000 symbols, while " + f"your list of properties `{props[:200]}`... is {len(props)} " + "symbols long. Use the `props` argument of the resource to " + "set the list of properties to extract from the endpoint." ) params = {"properties": props, "limit": 100} diff --git a/posthog/temporal/data_imports/pipelines/hubspot/auth.py b/posthog/temporal/data_imports/pipelines/hubspot/auth.py index 490552cfe237d..b88aa731499bf 100644 --- a/posthog/temporal/data_imports/pipelines/hubspot/auth.py +++ b/posthog/temporal/data_imports/pipelines/hubspot/auth.py @@ -1,6 +1,5 @@ import requests from django.conf import settings -from typing import Tuple def refresh_access_token(refresh_token: str) -> str: @@ -21,7 +20,7 @@ def refresh_access_token(refresh_token: str) -> str: return res.json()["access_token"] -def get_access_token_from_code(code: str, redirect_uri: str) -> Tuple[str, str]: +def get_access_token_from_code(code: str, redirect_uri: str) -> tuple[str, str]: res = requests.post( "https://api.hubapi.com/oauth/v1/token", data={ diff --git a/posthog/temporal/data_imports/pipelines/hubspot/helpers.py b/posthog/temporal/data_imports/pipelines/hubspot/helpers.py index 0ef03b6db23d6..d47616f251abb 100644 --- a/posthog/temporal/data_imports/pipelines/hubspot/helpers.py +++ b/posthog/temporal/data_imports/pipelines/hubspot/helpers.py @@ -1,7 +1,8 @@ """Hubspot source helpers""" import urllib.parse -from typing import Iterator, Dict, Any, List, Optional +from typing import Any, Optional +from collections.abc import Iterator from dlt.sources.helpers import requests import requests as http_requests @@ -16,7 +17,7 @@ def get_url(endpoint: str) -> str: return urllib.parse.urljoin(BASE_URL, endpoint) -def _get_headers(api_key: str) -> Dict[str, str]: +def _get_headers(api_key: str) -> dict[str, str]: """ Return a dictionary of HTTP headers to use for API requests, including the specified API key. @@ -32,7 +33,7 @@ def _get_headers(api_key: str) -> Dict[str, str]: return {"authorization": f"Bearer {api_key}"} -def extract_property_history(objects: List[Dict[str, Any]]) -> Iterator[Dict[str, Any]]: +def extract_property_history(objects: list[dict[str, Any]]) -> Iterator[dict[str, Any]]: for item in objects: history = item.get("propertiesWithHistory") if not history: @@ -49,8 +50,8 @@ def fetch_property_history( endpoint: str, api_key: str, props: str, - params: Optional[Dict[str, Any]] = None, -) -> Iterator[List[Dict[str, Any]]]: + params: Optional[dict[str, Any]] = None, +) -> Iterator[list[dict[str, Any]]]: """Fetch property history from the given CRM endpoint. Args: @@ -91,8 +92,8 @@ def fetch_property_history( def fetch_data( - endpoint: str, api_key: str, refresh_token: str, params: Optional[Dict[str, Any]] = None -) -> Iterator[List[Dict[str, Any]]]: + endpoint: str, api_key: str, refresh_token: str, params: Optional[dict[str, Any]] = None +) -> Iterator[list[dict[str, Any]]]: """ Fetch data from HUBSPOT endpoint using a specified API key and yield the properties of each result. For paginated endpoint this function yields item from all pages. @@ -141,7 +142,7 @@ def fetch_data( # Yield the properties of each result in the API response while _data is not None: if "results" in _data: - _objects: List[Dict[str, Any]] = [] + _objects: list[dict[str, Any]] = [] for _result in _data["results"]: _obj = _result.get("properties", _result) if "id" not in _obj and "id" in _result: @@ -176,7 +177,7 @@ def fetch_data( _data = None -def _get_property_names(api_key: str, refresh_token: str, object_type: str) -> List[str]: +def _get_property_names(api_key: str, refresh_token: str, object_type: str) -> list[str]: """ Retrieve property names for a given entity from the HubSpot API. diff --git a/posthog/temporal/data_imports/pipelines/pipeline.py b/posthog/temporal/data_imports/pipelines/pipeline.py index 0b3f7c448f129..0ac469c214d04 100644 --- a/posthog/temporal/data_imports/pipelines/pipeline.py +++ b/posthog/temporal/data_imports/pipelines/pipeline.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, Literal +from typing import Literal from uuid import UUID import dlt @@ -95,7 +95,7 @@ def _create_pipeline(self): dataset_name=self.inputs.dataset_name, ) - def _run(self) -> Dict[str, int]: + def _run(self) -> dict[str, int]: pipeline = self._create_pipeline() total_counts: Counter = Counter({}) @@ -121,7 +121,7 @@ def _run(self) -> Dict[str, int]: return dict(total_counts) - async def run(self) -> Dict[str, int]: + async def run(self) -> dict[str, int]: try: return await asyncio.to_thread(self._run) except PipelineStepFailed: diff --git a/posthog/temporal/data_imports/pipelines/postgres/__init__.py b/posthog/temporal/data_imports/pipelines/postgres/__init__.py index 438b25fbe9dac..07a368ed572e2 100644 --- a/posthog/temporal/data_imports/pipelines/postgres/__init__.py +++ b/posthog/temporal/data_imports/pipelines/postgres/__init__.py @@ -1,7 +1,8 @@ """Source that loads tables form any SQLAlchemy supported database, supports batching requests and incremental loads.""" -from typing import List, Optional, Union, Iterable, Any -from sqlalchemy import MetaData, Table, text +from typing import Optional, Union, List # noqa: UP035 +from collections.abc import Iterable +from sqlalchemy import MetaData, Table from sqlalchemy.engine import Engine import dlt @@ -35,7 +36,7 @@ def sql_database( credentials: Union[ConnectionStringCredentials, Engine, str] = dlt.secrets.value, schema: Optional[str] = dlt.config.value, metadata: Optional[MetaData] = None, - table_names: Optional[List[str]] = dlt.config.value, + table_names: Optional[List[str]] = dlt.config.value, # noqa: UP006 ) -> Iterable[DltResource]: """ A DLT source which loads data from an SQL database using SQLAlchemy. diff --git a/posthog/temporal/data_imports/pipelines/postgres/helpers.py b/posthog/temporal/data_imports/pipelines/postgres/helpers.py index a288205063f15..5805e7899189a 100644 --- a/posthog/temporal/data_imports/pipelines/postgres/helpers.py +++ b/posthog/temporal/data_imports/pipelines/postgres/helpers.py @@ -2,11 +2,10 @@ from typing import ( Any, - List, Optional, - Iterator, Union, ) +from collections.abc import Iterator import operator import dlt @@ -63,7 +62,7 @@ def make_query(self) -> Select[Any]: return query return query.where(filter_op(self.cursor_column, self.last_value)) # type: ignore - def load_rows(self) -> Iterator[List[TDataItem]]: + def load_rows(self) -> Iterator[list[TDataItem]]: query = self.make_query() with self.engine.connect() as conn: result = conn.execution_options(yield_per=self.chunk_size).execute(query) @@ -104,7 +103,7 @@ def engine_from_credentials(credentials: Union[ConnectionStringCredentials, Engi return create_engine(credentials) -def get_primary_key(table: Table) -> List[str]: +def get_primary_key(table: Table) -> list[str]: return [c.name for c in table.primary_key] diff --git a/posthog/temporal/data_imports/pipelines/stripe/helpers.py b/posthog/temporal/data_imports/pipelines/stripe/helpers.py index 7e2e02017b2c0..56494d3d47ce2 100644 --- a/posthog/temporal/data_imports/pipelines/stripe/helpers.py +++ b/posthog/temporal/data_imports/pipelines/stripe/helpers.py @@ -1,6 +1,7 @@ """Stripe analytics source helpers""" -from typing import Any, Dict, Optional, Union, Iterable, Tuple +from typing import Any, Optional, Union +from collections.abc import Iterable import stripe import dlt @@ -32,7 +33,7 @@ async def stripe_get_data( start_date: Optional[Any] = None, end_date: Optional[Any] = None, **kwargs: Any, -) -> Dict[Any, Any]: +) -> dict[Any, Any]: if start_date: start_date = transform_date(start_date) if end_date: @@ -148,7 +149,7 @@ async def stripe_pagination( def stripe_source( api_key: str, account_id: str, - endpoints: Tuple[str, ...], + endpoints: tuple[str, ...], team_id, job_id, schema_id, diff --git a/posthog/temporal/data_imports/pipelines/zendesk/api_helpers.py b/posthog/temporal/data_imports/pipelines/zendesk/api_helpers.py index a1747d96c78aa..c478060940d4f 100644 --- a/posthog/temporal/data_imports/pipelines/zendesk/api_helpers.py +++ b/posthog/temporal/data_imports/pipelines/zendesk/api_helpers.py @@ -1,4 +1,4 @@ -from typing import Optional, TypedDict, Dict +from typing import Optional, TypedDict from dlt.common import pendulum from dlt.common.time import ensure_pendulum_datetime @@ -18,7 +18,7 @@ def _parse_date_or_none(value: Optional[str]) -> Optional[pendulum.DateTime]: def process_ticket( ticket: DictStrAny, - custom_fields: Dict[str, TCustomFieldInfo], + custom_fields: dict[str, TCustomFieldInfo], pivot_custom_fields: bool = True, ) -> DictStrAny: """ @@ -78,7 +78,7 @@ def process_ticket( return ticket -def process_ticket_field(field: DictStrAny, custom_fields_state: Dict[str, TCustomFieldInfo]) -> TDataItem: +def process_ticket_field(field: DictStrAny, custom_fields_state: dict[str, TCustomFieldInfo]) -> TDataItem: """Update custom field mapping in dlt state for the given field.""" # grab id and update state dict # if the id is new, add a new key to indicate that this is the initial value for title diff --git a/posthog/temporal/data_imports/pipelines/zendesk/credentials.py b/posthog/temporal/data_imports/pipelines/zendesk/credentials.py index aa0463bb4411f..d056528059530 100644 --- a/posthog/temporal/data_imports/pipelines/zendesk/credentials.py +++ b/posthog/temporal/data_imports/pipelines/zendesk/credentials.py @@ -2,7 +2,7 @@ This module handles how credentials are read in dlt sources """ -from typing import ClassVar, List, Union +from typing import ClassVar, Union import dlt from dlt.common.configuration import configspec from dlt.common.configuration.specs import CredentialsConfiguration @@ -16,7 +16,7 @@ class ZendeskCredentialsBase(CredentialsConfiguration): """ subdomain: str - __config_gen_annotations__: ClassVar[List[str]] = [] + __config_gen_annotations__: ClassVar[list[str]] = [] @configspec diff --git a/posthog/temporal/data_imports/pipelines/zendesk/helpers.py b/posthog/temporal/data_imports/pipelines/zendesk/helpers.py index 8c0e0427c3fbb..c29f41279a06b 100644 --- a/posthog/temporal/data_imports/pipelines/zendesk/helpers.py +++ b/posthog/temporal/data_imports/pipelines/zendesk/helpers.py @@ -1,4 +1,5 @@ -from typing import Iterator, Optional, Iterable, Tuple +from typing import Optional +from collections.abc import Iterator, Iterable from itertools import chain import dlt @@ -211,7 +212,7 @@ def chats_table_resource( def zendesk_support( team_id: int, credentials: TZendeskCredentials = dlt.secrets.value, - endpoints: Tuple[str, ...] = (), + endpoints: tuple[str, ...] = (), pivot_ticket_fields: bool = True, start_date: Optional[TAnyDateTime] = DEFAULT_START_DATE, end_date: Optional[TAnyDateTime] = None, diff --git a/posthog/temporal/data_imports/pipelines/zendesk/talk_api.py b/posthog/temporal/data_imports/pipelines/zendesk/talk_api.py index 5db9a28eafc74..4ebf375bf7050 100644 --- a/posthog/temporal/data_imports/pipelines/zendesk/talk_api.py +++ b/posthog/temporal/data_imports/pipelines/zendesk/talk_api.py @@ -1,5 +1,6 @@ from enum import Enum -from typing import Dict, Iterator, Optional, Tuple, Any +from typing import Optional, Any +from collections.abc import Iterator from dlt.common.typing import DictStrStr, TDataItems, TSecretValue from dlt.sources.helpers.requests import client @@ -27,7 +28,7 @@ class ZendeskAPIClient: subdomain: str = "" url: str = "" headers: Optional[DictStrStr] - auth: Optional[Tuple[str, TSecretValue]] + auth: Optional[tuple[str, TSecretValue]] def __init__(self, credentials: TZendeskCredentials, url_prefix: Optional[str] = None) -> None: """ @@ -64,7 +65,7 @@ def get_pages( endpoint: str, data_point_name: str, pagination: PaginationType, - params: Optional[Dict[str, Any]] = None, + params: Optional[dict[str, Any]] = None, ) -> Iterator[TDataItems]: """ Makes a request to a paginated endpoint and returns a generator of data items per page. diff --git a/posthog/temporal/data_imports/workflow_activities/create_job_model.py b/posthog/temporal/data_imports/workflow_activities/create_job_model.py index a838961137462..e6407e9f78598 100644 --- a/posthog/temporal/data_imports/workflow_activities/create_job_model.py +++ b/posthog/temporal/data_imports/workflow_activities/create_job_model.py @@ -3,7 +3,6 @@ from asgiref.sync import sync_to_async from temporalio import activity -from typing import Tuple # TODO: remove dependency from posthog.temporal.data_imports.pipelines.schemas import PIPELINE_TYPE_SCHEMA_DEFAULT_MAPPING @@ -24,7 +23,7 @@ class CreateExternalDataJobModelActivityInputs: @activity.defn -async def create_external_data_job_model_activity(inputs: CreateExternalDataJobModelActivityInputs) -> Tuple[str, bool]: +async def create_external_data_job_model_activity(inputs: CreateExternalDataJobModelActivityInputs) -> tuple[str, bool]: run = await sync_to_async(create_external_data_job)( team_id=inputs.team_id, external_data_source_id=inputs.source_id, diff --git a/posthog/temporal/data_imports/workflow_activities/import_data.py b/posthog/temporal/data_imports/workflow_activities/import_data.py index b1730221f8bac..b6806071e721e 100644 --- a/posthog/temporal/data_imports/workflow_activities/import_data.py +++ b/posthog/temporal/data_imports/workflow_activities/import_data.py @@ -17,7 +17,6 @@ get_external_data_job, ) from posthog.temporal.common.logger import bind_temporal_worker_logger -from typing import Dict, Tuple import asyncio from django.conf import settings from django.utils import timezone @@ -34,7 +33,7 @@ class ImportDataActivityInputs: @activity.defn -async def import_data_activity(inputs: ImportDataActivityInputs) -> Tuple[TSchemaTables, Dict[str, int]]: # noqa: F821 +async def import_data_activity(inputs: ImportDataActivityInputs) -> tuple[TSchemaTables, dict[str, int]]: # noqa: F821 model: ExternalDataJob = await get_external_data_job( job_id=inputs.run_id, ) diff --git a/posthog/temporal/tests/batch_exports/test_http_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_http_batch_export_workflow.py index b0163b8fee798..7b7e2b566743f 100644 --- a/posthog/temporal/tests/batch_exports/test_http_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_http_batch_export_workflow.py @@ -190,8 +190,9 @@ async def test_insert_into_http_activity_inserts_data_into_http_endpoint( ) mock_server = MockServer() - with aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, override_settings( - BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2 + with ( + aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, + override_settings(BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2), ): m.post(TEST_URL, status=200, callback=mock_server.post, repeat=True) await activity_environment.run(insert_into_http_activity, insert_inputs) @@ -239,22 +240,25 @@ async def test_insert_into_http_activity_throws_on_bad_http_status( **http_config, ) - with aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, override_settings( - BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2 + with ( + aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, + override_settings(BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2), ): m.post(TEST_URL, status=400, repeat=True) with pytest.raises(NonRetryableResponseError): await activity_environment.run(insert_into_http_activity, insert_inputs) - with aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, override_settings( - BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2 + with ( + aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, + override_settings(BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2), ): m.post(TEST_URL, status=429, repeat=True) with pytest.raises(RetryableResponseError): await activity_environment.run(insert_into_http_activity, insert_inputs) - with aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, override_settings( - BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2 + with ( + aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, + override_settings(BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2), ): m.post(TEST_URL, status=500, repeat=True) with pytest.raises(RetryableResponseError): @@ -352,8 +356,9 @@ async def test_http_export_workflow( ], workflow_runner=UnsandboxedWorkflowRunner(), ): - with aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, override_settings( - BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2 + with ( + aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, + override_settings(BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2), ): m.post(TEST_URL, status=200, callback=mock_server.post, repeat=True) @@ -589,8 +594,9 @@ def assert_heartbeat_details(*raw_details): ) mock_server = MockServer() - with aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, override_settings( - BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2 + with ( + aioresponses(passthrough=[settings.CLICKHOUSE_HTTP_URL]) as m, + override_settings(BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2), ): m.post(TEST_URL, status=200, callback=mock_server.post, repeat=True) await activity_environment.run(insert_into_http_activity, insert_inputs) diff --git a/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py index 459dff8dc3c00..6652ac224b22a 100644 --- a/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py @@ -175,7 +175,7 @@ def query_request_handler(request: PreparedRequest): # contents as a string in `staged_files`. if match := re.match(r"^PUT file://(?P.*) @%(?P.*)$", sql_text): file_path = match.group("file_path") - with open(file_path, "r") as f: + with open(file_path) as f: staged_files.append(f.read()) if fail == "put": @@ -414,9 +414,12 @@ async def test_snowflake_export_workflow_exports_events( ], workflow_runner=UnsandboxedWorkflowRunner(), ): - with unittest.mock.patch( - "posthog.temporal.batch_exports.snowflake_batch_export.snowflake.connector.connect", - ) as mock, override_settings(BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES=1): + with ( + unittest.mock.patch( + "posthog.temporal.batch_exports.snowflake_batch_export.snowflake.connector.connect", + ) as mock, + override_settings(BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES=1), + ): fake_conn = FakeSnowflakeConnection() mock.return_value = fake_conn @@ -482,10 +485,13 @@ async def test_snowflake_export_workflow_without_events(ateam, snowflake_batch_e ], workflow_runner=UnsandboxedWorkflowRunner(), ): - with responses.RequestsMock( - target="snowflake.connector.vendored.requests.adapters.HTTPAdapter.send", - assert_all_requests_are_fired=False, - ) as rsps, override_settings(BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES=1**2): + with ( + responses.RequestsMock( + target="snowflake.connector.vendored.requests.adapters.HTTPAdapter.send", + assert_all_requests_are_fired=False, + ) as rsps, + override_settings(BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES=1**2), + ): queries, staged_files = add_mock_snowflake_api(rsps) await activity_environment.client.execute_workflow( SnowflakeBatchExportWorkflow.run, diff --git a/posthog/temporal/tests/external_data/test_external_data_job.py b/posthog/temporal/tests/external_data/test_external_data_job.py index 80d5aa7b2cb18..37431206fe331 100644 --- a/posthog/temporal/tests/external_data/test_external_data_job.py +++ b/posthog/temporal/tests/external_data/test_external_data_job.py @@ -330,12 +330,14 @@ async def setup_job_2(): job_1, job_1_inputs = await setup_job_1() job_2, job_2_inputs = await setup_job_2() - with mock.patch("stripe.Customer.list") as mock_customer_list, mock.patch( - "stripe.Charge.list" - ) as mock_charge_list, override_settings( - BUCKET_URL=f"s3://{BUCKET_NAME}", - AIRBYTE_BUCKET_KEY=settings.OBJECT_STORAGE_ACCESS_KEY_ID, - AIRBYTE_BUCKET_SECRET=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + with ( + mock.patch("stripe.Customer.list") as mock_customer_list, + mock.patch("stripe.Charge.list") as mock_charge_list, + override_settings( + BUCKET_URL=f"s3://{BUCKET_NAME}", + AIRBYTE_BUCKET_KEY=settings.OBJECT_STORAGE_ACCESS_KEY_ID, + AIRBYTE_BUCKET_SECRET=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + ), ): mock_customer_list.return_value = { "data": [ @@ -410,10 +412,13 @@ async def setup_job_1(): job_1, job_1_inputs = await setup_job_1() - with mock.patch("stripe.Customer.list") as mock_customer_list, override_settings( - BUCKET_URL=f"s3://{BUCKET_NAME}", - AIRBYTE_BUCKET_KEY=settings.OBJECT_STORAGE_ACCESS_KEY_ID, - AIRBYTE_BUCKET_SECRET=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + with ( + mock.patch("stripe.Customer.list") as mock_customer_list, + override_settings( + BUCKET_URL=f"s3://{BUCKET_NAME}", + AIRBYTE_BUCKET_KEY=settings.OBJECT_STORAGE_ACCESS_KEY_ID, + AIRBYTE_BUCKET_SECRET=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + ), ): mock_customer_list.return_value = { "data": [ @@ -475,12 +480,14 @@ async def setup_job_1(): job_1, job_1_inputs = await setup_job_1() - with mock.patch("stripe.Customer.list") as mock_customer_list, mock.patch( - "posthog.temporal.data_imports.pipelines.helpers.CHUNK_SIZE", 0 - ), override_settings( - BUCKET_URL=f"s3://{BUCKET_NAME}", - AIRBYTE_BUCKET_KEY=settings.OBJECT_STORAGE_ACCESS_KEY_ID, - AIRBYTE_BUCKET_SECRET=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + with ( + mock.patch("stripe.Customer.list") as mock_customer_list, + mock.patch("posthog.temporal.data_imports.pipelines.helpers.CHUNK_SIZE", 0), + override_settings( + BUCKET_URL=f"s3://{BUCKET_NAME}", + AIRBYTE_BUCKET_KEY=settings.OBJECT_STORAGE_ACCESS_KEY_ID, + AIRBYTE_BUCKET_SECRET=settings.OBJECT_STORAGE_SECRET_ACCESS_KEY, + ), ): mock_customer_list.return_value = { "data": [ @@ -527,9 +534,10 @@ async def test_validate_schema_and_update_table_activity(activity_environment, t test_1_schema = await _create_schema("test-1", new_source, team) - with mock.patch( - "posthog.warehouse.models.table.DataWarehouseTable.get_columns" - ) as mock_get_columns, override_settings(**AWS_BUCKET_MOCK_SETTINGS): + with ( + mock.patch("posthog.warehouse.models.table.DataWarehouseTable.get_columns") as mock_get_columns, + override_settings(**AWS_BUCKET_MOCK_SETTINGS), + ): mock_get_columns.return_value = {"id": "string"} await activity_environment.run( validate_schema_activity, @@ -597,9 +605,10 @@ async def test_validate_schema_and_update_table_activity_with_existing(activity_ test_1_schema = await _create_schema("test-1", new_source, team, table_id=existing_table.id) - with mock.patch( - "posthog.warehouse.models.table.DataWarehouseTable.get_columns" - ) as mock_get_columns, override_settings(**AWS_BUCKET_MOCK_SETTINGS): + with ( + mock.patch("posthog.warehouse.models.table.DataWarehouseTable.get_columns") as mock_get_columns, + override_settings(**AWS_BUCKET_MOCK_SETTINGS), + ): mock_get_columns.return_value = {"id": "string"} await activity_environment.run( validate_schema_activity, @@ -640,9 +649,13 @@ async def test_validate_schema_and_update_table_activity_half_run(activity_envir rows_synced=0, ) - with mock.patch("posthog.warehouse.models.table.DataWarehouseTable.get_columns") as mock_get_columns, mock.patch( - "posthog.warehouse.data_load.validate_schema.validate_schema", - ) as mock_validate, override_settings(**AWS_BUCKET_MOCK_SETTINGS): + with ( + mock.patch("posthog.warehouse.models.table.DataWarehouseTable.get_columns") as mock_get_columns, + mock.patch( + "posthog.warehouse.data_load.validate_schema.validate_schema", + ) as mock_validate, + override_settings(**AWS_BUCKET_MOCK_SETTINGS), + ): mock_get_columns.return_value = {"id": "string"} credential = await sync_to_async(DataWarehouseCredential.objects.create)( team=team, @@ -708,9 +721,10 @@ async def test_create_schema_activity(activity_environment, team, **kwargs): test_1_schema = await _create_schema("test-1", new_source, team) - with mock.patch( - "posthog.warehouse.models.table.DataWarehouseTable.get_columns" - ) as mock_get_columns, override_settings(**AWS_BUCKET_MOCK_SETTINGS): + with ( + mock.patch("posthog.warehouse.models.table.DataWarehouseTable.get_columns") as mock_get_columns, + override_settings(**AWS_BUCKET_MOCK_SETTINGS), + ): mock_get_columns.return_value = {"id": "string"} await activity_environment.run( validate_schema_activity, @@ -763,9 +777,10 @@ async def test_external_data_job_workflow_with_schema(team, **kwargs): async def mock_async_func(inputs): return {} - with mock.patch( - "posthog.warehouse.models.table.DataWarehouseTable.get_columns", return_value={"id": "string"} - ), mock.patch.object(DataImportPipeline, "run", mock_async_func): + with ( + mock.patch("posthog.warehouse.models.table.DataWarehouseTable.get_columns", return_value={"id": "string"}), + mock.patch.object(DataImportPipeline, "run", mock_async_func), + ): with override_settings(AIRBYTE_BUCKET_KEY="test-key", AIRBYTE_BUCKET_SECRET="test-secret"): async with await WorkflowEnvironment.start_time_skipping() as activity_environment: async with Worker( @@ -910,13 +925,17 @@ async def test_check_schedule_activity_with_missing_schema_id_but_with_schedule( should_sync=True, ) - with mock.patch( - "posthog.temporal.data_imports.external_data_job.a_external_data_workflow_exists", return_value=True - ), mock.patch( - "posthog.temporal.data_imports.external_data_job.a_delete_external_data_schedule", return_value=True - ), mock.patch( - "posthog.temporal.data_imports.external_data_job.a_trigger_external_data_workflow" - ) as mock_a_trigger_external_data_workflow: + with ( + mock.patch( + "posthog.temporal.data_imports.external_data_job.a_external_data_workflow_exists", return_value=True + ), + mock.patch( + "posthog.temporal.data_imports.external_data_job.a_delete_external_data_schedule", return_value=True + ), + mock.patch( + "posthog.temporal.data_imports.external_data_job.a_trigger_external_data_workflow" + ) as mock_a_trigger_external_data_workflow, + ): should_exit = await activity_environment.run( check_schedule_activity, ExternalDataWorkflowInputs( @@ -950,13 +969,17 @@ async def test_check_schedule_activity_with_missing_schema_id_and_no_schedule(ac should_sync=True, ) - with mock.patch( - "posthog.temporal.data_imports.external_data_job.a_external_data_workflow_exists", return_value=False - ), mock.patch( - "posthog.temporal.data_imports.external_data_job.a_delete_external_data_schedule", return_value=True - ), mock.patch( - "posthog.temporal.data_imports.external_data_job.a_sync_external_data_job_workflow" - ) as mock_a_sync_external_data_job_workflow: + with ( + mock.patch( + "posthog.temporal.data_imports.external_data_job.a_external_data_workflow_exists", return_value=False + ), + mock.patch( + "posthog.temporal.data_imports.external_data_job.a_delete_external_data_schedule", return_value=True + ), + mock.patch( + "posthog.temporal.data_imports.external_data_job.a_sync_external_data_job_workflow" + ) as mock_a_sync_external_data_job_workflow, + ): should_exit = await activity_environment.run( check_schedule_activity, ExternalDataWorkflowInputs( diff --git a/posthog/test/base.py b/posthog/test/base.py index c96738aafa139..2ebfa6178e259 100644 --- a/posthog/test/base.py +++ b/posthog/test/base.py @@ -7,7 +7,8 @@ import uuid from contextlib import contextmanager from functools import wraps -from typing import Any, Dict, List, Optional, Tuple, Union, Generator +from typing import Any, Optional, Union +from collections.abc import Generator from unittest.mock import patch import freezegun @@ -86,8 +87,8 @@ freezegun.configure(extend_ignore_list=["posthog.test.assert_faster_than"]) # type: ignore -persons_cache_tests: List[Dict[str, Any]] = [] -events_cache_tests: List[Dict[str, Any]] = [] +persons_cache_tests: list[dict[str, Any]] = [] +events_cache_tests: list[dict[str, Any]] = [] persons_ordering_int: int = 1 @@ -124,7 +125,7 @@ class FuzzyInt(int): highest: int def __new__(cls, lowest, highest): - obj = super(FuzzyInt, cls).__new__(cls, highest) + obj = super().__new__(cls, highest) obj.lowest = lowest obj.highest = highest return obj @@ -144,7 +145,7 @@ class ErrorResponsesMixin: "attr": None, } - def not_found_response(self, message: str = "Not found.") -> Dict[str, Optional[str]]: + def not_found_response(self, message: str = "Not found.") -> dict[str, Optional[str]]: return { "type": "invalid_request", "code": "not_found", @@ -154,7 +155,7 @@ def not_found_response(self, message: str = "Not found.") -> Dict[str, Optional[ def permission_denied_response( self, message: str = "You do not have permission to perform this action." - ) -> Dict[str, Optional[str]]: + ) -> dict[str, Optional[str]]: return { "type": "authentication_error", "code": "permission_denied", @@ -162,7 +163,7 @@ def permission_denied_response( "attr": None, } - def method_not_allowed_response(self, method: str) -> Dict[str, Optional[str]]: + def method_not_allowed_response(self, method: str) -> dict[str, Optional[str]]: return { "type": "invalid_request", "code": "method_not_allowed", @@ -174,7 +175,7 @@ def unauthenticated_response( self, message: str = "Authentication credentials were not provided.", code: str = "not_authenticated", - ) -> Dict[str, Optional[str]]: + ) -> dict[str, Optional[str]]: return { "type": "authentication_error", "code": code, @@ -187,7 +188,7 @@ def validation_error_response( message: str = "Malformed request", code: str = "invalid_input", attr: Optional[str] = None, - ) -> Dict[str, Optional[str]]: + ) -> dict[str, Optional[str]]: return { "type": "validation_error", "code": code, @@ -820,7 +821,7 @@ def capture_select_queries(self): return self.capture_queries(("SELECT", "WITH", "select", "with")) @contextmanager - def capture_queries(self, query_prefixes: Union[str, Tuple[str, ...]]): + def capture_queries(self, query_prefixes: Union[str, tuple[str, ...]]): queries = [] original_get_client = ch_pool.get_client @@ -863,7 +864,7 @@ def raise_hook(args: threading.ExceptHookArgs): threading.excepthook = old_hook -def run_clickhouse_statement_in_parallel(statements: List[str]): +def run_clickhouse_statement_in_parallel(statements: list[str]): jobs = [] with failhard_threadhook_context(): for item in statements: @@ -1063,8 +1064,8 @@ def fn_with_poe_v2(self, *args, **kwargs): def _create_insight( - team: Team, insight_filters: Dict[str, Any], dashboard_filters: Dict[str, Any] -) -> Tuple[Insight, Dashboard, DashboardTile]: + team: Team, insight_filters: dict[str, Any], dashboard_filters: dict[str, Any] +) -> tuple[Insight, Dashboard, DashboardTile]: dashboard = Dashboard.objects.create(team=team, filters=dashboard_filters) insight = Insight.objects.create(team=team, filters=insight_filters) dashboard_tile = DashboardTile.objects.create(dashboard=dashboard, insight=insight) @@ -1088,7 +1089,7 @@ def create_person_id_override_by_distinct_id( """ ) - person_id_from, person_id_to = [row[1] for row in person_ids_result] + person_id_from, person_id_to = (row[1] for row in person_ids_result) sync_execute( f""" diff --git a/posthog/test/db_context_capturing.py b/posthog/test/db_context_capturing.py index 6060023545637..44c1b05d23cc0 100644 --- a/posthog/test/db_context_capturing.py +++ b/posthog/test/db_context_capturing.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import Generator +from collections.abc import Generator from django.db import DEFAULT_DB_ALIAS, connections from django.test.utils import CaptureQueriesContext diff --git a/posthog/test/test_feature_flag.py b/posthog/test/test_feature_flag.py index 38afbe7dbbcd7..91db555b31b9e 100644 --- a/posthog/test/test_feature_flag.py +++ b/posthog/test/test_feature_flag.py @@ -2784,8 +2784,9 @@ def test_multiple_flags(self): key="variant", ) - with self.assertNumQueries(10), snapshot_postgres_queries_context( - self + with ( + self.assertNumQueries(10), + snapshot_postgres_queries_context(self), ): # 1 to fill group cache, 2 to match feature flags with group properties (of each type), 1 to match feature flags with person properties matches, reasons, payloads, _ = FeatureFlagMatcher( [ @@ -2859,8 +2860,9 @@ def test_multiple_flags(self): self.assertEqual(payloads, {"variant": {"color": "blue"}}) - with self.assertNumQueries(9), snapshot_postgres_queries_context( - self + with ( + self.assertNumQueries(9), + snapshot_postgres_queries_context(self), ): # 1 to fill group cache, 1 to match feature flags with group properties (only 1 group provided), 1 to match feature flags with person properties matches, reasons, payloads, _ = FeatureFlagMatcher( [ @@ -6016,8 +6018,9 @@ def __call__(self, execute, sql, *args, **kwargs): properties={"email": "tim@posthog.com", "team": "posthog"}, ) - with snapshot_postgres_queries_context(self, capture_all_queries=True), connection.execute_wrapper( - InsertFailOnce() + with ( + snapshot_postgres_queries_context(self, capture_all_queries=True), + connection.execute_wrapper(InsertFailOnce()), ): flags, reasons, payloads, errors = get_all_feature_flags( team.pk, "other_id", {}, hash_key_override="example_id" diff --git a/posthog/test/test_feature_flag_analytics.py b/posthog/test/test_feature_flag_analytics.py index f5a5f37e0ac0a..ed8228ff21170 100644 --- a/posthog/test/test_feature_flag_analytics.py +++ b/posthog/test/test_feature_flag_analytics.py @@ -77,8 +77,9 @@ def test_capture_team_decide_usage(self): team_uuid = "team-uuid" other_team_uuid = "other-team-uuid" - with freeze_time("2022-05-07 12:23:07") as frozen_datetime, self.settings( - DECIDE_BILLING_ANALYTICS_TOKEN="token" + with ( + freeze_time("2022-05-07 12:23:07") as frozen_datetime, + self.settings(DECIDE_BILLING_ANALYTICS_TOKEN="token"), ): for _ in range(10): # 10 requests in first bucket @@ -299,8 +300,9 @@ def test_no_interference_between_different_types_of_new_incoming_increments(self other_team_id = 1243 team_uuid = "team-uuid" - with freeze_time("2022-05-07 12:23:07") as frozen_datetime, self.settings( - DECIDE_BILLING_ANALYTICS_TOKEN="token" + with ( + freeze_time("2022-05-07 12:23:07") as frozen_datetime, + self.settings(DECIDE_BILLING_ANALYTICS_TOKEN="token"), ): for _ in range(10): # 10 requests in first bucket @@ -400,8 +402,9 @@ def test_locking_works_for_capture_team_decide_usage(self): team_uuid = "team-uuid" other_team_uuid = "other-team-uuid" - with freeze_time("2022-05-07 12:23:07") as frozen_datetime, self.settings( - DECIDE_BILLING_ANALYTICS_TOKEN="token" + with ( + freeze_time("2022-05-07 12:23:07") as frozen_datetime, + self.settings(DECIDE_BILLING_ANALYTICS_TOKEN="token"), ): for _ in range(10): # 10 requests in first bucket @@ -489,8 +492,9 @@ def test_locking_in_redis_doesnt_block_new_incoming_increments(self): other_team_id = 1243 team_uuid = "team-uuid" - with freeze_time("2022-05-07 12:23:07") as frozen_datetime, self.settings( - DECIDE_BILLING_ANALYTICS_TOKEN="token" + with ( + freeze_time("2022-05-07 12:23:07") as frozen_datetime, + self.settings(DECIDE_BILLING_ANALYTICS_TOKEN="token"), ): for _ in range(10): # 10 requests in first bucket diff --git a/posthog/test/test_health.py b/posthog/test/test_health.py index 89611fb11ee31..2ce4e464e8cf7 100644 --- a/posthog/test/test_health.py +++ b/posthog/test/test_health.py @@ -1,7 +1,7 @@ import logging from contextlib import contextmanager import random -from typing import List, Optional +from typing import Optional from unittest import mock from unittest.mock import patch @@ -70,7 +70,13 @@ def test_livez_returns_200_and_doesnt_require_any_dependencies(client: Client): just be an indicator that the python process hasn't hung. """ - with simulate_postgres_error(), simulate_kafka_cannot_connect(), simulate_clickhouse_cannot_connect(), simulate_celery_cannot_connect(), simulate_cache_cannot_connect(): + with ( + simulate_postgres_error(), + simulate_kafka_cannot_connect(), + simulate_clickhouse_cannot_connect(), + simulate_celery_cannot_connect(), + simulate_cache_cannot_connect(), + ): resp = get_livez(client) assert resp.status_code == 200, resp.content @@ -263,7 +269,7 @@ def test_readyz_complains_if_role_does_not_exist(client: Client): assert data["error"] == "InvalidRole" -def get_readyz(client: Client, exclude: Optional[List[str]] = None, role: Optional[str] = None) -> HttpResponse: +def get_readyz(client: Client, exclude: Optional[list[str]] = None, role: Optional[str] = None) -> HttpResponse: return client.get("/_readyz", data={"exclude": exclude or [], "role": role or ""}) diff --git a/posthog/test/test_journeys.py b/posthog/test/test_journeys.py index 0e535437076e9..69bb2050d8f3b 100644 --- a/posthog/test/test_journeys.py +++ b/posthog/test/test_journeys.py @@ -3,7 +3,7 @@ import json from datetime import datetime import os -from typing import Any, Dict, List +from typing import Any from uuid import UUID, uuid4 from django.utils import timezone @@ -15,10 +15,10 @@ def journeys_for( - events_by_person: Dict[str, List[Dict[str, Any]]], + events_by_person: dict[str, list[dict[str, Any]]], team: Team, create_people: bool = True, -) -> Dict[str, Person]: +) -> dict[str, Person]: """ Helper for creating specific events for a team. @@ -115,11 +115,11 @@ def journeys_for( return people -def _create_all_events_raw(all_events: List[Dict]): +def _create_all_events_raw(all_events: list[dict]): parsed = "" for event in all_events: timestamp = timezone.now() - data: Dict[str, Any] = { + data: dict[str, Any] = { "properties": {}, "timestamp": timestamp.strftime("%Y-%m-%d %H:%M:%S.%f"), "person_id": str(uuid4()), @@ -162,7 +162,7 @@ def _create_all_events_raw(all_events: List[Dict]): ) -def create_all_events(all_events: List[dict]): +def create_all_events(all_events: list[dict]): for event in all_events: _create_event(**event) @@ -175,15 +175,15 @@ class InMemoryEvent: distinct_id: str team: Team timestamp: str - properties: Dict + properties: dict person_id: str person_created_at: datetime - person_properties: Dict - group0_properties: Dict - group1_properties: Dict - group2_properties: Dict - group3_properties: Dict - group4_properties: Dict + person_properties: dict + group0_properties: dict + group1_properties: dict + group2_properties: dict + group3_properties: dict + group4_properties: dict group0_created_at: datetime group1_created_at: datetime group2_created_at: datetime @@ -191,7 +191,7 @@ class InMemoryEvent: group4_created_at: datetime -def update_or_create_person(distinct_ids: List[str], team_id: int, **kwargs): +def update_or_create_person(distinct_ids: list[str], team_id: int, **kwargs): (person, _) = Person.objects.update_or_create( persondistinctid__distinct_id__in=distinct_ids, persondistinctid__team_id=team_id, diff --git a/posthog/test/test_utils.py b/posthog/test/test_utils.py index 827c5dd1de851..dab6a4d1e0ea7 100644 --- a/posthog/test/test_utils.py +++ b/posthog/test/test_utils.py @@ -434,7 +434,7 @@ def test_should_not_refresh_with_refresh_gibberish(self): def test_refresh_requested_by_client_with_data_true(self): drf_request = Request(HttpRequest()) drf_request._full_data = {"refresh": True} # type: ignore - self.assertTrue(refresh_requested_by_client((drf_request))) + self.assertTrue(refresh_requested_by_client(drf_request)) def test_should_not_refresh_with_data_false(self): drf_request = Request(HttpRequest()) diff --git a/posthog/urls.py b/posthog/urls.py index b047f897307e4..3681f4a1ca4f2 100644 --- a/posthog/urls.py +++ b/posthog/urls.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, List, Optional, cast +from typing import Any, Optional, cast +from collections.abc import Callable from posthog.models.instance_setting import get_instance_setting from urllib.parse import urlparse @@ -60,7 +61,7 @@ logger = structlog.get_logger(__name__) -ee_urlpatterns: List[Any] = [] +ee_urlpatterns: list[Any] = [] try: from ee.urls import extend_api_router from ee.urls import urlpatterns as ee_urlpatterns diff --git a/posthog/user_permissions.py b/posthog/user_permissions.py index 30a6bfca298b1..7b4d9b07728ca 100644 --- a/posthog/user_permissions.py +++ b/posthog/user_permissions.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Any, Dict, List, Optional, cast +from typing import Any, Optional, cast from uuid import UUID from posthog.constants import AvailableFeature @@ -32,10 +32,10 @@ def __init__(self, user: User, team: Optional[Team] = None): self.user = user self._current_team = team - self._tiles: Optional[List[DashboardTile]] = None - self._team_permissions: Dict[int, UserTeamPermissions] = {} - self._dashboard_permissions: Dict[int, UserDashboardPermissions] = {} - self._insight_permissions: Dict[int, UserInsightPermissions] = {} + self._tiles: Optional[list[DashboardTile]] = None + self._team_permissions: dict[int, UserTeamPermissions] = {} + self._dashboard_permissions: dict[int, UserDashboardPermissions] = {} + self._insight_permissions: dict[int, UserInsightPermissions] = {} @cached_property def current_team(self) -> "UserTeamPermissions": @@ -68,7 +68,7 @@ def insight(self, insight: Insight) -> "UserInsightPermissions": return self._insight_permissions[insight.pk] @cached_property - def team_ids_visible_for_user(self) -> List[int]: + def team_ids_visible_for_user(self) -> list[int]: candidate_teams = Team.objects.filter(organization_id__in=self.organizations.keys()).only( "pk", "organization_id", "access_control" ) @@ -86,16 +86,16 @@ def get_organization(self, organization_id: UUID) -> Optional[Organization]: return self.organizations.get(organization_id) @cached_property - def organizations(self) -> Dict[UUID, Organization]: + def organizations(self) -> dict[UUID, Organization]: return {member.organization_id: member.organization for member in self.organization_memberships.values()} @cached_property - def organization_memberships(self) -> Dict[UUID, OrganizationMembership]: + def organization_memberships(self) -> dict[UUID, OrganizationMembership]: memberships = OrganizationMembership.objects.filter(user=self.user).select_related("organization") return {membership.organization_id: membership for membership in memberships} @cached_property - def explicit_team_memberships(self) -> Dict[int, Any]: + def explicit_team_memberships(self) -> dict[int, Any]: try: from ee.models import ExplicitTeamMembership except ImportError: @@ -107,7 +107,7 @@ def explicit_team_memberships(self) -> Dict[int, Any]: return {membership.team_id: membership.level for membership in memberships} @cached_property - def dashboard_privileges(self) -> Dict[int, Dashboard.PrivilegeLevel]: + def dashboard_privileges(self) -> dict[int, Dashboard.PrivilegeLevel]: try: from ee.models import DashboardPrivilege @@ -116,14 +116,14 @@ def dashboard_privileges(self) -> Dict[int, Dashboard.PrivilegeLevel]: except ImportError: return {} - def set_preloaded_dashboard_tiles(self, tiles: List[DashboardTile]): + def set_preloaded_dashboard_tiles(self, tiles: list[DashboardTile]): """ Allows for speeding up insight-related permissions code """ self._tiles = tiles @cached_property - def preloaded_insight_dashboards(self) -> Optional[List[Dashboard]]: + def preloaded_insight_dashboards(self) -> Optional[list[Dashboard]]: if self._tiles is None: return None diff --git a/posthog/utils.py b/posthog/utils.py index f7c32736b2557..cdc0a4ed48fd2 100644 --- a/posthog/utils.py +++ b/posthog/utils.py @@ -19,15 +19,11 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - Generator, - List, - Mapping, Optional, - Tuple, Union, cast, ) +from collections.abc import Generator, Mapping from urllib.parse import urljoin, urlparse from zoneinfo import ZoneInfo @@ -125,7 +121,7 @@ def absolute_uri(url: Optional[str] = None) -> str: return urljoin(settings.SITE_URL.rstrip("/") + "/", url.lstrip("/")) -def get_previous_day(at: Optional[datetime.datetime] = None) -> Tuple[datetime.datetime, datetime.datetime]: +def get_previous_day(at: Optional[datetime.datetime] = None) -> tuple[datetime.datetime, datetime.datetime]: """ Returns a pair of datetimes, representing the start and end of the preceding day. `at` is the datetime to use as a reference point. @@ -149,7 +145,7 @@ def get_previous_day(at: Optional[datetime.datetime] = None) -> Tuple[datetime.d return (period_start, period_end) -def get_current_day(at: Optional[datetime.datetime] = None) -> Tuple[datetime.datetime, datetime.datetime]: +def get_current_day(at: Optional[datetime.datetime] = None) -> tuple[datetime.datetime, datetime.datetime]: """ Returns a pair of datetimes, representing the start and end of the current day. `at` is the datetime to use as a reference point. @@ -179,7 +175,7 @@ def relative_date_parse_with_delta_mapping( *, always_truncate: bool = False, now: Optional[datetime.datetime] = None, -) -> Tuple[datetime.datetime, Optional[Dict[str, int]], str | None]: +) -> tuple[datetime.datetime, Optional[dict[str, int]], str | None]: """Returns the parsed datetime, along with the period mapping - if the input was a relative datetime string.""" try: try: @@ -202,7 +198,7 @@ def relative_date_parse_with_delta_mapping( regex = r"\-?(?P[0-9]+)?(?P[a-z])(?PStart|End)?" match = re.search(regex, input) parsed_dt = (now or dt.datetime.now()).astimezone(timezone_info) - delta_mapping: Dict[str, int] = {} + delta_mapping: dict[str, int] = {} if not match: return parsed_dt, delta_mapping, None if match.group("type") == "h": @@ -276,7 +272,7 @@ def get_js_url(request: HttpRequest) -> str: def render_template( template_name: str, request: HttpRequest, - context: Optional[Dict] = None, + context: Optional[dict] = None, *, team_for_public_context: Optional["Team"] = None, ) -> HttpResponse: @@ -331,13 +327,13 @@ def render_template( except: year_in_hog_url = None - posthog_app_context: Dict[str, Any] = { + posthog_app_context: dict[str, Any] = { "persisted_feature_flags": settings.PERSISTED_FEATURE_FLAGS, "anonymous": not request.user or not request.user.is_authenticated, "year_in_hog_url": year_in_hog_url, } - posthog_bootstrap: Dict[str, Any] = {} + posthog_bootstrap: dict[str, Any] = {} posthog_distinct_id: Optional[str] = None # Set the frontend app context @@ -453,7 +449,7 @@ def get_default_event_name(team: "Team"): return "$pageview" -def get_frontend_apps(team_id: int) -> Dict[int, Dict[str, Any]]: +def get_frontend_apps(team_id: int) -> dict[int, dict[str, Any]]: from posthog.models import Plugin, PluginSourceFile plugin_configs = ( @@ -541,10 +537,10 @@ def convert_property_value(input: Union[str, bool, dict, list, int, Optional[str def get_compare_period_dates( date_from: datetime.datetime, date_to: datetime.datetime, - date_from_delta_mapping: Optional[Dict[str, int]], - date_to_delta_mapping: Optional[Dict[str, int]], + date_from_delta_mapping: Optional[dict[str, int]], + date_to_delta_mapping: Optional[dict[str, int]], interval: str, -) -> Tuple[datetime.datetime, datetime.datetime]: +) -> tuple[datetime.datetime, datetime.datetime]: diff = date_to - date_from new_date_from = date_from - diff if interval == "hour": @@ -783,7 +779,7 @@ def get_plugin_server_version() -> Optional[str]: return None -def get_plugin_server_job_queues() -> Optional[List[str]]: +def get_plugin_server_job_queues() -> Optional[list[str]]: cache_key_value = get_client().get("@posthog-plugin-server/enabled-job-queues") if cache_key_value: qs = cache_key_value.decode("utf-8").replace('"', "") @@ -861,13 +857,13 @@ def get_can_create_org(user: Union["AbstractBaseUser", "AnonymousUser"]) -> bool return False -def get_instance_available_sso_providers() -> Dict[str, bool]: +def get_instance_available_sso_providers() -> dict[str, bool]: """ Returns a dictionary containing final determination to which SSO providers are available. SAML is not included in this method as it can only be configured domain-based and not instance-based (see `OrganizationDomain` for details) Validates configuration settings and license validity (if applicable). """ - output: Dict[str, bool] = { + output: dict[str, bool] = { "github": bool(settings.SOCIAL_AUTH_GITHUB_KEY and settings.SOCIAL_AUTH_GITHUB_SECRET), "gitlab": bool(settings.SOCIAL_AUTH_GITLAB_KEY and settings.SOCIAL_AUTH_GITLAB_SECRET), "google-oauth2": False, @@ -897,7 +893,7 @@ def get_instance_available_sso_providers() -> Dict[str, bool]: return output -def flatten(i: Union[List, Tuple], max_depth=10) -> Generator: +def flatten(i: Union[list, tuple], max_depth=10) -> Generator: for el in i: if isinstance(el, list) and max_depth > 0: yield from flatten(el, max_depth=max_depth - 1) @@ -909,7 +905,7 @@ def get_daterange( start_date: Optional[datetime.datetime], end_date: Optional[datetime.datetime], frequency: str, -) -> List[Any]: +) -> list[Any]: """ Returns list of a fixed frequency Datetime objects between given bounds. @@ -981,7 +977,7 @@ class GenericEmails: """ def __init__(self): - with open(get_absolute_path("helpers/generic_emails.txt"), "r") as f: + with open(get_absolute_path("helpers/generic_emails.txt")) as f: self.emails = {x.rstrip(): True for x in f} def is_generic(self, email: str) -> bool: @@ -992,7 +988,7 @@ def is_generic(self, email: str) -> bool: @lru_cache(maxsize=1) -def get_available_timezones_with_offsets() -> Dict[str, float]: +def get_available_timezones_with_offsets() -> dict[str, float]: now = dt.datetime.now() result = {} for tz in pytz.common_timezones: @@ -1066,7 +1062,7 @@ def get_milliseconds_between_dates(d1: dt.datetime, d2: dt.datetime) -> int: return abs(int((d1 - d2).total_seconds() * 1000)) -def encode_get_request_params(data: Dict[str, Any]) -> Dict[str, str]: +def encode_get_request_params(data: dict[str, Any]) -> dict[str, str]: return { key: encode_value_as_param(value=value) for key, value in data.items() @@ -1083,7 +1079,7 @@ def default(self, o): def encode_value_as_param(value: Union[str, list, dict, datetime.datetime]) -> str: - if isinstance(value, (list, dict, tuple)): + if isinstance(value, list | dict | tuple): return json.dumps(value, cls=DataclassJSONEncoder) elif isinstance(value, Enum): return value.value @@ -1311,7 +1307,7 @@ def patch(wrapper): def label_for_team_id_to_track(team_id: int) -> str: - team_id_filter: List[str] = settings.DECIDE_TRACK_TEAM_IDS + team_id_filter: list[str] = settings.DECIDE_TRACK_TEAM_IDS team_id_as_string = str(team_id) diff --git a/posthog/version_requirement.py b/posthog/version_requirement.py index 0f60d553e762e..ad0979abc3b32 100644 --- a/posthog/version_requirement.py +++ b/posthog/version_requirement.py @@ -1,5 +1,3 @@ -from typing import Tuple - from semantic_version.base import SimpleSpec, Version from posthog import redis @@ -24,7 +22,7 @@ def __init__(self, service, supported_version): f"The provided supported_version for service {service} is invalid. See the Docs for SimpleSpec: https://pypi.org/project/semantic-version/" ) - def is_service_in_accepted_version(self) -> Tuple[bool, Version]: + def is_service_in_accepted_version(self) -> tuple[bool, Version]: service_version = self.get_service_version() return service_version in self.supported_version, service_version diff --git a/posthog/views.py b/posthog/views.py index b9cae80fde3d7..6797b3ab7f823 100644 --- a/posthog/views.py +++ b/posthog/views.py @@ -1,6 +1,6 @@ import os from functools import wraps -from typing import Dict, Union +from typing import Union import sentry_sdk from django.conf import settings @@ -70,7 +70,7 @@ def health(request): def stats(request): - stats_response: Dict[str, Union[int, str]] = {} + stats_response: dict[str, Union[int, str]] = {} stats_response["worker_heartbeat"] = get_celery_heartbeat() return JsonResponse(stats_response) diff --git a/posthog/warehouse/api/external_data_schema.py b/posthog/warehouse/api/external_data_schema.py index e7abb808ce555..c02f6c146f7c9 100644 --- a/posthog/warehouse/api/external_data_schema.py +++ b/posthog/warehouse/api/external_data_schema.py @@ -2,7 +2,7 @@ import structlog import temporalio from posthog.warehouse.models import ExternalDataSchema, ExternalDataJob -from typing import Optional, Dict, Any +from typing import Optional, Any from posthog.api.routing import TeamAndOrgViewSetMixin from rest_framework import viewsets, filters, status from rest_framework.decorators import action @@ -47,7 +47,7 @@ def get_table(self, schema: ExternalDataSchema) -> Optional[dict]: return SimpleTableSerializer(schema.table, context={"database": hogql_context}).data or None - def update(self, instance: ExternalDataSchema, validated_data: Dict[str, Any]) -> ExternalDataSchema: + def update(self, instance: ExternalDataSchema, validated_data: dict[str, Any]) -> ExternalDataSchema: should_sync = validated_data.get("should_sync", None) schedule_exists = external_data_workflow_exists(str(instance.id)) @@ -77,7 +77,7 @@ class ExternalDataSchemaViewset(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): search_fields = ["name"] ordering = "-created_at" - def get_serializer_context(self) -> Dict[str, Any]: + def get_serializer_context(self) -> dict[str, Any]: context = super().get_serializer_context() context["database"] = create_hogql_database(team_id=self.team_id) return context diff --git a/posthog/warehouse/api/external_data_source.py b/posthog/warehouse/api/external_data_source.py index 36142a805938c..c8e7031540931 100644 --- a/posthog/warehouse/api/external_data_source.py +++ b/posthog/warehouse/api/external_data_source.py @@ -1,5 +1,5 @@ import uuid -from typing import Any, List, Tuple, Dict +from typing import Any import structlog from rest_framework import filters, serializers, status, viewsets @@ -71,7 +71,7 @@ def get_last_run_at(self, instance: ExternalDataSource) -> str: return latest_completed_run.created_at if latest_completed_run else None def get_status(self, instance: ExternalDataSource) -> str: - active_schemas: List[ExternalDataSchema] = list(instance.schemas.filter(should_sync=True).all()) + active_schemas: list[ExternalDataSchema] = list(instance.schemas.filter(should_sync=True).all()) any_failures = any(schema.status == ExternalDataSchema.Status.ERROR for schema in active_schemas) any_cancelled = any(schema.status == ExternalDataSchema.Status.CANCELLED for schema in active_schemas) any_paused = any(schema.status == ExternalDataSchema.Status.PAUSED for schema in active_schemas) @@ -122,7 +122,7 @@ class ExternalDataSourceViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): search_fields = ["source_id"] ordering = "-created_at" - def get_serializer_context(self) -> Dict[str, Any]: + def get_serializer_context(self) -> dict[str, Any]: context = super().get_serializer_context() context["database"] = create_hogql_database(team_id=self.team_id) return context @@ -193,7 +193,7 @@ def create(self, request: Request, *args: Any, **kwargs: Any) -> Response: disabled_schemas = [schema for schema in default_schemas if schema not in enabled_schemas] - active_schemas: List[ExternalDataSchema] = [] + active_schemas: list[ExternalDataSchema] = [] for schema in enabled_schemas: active_schemas.append( @@ -289,7 +289,7 @@ def _handle_hubspot_source(self, request: Request, *args: Any, **kwargs: Any) -> def _handle_postgres_source( self, request: Request, *args: Any, **kwargs: Any - ) -> Tuple[ExternalDataSource, List[Any]]: + ) -> tuple[ExternalDataSource, list[Any]]: payload = request.data["payload"] prefix = request.data.get("prefix", None) source_type = request.data["source_type"] diff --git a/posthog/warehouse/api/saved_query.py b/posthog/warehouse/api/saved_query.py index f341b5779d0b3..581593377f299 100644 --- a/posthog/warehouse/api/saved_query.py +++ b/posthog/warehouse/api/saved_query.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any from django.conf import settings from rest_framework import exceptions, filters, serializers, viewsets @@ -33,7 +33,7 @@ class Meta: ] read_only_fields = ["id", "created_by", "created_at", "columns"] - def get_columns(self, view: DataWarehouseSavedQuery) -> List[SerializedField]: + def get_columns(self, view: DataWarehouseSavedQuery) -> list[SerializedField]: team_id = self.context["team_id"] context = HogQLContext(team_id=team_id, database=create_hogql_database(team_id=team_id)) diff --git a/posthog/warehouse/api/table.py b/posthog/warehouse/api/table.py index fcfdd7eee8843..7e149b0faba19 100644 --- a/posthog/warehouse/api/table.py +++ b/posthog/warehouse/api/table.py @@ -1,4 +1,4 @@ -from typing import Any, List, Dict +from typing import Any from rest_framework import filters, request, response, serializers, status, viewsets from rest_framework.exceptions import NotAuthenticated @@ -53,7 +53,7 @@ class Meta: ] read_only_fields = ["id", "created_by", "created_at", "columns", "external_data_source", "external_schema"] - def get_columns(self, table: DataWarehouseTable) -> List[SerializedField]: + def get_columns(self, table: DataWarehouseTable) -> list[SerializedField]: hogql_context = self.context.get("database", None) if not hogql_context: hogql_context = create_hogql_database(team_id=self.context["team_id"]) @@ -91,7 +91,7 @@ class Meta: fields = ["id", "name", "columns", "row_count"] read_only_fields = ["id", "name", "columns", "row_count"] - def get_columns(self, table: DataWarehouseTable) -> List[SerializedField]: + def get_columns(self, table: DataWarehouseTable) -> list[SerializedField]: hogql_context = self.context.get("database", None) if not hogql_context: hogql_context = create_hogql_database(team_id=self.context["team_id"]) @@ -111,7 +111,7 @@ class TableViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet): search_fields = ["name"] ordering = "-created_at" - def get_serializer_context(self) -> Dict[str, Any]: + def get_serializer_context(self) -> dict[str, Any]: context = super().get_serializer_context() context["database"] = create_hogql_database(team_id=self.team_id) return context diff --git a/posthog/warehouse/data_load/validate_schema.py b/posthog/warehouse/data_load/validate_schema.py index 6a7e251258391..f3755442d3cea 100644 --- a/posthog/warehouse/data_load/validate_schema.py +++ b/posthog/warehouse/data_load/validate_schema.py @@ -29,13 +29,12 @@ from posthog.temporal.common.logger import bind_temporal_worker_logger from clickhouse_driver.errors import ServerException from asgiref.sync import sync_to_async -from typing import Dict, Type from posthog.utils import camel_to_snake_case from posthog.warehouse.models.external_data_schema import ExternalDataSchema def dlt_to_hogql_type(dlt_type: TDataType | None) -> str: - hogql_type: Type[DatabaseField] = DatabaseField + hogql_type: type[DatabaseField] = DatabaseField if dlt_type is None: hogql_type = StringDatabaseField @@ -69,7 +68,7 @@ def dlt_to_hogql_type(dlt_type: TDataType | None) -> str: async def validate_schema( credential: DataWarehouseCredential, table_name: str, new_url_pattern: str, team_id: int, row_count: int -) -> Dict: +) -> dict: params = { "credential": credential, "name": table_name, @@ -97,7 +96,7 @@ async def validate_schema_and_update_table( team_id: int, schema_id: uuid.UUID, table_schema: TSchemaTables, - table_row_counts: Dict[str, int], + table_row_counts: dict[str, int], ) -> None: """ @@ -167,7 +166,7 @@ async def validate_schema_and_update_table( for schema in table_schema.values(): if schema.get("resource") == _schema_name: schema_columns = schema.get("columns") or {} - db_columns: Dict[str, str] = await sync_to_async(table_created.get_columns)() + db_columns: dict[str, str] = await sync_to_async(table_created.get_columns)() columns = {} for column_name, db_column_type in db_columns.items(): diff --git a/posthog/warehouse/external_data_source/source.py b/posthog/warehouse/external_data_source/source.py index f722bae1f33b4..99e49a39a1df0 100644 --- a/posthog/warehouse/external_data_source/source.py +++ b/posthog/warehouse/external_data_source/source.py @@ -1,5 +1,5 @@ import datetime as dt -from typing import Dict, Optional +from typing import Optional from pydantic import BaseModel, field_validator @@ -71,7 +71,7 @@ def create_stripe_source(payload: StripeSourcePayload, workspace_id: str) -> Ext return _create_source(payload) -def _create_source(payload: Dict) -> ExternalDataSource: +def _create_source(payload: dict) -> ExternalDataSource: response = send_request(AIRBYTE_SOURCE_URL, method="POST", payload=payload) return ExternalDataSource( source_id=response["sourceId"], diff --git a/posthog/warehouse/models/datawarehouse_saved_query.py b/posthog/warehouse/models/datawarehouse_saved_query.py index ffa890ba45b8a..0513cc3b7d1c2 100644 --- a/posthog/warehouse/models/datawarehouse_saved_query.py +++ b/posthog/warehouse/models/datawarehouse_saved_query.py @@ -1,5 +1,4 @@ import re -from typing import Dict from sentry_sdk import capture_exception from django.core.exceptions import ValidationError from django.db import models @@ -47,7 +46,7 @@ class Meta: ) ] - def get_columns(self) -> Dict[str, str]: + def get_columns(self) -> dict[str, str]: from posthog.api.services.query import process_query # TODO: catch and raise error diff --git a/posthog/warehouse/models/external_data_schema.py b/posthog/warehouse/models/external_data_schema.py index ed883f6d623fd..045a4e10d8a98 100644 --- a/posthog/warehouse/models/external_data_schema.py +++ b/posthog/warehouse/models/external_data_schema.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any from django.db import models from posthog.models.team import Team @@ -80,7 +80,7 @@ def sync_old_schemas_with_new_schemas(new_schemas: list, source_id: uuid.UUID, t ExternalDataSchema.objects.create(name=schema, team_id=team_id, source_id=source_id, should_sync=False) -def get_postgres_schemas(host: str, port: str, database: str, user: str, password: str, schema: str) -> List[Any]: +def get_postgres_schemas(host: str, port: str, database: str, user: str, password: str, schema: str) -> list[Any]: connection = psycopg.Connection.connect( host=host, port=int(port), diff --git a/posthog/warehouse/models/external_table_definitions.py b/posthog/warehouse/models/external_table_definitions.py index 405ffa150e6ae..6a684d96eca60 100644 --- a/posthog/warehouse/models/external_table_definitions.py +++ b/posthog/warehouse/models/external_table_definitions.py @@ -1,4 +1,3 @@ -from typing import Dict from posthog.hogql import ast from posthog.hogql.database.models import ( BooleanDatabaseField, @@ -10,7 +9,7 @@ ) -external_tables: Dict[str, Dict[str, FieldOrTable]] = { +external_tables: dict[str, dict[str, FieldOrTable]] = { "*": { "__dlt_id": StringDatabaseField(name="_dlt_id", hidden=True), "__dlt_load_id": StringDatabaseField(name="_dlt_load_id", hidden=True), diff --git a/posthog/warehouse/models/join.py b/posthog/warehouse/models/join.py index 5a3e46658fdbb..d3edfb864c434 100644 --- a/posthog/warehouse/models/join.py +++ b/posthog/warehouse/models/join.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any from warnings import warn from django.db import models @@ -45,7 +45,7 @@ def join_function(self): def _join_function( from_table: str, to_table: str, - requested_fields: Dict[str, Any], + requested_fields: dict[str, Any], context: HogQLContext, node: SelectQuery, ): diff --git a/posthog/warehouse/models/table.py b/posthog/warehouse/models/table.py index 2b4609e79a6ad..229c81168a8d3 100644 --- a/posthog/warehouse/models/table.py +++ b/posthog/warehouse/models/table.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Optional from django.db import models from posthog.client import sync_execute @@ -111,7 +111,7 @@ def table_name_without_prefix(self) -> str: prefix = "" return self.name[len(prefix) :] - def get_columns(self, safe_expose_ch_error=True) -> Dict[str, str]: + def get_columns(self, safe_expose_ch_error=True) -> dict[str, str]: try: result = sync_execute( """DESCRIBE TABLE ( @@ -160,7 +160,7 @@ def hogql_definition(self) -> S3Table: if not self.columns: raise Exception("Columns must be fetched and saved to use in HogQL.") - fields: Dict[str, FieldOrTable] = {} + fields: dict[str, FieldOrTable] = {} structure = [] for column, type in self.columns.items(): # Support for 'old' style columns diff --git a/posthog/year_in_posthog/calculate_2023.py b/posthog/year_in_posthog/calculate_2023.py index 29477cfd15007..03428d2711d30 100644 --- a/posthog/year_in_posthog/calculate_2023.py +++ b/posthog/year_in_posthog/calculate_2023.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Dict, Optional +from typing import Optional from django.conf import settings from django.db import connection @@ -147,7 +147,7 @@ def dictfetchall(cursor): @cache_for(timedelta(seconds=0 if settings.DEBUG else 30)) -def calculate_year_in_posthog_2023(user_uuid: str) -> Optional[Dict]: +def calculate_year_in_posthog_2023(user_uuid: str) -> Optional[dict]: with connection.cursor() as cursor: cursor.execute(query, {"user_uuid": user_uuid}) rows = dictfetchall(cursor) diff --git a/posthog/year_in_posthog/year_in_posthog.py b/posthog/year_in_posthog/year_in_posthog.py index 3bf05d821c27e..a6ac65fa2fdaa 100644 --- a/posthog/year_in_posthog/year_in_posthog.py +++ b/posthog/year_in_posthog/year_in_posthog.py @@ -2,7 +2,7 @@ from django.template.loader import get_template from django.views.decorators.cache import cache_control import os -from typing import Dict, List, Union +from typing import Union import structlog @@ -58,7 +58,7 @@ } -def stats_for_user(data: Dict) -> List[Dict[str, Union[int, str]]]: +def stats_for_user(data: dict) -> list[dict[str, Union[int, str]]]: stats = data["stats"] return [ @@ -75,7 +75,7 @@ def stats_for_user(data: Dict) -> List[Dict[str, Union[int, str]]]: ] -def sort_list_based_on_preference(badges: List[str]) -> str: +def sort_list_based_on_preference(badges: list[str]) -> str: """sort a list based on its order in badge_preferences and then choose the last one""" if len(badges) >= 3: return "champion" diff --git a/pyproject.toml b/pyproject.toml index 2701b5a74d699..cb19ccadb8178 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,6 @@ +[project] +requires-python = ">=3.10" + [tool.black] line-length = 120 target-version = ['py310'] @@ -28,6 +31,8 @@ ignore = [ "F403", "F541", "F601", + "UP007", + "UP032", ] select = [ "B", @@ -40,6 +45,7 @@ select = [ "RUF015", "RUF019", "T2", + "UP", "W", ]