Skip to content

Commit

Permalink
Nullable arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jopel committed Nov 13, 2024
1 parent 8d31aee commit 4b465c3
Showing 1 changed file with 34 additions and 27 deletions.
61 changes: 34 additions & 27 deletions tests/test_proto_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import (
Any,
Dict,
Mapping,
Sequence
List,
Mapping,
)
import unittest
import hypothesis
Expand All @@ -24,14 +24,15 @@
import snowflake.telemetry._internal.opentelemetry.proto.resource.v1.resource_marshaler as resource_sf

# Strategy for generating protobuf types
def pb_uint32(): return st.integers(min_value=0, max_value=2**32-1)
def pb_uint64(): return st.integers(min_value=0, max_value=2**64-1)
def pb_int32(): return st.integers(min_value=-2**31, max_value=2**31-1)
def pb_int64(): return st.integers(min_value=-2**63, max_value=2**63-1)
def pb_sint32(): return st.integers(min_value=-2**31, max_value=2**31-1)
def pb_sint64(): return st.integers(min_value=-2**63, max_value=2**63-1)
def pb_float(): return st.floats(allow_nan=False, allow_infinity=False, width=32)
def pb_double(): return st.floats(allow_nan=False, allow_infinity=False, width=64)
def nullable(type): return st.one_of(st.none(), type)
def pb_uint32(): return nullable(st.integers(min_value=0, max_value=2**32-1))
def pb_uint64(): return nullable(st.integers(min_value=0, max_value=2**64-1))
def pb_int32(): return nullable(st.integers(min_value=-2**31, max_value=2**31-1))
def pb_int64(): return nullable(st.integers(min_value=-2**63, max_value=2**63-1))
def pb_sint32(): return nullable(st.integers(min_value=-2**31, max_value=2**31-1))
def pb_sint64(): return nullable(st.integers(min_value=-2**63, max_value=2**63-1))
def pb_float(): return nullable(st.floats(allow_nan=False, allow_infinity=False, width=32))
def pb_double(): return nullable(st.floats(allow_nan=False, allow_infinity=False, width=64))
def draw_pb_double(draw):
# -0.0 is an edge case that is not handled by the custom serialization library
double = draw(pb_double())
Expand All @@ -41,23 +42,25 @@ def pb_fixed64(): return pb_uint64()
def pb_fixed32(): return pb_uint32()
def pb_sfixed64(): return pb_int64()
def pb_sfixed32(): return pb_int32()
def pb_bool(): return st.booleans()
def pb_string(): return st.text(max_size=20)
def pb_bytes(): return st.binary(max_size=20)
def pb_bool(): return nullable(st.booleans())
def pb_string(): return nullable(st.text(max_size=20))
def pb_bytes(): return nullable(st.binary(max_size=20))
def draw_pb_enum(draw, enum):
# Sample int val of enum, will be converted to member in encode_recurse
# Sample from pb2 values as it is the source of truth
return draw(st.sampled_from([member for member in enum.values()]))
def pb_repeated(type): return st.lists(type, max_size=3) # limit the size of the repeated field to speed up testing
def pb_span_id(): return st.binary(min_size=8, max_size=8)
def pb_trace_id(): return st.binary(min_size=16, max_size=16)
return draw(nullable(st.sampled_from([member for member in enum.values()])))
def pb_repeated(type): return nullable(st.lists(type, max_size=3)) # limit the size of the repeated field to speed up testing
def pb_span_id(): return nullable(st.binary(min_size=8, max_size=8))
def pb_trace_id(): return nullable(st.binary(min_size=16, max_size=16))
# For drawing oneof fields
# call with pb_oneof(draw, field1=pb_type1_callable, field2=pb_type2_callable, ...)
def pb_oneof(draw, **kwargs):
n = len(kwargs)
r = draw(st.integers(min_value=0, max_value=n-1))
k, v = list(kwargs.items())[r]
return {k: draw(v())}
def pb_message(type):
return nullable(type)

SF = "_sf"
PB = "_pb2"
Expand Down Expand Up @@ -136,7 +139,7 @@ def log_record(draw):
"observed_time_unix_nano": draw(pb_fixed64()),
"severity_number": draw_pb_enum(draw, logs_pb2.SeverityNumber),
"severity_text": draw(pb_string()),
"body": draw(any_value()),
"body": draw(pb_message(any_value())),
"attributes": draw(pb_repeated(key_value())),
"dropped_attributes_count": draw(pb_uint32()),
"flags": draw(pb_fixed32()),
Expand All @@ -149,7 +152,7 @@ def scope_logs(draw):
return {
SF: logs_sf.ScopeLogs,
PB: logs_pb2.ScopeLogs,
"scope": draw(instrumentation_scope()),
"scope": draw(pb_message(instrumentation_scope())),
"log_records": draw(pb_repeated(log_record())),
"schema_url": draw(pb_string()),
}
Expand All @@ -159,7 +162,7 @@ def resource_logs(draw):
return {
SF: logs_sf.ResourceLogs,
PB: logs_pb2.ResourceLogs,
"resource": draw(resource()),
"resource": draw(pb_message(resource())),
"scope_logs": draw(pb_repeated(scope_logs())),
"schema_url": draw(pb_string()),
}
Expand Down Expand Up @@ -221,7 +224,7 @@ def span(draw):
"attributes": draw(pb_repeated(key_value())),
"events": draw(pb_repeated(event())),
"links": draw(pb_repeated(link())),
"status": draw(status()),
"status": draw(pb_message(status())),
"dropped_attributes_count": draw(pb_uint32()),
"dropped_events_count": draw(pb_uint32()),
"dropped_links_count": draw(pb_uint32()),
Expand All @@ -233,7 +236,7 @@ def scope_spans(draw):
return {
SF: trace_sf.ScopeSpans,
PB: trace_pb2.ScopeSpans,
"scope": draw(instrumentation_scope()),
"scope": draw(pb_message(instrumentation_scope())),
"spans": draw(pb_repeated(span())),
"schema_url": draw(pb_string()),
}
Expand All @@ -243,7 +246,7 @@ def resource_spans(draw):
return {
SF: trace_sf.ResourceSpans,
PB: trace_pb2.ResourceSpans,
"resource": draw(resource()),
"resource": draw(pb_message(resource())),
"scope_spans": draw(pb_repeated(scope_spans())),
"schema_url": draw(pb_string()),
}
Expand Down Expand Up @@ -431,7 +434,7 @@ def scope_metrics(draw):
return {
SF: metrics_sf.ScopeMetrics,
PB: metrics_pb2.ScopeMetrics,
"scope": draw(instrumentation_scope()),
"scope": draw(pb_message(instrumentation_scope())),
"metrics": draw(pb_repeated(metric())),
"schema_url": draw(pb_string()),
}
Expand All @@ -441,7 +444,7 @@ def resource_metrics(draw):
return {
SF: metrics_sf.ResourceMetrics,
PB: metrics_pb2.ResourceMetrics,
"resource": draw(resource()),
"resource": draw(pb_message(resource())),
"scope_metrics": draw(pb_repeated(scope_metrics())),
"schema_url": draw(pb_string()),
}
Expand All @@ -460,10 +463,14 @@ def encode_recurse(obj: Dict[str, Any], strategy: str) -> Any:
for key, value in obj.items():
if key in [SF, PB]:
continue
elif value is None:
continue
elif isinstance(value, Mapping):
kwargs[key] = encode_recurse(value, strategy)
elif isinstance(value, Sequence) and value and isinstance(value[0], Mapping):
kwargs[key] = [encode_recurse(v, strategy) for v in value]
elif isinstance(value, List) and value and isinstance(value[0], Mapping):
kwargs[key] = [encode_recurse(v, strategy) for v in value if v is not None]
elif isinstance(value, List):
kwargs[key] = [v for v in value if v is not None]
else:
kwargs[key] = value
return obj[strategy](**kwargs)
Expand Down

0 comments on commit 4b465c3

Please sign in to comment.