From fb123b41cea8603438f0a263b3234a685b2f9885 Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Mon, 4 Nov 2024 06:58:40 -0500 Subject: [PATCH 01/26] feat(subscriptions): Add create subscriptions RPC --- requirements.txt | 2 +- snuba/subscriptions/data.py | 197 ++++++++++++++++++++++++---- snuba/subscriptions/subscription.py | 4 +- 3 files changed, 176 insertions(+), 27 deletions(-) diff --git a/requirements.txt b/requirements.txt index cc8c0b6c44..a80d4bc316 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,4 +46,4 @@ sqlparse==0.5.0 google-api-python-client==2.88.0 sentry-usage-accountant==0.0.10 freezegun==1.2.2 -sentry-protos==0.1.30 +sentry-protos==0.1.32 diff --git a/snuba/subscriptions/data.py b/snuba/subscriptions/data.py index 4c7d005bbe..82cb2d5a00 100644 --- a/snuba/subscriptions/data.py +++ b/snuba/subscriptions/data.py @@ -1,13 +1,16 @@ from __future__ import annotations +import base64 import logging from abc import ABC, abstractmethod from concurrent.futures import Future from dataclasses import dataclass, field from datetime import datetime, timedelta +from enum import Enum from functools import partial from typing import ( Any, + Generic, Iterator, List, Mapping, @@ -15,10 +18,20 @@ NewType, Optional, Tuple, + TypeVar, Union, ) from uuid import UUID +from google.protobuf.timestamp_pb2 import Timestamp +from sentry_protos.snuba.v1.endpoint_trace_item_table_pb2 import ( + TraceItemTableRequest, + TraceItemTableResponse, +) +from sentry_protos.snuba.v1.endpoint_trace_item_table_subscription_pb2 import ( + CreateTraceItemTableSubscriptionRequest, +) + from snuba.datasets.dataset import Dataset from snuba.datasets.entities.entity_key import EntityKey from snuba.datasets.entities.factory import get_entity @@ -36,13 +49,15 @@ from snuba.query.expressions import Column, Expression, Literal from snuba.query.logical import Query from snuba.query.query_settings import SubscriptionQuerySettings -from snuba.reader import Result +from snuba.reader import Result, Row from snuba.request import Request from snuba.request.schema import RequestSchema from snuba.request.validation import build_request, parse_snql_query from snuba.subscriptions.utils import Tick from snuba.utils.metrics import MetricsBackend from snuba.utils.metrics.timer import Timer +from snuba.web import QueryResult +from snuba.web.rpc import RPCEndpoint SUBSCRIPTION_REFERRER = "subscription" @@ -55,10 +70,18 @@ "tenant_ids", } + +class SubscriptionType(Enum): + SNQL = "snql" + RPC = "rpc" + + logger = logging.getLogger("snuba.subscriptions") PartitionId = NewType("PartitionId", int) +TRequest = TypeVar("TRequest") + @dataclass(frozen=True) class SubscriptionIdentifier: @@ -75,7 +98,7 @@ def from_string(cls, value: str) -> SubscriptionIdentifier: @dataclass(frozen=True, kw_only=True) -class SubscriptionData(ABC): +class _SubscriptionData(ABC, Generic[TRequest]): project_id: int resolution_sec: int time_window_sec: int @@ -83,9 +106,20 @@ class SubscriptionData(ABC): metadata: Mapping[str, Any] tenant_ids: Mapping[str, Any] = field(default_factory=lambda: dict()) - @abstractmethod def validate(self) -> None: - raise NotImplementedError + if self.time_window_sec < 60: + raise InvalidSubscriptionError( + "Time window must be greater than or equal to 1 minute" + ) + elif self.time_window_sec > 60 * 60 * 24: + raise InvalidSubscriptionError( + "Time window must be less than or equal to 24 hours" + ) + + if self.resolution_sec < 60: + raise InvalidSubscriptionError( + "Resolution must be greater than or equal to 1 minute" + ) @abstractmethod def build_request( @@ -96,14 +130,14 @@ def build_request( timer: Timer, metrics: Optional[MetricsBackend] = None, referrer: str = SUBSCRIPTION_REFERRER, - ) -> Request: + ) -> TRequest: raise NotImplementedError @classmethod @abstractmethod def from_dict( cls, data: Mapping[str, Any], entity_key: EntityKey - ) -> SubscriptionData: + ) -> _SubscriptionData: raise NotImplementedError @abstractmethod @@ -112,7 +146,7 @@ def to_dict(self) -> Mapping[str, Any]: @dataclass(frozen=True, kw_only=True) -class SnQLSubscriptionData(SubscriptionData): +class SnQLSubscriptionData(_SubscriptionData[Request]): """ Represents the state of a subscription. """ @@ -187,21 +221,6 @@ def add_conditions( "At least one Entity must have a timestamp column for subscriptions" ) - def validate(self) -> None: - if self.time_window_sec < 60: - raise InvalidSubscriptionError( - "Time window must be greater than or equal to 1 minute" - ) - elif self.time_window_sec > 60 * 60 * 24: - raise InvalidSubscriptionError( - "Time window must be less than or equal to 24 hours" - ) - - if self.resolution_sec < 60: - raise InvalidSubscriptionError( - "Resolution must be greater than or equal to 1 minute" - ) - def build_request( self, dataset: Dataset, @@ -246,7 +265,7 @@ def build_request( @classmethod def from_dict( cls, data: Mapping[str, Any], entity_key: EntityKey - ) -> SubscriptionData: + ) -> SnQLSubscriptionData: entity: Entity = get_entity(entity_key) metadata = {} @@ -271,6 +290,7 @@ def to_dict(self) -> Mapping[str, Any]: "time_window": self.time_window_sec, "resolution": self.resolution_sec, "query": self.query, + "subscription_type": SubscriptionType.SNQL.value, } subscription_processors = self.entity.get_subscription_processors() @@ -280,6 +300,133 @@ def to_dict(self) -> Mapping[str, Any]: return subscription_data_dict +@dataclass(frozen=True, kw_only=True) +class RPCSubscriptionData(_SubscriptionData[TraceItemTableRequest]): + """ + Represents the state of an RPC subscription. + """ + + trace_item_table_request: str + proto_name: str + proto_version: str + + def validate(self): + super().validate() + # TODO: Validate no group by, having, order by etc + + def build_request( + self, + dataset: Dataset, + timestamp: datetime, + offset: Optional[int], + timer: Timer, + metrics: Optional[MetricsBackend] = None, + referrer: str = SUBSCRIPTION_REFERRER, + ) -> TraceItemTableRequest: + + table_request = TraceItemTableRequest() + table_request.ParseFromString(base64.b64decode(self.trace_item_table_request)) + start_time_proto = Timestamp() + start_time_proto.FromDatetime(timestamp - timedelta(self.time_window_sec)) + end_time_proto = Timestamp() + end_time_proto.FromDatetime(timestamp) + table_request.meta.start_timestamp.CopyFrom(start_time_proto) + table_request.meta.end_timestamp.CopyFrom(end_time_proto) + + return table_request + + def run_query( + self, + request, + ): + enpdoint = RPCEndpoint.get_from_name(self.proto_name, self.proto_version)() + response: TraceItemTableResponse = enpdoint.execute(request) + + column_values = response.column_values + num_rows = len(column_values[0].results) + data = [] + for i in range(num_rows): + data_row: Row = {} + for column in column_values: + value_key = column.results[i].WhichOneof("value") + if value_key: + value = getattr(column.results[i], value_key) + else: + value = None + + data_row[column.attribute_name] = value + + data.append(data_row) + + result: Result = {"meta": [], "data": data, "trace_output": ""} + return QueryResult( + result=result, extra={"stats": {}, "sql": "", "experiments": {}} + ) + + @classmethod + def from_dict( + cls, data: Mapping[str, Any], entity_key: EntityKey + ) -> RPCSubscriptionData: + entity: Entity = get_entity(entity_key) + + metadata = {} + for key in data.keys(): + if key not in SUBSCRIPTION_DATA_PAYLOAD_KEYS: + metadata[key] = data[key] + + return RPCSubscriptionData( + project_id=data["project_id"], + time_window_sec=int(data["time_window"]), + resolution_sec=int(data["resolution"]), + trace_item_table_request=data["trace_item_table_request"], + proto_version=data["proto_version"], + proto_name=data["proto_name"], + entity=entity, + metadata=metadata, + tenant_ids=data.get("tenant_ids", dict()), + ) + + @classmethod + def from_proto( + cls, item: CreateTraceItemTableSubscriptionRequest, entity_key: EntityKey + ) -> RPCSubscriptionData: + entity: Entity = get_entity(entity_key) + cls = item.table_request.__class__ + class_name = cls.__name__ + class_version = cls.__module__.split(".", 3)[2] + + return RPCSubscriptionData( + project_id=item.project_id, + time_window_sec=item.time_window, + resolution_sec=item.resolution, + trace_item_table_request=base64.b64encode( + item.table_request.SerializeToString() + ).decode("utf-8"), + entity=entity, + metadata={}, + tenant_ids={}, + proto_version=class_version, + proto_name=class_name, + ) + + def to_dict(self) -> Mapping[str, Any]: + subscription_data_dict = { + "project_id": self.project_id, + "time_window": self.time_window_sec, + "resolution": self.resolution_sec, + "trace_item_table_request": self.trace_item_table_request, + "proto_version": self.proto_version, + "proto_name": self.proto_name, + "subscription_type": SubscriptionType.RPC.value, + } + + return subscription_data_dict + + +SubscriptionData = Union[RPCSubscriptionData, SnQLSubscriptionData] +SubscriptionRequest = Union[Request, TraceItemTableRequest] + + class Subscription(NamedTuple): identifier: SubscriptionIdentifier data: SubscriptionData @@ -325,9 +472,9 @@ def find(self, tick: Tick) -> Iterator[ScheduledSubscriptionTask]: class SubscriptionTaskResultFuture(NamedTuple): task: ScheduledSubscriptionTask - future: Future[Tuple[Request, Result]] + future: Future[Tuple[SubscriptionRequest, Result]] class SubscriptionTaskResult(NamedTuple): task: ScheduledSubscriptionTask - result: Tuple[Request, Result] + result: Tuple[SubscriptionRequest, Result] diff --git a/snuba/subscriptions/subscription.py b/snuba/subscriptions/subscription.py index 021de963fc..2cb3f11836 100644 --- a/snuba/subscriptions/subscription.py +++ b/snuba/subscriptions/subscription.py @@ -51,7 +51,9 @@ def create(self, data: SubscriptionData, timer: Timer) -> SubscriptionIdentifier return identifier def _test_request(self, data: SubscriptionData, timer: Timer) -> None: - request = data.build_request(self.dataset, datetime.utcnow(), None, timer) + request = data.build_request_and_run_query( + self.dataset, datetime.utcnow(), None, timer + ) run_query(self.dataset, request, timer) From 0f862e93a1fdd4ccddca3e338827d1fc87c16e25 Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Mon, 4 Nov 2024 07:00:55 -0500 Subject: [PATCH 02/26] create subscriptions file --- snuba/web/rpc/v1/create_subscription.py | 46 +++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 snuba/web/rpc/v1/create_subscription.py diff --git a/snuba/web/rpc/v1/create_subscription.py b/snuba/web/rpc/v1/create_subscription.py new file mode 100644 index 0000000000..f90518728e --- /dev/null +++ b/snuba/web/rpc/v1/create_subscription.py @@ -0,0 +1,46 @@ +from typing import Type + +from sentry_protos.snuba.v1.endpoint_trace_item_table_subscription_pb2 import ( + CreateTraceItemTableSubscriptionRequest as CreateSubscriptionRequestProto, +) +from sentry_protos.snuba.v1.endpoint_trace_item_table_subscription_pb2 import ( + CreateTraceItemTableSubscriptionResponse, +) + +from snuba.datasets.entities.entity_key import EntityKey +from snuba.datasets.pluggable_dataset import PluggableDataset +from snuba.subscriptions.data import RPCSubscriptionData +from snuba.web.rpc import RPCEndpoint + + +class CreateTraceItemTableSubscriptionRequest( + RPCEndpoint[ + CreateSubscriptionRequestProto, CreateTraceItemTableSubscriptionResponse + ] +): + @classmethod + def version(cls) -> str: + return "v1" + + @classmethod + def request_class(cls) -> Type[CreateSubscriptionRequestProto]: + return CreateSubscriptionRequestProto + + @classmethod + def response_class(cls) -> Type[CreateTraceItemTableSubscriptionResponse]: + return CreateTraceItemTableSubscriptionResponse + + def _execute( + self, in_msg: CreateSubscriptionRequestProto + ) -> CreateTraceItemTableSubscriptionResponse: + from snuba.subscriptions.subscription import SubscriptionCreator + + dataset = PluggableDataset(name="eap", all_entities=[]) + entity_key = EntityKey("eap_spans") + + subscription = RPCSubscriptionData.from_proto(in_msg, entity_key=entity_key) + identifier = SubscriptionCreator(dataset, entity_key).create( + subscription, self._timer + ) + + return CreateTraceItemTableSubscriptionResponse(subscription_id=str(identifier)) From 0e90ffcfca93a19e0ba0b291761e2a3e72fb187b Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Mon, 4 Nov 2024 08:56:29 -0500 Subject: [PATCH 03/26] fix typing --- snuba/subscriptions/codecs.py | 25 +++- snuba/subscriptions/data.py | 140 ++++++++++++++--------- snuba/subscriptions/executor_consumer.py | 9 +- snuba/subscriptions/subscription.py | 7 +- snuba/web/rpc/v1/create_subscription.py | 22 ++-- tests/subscriptions/test_data.py | 5 +- 6 files changed, 126 insertions(+), 82 deletions(-) diff --git a/snuba/subscriptions/codecs.py b/snuba/subscriptions/codecs.py index b565e32579..b31214578a 100644 --- a/snuba/subscriptions/codecs.py +++ b/snuba/subscriptions/codecs.py @@ -4,17 +4,21 @@ import rapidjson from arroyo.backends.kafka import KafkaPayload +from google.protobuf.json_format import MessageToDict +from google.protobuf.message import Message as ProtobufMessage from sentry_kafka_schemas.schema_types import events_subscription_results_v1 from snuba.datasets.entities.entity_key import EntityKey from snuba.query.exceptions import InvalidQueryException from snuba.subscriptions.data import ( + RPCSubscriptionData, ScheduledSubscriptionTask, SnQLSubscriptionData, Subscription, SubscriptionData, SubscriptionIdentifier, SubscriptionTaskResult, + SubscriptionType, SubscriptionWithMetadata, ) from snuba.utils.codecs import Codec, Encoder @@ -33,6 +37,9 @@ def decode(self, value: bytes) -> SubscriptionData: except json.JSONDecodeError: raise InvalidQueryException("Invalid JSON") + if data.get("subscription_type") == SubscriptionType.RPC: + return RPCSubscriptionData.from_dict(data, self.entity_key) + return SnQLSubscriptionData.from_dict(data, self.entity_key) @@ -42,11 +49,16 @@ def encode(self, value: SubscriptionTaskResult) -> KafkaPayload: subscription_id = str(subscription.identifier) request, result = value.result + if isinstance(request, ProtobufMessage): + original_body = {**MessageToDict(request)} + else: + original_body = {**request.original_body} + data: events_subscription_results_v1.SubscriptionResult = { "version": 3, "payload": { "subscription_id": subscription_id, - "request": {**request.original_body}, + "request": original_body, "result": { "data": result["data"], "meta": result["meta"], @@ -98,15 +110,20 @@ def decode(self, value: KafkaPayload) -> ScheduledSubscriptionTask: entity_key = EntityKey(scheduled_subscription_dict["entity"]) + data = scheduled_subscription_dict["task"]["data"] + subscription: SubscriptionData + if data.get("subscription_type") == SubscriptionType.RPC: + subscription = RPCSubscriptionData.from_dict(data, entity_key) + else: + subscription = SnQLSubscriptionData.from_dict(data, entity_key) + return ScheduledSubscriptionTask( datetime.fromisoformat(scheduled_subscription_dict["timestamp"]), SubscriptionWithMetadata( entity_key, Subscription( SubscriptionIdentifier.from_string(subscription_identifier), - SnQLSubscriptionData.from_dict( - scheduled_subscription_dict["task"]["data"], entity_key - ), + subscription, ), scheduled_subscription_dict["tick_upper_offset"], ), diff --git a/snuba/subscriptions/data.py b/snuba/subscriptions/data.py index 82cb2d5a00..a13f1415ae 100644 --- a/snuba/subscriptions/data.py +++ b/snuba/subscriptions/data.py @@ -24,12 +24,12 @@ from uuid import UUID from google.protobuf.timestamp_pb2 import Timestamp -from sentry_protos.snuba.v1.endpoint_trace_item_table_pb2 import ( - TraceItemTableRequest, - TraceItemTableResponse, +from sentry_protos.snuba.v1.endpoint_create_subscription_pb2 import ( + CreateSubscriptionRequest, ) -from sentry_protos.snuba.v1.endpoint_trace_item_table_subscription_pb2 import ( - CreateTraceItemTableSubscriptionRequest, +from sentry_protos.snuba.v1.endpoint_time_series_pb2 import ( + TimeSeriesRequest, + TimeSeriesResponse, ) from snuba.datasets.dataset import Dataset @@ -49,14 +49,16 @@ from snuba.query.expressions import Column, Expression, Literal from snuba.query.logical import Query from snuba.query.query_settings import SubscriptionQuerySettings -from snuba.reader import Result, Row +from snuba.reader import Result from snuba.request import Request from snuba.request.schema import RequestSchema from snuba.request.validation import build_request, parse_snql_query from snuba.subscriptions.utils import Tick from snuba.utils.metrics import MetricsBackend +from snuba.utils.metrics.gauge import Gauge from snuba.utils.metrics.timer import Timer from snuba.web import QueryResult +from snuba.web.query import run_query from snuba.web.rpc import RPCEndpoint SUBSCRIPTION_REFERRER = "subscription" @@ -70,6 +72,9 @@ "tenant_ids", } +PROTOBUF_ALLOWLIST = ["TimeSeriesRequest"] +PROTOBUF_VERSION_ALLOWLIST = ["v1"] + class SubscriptionType(Enum): SNQL = "snql" @@ -133,11 +138,22 @@ def build_request( ) -> TRequest: raise NotImplementedError + @abstractmethod + def run_query( + self, + dataset: Dataset, + request: TRequest, + timer: Timer, + robust: bool = False, + concurrent_queries_gauge: Optional[Gauge] = None, + ) -> QueryResult: + raise NotImplementedError + @classmethod @abstractmethod def from_dict( cls, data: Mapping[str, Any], entity_key: EntityKey - ) -> _SubscriptionData: + ) -> _SubscriptionData[TRequest]: raise NotImplementedError @abstractmethod @@ -262,6 +278,22 @@ def build_request( ) return request + def run_query( + self, + dataset: Dataset, + request: Request, + timer: Timer, + robust: bool = False, + concurrent_queries_gauge: Optional[Gauge] = None, + ) -> QueryResult: + return run_query( + dataset, + request, + timer, + robust=robust, + concurrent_queries_gauge=concurrent_queries_gauge, + ) + @classmethod def from_dict( cls, data: Mapping[str, Any], entity_key: EntityKey @@ -301,17 +333,33 @@ def to_dict(self) -> Mapping[str, Any]: @dataclass(frozen=True, kw_only=True) -class RPCSubscriptionData(_SubscriptionData[TraceItemTableRequest]): +class RPCSubscriptionData(_SubscriptionData[TimeSeriesRequest]): """ Represents the state of an RPC subscription. """ - trace_item_table_request: str + time_series_request: str proto_name: str proto_version: str - def validate(self): + @property + def endpoint(self) -> RPCEndpoint[TimeSeriesRequest, TimeSeriesResponse]: + return RPCEndpoint[TimeSeriesRequest, TimeSeriesResponse].get_from_name( + self.proto_name, self.proto_version + )() + + def validate(self) -> None: super().validate() + if self.proto_name not in PROTOBUF_ALLOWLIST: + raise InvalidSubscriptionError( + f"{self.proto_name} is not supported. Supported request types are: {PROTOBUF_ALLOWLIST}" + ) + + if self.proto_version not in PROTOBUF_VERSION_ALLOWLIST: + raise InvalidSubscriptionError( + f"{self.proto_version} version not supported. Supported versions are: {PROTOBUF_VERSION_ALLOWLIST}" + ) + # TODO: Validate no group by, having, order by etc def build_request( @@ -322,41 +370,32 @@ def build_request( timer: Timer, metrics: Optional[MetricsBackend] = None, referrer: str = SUBSCRIPTION_REFERRER, - ) -> TraceItemTableRequest: + ) -> TimeSeriesRequest: + + request_class = self.endpoint.request_class()() - table_request = TraceItemTableRequest() - table_request.ParseFromString(base64.b64decode(self.trace_item_table_request)) + request_class.ParseFromString(base64.b64decode(self.time_series_request)) start_time_proto = Timestamp() start_time_proto.FromDatetime(timestamp - timedelta(self.time_window_sec)) end_time_proto = Timestamp() end_time_proto.FromDatetime(timestamp) - table_request.meta.start_timestamp.CopyFrom(start_time_proto) - table_request.meta.end_timestamp.CopyFrom(end_time_proto) + request_class.meta.start_timestamp.CopyFrom(start_time_proto) + request_class.meta.end_timestamp.CopyFrom(end_time_proto) - return table_request + return request_class def run_query( self, - request, - ): - enpdoint = RPCEndpoint.get_from_name(self.proto_name, self.proto_version)() - response: TraceItemTableResponse = enpdoint.execute(request) - - column_values = response.column_values - num_rows = len(column_values[0].results) - data = [] - for i in range(num_rows): - data_row: Row = {} - for column in column_values: - value_key = column.results[i].WhichOneof("value") - if value_key: - value = getattr(column.results[i], value_key) - else: - value = None - - data_row[column.attribute_name] = value - - data.append(data_row) + dataset: Dataset, + request: TimeSeriesRequest, + timer: Timer, + robust: bool = False, + concurrent_queries_gauge: Optional[Gauge] = None, + ) -> QueryResult: + response = self.endpoint.execute(request) + + timeseries = response.result_timeseries[0] + data = [{timeseries.label: timeseries.data_points[0].data}] result: Result = {"meta": [], "data": data, "trace_output": ""} return QueryResult( @@ -369,38 +408,33 @@ def from_dict( ) -> RPCSubscriptionData: entity: Entity = get_entity(entity_key) - metadata = {} - for key in data.keys(): - if key not in SUBSCRIPTION_DATA_PAYLOAD_KEYS: - metadata[key] = data[key] - return RPCSubscriptionData( project_id=data["project_id"], time_window_sec=int(data["time_window"]), resolution_sec=int(data["resolution"]), - trace_item_table_request=data["trace_item_table_request"], + time_series_request=data["time_series_request"], proto_version=data["proto_version"], proto_name=data["proto_name"], entity=entity, - metadata=metadata, + metadata={}, tenant_ids=data.get("tenant_ids", dict()), ) @classmethod def from_proto( - cls, item: CreateTraceItemTableSubscriptionRequest, entity_key: EntityKey + cls, item: CreateSubscriptionRequest, entity_key: EntityKey ) -> RPCSubscriptionData: entity: Entity = get_entity(entity_key) - cls = item.table_request.__class__ - class_name = cls.__name__ - class_version = cls.__module__.split(".", 3)[2] + request_class = item.time_series_request.__class__ + class_name = request_class.__name__ + class_version = request_class.__module__.split(".", 3)[2] return RPCSubscriptionData( project_id=item.project_id, - time_window_sec=item.time_window, - resolution_sec=item.resolution, - trace_item_table_request=base64.b64encode( - item.table_request.SerializeToString() + time_window_sec=item.time_window_secs, + resolution_sec=item.resolution_secs, + time_series_request=base64.b64encode( + item.time_series_request.SerializeToString() ).decode("utf-8"), entity=entity, metadata={}, @@ -414,7 +448,7 @@ def to_dict(self) -> Mapping[str, Any]: "project_id": self.project_id, "time_window": self.time_window_sec, "resolution": self.resolution_sec, - "trace_item_table_request": self.trace_item_table_request, + "time_series_request": self.time_series_request, "proto_version": self.proto_version, "proto_name": self.proto_name, "subscription_type": SubscriptionType.RPC.value, @@ -424,7 +458,7 @@ def to_dict(self) -> Mapping[str, Any]: SubscriptionData = Union[RPCSubscriptionData, SnQLSubscriptionData] -SubscriptionRequest = Union[Request, TraceItemTableRequest] +SubscriptionRequest = Union[Request, TimeSeriesRequest] class Subscription(NamedTuple): diff --git a/snuba/subscriptions/executor_consumer.py b/snuba/subscriptions/executor_consumer.py index 6273320a79..a38340c6dc 100644 --- a/snuba/subscriptions/executor_consumer.py +++ b/snuba/subscriptions/executor_consumer.py @@ -29,13 +29,13 @@ from snuba.datasets.factory import get_dataset from snuba.datasets.table_storage import KafkaTopicSpec from snuba.reader import Result -from snuba.request import Request from snuba.subscriptions.codecs import ( SubscriptionScheduledTaskEncoder, SubscriptionTaskResultEncoder, ) from snuba.subscriptions.data import ( ScheduledSubscriptionTask, + SubscriptionRequest, SubscriptionTaskResult, SubscriptionTaskResultFuture, ) @@ -46,7 +46,6 @@ from snuba.utils.streams.topics import Topic as SnubaTopic from snuba.web import QueryException from snuba.web.constants import NON_RETRYABLE_CLICKHOUSE_ERROR_CODES -from snuba.web.query import run_query logger = logging.getLogger(__name__) @@ -255,7 +254,7 @@ def __init__( def __execute_query( self, task: ScheduledSubscriptionTask, tick_upper_offset: int - ) -> Tuple[Request, Result]: + ) -> Tuple[SubscriptionRequest, Result]: # Measure the amount of time that took between the task's scheduled # time and it beginning to execute. self.__metrics.timing( @@ -274,9 +273,9 @@ def __execute_query( "subscriptions_executor", ) - result = run_query( + result = task.task.subscription.data.run_query( self.__dataset, - request, + request, # type: ignore timer, robust=True, concurrent_queries_gauge=self.__concurrent_clickhouse_gauge, diff --git a/snuba/subscriptions/subscription.py b/snuba/subscriptions/subscription.py index 2cb3f11836..be13f0c0bc 100644 --- a/snuba/subscriptions/subscription.py +++ b/snuba/subscriptions/subscription.py @@ -13,7 +13,6 @@ from snuba.subscriptions.partitioner import TopicSubscriptionDataPartitioner from snuba.subscriptions.store import RedisSubscriptionDataStore from snuba.utils.metrics.timer import Timer -from snuba.web.query import run_query redis_client = get_redis_client(RedisClientKey.SUBSCRIPTION_STORE) @@ -51,10 +50,8 @@ def create(self, data: SubscriptionData, timer: Timer) -> SubscriptionIdentifier return identifier def _test_request(self, data: SubscriptionData, timer: Timer) -> None: - request = data.build_request_and_run_query( - self.dataset, datetime.utcnow(), None, timer - ) - run_query(self.dataset, request, timer) + request = data.build_request(self.dataset, datetime.utcnow(), None, timer) + data.run_query(self.dataset, request, timer) # type: ignore class SubscriptionDeleter: diff --git a/snuba/web/rpc/v1/create_subscription.py b/snuba/web/rpc/v1/create_subscription.py index f90518728e..48fa9f589e 100644 --- a/snuba/web/rpc/v1/create_subscription.py +++ b/snuba/web/rpc/v1/create_subscription.py @@ -1,10 +1,10 @@ from typing import Type -from sentry_protos.snuba.v1.endpoint_trace_item_table_subscription_pb2 import ( - CreateTraceItemTableSubscriptionRequest as CreateSubscriptionRequestProto, +from sentry_protos.snuba.v1.endpoint_create_subscription_pb2 import ( + CreateSubscriptionRequest as CreateSubscriptionRequestProto, ) -from sentry_protos.snuba.v1.endpoint_trace_item_table_subscription_pb2 import ( - CreateTraceItemTableSubscriptionResponse, +from sentry_protos.snuba.v1.endpoint_create_subscription_pb2 import ( + CreateSubscriptionResponse, ) from snuba.datasets.entities.entity_key import EntityKey @@ -13,10 +13,8 @@ from snuba.web.rpc import RPCEndpoint -class CreateTraceItemTableSubscriptionRequest( - RPCEndpoint[ - CreateSubscriptionRequestProto, CreateTraceItemTableSubscriptionResponse - ] +class CreateSubscriptionRequest( + RPCEndpoint[CreateSubscriptionRequestProto, CreateSubscriptionResponse] ): @classmethod def version(cls) -> str: @@ -27,12 +25,12 @@ def request_class(cls) -> Type[CreateSubscriptionRequestProto]: return CreateSubscriptionRequestProto @classmethod - def response_class(cls) -> Type[CreateTraceItemTableSubscriptionResponse]: - return CreateTraceItemTableSubscriptionResponse + def response_class(cls) -> Type[CreateSubscriptionResponse]: + return CreateSubscriptionResponse def _execute( self, in_msg: CreateSubscriptionRequestProto - ) -> CreateTraceItemTableSubscriptionResponse: + ) -> CreateSubscriptionResponse: from snuba.subscriptions.subscription import SubscriptionCreator dataset = PluggableDataset(name="eap", all_entities=[]) @@ -43,4 +41,4 @@ def _execute( subscription, self._timer ) - return CreateTraceItemTableSubscriptionResponse(subscription_id=str(identifier)) + return CreateSubscriptionResponse(subscription_id=str(identifier)) diff --git a/tests/subscriptions/test_data.py b/tests/subscriptions/test_data.py index c0ce08af7a..98b301b8e2 100644 --- a/tests/subscriptions/test_data.py +++ b/tests/subscriptions/test_data.py @@ -9,7 +9,6 @@ from snuba.query.exceptions import InvalidQueryException from snuba.subscriptions.data import SnQLSubscriptionData, SubscriptionData from snuba.utils.metrics.timer import Timer -from snuba.web.query import run_query from tests.subscriptions import BaseSubscriptionTest TESTS = [ @@ -102,7 +101,7 @@ def compare_conditions( 100, timer, ) - run_query(self.dataset, request, timer) + subscription.run_query(self.dataset, request, timer) # type: ignore return request = subscription.build_request( @@ -111,7 +110,7 @@ def compare_conditions( 100, timer, ) - result = run_query(self.dataset, request, timer) + result = subscription.run_query(self.dataset, request, timer) # type: ignore assert result.result["data"][0][aggregate] == value From efa8aaa985ce4501430126927e7b1b1d61e5238a Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Wed, 6 Nov 2024 15:14:08 -0500 Subject: [PATCH 04/26] tests --- snuba/datasets/slicing.py | 2 +- snuba/subscriptions/codecs.py | 2 +- snuba/subscriptions/data.py | 72 ++++---- snuba/web/rpc/v1/create_subscription.py | 2 +- .../test_filter_subscriptions.py | 89 ++++++++++ .../subscriptions/test_scheduler_consumer.py | 164 ++++++++++++++++++ tests/subscriptions/test_subscription.py | 87 +++++++++- tests/web/rpc/v1/test_create_subscription.py | 109 ++++++++++++ 8 files changed, 493 insertions(+), 34 deletions(-) create mode 100644 tests/web/rpc/v1/test_create_subscription.py diff --git a/snuba/datasets/slicing.py b/snuba/datasets/slicing.py index c4c0597367..60a0e78c10 100644 --- a/snuba/datasets/slicing.py +++ b/snuba/datasets/slicing.py @@ -3,6 +3,7 @@ should be stored. These do not require individual physical partitions but allow for repartitioning with less code changes per physical change. """ + from snuba.clusters.storage_sets import StorageSetKey SENTRY_LOGICAL_PARTITIONS = 256 @@ -30,7 +31,6 @@ def map_logical_partition_to_slice( assert ( storage_set.value in LOGICAL_PARTITION_MAPPING ), f"logical partition mapping missing for storage set {storage_set}" - return LOGICAL_PARTITION_MAPPING[storage_set.value][logical_partition] diff --git a/snuba/subscriptions/codecs.py b/snuba/subscriptions/codecs.py index b31214578a..72a2d16101 100644 --- a/snuba/subscriptions/codecs.py +++ b/snuba/subscriptions/codecs.py @@ -37,7 +37,7 @@ def decode(self, value: bytes) -> SubscriptionData: except json.JSONDecodeError: raise InvalidQueryException("Invalid JSON") - if data.get("subscription_type") == SubscriptionType.RPC: + if data.get("subscription_type") == SubscriptionType.RPC.value: return RPCSubscriptionData.from_dict(data, self.entity_key) return SnQLSubscriptionData.from_dict(data, self.entity_key) diff --git a/snuba/subscriptions/data.py b/snuba/subscriptions/data.py index a13f1415ae..495de11800 100644 --- a/snuba/subscriptions/data.py +++ b/snuba/subscriptions/data.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from concurrent.futures import Future from dataclasses import dataclass, field -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta from enum import Enum from functools import partial from typing import ( @@ -27,10 +27,7 @@ from sentry_protos.snuba.v1.endpoint_create_subscription_pb2 import ( CreateSubscriptionRequest, ) -from sentry_protos.snuba.v1.endpoint_time_series_pb2 import ( - TimeSeriesRequest, - TimeSeriesResponse, -) +from sentry_protos.snuba.v1.endpoint_time_series_pb2 import TimeSeriesRequest from snuba.datasets.dataset import Dataset from snuba.datasets.entities.entity_key import EntityKey @@ -59,7 +56,7 @@ from snuba.utils.metrics.timer import Timer from snuba.web import QueryResult from snuba.web.query import run_query -from snuba.web.rpc import RPCEndpoint +from snuba.web.rpc.v1.endpoint_time_series import EndpointTimeSeries SUBSCRIPTION_REFERRER = "subscription" @@ -339,25 +336,20 @@ class RPCSubscriptionData(_SubscriptionData[TimeSeriesRequest]): """ time_series_request: str - proto_name: str - proto_version: str - @property - def endpoint(self) -> RPCEndpoint[TimeSeriesRequest, TimeSeriesResponse]: - return RPCEndpoint[TimeSeriesRequest, TimeSeriesResponse].get_from_name( - self.proto_name, self.proto_version - )() + request_name: str + request_version: str def validate(self) -> None: super().validate() - if self.proto_name not in PROTOBUF_ALLOWLIST: + if self.request_name not in PROTOBUF_ALLOWLIST: raise InvalidSubscriptionError( - f"{self.proto_name} is not supported. Supported request types are: {PROTOBUF_ALLOWLIST}" + f"{self.request_name} is not supported. Supported request types are: {PROTOBUF_ALLOWLIST}" ) - if self.proto_version not in PROTOBUF_VERSION_ALLOWLIST: + if self.request_version not in PROTOBUF_VERSION_ALLOWLIST: raise InvalidSubscriptionError( - f"{self.proto_version} version not supported. Supported versions are: {PROTOBUF_VERSION_ALLOWLIST}" + f"{self.request_version} version not supported. Supported versions are: {PROTOBUF_VERSION_ALLOWLIST}" ) # TODO: Validate no group by, having, order by etc @@ -372,13 +364,23 @@ def build_request( referrer: str = SUBSCRIPTION_REFERRER, ) -> TimeSeriesRequest: - request_class = self.endpoint.request_class()() - + request_class = EndpointTimeSeries().request_class()() request_class.ParseFromString(base64.b64decode(self.time_series_request)) + + # TODO: update it to round to the lowest granularity + # rounded_ts = int(timestamp.replace(tzinfo=UTC).timestamp() / 15) * 15 + rounded_ts = ( + int(timestamp.replace(tzinfo=UTC).timestamp() / self.time_window_sec) + * self.time_window_sec + ) + rounded_start = datetime.utcfromtimestamp(rounded_ts) + start_time_proto = Timestamp() - start_time_proto.FromDatetime(timestamp - timedelta(self.time_window_sec)) + start_time_proto.FromDatetime( + rounded_start - timedelta(seconds=self.time_window_sec) + ) end_time_proto = Timestamp() - end_time_proto.FromDatetime(timestamp) + end_time_proto.FromDatetime(rounded_start) request_class.meta.start_timestamp.CopyFrom(start_time_proto) request_class.meta.end_timestamp.CopyFrom(end_time_proto) @@ -392,7 +394,12 @@ def run_query( robust: bool = False, concurrent_queries_gauge: Optional[Gauge] = None, ) -> QueryResult: - response = self.endpoint.execute(request) + response = EndpointTimeSeries().execute(request) + if not response.result_timeseries: + result: Result = {"meta": [], "data": [], "trace_output": ""} + return QueryResult( + result=result, extra={"stats": {}, "sql": "", "experiments": {}} + ) timeseries = response.result_timeseries[0] data = [{timeseries.label: timeseries.data_points[0].data}] @@ -413,10 +420,10 @@ def from_dict( time_window_sec=int(data["time_window"]), resolution_sec=int(data["resolution"]), time_series_request=data["time_series_request"], - proto_version=data["proto_version"], - proto_name=data["proto_name"], + request_version=data["request_version"], + request_name=data["request_name"], entity=entity, - metadata={}, + metadata=data.get("metadata", dict()), tenant_ids=data.get("tenant_ids", dict()), ) @@ -429,6 +436,10 @@ def from_proto( class_name = request_class.__name__ class_version = request_class.__module__.split(".", 3)[2] + metadata = dict() + if item.time_series_request.meta: + metadata["organization"] = item.time_series_request.meta.organization_id + return RPCSubscriptionData( project_id=item.project_id, time_window_sec=item.time_window_secs, @@ -437,10 +448,10 @@ def from_proto( item.time_series_request.SerializeToString() ).decode("utf-8"), entity=entity, - metadata={}, + metadata=metadata, tenant_ids={}, - proto_version=class_version, - proto_name=class_name, + request_version=class_version, + request_name=class_name, ) def to_dict(self) -> Mapping[str, Any]: @@ -449,9 +460,10 @@ def to_dict(self) -> Mapping[str, Any]: "time_window": self.time_window_sec, "resolution": self.resolution_sec, "time_series_request": self.time_series_request, - "proto_version": self.proto_version, - "proto_name": self.proto_name, + "request_version": self.request_version, + "request_name": self.request_name, "subscription_type": SubscriptionType.RPC.value, + "metadata": self.metadata, } return subscription_data_dict diff --git a/snuba/web/rpc/v1/create_subscription.py b/snuba/web/rpc/v1/create_subscription.py index 48fa9f589e..950943205b 100644 --- a/snuba/web/rpc/v1/create_subscription.py +++ b/snuba/web/rpc/v1/create_subscription.py @@ -9,7 +9,6 @@ from snuba.datasets.entities.entity_key import EntityKey from snuba.datasets.pluggable_dataset import PluggableDataset -from snuba.subscriptions.data import RPCSubscriptionData from snuba.web.rpc import RPCEndpoint @@ -31,6 +30,7 @@ def response_class(cls) -> Type[CreateSubscriptionResponse]: def _execute( self, in_msg: CreateSubscriptionRequestProto ) -> CreateSubscriptionResponse: + from snuba.subscriptions.data import RPCSubscriptionData from snuba.subscriptions.subscription import SubscriptionCreator dataset = PluggableDataset(name="eap", all_entities=[]) diff --git a/tests/subscriptions/test_filter_subscriptions.py b/tests/subscriptions/test_filter_subscriptions.py index ea5b9ce983..76d2fe34a4 100644 --- a/tests/subscriptions/test_filter_subscriptions.py +++ b/tests/subscriptions/test_filter_subscriptions.py @@ -6,12 +6,29 @@ from unittest.mock import patch import pytest +from sentry_protos.snuba.v1.endpoint_create_subscription_pb2 import ( + CreateSubscriptionRequest as CreateSubscriptionRequestProto, +) +from sentry_protos.snuba.v1.endpoint_time_series_pb2 import TimeSeriesRequest +from sentry_protos.snuba.v1.request_common_pb2 import RequestMeta +from sentry_protos.snuba.v1.trace_item_attribute_pb2 import ( + AttributeAggregation, + AttributeKey, + AttributeValue, + ExtrapolationMode, + Function, +) +from sentry_protos.snuba.v1.trace_item_filter_pb2 import ( + ComparisonFilter, + TraceItemFilter, +) from snuba.datasets.entities.entity_key import EntityKey from snuba.datasets.entities.factory import get_entity from snuba.subscriptions import scheduler from snuba.subscriptions.data import ( PartitionId, + RPCSubscriptionData, SnQLSubscriptionData, Subscription, SubscriptionIdentifier, @@ -60,3 +77,75 @@ def test_filter_subscriptions(expected_subs, extra_subs) -> None: # type: ignor slice_id=2, ) assert filtered_subs == expected_subs + + +def build_rpc_subscription(resolution: timedelta, org_id: int) -> Subscription: + return Subscription( + SubscriptionIdentifier(PartitionId(1), uuid.uuid4()), + RPCSubscriptionData.from_proto( + CreateSubscriptionRequestProto( + time_series_request=TimeSeriesRequest( + meta=RequestMeta( + project_ids=[1, 2, 3], + organization_id=org_id, + cogs_category="something", + referrer="something", + ), + aggregations=[ + AttributeAggregation( + aggregate=Function.FUNCTION_SUM, + key=AttributeKey( + type=AttributeKey.TYPE_FLOAT, name="test_metric" + ), + label="sum", + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, + ), + ], + filter=TraceItemFilter( + comparison_filter=ComparisonFilter( + key=AttributeKey(type=AttributeKey.TYPE_STRING, name="foo"), + op=ComparisonFilter.OP_NOT_EQUALS, + value=AttributeValue(val_str="bar"), + ) + ), + granularity_secs=300, + ), + time_window_secs=300, + resolution_secs=int(resolution.total_seconds()), + ), + EntityKey.EAP_SPANS, + ), + ) + + +@pytest.fixture +def expected_rpc_subs() -> MutableSequence[Subscription]: + return [ + build_rpc_subscription(timedelta(minutes=1), 2) + for count in range(randint(1, 50)) + ] + + +@pytest.fixture +def extra_rpc_subs() -> MutableSequence[Subscription]: + return [ + build_rpc_subscription(timedelta(minutes=3), 1) + for count in range(randint(1, 50)) + ] + + +@patch("snuba.settings.SLICED_STORAGE_SETS", {"events_analytics_platform": 3}) +@patch( + "snuba.settings.LOGICAL_PARTITION_MAPPING", + {"events_analytics_platform": {0: 0, 1: 1, 2: 2}}, +) +def test_filter_rpc_subscriptions(expected_rpc_subs, extra_rpc_subs) -> None: # type: ignore + importlib.reload(scheduler) + + filtered_subs = filter_subscriptions( + subscriptions=expected_rpc_subs + extra_rpc_subs, + entity_key=EntityKey.EAP_SPANS, + metrics=DummyMetricsBackend(strict=True), + slice_id=2, + ) + assert filtered_subs == expected_rpc_subs diff --git a/tests/subscriptions/test_scheduler_consumer.py b/tests/subscriptions/test_scheduler_consumer.py index 9f114ea124..0c36604766 100644 --- a/tests/subscriptions/test_scheduler_consumer.py +++ b/tests/subscriptions/test_scheduler_consumer.py @@ -1,4 +1,5 @@ import importlib +import json import logging import time import uuid @@ -17,6 +18,22 @@ from arroyo.utils.clock import TestingClock from confluent_kafka.admin import AdminClient from py._path.local import LocalPath +from sentry_protos.snuba.v1.endpoint_create_subscription_pb2 import ( + CreateSubscriptionRequest as CreateSubscriptionRequestProto, +) +from sentry_protos.snuba.v1.endpoint_time_series_pb2 import TimeSeriesRequest +from sentry_protos.snuba.v1.request_common_pb2 import RequestMeta +from sentry_protos.snuba.v1.trace_item_attribute_pb2 import ( + AttributeAggregation, + AttributeKey, + AttributeValue, + ExtrapolationMode, + Function, +) +from sentry_protos.snuba.v1.trace_item_filter_pb2 import ( + ComparisonFilter, + TraceItemFilter, +) from snuba import settings from snuba.datasets.entities.entity_key import EntityKey @@ -141,6 +158,153 @@ def test_scheduler_consumer(tmpdir: LocalPath) -> None: settings.TOPIC_PARTITION_COUNTS = {} +@pytest.mark.redis_db +def test_scheduler_consumer_rpc_subscriptions(tmpdir: LocalPath) -> None: + settings.TOPIC_PARTITION_COUNTS = {"snuba-spans": 2} + importlib.reload(scheduler_consumer) + + admin_client = AdminClient(get_default_kafka_configuration()) + create_topics(admin_client, [SnubaTopic.EAP_SPANS_COMMIT_LOG]) + + metrics_backend = TestingMetricsBackend() + entity_name = "eap_spans" + entity = get_entity(EntityKey(entity_name)) + storage = entity.get_writable_storage() + assert storage is not None + stream_loader = storage.get_table_writer().get_stream_loader() + + commit_log_topic = Topic("snuba-eap-spans-commit-log") + + mock_scheduler_producer = mock.Mock() + + from snuba.redis import RedisClientKey, get_redis_client + from snuba.subscriptions.data import PartitionId, RPCSubscriptionData + from snuba.subscriptions.store import RedisSubscriptionDataStore + + entity_key = EntityKey(entity_name) + partition_index = 0 + + store = RedisSubscriptionDataStore( + get_redis_client(RedisClientKey.SUBSCRIPTION_STORE), + entity_key, + PartitionId(partition_index), + ) + entity = get_entity(EntityKey.EVENTS) + store.create( + uuid.uuid4(), + RPCSubscriptionData.from_proto( + CreateSubscriptionRequestProto( + time_series_request=TimeSeriesRequest( + meta=RequestMeta( + project_ids=[1, 2, 3], + organization_id=1, + cogs_category="something", + referrer="something", + ), + aggregations=[ + AttributeAggregation( + aggregate=Function.FUNCTION_SUM, + key=AttributeKey( + type=AttributeKey.TYPE_FLOAT, name="test_metric" + ), + label="sum", + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, + ), + ], + filter=TraceItemFilter( + comparison_filter=ComparisonFilter( + key=AttributeKey(type=AttributeKey.TYPE_STRING, name="foo"), + op=ComparisonFilter.OP_NOT_EQUALS, + value=AttributeValue(val_str="bar"), + ) + ), + granularity_secs=300, + ), + time_window_secs=300, + resolution_secs=60, + ), + EntityKey.EAP_SPANS, + ), + ) + + builder = scheduler_consumer.SchedulerBuilder( + entity_name, + str(uuid.uuid1().hex), + "eap_spans", + [], + mock_scheduler_producer, + "latest", + False, + 60 * 5, + None, + metrics_backend, + health_check_file=(tmpdir / "health.txt").strpath, + ) + scheduler = builder.build_consumer() + time.sleep(2) + scheduler._run_once() + scheduler._run_once() + scheduler._run_once() + + epoch = 1000 + + producer = KafkaProducer( + build_kafka_producer_configuration( + stream_loader.get_default_topic_spec().topic, + ) + ) + + for partition, offset, ts in [ + (0, 0, epoch), + (1, 0, epoch + 60), + (0, 1, epoch + 120), + (1, 1, epoch + 180), + ]: + fut = producer.produce( + commit_log_topic, + payload=commit_codec.encode( + Commit( + "eap_spans", + Partition(commit_log_topic, partition), + offset, + ts, + ts, + ) + ), + ) + fut.result() + + producer.close() + + for _ in range(5): + scheduler._run_once() + + scheduler._shutdown() + + assert (tmpdir / "health.txt").check() + assert mock_scheduler_producer.produce.call_count == 2 + assert json.loads( + mock_scheduler_producer.produce.call_args_list[0][0][1].value + ) == { + "timestamp": "1970-01-01T00:16:00", + "entity": "eap_spans", + "task": { + "data": { + "project_id": 0, + "time_window": 300, + "resolution": 60, + "time_series_request": "Ch0IARIJc29tZXRoaW5nGglzb21ldGhpbmciAwECAxIUIhIKBwgBEgNmb28QBhoFEgNiYXIaGggBEg8IAxILdGVzdF9tZXRyaWMaA3N1bSABIKwC", + "request_version": "v1", + "request_name": "TimeSeriesRequest", + "subscription_type": "rpc", + } + }, + "tick_upper_offset": 1, + } + + settings.TOPIC_PARTITION_COUNTS = {} + + def test_tick_time_shift() -> None: partition = 0 offsets = Interval(0, 1) diff --git a/tests/subscriptions/test_subscription.py b/tests/subscriptions/test_subscription.py index cce21a710a..28ba62abdd 100644 --- a/tests/subscriptions/test_subscription.py +++ b/tests/subscriptions/test_subscription.py @@ -3,6 +3,22 @@ import pytest from pytest import raises +from sentry_protos.snuba.v1.endpoint_create_subscription_pb2 import ( + CreateSubscriptionRequest as CreateSubscriptionRequestProto, +) +from sentry_protos.snuba.v1.endpoint_time_series_pb2 import TimeSeriesRequest +from sentry_protos.snuba.v1.request_common_pb2 import RequestMeta +from sentry_protos.snuba.v1.trace_item_attribute_pb2 import ( + AttributeAggregation, + AttributeKey, + AttributeValue, + ExtrapolationMode, + Function, +) +from sentry_protos.snuba.v1.trace_item_filter_pb2 import ( + ComparisonFilter, + TraceItemFilter, +) from snuba.datasets.entities.entity_key import EntityKey from snuba.datasets.entities.factory import get_entity @@ -11,7 +27,11 @@ from snuba.query.exceptions import InvalidQueryException, ValidationException from snuba.query.validation.validators import ColumnValidationMode from snuba.redis import RedisClientKey, get_redis_client -from snuba.subscriptions.data import SnQLSubscriptionData, SubscriptionData +from snuba.subscriptions.data import ( + RPCSubscriptionData, + SnQLSubscriptionData, + SubscriptionData, +) from snuba.subscriptions.store import RedisSubscriptionDataStore from snuba.subscriptions.subscription import SubscriptionCreator, SubscriptionDeleter from snuba.utils.metrics.timer import Timer @@ -322,3 +342,68 @@ def test(self) -> None: ).all() == [] ) + + +TESTS_CREATE_RPC_SUBSCRIPTIONS = [ + pytest.param( + RPCSubscriptionData.from_proto( + CreateSubscriptionRequestProto( + time_series_request=TimeSeriesRequest( + meta=RequestMeta( + project_ids=[1, 2, 3], + organization_id=1, + cogs_category="something", + referrer="something", + ), + aggregations=[ + AttributeAggregation( + aggregate=Function.FUNCTION_SUM, + key=AttributeKey( + type=AttributeKey.TYPE_FLOAT, name="test_metric" + ), + label="sum", + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, + ), + ], + filter=TraceItemFilter( + comparison_filter=ComparisonFilter( + key=AttributeKey(type=AttributeKey.TYPE_STRING, name="foo"), + op=ComparisonFilter.OP_NOT_EQUALS, + value=AttributeValue(val_str="bar"), + ) + ), + granularity_secs=300, + ), + time_window_secs=300, + resolution_secs=60, + ), + EntityKey.EAP_SPANS, + ), + id="EAP spans RPC subscription", + ), +] + + +class TestEAPSpansRPCSubscriptionCreator: + timer = Timer("test") + + def setup_method(self) -> None: + self.dataset = get_dataset("metrics") + + @pytest.mark.parametrize("subscription", TESTS_CREATE_RPC_SUBSCRIPTIONS) + @pytest.mark.clickhouse_db + @pytest.mark.redis_db + def test(self, subscription: SubscriptionData) -> None: + creator = SubscriptionCreator(self.dataset, EntityKey.EAP_SPANS) + identifier = creator.create(subscription, self.timer) + assert ( + cast( + List[Tuple[UUID, SubscriptionData]], + RedisSubscriptionDataStore( + get_redis_client(RedisClientKey.SUBSCRIPTION_STORE), + EntityKey.EAP_SPANS, + identifier.partition, + ).all(), + )[0][1] + == subscription + ) diff --git a/tests/web/rpc/v1/test_create_subscription.py b/tests/web/rpc/v1/test_create_subscription.py new file mode 100644 index 0000000000..b1a5d7ed20 --- /dev/null +++ b/tests/web/rpc/v1/test_create_subscription.py @@ -0,0 +1,109 @@ +from datetime import UTC, datetime, timedelta + +import pytest +from sentry_protos.snuba.v1.endpoint_create_subscription_pb2 import ( + CreateSubscriptionRequest as CreateSubscriptionRequestProto, +) +from sentry_protos.snuba.v1.endpoint_create_subscription_pb2 import ( + CreateSubscriptionResponse, +) +from sentry_protos.snuba.v1.endpoint_time_series_pb2 import TimeSeriesRequest +from sentry_protos.snuba.v1.error_pb2 import Error +from sentry_protos.snuba.v1.request_common_pb2 import RequestMeta +from sentry_protos.snuba.v1.trace_item_attribute_pb2 import ( + AttributeAggregation, + AttributeKey, + ExtrapolationMode, + Function, +) + +from tests.base import BaseApiTest +from tests.web.rpc.v1.test_endpoint_time_series import DummyMetric, store_timeseries + +END_TIME = datetime.utcnow().replace(second=0, microsecond=0, tzinfo=UTC) +START_TIME = END_TIME - timedelta(hours=1) + + +@pytest.mark.clickhouse_db +@pytest.mark.redis_db +class TestCreateSubscriptionApi(BaseApiTest): + def test_create_valid_subscription(self) -> None: + store_timeseries( + START_TIME, + 1, + 3600, + metrics=[DummyMetric("test_metric", get_value=lambda x: 1)], + ) + + message = CreateSubscriptionRequestProto( + time_series_request=TimeSeriesRequest( + meta=RequestMeta( + project_ids=[1, 2, 3], + organization_id=1, + cogs_category="something", + referrer="something", + ), + aggregations=[ + AttributeAggregation( + aggregate=Function.FUNCTION_SUM, + key=AttributeKey( + type=AttributeKey.TYPE_FLOAT, name="test_metric" + ), + label="sum", + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, + ), + ], + granularity_secs=300, + ), + time_window_secs=300, + resolution_secs=60, + ) + response = self.app.post( + "/rpc/CreateSubscriptionRequest/v1", data=message.SerializeToString() + ) + assert response.status_code == 200 + response_class = CreateSubscriptionResponse() + response_class.ParseFromString(response.data) + assert response_class.subscription_id + + def test_create_invalid_subscription(self) -> None: + store_timeseries( + START_TIME, + 1, + 3600, + metrics=[DummyMetric("test_metric", get_value=lambda x: 1)], + ) + + message = CreateSubscriptionRequestProto( + time_series_request=TimeSeriesRequest( + meta=RequestMeta( + project_ids=[1, 2, 3], + organization_id=1, + cogs_category="something", + referrer="something", + ), + aggregations=[ + AttributeAggregation( + aggregate=Function.FUNCTION_SUM, + key=AttributeKey( + type=AttributeKey.TYPE_FLOAT, name="test_metric" + ), + label="sum", + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, + ), + ], + granularity_secs=172800, + ), + time_window_secs=172800, + resolution_secs=60, + ) + response = self.app.post( + "/rpc/CreateSubscriptionRequest/v1", data=message.SerializeToString() + ) + assert response.status_code == 500 + error = Error() + error.ParseFromString(response.data) + assert ( + error.message + == "internal error occurred while executing this RPC call: Time window must be less than or equal to 24 hours" + ) From 69cfc87973a6f67e714e7c7d02b1a3731a1ac81a Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Wed, 6 Nov 2024 16:23:17 -0500 Subject: [PATCH 05/26] subscription data test --- snuba/datasets/slicing.py | 1 + snuba/subscriptions/data.py | 6 +- tests/subscriptions/__init__.py | 94 ++++++++++++++++++++++++++- tests/subscriptions/test_data.py | 106 +++++++++++++++++++++++++++++-- 4 files changed, 200 insertions(+), 7 deletions(-) diff --git a/snuba/datasets/slicing.py b/snuba/datasets/slicing.py index 60a0e78c10..07995ca14e 100644 --- a/snuba/datasets/slicing.py +++ b/snuba/datasets/slicing.py @@ -31,6 +31,7 @@ def map_logical_partition_to_slice( assert ( storage_set.value in LOGICAL_PARTITION_MAPPING ), f"logical partition mapping missing for storage set {storage_set}" + return LOGICAL_PARTITION_MAPPING[storage_set.value][logical_partition] diff --git a/snuba/subscriptions/data.py b/snuba/subscriptions/data.py index 495de11800..797c3466c5 100644 --- a/snuba/subscriptions/data.py +++ b/snuba/subscriptions/data.py @@ -396,7 +396,11 @@ def run_query( ) -> QueryResult: response = EndpointTimeSeries().execute(request) if not response.result_timeseries: - result: Result = {"meta": [], "data": [], "trace_output": ""} + result: Result = { + "meta": [], + "data": [{request.aggregations[0].label: 0}], + "trace_output": "", + } return QueryResult( result=result, extra={"stats": {}, "sql": "", "experiments": {}} ) diff --git a/tests/subscriptions/__init__.py b/tests/subscriptions/__init__.py index b2c07719b1..360453ba88 100644 --- a/tests/subscriptions/__init__.py +++ b/tests/subscriptions/__init__.py @@ -1,6 +1,8 @@ import calendar +import random import uuid -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta +from typing import Any, Mapping import pytest @@ -9,11 +11,92 @@ from snuba.datasets.entities.factory import get_entity, get_entity_name from snuba.datasets.entity import Entity from snuba.datasets.factory import get_dataset -from snuba.datasets.storages.factory import get_writable_storage +from snuba.datasets.storages.factory import get_storage, get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.processor import InsertEvent from tests.helpers import write_raw_unprocessed_events, write_unprocessed_events +_RELEASE_TAG = "backend@24.7.0.dev0+c45b49caed1e5fcbf70097ab3f434b487c359b6b" +_SERVER_NAME = "D23CXQ4GK2.local" + + +def gen_span_message( + dt: datetime, + tags: dict[str, str] | None = None, +) -> Mapping[str, Any]: + tags = tags or {} + return { + "description": "/api/0/relays/projectconfigs/", + "duration_ms": 152, + "event_id": "d826225de75d42d6b2f01b957d51f18f", + "exclusive_time_ms": 0.228, + "is_segment": True, + "data": { + "sentry.environment": "development", + "sentry.release": _RELEASE_TAG, + "thread.name": "uWSGIWorker1Core0", + "thread.id": "8522009600", + "sentry.segment.name": "/api/0/relays/projectconfigs/", + "sentry.sdk.name": "sentry.python.django", + "sentry.sdk.version": "2.7.0", + "my.float.field": 101.2, + "my.int.field": 2000, + "my.neg.field": -100, + "my.neg.float.field": -101.2, + "my.true.bool.field": True, + "my.false.bool.field": False, + "my.numeric.attribute": 1, + }, + "measurements": { + "num_of_spans": {"value": 50.0}, + "eap.measurement": {"value": random.choice([1, 100, 1000])}, + }, + "organization_id": 1, + "origin": "auto.http.django", + "project_id": 1, + "received": 1721319572.877828, + "retention_days": 90, + "segment_id": "8873a98879faf06d", + "sentry_tags": { + "category": "http", + "environment": "development", + "op": "http.server", + "platform": "python", + "release": _RELEASE_TAG, + "sdk.name": "sentry.python.django", + "sdk.version": "2.7.0", + "status": "ok", + "status_code": "200", + "thread.id": "8522009600", + "thread.name": "uWSGIWorker1Core0", + "trace.status": "ok", + "transaction": "/api/0/relays/projectconfigs/", + "transaction.method": "POST", + "transaction.op": "http.server", + "user": "ip:127.0.0.1", + }, + "span_id": "123456781234567D", + "tags": { + "http.status_code": "200", + "relay_endpoint_version": "3", + "relay_id": "88888888-4444-4444-8444-cccccccccccc", + "relay_no_cache": "False", + "relay_protocol_version": "3", + "relay_use_post_or_schedule": "True", + "relay_use_post_or_schedule_rejected": "version", + "server_name": _SERVER_NAME, + "spans_over_limit": "False", + "color": random.choice(["red", "green", "blue"]), + "location": random.choice(["mobile", "frontend", "backend"]), + **tags, + }, + "trace_id": uuid.uuid4().hex, + "start_timestamp_ms": int(dt.replace(tzinfo=UTC).timestamp()) * 1000 + - int(random.gauss(1000, 200)), + "start_timestamp_precise": dt.replace(tzinfo=UTC).timestamp(), + "end_timestamp_precise": dt.replace(tzinfo=UTC).timestamp() + 1, + } + class BaseSubscriptionTest: @pytest.fixture(autouse=True) @@ -86,6 +169,13 @@ def setup_teardown(self, clickhouse_db: None) -> None: ], ) + spans_storage = get_storage(StorageKey("eap_spans")) + messages = [ + gen_span_message(self.base_time + timedelta(minutes=tick)) + for tick in range(self.minutes) + ] + write_raw_unprocessed_events(spans_storage, messages) + def __entity_eq__(self: Entity, other: object) -> bool: if not isinstance(other, Entity): diff --git a/tests/subscriptions/test_data.py b/tests/subscriptions/test_data.py index 98b301b8e2..0d0b647003 100644 --- a/tests/subscriptions/test_data.py +++ b/tests/subscriptions/test_data.py @@ -2,12 +2,32 @@ from typing import Optional, Type, Union import pytest +from sentry_protos.snuba.v1.endpoint_create_subscription_pb2 import ( + CreateSubscriptionRequest as CreateSubscriptionRequestProto, +) +from sentry_protos.snuba.v1.endpoint_time_series_pb2 import TimeSeriesRequest +from sentry_protos.snuba.v1.request_common_pb2 import RequestMeta +from sentry_protos.snuba.v1.trace_item_attribute_pb2 import ( + AttributeAggregation, + AttributeKey, + AttributeValue, + ExtrapolationMode, + Function, +) +from sentry_protos.snuba.v1.trace_item_filter_pb2 import ( + ComparisonFilter, + TraceItemFilter, +) from snuba.datasets.dataset import Dataset from snuba.datasets.entities.entity_key import EntityKey from snuba.datasets.entities.factory import get_entity from snuba.query.exceptions import InvalidQueryException -from snuba.subscriptions.data import SnQLSubscriptionData, SubscriptionData +from snuba.subscriptions.data import ( + RPCSubscriptionData, + SnQLSubscriptionData, + SubscriptionData, +) from snuba.utils.metrics.timer import Timer from tests.subscriptions import BaseSubscriptionTest @@ -26,6 +46,7 @@ entity=get_entity(EntityKey.EVENTS), metadata={}, ), + 10, None, id="SnQL subscription", ), @@ -42,6 +63,7 @@ entity=get_entity(EntityKey.EVENTS), metadata={}, ), + 10, None, id="SnQL subscription", ), @@ -59,6 +81,7 @@ entity=get_entity(EntityKey.EVENTS), metadata={}, ), + None, InvalidQueryException, id="SnQL subscription with 2 many aggregates", ), @@ -76,9 +99,81 @@ entity=get_entity(EntityKey.EVENTS), metadata={}, ), + None, InvalidQueryException, id="SnQL subscription with disallowed clause", ), + pytest.param( + RPCSubscriptionData.from_proto( + CreateSubscriptionRequestProto( + time_series_request=TimeSeriesRequest( + meta=RequestMeta( + project_ids=[1, 2, 3], + organization_id=1, + cogs_category="something", + referrer="something", + ), + aggregations=[ + AttributeAggregation( + aggregate=Function.FUNCTION_COUNT, + key=AttributeKey( + type=AttributeKey.TYPE_FLOAT, name="my.float.field" + ), + label="count", + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, + ), + ], + granularity_secs=3600, + ), + time_window_secs=3600, + resolution_secs=60, + ), + EntityKey.EAP_SPANS, + ), + 20.0, + None, + id="RPC subscription", + ), + pytest.param( + RPCSubscriptionData.from_proto( + CreateSubscriptionRequestProto( + time_series_request=TimeSeriesRequest( + meta=RequestMeta( + project_ids=[1, 2, 3], + organization_id=1, + cogs_category="something", + referrer="something", + ), + aggregations=[ + AttributeAggregation( + aggregate=Function.FUNCTION_COUNT, + key=AttributeKey( + type=AttributeKey.TYPE_FLOAT, name="my.float.field" + ), + label="count", + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, + ), + ], + filter=TraceItemFilter( + comparison_filter=ComparisonFilter( + key=AttributeKey( + type=AttributeKey.TYPE_STRING, name="sentry.sdk.version" + ), + op=ComparisonFilter.OP_EQUALS, + value=AttributeValue(val_str="3.0.0"), + ) + ), + granularity_secs=3600, + ), + time_window_secs=3600, + resolution_secs=60, + ), + EntityKey.EAP_SPANS, + ), + 0.0, + None, + id="RPC subscription with filter", + ), ] @@ -116,10 +211,13 @@ def compare_conditions( class TestBuildRequest(BaseSubscriptionTest, TestBuildRequestBase): - @pytest.mark.parametrize("subscription, exception", TESTS) + @pytest.mark.parametrize("subscription, expected_value, exception", TESTS) @pytest.mark.clickhouse_db @pytest.mark.redis_db def test_conditions( - self, subscription: SubscriptionData, exception: Optional[Type[Exception]] + self, + subscription: SubscriptionData, + expected_value: Optional[int | float], + exception: Optional[Type[Exception]], ) -> None: - self.compare_conditions(subscription, exception, "count", 10) + self.compare_conditions(subscription, exception, "count", expected_value) From f5323427fc6419af44fe30fa0fa7d0d7e6531581 Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Thu, 7 Nov 2024 10:52:34 -0500 Subject: [PATCH 06/26] fix typing --- snuba/subscriptions/data.py | 4 ++-- tests/subscriptions/__init__.py | 4 ++-- tests/subscriptions/test_data.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/snuba/subscriptions/data.py b/snuba/subscriptions/data.py index 797c3466c5..5da8cb659a 100644 --- a/snuba/subscriptions/data.py +++ b/snuba/subscriptions/data.py @@ -408,7 +408,7 @@ def run_query( timeseries = response.result_timeseries[0] data = [{timeseries.label: timeseries.data_points[0].data}] - result: Result = {"meta": [], "data": data, "trace_output": ""} + result = {"meta": [], "data": data, "trace_output": ""} return QueryResult( result=result, extra={"stats": {}, "sql": "", "experiments": {}} ) @@ -445,7 +445,7 @@ def from_proto( metadata["organization"] = item.time_series_request.meta.organization_id return RPCSubscriptionData( - project_id=item.project_id, + project_id=item.time_series_request.meta.project_ids[0], time_window_sec=item.time_window_secs, resolution_sec=item.resolution_secs, time_series_request=base64.b64encode( diff --git a/tests/subscriptions/__init__.py b/tests/subscriptions/__init__.py index 360453ba88..698b5533f2 100644 --- a/tests/subscriptions/__init__.py +++ b/tests/subscriptions/__init__.py @@ -11,7 +11,7 @@ from snuba.datasets.entities.factory import get_entity, get_entity_name from snuba.datasets.entity import Entity from snuba.datasets.factory import get_dataset -from snuba.datasets.storages.factory import get_storage, get_writable_storage +from snuba.datasets.storages.factory import get_writable_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.processor import InsertEvent from tests.helpers import write_raw_unprocessed_events, write_unprocessed_events @@ -169,7 +169,7 @@ def setup_teardown(self, clickhouse_db: None) -> None: ], ) - spans_storage = get_storage(StorageKey("eap_spans")) + spans_storage = get_writable_storage(StorageKey("eap_spans")) messages = [ gen_span_message(self.base_time + timedelta(minutes=tick)) for tick in range(self.minutes) diff --git a/tests/subscriptions/test_data.py b/tests/subscriptions/test_data.py index 0d0b647003..fb0ea902fb 100644 --- a/tests/subscriptions/test_data.py +++ b/tests/subscriptions/test_data.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Optional, Type, Union +from typing import Optional, Type import pytest from sentry_protos.snuba.v1.endpoint_create_subscription_pb2 import ( @@ -185,7 +185,7 @@ def compare_conditions( subscription: SubscriptionData, exception: Optional[Type[Exception]], aggregate: str, - value: Union[int, float], + value: Optional[int | float], ) -> None: timer = Timer("test") if exception is not None: From 2cd47135475a3da9a0dd5457dca2d14c3ee2896f Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Fri, 8 Nov 2024 11:42:09 -0500 Subject: [PATCH 07/26] codec tests --- snuba/subscriptions/codecs.py | 2 +- snuba/subscriptions/data.py | 29 +-- tests/subscriptions/test_codecs.py | 177 +++++++++++++++++- tests/subscriptions/test_data.py | 2 - .../test_filter_subscriptions.py | 1 - .../subscriptions/test_scheduler_consumer.py | 1 - tests/subscriptions/test_subscription.py | 1 - 7 files changed, 191 insertions(+), 22 deletions(-) diff --git a/snuba/subscriptions/codecs.py b/snuba/subscriptions/codecs.py index 72a2d16101..10983f6195 100644 --- a/snuba/subscriptions/codecs.py +++ b/snuba/subscriptions/codecs.py @@ -112,7 +112,7 @@ def decode(self, value: KafkaPayload) -> ScheduledSubscriptionTask: data = scheduled_subscription_dict["task"]["data"] subscription: SubscriptionData - if data.get("subscription_type") == SubscriptionType.RPC: + if data.get("subscription_type") == SubscriptionType.RPC.value: subscription = RPCSubscriptionData.from_dict(data, entity_key) else: subscription = SnQLSubscriptionData.from_dict(data, entity_key) diff --git a/snuba/subscriptions/data.py b/snuba/subscriptions/data.py index 5da8cb659a..e0ec440c3f 100644 --- a/snuba/subscriptions/data.py +++ b/snuba/subscriptions/data.py @@ -67,10 +67,13 @@ "resolution", "query", "tenant_ids", + "subscription_type", + "time_series_request", + "request_name", + "request_version", } -PROTOBUF_ALLOWLIST = ["TimeSeriesRequest"] -PROTOBUF_VERSION_ALLOWLIST = ["v1"] +REQUEST_TYPE_ALLOWLIST = [("TimeSeriesRequest", "v1")] class SubscriptionType(Enum): @@ -342,14 +345,9 @@ class RPCSubscriptionData(_SubscriptionData[TimeSeriesRequest]): def validate(self) -> None: super().validate() - if self.request_name not in PROTOBUF_ALLOWLIST: + if (self.request_name, self.request_version) not in REQUEST_TYPE_ALLOWLIST: raise InvalidSubscriptionError( - f"{self.request_name} is not supported. Supported request types are: {PROTOBUF_ALLOWLIST}" - ) - - if self.request_version not in PROTOBUF_VERSION_ALLOWLIST: - raise InvalidSubscriptionError( - f"{self.request_version} version not supported. Supported versions are: {PROTOBUF_VERSION_ALLOWLIST}" + f"{self.request_name} {self.request_version} not supported." ) # TODO: Validate no group by, having, order by etc @@ -384,6 +382,8 @@ def build_request( request_class.meta.start_timestamp.CopyFrom(start_time_proto) request_class.meta.end_timestamp.CopyFrom(end_time_proto) + request_class.granularity_secs = self.time_window_sec + return request_class def run_query( @@ -418,6 +418,10 @@ def from_dict( cls, data: Mapping[str, Any], entity_key: EntityKey ) -> RPCSubscriptionData: entity: Entity = get_entity(entity_key) + metadata = data.pop("metadata", {}) + for key in data.keys(): + if key not in SUBSCRIPTION_DATA_PAYLOAD_KEYS: + metadata[key] = data[key] return RPCSubscriptionData( project_id=data["project_id"], @@ -427,7 +431,7 @@ def from_dict( request_version=data["request_version"], request_name=data["request_name"], entity=entity, - metadata=data.get("metadata", dict()), + metadata=metadata, tenant_ids=data.get("tenant_ids", dict()), ) @@ -440,7 +444,7 @@ def from_proto( class_name = request_class.__name__ class_version = request_class.__module__.split(".", 3)[2] - metadata = dict() + metadata = {} if item.time_series_request.meta: metadata["organization"] = item.time_series_request.meta.organization_id @@ -467,8 +471,9 @@ def to_dict(self) -> Mapping[str, Any]: "request_version": self.request_version, "request_name": self.request_name, "subscription_type": SubscriptionType.RPC.value, - "metadata": self.metadata, } + if self.metadata: + subscription_data_dict["metadata"] = self.metadata return subscription_data_dict diff --git a/tests/subscriptions/test_codecs.py b/tests/subscriptions/test_codecs.py index 03a2080e91..f2ad2db200 100644 --- a/tests/subscriptions/test_codecs.py +++ b/tests/subscriptions/test_codecs.py @@ -6,6 +6,22 @@ from typing import Any, Callable, Mapping import pytest +from sentry_protos.snuba.v1.endpoint_create_subscription_pb2 import ( + CreateSubscriptionRequest as CreateSubscriptionRequestProto, +) +from sentry_protos.snuba.v1.endpoint_time_series_pb2 import TimeSeriesRequest +from sentry_protos.snuba.v1.request_common_pb2 import RequestMeta +from sentry_protos.snuba.v1.trace_item_attribute_pb2 import ( + AttributeAggregation, + AttributeKey, + AttributeValue, + ExtrapolationMode, + Function, +) +from sentry_protos.snuba.v1.trace_item_filter_pb2 import ( + ComparisonFilter, + TraceItemFilter, +) from snuba.datasets.entities.entity_key import EntityKey from snuba.datasets.entities.factory import get_entity @@ -19,6 +35,7 @@ ) from snuba.subscriptions.data import ( PartitionId, + RPCSubscriptionData, ScheduledSubscriptionTask, SnQLSubscriptionData, Subscription, @@ -30,6 +47,76 @@ from snuba.utils.metrics.timer import Timer +def build_rpc_subscription_data_from_proto( + entity_key: EntityKey, metadata: Mapping[str, Any] +) -> SubscriptionData: + + return RPCSubscriptionData.from_proto( + CreateSubscriptionRequestProto( + time_series_request=TimeSeriesRequest( + meta=RequestMeta( + project_ids=[1], + organization_id=1, + cogs_category="something", + referrer="something", + ), + aggregations=[ + AttributeAggregation( + aggregate=Function.FUNCTION_SUM, + key=AttributeKey( + type=AttributeKey.TYPE_FLOAT, name="test_metric" + ), + label="sum", + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, + ), + ], + filter=TraceItemFilter( + comparison_filter=ComparisonFilter( + key=AttributeKey(type=AttributeKey.TYPE_STRING, name="foo"), + op=ComparisonFilter.OP_NOT_EQUALS, + value=AttributeValue(val_str="bar"), + ) + ), + ), + time_window_secs=300, + resolution_secs=60, + ), + EntityKey.EAP_SPANS, + ) + + +def build_rpc_subscription_data( + entity_key: EntityKey, metadata: Mapping[str, Any] +) -> SubscriptionData: + + return RPCSubscriptionData( + project_id=1, + time_window_sec=300, + resolution_sec=60, + entity=get_entity(entity_key), + metadata=metadata, + time_series_request="Ch0IARIJc29tZXRoaW5nGglzb21ldGhpbmciAwECAxIUIhIKBwgBEgNmb28QBhoFEgNiYXIaGggBEg8IAxILdGVzdF9tZXRyaWMaA3N1bSAB", + request_name="TimeSeriesRequest", + request_version="v1", + ) + + +RPC_CASES = [ + pytest.param( + build_rpc_subscription_data_from_proto, + {"organization": 1}, + EntityKey.EAP_SPANS, + id="rpc", + ), + pytest.param( + build_rpc_subscription_data, + {"organization": 1}, + EntityKey.EAP_SPANS, + id="rpc", + ), +] + + def build_snql_subscription_data( entity_key: EntityKey, metadata: Mapping[str, Any] ) -> SubscriptionData: @@ -66,7 +153,7 @@ def build_snql_subscription_data( ] -@pytest.mark.parametrize("builder, metadata, entity_key", SNQL_CASES) +@pytest.mark.parametrize("builder, metadata, entity_key", SNQL_CASES + RPC_CASES) def test_basic( builder: Callable[[EntityKey, Mapping[str, Any]], SubscriptionData], metadata: Mapping[str, Any], @@ -119,6 +206,53 @@ def test_decode_snql( assert codec.decode(payload) == subscription +@pytest.mark.parametrize("builder, metadata, entity_key", RPC_CASES) +def test_encode_rpc( + builder: Callable[[EntityKey, Mapping[str, Any]], SubscriptionData], + metadata: Mapping[str, Any], + entity_key: EntityKey, +) -> None: + codec = SubscriptionDataCodec(entity_key) + subscription = builder(entity_key, metadata) + + assert isinstance(subscription, RPCSubscriptionData) + + payload = codec.encode(subscription) + data = json.loads(payload.decode("utf-8")) + assert data["project_id"] == subscription.project_id + assert data["time_window"] == subscription.time_window_sec + assert data["resolution"] == subscription.resolution_sec + assert data["time_series_request"] == subscription.time_series_request + assert data["request_name"] == subscription.request_name + assert data["request_version"] == subscription.request_version + assert metadata == subscription.metadata + + +@pytest.mark.parametrize("builder, metadata, entity_key", RPC_CASES) +def test_decode_rpc( + builder: Callable[[EntityKey, Mapping[str, Any]], SubscriptionData], + metadata: Mapping[str, Any], + entity_key: EntityKey, +) -> None: + codec = SubscriptionDataCodec(entity_key) + subscription = builder(entity_key, metadata) + + assert isinstance(subscription, RPCSubscriptionData) + data = { + "project_id": subscription.project_id, + "time_window": subscription.time_window_sec, + "resolution": subscription.resolution_sec, + "time_series_request": subscription.time_series_request, + "request_version": subscription.request_version, + "request_name": subscription.request_name, + "subscription_type": "rpc", + } + if metadata: + data.update(metadata) + payload = json.dumps(data).encode("utf-8") + assert codec.decode(payload) == subscription + + def test_subscription_task_result_encoder() -> None: codec = SubscriptionTaskResultEncoder() @@ -257,7 +391,7 @@ def test_metrics_subscription_task_result_encoder( assert payload["entity"] == entity_key.value -def test_subscription_task_encoder() -> None: +def test_subscription_task_encoder_snql() -> None: encoder = SubscriptionScheduledTaskEncoder() entity = get_entity(EntityKey.EVENTS) subscription_data = SnQLSubscriptionData( @@ -288,17 +422,52 @@ def test_subscription_task_encoder() -> None: encoded = encoder.encode(task) assert encoded.key == b"1/91b46cb6224f11ecb2ddacde48001122" - assert encoded.value == ( b"{" b'"timestamp":"1970-01-01T00:00:00",' b'"entity":"events",' b'"task":{' - b'"data":{"project_id":1,"time_window":60,"resolution":60,"query":"MATCH events SELECT count()"}},' + b'"data":{"project_id":1,"time_window":60,"resolution":60,"query":"MATCH events SELECT count()","subscription_type":"snql"}},' b'"tick_upper_offset":5' b"}" ) decoded = encoder.decode(encoded) + assert decoded == task + +def test_subscription_task_encoder_rpc() -> None: + encoder = SubscriptionScheduledTaskEncoder() + subscription_data = build_rpc_subscription_data(EntityKey.EAP_SPANS, {}) + + subscription_id = uuid.UUID("91b46cb6224f11ecb2ddacde48001122") + + epoch = datetime(1970, 1, 1) + + tick_upper_offset = 5 + + subscription_with_metadata = SubscriptionWithMetadata( + EntityKey.EAP_SPANS, + Subscription( + SubscriptionIdentifier(PartitionId(1), subscription_id), subscription_data + ), + tick_upper_offset, + ) + + task = ScheduledSubscriptionTask(timestamp=epoch, task=subscription_with_metadata) + + encoded = encoder.encode(task) + + assert encoded.key == b"1/91b46cb6224f11ecb2ddacde48001122" + assert encoded.value == ( + b"{" + b'"timestamp":"1970-01-01T00:00:00",' + b'"entity":"eap_spans",' + b'"task":{' + b'"data":{"project_id":1,"time_window":300,"resolution":60,"time_series_request":"Ch0IARIJc29tZXRoaW5nGglzb21ldGhpbmciAwECAxIUIhIKBwgBEgNmb28QBhoFEgNiYXIaGggBEg8IAxILdGVzdF9tZXRyaWMaA3N1bSAB","request_version":"v1","request_name":"TimeSeriesRequest","subscription_type":"rpc"}},' + b'"tick_upper_offset":5' + b"}" + ) + + decoded = encoder.decode(encoded) assert decoded == task diff --git a/tests/subscriptions/test_data.py b/tests/subscriptions/test_data.py index fb0ea902fb..0ef67ebc6a 100644 --- a/tests/subscriptions/test_data.py +++ b/tests/subscriptions/test_data.py @@ -123,7 +123,6 @@ extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, ), ], - granularity_secs=3600, ), time_window_secs=3600, resolution_secs=60, @@ -163,7 +162,6 @@ value=AttributeValue(val_str="3.0.0"), ) ), - granularity_secs=3600, ), time_window_secs=3600, resolution_secs=60, diff --git a/tests/subscriptions/test_filter_subscriptions.py b/tests/subscriptions/test_filter_subscriptions.py index 76d2fe34a4..f683cab241 100644 --- a/tests/subscriptions/test_filter_subscriptions.py +++ b/tests/subscriptions/test_filter_subscriptions.py @@ -108,7 +108,6 @@ def build_rpc_subscription(resolution: timedelta, org_id: int) -> Subscription: value=AttributeValue(val_str="bar"), ) ), - granularity_secs=300, ), time_window_secs=300, resolution_secs=int(resolution.total_seconds()), diff --git a/tests/subscriptions/test_scheduler_consumer.py b/tests/subscriptions/test_scheduler_consumer.py index 0c36604766..9605c969ed 100644 --- a/tests/subscriptions/test_scheduler_consumer.py +++ b/tests/subscriptions/test_scheduler_consumer.py @@ -218,7 +218,6 @@ def test_scheduler_consumer_rpc_subscriptions(tmpdir: LocalPath) -> None: value=AttributeValue(val_str="bar"), ) ), - granularity_secs=300, ), time_window_secs=300, resolution_secs=60, diff --git a/tests/subscriptions/test_subscription.py b/tests/subscriptions/test_subscription.py index 28ba62abdd..0e59cc6c21 100644 --- a/tests/subscriptions/test_subscription.py +++ b/tests/subscriptions/test_subscription.py @@ -372,7 +372,6 @@ def test(self) -> None: value=AttributeValue(val_str="bar"), ) ), - granularity_secs=300, ), time_window_secs=300, resolution_secs=60, From 958926dbfc444c47b905a7797f9fa2ffaa23eccd Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Fri, 8 Nov 2024 12:28:38 -0500 Subject: [PATCH 08/26] task result encoder --- snuba/subscriptions/codecs.py | 10 ++++-- snuba/subscriptions/data.py | 6 ++-- tests/subscriptions/test_codecs.py | 55 +++++++++++++++++++++--------- 3 files changed, 51 insertions(+), 20 deletions(-) diff --git a/snuba/subscriptions/codecs.py b/snuba/subscriptions/codecs.py index 10983f6195..9f0018eefb 100644 --- a/snuba/subscriptions/codecs.py +++ b/snuba/subscriptions/codecs.py @@ -1,10 +1,10 @@ +import base64 import json from datetime import datetime from typing import cast import rapidjson from arroyo.backends.kafka import KafkaPayload -from google.protobuf.json_format import MessageToDict from google.protobuf.message import Message as ProtobufMessage from sentry_kafka_schemas.schema_types import events_subscription_results_v1 @@ -50,7 +50,13 @@ def encode(self, value: SubscriptionTaskResult) -> KafkaPayload: request, result = value.result if isinstance(request, ProtobufMessage): - original_body = {**MessageToDict(request)} + original_body = { + "request": base64.b64encode(request.SerializeToString()).decode( + "utf-8" + ), + "request_name": request.__name__, + "request_version": request.__module__.split(".", 3)[2], + } else: original_body = {**request.original_body} diff --git a/snuba/subscriptions/data.py b/snuba/subscriptions/data.py index e0ec440c3f..586ab7ce38 100644 --- a/snuba/subscriptions/data.py +++ b/snuba/subscriptions/data.py @@ -418,9 +418,11 @@ def from_dict( cls, data: Mapping[str, Any], entity_key: EntityKey ) -> RPCSubscriptionData: entity: Entity = get_entity(entity_key) - metadata = data.pop("metadata", {}) + metadata = {} for key in data.keys(): - if key not in SUBSCRIPTION_DATA_PAYLOAD_KEYS: + if key == "metadata": + metadata.update(data[key]) + elif key not in SUBSCRIPTION_DATA_PAYLOAD_KEYS: metadata[key] = data[key] return RPCSubscriptionData( diff --git a/tests/subscriptions/test_codecs.py b/tests/subscriptions/test_codecs.py index f2ad2db200..efac4b4e27 100644 --- a/tests/subscriptions/test_codecs.py +++ b/tests/subscriptions/test_codecs.py @@ -253,23 +253,46 @@ def test_decode_rpc( assert codec.decode(payload) == subscription -def test_subscription_task_result_encoder() -> None: +RESULTS_CASES = [ + pytest.param( + SnQLSubscriptionData( + project_id=1, + query="MATCH (events) SELECT count() AS count", + time_window_sec=60, + resolution_sec=60, + entity=get_entity(EntityKey.EVENTS), + metadata={}, + ), + EntityKey.EVENTS, + { + "query": "MATCH (events) SELECT count() AS count", + "tenant_ids": {"referrer": "subscription", "organization_id": 1}, + }, + id="snql_subscription", + ), + pytest.param( + build_rpc_subscription_data(entity_key=EntityKey.EAP_SPANS, metadata={}), + EntityKey.EAP_SPANS, + { + "request": "Ci0IARIJc29tZXRoaW5nGglzb21ldGhpbmciAwECAyoGCPCDuLkGMgYInIa4uQYSFCISCgcIARIDZm9vEAYaBRIDYmFyGhoIARIPCAMSC3Rlc3RfbWV0cmljGgNzdW0gASCsAg==", + "request_name": "TimeSeriesRequest", + "request_version": "v1", + }, + id="snql_subscription", + ), +] + + +@pytest.mark.parametrize("subscription, entity_key, original_body", RESULTS_CASES) +def test_subscription_task_result_encoder( + subscription: SubscriptionData, entity_key: EntityKey, original_body: dict[str, Any] +) -> None: codec = SubscriptionTaskResultEncoder() timestamp = datetime.now() - entity = get_entity(EntityKey.EVENTS) - subscription_data = SnQLSubscriptionData( - project_id=1, - query="MATCH (events) SELECT count() AS count", - time_window_sec=60, - resolution_sec=60, - entity=entity, - metadata={}, - ) - # XXX: This seems way too coupled to the dataset. - request = subscription_data.build_request( + request = subscription.build_request( get_dataset("events"), timestamp, None, Timer("timer") ) result: Result = { @@ -283,10 +306,10 @@ def test_subscription_task_result_encoder() -> None: ScheduledSubscriptionTask( timestamp, SubscriptionWithMetadata( - EntityKey.EVENTS, + entity_key, Subscription( SubscriptionIdentifier(PartitionId(1), uuid.uuid1()), - subscription_data, + subscription, ), 5, ), @@ -302,10 +325,10 @@ def test_subscription_task_result_encoder() -> None: assert payload["subscription_id"] == str( task_result.task.task.subscription.identifier ) - assert payload["request"] == request.original_body + assert payload["request"] == original_body assert payload["result"]["data"] == result["data"] assert payload["timestamp"] == task_result.task.timestamp.isoformat() - assert payload["entity"] == EntityKey.EVENTS.value + assert payload["entity"] == entity_key.value METRICS_CASES = [ From 9a8f16e71e33f5b5885b59f83027e07f1b2db900 Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Fri, 8 Nov 2024 12:37:54 -0500 Subject: [PATCH 09/26] fix typing --- snuba/subscriptions/codecs.py | 4 ++-- tests/subscriptions/test_codecs.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/snuba/subscriptions/codecs.py b/snuba/subscriptions/codecs.py index 9f0018eefb..9136ed2225 100644 --- a/snuba/subscriptions/codecs.py +++ b/snuba/subscriptions/codecs.py @@ -54,8 +54,8 @@ def encode(self, value: SubscriptionTaskResult) -> KafkaPayload: "request": base64.b64encode(request.SerializeToString()).decode( "utf-8" ), - "request_name": request.__name__, - "request_version": request.__module__.split(".", 3)[2], + "request_name": request.__class__.__name__, + "request_version": request.__class__.__module__.split(".", 3)[2], } else: original_body = {**request.original_body} diff --git a/tests/subscriptions/test_codecs.py b/tests/subscriptions/test_codecs.py index efac4b4e27..92cb3f04d9 100644 --- a/tests/subscriptions/test_codecs.py +++ b/tests/subscriptions/test_codecs.py @@ -274,7 +274,7 @@ def test_decode_rpc( build_rpc_subscription_data(entity_key=EntityKey.EAP_SPANS, metadata={}), EntityKey.EAP_SPANS, { - "request": "Ci0IARIJc29tZXRoaW5nGglzb21ldGhpbmciAwECAyoGCPCDuLkGMgYInIa4uQYSFCISCgcIARIDZm9vEAYaBRIDYmFyGhoIARIPCAMSC3Rlc3RfbWV0cmljGgNzdW0gASCsAg==", + "request": "Ci0IARIJc29tZXRoaW5nGglzb21ldGhpbmciAwECAyoGCMiIuLkGMgYI9Iq4uQYSFCISCgcIARIDZm9vEAYaBRIDYmFyGhoIARIPCAMSC3Rlc3RfbWV0cmljGgNzdW0gASCsAg==", "request_name": "TimeSeriesRequest", "request_version": "v1", }, From 67372da45276096b78b1cad983621bc0bc26d81f Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Fri, 8 Nov 2024 13:41:31 -0500 Subject: [PATCH 10/26] not deterministic --- tests/subscriptions/test_codecs.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/tests/subscriptions/test_codecs.py b/tests/subscriptions/test_codecs.py index 92cb3f04d9..5046fbc26d 100644 --- a/tests/subscriptions/test_codecs.py +++ b/tests/subscriptions/test_codecs.py @@ -6,6 +6,7 @@ from typing import Any, Callable, Mapping import pytest +from google.protobuf.message import Message as ProtobufMessage from sentry_protos.snuba.v1.endpoint_create_subscription_pb2 import ( CreateSubscriptionRequest as CreateSubscriptionRequestProto, ) @@ -264,28 +265,19 @@ def test_decode_rpc( metadata={}, ), EntityKey.EVENTS, - { - "query": "MATCH (events) SELECT count() AS count", - "tenant_ids": {"referrer": "subscription", "organization_id": 1}, - }, id="snql_subscription", ), pytest.param( build_rpc_subscription_data(entity_key=EntityKey.EAP_SPANS, metadata={}), EntityKey.EAP_SPANS, - { - "request": "Ci0IARIJc29tZXRoaW5nGglzb21ldGhpbmciAwECAyoGCMiIuLkGMgYI9Iq4uQYSFCISCgcIARIDZm9vEAYaBRIDYmFyGhoIARIPCAMSC3Rlc3RfbWV0cmljGgNzdW0gASCsAg==", - "request_name": "TimeSeriesRequest", - "request_version": "v1", - }, - id="snql_subscription", + id="rpc_subscriptions", ), ] -@pytest.mark.parametrize("subscription, entity_key, original_body", RESULTS_CASES) +@pytest.mark.parametrize("subscription, entity_key", RESULTS_CASES) def test_subscription_task_result_encoder( - subscription: SubscriptionData, entity_key: EntityKey, original_body: dict[str, Any] + subscription: SubscriptionData, entity_key: EntityKey ) -> None: codec = SubscriptionTaskResultEncoder() @@ -325,7 +317,11 @@ def test_subscription_task_result_encoder( assert payload["subscription_id"] == str( task_result.task.task.subscription.identifier ) - assert payload["request"] == original_body + if isinstance(request, ProtobufMessage): + assert payload["request"]["request_name"] == "TimeSeriesRequest" + assert payload["request"]["request_version"] == "v1" + else: + assert payload["request"] == request.original_body assert payload["result"]["data"] == result["data"] assert payload["timestamp"] == task_result.task.timestamp.isoformat() assert payload["entity"] == entity_key.value From c16e31bd9aee8748c8f39ed19aac8022517f2cc6 Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Fri, 8 Nov 2024 14:38:35 -0500 Subject: [PATCH 11/26] fix test --- snuba/subscriptions/utils.py | 1 + .../subscriptions/test_scheduler_consumer.py | 29 +++++++------------ 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/snuba/subscriptions/utils.py b/snuba/subscriptions/utils.py index 409b118e4a..68599cd5ae 100644 --- a/snuba/subscriptions/utils.py +++ b/snuba/subscriptions/utils.py @@ -24,6 +24,7 @@ def time_shift(self, delta: float) -> Tick: Returns a new ``Tick`` instance that has had the bounds of its time interval shifted by the provided delta. """ + breakpoint() return Tick( self.partition, self.offsets, diff --git a/tests/subscriptions/test_scheduler_consumer.py b/tests/subscriptions/test_scheduler_consumer.py index 9605c969ed..664a17384d 100644 --- a/tests/subscriptions/test_scheduler_consumer.py +++ b/tests/subscriptions/test_scheduler_consumer.py @@ -1,3 +1,4 @@ +import base64 import importlib import json import logging @@ -196,7 +197,7 @@ def test_scheduler_consumer_rpc_subscriptions(tmpdir: LocalPath) -> None: CreateSubscriptionRequestProto( time_series_request=TimeSeriesRequest( meta=RequestMeta( - project_ids=[1, 2, 3], + project_ids=[1], organization_id=1, cogs_category="something", referrer="something", @@ -282,24 +283,14 @@ def test_scheduler_consumer_rpc_subscriptions(tmpdir: LocalPath) -> None: assert (tmpdir / "health.txt").check() assert mock_scheduler_producer.produce.call_count == 2 - assert json.loads( - mock_scheduler_producer.produce.call_args_list[0][0][1].value - ) == { - "timestamp": "1970-01-01T00:16:00", - "entity": "eap_spans", - "task": { - "data": { - "project_id": 0, - "time_window": 300, - "resolution": 60, - "time_series_request": "Ch0IARIJc29tZXRoaW5nGglzb21ldGhpbmciAwECAxIUIhIKBwgBEgNmb28QBhoFEgNiYXIaGggBEg8IAxILdGVzdF9tZXRyaWMaA3N1bSABIKwC", - "request_version": "v1", - "request_name": "TimeSeriesRequest", - "subscription_type": "rpc", - } - }, - "tick_upper_offset": 1, - } + payload = json.loads(mock_scheduler_producer.produce.call_args_list[0][0][1].value) + assert payload["task"]["data"]["project_id"] == 1 + assert payload["task"]["data"]["resolution"] == 60 + assert payload["task"]["data"]["time_window"] == 300 + assert payload["task"]["data"]["request_name"] == "TimeSeriesRequest" + assert payload["task"]["data"]["request_version"] == "v1" + time_series_request = payload["task"]["data"]["time_series_request"] + TimeSeriesRequest().ParseFromString(base64.b64decode(time_series_request)) settings.TOPIC_PARTITION_COUNTS = {} From d1529411e030adfa7c30877dadc057553eac0ed4 Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Fri, 8 Nov 2024 14:58:52 -0500 Subject: [PATCH 12/26] remove breakpoint --- snuba/subscriptions/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/snuba/subscriptions/utils.py b/snuba/subscriptions/utils.py index 68599cd5ae..409b118e4a 100644 --- a/snuba/subscriptions/utils.py +++ b/snuba/subscriptions/utils.py @@ -24,7 +24,6 @@ def time_shift(self, delta: float) -> Tick: Returns a new ``Tick`` instance that has had the bounds of its time interval shifted by the provided delta. """ - breakpoint() return Tick( self.partition, self.offsets, From 76a7f0deffedc478d1aa501d63713b934de22a2b Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Mon, 18 Nov 2024 11:26:36 -0500 Subject: [PATCH 13/26] feat(subscriptions): Add a create subscription rpc endpoint --- requirements.txt | 2 +- snuba/subscriptions/data.py | 129 ++++++++++++++++++- snuba/subscriptions/rpc_helpers.py | 59 +++++++++ snuba/subscriptions/subscription.py | 22 +++- snuba/web/rpc/v1/create_subscription.py | 44 +++++++ tests/web/rpc/v1/test_create_subscription.py | 109 ++++++++++++++++ 6 files changed, 360 insertions(+), 5 deletions(-) create mode 100644 snuba/subscriptions/rpc_helpers.py create mode 100644 snuba/web/rpc/v1/create_subscription.py create mode 100644 tests/web/rpc/v1/test_create_subscription.py diff --git a/requirements.txt b/requirements.txt index 48863b2e5e..54efdbe65b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,4 +46,4 @@ sqlparse==0.5.0 google-api-python-client==2.88.0 sentry-usage-accountant==0.0.11 freezegun==1.2.2 -sentry-protos==0.1.31 +sentry-protos==0.1.34 diff --git a/snuba/subscriptions/data.py b/snuba/subscriptions/data.py index 4c7d005bbe..ad8c38cb88 100644 --- a/snuba/subscriptions/data.py +++ b/snuba/subscriptions/data.py @@ -1,10 +1,12 @@ from __future__ import annotations +import base64 import logging from abc import ABC, abstractmethod from concurrent.futures import Future from dataclasses import dataclass, field from datetime import datetime, timedelta +from enum import Enum from functools import partial from typing import ( Any, @@ -19,6 +21,10 @@ ) from uuid import UUID +from sentry_protos.snuba.v1.endpoint_create_subscription_pb2 import ( + CreateSubscriptionRequest, +) + from snuba.datasets.dataset import Dataset from snuba.datasets.entities.entity_key import EntityKey from snuba.datasets.entities.factory import get_entity @@ -53,8 +59,20 @@ "resolution", "query", "tenant_ids", + "subscription_type", + "time_series_request", + "request_name", + "request_version", } +REQUEST_TYPE_ALLOWLIST = [("TimeSeriesRequest", "v1")] + + +class SubscriptionType(Enum): + SNQL = "snql" + RPC = "rpc" + + logger = logging.getLogger("snuba.subscriptions") PartitionId = NewType("PartitionId", int) @@ -83,9 +101,20 @@ class SubscriptionData(ABC): metadata: Mapping[str, Any] tenant_ids: Mapping[str, Any] = field(default_factory=lambda: dict()) - @abstractmethod def validate(self) -> None: - raise NotImplementedError + if self.time_window_sec < 60: + raise InvalidSubscriptionError( + "Time window must be greater than or equal to 1 minute" + ) + elif self.time_window_sec > 60 * 60 * 24: + raise InvalidSubscriptionError( + "Time window must be less than or equal to 24 hours" + ) + + if self.resolution_sec < 60: + raise InvalidSubscriptionError( + "Resolution must be greater than or equal to 1 minute" + ) @abstractmethod def build_request( @@ -111,6 +140,102 @@ def to_dict(self) -> Mapping[str, Any]: raise NotImplementedError +@dataclass(frozen=True, kw_only=True) +class RPCSubscriptionData(SubscriptionData): + """ + Represents the state of an RPC subscription. + """ + + time_series_request: str + + request_name: str + request_version: str + + def validate(self) -> None: + super().validate() + if (self.request_name, self.request_version) not in REQUEST_TYPE_ALLOWLIST: + raise InvalidSubscriptionError( + f"{self.request_name} {self.request_version} not supported." + ) + + def build_request( + self, + dataset: Dataset, + timestamp: datetime, + offset: Optional[int], + timer: Timer, + metrics: Optional[MetricsBackend] = None, + referrer: str = SUBSCRIPTION_REFERRER, + ) -> Request: + raise NotImplementedError + + @classmethod + def from_dict( + cls, data: Mapping[str, Any], entity_key: EntityKey + ) -> RPCSubscriptionData: + entity: Entity = get_entity(entity_key) + metadata = {} + for key in data.keys(): + if key == "metadata": + metadata.update(data[key]) + elif key not in SUBSCRIPTION_DATA_PAYLOAD_KEYS: + metadata[key] = data[key] + + return RPCSubscriptionData( + project_id=data["project_id"], + time_window_sec=int(data["time_window"]), + resolution_sec=int(data["resolution"]), + time_series_request=data["time_series_request"], + request_version=data["request_version"], + request_name=data["request_name"], + entity=entity, + metadata=metadata, + tenant_ids=data.get("tenant_ids", dict()), + ) + + @classmethod + def from_proto( + cls, item: CreateSubscriptionRequest, entity_key: EntityKey + ) -> RPCSubscriptionData: + entity: Entity = get_entity(entity_key) + request_class = item.time_series_request.__class__ + class_name = request_class.__name__ + class_version = request_class.__module__.split(".", 3)[2] + + metadata = {} + if item.time_series_request.meta: + metadata["organization"] = item.time_series_request.meta.organization_id + + return RPCSubscriptionData( + project_id=item.time_series_request.meta.project_ids[0], + time_window_sec=item.time_window_secs, + resolution_sec=item.resolution_secs, + time_series_request=base64.b64encode( + item.time_series_request.SerializeToString() + ).decode("utf-8"), + entity=entity, + metadata=metadata, + tenant_ids={}, + request_version=class_version, + request_name=class_name, + ) + + def to_dict(self) -> Mapping[str, Any]: + subscription_data_dict = { + "project_id": self.project_id, + "time_window": self.time_window_sec, + "resolution": self.resolution_sec, + "time_series_request": self.time_series_request, + "request_version": self.request_version, + "request_name": self.request_name, + "subscription_type": SubscriptionType.RPC.value, + } + if self.metadata: + subscription_data_dict["metadata"] = self.metadata + + return subscription_data_dict + + @dataclass(frozen=True, kw_only=True) class SnQLSubscriptionData(SubscriptionData): """ diff --git a/snuba/subscriptions/rpc_helpers.py b/snuba/subscriptions/rpc_helpers.py new file mode 100644 index 0000000000..b019ad8976 --- /dev/null +++ b/snuba/subscriptions/rpc_helpers.py @@ -0,0 +1,59 @@ +import base64 +from datetime import UTC, datetime, timedelta + +from google.protobuf.timestamp_pb2 import Timestamp +from sentry_protos.snuba.v1.endpoint_time_series_pb2 import TimeSeriesRequest + +from snuba.reader import Result +from snuba.web import QueryResult +from snuba.web.rpc.v1.endpoint_time_series import EndpointTimeSeries + + +def build_rpc_request( + timestamp: datetime, + time_window_sec: int, + time_series_request: str, +) -> TimeSeriesRequest: + + request_class = EndpointTimeSeries().request_class()() + request_class.ParseFromString(base64.b64decode(time_series_request)) + + # TODO: update it to round to the lowest granularity + # rounded_ts = int(timestamp.replace(tzinfo=UTC).timestamp() / 15) * 15 + rounded_ts = ( + int(timestamp.replace(tzinfo=UTC).timestamp() / time_window_sec) + * time_window_sec + ) + rounded_start = datetime.utcfromtimestamp(rounded_ts) + + start_time_proto = Timestamp() + start_time_proto.FromDatetime(rounded_start - timedelta(seconds=time_window_sec)) + end_time_proto = Timestamp() + end_time_proto.FromDatetime(rounded_start) + request_class.meta.start_timestamp.CopyFrom(start_time_proto) + request_class.meta.end_timestamp.CopyFrom(end_time_proto) + + request_class.granularity_secs = time_window_sec + + return request_class + + +def run_rpc_subscription_query( + request: TimeSeriesRequest, +) -> QueryResult: + response = EndpointTimeSeries().execute(request) + if not response.result_timeseries: + result: Result = { + "meta": [], + "data": [{request.aggregations[0].label: 0}], + "trace_output": "", + } + return QueryResult( + result=result, extra={"stats": {}, "sql": "", "experiments": {}} + ) + + timeseries = response.result_timeseries[0] + data = [{timeseries.label: timeseries.data_points[0].data}] + + result = {"meta": [], "data": data, "trace_output": ""} + return QueryResult(result=result, extra={"stats": {}, "sql": "", "experiments": {}}) diff --git a/snuba/subscriptions/subscription.py b/snuba/subscriptions/subscription.py index 021de963fc..0212dd2a31 100644 --- a/snuba/subscriptions/subscription.py +++ b/snuba/subscriptions/subscription.py @@ -1,16 +1,25 @@ from datetime import datetime from uuid import UUID, uuid1 +from sentry_protos.snuba.v1.endpoint_time_series_pb2 import TimeSeriesRequest + from snuba.datasets.dataset import Dataset from snuba.datasets.entities.entity_key import EntityKey from snuba.datasets.entities.factory import enforce_table_writer, get_entity from snuba.redis import RedisClientKey, get_redis_client +from snuba.request import Request from snuba.subscriptions.data import ( PartitionId, + RPCSubscriptionData, + SnQLSubscriptionData, SubscriptionData, SubscriptionIdentifier, ) from snuba.subscriptions.partitioner import TopicSubscriptionDataPartitioner +from snuba.subscriptions.rpc_helpers import ( + build_rpc_request, + run_rpc_subscription_query, +) from snuba.subscriptions.store import RedisSubscriptionDataStore from snuba.utils.metrics.timer import Timer from snuba.web.query import run_query @@ -51,8 +60,17 @@ def create(self, data: SubscriptionData, timer: Timer) -> SubscriptionIdentifier return identifier def _test_request(self, data: SubscriptionData, timer: Timer) -> None: - request = data.build_request(self.dataset, datetime.utcnow(), None, timer) - run_query(self.dataset, request, timer) + request: Request | TimeSeriesRequest + if isinstance(data, SnQLSubscriptionData): + request = data.build_request(self.dataset, datetime.utcnow(), None, timer) + run_query(self.dataset, request, timer) + if isinstance(data, RPCSubscriptionData): + request = build_rpc_request( + datetime.utcnow(), + data.time_window_sec, + data.time_series_request, + ) + run_rpc_subscription_query(request) class SubscriptionDeleter: diff --git a/snuba/web/rpc/v1/create_subscription.py b/snuba/web/rpc/v1/create_subscription.py new file mode 100644 index 0000000000..950943205b --- /dev/null +++ b/snuba/web/rpc/v1/create_subscription.py @@ -0,0 +1,44 @@ +from typing import Type + +from sentry_protos.snuba.v1.endpoint_create_subscription_pb2 import ( + CreateSubscriptionRequest as CreateSubscriptionRequestProto, +) +from sentry_protos.snuba.v1.endpoint_create_subscription_pb2 import ( + CreateSubscriptionResponse, +) + +from snuba.datasets.entities.entity_key import EntityKey +from snuba.datasets.pluggable_dataset import PluggableDataset +from snuba.web.rpc import RPCEndpoint + + +class CreateSubscriptionRequest( + RPCEndpoint[CreateSubscriptionRequestProto, CreateSubscriptionResponse] +): + @classmethod + def version(cls) -> str: + return "v1" + + @classmethod + def request_class(cls) -> Type[CreateSubscriptionRequestProto]: + return CreateSubscriptionRequestProto + + @classmethod + def response_class(cls) -> Type[CreateSubscriptionResponse]: + return CreateSubscriptionResponse + + def _execute( + self, in_msg: CreateSubscriptionRequestProto + ) -> CreateSubscriptionResponse: + from snuba.subscriptions.data import RPCSubscriptionData + from snuba.subscriptions.subscription import SubscriptionCreator + + dataset = PluggableDataset(name="eap", all_entities=[]) + entity_key = EntityKey("eap_spans") + + subscription = RPCSubscriptionData.from_proto(in_msg, entity_key=entity_key) + identifier = SubscriptionCreator(dataset, entity_key).create( + subscription, self._timer + ) + + return CreateSubscriptionResponse(subscription_id=str(identifier)) diff --git a/tests/web/rpc/v1/test_create_subscription.py b/tests/web/rpc/v1/test_create_subscription.py new file mode 100644 index 0000000000..b1a5d7ed20 --- /dev/null +++ b/tests/web/rpc/v1/test_create_subscription.py @@ -0,0 +1,109 @@ +from datetime import UTC, datetime, timedelta + +import pytest +from sentry_protos.snuba.v1.endpoint_create_subscription_pb2 import ( + CreateSubscriptionRequest as CreateSubscriptionRequestProto, +) +from sentry_protos.snuba.v1.endpoint_create_subscription_pb2 import ( + CreateSubscriptionResponse, +) +from sentry_protos.snuba.v1.endpoint_time_series_pb2 import TimeSeriesRequest +from sentry_protos.snuba.v1.error_pb2 import Error +from sentry_protos.snuba.v1.request_common_pb2 import RequestMeta +from sentry_protos.snuba.v1.trace_item_attribute_pb2 import ( + AttributeAggregation, + AttributeKey, + ExtrapolationMode, + Function, +) + +from tests.base import BaseApiTest +from tests.web.rpc.v1.test_endpoint_time_series import DummyMetric, store_timeseries + +END_TIME = datetime.utcnow().replace(second=0, microsecond=0, tzinfo=UTC) +START_TIME = END_TIME - timedelta(hours=1) + + +@pytest.mark.clickhouse_db +@pytest.mark.redis_db +class TestCreateSubscriptionApi(BaseApiTest): + def test_create_valid_subscription(self) -> None: + store_timeseries( + START_TIME, + 1, + 3600, + metrics=[DummyMetric("test_metric", get_value=lambda x: 1)], + ) + + message = CreateSubscriptionRequestProto( + time_series_request=TimeSeriesRequest( + meta=RequestMeta( + project_ids=[1, 2, 3], + organization_id=1, + cogs_category="something", + referrer="something", + ), + aggregations=[ + AttributeAggregation( + aggregate=Function.FUNCTION_SUM, + key=AttributeKey( + type=AttributeKey.TYPE_FLOAT, name="test_metric" + ), + label="sum", + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, + ), + ], + granularity_secs=300, + ), + time_window_secs=300, + resolution_secs=60, + ) + response = self.app.post( + "/rpc/CreateSubscriptionRequest/v1", data=message.SerializeToString() + ) + assert response.status_code == 200 + response_class = CreateSubscriptionResponse() + response_class.ParseFromString(response.data) + assert response_class.subscription_id + + def test_create_invalid_subscription(self) -> None: + store_timeseries( + START_TIME, + 1, + 3600, + metrics=[DummyMetric("test_metric", get_value=lambda x: 1)], + ) + + message = CreateSubscriptionRequestProto( + time_series_request=TimeSeriesRequest( + meta=RequestMeta( + project_ids=[1, 2, 3], + organization_id=1, + cogs_category="something", + referrer="something", + ), + aggregations=[ + AttributeAggregation( + aggregate=Function.FUNCTION_SUM, + key=AttributeKey( + type=AttributeKey.TYPE_FLOAT, name="test_metric" + ), + label="sum", + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, + ), + ], + granularity_secs=172800, + ), + time_window_secs=172800, + resolution_secs=60, + ) + response = self.app.post( + "/rpc/CreateSubscriptionRequest/v1", data=message.SerializeToString() + ) + assert response.status_code == 500 + error = Error() + error.ParseFromString(response.data) + assert ( + error.message + == "internal error occurred while executing this RPC call: Time window must be less than or equal to 24 hours" + ) From 5726ade1e941714c79e3f1c291d44814a7858e37 Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Mon, 18 Nov 2024 11:29:00 -0500 Subject: [PATCH 14/26] move validate to base class --- snuba/subscriptions/data.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/snuba/subscriptions/data.py b/snuba/subscriptions/data.py index ad8c38cb88..23799c0c87 100644 --- a/snuba/subscriptions/data.py +++ b/snuba/subscriptions/data.py @@ -312,21 +312,6 @@ def add_conditions( "At least one Entity must have a timestamp column for subscriptions" ) - def validate(self) -> None: - if self.time_window_sec < 60: - raise InvalidSubscriptionError( - "Time window must be greater than or equal to 1 minute" - ) - elif self.time_window_sec > 60 * 60 * 24: - raise InvalidSubscriptionError( - "Time window must be less than or equal to 24 hours" - ) - - if self.resolution_sec < 60: - raise InvalidSubscriptionError( - "Resolution must be greater than or equal to 1 minute" - ) - def build_request( self, dataset: Dataset, From 66270bf9ef0dba6d62b31b9331fd9d059b0c1d40 Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Mon, 18 Nov 2024 11:49:57 -0500 Subject: [PATCH 15/26] inspect what's stored in redis in the test --- tests/web/rpc/v1/test_create_subscription.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/web/rpc/v1/test_create_subscription.py b/tests/web/rpc/v1/test_create_subscription.py index b1a5d7ed20..4ccfb649de 100644 --- a/tests/web/rpc/v1/test_create_subscription.py +++ b/tests/web/rpc/v1/test_create_subscription.py @@ -1,3 +1,4 @@ +import json from datetime import UTC, datetime, timedelta import pytest @@ -17,6 +18,7 @@ Function, ) +from snuba.redis import RedisClientKey, get_redis_client from tests.base import BaseApiTest from tests.web.rpc.v1.test_endpoint_time_series import DummyMetric, store_timeseries @@ -66,6 +68,19 @@ def test_create_valid_subscription(self) -> None: response_class.ParseFromString(response.data) assert response_class.subscription_id + redis_client = get_redis_client(RedisClientKey.SUBSCRIPTION_STORE) + stored_subscription_data = list( + redis_client.hgetall("subscriptions:eap_spans:0").items() + )[0] + subscription_request = stored_subscription_data[1] + subscription_data = json.loads(subscription_request.decode("utf-8")) + + assert "time_series_request" in subscription_data + assert subscription_data["time_window"] == 300 + assert subscription_data["resolution"] == 60 + assert subscription_data["request_name"] == "TimeSeriesRequest" + assert subscription_data["request_version"] == "v1" + def test_create_invalid_subscription(self) -> None: store_timeseries( START_TIME, From c99d49c4c0f10f535f61508acd3fdcc13ec332ed Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Tue, 19 Nov 2024 13:44:06 -0500 Subject: [PATCH 16/26] serialize what's stored in redis --- tests/web/rpc/v1/test_create_subscription.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/web/rpc/v1/test_create_subscription.py b/tests/web/rpc/v1/test_create_subscription.py index 4ccfb649de..4250151e16 100644 --- a/tests/web/rpc/v1/test_create_subscription.py +++ b/tests/web/rpc/v1/test_create_subscription.py @@ -1,3 +1,4 @@ +import base64 import json from datetime import UTC, datetime, timedelta @@ -18,6 +19,7 @@ Function, ) +from snuba.datasets.entities.entity_key import EntityKey from snuba.redis import RedisClientKey, get_redis_client from tests.base import BaseApiTest from tests.web.rpc.v1.test_endpoint_time_series import DummyMetric, store_timeseries @@ -67,15 +69,21 @@ def test_create_valid_subscription(self) -> None: response_class = CreateSubscriptionResponse() response_class.ParseFromString(response.data) assert response_class.subscription_id + partition = response_class.subscription_id.split("/", 1)[0] + entity_key = EntityKey("eap_spans") redis_client = get_redis_client(RedisClientKey.SUBSCRIPTION_STORE) stored_subscription_data = list( - redis_client.hgetall("subscriptions:eap_spans:0").items() + redis_client.hgetall( + f"subscriptions:{entity_key.value}:{partition}" + ).items() )[0] subscription_request = stored_subscription_data[1] subscription_data = json.loads(subscription_request.decode("utf-8")) - assert "time_series_request" in subscription_data + time_series_request = subscription_data["time_series_request"] + request_class = TimeSeriesRequest() + request_class.ParseFromString(base64.b64decode(time_series_request)) assert subscription_data["time_window"] == 300 assert subscription_data["resolution"] == 60 assert subscription_data["request_name"] == "TimeSeriesRequest" From 17a49feb3538e891803e54e3aa00bf8d6d0170e5 Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Tue, 19 Nov 2024 13:49:17 -0500 Subject: [PATCH 17/26] add a comment --- tests/web/rpc/v1/test_create_subscription.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/web/rpc/v1/test_create_subscription.py b/tests/web/rpc/v1/test_create_subscription.py index 4250151e16..143d5eb74e 100644 --- a/tests/web/rpc/v1/test_create_subscription.py +++ b/tests/web/rpc/v1/test_create_subscription.py @@ -73,6 +73,8 @@ def test_create_valid_subscription(self) -> None: entity_key = EntityKey("eap_spans") redis_client = get_redis_client(RedisClientKey.SUBSCRIPTION_STORE) + # TODO[fix]: querying the redis client like this directly is temporary + # because we don't have decode support for rpc queries in the codec yet. stored_subscription_data = list( redis_client.hgetall( f"subscriptions:{entity_key.value}:{partition}" From 83f1693dcb4fe17cf3c46fb8f7c4027ec222939b Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Wed, 20 Nov 2024 07:44:34 -0500 Subject: [PATCH 18/26] add more validation --- snuba/subscriptions/data.py | 16 ++ tests/web/rpc/v1/test_create_subscription.py | 159 +++++++++++++++---- 2 files changed, 145 insertions(+), 30 deletions(-) diff --git a/snuba/subscriptions/data.py b/snuba/subscriptions/data.py index 23799c0c87..80d3651f5a 100644 --- a/snuba/subscriptions/data.py +++ b/snuba/subscriptions/data.py @@ -24,6 +24,7 @@ from sentry_protos.snuba.v1.endpoint_create_subscription_pb2 import ( CreateSubscriptionRequest, ) +from sentry_protos.snuba.v1.endpoint_time_series_pb2 import TimeSeriesRequest from snuba.datasets.dataset import Dataset from snuba.datasets.entities.entity_key import EntityKey @@ -158,6 +159,21 @@ def validate(self) -> None: f"{self.request_name} {self.request_version} not supported." ) + request = TimeSeriesRequest() + request.ParseFromString(base64.b64decode(self.time_series_request)) + + if not (request.meta) or len(request.meta.project_ids) == 0: + raise InvalidSubscriptionError("Project ID is required.") + + if len(request.meta.project_ids) != 1: + raise InvalidSubscriptionError("Multiple project IDs not supported.") + + if not request.aggregations or len(request.aggregations) != 1: + raise InvalidSubscriptionError("Exactly one aggregation required.") + + if request.group_by: + raise InvalidSubscriptionError("Group bys not supported.") + def build_request( self, dataset: Dataset, diff --git a/tests/web/rpc/v1/test_create_subscription.py b/tests/web/rpc/v1/test_create_subscription.py index 143d5eb74e..5f6eea9954 100644 --- a/tests/web/rpc/v1/test_create_subscription.py +++ b/tests/web/rpc/v1/test_create_subscription.py @@ -28,6 +28,125 @@ START_TIME = END_TIME - timedelta(hours=1) +TESTS_INVALID_RPC_SUBSCRIPTIONS = [ + pytest.param( + CreateSubscriptionRequestProto( + time_series_request=TimeSeriesRequest( + meta=RequestMeta( + project_ids=[1], + organization_id=1, + cogs_category="something", + referrer="something", + ), + aggregations=[ + AttributeAggregation( + aggregate=Function.FUNCTION_SUM, + key=AttributeKey( + type=AttributeKey.TYPE_FLOAT, name="test_metric" + ), + label="sum", + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, + ), + ], + ), + time_window_secs=172800, + resolution_secs=60, + ), + "Time window must be less than or equal to 24 hours", + id="Invalid subscription: time window", + ), + pytest.param( + CreateSubscriptionRequestProto( + time_series_request=TimeSeriesRequest( + meta=RequestMeta( + project_ids=[1, 2, 3], + organization_id=1, + cogs_category="something", + referrer="something", + ), + aggregations=[ + AttributeAggregation( + aggregate=Function.FUNCTION_SUM, + key=AttributeKey( + type=AttributeKey.TYPE_FLOAT, name="test_metric" + ), + label="sum", + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, + ), + ], + ), + time_window_secs=300, + resolution_secs=60, + ), + "Multiple project IDs not supported", + id="Invalid subscription: multiple project ids", + ), + pytest.param( + CreateSubscriptionRequestProto( + time_series_request=TimeSeriesRequest( + meta=RequestMeta( + project_ids=[1], + organization_id=1, + cogs_category="something", + referrer="something", + ), + aggregations=[ + AttributeAggregation( + aggregate=Function.FUNCTION_SUM, + key=AttributeKey( + type=AttributeKey.TYPE_FLOAT, name="test_metric" + ), + label="sum", + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, + ), + AttributeAggregation( + aggregate=Function.FUNCTION_SUM, + key=AttributeKey( + type=AttributeKey.TYPE_FLOAT, name="test_metric" + ), + label="sum", + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, + ), + ], + ), + time_window_secs=300, + resolution_secs=60, + ), + "Exactly one aggregation required", + id="Invalid subscription: multiple aggregations", + ), + pytest.param( + CreateSubscriptionRequestProto( + time_series_request=TimeSeriesRequest( + meta=RequestMeta( + project_ids=[1], + organization_id=1, + cogs_category="something", + referrer="something", + ), + group_by=[ + AttributeKey(type=AttributeKey.TYPE_STRING, name="device.class") + ], + aggregations=[ + AttributeAggregation( + aggregate=Function.FUNCTION_SUM, + key=AttributeKey( + type=AttributeKey.TYPE_FLOAT, name="test_metric" + ), + label="sum", + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, + ), + ], + ), + time_window_secs=300, + resolution_secs=60, + ), + "Group bys not supported", + id="Invalid subscription: group by", + ), +] + + @pytest.mark.clickhouse_db @pytest.mark.redis_db class TestCreateSubscriptionApi(BaseApiTest): @@ -42,7 +161,7 @@ def test_create_valid_subscription(self) -> None: message = CreateSubscriptionRequestProto( time_series_request=TimeSeriesRequest( meta=RequestMeta( - project_ids=[1, 2, 3], + project_ids=[1], organization_id=1, cogs_category="something", referrer="something", @@ -91,7 +210,12 @@ def test_create_valid_subscription(self) -> None: assert subscription_data["request_name"] == "TimeSeriesRequest" assert subscription_data["request_version"] == "v1" - def test_create_invalid_subscription(self) -> None: + @pytest.mark.parametrize( + "create_subscription, error_message", TESTS_INVALID_RPC_SUBSCRIPTIONS + ) + def test_create_invalid_subscription( + self, create_subscription: CreateSubscriptionRequestProto, error_message: str + ) -> None: store_timeseries( START_TIME, 1, @@ -99,36 +223,11 @@ def test_create_invalid_subscription(self) -> None: metrics=[DummyMetric("test_metric", get_value=lambda x: 1)], ) - message = CreateSubscriptionRequestProto( - time_series_request=TimeSeriesRequest( - meta=RequestMeta( - project_ids=[1, 2, 3], - organization_id=1, - cogs_category="something", - referrer="something", - ), - aggregations=[ - AttributeAggregation( - aggregate=Function.FUNCTION_SUM, - key=AttributeKey( - type=AttributeKey.TYPE_FLOAT, name="test_metric" - ), - label="sum", - extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, - ), - ], - granularity_secs=172800, - ), - time_window_secs=172800, - resolution_secs=60, - ) response = self.app.post( - "/rpc/CreateSubscriptionRequest/v1", data=message.SerializeToString() + "/rpc/CreateSubscriptionRequest/v1", + data=create_subscription.SerializeToString(), ) assert response.status_code == 500 error = Error() error.ParseFromString(response.data) - assert ( - error.message - == "internal error occurred while executing this RPC call: Time window must be less than or equal to 24 hours" - ) + assert error_message in error.message From e7fa5fb5aaa5f0e7baaeaccbd536a557c180209c Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Wed, 20 Nov 2024 08:34:38 -0500 Subject: [PATCH 19/26] fix tests --- tests/web/rpc/v1/test_create_subscription.py | 31 ++++++++------------ 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/tests/web/rpc/v1/test_create_subscription.py b/tests/web/rpc/v1/test_create_subscription.py index 5f6eea9954..aa71e772e1 100644 --- a/tests/web/rpc/v1/test_create_subscription.py +++ b/tests/web/rpc/v1/test_create_subscription.py @@ -1,5 +1,4 @@ import base64 -import json from datetime import UTC, datetime, timedelta import pytest @@ -21,6 +20,7 @@ from snuba.datasets.entities.entity_key import EntityKey from snuba.redis import RedisClientKey, get_redis_client +from snuba.subscriptions.store import RedisSubscriptionDataStore from tests.base import BaseApiTest from tests.web.rpc.v1.test_endpoint_time_series import DummyMetric, store_timeseries @@ -189,26 +189,21 @@ def test_create_valid_subscription(self) -> None: response_class.ParseFromString(response.data) assert response_class.subscription_id partition = response_class.subscription_id.split("/", 1)[0] - entity_key = EntityKey("eap_spans") - redis_client = get_redis_client(RedisClientKey.SUBSCRIPTION_STORE) - # TODO[fix]: querying the redis client like this directly is temporary - # because we don't have decode support for rpc queries in the codec yet. - stored_subscription_data = list( - redis_client.hgetall( - f"subscriptions:{entity_key.value}:{partition}" - ).items() - )[0] - subscription_request = stored_subscription_data[1] - subscription_data = json.loads(subscription_request.decode("utf-8")) + rpc_subscription_data = RedisSubscriptionDataStore( + get_redis_client(RedisClientKey.SUBSCRIPTION_STORE), + EntityKey("eap_spans"), + partition, + ).all()[0][1] - time_series_request = subscription_data["time_series_request"] request_class = TimeSeriesRequest() - request_class.ParseFromString(base64.b64decode(time_series_request)) - assert subscription_data["time_window"] == 300 - assert subscription_data["resolution"] == 60 - assert subscription_data["request_name"] == "TimeSeriesRequest" - assert subscription_data["request_version"] == "v1" + request_class.ParseFromString( + base64.b64decode(rpc_subscription_data.time_series_request) + ) + assert rpc_subscription_data.time_window_sec == 300 + assert rpc_subscription_data.resolution_sec == 60 + assert rpc_subscription_data.request_name == "TimeSeriesRequest" + assert rpc_subscription_data.request_version == "v1" @pytest.mark.parametrize( "create_subscription, error_message", TESTS_INVALID_RPC_SUBSCRIPTIONS From 15d29f664508345aa937414a0fc56306605c8c7c Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Wed, 20 Nov 2024 08:35:58 -0500 Subject: [PATCH 20/26] minimize diff --- snuba/subscriptions/data.py | 322 ++++++++++++++++++------------------ 1 file changed, 161 insertions(+), 161 deletions(-) diff --git a/snuba/subscriptions/data.py b/snuba/subscriptions/data.py index fc6e7bbbb3..49785f1be7 100644 --- a/snuba/subscriptions/data.py +++ b/snuba/subscriptions/data.py @@ -161,6 +161,167 @@ def to_dict(self) -> Mapping[str, Any]: raise NotImplementedError +@dataclass(frozen=True, kw_only=True) +class RPCSubscriptionData(_SubscriptionData[TimeSeriesRequest]): + """ + Represents the state of an RPC subscription. + """ + + time_series_request: str + + request_name: str + request_version: str + + def validate(self) -> None: + super().validate() + if (self.request_name, self.request_version) not in REQUEST_TYPE_ALLOWLIST: + raise InvalidSubscriptionError( + f"{self.request_name} {self.request_version} not supported." + ) + + request = TimeSeriesRequest() + request.ParseFromString(base64.b64decode(self.time_series_request)) + + if not (request.meta) or len(request.meta.project_ids) == 0: + raise InvalidSubscriptionError("Project ID is required.") + + if len(request.meta.project_ids) != 1: + raise InvalidSubscriptionError("Multiple project IDs not supported.") + + if not request.aggregations or len(request.aggregations) != 1: + raise InvalidSubscriptionError("Exactly one aggregation required.") + + if request.group_by: + raise InvalidSubscriptionError("Group bys not supported.") + + def build_request( + self, + dataset: Dataset, + timestamp: datetime, + offset: Optional[int], + timer: Timer, + metrics: Optional[MetricsBackend] = None, + referrer: str = SUBSCRIPTION_REFERRER, + ) -> TimeSeriesRequest: + + request_class = EndpointTimeSeries().request_class()() + request_class.ParseFromString(base64.b64decode(self.time_series_request)) + + # TODO: update it to round to the lowest granularity + # rounded_ts = int(timestamp.replace(tzinfo=UTC).timestamp() / 15) * 15 + rounded_ts = ( + int(timestamp.replace(tzinfo=UTC).timestamp() / self.time_window_sec) + * self.time_window_sec + ) + rounded_start = datetime.utcfromtimestamp(rounded_ts) + + start_time_proto = Timestamp() + start_time_proto.FromDatetime( + rounded_start - timedelta(seconds=self.time_window_sec) + ) + end_time_proto = Timestamp() + end_time_proto.FromDatetime(rounded_start) + request_class.meta.start_timestamp.CopyFrom(start_time_proto) + request_class.meta.end_timestamp.CopyFrom(end_time_proto) + + request_class.granularity_secs = self.time_window_sec + + return request_class + + def run_query( + self, + dataset: Dataset, + request: TimeSeriesRequest, + timer: Timer, + robust: bool = False, + concurrent_queries_gauge: Optional[Gauge] = None, + ) -> QueryResult: + response = EndpointTimeSeries().execute(request) + if not response.result_timeseries: + result: Result = { + "meta": [], + "data": [{request.aggregations[0].label: 0}], + "trace_output": "", + } + return QueryResult( + result=result, extra={"stats": {}, "sql": "", "experiments": {}} + ) + + timeseries = response.result_timeseries[0] + data = [{timeseries.label: timeseries.data_points[0].data}] + + result = {"meta": [], "data": data, "trace_output": ""} + return QueryResult( + result=result, extra={"stats": {}, "sql": "", "experiments": {}} + ) + + @classmethod + def from_dict( + cls, data: Mapping[str, Any], entity_key: EntityKey + ) -> RPCSubscriptionData: + entity: Entity = get_entity(entity_key) + metadata = {} + for key in data.keys(): + if key == "metadata": + metadata.update(data[key]) + elif key not in SUBSCRIPTION_DATA_PAYLOAD_KEYS: + metadata[key] = data[key] + + return RPCSubscriptionData( + project_id=data["project_id"], + time_window_sec=int(data["time_window"]), + resolution_sec=int(data["resolution"]), + time_series_request=data["time_series_request"], + request_version=data["request_version"], + request_name=data["request_name"], + entity=entity, + metadata=metadata, + tenant_ids=data.get("tenant_ids", dict()), + ) + + @classmethod + def from_proto( + cls, item: CreateSubscriptionRequest, entity_key: EntityKey + ) -> RPCSubscriptionData: + entity: Entity = get_entity(entity_key) + request_class = item.time_series_request.__class__ + class_name = request_class.__name__ + class_version = request_class.__module__.split(".", 3)[2] + + metadata = {} + if item.time_series_request.meta: + metadata["organization"] = item.time_series_request.meta.organization_id + + return RPCSubscriptionData( + project_id=item.time_series_request.meta.project_ids[0], + time_window_sec=item.time_window_secs, + resolution_sec=item.resolution_secs, + time_series_request=base64.b64encode( + item.time_series_request.SerializeToString() + ).decode("utf-8"), + entity=entity, + metadata=metadata, + tenant_ids={}, + request_version=class_version, + request_name=class_name, + ) + + def to_dict(self) -> Mapping[str, Any]: + subscription_data_dict = { + "project_id": self.project_id, + "time_window": self.time_window_sec, + "resolution": self.resolution_sec, + "time_series_request": self.time_series_request, + "request_version": self.request_version, + "request_name": self.request_name, + "subscription_type": SubscriptionType.RPC.value, + } + if self.metadata: + subscription_data_dict["metadata"] = self.metadata + + return subscription_data_dict + + @dataclass(frozen=True, kw_only=True) class SnQLSubscriptionData(_SubscriptionData[Request]): """ @@ -332,167 +493,6 @@ def to_dict(self) -> Mapping[str, Any]: return subscription_data_dict -@dataclass(frozen=True, kw_only=True) -class RPCSubscriptionData(_SubscriptionData[TimeSeriesRequest]): - """ - Represents the state of an RPC subscription. - """ - - time_series_request: str - - request_name: str - request_version: str - - def validate(self) -> None: - super().validate() - if (self.request_name, self.request_version) not in REQUEST_TYPE_ALLOWLIST: - raise InvalidSubscriptionError( - f"{self.request_name} {self.request_version} not supported." - ) - - request = TimeSeriesRequest() - request.ParseFromString(base64.b64decode(self.time_series_request)) - - if not (request.meta) or len(request.meta.project_ids) == 0: - raise InvalidSubscriptionError("Project ID is required.") - - if len(request.meta.project_ids) != 1: - raise InvalidSubscriptionError("Multiple project IDs not supported.") - - if not request.aggregations or len(request.aggregations) != 1: - raise InvalidSubscriptionError("Exactly one aggregation required.") - - if request.group_by: - raise InvalidSubscriptionError("Group bys not supported.") - - def build_request( - self, - dataset: Dataset, - timestamp: datetime, - offset: Optional[int], - timer: Timer, - metrics: Optional[MetricsBackend] = None, - referrer: str = SUBSCRIPTION_REFERRER, - ) -> TimeSeriesRequest: - - request_class = EndpointTimeSeries().request_class()() - request_class.ParseFromString(base64.b64decode(self.time_series_request)) - - # TODO: update it to round to the lowest granularity - # rounded_ts = int(timestamp.replace(tzinfo=UTC).timestamp() / 15) * 15 - rounded_ts = ( - int(timestamp.replace(tzinfo=UTC).timestamp() / self.time_window_sec) - * self.time_window_sec - ) - rounded_start = datetime.utcfromtimestamp(rounded_ts) - - start_time_proto = Timestamp() - start_time_proto.FromDatetime( - rounded_start - timedelta(seconds=self.time_window_sec) - ) - end_time_proto = Timestamp() - end_time_proto.FromDatetime(rounded_start) - request_class.meta.start_timestamp.CopyFrom(start_time_proto) - request_class.meta.end_timestamp.CopyFrom(end_time_proto) - - request_class.granularity_secs = self.time_window_sec - - return request_class - - def run_query( - self, - dataset: Dataset, - request: TimeSeriesRequest, - timer: Timer, - robust: bool = False, - concurrent_queries_gauge: Optional[Gauge] = None, - ) -> QueryResult: - response = EndpointTimeSeries().execute(request) - if not response.result_timeseries: - result: Result = { - "meta": [], - "data": [{request.aggregations[0].label: 0}], - "trace_output": "", - } - return QueryResult( - result=result, extra={"stats": {}, "sql": "", "experiments": {}} - ) - - timeseries = response.result_timeseries[0] - data = [{timeseries.label: timeseries.data_points[0].data}] - - result = {"meta": [], "data": data, "trace_output": ""} - return QueryResult( - result=result, extra={"stats": {}, "sql": "", "experiments": {}} - ) - - @classmethod - def from_dict( - cls, data: Mapping[str, Any], entity_key: EntityKey - ) -> RPCSubscriptionData: - entity: Entity = get_entity(entity_key) - metadata = {} - for key in data.keys(): - if key == "metadata": - metadata.update(data[key]) - elif key not in SUBSCRIPTION_DATA_PAYLOAD_KEYS: - metadata[key] = data[key] - - return RPCSubscriptionData( - project_id=data["project_id"], - time_window_sec=int(data["time_window"]), - resolution_sec=int(data["resolution"]), - time_series_request=data["time_series_request"], - request_version=data["request_version"], - request_name=data["request_name"], - entity=entity, - metadata=metadata, - tenant_ids=data.get("tenant_ids", dict()), - ) - - @classmethod - def from_proto( - cls, item: CreateSubscriptionRequest, entity_key: EntityKey - ) -> RPCSubscriptionData: - entity: Entity = get_entity(entity_key) - request_class = item.time_series_request.__class__ - class_name = request_class.__name__ - class_version = request_class.__module__.split(".", 3)[2] - - metadata = {} - if item.time_series_request.meta: - metadata["organization"] = item.time_series_request.meta.organization_id - - return RPCSubscriptionData( - project_id=item.time_series_request.meta.project_ids[0], - time_window_sec=item.time_window_secs, - resolution_sec=item.resolution_secs, - time_series_request=base64.b64encode( - item.time_series_request.SerializeToString() - ).decode("utf-8"), - entity=entity, - metadata=metadata, - tenant_ids={}, - request_version=class_version, - request_name=class_name, - ) - - def to_dict(self) -> Mapping[str, Any]: - subscription_data_dict = { - "project_id": self.project_id, - "time_window": self.time_window_sec, - "resolution": self.resolution_sec, - "time_series_request": self.time_series_request, - "request_version": self.request_version, - "request_name": self.request_name, - "subscription_type": SubscriptionType.RPC.value, - } - if self.metadata: - subscription_data_dict["metadata"] = self.metadata - - return subscription_data_dict - - SubscriptionData = Union[RPCSubscriptionData, SnQLSubscriptionData] SubscriptionRequest = Union[Request, TimeSeriesRequest] From df861f5c7dd2f2ec97feaf7dbc2781cdc4b7674a Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Wed, 20 Nov 2024 08:38:30 -0500 Subject: [PATCH 21/26] remove rpc helper file --- snuba/subscriptions/rpc_helpers.py | 59 ------------------------------ 1 file changed, 59 deletions(-) delete mode 100644 snuba/subscriptions/rpc_helpers.py diff --git a/snuba/subscriptions/rpc_helpers.py b/snuba/subscriptions/rpc_helpers.py deleted file mode 100644 index b019ad8976..0000000000 --- a/snuba/subscriptions/rpc_helpers.py +++ /dev/null @@ -1,59 +0,0 @@ -import base64 -from datetime import UTC, datetime, timedelta - -from google.protobuf.timestamp_pb2 import Timestamp -from sentry_protos.snuba.v1.endpoint_time_series_pb2 import TimeSeriesRequest - -from snuba.reader import Result -from snuba.web import QueryResult -from snuba.web.rpc.v1.endpoint_time_series import EndpointTimeSeries - - -def build_rpc_request( - timestamp: datetime, - time_window_sec: int, - time_series_request: str, -) -> TimeSeriesRequest: - - request_class = EndpointTimeSeries().request_class()() - request_class.ParseFromString(base64.b64decode(time_series_request)) - - # TODO: update it to round to the lowest granularity - # rounded_ts = int(timestamp.replace(tzinfo=UTC).timestamp() / 15) * 15 - rounded_ts = ( - int(timestamp.replace(tzinfo=UTC).timestamp() / time_window_sec) - * time_window_sec - ) - rounded_start = datetime.utcfromtimestamp(rounded_ts) - - start_time_proto = Timestamp() - start_time_proto.FromDatetime(rounded_start - timedelta(seconds=time_window_sec)) - end_time_proto = Timestamp() - end_time_proto.FromDatetime(rounded_start) - request_class.meta.start_timestamp.CopyFrom(start_time_proto) - request_class.meta.end_timestamp.CopyFrom(end_time_proto) - - request_class.granularity_secs = time_window_sec - - return request_class - - -def run_rpc_subscription_query( - request: TimeSeriesRequest, -) -> QueryResult: - response = EndpointTimeSeries().execute(request) - if not response.result_timeseries: - result: Result = { - "meta": [], - "data": [{request.aggregations[0].label: 0}], - "trace_output": "", - } - return QueryResult( - result=result, extra={"stats": {}, "sql": "", "experiments": {}} - ) - - timeseries = response.result_timeseries[0] - data = [{timeseries.label: timeseries.data_points[0].data}] - - result = {"meta": [], "data": data, "trace_output": ""} - return QueryResult(result=result, extra={"stats": {}, "sql": "", "experiments": {}}) From 356d368ffdde9d0547cac6a5066fe5e9c22ed52a Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Wed, 20 Nov 2024 10:08:18 -0500 Subject: [PATCH 22/26] update test subscriptions to be valid --- tests/subscriptions/test_codecs.py | 2 +- tests/subscriptions/test_data.py | 8 ++++---- tests/subscriptions/test_filter_subscriptions.py | 4 ++-- tests/subscriptions/test_scheduler_consumer.py | 2 +- tests/subscriptions/test_subscription.py | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/subscriptions/test_codecs.py b/tests/subscriptions/test_codecs.py index 5046fbc26d..c2818dc915 100644 --- a/tests/subscriptions/test_codecs.py +++ b/tests/subscriptions/test_codecs.py @@ -68,7 +68,7 @@ def build_rpc_subscription_data_from_proto( type=AttributeKey.TYPE_FLOAT, name="test_metric" ), label="sum", - extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED, ), ], filter=TraceItemFilter( diff --git a/tests/subscriptions/test_data.py b/tests/subscriptions/test_data.py index 0ef67ebc6a..26b84c4190 100644 --- a/tests/subscriptions/test_data.py +++ b/tests/subscriptions/test_data.py @@ -108,7 +108,7 @@ CreateSubscriptionRequestProto( time_series_request=TimeSeriesRequest( meta=RequestMeta( - project_ids=[1, 2, 3], + project_ids=[1], organization_id=1, cogs_category="something", referrer="something", @@ -120,7 +120,7 @@ type=AttributeKey.TYPE_FLOAT, name="my.float.field" ), label="count", - extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED, ), ], ), @@ -138,7 +138,7 @@ CreateSubscriptionRequestProto( time_series_request=TimeSeriesRequest( meta=RequestMeta( - project_ids=[1, 2, 3], + project_ids=[1], organization_id=1, cogs_category="something", referrer="something", @@ -150,7 +150,7 @@ type=AttributeKey.TYPE_FLOAT, name="my.float.field" ), label="count", - extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED, ), ], filter=TraceItemFilter( diff --git a/tests/subscriptions/test_filter_subscriptions.py b/tests/subscriptions/test_filter_subscriptions.py index f683cab241..f13cb3c7cb 100644 --- a/tests/subscriptions/test_filter_subscriptions.py +++ b/tests/subscriptions/test_filter_subscriptions.py @@ -86,7 +86,7 @@ def build_rpc_subscription(resolution: timedelta, org_id: int) -> Subscription: CreateSubscriptionRequestProto( time_series_request=TimeSeriesRequest( meta=RequestMeta( - project_ids=[1, 2, 3], + project_ids=[1], organization_id=org_id, cogs_category="something", referrer="something", @@ -98,7 +98,7 @@ def build_rpc_subscription(resolution: timedelta, org_id: int) -> Subscription: type=AttributeKey.TYPE_FLOAT, name="test_metric" ), label="sum", - extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED, ), ], filter=TraceItemFilter( diff --git a/tests/subscriptions/test_scheduler_consumer.py b/tests/subscriptions/test_scheduler_consumer.py index 664a17384d..dfc8967674 100644 --- a/tests/subscriptions/test_scheduler_consumer.py +++ b/tests/subscriptions/test_scheduler_consumer.py @@ -209,7 +209,7 @@ def test_scheduler_consumer_rpc_subscriptions(tmpdir: LocalPath) -> None: type=AttributeKey.TYPE_FLOAT, name="test_metric" ), label="sum", - extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED, ), ], filter=TraceItemFilter( diff --git a/tests/subscriptions/test_subscription.py b/tests/subscriptions/test_subscription.py index 0e59cc6c21..8fa2b5089f 100644 --- a/tests/subscriptions/test_subscription.py +++ b/tests/subscriptions/test_subscription.py @@ -350,7 +350,7 @@ def test(self) -> None: CreateSubscriptionRequestProto( time_series_request=TimeSeriesRequest( meta=RequestMeta( - project_ids=[1, 2, 3], + project_ids=[1], organization_id=1, cogs_category="something", referrer="something", @@ -362,7 +362,7 @@ def test(self) -> None: type=AttributeKey.TYPE_FLOAT, name="test_metric" ), label="sum", - extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE, + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED, ), ], filter=TraceItemFilter( From baf5566add8795654352944765d275a3918bbbec Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Wed, 20 Nov 2024 13:09:41 -0500 Subject: [PATCH 23/26] fix typing --- tests/web/rpc/v1/test_create_subscription.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/web/rpc/v1/test_create_subscription.py b/tests/web/rpc/v1/test_create_subscription.py index aa71e772e1..4429aae1fd 100644 --- a/tests/web/rpc/v1/test_create_subscription.py +++ b/tests/web/rpc/v1/test_create_subscription.py @@ -20,6 +20,7 @@ from snuba.datasets.entities.entity_key import EntityKey from snuba.redis import RedisClientKey, get_redis_client +from snuba.subscriptions.data import PartitionId, RPCSubscriptionData from snuba.subscriptions.store import RedisSubscriptionDataStore from tests.base import BaseApiTest from tests.web.rpc.v1.test_endpoint_time_series import DummyMetric, store_timeseries @@ -188,13 +189,17 @@ def test_create_valid_subscription(self) -> None: response_class = CreateSubscriptionResponse() response_class.ParseFromString(response.data) assert response_class.subscription_id - partition = response_class.subscription_id.split("/", 1)[0] + partition = int(response_class.subscription_id.split("/", 1)[0]) - rpc_subscription_data = RedisSubscriptionDataStore( - get_redis_client(RedisClientKey.SUBSCRIPTION_STORE), - EntityKey("eap_spans"), - partition, - ).all()[0][1] + rpc_subscription_data = list( + RedisSubscriptionDataStore( + get_redis_client(RedisClientKey.SUBSCRIPTION_STORE), + EntityKey("eap_spans"), + PartitionId(partition), + ).all() + )[0][1] + + assert isinstance(rpc_subscription_data, RPCSubscriptionData) request_class = TimeSeriesRequest() request_class.ParseFromString( From 6321e0c9055653723356de01cd2c70d98499166b Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Thu, 21 Nov 2024 06:56:40 -0500 Subject: [PATCH 24/26] remove rpc helper file --- snuba/subscriptions/rpc_helpers.py | 59 ------------------------------ 1 file changed, 59 deletions(-) delete mode 100644 snuba/subscriptions/rpc_helpers.py diff --git a/snuba/subscriptions/rpc_helpers.py b/snuba/subscriptions/rpc_helpers.py deleted file mode 100644 index b019ad8976..0000000000 --- a/snuba/subscriptions/rpc_helpers.py +++ /dev/null @@ -1,59 +0,0 @@ -import base64 -from datetime import UTC, datetime, timedelta - -from google.protobuf.timestamp_pb2 import Timestamp -from sentry_protos.snuba.v1.endpoint_time_series_pb2 import TimeSeriesRequest - -from snuba.reader import Result -from snuba.web import QueryResult -from snuba.web.rpc.v1.endpoint_time_series import EndpointTimeSeries - - -def build_rpc_request( - timestamp: datetime, - time_window_sec: int, - time_series_request: str, -) -> TimeSeriesRequest: - - request_class = EndpointTimeSeries().request_class()() - request_class.ParseFromString(base64.b64decode(time_series_request)) - - # TODO: update it to round to the lowest granularity - # rounded_ts = int(timestamp.replace(tzinfo=UTC).timestamp() / 15) * 15 - rounded_ts = ( - int(timestamp.replace(tzinfo=UTC).timestamp() / time_window_sec) - * time_window_sec - ) - rounded_start = datetime.utcfromtimestamp(rounded_ts) - - start_time_proto = Timestamp() - start_time_proto.FromDatetime(rounded_start - timedelta(seconds=time_window_sec)) - end_time_proto = Timestamp() - end_time_proto.FromDatetime(rounded_start) - request_class.meta.start_timestamp.CopyFrom(start_time_proto) - request_class.meta.end_timestamp.CopyFrom(end_time_proto) - - request_class.granularity_secs = time_window_sec - - return request_class - - -def run_rpc_subscription_query( - request: TimeSeriesRequest, -) -> QueryResult: - response = EndpointTimeSeries().execute(request) - if not response.result_timeseries: - result: Result = { - "meta": [], - "data": [{request.aggregations[0].label: 0}], - "trace_output": "", - } - return QueryResult( - result=result, extra={"stats": {}, "sql": "", "experiments": {}} - ) - - timeseries = response.result_timeseries[0] - data = [{timeseries.label: timeseries.data_points[0].data}] - - result = {"meta": [], "data": data, "trace_output": ""} - return QueryResult(result=result, extra={"stats": {}, "sql": "", "experiments": {}}) From f7f844865f9ffe368944ff43bea84a8a18fd5df5 Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Thu, 21 Nov 2024 07:42:02 -0500 Subject: [PATCH 25/26] return none, not 0 when no data is found --- snuba/subscriptions/data.py | 2 +- tests/subscriptions/test_data.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/snuba/subscriptions/data.py b/snuba/subscriptions/data.py index 389f4b410f..ad7bab4138 100644 --- a/snuba/subscriptions/data.py +++ b/snuba/subscriptions/data.py @@ -250,7 +250,7 @@ def run_query( if not response.result_timeseries: result: Result = { "meta": [], - "data": [{request.aggregations[0].label: 0}], + "data": [{request.aggregations[0].label: None}], "trace_output": "", } return QueryResult( diff --git a/tests/subscriptions/test_data.py b/tests/subscriptions/test_data.py index 26b84c4190..d570de012b 100644 --- a/tests/subscriptions/test_data.py +++ b/tests/subscriptions/test_data.py @@ -168,7 +168,7 @@ ), EntityKey.EAP_SPANS, ), - 0.0, + None, None, id="RPC subscription with filter", ), From a642b1b997cf6697615f5406e79b94d5dbdfcf6a Mon Sep 17 00:00:00 2001 From: Shruthilaya Jaganathan Date: Tue, 26 Nov 2024 10:45:45 -0500 Subject: [PATCH 26/26] address comments --- snuba/datasets/slicing.py | 1 - snuba/subscriptions/codecs.py | 2 +- .../test_filter_subscriptions.py | 88 ------------------- .../subscriptions/test_scheduler_consumer.py | 74 ++++++---------- tests/subscriptions/test_subscription.py | 7 +- 5 files changed, 32 insertions(+), 140 deletions(-) diff --git a/snuba/datasets/slicing.py b/snuba/datasets/slicing.py index 07995ca14e..c4c0597367 100644 --- a/snuba/datasets/slicing.py +++ b/snuba/datasets/slicing.py @@ -3,7 +3,6 @@ should be stored. These do not require individual physical partitions but allow for repartitioning with less code changes per physical change. """ - from snuba.clusters.storage_sets import StorageSetKey SENTRY_LOGICAL_PARTITIONS = 256 diff --git a/snuba/subscriptions/codecs.py b/snuba/subscriptions/codecs.py index 9136ed2225..4ce311a185 100644 --- a/snuba/subscriptions/codecs.py +++ b/snuba/subscriptions/codecs.py @@ -84,7 +84,7 @@ def encode(self, value: SubscriptionTaskResult) -> KafkaPayload: class SubscriptionScheduledTaskEncoder(Codec[KafkaPayload, ScheduledSubscriptionTask]): """ Encodes/decodes a scheduled subscription to Kafka payload. - Does not support non SnQL subscriptions. + Supports SnQL and RPC subscriptions. """ def encode(self, value: ScheduledSubscriptionTask) -> KafkaPayload: diff --git a/tests/subscriptions/test_filter_subscriptions.py b/tests/subscriptions/test_filter_subscriptions.py index f13cb3c7cb..ea5b9ce983 100644 --- a/tests/subscriptions/test_filter_subscriptions.py +++ b/tests/subscriptions/test_filter_subscriptions.py @@ -6,29 +6,12 @@ from unittest.mock import patch import pytest -from sentry_protos.snuba.v1.endpoint_create_subscription_pb2 import ( - CreateSubscriptionRequest as CreateSubscriptionRequestProto, -) -from sentry_protos.snuba.v1.endpoint_time_series_pb2 import TimeSeriesRequest -from sentry_protos.snuba.v1.request_common_pb2 import RequestMeta -from sentry_protos.snuba.v1.trace_item_attribute_pb2 import ( - AttributeAggregation, - AttributeKey, - AttributeValue, - ExtrapolationMode, - Function, -) -from sentry_protos.snuba.v1.trace_item_filter_pb2 import ( - ComparisonFilter, - TraceItemFilter, -) from snuba.datasets.entities.entity_key import EntityKey from snuba.datasets.entities.factory import get_entity from snuba.subscriptions import scheduler from snuba.subscriptions.data import ( PartitionId, - RPCSubscriptionData, SnQLSubscriptionData, Subscription, SubscriptionIdentifier, @@ -77,74 +60,3 @@ def test_filter_subscriptions(expected_subs, extra_subs) -> None: # type: ignor slice_id=2, ) assert filtered_subs == expected_subs - - -def build_rpc_subscription(resolution: timedelta, org_id: int) -> Subscription: - return Subscription( - SubscriptionIdentifier(PartitionId(1), uuid.uuid4()), - RPCSubscriptionData.from_proto( - CreateSubscriptionRequestProto( - time_series_request=TimeSeriesRequest( - meta=RequestMeta( - project_ids=[1], - organization_id=org_id, - cogs_category="something", - referrer="something", - ), - aggregations=[ - AttributeAggregation( - aggregate=Function.FUNCTION_SUM, - key=AttributeKey( - type=AttributeKey.TYPE_FLOAT, name="test_metric" - ), - label="sum", - extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED, - ), - ], - filter=TraceItemFilter( - comparison_filter=ComparisonFilter( - key=AttributeKey(type=AttributeKey.TYPE_STRING, name="foo"), - op=ComparisonFilter.OP_NOT_EQUALS, - value=AttributeValue(val_str="bar"), - ) - ), - ), - time_window_secs=300, - resolution_secs=int(resolution.total_seconds()), - ), - EntityKey.EAP_SPANS, - ), - ) - - -@pytest.fixture -def expected_rpc_subs() -> MutableSequence[Subscription]: - return [ - build_rpc_subscription(timedelta(minutes=1), 2) - for count in range(randint(1, 50)) - ] - - -@pytest.fixture -def extra_rpc_subs() -> MutableSequence[Subscription]: - return [ - build_rpc_subscription(timedelta(minutes=3), 1) - for count in range(randint(1, 50)) - ] - - -@patch("snuba.settings.SLICED_STORAGE_SETS", {"events_analytics_platform": 3}) -@patch( - "snuba.settings.LOGICAL_PARTITION_MAPPING", - {"events_analytics_platform": {0: 0, 1: 1, 2: 2}}, -) -def test_filter_rpc_subscriptions(expected_rpc_subs, extra_rpc_subs) -> None: # type: ignore - importlib.reload(scheduler) - - filtered_subs = filter_subscriptions( - subscriptions=expected_rpc_subs + extra_rpc_subs, - entity_key=EntityKey.EAP_SPANS, - metrics=DummyMetricsBackend(strict=True), - slice_id=2, - ) - assert filtered_subs == expected_rpc_subs diff --git a/tests/subscriptions/test_scheduler_consumer.py b/tests/subscriptions/test_scheduler_consumer.py index dfc8967674..dc884a9e2e 100644 --- a/tests/subscriptions/test_scheduler_consumer.py +++ b/tests/subscriptions/test_scheduler_consumer.py @@ -49,6 +49,7 @@ get_default_kafka_configuration, ) from snuba.utils.streams.topics import Topic as SnubaTopic +from snuba.web.rpc.v1.create_subscription import CreateSubscriptionRequest from tests.assertions import assert_changes from tests.backends.metrics import TestingMetricsBackend @@ -159,6 +160,7 @@ def test_scheduler_consumer(tmpdir: LocalPath) -> None: settings.TOPIC_PARTITION_COUNTS = {} +@pytest.mark.clickhouse_db @pytest.mark.redis_db def test_scheduler_consumer_rpc_subscriptions(tmpdir: LocalPath) -> None: settings.TOPIC_PARTITION_COUNTS = {"snuba-spans": 2} @@ -178,54 +180,34 @@ def test_scheduler_consumer_rpc_subscriptions(tmpdir: LocalPath) -> None: mock_scheduler_producer = mock.Mock() - from snuba.redis import RedisClientKey, get_redis_client - from snuba.subscriptions.data import PartitionId, RPCSubscriptionData - from snuba.subscriptions.store import RedisSubscriptionDataStore - - entity_key = EntityKey(entity_name) - partition_index = 0 - - store = RedisSubscriptionDataStore( - get_redis_client(RedisClientKey.SUBSCRIPTION_STORE), - entity_key, - PartitionId(partition_index), - ) - entity = get_entity(EntityKey.EVENTS) - store.create( - uuid.uuid4(), - RPCSubscriptionData.from_proto( - CreateSubscriptionRequestProto( - time_series_request=TimeSeriesRequest( - meta=RequestMeta( - project_ids=[1], - organization_id=1, - cogs_category="something", - referrer="something", - ), - aggregations=[ - AttributeAggregation( - aggregate=Function.FUNCTION_SUM, - key=AttributeKey( - type=AttributeKey.TYPE_FLOAT, name="test_metric" - ), - label="sum", - extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED, - ), - ], - filter=TraceItemFilter( - comparison_filter=ComparisonFilter( - key=AttributeKey(type=AttributeKey.TYPE_STRING, name="foo"), - op=ComparisonFilter.OP_NOT_EQUALS, - value=AttributeValue(val_str="bar"), - ) - ), + message = CreateSubscriptionRequestProto( + time_series_request=TimeSeriesRequest( + meta=RequestMeta( + project_ids=[1], + organization_id=1, + cogs_category="something", + referrer="something", + ), + aggregations=[ + AttributeAggregation( + aggregate=Function.FUNCTION_SUM, + key=AttributeKey(type=AttributeKey.TYPE_FLOAT, name="test_metric"), + label="sum", + extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED, ), - time_window_secs=300, - resolution_secs=60, + ], + filter=TraceItemFilter( + comparison_filter=ComparisonFilter( + key=AttributeKey(type=AttributeKey.TYPE_STRING, name="foo"), + op=ComparisonFilter.OP_NOT_EQUALS, + value=AttributeValue(val_str="bar"), + ) ), - EntityKey.EAP_SPANS, ), + time_window_secs=300, + resolution_secs=60, ) + CreateSubscriptionRequest().execute(message) builder = scheduler_consumer.SchedulerBuilder( entity_name, @@ -256,9 +238,9 @@ def test_scheduler_consumer_rpc_subscriptions(tmpdir: LocalPath) -> None: for partition, offset, ts in [ (0, 0, epoch), - (1, 0, epoch + 60), + (1, 0, epoch), (0, 1, epoch + 120), - (1, 1, epoch + 180), + (1, 1, epoch + 120), ]: fut = producer.produce( commit_log_topic, diff --git a/tests/subscriptions/test_subscription.py b/tests/subscriptions/test_subscription.py index 8fa2b5089f..31890bc278 100644 --- a/tests/subscriptions/test_subscription.py +++ b/tests/subscriptions/test_subscription.py @@ -24,6 +24,7 @@ from snuba.datasets.entities.factory import get_entity from snuba.datasets.entity_subscriptions.validators import InvalidSubscriptionError from snuba.datasets.factory import get_dataset +from snuba.datasets.pluggable_dataset import PluggableDataset from snuba.query.exceptions import InvalidQueryException, ValidationException from snuba.query.validation.validators import ColumnValidationMode from snuba.redis import RedisClientKey, get_redis_client @@ -386,14 +387,12 @@ def test(self) -> None: class TestEAPSpansRPCSubscriptionCreator: timer = Timer("test") - def setup_method(self) -> None: - self.dataset = get_dataset("metrics") - @pytest.mark.parametrize("subscription", TESTS_CREATE_RPC_SUBSCRIPTIONS) @pytest.mark.clickhouse_db @pytest.mark.redis_db def test(self, subscription: SubscriptionData) -> None: - creator = SubscriptionCreator(self.dataset, EntityKey.EAP_SPANS) + dataset = PluggableDataset(name="eap", all_entities=[]) + creator = SubscriptionCreator(dataset, EntityKey.EAP_SPANS) identifier = creator.create(subscription, self.timer) assert ( cast(