diff --git a/scripts/plugin.py b/scripts/plugin.py index 82345cd..cdff712 100755 --- a/scripts/plugin.py +++ b/scripts/plugin.py @@ -31,7 +31,7 @@ # Inline utility functions # Inline the size function for a given proto message field -def inline_size_function(proto_type: str, field_name: str, field_tag: str) -> str: +def inline_size_function(proto_type: str, attr_name: str, field_tag: str) -> str: """ For example: @@ -48,8 +48,8 @@ def size_uint32(self, TAG: bytes, FIELD_ATTR: int) -> int: function_definition = function_definition.splitlines()[1:] function_definition = "\n".join(function_definition) function_definition = dedent(function_definition) - # Replace the field name - function_definition = function_definition.replace("FIELD_ATTR", f"self.{field_name}") + # Replace the attribute name + function_definition = function_definition.replace("FIELD_ATTR", f"self.{attr_name}") # Replace the TAG function_definition = function_definition.replace("TAG", field_tag) # Inline the return statement @@ -57,7 +57,7 @@ def size_uint32(self, TAG: bytes, FIELD_ATTR: int) -> int: return function_definition # Inline the serialization function for a given proto message field -def inline_serialize_function(proto_type: str, field_name: str, field_tag: str) -> str: +def inline_serialize_function(proto_type: str, attr_name: str, field_tag: str) -> str: """ For example: @@ -76,8 +76,8 @@ def serialize_uint32(self, out: BytesIO, TAG: bytes, FIELD_ATTR: int) -> None: function_definition = function_definition.splitlines()[1:] function_definition = "\n".join(function_definition) function_definition = dedent(function_definition) - # Replace the field name - function_definition = function_definition.replace("FIELD_ATTR", f"self.{field_name}") + # Replace the attribute name + function_definition = function_definition.replace("FIELD_ATTR", f"self.{attr_name}") # Replace the TAG function_definition = function_definition.replace("TAG", field_tag) return function_definition @@ -93,16 +93,16 @@ def inline_init() -> str: # Add a presence check to a function definition # https://protobuf.dev/programming-guides/proto3/#default -def add_presence_check(proto_type: str, encode_presence: bool, field_name: str, function_definition: str) -> str: +def add_presence_check(proto_type: str, encode_presence: bool, attr_name: str, function_definition: str) -> str: # oneof, optional (virtual oneof), and message fields are encoded if they are not None function_definition = indent(function_definition, " ") if encode_presence: - return f"if self.{field_name} is not None:\n{function_definition}" + return f"if self.{attr_name} is not None:\n{function_definition}" # Other fields are encoded if they are not the default value # Which happens to align with the bool(x) check for all primitive types # TODO: Except # - double and float -0.0 should be encoded, even though bool(-0.0) is False - return f"if self.{field_name}:\n{function_definition}" + return f"if self.{attr_name}:\n{function_definition}" class WireType(IntEnum): VARINT = 0 @@ -234,6 +234,8 @@ def from_descriptor(descriptor: FieldDescriptorProto, group: Optional[str] = Non else: # https://protobuf.dev/reference/python/python-generated/#embedded_message generator = f"{python_type}()" + # the attribute name is prefixed with an underscore as message and repeated attributes + # are hidden behind a property that has the actual proto field name attr_name = f"_{field_name}" # Inline the size and serialization functions for the field diff --git a/tests/test_proto_serialization.py b/tests/test_proto_serialization.py index 17a9eca..7bb3feb 100644 --- a/tests/test_proto_serialization.py +++ b/tests/test_proto_serialization.py @@ -1,14 +1,16 @@ from __future__ import annotations +from typing import ( + Any, + Dict, + Mapping, + Sequence +) import unittest - -from dataclasses import dataclass -from typing import Any, Dict, Callable import hypothesis -from hypothesis.strategies import composite, text, booleans, integers, floats, lists, binary, sampled_from -from hypothesis.control import assume +import hypothesis.control as hc +import hypothesis.strategies as st -import hypothesis.strategies import opentelemetry.proto.logs.v1.logs_pb2 as logs_pb2 import opentelemetry.proto.trace.v1.trace_pb2 as trace_pb2 import opentelemetry.proto.common.v1.common_pb2 as common_pb2 @@ -22,158 +24,117 @@ import snowflake.telemetry._internal.opentelemetry.proto.resource.v1.resource_marshaler as resource_sf # Strategy for generating protobuf types -def pb_uint32(): return integers(min_value=0, max_value=2**32-1) -def pb_uint64(): return integers(min_value=0, max_value=2**64-1) -def pb_int32(): return integers(min_value=-2**31, max_value=2**31-1) -def pb_int64(): return integers(min_value=-2**63, max_value=2**63-1) -def pb_sint32(): return integers(min_value=-2**31, max_value=2**31-1) -def pb_sint64(): return integers(min_value=-2**63, max_value=2**63-1) -def pb_float(): return floats(allow_nan=False, allow_infinity=False, width=32) -def pb_double(): return floats(allow_nan=False, allow_infinity=False, width=64) +def pb_uint32(): return st.integers(min_value=0, max_value=2**32-1) +def pb_uint64(): return st.integers(min_value=0, max_value=2**64-1) +def pb_int32(): return st.integers(min_value=-2**31, max_value=2**31-1) +def pb_int64(): return st.integers(min_value=-2**63, max_value=2**63-1) +def pb_sint32(): return st.integers(min_value=-2**31, max_value=2**31-1) +def pb_sint64(): return st.integers(min_value=-2**63, max_value=2**63-1) +def pb_float(): return st.floats(allow_nan=False, allow_infinity=False, width=32) +def pb_double(): return st.floats(allow_nan=False, allow_infinity=False, width=64) def draw_pb_double(draw): # -0.0 is an edge case that is not handled by the custom serialization library double = draw(pb_double()) - assume(str(double) != "-0.0") + hc.assume(str(double) != "-0.0") return double def pb_fixed64(): return pb_uint64() def pb_fixed32(): return pb_uint32() def pb_sfixed64(): return pb_int64() def pb_sfixed32(): return pb_int32() -def pb_bool(): return booleans() -def pb_string(): return text(max_size=20) -def pb_bytes(): return binary(max_size=20) -def draw_pb_enum(draw, enum: EncodeStrategy): +def pb_bool(): return st.booleans() +def pb_string(): return st.text(max_size=20) +def pb_bytes(): return st.binary(max_size=20) +def draw_pb_enum(draw, enum): # Sample int val of enum, will be converted to member in encode_recurse # Sample from pb2 values as it is the source of truth - return draw(sampled_from([member for member in enum.pb2.values()])) -def pb_repeated(type): return lists(type, max_size=3) # limit the size of the repeated field to speed up testing -def pb_span_id(): return binary(min_size=8, max_size=8) -def pb_trace_id(): return binary(min_size=16, max_size=16) - -# Maps protobuf types to their serialization functions, from the protobuf and snowflake serialization libraries -@dataclass -class EncodeStrategy: - pb2: Callable[[Any], Any] - sf: Callable[[Any], Any] - -Resource = EncodeStrategy(pb2=resource_pb2.Resource, sf=resource_sf.Resource) - -InstrumentationScope = EncodeStrategy(pb2=common_pb2.InstrumentationScope, sf=common_sf.InstrumentationScope) -AnyValue = EncodeStrategy(pb2=common_pb2.AnyValue, sf=common_sf.AnyValue) -ArrayValue = EncodeStrategy(pb2=common_pb2.ArrayValue, sf=common_sf.ArrayValue) -KeyValue = EncodeStrategy(pb2=common_pb2.KeyValue, sf=common_sf.KeyValue) -KeyValueList = EncodeStrategy(pb2=common_pb2.KeyValueList, sf=common_sf.KeyValueList) - -LogRecord = EncodeStrategy(pb2=logs_pb2.LogRecord, sf=logs_sf.LogRecord) -ScopeLogs = EncodeStrategy(pb2=logs_pb2.ScopeLogs, sf=logs_sf.ScopeLogs) -ResourceLogs = EncodeStrategy(pb2=logs_pb2.ResourceLogs, sf=logs_sf.ResourceLogs) -LogsData = EncodeStrategy(pb2=logs_pb2.LogsData, sf=logs_sf.LogsData) -SeverityNumber = EncodeStrategy(pb2=logs_pb2.SeverityNumber, sf=logs_sf.SeverityNumber) - -TracesData = EncodeStrategy(pb2=trace_pb2.TracesData, sf=trace_sf.TracesData) -ScopeSpans = EncodeStrategy(pb2=trace_pb2.ScopeSpans, sf=trace_sf.ScopeSpans) -ResourceSpans = EncodeStrategy(pb2=trace_pb2.ResourceSpans, sf=trace_sf.ResourceSpans) -Span = EncodeStrategy(pb2=trace_pb2.Span, sf=trace_sf.Span) -Event = EncodeStrategy(pb2=trace_pb2.Span.Event, sf=trace_sf.Span.Event) -Link = EncodeStrategy(pb2=trace_pb2.Span.Link, sf=trace_sf.Span.Link) -Status = EncodeStrategy(pb2=trace_pb2.Status, sf=trace_sf.Status) -SpanKind = EncodeStrategy(pb2=trace_pb2.Span.SpanKind, sf=trace_sf.Span.SpanKind) -StatusCode = EncodeStrategy(pb2=trace_pb2.Status.StatusCode, sf=trace_sf.Status.StatusCode) - -Metric = EncodeStrategy(pb2=metrics_pb2.Metric, sf=metrics_sf.Metric) -ScopeMetrics = EncodeStrategy(pb2=metrics_pb2.ScopeMetrics, sf=metrics_sf.ScopeMetrics) -ResourceMetrics = EncodeStrategy(pb2=metrics_pb2.ResourceMetrics, sf=metrics_sf.ResourceMetrics) -MetricsData = EncodeStrategy(pb2=metrics_pb2.MetricsData, sf=metrics_sf.MetricsData) -Gauge = EncodeStrategy(pb2=metrics_pb2.Gauge, sf=metrics_sf.Gauge) -Sum = EncodeStrategy(pb2=metrics_pb2.Sum, sf=metrics_sf.Sum) -Histogram = EncodeStrategy(pb2=metrics_pb2.Histogram, sf=metrics_sf.Histogram) -ExponentialHistogram = EncodeStrategy(pb2=metrics_pb2.ExponentialHistogram, sf=metrics_sf.ExponentialHistogram) -Summary = EncodeStrategy(pb2=metrics_pb2.Summary, sf=metrics_sf.Summary) -NumberDataPoint = EncodeStrategy(pb2=metrics_pb2.NumberDataPoint, sf=metrics_sf.NumberDataPoint) -Exemplar = EncodeStrategy(pb2=metrics_pb2.Exemplar, sf=metrics_sf.Exemplar) -HistogramDataPoint = EncodeStrategy(pb2=metrics_pb2.HistogramDataPoint, sf=metrics_sf.HistogramDataPoint) -ExponentialHistogramDataPoint = EncodeStrategy(pb2=metrics_pb2.ExponentialHistogramDataPoint, sf=metrics_sf.ExponentialHistogramDataPoint) -SummaryDataPoint = EncodeStrategy(pb2=metrics_pb2.SummaryDataPoint, sf=metrics_sf.SummaryDataPoint) -ValueAtQuantile = EncodeStrategy(pb2=metrics_pb2.SummaryDataPoint.ValueAtQuantile, sf=metrics_sf.SummaryDataPoint.ValueAtQuantile) -Buckets = EncodeStrategy(pb2=metrics_pb2.ExponentialHistogramDataPoint.Buckets, sf=metrics_sf.ExponentialHistogramDataPoint.Buckets) -AggregationTemporality = EncodeStrategy(pb2=metrics_pb2.AggregationTemporality, sf=metrics_sf.AggregationTemporality) - -# Package the protobuf type with its arguments for serialization -@dataclass -class EncodeWithArgs: - kwargs: Dict[str, Any] - cls: EncodeStrategy - -# Package enum value with their serialization strategies -@dataclass -class EncodeEnumWithVal: - val: int - cls: EncodeStrategy + return draw(st.sampled_from([member for member in enum.values()])) +def pb_repeated(type): return st.lists(type, max_size=3) # limit the size of the repeated field to speed up testing +def pb_span_id(): return st.binary(min_size=8, max_size=8) +def pb_trace_id(): return st.binary(min_size=16, max_size=16) +# For drawing oneof fields +# call with pb_oneof(draw, field1=pb_type1_callable, field2=pb_type2_callable, ...) +def pb_oneof(draw, **kwargs): + n = len(kwargs) + r = draw(st.integers(min_value=0, max_value=n-1)) + k, v = list(kwargs.items())[r] + return {k: draw(v())} + +SF = "_sf" +PB = "_pb2" # Strategies for generating opentelemetry-proto types -@composite +@st.composite def instrumentation_scope(draw): - return EncodeWithArgs({ + return { + SF: common_sf.InstrumentationScope, + PB: common_pb2.InstrumentationScope, "name": draw(pb_string()), "version": draw(pb_string()), "attributes": draw(pb_repeated(key_value())), "dropped_attributes_count": draw(pb_uint32()), - }, InstrumentationScope) + } -@composite +@st.composite def resource(draw): - return EncodeWithArgs({ + return { + SF: resource_sf.Resource, + PB: resource_pb2.Resource, "attributes": draw(pb_repeated(key_value())), "dropped_attributes_count": draw(pb_uint32()), - }, Resource) + } -@composite +@st.composite def any_value(draw): - # oneof field so only set one - oneof = draw(integers(min_value=0, max_value=6)) - if oneof == 0: - kwargs = {"string_value": draw(pb_string())} - elif oneof == 1: - kwargs = {"bool_value": draw(pb_bool())} - elif oneof == 2: - kwargs = {"int_value": draw(pb_int64())} - elif oneof == 3: - kwargs = {"double_value": draw_pb_double(draw)} - elif oneof == 4: - kwargs = {"array_value": draw(array_value())} - elif oneof == 5: - kwargs = {"kvlist_value": draw(key_value_list())} - elif oneof == 6: - kwargs = {"bytes_value": draw(pb_bytes())} - return EncodeWithArgs(kwargs, AnyValue) - -@composite + return { + SF: common_sf.AnyValue, + PB: common_pb2.AnyValue, + **pb_oneof( + draw, + string_value=pb_string, + bool_value=pb_bool, + int_value=pb_int64, + double_value=pb_double, + array_value=array_value, + kvlist_value=key_value_list, + bytes_value=pb_bytes, + ), + } + +@st.composite def array_value(draw): - return EncodeWithArgs({ + return { + SF: common_sf.ArrayValue, + PB: common_pb2.ArrayValue, "values": draw(pb_repeated(any_value())), - }, ArrayValue) + } -@composite +@st.composite def key_value(draw): - return EncodeWithArgs({ + return { + SF: common_sf.KeyValue, + PB: common_pb2.KeyValue, "key": draw(pb_string()), "value": draw(any_value()), - }, KeyValue) + } -@composite +@st.composite def key_value_list(draw): - return EncodeWithArgs({ + return { + SF: common_sf.KeyValueList, + PB: common_pb2.KeyValueList, "values": draw(pb_repeated(key_value())), - }, KeyValueList) + } -@composite +@st.composite def logs_data(draw): - @composite + @st.composite def log_record(draw): - return EncodeWithArgs({ + return { + SF: logs_sf.LogRecord, + PB: logs_pb2.LogRecord, "time_unix_nano": draw(pb_fixed64()), "observed_time_unix_nano": draw(pb_fixed64()), - "severity_number": draw_pb_enum(draw, SeverityNumber), + "severity_number": draw_pb_enum(draw, logs_pb2.SeverityNumber), "severity_text": draw(pb_string()), "body": draw(any_value()), "attributes": draw(pb_repeated(key_value())), @@ -181,66 +142,80 @@ def log_record(draw): "flags": draw(pb_fixed32()), "span_id": draw(pb_span_id()), "trace_id": draw(pb_trace_id()), - }, LogRecord) + } - @composite + @st.composite def scope_logs(draw): - return EncodeWithArgs({ + return { + SF: logs_sf.ScopeLogs, + PB: logs_pb2.ScopeLogs, "scope": draw(instrumentation_scope()), "log_records": draw(pb_repeated(log_record())), "schema_url": draw(pb_string()), - }, ScopeLogs) + } - @composite + @st.composite def resource_logs(draw): - return EncodeWithArgs({ + return { + SF: logs_sf.ResourceLogs, + PB: logs_pb2.ResourceLogs, "resource": draw(resource()), "scope_logs": draw(pb_repeated(scope_logs())), "schema_url": draw(pb_string()), - }, ResourceLogs) + } - return EncodeWithArgs({ + return { + SF: logs_sf.LogsData, + PB: logs_pb2.LogsData, "resource_logs": draw(pb_repeated(resource_logs())), - }, LogsData) + } -@composite +@st.composite def traces_data(draw): - @composite + @st.composite def event(draw): - return EncodeWithArgs({ + return { + SF: trace_sf.Span.Event, + PB: trace_pb2.Span.Event, "time_unix_nano": draw(pb_fixed64()), "name": draw(pb_string()), "attributes": draw(pb_repeated(key_value())), "dropped_attributes_count": draw(pb_uint32()), - }, Event) + } - @composite + @st.composite def link(draw): - return EncodeWithArgs({ + return { + SF: trace_sf.Span.Link, + PB: trace_pb2.Span.Link, "trace_id": draw(pb_trace_id()), "span_id": draw(pb_span_id()), "trace_state": draw(pb_string()), "attributes": draw(pb_repeated(key_value())), "dropped_attributes_count": draw(pb_uint32()), "flags": draw(pb_fixed32()), - }, Link) + } - @composite + @st.composite def status(draw): - return EncodeWithArgs({ - "code": draw_pb_enum(draw, StatusCode), + return { + SF: trace_sf.Status, + PB: trace_pb2.Status, + "code": draw_pb_enum(draw, trace_pb2.Status.StatusCode), "message": draw(pb_string()), - }, Status) + } - @composite + @st.composite def span(draw): - return EncodeWithArgs({ + return { + SF: trace_sf.Span, + PB: trace_pb2.Span, "trace_id": draw(pb_trace_id()), "span_id": draw(pb_span_id()), "trace_state": draw(pb_string()), "parent_span_id": draw(pb_span_id()), "name": draw(pb_string()), - "kind": draw_pb_enum(draw, SpanKind), + "kind": draw_pb_enum(draw, trace_pb2.Span.SpanKind), "start_time_unix_nano": draw(pb_fixed64()), "end_time_unix_nano": draw(pb_fixed64()), "attributes": draw(pb_repeated(key_value())), @@ -251,57 +226,66 @@ def span(draw): "dropped_events_count": draw(pb_uint32()), "dropped_links_count": draw(pb_uint32()), "flags": draw(pb_fixed32()), - }, Span) + } - @composite + @st.composite def scope_spans(draw): - return EncodeWithArgs({ + return { + SF: trace_sf.ScopeSpans, + PB: trace_pb2.ScopeSpans, "scope": draw(instrumentation_scope()), "spans": draw(pb_repeated(span())), "schema_url": draw(pb_string()), - }, ScopeSpans) - - @composite + } + + @st.composite def resource_spans(draw): - return EncodeWithArgs({ + return { + SF: trace_sf.ResourceSpans, + PB: trace_pb2.ResourceSpans, "resource": draw(resource()), "scope_spans": draw(pb_repeated(scope_spans())), "schema_url": draw(pb_string()), - }, ResourceSpans) - - return EncodeWithArgs({ + } + + return { + SF: trace_sf.TracesData, + PB: trace_pb2.TracesData, "resource_spans": draw(pb_repeated(resource_spans())), - }, TracesData) + } -@composite +@st.composite def metrics_data(draw): - @composite + @st.composite def exemplar(draw): - kwargs = {} - oneof = draw(integers(min_value=0, max_value=1)) - if oneof == 0: - kwargs["as_double"] = draw(pb_double()) - elif oneof == 1: - kwargs["as_int"] = draw(pb_sfixed64()) - - return EncodeWithArgs({ - **kwargs, + return { + SF: metrics_sf.Exemplar, + PB: metrics_pb2.Exemplar, + **pb_oneof( + draw, + as_double=pb_double, + as_int=pb_sfixed64, + ), "time_unix_nano": draw(pb_fixed64()), "trace_id": draw(pb_trace_id()), "span_id": draw(pb_span_id()), "filtered_attributes": draw(pb_repeated(key_value())), - }, Exemplar) + } - @composite + @st.composite def value_at_quantile(draw): - return EncodeWithArgs({ + return { + SF: metrics_sf.SummaryDataPoint.ValueAtQuantile, + PB: metrics_pb2.SummaryDataPoint.ValueAtQuantile, "quantile": draw_pb_double(draw), "value": draw_pb_double(draw), - }, ValueAtQuantile) + } - @composite + @st.composite def summary_data_point(draw): - return EncodeWithArgs({ + return { + SF: metrics_sf.SummaryDataPoint, + PB: metrics_pb2.SummaryDataPoint, "start_time_unix_nano": draw(pb_fixed64()), "time_unix_nano": draw(pb_fixed64()), "count": draw(pb_fixed64()), @@ -309,33 +293,43 @@ def summary_data_point(draw): "quantile_values": draw(pb_repeated(value_at_quantile())), "attributes": draw(pb_repeated(key_value())), "flags": draw(pb_uint32()), - }, SummaryDataPoint) + } - @composite + @st.composite def buckets(draw): - return EncodeWithArgs({ + return { + SF: metrics_sf.ExponentialHistogramDataPoint.Buckets, + PB: metrics_pb2.ExponentialHistogramDataPoint.Buckets, "offset": draw(pb_sint32()), "bucket_counts": draw(pb_repeated(pb_uint64())), - }, Buckets) + } - @composite + @st.composite def exponential_histogram_data_point(draw): - return EncodeWithArgs({ + return { + SF: metrics_sf.ExponentialHistogramDataPoint, + PB: metrics_pb2.ExponentialHistogramDataPoint, "start_time_unix_nano": draw(pb_fixed64()), "time_unix_nano": draw(pb_fixed64()), "count": draw(pb_fixed64()), "sum": draw_pb_double(draw), - "positive": draw(buckets()), + **pb_oneof( + draw, + positive=buckets, + negative=buckets, + ), "attributes": draw(pb_repeated(key_value())), "flags": draw(pb_uint32()), "exemplars": draw(pb_repeated(exemplar())), "max": draw_pb_double(draw), "zero_threshold": draw_pb_double(draw), - }, ExponentialHistogramDataPoint) + } - @composite + @st.composite def histogram_data_point(draw): - return EncodeWithArgs({ + return { + SF: metrics_sf.HistogramDataPoint, + PB: metrics_pb2.HistogramDataPoint, "start_time_unix_nano": draw(pb_fixed64()), "time_unix_nano": draw(pb_fixed64()), "count": draw(pb_fixed64()), @@ -345,150 +339,156 @@ def histogram_data_point(draw): "flags": draw(pb_uint32()), "exemplars": draw(pb_repeated(exemplar())), "explicit_bounds": draw(pb_repeated(pb_double())), - "max": draw_pb_double(draw), - }, HistogramDataPoint) + **pb_oneof( + draw, + max=pb_double, + min=pb_double, + ), + } - @composite + @st.composite def number_data_point(draw): - oneof = draw(integers(min_value=0, max_value=3)) - kwargs = {} - if oneof == 0: - kwargs["as_int"] = draw(pb_sfixed32()) - elif oneof == 1: - kwargs["as_double"] = draw(pb_double()) - - return EncodeWithArgs({ + return { + SF: metrics_sf.NumberDataPoint, + PB: metrics_pb2.NumberDataPoint, "start_time_unix_nano": draw(pb_fixed64()), "time_unix_nano": draw(pb_fixed64()), - **kwargs, + **pb_oneof( + draw, + as_int=pb_sfixed64, + as_double=pb_double, + ), "exemplars": draw(pb_repeated(exemplar())), "attributes": draw(pb_repeated(key_value())), "flags": draw(pb_uint32()), - }, NumberDataPoint) + } - @composite + @st.composite def summary(draw): - return EncodeWithArgs({ + return { + SF: metrics_sf.Summary, + PB: metrics_pb2.Summary, "data_points": draw(pb_repeated(summary_data_point())), - }, Summary) + } - @composite + @st.composite def exponential_histogram(draw): - return EncodeWithArgs({ + return { + SF: metrics_sf.ExponentialHistogram, + PB: metrics_pb2.ExponentialHistogram, "data_points": draw(pb_repeated(exponential_histogram_data_point())), - "aggregation_temporality": draw_pb_enum(draw, AggregationTemporality), - }, ExponentialHistogram) + "aggregation_temporality": draw_pb_enum(draw, metrics_pb2.AggregationTemporality), + } - @composite + @st.composite def histogram(draw): - return EncodeWithArgs({ + return { + SF: metrics_sf.Histogram, + PB: metrics_pb2.Histogram, "data_points": draw(pb_repeated(histogram_data_point())), - "aggregation_temporality": draw_pb_enum(draw, AggregationTemporality), - }, Histogram) + "aggregation_temporality": draw_pb_enum(draw, metrics_pb2.AggregationTemporality), + } - @composite + @st.composite def sum(draw): - return EncodeWithArgs({ + return { + SF: metrics_sf.Sum, + PB: metrics_pb2.Sum, "data_points": draw(pb_repeated(number_data_point())), - "aggregation_temporality": draw_pb_enum(draw, AggregationTemporality), + "aggregation_temporality": draw_pb_enum(draw, metrics_pb2.AggregationTemporality), "is_monotonic": draw(pb_bool()), - }, Sum) + } - @composite + @st.composite def gauge(draw): - return EncodeWithArgs({ + return { + SF: metrics_sf.Gauge, + PB: metrics_pb2.Gauge, "data_points": draw(pb_repeated(number_data_point())), - }, Gauge) + } - @composite + @st.composite def metric(draw): - oneof = draw(integers(min_value=0, max_value=3)) - kwargs = {} - if oneof == 0: - kwargs["gauge"] = draw(gauge()) - elif oneof == 1: - kwargs["sum"] = draw(sum()) - elif oneof == 2: - kwargs["histogram"] = draw(histogram()) - elif oneof == 3: - kwargs["exponential_histogram"] = draw(exponential_histogram()) - - return EncodeWithArgs({ + return { + SF: metrics_sf.Metric, + PB: metrics_pb2.Metric, "name": draw(pb_string()), "description": draw(pb_string()), "unit": draw(pb_string()), - **kwargs, + **pb_oneof( + draw, + gauge=gauge, + sum=sum, + summary=summary, + histogram=histogram, + exponential_histogram=exponential_histogram, + ), "metadata": draw(pb_repeated(key_value())), - }, Metric) + } - @composite + @st.composite def scope_metrics(draw): - return EncodeWithArgs({ + return { + SF: metrics_sf.ScopeMetrics, + PB: metrics_pb2.ScopeMetrics, "scope": draw(instrumentation_scope()), "metrics": draw(pb_repeated(metric())), "schema_url": draw(pb_string()), - }, ScopeMetrics) + } - @composite + @st.composite def resource_metrics(draw): - return EncodeWithArgs({ + return { + SF: metrics_sf.ResourceMetrics, + PB: metrics_pb2.ResourceMetrics, "resource": draw(resource()), "scope_metrics": draw(pb_repeated(scope_metrics())), "schema_url": draw(pb_string()), - }, ResourceMetrics) + } - return EncodeWithArgs({ + return { + SF: metrics_sf.MetricsData, + PB: metrics_pb2.MetricsData, "resource_metrics": draw(pb_repeated(resource_metrics())), - }, MetricsData) + } # Helper functions to recursively encode protobuf types using the generated args # and the given serialization strategy -def encode_enum(enum: EncodeEnumWithVal, strategy: str) -> Any: - if strategy == "pb2": - return enum.cls.pb2(enum.val) - elif strategy == "sf": - return enum.cls.sf(enum.val) - -def encode_recurse(obj: EncodeWithArgs, strategy: str) -> Any: +def encode_recurse(obj: Dict[str, Any], strategy: str) -> Any: kwargs = {} - for key, value in obj.kwargs.items(): - if isinstance(value, EncodeWithArgs): + for key, value in obj.items(): + if key in [SF, PB]: + continue + elif isinstance(value, Mapping): kwargs[key] = encode_recurse(value, strategy) - elif isinstance(value, EncodeEnumWithVal): - kwargs[key] = encode_enum(value, strategy) - elif isinstance(value, list) and value and isinstance(value[0], EncodeWithArgs): + elif isinstance(value, Sequence) and value and isinstance(value[0], Mapping): kwargs[key] = [encode_recurse(v, strategy) for v in value] - elif isinstance(value, list) and value and isinstance(value[0], EncodeEnumWithVal): - kwargs[key] = [encode_enum(v, strategy) for v in value] else: kwargs[key] = value - if strategy == "pb2": - return obj.cls.pb2(**kwargs) - elif strategy == "sf": - return obj.cls.sf(**kwargs) + return obj[strategy](**kwargs) class TestProtoSerialization(unittest.TestCase): @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.too_slow]) @hypothesis.given(logs_data()) def test_log_data(self, logs_data): self.assertEqual( - encode_recurse(logs_data, "pb2").SerializeToString(deterministic=True), - bytes(encode_recurse(logs_data, "sf")) + encode_recurse(logs_data, PB).SerializeToString(deterministic=True), + bytes(encode_recurse(logs_data, SF)) ) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.too_slow]) @hypothesis.given(traces_data()) def test_trace_data(self, traces_data): self.assertEqual( - encode_recurse(traces_data, "pb2").SerializeToString(deterministic=True), - bytes(encode_recurse(traces_data, "sf")) + encode_recurse(traces_data, PB).SerializeToString(deterministic=True), + bytes(encode_recurse(traces_data, SF)) ) @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.too_slow]) @hypothesis.given(metrics_data()) def test_metrics_data(self, metrics_data): self.assertEqual( - encode_recurse(metrics_data, "pb2").SerializeToString(deterministic=True), - bytes(encode_recurse(metrics_data, "sf")) + encode_recurse(metrics_data, PB).SerializeToString(deterministic=True), + bytes(encode_recurse(metrics_data, SF)) )