diff --git a/.github/workflows/check-codegen.yml b/.github/workflows/check-codegen.yml new file mode 100644 index 0000000..17b4c6f --- /dev/null +++ b/.github/workflows/check-codegen.yml @@ -0,0 +1,36 @@ +# This workflow will delete and regenerate the opentelemetry marshaling code using scripts/proto_codegen.sh. +# If generating the code produces any changes from what is currently checked in, the workflow will fail and prompt the user to regenerate the code. +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Check Codegen + +on: + push: + branches: [ "main" ] + paths: + - "scripts/**" + - "src/snowflake/telemetry/_internal/opentelemetry/proto/**" + - ".github/workflows/check-codegen.yml" + pull_request: + branches: [ "main" ] + paths: + - "scripts/**" + - "src/snowflake/telemetry/_internal/opentelemetry/proto/**" + - ".github/workflows/check-codegen.yml" + +jobs: + check-codegen: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: "3.11" + - name: Run codegen script + run: | + rm -rf src/snowflake/telemetry/_internal/opentelemetry/proto/ + ./scripts/proto_codegen.sh + - name: Check for changes + run: | + git diff --exit-code || { echo "Code generation produced changes! Regenerate the code using ./scripts/proto_codegen.sh"; exit 1; } diff --git a/README.md b/README.md index ef6e6ae..ca9d2b3 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,8 @@ pip install --upgrade pip pip install . ``` +## Development + To develop this package, run ```bash @@ -33,3 +35,9 @@ source .venv/bin/activate pip install --upgrade pip pip install . ./tests/snowflake-telemetry-test-utils ``` + +### Code generation + +To regenerate the code under `src/snowflake/_internal/opentelemetry/proto/`, execute the script `./scripts/proto_codegen.sh`. The script expects the `src/snowflake/_internal/opentelemetry/proto/` directory to exist, and will delete all .py files in it before regerating the code. + +The commit/branch/tag of [opentelemetry-proto](https://github.com/open-telemetry/opentelemetry-proto) that the code is generated from is pinned to PROTO_REPO_BRANCH_OR_COMMIT, which can be configured in the script. It is currently pinned to the same tag as [opentelemetry-python](https://github.com/open-telemetry/opentelemetry-python/blob/main/scripts/proto_codegen.sh#L15). diff --git a/scripts/gen-requirements.txt b/scripts/gen-requirements.txt new file mode 100644 index 0000000..bf8c682 --- /dev/null +++ b/scripts/gen-requirements.txt @@ -0,0 +1,5 @@ +Jinja2==3.1.4 +grpcio-tools==1.62.3 +protobuf==4.25.5 +black==24.10.0 +isort==5.13.2 diff --git a/scripts/plugin.py b/scripts/plugin.py new file mode 100755 index 0000000..d01cb00 --- /dev/null +++ b/scripts/plugin.py @@ -0,0 +1,345 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import re +import os +import sys +import inspect +from enum import IntEnum +from typing import List, Optional +from textwrap import dedent, indent +from dataclasses import dataclass, field +# Must be imported into globals for the inline functions to work +from snowflake.telemetry._internal.serialize import MessageMarshaler # noqa + +from google.protobuf.compiler import plugin_pb2 as plugin +from google.protobuf.descriptor_pb2 import ( + FileDescriptorProto, + FieldDescriptorProto, + EnumDescriptorProto, + EnumValueDescriptorProto, + DescriptorProto, +) +from jinja2 import Environment, FileSystemLoader +import black +import isort.api + +INLINE_OPTIMIZATION = True +FILE_PATH_PREFIX = "snowflake.telemetry._internal" +FILE_NAME_SUFFIX = "_marshaler" + +# Inline utility functions + +# Inline the size function for a given proto message field +def inline_size_function(proto_type: str, attr_name: str, field_tag: str) -> str: + """ + For example: + + class MessageMarshaler: + def size_uint32(self, TAG: bytes, FIELD_ATTR: int) -> int: + return len(TAG) + Varint.size_varint_u32(FIELD_ATTR) + + Becomes: + + size += len(b"\x10") + Varint.size_varint_u32(self.int_value) + """ + function_definition = inspect.getsource(globals()["MessageMarshaler"].__dict__[f"size_{proto_type}"]) + # Remove the function header and unindent the function body + function_definition = function_definition.splitlines()[1:] + function_definition = "\n".join(function_definition) + function_definition = dedent(function_definition) + # Replace the attribute name + function_definition = function_definition.replace("FIELD_ATTR", f"self.{attr_name}") + # Replace the TAG + function_definition = function_definition.replace("TAG", field_tag) + # Inline the return statement + function_definition = function_definition.replace("return ", "size += ") + return function_definition + +# Inline the serialization function for a given proto message field +def inline_serialize_function(proto_type: str, attr_name: str, field_tag: str) -> str: + """ + For example: + + class MessageMarshaler: + def serialize_uint32(self, out: BytesIO, TAG: bytes, FIELD_ATTR: int) -> None: + out.write(TAG) + Varint.serialize_varint_u32(out, FIELD_ATTR) + + Becomes: + + out.write(b"\x10") + Varint.serialize_varint_u32(out, self.int_value) + """ + function_definition = inspect.getsource(globals()["MessageMarshaler"].__dict__[f"serialize_{proto_type}"]) + # Remove the function header and unindent the function body + function_definition = function_definition.splitlines()[1:] + function_definition = "\n".join(function_definition) + function_definition = dedent(function_definition) + # Replace the attribute name + function_definition = function_definition.replace("FIELD_ATTR", f"self.{attr_name}") + # Replace the TAG + function_definition = function_definition.replace("TAG", field_tag) + return function_definition + +# Add a presence check to a function definition +# https://protobuf.dev/programming-guides/proto3/#default +def add_presence_check(proto_type: str, encode_presence: bool, attr_name: str, function_definition: str) -> str: + # oneof, optional (virtual oneof), and message fields are encoded if they are not None + function_definition = indent(function_definition, " ") + if encode_presence: + return f"if self.{attr_name} is not None:\n{function_definition}" + # Other fields are encoded if they are not the default value + # Which happens to align with the bool(x) check for all primitive types + # TODO: Except + # - double and float -0.0 should be encoded, even though bool(-0.0) is False + return f"if self.{attr_name}:\n{function_definition}" + +class WireType(IntEnum): + VARINT = 0 + I64 = 1 + LEN = 2 + I32 = 5 + +@dataclass +class ProtoTypeDescriptor: + name: str + wire_type: WireType + python_type: str + default_val: str + +proto_type_to_descriptor = { + FieldDescriptorProto.TYPE_BOOL: ProtoTypeDescriptor("bool", WireType.VARINT, "bool", "False"), + FieldDescriptorProto.TYPE_ENUM: ProtoTypeDescriptor("enum", WireType.VARINT, "int", "0"), + FieldDescriptorProto.TYPE_INT32: ProtoTypeDescriptor("int32", WireType.VARINT, "int", "0"), + FieldDescriptorProto.TYPE_INT64: ProtoTypeDescriptor("int64", WireType.VARINT, "int", "0"), + FieldDescriptorProto.TYPE_UINT32: ProtoTypeDescriptor("uint32", WireType.VARINT, "int", "0"), + FieldDescriptorProto.TYPE_UINT64: ProtoTypeDescriptor("uint64", WireType.VARINT, "int", "0"), + FieldDescriptorProto.TYPE_SINT32: ProtoTypeDescriptor("sint32", WireType.VARINT, "int", "0"), + FieldDescriptorProto.TYPE_SINT64: ProtoTypeDescriptor("sint64", WireType.VARINT, "int", "0"), + FieldDescriptorProto.TYPE_FIXED32: ProtoTypeDescriptor("fixed32", WireType.I32, "int", "0"), + FieldDescriptorProto.TYPE_FIXED64: ProtoTypeDescriptor("fixed64", WireType.I64, "int", "0"), + FieldDescriptorProto.TYPE_SFIXED32: ProtoTypeDescriptor("sfixed32", WireType.I32, "int", "0"), + FieldDescriptorProto.TYPE_SFIXED64: ProtoTypeDescriptor("sfixed64", WireType.I64, "int", "0"), + FieldDescriptorProto.TYPE_FLOAT: ProtoTypeDescriptor("float", WireType.I32, "float", "0.0"), + FieldDescriptorProto.TYPE_DOUBLE: ProtoTypeDescriptor("double", WireType.I64, "float", "0.0"), + FieldDescriptorProto.TYPE_STRING: ProtoTypeDescriptor("string", WireType.LEN, "str", '""'), + FieldDescriptorProto.TYPE_BYTES: ProtoTypeDescriptor("bytes", WireType.LEN, "bytes", 'b""'), + FieldDescriptorProto.TYPE_MESSAGE: ProtoTypeDescriptor("message", WireType.LEN, "PLACEHOLDER", "None"), +} + +@dataclass +class EnumValueTemplate: + name: str + number: int + + @staticmethod + def from_descriptor(descriptor: EnumValueDescriptorProto) -> "EnumValueTemplate": + return EnumValueTemplate( + name=descriptor.name, + number=descriptor.number, + ) + +@dataclass +class EnumTemplate: + name: str + values: List[EnumValueTemplate] = field(default_factory=list) + + @staticmethod + def from_descriptor(descriptor: EnumDescriptorProto) -> "EnumTemplate": + return EnumTemplate( + name=descriptor.name, + values=[EnumValueTemplate.from_descriptor(value) for value in descriptor.value], + ) + +def tag_to_repr_varint(tag: int) -> str: + out = bytearray() + while tag >= 128: + out.append((tag & 0x7F) | 0x80) + tag >>= 7 + out.append(tag) + return repr(bytes(out)) + +@dataclass +class FieldTemplate: + name: str + attr_name: str + number: int + generator: str + python_type: str + proto_type: str + default_val: str + serialize_field_inline: str + size_field_inline: str + + @staticmethod + def from_descriptor(descriptor: FieldDescriptorProto, group: Optional[str] = None) -> "FieldTemplate": + type_descriptor = proto_type_to_descriptor[descriptor.type] + python_type = type_descriptor.python_type + proto_type = type_descriptor.name + default_val = type_descriptor.default_val + + if proto_type == "message" or proto_type == "enum": + # Extract the class name of message fields, to use as python type + python_type = re.sub(r"^[a-zA-Z0-9_\.]+\.v1\.", "", descriptor.type_name) + + repeated = descriptor.label == FieldDescriptorProto.LABEL_REPEATED + if repeated: + # Update type for repeated fields + python_type = f"List[{python_type}]" + proto_type = f"repeated_{proto_type}" + # Default value is None, since we can't use a mutable default value like [] + default_val = "None" + + # Calculate the tag for the field to save some computation at runtime + tag = (descriptor.number << 3) | type_descriptor.wire_type.value + if repeated and type_descriptor.wire_type != WireType.LEN: + # Special case: repeated primitive fields are packed, so need to use LEN wire type + # Note: packed fields can be disabled in proto files, but we don't handle that case + # https://protobuf.dev/programming-guides/encoding/#packed + tag = (descriptor.number << 3) | WireType.LEN.value + # Convert the tag to a varint representation to inline it in the generated code + tag = tag_to_repr_varint(tag) + + # For oneof and optional fields, we need to encode the presence of the field. + # Optional fields are treated as virtual oneof fields, with a single field in the oneof. + # For message fields, we need to encode the presence of the field if it is not None. + # https://protobuf.dev/programming-guides/field_presence/ + encode_presence = group is not None or proto_type == "message" + if group is not None: + # The default value for oneof fields must be None, so that the default value is not encoded + default_val = "None" + + field_name = descriptor.name + attr_name = field_name + generator = None + if proto_type == "message" or repeated: + # For message and repeated fields, store as a private attribute that is + # initialized on access to match protobuf embedded message access pattern + if repeated: + # In python protobuf, repeated fields return an implementation of the list interface + # with a self.add() method to add and initialize elements + # This can be supported with a custom list implementation, but we use a simple list for now + # https://protobuf.dev/reference/python/python-generated/#repeated-message-fields + generator = "list()" + else: + # https://protobuf.dev/reference/python/python-generated/#embedded_message + generator = f"{python_type}()" + # the attribute name is prefixed with an underscore as message and repeated attributes + # are hidden behind a property that has the actual proto field name + attr_name = f"_{field_name}" + + # Inline the size and serialization functions for the field + if INLINE_OPTIMIZATION: + serialize_field_inline = inline_serialize_function(proto_type, attr_name, tag) + size_field_inline = inline_size_function(proto_type, attr_name, tag) + else: + serialize_field_inline = f"self.serialize_{proto_type}(out, {tag}, self.{attr_name})" + size_field_inline = f"size += self.size_{proto_type}({tag}, self.{attr_name})" + + serialize_field_inline = add_presence_check(proto_type, encode_presence, attr_name, serialize_field_inline) + size_field_inline = add_presence_check(proto_type, encode_presence, attr_name, size_field_inline) + + return FieldTemplate( + name=field_name, + attr_name=attr_name, + number=descriptor.number, + generator=generator, + python_type=python_type, + proto_type=proto_type, + default_val=default_val, + serialize_field_inline=serialize_field_inline, + size_field_inline=size_field_inline, + ) + +@dataclass +class MessageTemplate: + name: str + fields: List[FieldTemplate] = field(default_factory=list) + enums: List[EnumTemplate] = field(default_factory=list) + messages: List[MessageTemplate] = field(default_factory=list) + + @staticmethod + def from_descriptor(descriptor: DescriptorProto) -> "MessageTemplate": + # Helper function to extract the group name for a field, if it exists + def get_group(field: FieldDescriptorProto) -> str: + return descriptor.oneof_decl[field.oneof_index].name if field.HasField("oneof_index") else None + fields = [FieldTemplate.from_descriptor(field, get_group(field)) for field in descriptor.field] + fields.sort(key=lambda field: field.number) + + name = descriptor.name + return MessageTemplate( + name=name, + fields=fields, + enums=[EnumTemplate.from_descriptor(enum) for enum in descriptor.enum_type], + messages=[MessageTemplate.from_descriptor(message) for message in descriptor.nested_type], + ) + +@dataclass +class FileTemplate: + messages: List[MessageTemplate] = field(default_factory=list) + enums: List[EnumTemplate] = field(default_factory=list) + imports: List[str] = field(default_factory=list) + name: str = "" + + @staticmethod + def from_descriptor(descriptor: FileDescriptorProto) -> "FileTemplate": + + # Extract the import paths for the proto file + imports = [] + for dependency in descriptor.dependency: + path = re.sub(r"\.proto$", "", dependency) + if descriptor.name.startswith(path): + continue + path = path.replace("/", ".") + path = f"{FILE_PATH_PREFIX}.{path}{FILE_NAME_SUFFIX}" + imports.append(path) + + return FileTemplate( + messages=[MessageTemplate.from_descriptor(message) for message in descriptor.message_type], + enums=[EnumTemplate.from_descriptor(enum) for enum in descriptor.enum_type], + imports=imports, + name=descriptor.name, + ) + +def main(): + request = plugin.CodeGeneratorRequest() + request.ParseFromString(sys.stdin.buffer.read()) + + response = plugin.CodeGeneratorResponse() + # Needed since metrics.proto uses proto3 optional fields + # https://github.com/protocolbuffers/protobuf/blob/main/docs/implementing_proto3_presence.md + response.supported_features = plugin.CodeGeneratorResponse.FEATURE_PROTO3_OPTIONAL + + template_env = Environment(loader=FileSystemLoader(f"{os.path.dirname(os.path.realpath(__file__))}/templates")) + jinja_body_template = template_env.get_template("template.py.jinja2") + + for proto_file in request.proto_file: + file_name = re.sub(r"\.proto$", f"{FILE_NAME_SUFFIX}.py", proto_file.name) + file_descriptor_proto = proto_file + + file_template = FileTemplate.from_descriptor(file_descriptor_proto) + + code = jinja_body_template.render(file_template=file_template) + code = isort.api.sort_code_string( + code = code, + show_diff=False, + profile="black", + combine_as_imports=True, + lines_after_imports=2, + quiet=True, + force_grid_wrap=2, + ) + code = black.format_str( + src_contents=code, + mode=black.Mode(), + ) + + response_file = response.file.add() + response_file.name = file_name + response_file.content = code + + sys.stdout.buffer.write(response.SerializeToString()) + +if __name__ == '__main__': + main() diff --git a/scripts/proto_codegen.sh b/scripts/proto_codegen.sh new file mode 100755 index 0000000..9db288f --- /dev/null +++ b/scripts/proto_codegen.sh @@ -0,0 +1,67 @@ +#!/bin/bash +# +# Regenerate python code from OTLP protos in +# https://github.com/open-telemetry/opentelemetry-proto +# +# To use, update PROTO_REPO_BRANCH_OR_COMMIT variable below to a commit hash or +# tag in opentelemtry-proto repo that you want to build off of. Then, just run +# this script to update the proto files. Commit the changes as well as any +# fixes needed in the OTLP exporter. +# +# Optional envars: +# PROTO_REPO_DIR - the path to an existing checkout of the opentelemetry-proto repo + +# Pinned commit/branch/tag for the current version used in opentelemetry-proto python package. +PROTO_REPO_BRANCH_OR_COMMIT="v1.2.0" + +set -e + +PROTO_REPO_DIR=${PROTO_REPO_DIR:-"/tmp/opentelemetry-proto"} +# root of opentelemetry-python repo +repo_root="$(git rev-parse --show-toplevel)" +venv_dir="/tmp/proto_codegen_venv" + +# run on exit even if crash +cleanup() { + echo "Deleting $venv_dir" + rm -rf $venv_dir +} +trap cleanup EXIT + +echo "Creating temporary virtualenv at $venv_dir using $(python3 --version)" +python3 -m venv $venv_dir +source $venv_dir/bin/activate +python -m pip install \ + -c $repo_root/scripts/gen-requirements.txt \ + protobuf Jinja2 grpcio-tools black isort . + +echo 'python -m grpc_tools.protoc --version' +python -m grpc_tools.protoc --version + +# Clone the proto repo if it doesn't exist +if [ ! -d "$PROTO_REPO_DIR" ]; then + git clone https://github.com/open-telemetry/opentelemetry-proto.git $PROTO_REPO_DIR +fi + +# Pull in changes and switch to requested branch +( + cd $PROTO_REPO_DIR + git fetch --all + git checkout $PROTO_REPO_BRANCH_OR_COMMIT + # pull if PROTO_REPO_BRANCH_OR_COMMIT is not a detached head + git symbolic-ref -q HEAD && git pull --ff-only || true +) + +cd $repo_root/src/snowflake/telemetry/_internal + +# clean up old generated code +mkdir -p opentelemetry/proto +find opentelemetry/proto/ -regex ".*_marshaler\.py" -exec rm {} + + +# generate proto code for all protos +all_protos=$(find $PROTO_REPO_DIR/ -iname "*.proto") +python -m grpc_tools.protoc \ + -I $PROTO_REPO_DIR \ + --plugin=protoc-gen-custom-plugin=$repo_root/scripts/plugin.py \ + --custom-plugin_out=. \ + $all_protos diff --git a/scripts/templates/template.py.jinja2 b/scripts/templates/template.py.jinja2 new file mode 100644 index 0000000..9de445f --- /dev/null +++ b/scripts/templates/template.py.jinja2 @@ -0,0 +1,76 @@ +# Generated by the protoc compiler with a custom plugin. DO NOT EDIT! +# sources: {{ file_template.name }} + +from __future__ import annotations + +import struct +from snowflake.telemetry._internal.serialize import ( + Enum, + MessageMarshaler, + Varint, +) +from typing import List + +{% for import in file_template.imports %} +from {{ import }} import * +{% endfor %} + +{% for enum in file_template.enums %} +class {{ enum.name }}(Enum): +{%- for value in enum.values %} + {{ value.name }} = {{ value.number }} +{%- endfor %} +{% endfor %} + +{% macro render_message(message) %} +class {{ message.name }}(MessageMarshaler): + +{%- for field in message.fields %} +{%- if field.generator %} + @property + def {{ field.name }}(self) -> {{ field.python_type }}: + if self.{{ field.attr_name }} is None: + self.{{ field.attr_name }} = {{ field.generator }} + return self.{{ field.attr_name }} +{%- else %} + {{ field.name }}: {{ field.python_type }} +{%- endif %} +{%- endfor %} + + def __init__( + self, +{%- for field in message.fields %} + {{ field.name }}: {{ field.python_type }} = {{ field.default_val }}, +{%- endfor %} + ): +{%- for field in message.fields %} + self.{{ field.attr_name }}: {{ field.python_type }} = {{ field.name }} +{%- endfor %} + + def calculate_size(self) -> int: + size = 0 +{%- for field in message.fields %} + {{ field.size_field_inline | indent(8) }} +{%- endfor %} + return size + + def write_to(self, out: bytearray) -> None: +{%- for field in message.fields %} + {{ field.serialize_field_inline | indent(8) }} +{%- endfor %} + +{% for nested_enum in message.enums %} + class {{ nested_enum.name }}(Enum): +{%- for value in nested_enum.values %} + {{ value.name }} = {{ value.number }} +{%- endfor %} +{% endfor %} + +{% for nested_message in message.messages %} +{{ render_message(nested_message) | indent(4) }} +{% endfor %} +{% endmacro %} + +{% for message in file_template.messages %} +{{ render_message(message) }} +{% endfor %} \ No newline at end of file diff --git a/src/snowflake/telemetry/_internal/opentelemetry/proto/collector/logs/v1/logs_service_marshaler.py b/src/snowflake/telemetry/_internal/opentelemetry/proto/collector/logs/v1/logs_service_marshaler.py new file mode 100644 index 0000000..52711d2 --- /dev/null +++ b/src/snowflake/telemetry/_internal/opentelemetry/proto/collector/logs/v1/logs_service_marshaler.py @@ -0,0 +1,108 @@ +# Generated by the protoc compiler with a custom plugin. DO NOT EDIT! +# sources: opentelemetry/proto/collector/logs/v1/logs_service.proto + +from __future__ import annotations + +import struct +from typing import List + +from snowflake.telemetry._internal.opentelemetry.proto.logs.v1.logs_marshaler import * +from snowflake.telemetry._internal.serialize import ( + Enum, + MessageMarshaler, + Varint, +) + + +class ExportLogsServiceRequest(MessageMarshaler): + @property + def resource_logs(self) -> List[ResourceLogs]: + if self._resource_logs is None: + self._resource_logs = list() + return self._resource_logs + + def __init__( + self, + resource_logs: List[ResourceLogs] = None, + ): + self._resource_logs: List[ResourceLogs] = resource_logs + + def calculate_size(self) -> int: + size = 0 + if self._resource_logs: + size += sum( + message._get_size() + + len(b"\n") + + Varint.size_varint_u32(message._get_size()) + for message in self._resource_logs + ) + return size + + def write_to(self, out: bytearray) -> None: + if self._resource_logs: + for v in self._resource_logs: + out += b"\n" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + + +class ExportLogsServiceResponse(MessageMarshaler): + @property + def partial_success(self) -> ExportLogsPartialSuccess: + if self._partial_success is None: + self._partial_success = ExportLogsPartialSuccess() + return self._partial_success + + def __init__( + self, + partial_success: ExportLogsPartialSuccess = None, + ): + self._partial_success: ExportLogsPartialSuccess = partial_success + + def calculate_size(self) -> int: + size = 0 + if self._partial_success is not None: + size += ( + len(b"\n") + + Varint.size_varint_u32(self._partial_success._get_size()) + + self._partial_success._get_size() + ) + return size + + def write_to(self, out: bytearray) -> None: + if self._partial_success is not None: + out += b"\n" + Varint.write_varint_u32(out, self._partial_success._get_size()) + self._partial_success.write_to(out) + + +class ExportLogsPartialSuccess(MessageMarshaler): + rejected_log_records: int + error_message: str + + def __init__( + self, + rejected_log_records: int = 0, + error_message: str = "", + ): + self.rejected_log_records: int = rejected_log_records + self.error_message: str = error_message + + def calculate_size(self) -> int: + size = 0 + if self.rejected_log_records: + size += len(b"\x08") + Varint.size_varint_i64(self.rejected_log_records) + if self.error_message: + v = self.error_message.encode("utf-8") + size += len(b"\x12") + Varint.size_varint_u32(len(v)) + len(v) + return size + + def write_to(self, out: bytearray) -> None: + if self.rejected_log_records: + out += b"\x08" + Varint.write_varint_i64(out, self.rejected_log_records) + if self.error_message: + v = self.error_message.encode("utf-8") + out += b"\x12" + Varint.write_varint_u32(out, len(v)) + out += v diff --git a/src/snowflake/telemetry/_internal/opentelemetry/proto/collector/metrics/v1/metrics_service_marshaler.py b/src/snowflake/telemetry/_internal/opentelemetry/proto/collector/metrics/v1/metrics_service_marshaler.py new file mode 100644 index 0000000..7701775 --- /dev/null +++ b/src/snowflake/telemetry/_internal/opentelemetry/proto/collector/metrics/v1/metrics_service_marshaler.py @@ -0,0 +1,108 @@ +# Generated by the protoc compiler with a custom plugin. DO NOT EDIT! +# sources: opentelemetry/proto/collector/metrics/v1/metrics_service.proto + +from __future__ import annotations + +import struct +from typing import List + +from snowflake.telemetry._internal.opentelemetry.proto.metrics.v1.metrics_marshaler import * +from snowflake.telemetry._internal.serialize import ( + Enum, + MessageMarshaler, + Varint, +) + + +class ExportMetricsServiceRequest(MessageMarshaler): + @property + def resource_metrics(self) -> List[ResourceMetrics]: + if self._resource_metrics is None: + self._resource_metrics = list() + return self._resource_metrics + + def __init__( + self, + resource_metrics: List[ResourceMetrics] = None, + ): + self._resource_metrics: List[ResourceMetrics] = resource_metrics + + def calculate_size(self) -> int: + size = 0 + if self._resource_metrics: + size += sum( + message._get_size() + + len(b"\n") + + Varint.size_varint_u32(message._get_size()) + for message in self._resource_metrics + ) + return size + + def write_to(self, out: bytearray) -> None: + if self._resource_metrics: + for v in self._resource_metrics: + out += b"\n" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + + +class ExportMetricsServiceResponse(MessageMarshaler): + @property + def partial_success(self) -> ExportMetricsPartialSuccess: + if self._partial_success is None: + self._partial_success = ExportMetricsPartialSuccess() + return self._partial_success + + def __init__( + self, + partial_success: ExportMetricsPartialSuccess = None, + ): + self._partial_success: ExportMetricsPartialSuccess = partial_success + + def calculate_size(self) -> int: + size = 0 + if self._partial_success is not None: + size += ( + len(b"\n") + + Varint.size_varint_u32(self._partial_success._get_size()) + + self._partial_success._get_size() + ) + return size + + def write_to(self, out: bytearray) -> None: + if self._partial_success is not None: + out += b"\n" + Varint.write_varint_u32(out, self._partial_success._get_size()) + self._partial_success.write_to(out) + + +class ExportMetricsPartialSuccess(MessageMarshaler): + rejected_data_points: int + error_message: str + + def __init__( + self, + rejected_data_points: int = 0, + error_message: str = "", + ): + self.rejected_data_points: int = rejected_data_points + self.error_message: str = error_message + + def calculate_size(self) -> int: + size = 0 + if self.rejected_data_points: + size += len(b"\x08") + Varint.size_varint_i64(self.rejected_data_points) + if self.error_message: + v = self.error_message.encode("utf-8") + size += len(b"\x12") + Varint.size_varint_u32(len(v)) + len(v) + return size + + def write_to(self, out: bytearray) -> None: + if self.rejected_data_points: + out += b"\x08" + Varint.write_varint_i64(out, self.rejected_data_points) + if self.error_message: + v = self.error_message.encode("utf-8") + out += b"\x12" + Varint.write_varint_u32(out, len(v)) + out += v diff --git a/src/snowflake/telemetry/_internal/opentelemetry/proto/collector/trace/v1/trace_service_marshaler.py b/src/snowflake/telemetry/_internal/opentelemetry/proto/collector/trace/v1/trace_service_marshaler.py new file mode 100644 index 0000000..2488f6c --- /dev/null +++ b/src/snowflake/telemetry/_internal/opentelemetry/proto/collector/trace/v1/trace_service_marshaler.py @@ -0,0 +1,108 @@ +# Generated by the protoc compiler with a custom plugin. DO NOT EDIT! +# sources: opentelemetry/proto/collector/trace/v1/trace_service.proto + +from __future__ import annotations + +import struct +from typing import List + +from snowflake.telemetry._internal.opentelemetry.proto.trace.v1.trace_marshaler import * +from snowflake.telemetry._internal.serialize import ( + Enum, + MessageMarshaler, + Varint, +) + + +class ExportTraceServiceRequest(MessageMarshaler): + @property + def resource_spans(self) -> List[ResourceSpans]: + if self._resource_spans is None: + self._resource_spans = list() + return self._resource_spans + + def __init__( + self, + resource_spans: List[ResourceSpans] = None, + ): + self._resource_spans: List[ResourceSpans] = resource_spans + + def calculate_size(self) -> int: + size = 0 + if self._resource_spans: + size += sum( + message._get_size() + + len(b"\n") + + Varint.size_varint_u32(message._get_size()) + for message in self._resource_spans + ) + return size + + def write_to(self, out: bytearray) -> None: + if self._resource_spans: + for v in self._resource_spans: + out += b"\n" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + + +class ExportTraceServiceResponse(MessageMarshaler): + @property + def partial_success(self) -> ExportTracePartialSuccess: + if self._partial_success is None: + self._partial_success = ExportTracePartialSuccess() + return self._partial_success + + def __init__( + self, + partial_success: ExportTracePartialSuccess = None, + ): + self._partial_success: ExportTracePartialSuccess = partial_success + + def calculate_size(self) -> int: + size = 0 + if self._partial_success is not None: + size += ( + len(b"\n") + + Varint.size_varint_u32(self._partial_success._get_size()) + + self._partial_success._get_size() + ) + return size + + def write_to(self, out: bytearray) -> None: + if self._partial_success is not None: + out += b"\n" + Varint.write_varint_u32(out, self._partial_success._get_size()) + self._partial_success.write_to(out) + + +class ExportTracePartialSuccess(MessageMarshaler): + rejected_spans: int + error_message: str + + def __init__( + self, + rejected_spans: int = 0, + error_message: str = "", + ): + self.rejected_spans: int = rejected_spans + self.error_message: str = error_message + + def calculate_size(self) -> int: + size = 0 + if self.rejected_spans: + size += len(b"\x08") + Varint.size_varint_i64(self.rejected_spans) + if self.error_message: + v = self.error_message.encode("utf-8") + size += len(b"\x12") + Varint.size_varint_u32(len(v)) + len(v) + return size + + def write_to(self, out: bytearray) -> None: + if self.rejected_spans: + out += b"\x08" + Varint.write_varint_i64(out, self.rejected_spans) + if self.error_message: + v = self.error_message.encode("utf-8") + out += b"\x12" + Varint.write_varint_u32(out, len(v)) + out += v diff --git a/src/snowflake/telemetry/_internal/opentelemetry/proto/common/v1/common_marshaler.py b/src/snowflake/telemetry/_internal/opentelemetry/proto/common/v1/common_marshaler.py new file mode 100644 index 0000000..2d25bcc --- /dev/null +++ b/src/snowflake/telemetry/_internal/opentelemetry/proto/common/v1/common_marshaler.py @@ -0,0 +1,281 @@ +# Generated by the protoc compiler with a custom plugin. DO NOT EDIT! +# sources: opentelemetry/proto/common/v1/common.proto + +from __future__ import annotations + +import struct +from typing import List + +from snowflake.telemetry._internal.serialize import ( + Enum, + MessageMarshaler, + Varint, +) + + +class AnyValue(MessageMarshaler): + string_value: str + bool_value: bool + int_value: int + double_value: float + + @property + def array_value(self) -> ArrayValue: + if self._array_value is None: + self._array_value = ArrayValue() + return self._array_value + + @property + def kvlist_value(self) -> KeyValueList: + if self._kvlist_value is None: + self._kvlist_value = KeyValueList() + return self._kvlist_value + + bytes_value: bytes + + def __init__( + self, + string_value: str = None, + bool_value: bool = None, + int_value: int = None, + double_value: float = None, + array_value: ArrayValue = None, + kvlist_value: KeyValueList = None, + bytes_value: bytes = None, + ): + self.string_value: str = string_value + self.bool_value: bool = bool_value + self.int_value: int = int_value + self.double_value: float = double_value + self._array_value: ArrayValue = array_value + self._kvlist_value: KeyValueList = kvlist_value + self.bytes_value: bytes = bytes_value + + def calculate_size(self) -> int: + size = 0 + if self.string_value is not None: + v = self.string_value.encode("utf-8") + size += len(b"\n") + Varint.size_varint_u32(len(v)) + len(v) + if self.bool_value is not None: + size += len(b"\x10") + 1 + if self.int_value is not None: + size += len(b"\x18") + Varint.size_varint_i64(self.int_value) + if self.double_value is not None: + size += len(b"!") + 8 + if self._array_value is not None: + size += ( + len(b"*") + + Varint.size_varint_u32(self._array_value._get_size()) + + self._array_value._get_size() + ) + if self._kvlist_value is not None: + size += ( + len(b"2") + + Varint.size_varint_u32(self._kvlist_value._get_size()) + + self._kvlist_value._get_size() + ) + if self.bytes_value is not None: + size += ( + len(b":") + + Varint.size_varint_u32(len(self.bytes_value)) + + len(self.bytes_value) + ) + return size + + def write_to(self, out: bytearray) -> None: + if self.string_value is not None: + v = self.string_value.encode("utf-8") + out += b"\n" + Varint.write_varint_u32(out, len(v)) + out += v + if self.bool_value is not None: + out += b"\x10" + Varint.write_varint_u32(out, 1 if self.bool_value else 0) + if self.int_value is not None: + out += b"\x18" + Varint.write_varint_i64(out, self.int_value) + if self.double_value is not None: + out += b"!" + out += struct.pack(" List[AnyValue]: + if self._values is None: + self._values = list() + return self._values + + def __init__( + self, + values: List[AnyValue] = None, + ): + self._values: List[AnyValue] = values + + def calculate_size(self) -> int: + size = 0 + if self._values: + size += sum( + message._get_size() + + len(b"\n") + + Varint.size_varint_u32(message._get_size()) + for message in self._values + ) + return size + + def write_to(self, out: bytearray) -> None: + if self._values: + for v in self._values: + out += b"\n" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + + +class KeyValueList(MessageMarshaler): + @property + def values(self) -> List[KeyValue]: + if self._values is None: + self._values = list() + return self._values + + def __init__( + self, + values: List[KeyValue] = None, + ): + self._values: List[KeyValue] = values + + def calculate_size(self) -> int: + size = 0 + if self._values: + size += sum( + message._get_size() + + len(b"\n") + + Varint.size_varint_u32(message._get_size()) + for message in self._values + ) + return size + + def write_to(self, out: bytearray) -> None: + if self._values: + for v in self._values: + out += b"\n" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + + +class KeyValue(MessageMarshaler): + key: str + + @property + def value(self) -> AnyValue: + if self._value is None: + self._value = AnyValue() + return self._value + + def __init__( + self, + key: str = "", + value: AnyValue = None, + ): + self.key: str = key + self._value: AnyValue = value + + def calculate_size(self) -> int: + size = 0 + if self.key: + v = self.key.encode("utf-8") + size += len(b"\n") + Varint.size_varint_u32(len(v)) + len(v) + if self._value is not None: + size += ( + len(b"\x12") + + Varint.size_varint_u32(self._value._get_size()) + + self._value._get_size() + ) + return size + + def write_to(self, out: bytearray) -> None: + if self.key: + v = self.key.encode("utf-8") + out += b"\n" + Varint.write_varint_u32(out, len(v)) + out += v + if self._value is not None: + out += b"\x12" + Varint.write_varint_u32(out, self._value._get_size()) + self._value.write_to(out) + + +class InstrumentationScope(MessageMarshaler): + name: str + version: str + + @property + def attributes(self) -> List[KeyValue]: + if self._attributes is None: + self._attributes = list() + return self._attributes + + dropped_attributes_count: int + + def __init__( + self, + name: str = "", + version: str = "", + attributes: List[KeyValue] = None, + dropped_attributes_count: int = 0, + ): + self.name: str = name + self.version: str = version + self._attributes: List[KeyValue] = attributes + self.dropped_attributes_count: int = dropped_attributes_count + + def calculate_size(self) -> int: + size = 0 + if self.name: + v = self.name.encode("utf-8") + size += len(b"\n") + Varint.size_varint_u32(len(v)) + len(v) + if self.version: + v = self.version.encode("utf-8") + size += len(b"\x12") + Varint.size_varint_u32(len(v)) + len(v) + if self._attributes: + size += sum( + message._get_size() + + len(b"\x1a") + + Varint.size_varint_u32(message._get_size()) + for message in self._attributes + ) + if self.dropped_attributes_count: + size += len(b" ") + Varint.size_varint_u32(self.dropped_attributes_count) + return size + + def write_to(self, out: bytearray) -> None: + if self.name: + v = self.name.encode("utf-8") + out += b"\n" + Varint.write_varint_u32(out, len(v)) + out += v + if self.version: + v = self.version.encode("utf-8") + out += b"\x12" + Varint.write_varint_u32(out, len(v)) + out += v + if self._attributes: + for v in self._attributes: + out += b"\x1a" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + if self.dropped_attributes_count: + out += b" " + Varint.write_varint_u32(out, self.dropped_attributes_count) diff --git a/src/snowflake/telemetry/_internal/opentelemetry/proto/logs/v1/logs_marshaler.py b/src/snowflake/telemetry/_internal/opentelemetry/proto/logs/v1/logs_marshaler.py new file mode 100644 index 0000000..18996f0 --- /dev/null +++ b/src/snowflake/telemetry/_internal/opentelemetry/proto/logs/v1/logs_marshaler.py @@ -0,0 +1,339 @@ +# Generated by the protoc compiler with a custom plugin. DO NOT EDIT! +# sources: opentelemetry/proto/logs/v1/logs.proto + +from __future__ import annotations + +import struct +from typing import List + +from snowflake.telemetry._internal.opentelemetry.proto.common.v1.common_marshaler import * +from snowflake.telemetry._internal.opentelemetry.proto.resource.v1.resource_marshaler import * +from snowflake.telemetry._internal.serialize import ( + Enum, + MessageMarshaler, + Varint, +) + + +class SeverityNumber(Enum): + SEVERITY_NUMBER_UNSPECIFIED = 0 + SEVERITY_NUMBER_TRACE = 1 + SEVERITY_NUMBER_TRACE2 = 2 + SEVERITY_NUMBER_TRACE3 = 3 + SEVERITY_NUMBER_TRACE4 = 4 + SEVERITY_NUMBER_DEBUG = 5 + SEVERITY_NUMBER_DEBUG2 = 6 + SEVERITY_NUMBER_DEBUG3 = 7 + SEVERITY_NUMBER_DEBUG4 = 8 + SEVERITY_NUMBER_INFO = 9 + SEVERITY_NUMBER_INFO2 = 10 + SEVERITY_NUMBER_INFO3 = 11 + SEVERITY_NUMBER_INFO4 = 12 + SEVERITY_NUMBER_WARN = 13 + SEVERITY_NUMBER_WARN2 = 14 + SEVERITY_NUMBER_WARN3 = 15 + SEVERITY_NUMBER_WARN4 = 16 + SEVERITY_NUMBER_ERROR = 17 + SEVERITY_NUMBER_ERROR2 = 18 + SEVERITY_NUMBER_ERROR3 = 19 + SEVERITY_NUMBER_ERROR4 = 20 + SEVERITY_NUMBER_FATAL = 21 + SEVERITY_NUMBER_FATAL2 = 22 + SEVERITY_NUMBER_FATAL3 = 23 + SEVERITY_NUMBER_FATAL4 = 24 + + +class LogRecordFlags(Enum): + LOG_RECORD_FLAGS_DO_NOT_USE = 0 + LOG_RECORD_FLAGS_TRACE_FLAGS_MASK = 255 + + +class LogsData(MessageMarshaler): + @property + def resource_logs(self) -> List[ResourceLogs]: + if self._resource_logs is None: + self._resource_logs = list() + return self._resource_logs + + def __init__( + self, + resource_logs: List[ResourceLogs] = None, + ): + self._resource_logs: List[ResourceLogs] = resource_logs + + def calculate_size(self) -> int: + size = 0 + if self._resource_logs: + size += sum( + message._get_size() + + len(b"\n") + + Varint.size_varint_u32(message._get_size()) + for message in self._resource_logs + ) + return size + + def write_to(self, out: bytearray) -> None: + if self._resource_logs: + for v in self._resource_logs: + out += b"\n" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + + +class ResourceLogs(MessageMarshaler): + @property + def resource(self) -> Resource: + if self._resource is None: + self._resource = Resource() + return self._resource + + @property + def scope_logs(self) -> List[ScopeLogs]: + if self._scope_logs is None: + self._scope_logs = list() + return self._scope_logs + + schema_url: str + + def __init__( + self, + resource: Resource = None, + scope_logs: List[ScopeLogs] = None, + schema_url: str = "", + ): + self._resource: Resource = resource + self._scope_logs: List[ScopeLogs] = scope_logs + self.schema_url: str = schema_url + + def calculate_size(self) -> int: + size = 0 + if self._resource is not None: + size += ( + len(b"\n") + + Varint.size_varint_u32(self._resource._get_size()) + + self._resource._get_size() + ) + if self._scope_logs: + size += sum( + message._get_size() + + len(b"\x12") + + Varint.size_varint_u32(message._get_size()) + for message in self._scope_logs + ) + if self.schema_url: + v = self.schema_url.encode("utf-8") + size += len(b"\x1a") + Varint.size_varint_u32(len(v)) + len(v) + return size + + def write_to(self, out: bytearray) -> None: + if self._resource is not None: + out += b"\n" + Varint.write_varint_u32(out, self._resource._get_size()) + self._resource.write_to(out) + if self._scope_logs: + for v in self._scope_logs: + out += b"\x12" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + if self.schema_url: + v = self.schema_url.encode("utf-8") + out += b"\x1a" + Varint.write_varint_u32(out, len(v)) + out += v + + +class ScopeLogs(MessageMarshaler): + @property + def scope(self) -> InstrumentationScope: + if self._scope is None: + self._scope = InstrumentationScope() + return self._scope + + @property + def log_records(self) -> List[LogRecord]: + if self._log_records is None: + self._log_records = list() + return self._log_records + + schema_url: str + + def __init__( + self, + scope: InstrumentationScope = None, + log_records: List[LogRecord] = None, + schema_url: str = "", + ): + self._scope: InstrumentationScope = scope + self._log_records: List[LogRecord] = log_records + self.schema_url: str = schema_url + + def calculate_size(self) -> int: + size = 0 + if self._scope is not None: + size += ( + len(b"\n") + + Varint.size_varint_u32(self._scope._get_size()) + + self._scope._get_size() + ) + if self._log_records: + size += sum( + message._get_size() + + len(b"\x12") + + Varint.size_varint_u32(message._get_size()) + for message in self._log_records + ) + if self.schema_url: + v = self.schema_url.encode("utf-8") + size += len(b"\x1a") + Varint.size_varint_u32(len(v)) + len(v) + return size + + def write_to(self, out: bytearray) -> None: + if self._scope is not None: + out += b"\n" + Varint.write_varint_u32(out, self._scope._get_size()) + self._scope.write_to(out) + if self._log_records: + for v in self._log_records: + out += b"\x12" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + if self.schema_url: + v = self.schema_url.encode("utf-8") + out += b"\x1a" + Varint.write_varint_u32(out, len(v)) + out += v + + +class LogRecord(MessageMarshaler): + time_unix_nano: int + severity_number: SeverityNumber + severity_text: str + + @property + def body(self) -> AnyValue: + if self._body is None: + self._body = AnyValue() + return self._body + + @property + def attributes(self) -> List[KeyValue]: + if self._attributes is None: + self._attributes = list() + return self._attributes + + dropped_attributes_count: int + flags: int + trace_id: bytes + span_id: bytes + observed_time_unix_nano: int + + def __init__( + self, + time_unix_nano: int = 0, + severity_number: SeverityNumber = 0, + severity_text: str = "", + body: AnyValue = None, + attributes: List[KeyValue] = None, + dropped_attributes_count: int = 0, + flags: int = 0, + trace_id: bytes = b"", + span_id: bytes = b"", + observed_time_unix_nano: int = 0, + ): + self.time_unix_nano: int = time_unix_nano + self.severity_number: SeverityNumber = severity_number + self.severity_text: str = severity_text + self._body: AnyValue = body + self._attributes: List[KeyValue] = attributes + self.dropped_attributes_count: int = dropped_attributes_count + self.flags: int = flags + self.trace_id: bytes = trace_id + self.span_id: bytes = span_id + self.observed_time_unix_nano: int = observed_time_unix_nano + + def calculate_size(self) -> int: + size = 0 + if self.time_unix_nano: + size += len(b"\t") + 8 + if self.severity_number: + v = self.severity_number + if not isinstance(v, int): + v = v.value + size += len(b"\x10") + Varint.size_varint_u32(v) + if self.severity_text: + v = self.severity_text.encode("utf-8") + size += len(b"\x1a") + Varint.size_varint_u32(len(v)) + len(v) + if self._body is not None: + size += ( + len(b"*") + + Varint.size_varint_u32(self._body._get_size()) + + self._body._get_size() + ) + if self._attributes: + size += sum( + message._get_size() + + len(b"2") + + Varint.size_varint_u32(message._get_size()) + for message in self._attributes + ) + if self.dropped_attributes_count: + size += len(b"8") + Varint.size_varint_u32(self.dropped_attributes_count) + if self.flags: + size += len(b"E") + 4 + if self.trace_id: + size += ( + len(b"J") + + Varint.size_varint_u32(len(self.trace_id)) + + len(self.trace_id) + ) + if self.span_id: + size += ( + len(b"R") + + Varint.size_varint_u32(len(self.span_id)) + + len(self.span_id) + ) + if self.observed_time_unix_nano: + size += len(b"Y") + 8 + return size + + def write_to(self, out: bytearray) -> None: + if self.time_unix_nano: + out += b"\t" + out += struct.pack(" List[ResourceMetrics]: + if self._resource_metrics is None: + self._resource_metrics = list() + return self._resource_metrics + + def __init__( + self, + resource_metrics: List[ResourceMetrics] = None, + ): + self._resource_metrics: List[ResourceMetrics] = resource_metrics + + def calculate_size(self) -> int: + size = 0 + if self._resource_metrics: + size += sum( + message._get_size() + + len(b"\n") + + Varint.size_varint_u32(message._get_size()) + for message in self._resource_metrics + ) + return size + + def write_to(self, out: bytearray) -> None: + if self._resource_metrics: + for v in self._resource_metrics: + out += b"\n" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + + +class ResourceMetrics(MessageMarshaler): + @property + def resource(self) -> Resource: + if self._resource is None: + self._resource = Resource() + return self._resource + + @property + def scope_metrics(self) -> List[ScopeMetrics]: + if self._scope_metrics is None: + self._scope_metrics = list() + return self._scope_metrics + + schema_url: str + + def __init__( + self, + resource: Resource = None, + scope_metrics: List[ScopeMetrics] = None, + schema_url: str = "", + ): + self._resource: Resource = resource + self._scope_metrics: List[ScopeMetrics] = scope_metrics + self.schema_url: str = schema_url + + def calculate_size(self) -> int: + size = 0 + if self._resource is not None: + size += ( + len(b"\n") + + Varint.size_varint_u32(self._resource._get_size()) + + self._resource._get_size() + ) + if self._scope_metrics: + size += sum( + message._get_size() + + len(b"\x12") + + Varint.size_varint_u32(message._get_size()) + for message in self._scope_metrics + ) + if self.schema_url: + v = self.schema_url.encode("utf-8") + size += len(b"\x1a") + Varint.size_varint_u32(len(v)) + len(v) + return size + + def write_to(self, out: bytearray) -> None: + if self._resource is not None: + out += b"\n" + Varint.write_varint_u32(out, self._resource._get_size()) + self._resource.write_to(out) + if self._scope_metrics: + for v in self._scope_metrics: + out += b"\x12" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + if self.schema_url: + v = self.schema_url.encode("utf-8") + out += b"\x1a" + Varint.write_varint_u32(out, len(v)) + out += v + + +class ScopeMetrics(MessageMarshaler): + @property + def scope(self) -> InstrumentationScope: + if self._scope is None: + self._scope = InstrumentationScope() + return self._scope + + @property + def metrics(self) -> List[Metric]: + if self._metrics is None: + self._metrics = list() + return self._metrics + + schema_url: str + + def __init__( + self, + scope: InstrumentationScope = None, + metrics: List[Metric] = None, + schema_url: str = "", + ): + self._scope: InstrumentationScope = scope + self._metrics: List[Metric] = metrics + self.schema_url: str = schema_url + + def calculate_size(self) -> int: + size = 0 + if self._scope is not None: + size += ( + len(b"\n") + + Varint.size_varint_u32(self._scope._get_size()) + + self._scope._get_size() + ) + if self._metrics: + size += sum( + message._get_size() + + len(b"\x12") + + Varint.size_varint_u32(message._get_size()) + for message in self._metrics + ) + if self.schema_url: + v = self.schema_url.encode("utf-8") + size += len(b"\x1a") + Varint.size_varint_u32(len(v)) + len(v) + return size + + def write_to(self, out: bytearray) -> None: + if self._scope is not None: + out += b"\n" + Varint.write_varint_u32(out, self._scope._get_size()) + self._scope.write_to(out) + if self._metrics: + for v in self._metrics: + out += b"\x12" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + if self.schema_url: + v = self.schema_url.encode("utf-8") + out += b"\x1a" + Varint.write_varint_u32(out, len(v)) + out += v + + +class Metric(MessageMarshaler): + name: str + description: str + unit: str + + @property + def gauge(self) -> Gauge: + if self._gauge is None: + self._gauge = Gauge() + return self._gauge + + @property + def sum(self) -> Sum: + if self._sum is None: + self._sum = Sum() + return self._sum + + @property + def histogram(self) -> Histogram: + if self._histogram is None: + self._histogram = Histogram() + return self._histogram + + @property + def exponential_histogram(self) -> ExponentialHistogram: + if self._exponential_histogram is None: + self._exponential_histogram = ExponentialHistogram() + return self._exponential_histogram + + @property + def summary(self) -> Summary: + if self._summary is None: + self._summary = Summary() + return self._summary + + @property + def metadata(self) -> List[KeyValue]: + if self._metadata is None: + self._metadata = list() + return self._metadata + + def __init__( + self, + name: str = "", + description: str = "", + unit: str = "", + gauge: Gauge = None, + sum: Sum = None, + histogram: Histogram = None, + exponential_histogram: ExponentialHistogram = None, + summary: Summary = None, + metadata: List[KeyValue] = None, + ): + self.name: str = name + self.description: str = description + self.unit: str = unit + self._gauge: Gauge = gauge + self._sum: Sum = sum + self._histogram: Histogram = histogram + self._exponential_histogram: ExponentialHistogram = exponential_histogram + self._summary: Summary = summary + self._metadata: List[KeyValue] = metadata + + def calculate_size(self) -> int: + size = 0 + if self.name: + v = self.name.encode("utf-8") + size += len(b"\n") + Varint.size_varint_u32(len(v)) + len(v) + if self.description: + v = self.description.encode("utf-8") + size += len(b"\x12") + Varint.size_varint_u32(len(v)) + len(v) + if self.unit: + v = self.unit.encode("utf-8") + size += len(b"\x1a") + Varint.size_varint_u32(len(v)) + len(v) + if self._gauge is not None: + size += ( + len(b"*") + + Varint.size_varint_u32(self._gauge._get_size()) + + self._gauge._get_size() + ) + if self._sum is not None: + size += ( + len(b":") + + Varint.size_varint_u32(self._sum._get_size()) + + self._sum._get_size() + ) + if self._histogram is not None: + size += ( + len(b"J") + + Varint.size_varint_u32(self._histogram._get_size()) + + self._histogram._get_size() + ) + if self._exponential_histogram is not None: + size += ( + len(b"R") + + Varint.size_varint_u32(self._exponential_histogram._get_size()) + + self._exponential_histogram._get_size() + ) + if self._summary is not None: + size += ( + len(b"Z") + + Varint.size_varint_u32(self._summary._get_size()) + + self._summary._get_size() + ) + if self._metadata: + size += sum( + message._get_size() + + len(b"b") + + Varint.size_varint_u32(message._get_size()) + for message in self._metadata + ) + return size + + def write_to(self, out: bytearray) -> None: + if self.name: + v = self.name.encode("utf-8") + out += b"\n" + Varint.write_varint_u32(out, len(v)) + out += v + if self.description: + v = self.description.encode("utf-8") + out += b"\x12" + Varint.write_varint_u32(out, len(v)) + out += v + if self.unit: + v = self.unit.encode("utf-8") + out += b"\x1a" + Varint.write_varint_u32(out, len(v)) + out += v + if self._gauge is not None: + out += b"*" + Varint.write_varint_u32(out, self._gauge._get_size()) + self._gauge.write_to(out) + if self._sum is not None: + out += b":" + Varint.write_varint_u32(out, self._sum._get_size()) + self._sum.write_to(out) + if self._histogram is not None: + out += b"J" + Varint.write_varint_u32(out, self._histogram._get_size()) + self._histogram.write_to(out) + if self._exponential_histogram is not None: + out += b"R" + Varint.write_varint_u32(out, self._exponential_histogram._get_size()) + self._exponential_histogram.write_to(out) + if self._summary is not None: + out += b"Z" + Varint.write_varint_u32(out, self._summary._get_size()) + self._summary.write_to(out) + if self._metadata: + for v in self._metadata: + out += b"b" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + + +class Gauge(MessageMarshaler): + @property + def data_points(self) -> List[NumberDataPoint]: + if self._data_points is None: + self._data_points = list() + return self._data_points + + def __init__( + self, + data_points: List[NumberDataPoint] = None, + ): + self._data_points: List[NumberDataPoint] = data_points + + def calculate_size(self) -> int: + size = 0 + if self._data_points: + size += sum( + message._get_size() + + len(b"\n") + + Varint.size_varint_u32(message._get_size()) + for message in self._data_points + ) + return size + + def write_to(self, out: bytearray) -> None: + if self._data_points: + for v in self._data_points: + out += b"\n" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + + +class Sum(MessageMarshaler): + @property + def data_points(self) -> List[NumberDataPoint]: + if self._data_points is None: + self._data_points = list() + return self._data_points + + aggregation_temporality: AggregationTemporality + is_monotonic: bool + + def __init__( + self, + data_points: List[NumberDataPoint] = None, + aggregation_temporality: AggregationTemporality = 0, + is_monotonic: bool = False, + ): + self._data_points: List[NumberDataPoint] = data_points + self.aggregation_temporality: AggregationTemporality = aggregation_temporality + self.is_monotonic: bool = is_monotonic + + def calculate_size(self) -> int: + size = 0 + if self._data_points: + size += sum( + message._get_size() + + len(b"\n") + + Varint.size_varint_u32(message._get_size()) + for message in self._data_points + ) + if self.aggregation_temporality: + v = self.aggregation_temporality + if not isinstance(v, int): + v = v.value + size += len(b"\x10") + Varint.size_varint_u32(v) + if self.is_monotonic: + size += len(b"\x18") + 1 + return size + + def write_to(self, out: bytearray) -> None: + if self._data_points: + for v in self._data_points: + out += b"\n" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + if self.aggregation_temporality: + v = self.aggregation_temporality + if not isinstance(v, int): + v = v.value + out += b"\x10" + Varint.write_varint_u32(out, v) + if self.is_monotonic: + out += b"\x18" + Varint.write_varint_u32(out, 1 if self.is_monotonic else 0) + + +class Histogram(MessageMarshaler): + @property + def data_points(self) -> List[HistogramDataPoint]: + if self._data_points is None: + self._data_points = list() + return self._data_points + + aggregation_temporality: AggregationTemporality + + def __init__( + self, + data_points: List[HistogramDataPoint] = None, + aggregation_temporality: AggregationTemporality = 0, + ): + self._data_points: List[HistogramDataPoint] = data_points + self.aggregation_temporality: AggregationTemporality = aggregation_temporality + + def calculate_size(self) -> int: + size = 0 + if self._data_points: + size += sum( + message._get_size() + + len(b"\n") + + Varint.size_varint_u32(message._get_size()) + for message in self._data_points + ) + if self.aggregation_temporality: + v = self.aggregation_temporality + if not isinstance(v, int): + v = v.value + size += len(b"\x10") + Varint.size_varint_u32(v) + return size + + def write_to(self, out: bytearray) -> None: + if self._data_points: + for v in self._data_points: + out += b"\n" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + if self.aggregation_temporality: + v = self.aggregation_temporality + if not isinstance(v, int): + v = v.value + out += b"\x10" + Varint.write_varint_u32(out, v) + + +class ExponentialHistogram(MessageMarshaler): + @property + def data_points(self) -> List[ExponentialHistogramDataPoint]: + if self._data_points is None: + self._data_points = list() + return self._data_points + + aggregation_temporality: AggregationTemporality + + def __init__( + self, + data_points: List[ExponentialHistogramDataPoint] = None, + aggregation_temporality: AggregationTemporality = 0, + ): + self._data_points: List[ExponentialHistogramDataPoint] = data_points + self.aggregation_temporality: AggregationTemporality = aggregation_temporality + + def calculate_size(self) -> int: + size = 0 + if self._data_points: + size += sum( + message._get_size() + + len(b"\n") + + Varint.size_varint_u32(message._get_size()) + for message in self._data_points + ) + if self.aggregation_temporality: + v = self.aggregation_temporality + if not isinstance(v, int): + v = v.value + size += len(b"\x10") + Varint.size_varint_u32(v) + return size + + def write_to(self, out: bytearray) -> None: + if self._data_points: + for v in self._data_points: + out += b"\n" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + if self.aggregation_temporality: + v = self.aggregation_temporality + if not isinstance(v, int): + v = v.value + out += b"\x10" + Varint.write_varint_u32(out, v) + + +class Summary(MessageMarshaler): + @property + def data_points(self) -> List[SummaryDataPoint]: + if self._data_points is None: + self._data_points = list() + return self._data_points + + def __init__( + self, + data_points: List[SummaryDataPoint] = None, + ): + self._data_points: List[SummaryDataPoint] = data_points + + def calculate_size(self) -> int: + size = 0 + if self._data_points: + size += sum( + message._get_size() + + len(b"\n") + + Varint.size_varint_u32(message._get_size()) + for message in self._data_points + ) + return size + + def write_to(self, out: bytearray) -> None: + if self._data_points: + for v in self._data_points: + out += b"\n" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + + +class NumberDataPoint(MessageMarshaler): + start_time_unix_nano: int + time_unix_nano: int + as_double: float + + @property + def exemplars(self) -> List[Exemplar]: + if self._exemplars is None: + self._exemplars = list() + return self._exemplars + + as_int: int + + @property + def attributes(self) -> List[KeyValue]: + if self._attributes is None: + self._attributes = list() + return self._attributes + + flags: int + + def __init__( + self, + start_time_unix_nano: int = 0, + time_unix_nano: int = 0, + as_double: float = None, + exemplars: List[Exemplar] = None, + as_int: int = None, + attributes: List[KeyValue] = None, + flags: int = 0, + ): + self.start_time_unix_nano: int = start_time_unix_nano + self.time_unix_nano: int = time_unix_nano + self.as_double: float = as_double + self._exemplars: List[Exemplar] = exemplars + self.as_int: int = as_int + self._attributes: List[KeyValue] = attributes + self.flags: int = flags + + def calculate_size(self) -> int: + size = 0 + if self.start_time_unix_nano: + size += len(b"\x11") + 8 + if self.time_unix_nano: + size += len(b"\x19") + 8 + if self.as_double is not None: + size += len(b"!") + 8 + if self._exemplars: + size += sum( + message._get_size() + + len(b"*") + + Varint.size_varint_u32(message._get_size()) + for message in self._exemplars + ) + if self.as_int is not None: + size += len(b"1") + 8 + if self._attributes: + size += sum( + message._get_size() + + len(b":") + + Varint.size_varint_u32(message._get_size()) + for message in self._attributes + ) + if self.flags: + size += len(b"@") + Varint.size_varint_u32(self.flags) + return size + + def write_to(self, out: bytearray) -> None: + if self.start_time_unix_nano: + out += b"\x11" + out += struct.pack(" List[int]: + if self._bucket_counts is None: + self._bucket_counts = list() + return self._bucket_counts + + @property + def explicit_bounds(self) -> List[float]: + if self._explicit_bounds is None: + self._explicit_bounds = list() + return self._explicit_bounds + + @property + def exemplars(self) -> List[Exemplar]: + if self._exemplars is None: + self._exemplars = list() + return self._exemplars + + @property + def attributes(self) -> List[KeyValue]: + if self._attributes is None: + self._attributes = list() + return self._attributes + + flags: int + min: float + max: float + + def __init__( + self, + start_time_unix_nano: int = 0, + time_unix_nano: int = 0, + count: int = 0, + sum: float = None, + bucket_counts: List[int] = None, + explicit_bounds: List[float] = None, + exemplars: List[Exemplar] = None, + attributes: List[KeyValue] = None, + flags: int = 0, + min: float = None, + max: float = None, + ): + self.start_time_unix_nano: int = start_time_unix_nano + self.time_unix_nano: int = time_unix_nano + self.count: int = count + self.sum: float = sum + self._bucket_counts: List[int] = bucket_counts + self._explicit_bounds: List[float] = explicit_bounds + self._exemplars: List[Exemplar] = exemplars + self._attributes: List[KeyValue] = attributes + self.flags: int = flags + self.min: float = min + self.max: float = max + + def calculate_size(self) -> int: + size = 0 + if self.start_time_unix_nano: + size += len(b"\x11") + 8 + if self.time_unix_nano: + size += len(b"\x19") + 8 + if self.count: + size += len(b"!") + 8 + if self.sum is not None: + size += len(b")") + 8 + if self._bucket_counts: + size += ( + len(b"2") + + len(self._bucket_counts) * 8 + + Varint.size_varint_u32(len(self._bucket_counts) * 8) + ) + if self._explicit_bounds: + size += ( + len(b":") + + len(self._explicit_bounds) * 8 + + Varint.size_varint_u32(len(self._explicit_bounds) * 8) + ) + if self._exemplars: + size += sum( + message._get_size() + + len(b"B") + + Varint.size_varint_u32(message._get_size()) + for message in self._exemplars + ) + if self._attributes: + size += sum( + message._get_size() + + len(b"J") + + Varint.size_varint_u32(message._get_size()) + for message in self._attributes + ) + if self.flags: + size += len(b"P") + Varint.size_varint_u32(self.flags) + if self.min is not None: + size += len(b"Y") + 8 + if self.max is not None: + size += len(b"a") + 8 + return size + + def write_to(self, out: bytearray) -> None: + if self.start_time_unix_nano: + out += b"\x11" + out += struct.pack(" List[KeyValue]: + if self._attributes is None: + self._attributes = list() + return self._attributes + + start_time_unix_nano: int + time_unix_nano: int + count: int + sum: float + scale: int + zero_count: int + + @property + def positive(self) -> ExponentialHistogramDataPoint.Buckets: + if self._positive is None: + self._positive = ExponentialHistogramDataPoint.Buckets() + return self._positive + + @property + def negative(self) -> ExponentialHistogramDataPoint.Buckets: + if self._negative is None: + self._negative = ExponentialHistogramDataPoint.Buckets() + return self._negative + + flags: int + + @property + def exemplars(self) -> List[Exemplar]: + if self._exemplars is None: + self._exemplars = list() + return self._exemplars + + min: float + max: float + zero_threshold: float + + def __init__( + self, + attributes: List[KeyValue] = None, + start_time_unix_nano: int = 0, + time_unix_nano: int = 0, + count: int = 0, + sum: float = None, + scale: int = 0, + zero_count: int = 0, + positive: ExponentialHistogramDataPoint.Buckets = None, + negative: ExponentialHistogramDataPoint.Buckets = None, + flags: int = 0, + exemplars: List[Exemplar] = None, + min: float = None, + max: float = None, + zero_threshold: float = 0.0, + ): + self._attributes: List[KeyValue] = attributes + self.start_time_unix_nano: int = start_time_unix_nano + self.time_unix_nano: int = time_unix_nano + self.count: int = count + self.sum: float = sum + self.scale: int = scale + self.zero_count: int = zero_count + self._positive: ExponentialHistogramDataPoint.Buckets = positive + self._negative: ExponentialHistogramDataPoint.Buckets = negative + self.flags: int = flags + self._exemplars: List[Exemplar] = exemplars + self.min: float = min + self.max: float = max + self.zero_threshold: float = zero_threshold + + def calculate_size(self) -> int: + size = 0 + if self._attributes: + size += sum( + message._get_size() + + len(b"\n") + + Varint.size_varint_u32(message._get_size()) + for message in self._attributes + ) + if self.start_time_unix_nano: + size += len(b"\x11") + 8 + if self.time_unix_nano: + size += len(b"\x19") + 8 + if self.count: + size += len(b"!") + 8 + if self.sum is not None: + size += len(b")") + 8 + if self.scale: + size += len(b"0") + Varint.size_varint_s32(self.scale) + if self.zero_count: + size += len(b"9") + 8 + if self._positive is not None: + size += ( + len(b"B") + + Varint.size_varint_u32(self._positive._get_size()) + + self._positive._get_size() + ) + if self._negative is not None: + size += ( + len(b"J") + + Varint.size_varint_u32(self._negative._get_size()) + + self._negative._get_size() + ) + if self.flags: + size += len(b"P") + Varint.size_varint_u32(self.flags) + if self._exemplars: + size += sum( + message._get_size() + + len(b"Z") + + Varint.size_varint_u32(message._get_size()) + for message in self._exemplars + ) + if self.min is not None: + size += len(b"a") + 8 + if self.max is not None: + size += len(b"i") + 8 + if self.zero_threshold: + size += len(b"q") + 8 + return size + + def write_to(self, out: bytearray) -> None: + if self._attributes: + for v in self._attributes: + out += b"\n" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + if self.start_time_unix_nano: + out += b"\x11" + out += struct.pack(" List[int]: + if self._bucket_counts is None: + self._bucket_counts = list() + return self._bucket_counts + + def __init__( + self, + offset: int = 0, + bucket_counts: List[int] = None, + ): + self.offset: int = offset + self._bucket_counts: List[int] = bucket_counts + + def calculate_size(self) -> int: + size = 0 + if self.offset: + size += len(b"\x08") + Varint.size_varint_s32(self.offset) + if self._bucket_counts: + s = sum( + Varint.size_varint_u64(uint64) for uint64 in self._bucket_counts + ) + self.marshaler_cache[b"\x12"] = s + size += len(b"\x12") + s + Varint.size_varint_u32(s) + return size + + def write_to(self, out: bytearray) -> None: + if self.offset: + out += b"\x08" + Varint.write_varint_s32(out, self.offset) + if self._bucket_counts: + out += b"\x12" + Varint.write_varint_u32(out, self.marshaler_cache[b"\x12"]) + for v in self._bucket_counts: + Varint.write_varint_u64(out, v) + + +class SummaryDataPoint(MessageMarshaler): + start_time_unix_nano: int + time_unix_nano: int + count: int + sum: float + + @property + def quantile_values(self) -> List[SummaryDataPoint.ValueAtQuantile]: + if self._quantile_values is None: + self._quantile_values = list() + return self._quantile_values + + @property + def attributes(self) -> List[KeyValue]: + if self._attributes is None: + self._attributes = list() + return self._attributes + + flags: int + + def __init__( + self, + start_time_unix_nano: int = 0, + time_unix_nano: int = 0, + count: int = 0, + sum: float = 0.0, + quantile_values: List[SummaryDataPoint.ValueAtQuantile] = None, + attributes: List[KeyValue] = None, + flags: int = 0, + ): + self.start_time_unix_nano: int = start_time_unix_nano + self.time_unix_nano: int = time_unix_nano + self.count: int = count + self.sum: float = sum + self._quantile_values: List[SummaryDataPoint.ValueAtQuantile] = quantile_values + self._attributes: List[KeyValue] = attributes + self.flags: int = flags + + def calculate_size(self) -> int: + size = 0 + if self.start_time_unix_nano: + size += len(b"\x11") + 8 + if self.time_unix_nano: + size += len(b"\x19") + 8 + if self.count: + size += len(b"!") + 8 + if self.sum: + size += len(b")") + 8 + if self._quantile_values: + size += sum( + message._get_size() + + len(b"2") + + Varint.size_varint_u32(message._get_size()) + for message in self._quantile_values + ) + if self._attributes: + size += sum( + message._get_size() + + len(b":") + + Varint.size_varint_u32(message._get_size()) + for message in self._attributes + ) + if self.flags: + size += len(b"@") + Varint.size_varint_u32(self.flags) + return size + + def write_to(self, out: bytearray) -> None: + if self.start_time_unix_nano: + out += b"\x11" + out += struct.pack(" int: + size = 0 + if self.quantile: + size += len(b"\t") + 8 + if self.value: + size += len(b"\x11") + 8 + return size + + def write_to(self, out: bytearray) -> None: + if self.quantile: + out += b"\t" + out += struct.pack(" List[KeyValue]: + if self._filtered_attributes is None: + self._filtered_attributes = list() + return self._filtered_attributes + + def __init__( + self, + time_unix_nano: int = 0, + as_double: float = None, + span_id: bytes = b"", + trace_id: bytes = b"", + as_int: int = None, + filtered_attributes: List[KeyValue] = None, + ): + self.time_unix_nano: int = time_unix_nano + self.as_double: float = as_double + self.span_id: bytes = span_id + self.trace_id: bytes = trace_id + self.as_int: int = as_int + self._filtered_attributes: List[KeyValue] = filtered_attributes + + def calculate_size(self) -> int: + size = 0 + if self.time_unix_nano: + size += len(b"\x11") + 8 + if self.as_double is not None: + size += len(b"\x19") + 8 + if self.span_id: + size += ( + len(b'"') + + Varint.size_varint_u32(len(self.span_id)) + + len(self.span_id) + ) + if self.trace_id: + size += ( + len(b"*") + + Varint.size_varint_u32(len(self.trace_id)) + + len(self.trace_id) + ) + if self.as_int is not None: + size += len(b"1") + 8 + if self._filtered_attributes: + size += sum( + message._get_size() + + len(b":") + + Varint.size_varint_u32(message._get_size()) + for message in self._filtered_attributes + ) + return size + + def write_to(self, out: bytearray) -> None: + if self.time_unix_nano: + out += b"\x11" + out += struct.pack(" List[KeyValue]: + if self._attributes is None: + self._attributes = list() + return self._attributes + + dropped_attributes_count: int + + def __init__( + self, + attributes: List[KeyValue] = None, + dropped_attributes_count: int = 0, + ): + self._attributes: List[KeyValue] = attributes + self.dropped_attributes_count: int = dropped_attributes_count + + def calculate_size(self) -> int: + size = 0 + if self._attributes: + size += sum( + message._get_size() + + len(b"\n") + + Varint.size_varint_u32(message._get_size()) + for message in self._attributes + ) + if self.dropped_attributes_count: + size += len(b"\x10") + Varint.size_varint_u32(self.dropped_attributes_count) + return size + + def write_to(self, out: bytearray) -> None: + if self._attributes: + for v in self._attributes: + out += b"\n" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + if self.dropped_attributes_count: + out += b"\x10" + Varint.write_varint_u32(out, self.dropped_attributes_count) diff --git a/src/snowflake/telemetry/_internal/opentelemetry/proto/trace/v1/trace_marshaler.py b/src/snowflake/telemetry/_internal/opentelemetry/proto/trace/v1/trace_marshaler.py new file mode 100644 index 0000000..ce957f6 --- /dev/null +++ b/src/snowflake/telemetry/_internal/opentelemetry/proto/trace/v1/trace_marshaler.py @@ -0,0 +1,597 @@ +# Generated by the protoc compiler with a custom plugin. DO NOT EDIT! +# sources: opentelemetry/proto/trace/v1/trace.proto + +from __future__ import annotations + +import struct +from typing import List + +from snowflake.telemetry._internal.opentelemetry.proto.common.v1.common_marshaler import * +from snowflake.telemetry._internal.opentelemetry.proto.resource.v1.resource_marshaler import * +from snowflake.telemetry._internal.serialize import ( + Enum, + MessageMarshaler, + Varint, +) + + +class SpanFlags(Enum): + SPAN_FLAGS_DO_NOT_USE = 0 + SPAN_FLAGS_TRACE_FLAGS_MASK = 255 + SPAN_FLAGS_CONTEXT_HAS_IS_REMOTE_MASK = 256 + SPAN_FLAGS_CONTEXT_IS_REMOTE_MASK = 512 + + +class TracesData(MessageMarshaler): + @property + def resource_spans(self) -> List[ResourceSpans]: + if self._resource_spans is None: + self._resource_spans = list() + return self._resource_spans + + def __init__( + self, + resource_spans: List[ResourceSpans] = None, + ): + self._resource_spans: List[ResourceSpans] = resource_spans + + def calculate_size(self) -> int: + size = 0 + if self._resource_spans: + size += sum( + message._get_size() + + len(b"\n") + + Varint.size_varint_u32(message._get_size()) + for message in self._resource_spans + ) + return size + + def write_to(self, out: bytearray) -> None: + if self._resource_spans: + for v in self._resource_spans: + out += b"\n" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + + +class ResourceSpans(MessageMarshaler): + @property + def resource(self) -> Resource: + if self._resource is None: + self._resource = Resource() + return self._resource + + @property + def scope_spans(self) -> List[ScopeSpans]: + if self._scope_spans is None: + self._scope_spans = list() + return self._scope_spans + + schema_url: str + + def __init__( + self, + resource: Resource = None, + scope_spans: List[ScopeSpans] = None, + schema_url: str = "", + ): + self._resource: Resource = resource + self._scope_spans: List[ScopeSpans] = scope_spans + self.schema_url: str = schema_url + + def calculate_size(self) -> int: + size = 0 + if self._resource is not None: + size += ( + len(b"\n") + + Varint.size_varint_u32(self._resource._get_size()) + + self._resource._get_size() + ) + if self._scope_spans: + size += sum( + message._get_size() + + len(b"\x12") + + Varint.size_varint_u32(message._get_size()) + for message in self._scope_spans + ) + if self.schema_url: + v = self.schema_url.encode("utf-8") + size += len(b"\x1a") + Varint.size_varint_u32(len(v)) + len(v) + return size + + def write_to(self, out: bytearray) -> None: + if self._resource is not None: + out += b"\n" + Varint.write_varint_u32(out, self._resource._get_size()) + self._resource.write_to(out) + if self._scope_spans: + for v in self._scope_spans: + out += b"\x12" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + if self.schema_url: + v = self.schema_url.encode("utf-8") + out += b"\x1a" + Varint.write_varint_u32(out, len(v)) + out += v + + +class ScopeSpans(MessageMarshaler): + @property + def scope(self) -> InstrumentationScope: + if self._scope is None: + self._scope = InstrumentationScope() + return self._scope + + @property + def spans(self) -> List[Span]: + if self._spans is None: + self._spans = list() + return self._spans + + schema_url: str + + def __init__( + self, + scope: InstrumentationScope = None, + spans: List[Span] = None, + schema_url: str = "", + ): + self._scope: InstrumentationScope = scope + self._spans: List[Span] = spans + self.schema_url: str = schema_url + + def calculate_size(self) -> int: + size = 0 + if self._scope is not None: + size += ( + len(b"\n") + + Varint.size_varint_u32(self._scope._get_size()) + + self._scope._get_size() + ) + if self._spans: + size += sum( + message._get_size() + + len(b"\x12") + + Varint.size_varint_u32(message._get_size()) + for message in self._spans + ) + if self.schema_url: + v = self.schema_url.encode("utf-8") + size += len(b"\x1a") + Varint.size_varint_u32(len(v)) + len(v) + return size + + def write_to(self, out: bytearray) -> None: + if self._scope is not None: + out += b"\n" + Varint.write_varint_u32(out, self._scope._get_size()) + self._scope.write_to(out) + if self._spans: + for v in self._spans: + out += b"\x12" + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + if self.schema_url: + v = self.schema_url.encode("utf-8") + out += b"\x1a" + Varint.write_varint_u32(out, len(v)) + out += v + + +class Span(MessageMarshaler): + trace_id: bytes + span_id: bytes + trace_state: str + parent_span_id: bytes + name: str + kind: Span.SpanKind + start_time_unix_nano: int + end_time_unix_nano: int + + @property + def attributes(self) -> List[KeyValue]: + if self._attributes is None: + self._attributes = list() + return self._attributes + + dropped_attributes_count: int + + @property + def events(self) -> List[Span.Event]: + if self._events is None: + self._events = list() + return self._events + + dropped_events_count: int + + @property + def links(self) -> List[Span.Link]: + if self._links is None: + self._links = list() + return self._links + + dropped_links_count: int + + @property + def status(self) -> Status: + if self._status is None: + self._status = Status() + return self._status + + flags: int + + def __init__( + self, + trace_id: bytes = b"", + span_id: bytes = b"", + trace_state: str = "", + parent_span_id: bytes = b"", + name: str = "", + kind: Span.SpanKind = 0, + start_time_unix_nano: int = 0, + end_time_unix_nano: int = 0, + attributes: List[KeyValue] = None, + dropped_attributes_count: int = 0, + events: List[Span.Event] = None, + dropped_events_count: int = 0, + links: List[Span.Link] = None, + dropped_links_count: int = 0, + status: Status = None, + flags: int = 0, + ): + self.trace_id: bytes = trace_id + self.span_id: bytes = span_id + self.trace_state: str = trace_state + self.parent_span_id: bytes = parent_span_id + self.name: str = name + self.kind: Span.SpanKind = kind + self.start_time_unix_nano: int = start_time_unix_nano + self.end_time_unix_nano: int = end_time_unix_nano + self._attributes: List[KeyValue] = attributes + self.dropped_attributes_count: int = dropped_attributes_count + self._events: List[Span.Event] = events + self.dropped_events_count: int = dropped_events_count + self._links: List[Span.Link] = links + self.dropped_links_count: int = dropped_links_count + self._status: Status = status + self.flags: int = flags + + def calculate_size(self) -> int: + size = 0 + if self.trace_id: + size += ( + len(b"\n") + + Varint.size_varint_u32(len(self.trace_id)) + + len(self.trace_id) + ) + if self.span_id: + size += ( + len(b"\x12") + + Varint.size_varint_u32(len(self.span_id)) + + len(self.span_id) + ) + if self.trace_state: + v = self.trace_state.encode("utf-8") + size += len(b"\x1a") + Varint.size_varint_u32(len(v)) + len(v) + if self.parent_span_id: + size += ( + len(b'"') + + Varint.size_varint_u32(len(self.parent_span_id)) + + len(self.parent_span_id) + ) + if self.name: + v = self.name.encode("utf-8") + size += len(b"*") + Varint.size_varint_u32(len(v)) + len(v) + if self.kind: + v = self.kind + if not isinstance(v, int): + v = v.value + size += len(b"0") + Varint.size_varint_u32(v) + if self.start_time_unix_nano: + size += len(b"9") + 8 + if self.end_time_unix_nano: + size += len(b"A") + 8 + if self._attributes: + size += sum( + message._get_size() + + len(b"J") + + Varint.size_varint_u32(message._get_size()) + for message in self._attributes + ) + if self.dropped_attributes_count: + size += len(b"P") + Varint.size_varint_u32(self.dropped_attributes_count) + if self._events: + size += sum( + message._get_size() + + len(b"Z") + + Varint.size_varint_u32(message._get_size()) + for message in self._events + ) + if self.dropped_events_count: + size += len(b"`") + Varint.size_varint_u32(self.dropped_events_count) + if self._links: + size += sum( + message._get_size() + + len(b"j") + + Varint.size_varint_u32(message._get_size()) + for message in self._links + ) + if self.dropped_links_count: + size += len(b"p") + Varint.size_varint_u32(self.dropped_links_count) + if self._status is not None: + size += ( + len(b"z") + + Varint.size_varint_u32(self._status._get_size()) + + self._status._get_size() + ) + if self.flags: + size += len(b"\x85\x01") + 4 + return size + + def write_to(self, out: bytearray) -> None: + if self.trace_id: + out += b"\n" + Varint.write_varint_u32(out, len(self.trace_id)) + out += self.trace_id + if self.span_id: + out += b"\x12" + Varint.write_varint_u32(out, len(self.span_id)) + out += self.span_id + if self.trace_state: + v = self.trace_state.encode("utf-8") + out += b"\x1a" + Varint.write_varint_u32(out, len(v)) + out += v + if self.parent_span_id: + out += b'"' + Varint.write_varint_u32(out, len(self.parent_span_id)) + out += self.parent_span_id + if self.name: + v = self.name.encode("utf-8") + out += b"*" + Varint.write_varint_u32(out, len(v)) + out += v + if self.kind: + v = self.kind + if not isinstance(v, int): + v = v.value + out += b"0" + Varint.write_varint_u32(out, v) + if self.start_time_unix_nano: + out += b"9" + out += struct.pack(" List[KeyValue]: + if self._attributes is None: + self._attributes = list() + return self._attributes + + dropped_attributes_count: int + + def __init__( + self, + time_unix_nano: int = 0, + name: str = "", + attributes: List[KeyValue] = None, + dropped_attributes_count: int = 0, + ): + self.time_unix_nano: int = time_unix_nano + self.name: str = name + self._attributes: List[KeyValue] = attributes + self.dropped_attributes_count: int = dropped_attributes_count + + def calculate_size(self) -> int: + size = 0 + if self.time_unix_nano: + size += len(b"\t") + 8 + if self.name: + v = self.name.encode("utf-8") + size += len(b"\x12") + Varint.size_varint_u32(len(v)) + len(v) + if self._attributes: + size += sum( + message._get_size() + + len(b"\x1a") + + Varint.size_varint_u32(message._get_size()) + for message in self._attributes + ) + if self.dropped_attributes_count: + size += len(b" ") + Varint.size_varint_u32( + self.dropped_attributes_count + ) + return size + + def write_to(self, out: bytearray) -> None: + if self.time_unix_nano: + out += b"\t" + out += struct.pack(" List[KeyValue]: + if self._attributes is None: + self._attributes = list() + return self._attributes + + dropped_attributes_count: int + flags: int + + def __init__( + self, + trace_id: bytes = b"", + span_id: bytes = b"", + trace_state: str = "", + attributes: List[KeyValue] = None, + dropped_attributes_count: int = 0, + flags: int = 0, + ): + self.trace_id: bytes = trace_id + self.span_id: bytes = span_id + self.trace_state: str = trace_state + self._attributes: List[KeyValue] = attributes + self.dropped_attributes_count: int = dropped_attributes_count + self.flags: int = flags + + def calculate_size(self) -> int: + size = 0 + if self.trace_id: + size += ( + len(b"\n") + + Varint.size_varint_u32(len(self.trace_id)) + + len(self.trace_id) + ) + if self.span_id: + size += ( + len(b"\x12") + + Varint.size_varint_u32(len(self.span_id)) + + len(self.span_id) + ) + if self.trace_state: + v = self.trace_state.encode("utf-8") + size += len(b"\x1a") + Varint.size_varint_u32(len(v)) + len(v) + if self._attributes: + size += sum( + message._get_size() + + len(b'"') + + Varint.size_varint_u32(message._get_size()) + for message in self._attributes + ) + if self.dropped_attributes_count: + size += len(b"(") + Varint.size_varint_u32( + self.dropped_attributes_count + ) + if self.flags: + size += len(b"5") + 4 + return size + + def write_to(self, out: bytearray) -> None: + if self.trace_id: + out += b"\n" + Varint.write_varint_u32(out, len(self.trace_id)) + out += self.trace_id + if self.span_id: + out += b"\x12" + Varint.write_varint_u32(out, len(self.span_id)) + out += self.span_id + if self.trace_state: + v = self.trace_state.encode("utf-8") + out += b"\x1a" + Varint.write_varint_u32(out, len(v)) + out += v + if self._attributes: + for v in self._attributes: + out += b'"' + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + if self.dropped_attributes_count: + out += b"(" + Varint.write_varint_u32(out, self.dropped_attributes_count) + if self.flags: + out += b"5" + out += struct.pack(" int: + size = 0 + if self.message: + v = self.message.encode("utf-8") + size += len(b"\x12") + Varint.size_varint_u32(len(v)) + len(v) + if self.code: + v = self.code + if not isinstance(v, int): + v = v.value + size += len(b"\x18") + Varint.size_varint_u32(v) + return size + + def write_to(self, out: bytearray) -> None: + if self.message: + v = self.message.encode("utf-8") + out += b"\x12" + Varint.write_varint_u32(out, len(v)) + out += v + if self.code: + v = self.code + if not isinstance(v, int): + v = v.value + out += b"\x18" + Varint.write_varint_u32(out, v) + + class StatusCode(Enum): + STATUS_CODE_UNSET = 0 + STATUS_CODE_OK = 1 + STATUS_CODE_ERROR = 2 diff --git a/src/snowflake/telemetry/_internal/serialize/__init__.py b/src/snowflake/telemetry/_internal/serialize/__init__.py new file mode 100644 index 0000000..bb4d92a --- /dev/null +++ b/src/snowflake/telemetry/_internal/serialize/__init__.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +import struct +from enum import IntEnum +from typing import List, Union, Dict, Any + +# Alias Enum to IntEnum +Enum = IntEnum + +# Static class to handle varint encoding +# There is code duplication for performance reasons +# https://developers.google.com/protocol-buffers/docs/encoding#varints +class Varint: + @staticmethod + def size_varint_u32(value: int) -> int: + size = 1 + while value >= 128: + value >>= 7 + size += 1 + return size + + size_varint_u64 = size_varint_u32 + + @staticmethod + def size_varint_i32(value: int) -> int: + value = value + (1 << 32) if value < 0 else value + size = 1 + while value >= 128: + value >>= 7 + size += 1 + return size + + @staticmethod + def size_varint_i64(value: int) -> int: + value = value + (1 << 64) if value < 0 else value + size = 1 + while value >= 128: + value >>= 7 + size += 1 + return size + + @staticmethod + def size_varint_s32(value: int) -> int: + value = value << 1 if value >= 0 else (value << 1) ^ (~0) + size = 1 + while value >= 128: + value >>= 7 + size += 1 + return size + + size_varint_s64 = size_varint_s32 + + @staticmethod + def write_varint_u32(out: bytearray, value: int) -> None: + while value >= 128: + out.append((value & 0x7F) | 0x80) + value >>= 7 + out.append(value) + + write_varint_u64 = write_varint_u32 + + @staticmethod + def write_varint_i32(out: bytearray, value: int) -> None: + value = value + (1 << 32) if value < 0 else value + while value >= 128: + out.append((value & 0x7F) | 0x80) + value >>= 7 + out.append(value) + + @staticmethod + def write_varint_i64(out: bytearray, value: int) -> None: + value = value + (1 << 64) if value < 0 else value + while value >= 128: + out.append((value & 0x7F) | 0x80) + value >>= 7 + out.append(value) + + @staticmethod + def write_varint_s32(out: bytearray, value: int) -> None: + value = value << 1 if value >= 0 else (value << 1) ^ (~0) + while value >= 128: + out.append((value & 0x7F) | 0x80) + value >>= 7 + out.append(value) + + write_varint_s64 = write_varint_s32 + +# Base class for all custom messages +class MessageMarshaler: + # There is a high overhead for creating an empty dict + # For this reason, the cache dict is lazily initialized + @property + def marshaler_cache(self) -> Dict[bytes, Any]: + if not hasattr(self, "_marshaler_cache"): + self._marshaler_cache = {} + return self._marshaler_cache + + def write_to(self, out: bytearray) -> None: + ... + + def calculate_size(self) -> int: + ... + + def _get_size(self) -> int: + if not hasattr(self, "_size"): + self._size = self.calculate_size() + return self._size + + def SerializeToString(self) -> bytes: + # size MUST be calculated before serializing since some preprocessing is done + self._get_size() + stream = bytearray() + self.write_to(stream) + return bytes(stream) + + def __bytes__(self) -> bytes: + return self.SerializeToString() + + # The following size and serialize functions may be inlined by the code generator + # The following strings are replaced by the code generator for inlining: + # - TAG + # - FIELD_ATTR + + def size_bool(self, TAG: bytes, _) -> int: + return len(TAG) + 1 + + def size_enum(self, TAG: bytes, FIELD_ATTR: Union[Enum, int]) -> int: + v = FIELD_ATTR + if not isinstance(v, int): + v = v.value + return len(TAG) + Varint.size_varint_u32(v) + + def size_uint32(self, TAG: bytes, FIELD_ATTR: int) -> int: + return len(TAG) + Varint.size_varint_u32(FIELD_ATTR) + + def size_uint64(self, TAG: bytes, FIELD_ATTR: int) -> int: + return len(TAG) + Varint.size_varint_u64(FIELD_ATTR) + + def size_sint32(self, TAG: bytes, FIELD_ATTR: int) -> int: + return len(TAG) + Varint.size_varint_s32(FIELD_ATTR) + + def size_sint64(self, TAG: bytes, FIELD_ATTR: int) -> int: + return len(TAG) + Varint.size_varint_s64(FIELD_ATTR) + + def size_int32(self, TAG: bytes, FIELD_ATTR: int) -> int: + return len(TAG) + Varint.size_varint_i32(FIELD_ATTR) + + def size_int64(self, TAG: bytes, FIELD_ATTR: int) -> int: + return len(TAG) + Varint.size_varint_i64(FIELD_ATTR) + + def size_float(self, TAG: bytes, FIELD_ATTR: float) -> int: + return len(TAG) + 4 + + def size_double(self, TAG: bytes, FIELD_ATTR: float) -> int: + return len(TAG) + 8 + + def size_fixed32(self, TAG: bytes, FIELD_ATTR: int) -> int: + return len(TAG) + 4 + + def size_fixed64(self, TAG: bytes, FIELD_ATTR: int) -> int: + return len(TAG) + 8 + + def size_sfixed32(self, TAG: bytes, FIELD_ATTR: int) -> int: + return len(TAG) + 4 + + def size_sfixed64(self, TAG: bytes, FIELD_ATTR: int) -> int: + return len(TAG) + 8 + + def size_bytes(self, TAG: bytes, FIELD_ATTR: bytes) -> int: + return len(TAG) + Varint.size_varint_u32(len(FIELD_ATTR)) + len(FIELD_ATTR) + + def size_string(self, TAG: bytes, FIELD_ATTR: str) -> int: + v = FIELD_ATTR.encode("utf-8") + return len(TAG) + Varint.size_varint_u32(len(v)) + len(v) + + def size_message(self, TAG: bytes, FIELD_ATTR: MessageMarshaler) -> int: + return len(TAG) + Varint.size_varint_u32(FIELD_ATTR._get_size()) + FIELD_ATTR._get_size() + + def size_repeated_message(self, TAG: bytes, FIELD_ATTR: List[MessageMarshaler]) -> int: + return sum(message._get_size() + len(TAG) + Varint.size_varint_u32(message._get_size()) for message in FIELD_ATTR) + + def size_repeated_double(self, TAG: bytes, FIELD_ATTR: List[float]): + return len(TAG) + len(FIELD_ATTR) * 8 + Varint.size_varint_u32(len(FIELD_ATTR) * 8) + + def size_repeated_fixed64(self, TAG: bytes, FIELD_ATTR: List[int]): + return len(TAG) + len(FIELD_ATTR) * 8 + Varint.size_varint_u32(len(FIELD_ATTR) * 8) + + def size_repeated_uint64(self, TAG: bytes, FIELD_ATTR: List[int]): + s = sum(Varint.size_varint_u64(uint64) for uint64 in FIELD_ATTR) + self.marshaler_cache[TAG] = s + return len(TAG) + s + Varint.size_varint_u32(s) + + def serialize_bool(self, out: bytearray, TAG: bytes, FIELD_ATTR: bool) -> None: + out += TAG + Varint.write_varint_u32(out, 1 if FIELD_ATTR else 0) + + def serialize_enum(self, out: bytearray, TAG: bytes, FIELD_ATTR: Union[Enum, int]) -> None: + v = FIELD_ATTR + if not isinstance(v, int): + v = v.value + out += TAG + Varint.write_varint_u32(out, v) + + def serialize_uint32(self, out: bytearray, TAG: bytes, FIELD_ATTR: int) -> None: + out += TAG + Varint.write_varint_u32(out, FIELD_ATTR) + + def serialize_uint64(self, out: bytearray, TAG: bytes, FIELD_ATTR: int) -> None: + out += TAG + Varint.write_varint_u64(out, FIELD_ATTR) + + def serialize_sint32(self, out: bytearray, TAG: bytes, FIELD_ATTR: int) -> None: + out += TAG + Varint.write_varint_s32(out, FIELD_ATTR) + + def serialize_sint64(self, out: bytearray, TAG: bytes, FIELD_ATTR: int) -> None: + out += TAG + Varint.write_varint_s64(out, FIELD_ATTR) + + def serialize_int32(self, out: bytearray, TAG: bytes, FIELD_ATTR: int) -> None: + out += TAG + Varint.write_varint_i32(out, FIELD_ATTR) + + def serialize_int64(self, out: bytearray, TAG: bytes, FIELD_ATTR: int) -> None: + out += TAG + Varint.write_varint_i64(out, FIELD_ATTR) + + def serialize_fixed32(self, out: bytearray, TAG: bytes, FIELD_ATTR: int) -> None: + out += TAG + out += struct.pack(" None: + out += TAG + out += struct.pack(" None: + out += TAG + out += struct.pack(" None: + out += TAG + out += struct.pack(" None: + out += TAG + out += struct.pack(" None: + out += TAG + out += struct.pack(" None: + out += TAG + Varint.write_varint_u32(out, len(FIELD_ATTR)) + out += FIELD_ATTR + + def serialize_string(self, out: bytearray, TAG: bytes, FIELD_ATTR: str) -> None: + v = FIELD_ATTR.encode("utf-8") + out += TAG + Varint.write_varint_u32(out, len(v)) + out += v + + def serialize_message(self, out: bytearray, TAG: bytes, FIELD_ATTR: MessageMarshaler) -> None: + out += TAG + Varint.write_varint_u32(out, FIELD_ATTR._get_size()) + FIELD_ATTR.write_to(out) + + def serialize_repeated_message(self, out: bytearray, TAG: bytes, FIELD_ATTR: List[MessageMarshaler]) -> None: + for v in FIELD_ATTR: + out += TAG + Varint.write_varint_u32(out, v._get_size()) + v.write_to(out) + + def serialize_repeated_double(self, out: bytearray, TAG: bytes, FIELD_ATTR: List[float]) -> None: + out += TAG + Varint.write_varint_u32(out, len(FIELD_ATTR) * 8) + for v in FIELD_ATTR: + out += struct.pack(" None: + out += TAG + Varint.write_varint_u32(out, len(FIELD_ATTR) * 8) + for v in FIELD_ATTR: + out += struct.pack(" None: + out += TAG + Varint.write_varint_u32(out, self.marshaler_cache[TAG]) + for v in FIELD_ATTR: + Varint.write_varint_u64(out, v) + diff --git a/tests/snowflake-telemetry-test-utils/setup.py b/tests/snowflake-telemetry-test-utils/setup.py index 4f78b56..0511b8e 100644 --- a/tests/snowflake-telemetry-test-utils/setup.py +++ b/tests/snowflake-telemetry-test-utils/setup.py @@ -17,6 +17,11 @@ install_requires=[ "pytest >= 7.0.0", "snowflake-telemetry-python == 0.6.0.dev", + "Jinja2 == 3.1.4", + "grpcio-tools >= 1.62.3", + "black >= 24.1.0", + "isort >= 5.12.0", + "hypothesis >= 6.0.0", ], packages=find_namespace_packages( where='src' diff --git a/tests/test_proto_serialization.py b/tests/test_proto_serialization.py new file mode 100644 index 0000000..5e3458f --- /dev/null +++ b/tests/test_proto_serialization.py @@ -0,0 +1,501 @@ +from __future__ import annotations + +from typing import ( + Any, + Dict, + List, + Mapping, +) +import unittest +import hypothesis +import hypothesis.control as hc +import hypothesis.strategies as st + +import opentelemetry.proto.logs.v1.logs_pb2 as logs_pb2 +import opentelemetry.proto.trace.v1.trace_pb2 as trace_pb2 +import opentelemetry.proto.common.v1.common_pb2 as common_pb2 +import opentelemetry.proto.metrics.v1.metrics_pb2 as metrics_pb2 +import opentelemetry.proto.resource.v1.resource_pb2 as resource_pb2 + +import snowflake.telemetry._internal.opentelemetry.proto.logs.v1.logs_marshaler as logs_sf +import snowflake.telemetry._internal.opentelemetry.proto.trace.v1.trace_marshaler as trace_sf +import snowflake.telemetry._internal.opentelemetry.proto.common.v1.common_marshaler as common_sf +import snowflake.telemetry._internal.opentelemetry.proto.metrics.v1.metrics_marshaler as metrics_sf +import snowflake.telemetry._internal.opentelemetry.proto.resource.v1.resource_marshaler as resource_sf + +# Strategy for generating protobuf types +def nullable(type): return st.one_of(st.none(), type) +def pb_uint32(): return nullable(st.integers(min_value=0, max_value=2**32-1)) +def pb_uint64(): return nullable(st.integers(min_value=0, max_value=2**64-1)) +def pb_int32(): return nullable(st.integers(min_value=-2**31, max_value=2**31-1)) +def pb_int64(): return nullable(st.integers(min_value=-2**63, max_value=2**63-1)) +def pb_sint32(): return nullable(st.integers(min_value=-2**31, max_value=2**31-1)) +def pb_sint64(): return nullable(st.integers(min_value=-2**63, max_value=2**63-1)) +def pb_float(): return nullable(st.floats(allow_nan=False, allow_infinity=False, width=32)) +def pb_double(): return nullable(st.floats(allow_nan=False, allow_infinity=False, width=64)) +def draw_pb_double(draw): + # -0.0 is an edge case that is not handled by the custom serialization library + double = draw(pb_double()) + hc.assume(str(double) != "-0.0") + return double +def pb_fixed64(): return pb_uint64() +def pb_fixed32(): return pb_uint32() +def pb_sfixed64(): return pb_int64() +def pb_sfixed32(): return pb_int32() +def pb_bool(): return nullable(st.booleans()) +def pb_string(): return nullable(st.text(max_size=20)) +def pb_bytes(): return nullable(st.binary(max_size=20)) +def draw_pb_enum(draw, enum): + # Sample int val of enum, will be converted to member in encode_recurse + # Sample from pb2 values as it is the source of truth + return draw(nullable(st.sampled_from([member for member in enum.values()]))) +def pb_repeated(type): return nullable(st.lists(type, max_size=3)) # limit the size of the repeated field to speed up testing +def pb_span_id(): return nullable(st.binary(min_size=8, max_size=8)) +def pb_trace_id(): return nullable(st.binary(min_size=16, max_size=16)) +# For drawing oneof fields +# call with pb_oneof(draw, field1=pb_type1_callable, field2=pb_type2_callable, ...) +def pb_oneof(draw, **kwargs): + n = len(kwargs) + r = draw(st.integers(min_value=0, max_value=n-1)) + k, v = list(kwargs.items())[r] + return {k: draw(v())} +def pb_message(type): + return nullable(type) + +SF = "_sf" +PB = "_pb2" + +# Strategies for generating opentelemetry-proto types +@st.composite +def instrumentation_scope(draw): + return { + SF: common_sf.InstrumentationScope, + PB: common_pb2.InstrumentationScope, + "name": draw(pb_string()), + "version": draw(pb_string()), + "attributes": draw(pb_repeated(key_value())), + "dropped_attributes_count": draw(pb_uint32()), + } + +@st.composite +def resource(draw): + return { + SF: resource_sf.Resource, + PB: resource_pb2.Resource, + "attributes": draw(pb_repeated(key_value())), + "dropped_attributes_count": draw(pb_uint32()), + } + +@st.composite +def any_value(draw): + return { + SF: common_sf.AnyValue, + PB: common_pb2.AnyValue, + **pb_oneof( + draw, + string_value=pb_string, + bool_value=pb_bool, + int_value=pb_int64, + double_value=pb_double, + array_value=array_value, + kvlist_value=key_value_list, + bytes_value=pb_bytes, + ), + } + +@st.composite +def array_value(draw): + return { + SF: common_sf.ArrayValue, + PB: common_pb2.ArrayValue, + "values": draw(pb_repeated(any_value())), + } + +@st.composite +def key_value(draw): + return { + SF: common_sf.KeyValue, + PB: common_pb2.KeyValue, + "key": draw(pb_string()), + "value": draw(any_value()), + } + +@st.composite +def key_value_list(draw): + return { + SF: common_sf.KeyValueList, + PB: common_pb2.KeyValueList, + "values": draw(pb_repeated(key_value())), + } + +@st.composite +def logs_data(draw): + @st.composite + def log_record(draw): + return { + SF: logs_sf.LogRecord, + PB: logs_pb2.LogRecord, + "time_unix_nano": draw(pb_fixed64()), + "observed_time_unix_nano": draw(pb_fixed64()), + "severity_number": draw_pb_enum(draw, logs_pb2.SeverityNumber), + "severity_text": draw(pb_string()), + "body": draw(pb_message(any_value())), + "attributes": draw(pb_repeated(key_value())), + "dropped_attributes_count": draw(pb_uint32()), + "flags": draw(pb_fixed32()), + "span_id": draw(pb_span_id()), + "trace_id": draw(pb_trace_id()), + } + + @st.composite + def scope_logs(draw): + return { + SF: logs_sf.ScopeLogs, + PB: logs_pb2.ScopeLogs, + "scope": draw(pb_message(instrumentation_scope())), + "log_records": draw(pb_repeated(log_record())), + "schema_url": draw(pb_string()), + } + + @st.composite + def resource_logs(draw): + return { + SF: logs_sf.ResourceLogs, + PB: logs_pb2.ResourceLogs, + "resource": draw(pb_message(resource())), + "scope_logs": draw(pb_repeated(scope_logs())), + "schema_url": draw(pb_string()), + } + + return { + SF: logs_sf.LogsData, + PB: logs_pb2.LogsData, + "resource_logs": draw(pb_repeated(resource_logs())), + } + +@st.composite +def traces_data(draw): + @st.composite + def event(draw): + return { + SF: trace_sf.Span.Event, + PB: trace_pb2.Span.Event, + "time_unix_nano": draw(pb_fixed64()), + "name": draw(pb_string()), + "attributes": draw(pb_repeated(key_value())), + "dropped_attributes_count": draw(pb_uint32()), + } + + @st.composite + def link(draw): + return { + SF: trace_sf.Span.Link, + PB: trace_pb2.Span.Link, + "trace_id": draw(pb_trace_id()), + "span_id": draw(pb_span_id()), + "trace_state": draw(pb_string()), + "attributes": draw(pb_repeated(key_value())), + "dropped_attributes_count": draw(pb_uint32()), + "flags": draw(pb_fixed32()), + } + + @st.composite + def status(draw): + return { + SF: trace_sf.Status, + PB: trace_pb2.Status, + "code": draw_pb_enum(draw, trace_pb2.Status.StatusCode), + "message": draw(pb_string()), + } + + @st.composite + def span(draw): + return { + SF: trace_sf.Span, + PB: trace_pb2.Span, + "trace_id": draw(pb_trace_id()), + "span_id": draw(pb_span_id()), + "trace_state": draw(pb_string()), + "parent_span_id": draw(pb_span_id()), + "name": draw(pb_string()), + "kind": draw_pb_enum(draw, trace_pb2.Span.SpanKind), + "start_time_unix_nano": draw(pb_fixed64()), + "end_time_unix_nano": draw(pb_fixed64()), + "attributes": draw(pb_repeated(key_value())), + "events": draw(pb_repeated(event())), + "links": draw(pb_repeated(link())), + "status": draw(pb_message(status())), + "dropped_attributes_count": draw(pb_uint32()), + "dropped_events_count": draw(pb_uint32()), + "dropped_links_count": draw(pb_uint32()), + "flags": draw(pb_fixed32()), + } + + @st.composite + def scope_spans(draw): + return { + SF: trace_sf.ScopeSpans, + PB: trace_pb2.ScopeSpans, + "scope": draw(pb_message(instrumentation_scope())), + "spans": draw(pb_repeated(span())), + "schema_url": draw(pb_string()), + } + + @st.composite + def resource_spans(draw): + return { + SF: trace_sf.ResourceSpans, + PB: trace_pb2.ResourceSpans, + "resource": draw(pb_message(resource())), + "scope_spans": draw(pb_repeated(scope_spans())), + "schema_url": draw(pb_string()), + } + + return { + SF: trace_sf.TracesData, + PB: trace_pb2.TracesData, + "resource_spans": draw(pb_repeated(resource_spans())), + } + +@st.composite +def metrics_data(draw): + @st.composite + def exemplar(draw): + return { + SF: metrics_sf.Exemplar, + PB: metrics_pb2.Exemplar, + **pb_oneof( + draw, + as_double=pb_double, + as_int=pb_sfixed64, + ), + "time_unix_nano": draw(pb_fixed64()), + "trace_id": draw(pb_trace_id()), + "span_id": draw(pb_span_id()), + "filtered_attributes": draw(pb_repeated(key_value())), + } + + @st.composite + def value_at_quantile(draw): + return { + SF: metrics_sf.SummaryDataPoint.ValueAtQuantile, + PB: metrics_pb2.SummaryDataPoint.ValueAtQuantile, + "quantile": draw_pb_double(draw), + "value": draw_pb_double(draw), + } + + @st.composite + def summary_data_point(draw): + return { + SF: metrics_sf.SummaryDataPoint, + PB: metrics_pb2.SummaryDataPoint, + "start_time_unix_nano": draw(pb_fixed64()), + "time_unix_nano": draw(pb_fixed64()), + "count": draw(pb_fixed64()), + "sum": draw_pb_double(draw), + "quantile_values": draw(pb_repeated(value_at_quantile())), + "attributes": draw(pb_repeated(key_value())), + "flags": draw(pb_uint32()), + } + + @st.composite + def buckets(draw): + return { + SF: metrics_sf.ExponentialHistogramDataPoint.Buckets, + PB: metrics_pb2.ExponentialHistogramDataPoint.Buckets, + "offset": draw(pb_sint32()), + "bucket_counts": draw(pb_repeated(pb_uint64())), + } + + @st.composite + def exponential_histogram_data_point(draw): + return { + SF: metrics_sf.ExponentialHistogramDataPoint, + PB: metrics_pb2.ExponentialHistogramDataPoint, + "start_time_unix_nano": draw(pb_fixed64()), + "time_unix_nano": draw(pb_fixed64()), + "count": draw(pb_fixed64()), + "sum": draw_pb_double(draw), + **pb_oneof( + draw, + positive=buckets, + negative=buckets, + ), + "attributes": draw(pb_repeated(key_value())), + "flags": draw(pb_uint32()), + "exemplars": draw(pb_repeated(exemplar())), + "max": draw_pb_double(draw), + "zero_threshold": draw_pb_double(draw), + } + + @st.composite + def histogram_data_point(draw): + return { + SF: metrics_sf.HistogramDataPoint, + PB: metrics_pb2.HistogramDataPoint, + "start_time_unix_nano": draw(pb_fixed64()), + "time_unix_nano": draw(pb_fixed64()), + "count": draw(pb_fixed64()), + "sum": draw_pb_double(draw), + "bucket_counts": draw(pb_repeated(pb_uint64())), + "attributes": draw(pb_repeated(key_value())), + "flags": draw(pb_uint32()), + "exemplars": draw(pb_repeated(exemplar())), + "explicit_bounds": draw(pb_repeated(pb_double())), + **pb_oneof( + draw, + max=pb_double, + min=pb_double, + ), + } + + @st.composite + def number_data_point(draw): + return { + SF: metrics_sf.NumberDataPoint, + PB: metrics_pb2.NumberDataPoint, + "start_time_unix_nano": draw(pb_fixed64()), + "time_unix_nano": draw(pb_fixed64()), + **pb_oneof( + draw, + as_int=pb_sfixed64, + as_double=pb_double, + ), + "exemplars": draw(pb_repeated(exemplar())), + "attributes": draw(pb_repeated(key_value())), + "flags": draw(pb_uint32()), + } + + @st.composite + def summary(draw): + return { + SF: metrics_sf.Summary, + PB: metrics_pb2.Summary, + "data_points": draw(pb_repeated(summary_data_point())), + } + + @st.composite + def exponential_histogram(draw): + return { + SF: metrics_sf.ExponentialHistogram, + PB: metrics_pb2.ExponentialHistogram, + "data_points": draw(pb_repeated(exponential_histogram_data_point())), + "aggregation_temporality": draw_pb_enum(draw, metrics_pb2.AggregationTemporality), + } + + @st.composite + def histogram(draw): + return { + SF: metrics_sf.Histogram, + PB: metrics_pb2.Histogram, + "data_points": draw(pb_repeated(histogram_data_point())), + "aggregation_temporality": draw_pb_enum(draw, metrics_pb2.AggregationTemporality), + } + + @st.composite + def sum(draw): + return { + SF: metrics_sf.Sum, + PB: metrics_pb2.Sum, + "data_points": draw(pb_repeated(number_data_point())), + "aggregation_temporality": draw_pb_enum(draw, metrics_pb2.AggregationTemporality), + "is_monotonic": draw(pb_bool()), + } + + @st.composite + def gauge(draw): + return { + SF: metrics_sf.Gauge, + PB: metrics_pb2.Gauge, + "data_points": draw(pb_repeated(number_data_point())), + } + + @st.composite + def metric(draw): + return { + SF: metrics_sf.Metric, + PB: metrics_pb2.Metric, + "name": draw(pb_string()), + "description": draw(pb_string()), + "unit": draw(pb_string()), + **pb_oneof( + draw, + gauge=gauge, + sum=sum, + summary=summary, + histogram=histogram, + exponential_histogram=exponential_histogram, + ), + "metadata": draw(pb_repeated(key_value())), + } + + @st.composite + def scope_metrics(draw): + return { + SF: metrics_sf.ScopeMetrics, + PB: metrics_pb2.ScopeMetrics, + "scope": draw(pb_message(instrumentation_scope())), + "metrics": draw(pb_repeated(metric())), + "schema_url": draw(pb_string()), + } + + @st.composite + def resource_metrics(draw): + return { + SF: metrics_sf.ResourceMetrics, + PB: metrics_pb2.ResourceMetrics, + "resource": draw(pb_message(resource())), + "scope_metrics": draw(pb_repeated(scope_metrics())), + "schema_url": draw(pb_string()), + } + + return { + SF: metrics_sf.MetricsData, + PB: metrics_pb2.MetricsData, + "resource_metrics": draw(pb_repeated(resource_metrics())), + } + + +# Helper functions to recursively encode protobuf types using the generated args +# and the given serialization strategy +def encode_recurse(obj: Dict[str, Any], strategy: str) -> Any: + kwargs = {} + for key, value in obj.items(): + if key in [SF, PB]: + continue + elif value is None: + continue + elif isinstance(value, Mapping): + kwargs[key] = encode_recurse(value, strategy) + elif isinstance(value, List) and value and isinstance(value[0], Mapping): + kwargs[key] = [encode_recurse(v, strategy) for v in value if v is not None] + elif isinstance(value, List): + kwargs[key] = [v for v in value if v is not None] + else: + kwargs[key] = value + return obj[strategy](**kwargs) + +class TestProtoSerialization(unittest.TestCase): + @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.too_slow]) + @hypothesis.given(logs_data()) + def test_log_data(self, logs_data): + self.assertEqual( + encode_recurse(logs_data, PB).SerializeToString(deterministic=True), + bytes(encode_recurse(logs_data, SF)) + ) + + @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.too_slow]) + @hypothesis.given(traces_data()) + def test_trace_data(self, traces_data): + self.assertEqual( + encode_recurse(traces_data, PB).SerializeToString(deterministic=True), + bytes(encode_recurse(traces_data, SF)) + ) + + @hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.too_slow]) + @hypothesis.given(metrics_data()) + def test_metrics_data(self, metrics_data): + self.assertEqual( + encode_recurse(metrics_data, PB).SerializeToString(deterministic=True), + bytes(encode_recurse(metrics_data, SF)) + ) diff --git a/tests/test_protoc_plugin.py b/tests/test_protoc_plugin.py new file mode 100644 index 0000000..b0508d3 --- /dev/null +++ b/tests/test_protoc_plugin.py @@ -0,0 +1,91 @@ +""" +Test protoc code generator plugin for custom protoc message types +""" +import unittest +import tempfile +import subprocess +import os + +# Import into globals() so generated code string can be compiled +from snowflake.telemetry._internal.serialize import * + +class TestProtocPlugin(unittest.TestCase): + def namespace_serialize_message(self, message_type: str, local_namespace: dict, **kwargs) -> bytes: + assert message_type in local_namespace, f"Message type {message_type} not found in local namespace" + return local_namespace[message_type](**kwargs).SerializeToString() + + def test_protoc_plugin(self): + with tempfile.NamedTemporaryFile(suffix=".proto", mode="w", delete=False) as proto_file: + # Define a simple proto file + proto_file.write( + """syntax = "proto3"; +package opentelemetry.proto.common.v1; + +message AnyValue { + oneof value { + string string_value = 1; + bool bool_value = 2; + int64 int_value = 3; + double double_value = 4; + ArrayValue array_value = 5; + KeyValueList kvlist_value = 6; + bytes bytes_value = 7; + } +} + +message ArrayValue { + repeated AnyValue values = 1; +} + +message KeyValueList { + repeated KeyValue values = 1; +} + +message KeyValue { + string key = 1; + AnyValue value = 2; +} + +message InstrumentationScope { + string name = 1; + string version = 2; + repeated KeyValue attributes = 3; + uint32 dropped_attributes_count = 4; +} +""" + ) + proto_file.flush() + proto_file.close() + + proto_file_dir = os.path.dirname(proto_file.name) + proto_file_name = os.path.basename(proto_file.name) + + # Run protoc with custom plugin to generate serialization code for messages + result = subprocess.run([ + "python", + "-m", + "grpc_tools.protoc", + "-I", + proto_file_dir, + "--plugin=protoc-gen-custom-plugin=scripts/plugin.py", + f"--custom-plugin_out={tempfile.gettempdir()}", + proto_file_name, + ], capture_output=True) + + # Ensure protoc ran successfully + self.assertEqual(result.returncode, 0) + + generated_code_file_dir = tempfile.gettempdir() + generated_code_file_name = proto_file_name.replace(".proto", "_marshaler.py") + generated_code_file = os.path.join(generated_code_file_dir, generated_code_file_name) + + # Ensure generated code file exists + self.assertTrue(os.path.exists(generated_code_file)) + + # Ensure code can be executed and serializes correctly + with open(generated_code_file, "r") as f: + generated_code = f.read() + local_namespace = {} + eval(compile(generated_code, generated_code_file, "exec"), globals(), local_namespace) + + self.assertEqual(b'\n\x04test', self.namespace_serialize_message("AnyValue", local_namespace, string_value="test"))