diff --git a/scripts/plugin.py b/scripts/plugin.py index dcef463..5740236 100755 --- a/scripts/plugin.py +++ b/scripts/plugin.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 +from __future__ import annotations import re import os import sys -import struct import inspect from enum import IntEnum from typing import List, Optional @@ -24,7 +24,9 @@ import black import isort.api -INLINE_OPTIMIZATION = True # Whether to inline the size and serialization functions +INLINE_OPTIMIZATION = True +FILE_PATH_PREFIX = "snowflake.telemetry._internal" +FILE_NAME_SUFFIX = "_marshaler" # Inline utility functions @@ -126,7 +128,7 @@ def from_descriptor(descriptor: EnumValueDescriptorProto) -> "EnumValueTemplate" @dataclass class EnumTemplate: name: str - values: List["EnumValueTemplate"] = field(default_factory=list) + values: List[EnumValueTemplate] = field(default_factory=list) @staticmethod def from_descriptor(descriptor: EnumDescriptorProto) -> "EnumTemplate": @@ -147,15 +149,11 @@ def tag_to_repr_varint(tag: int) -> str: class FieldTemplate: name: str attr_name: str - generator: str number: int - tag: str + generator: str python_type: str proto_type: str - repeated: bool - group: str default_val: str - encode_presence: bool serialize_field_inline: str size_field_inline: str @@ -217,27 +215,22 @@ def from_descriptor(descriptor: FieldDescriptorProto, group: Optional[str] = Non # Inline the size and serialization functions for the field if INLINE_OPTIMIZATION: serialize_field_inline = inline_serialize_function(proto_type, attr_name, tag) - serialize_field_inline = add_presence_check(proto_type, encode_presence, attr_name, serialize_field_inline) size_field_inline = inline_size_function(proto_type, attr_name, tag) - size_field_inline = add_presence_check(proto_type, encode_presence, attr_name, size_field_inline) else: serialize_field_inline = f"self.serialize_{proto_type}(out, {tag}, self.{attr_name})" - serialize_field_inline = add_presence_check(proto_type, encode_presence, attr_name, serialize_field_inline) size_field_inline = f"size += self.size_{proto_type}({tag}, self.{attr_name})" - size_field_inline = add_presence_check(proto_type, encode_presence, attr_name, size_field_inline) + + serialize_field_inline = add_presence_check(proto_type, encode_presence, attr_name, serialize_field_inline) + size_field_inline = add_presence_check(proto_type, encode_presence, attr_name, size_field_inline) return FieldTemplate( name=field_name, attr_name=attr_name, - generator=generator, - tag=tag, number=descriptor.number, + generator=generator, python_type=python_type, proto_type=proto_type, - repeated=repeated, - group=group, default_val=default_val, - encode_presence=encode_presence, serialize_field_inline=serialize_field_inline, size_field_inline=size_field_inline, ) @@ -247,9 +240,8 @@ class MessageTemplate: name: str super_class_init: str fields: List[FieldTemplate] = field(default_factory=list) - enums: List["EnumTemplate"] = field(default_factory=list) - messages: List["MessageTemplate"] = field(default_factory=list) - type_hints: List[str] = field(default_factory=list) + enums: List[EnumTemplate] = field(default_factory=list) + messages: List[MessageTemplate] = field(default_factory=list) @staticmethod def from_descriptor(descriptor: DescriptorProto) -> "MessageTemplate": @@ -276,8 +268,8 @@ def get_group(field: FieldDescriptorProto) -> str: @dataclass class FileTemplate: - messages: List["MessageTemplate"] = field(default_factory=list) - enums: List["EnumTemplate"] = field(default_factory=list) + messages: List[MessageTemplate] = field(default_factory=list) + enums: List[EnumTemplate] = field(default_factory=list) imports: List[str] = field(default_factory=list) name: str = "" @@ -291,7 +283,7 @@ def from_descriptor(descriptor: FileDescriptorProto) -> "FileTemplate": if descriptor.name.startswith(path): continue path = path.replace("/", ".") - path = "snowflake.telemetry._internal." + path + "_marshaler" + path = f"{FILE_PATH_PREFIX}.{path}{FILE_NAME_SUFFIX}" imports.append(path) return FileTemplate( @@ -314,7 +306,7 @@ def main(): jinja_body_template = template_env.get_template("template.py.jinja2") for proto_file in request.proto_file: - file_name = re.sub(r"\.proto$", "_marshaler.py", proto_file.name) + file_name = re.sub(r"\.proto$", f"{FILE_NAME_SUFFIX}.py", proto_file.name) file_descriptor_proto = proto_file file_template = FileTemplate.from_descriptor(file_descriptor_proto)