Skip to content

Commit

Permalink
Cleanup plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jopel committed Nov 12, 2024
1 parent 0378493 commit 05caebb
Showing 1 changed file with 16 additions and 24 deletions.
40 changes: 16 additions & 24 deletions scripts/plugin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/usr/bin/env python3
from __future__ import annotations

import re
import os
import sys
import struct
import inspect
from enum import IntEnum
from typing import List, Optional
Expand All @@ -24,7 +24,9 @@
import black
import isort.api

INLINE_OPTIMIZATION = True # Whether to inline the size and serialization functions
INLINE_OPTIMIZATION = True
FILE_PATH_PREFIX = "snowflake.telemetry._internal"
FILE_NAME_SUFFIX = "_marshaler"

# Inline utility functions

Expand Down Expand Up @@ -126,7 +128,7 @@ def from_descriptor(descriptor: EnumValueDescriptorProto) -> "EnumValueTemplate"
@dataclass
class EnumTemplate:
name: str
values: List["EnumValueTemplate"] = field(default_factory=list)
values: List[EnumValueTemplate] = field(default_factory=list)

@staticmethod
def from_descriptor(descriptor: EnumDescriptorProto) -> "EnumTemplate":
Expand All @@ -147,15 +149,11 @@ def tag_to_repr_varint(tag: int) -> str:
class FieldTemplate:
name: str
attr_name: str
generator: str
number: int
tag: str
generator: str
python_type: str
proto_type: str
repeated: bool
group: str
default_val: str
encode_presence: bool
serialize_field_inline: str
size_field_inline: str

Expand Down Expand Up @@ -217,27 +215,22 @@ def from_descriptor(descriptor: FieldDescriptorProto, group: Optional[str] = Non
# Inline the size and serialization functions for the field
if INLINE_OPTIMIZATION:
serialize_field_inline = inline_serialize_function(proto_type, attr_name, tag)
serialize_field_inline = add_presence_check(proto_type, encode_presence, attr_name, serialize_field_inline)
size_field_inline = inline_size_function(proto_type, attr_name, tag)
size_field_inline = add_presence_check(proto_type, encode_presence, attr_name, size_field_inline)
else:
serialize_field_inline = f"self.serialize_{proto_type}(out, {tag}, self.{attr_name})"
serialize_field_inline = add_presence_check(proto_type, encode_presence, attr_name, serialize_field_inline)
size_field_inline = f"size += self.size_{proto_type}({tag}, self.{attr_name})"
size_field_inline = add_presence_check(proto_type, encode_presence, attr_name, size_field_inline)

serialize_field_inline = add_presence_check(proto_type, encode_presence, attr_name, serialize_field_inline)
size_field_inline = add_presence_check(proto_type, encode_presence, attr_name, size_field_inline)

return FieldTemplate(
name=field_name,
attr_name=attr_name,
generator=generator,
tag=tag,
number=descriptor.number,
generator=generator,
python_type=python_type,
proto_type=proto_type,
repeated=repeated,
group=group,
default_val=default_val,
encode_presence=encode_presence,
serialize_field_inline=serialize_field_inline,
size_field_inline=size_field_inline,
)
Expand All @@ -247,9 +240,8 @@ class MessageTemplate:
name: str
super_class_init: str
fields: List[FieldTemplate] = field(default_factory=list)
enums: List["EnumTemplate"] = field(default_factory=list)
messages: List["MessageTemplate"] = field(default_factory=list)
type_hints: List[str] = field(default_factory=list)
enums: List[EnumTemplate] = field(default_factory=list)
messages: List[MessageTemplate] = field(default_factory=list)

@staticmethod
def from_descriptor(descriptor: DescriptorProto) -> "MessageTemplate":
Expand All @@ -276,8 +268,8 @@ def get_group(field: FieldDescriptorProto) -> str:

@dataclass
class FileTemplate:
messages: List["MessageTemplate"] = field(default_factory=list)
enums: List["EnumTemplate"] = field(default_factory=list)
messages: List[MessageTemplate] = field(default_factory=list)
enums: List[EnumTemplate] = field(default_factory=list)
imports: List[str] = field(default_factory=list)
name: str = ""

Expand All @@ -291,7 +283,7 @@ def from_descriptor(descriptor: FileDescriptorProto) -> "FileTemplate":
if descriptor.name.startswith(path):
continue
path = path.replace("/", ".")
path = "snowflake.telemetry._internal." + path + "_marshaler"
path = f"{FILE_PATH_PREFIX}.{path}{FILE_NAME_SUFFIX}"
imports.append(path)

return FileTemplate(
Expand All @@ -314,7 +306,7 @@ def main():
jinja_body_template = template_env.get_template("template.py.jinja2")

for proto_file in request.proto_file:
file_name = re.sub(r"\.proto$", "_marshaler.py", proto_file.name)
file_name = re.sub(r"\.proto$", f"{FILE_NAME_SUFFIX}.py", proto_file.name)
file_descriptor_proto = proto_file

file_template = FileTemplate.from_descriptor(file_descriptor_proto)
Expand Down

0 comments on commit 05caebb

Please sign in to comment.