Skip to content

Commit

Permalink
Add tests for encoding/decoding highly recursive messages
Browse files Browse the repository at this point in the history
This adds a test case found in the `orjson` repo to ensure that we
properly respect recursion limits when encoding or decoding deeply
recursive messages.

No code change is needed at this time, we properly manage recursion
limits already.
  • Loading branch information
jcrist committed Feb 22, 2024
1 parent 43c239e commit ac3ea04
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 31 deletions.
58 changes: 57 additions & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)

import pytest
from utils import temp_module
from utils import temp_module, max_call_depth

try:
import attrs
Expand Down Expand Up @@ -295,6 +295,62 @@ class Custom(metaclass=Metaclass):
dec.decode(msg)


@pytest.mark.skipif(
PY312,
reason=(
"Python 3.12 harcodes the C recursion limit, making this "
"behavior harder to test in CI"
),
)
class TestRecursion:
@staticmethod
def nested(n, is_array):
if is_array:
obj = []
for _ in range(n):
obj = [obj]
else:
obj = {}
for _ in range(n):
obj = {"": obj}
return obj

@pytest.mark.parametrize("is_array", [True, False])
def test_encode_highly_recursive_msg_errors(self, is_array, proto):
N = 200
obj = self.nested(N, is_array)

# Errors if above the recursion limit
with max_call_depth(N // 2):
with pytest.raises(RecursionError):
proto.encode(obj)

# Works if below the recursion limit
with max_call_depth(N * 2):
proto.encode(obj)

@pytest.mark.parametrize("is_array", [True, False])
def test_decode_highly_recursive_msg_errors(self, is_array, proto):
"""Ensure recursion is properly handled when decoding.
Test case seen in https://github.com/ijl/orjson/issues/458."""
N = 200
obj = self.nested(N, is_array)

with max_call_depth(N * 2):
msg = proto.encode(obj)

# Errors if above the recursion limit
with max_call_depth(N // 2):
with pytest.raises(RecursionError):
proto.decode(msg)

# Works if below the recursion limit
with max_call_depth(N * 2):
obj2 = proto.decode(msg)

assert obj2


class TestThreadSafe:
def test_encode_threadsafe(self, proto):
class Nested:
Expand Down
31 changes: 1 addition & 30 deletions tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
import decimal
import enum
import gc
import inspect
import math
import sys
import uuid
from base64 import b64encode
from collections.abc import MutableMapping
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import (
Any,
Expand All @@ -25,7 +23,7 @@
)

import pytest
from utils import temp_module
from utils import temp_module, max_call_depth

import msgspec
from msgspec import Meta, Struct, ValidationError, convert, to_builtins
Expand Down Expand Up @@ -196,33 +194,6 @@ def assert_eq(x, y):
assert x == y


@contextmanager
def max_call_depth(n):
cur_depth = len(inspect.stack(0))
orig = sys.getrecursionlimit()
try:
# Our measure of the current stack depth can be off by a bit. Trying to
# set a recursionlimit < the current depth will raise a RecursionError.
# We just try again with a slightly higher limit, bailing after an
# unreasonable amount of adjustments.
#
# Note that python 3.8 also has a minimum recursion limit of 64, so
# there's some additional fiddliness there.
for i in range(64):
try:
sys.setrecursionlimit(cur_depth + i + n)
break
except RecursionError:
pass
else:
raise ValueError(
"Failed to set low recursion limit, something is wrong here"
)
yield
finally:
sys.setrecursionlimit(orig)


def roundtrip(obj, typ):
return convert(to_builtins(obj), typ)

Expand Down
28 changes: 28 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import inspect
import textwrap
import types
import uuid
Expand All @@ -20,3 +21,30 @@ def temp_module(code):
yield mod
finally:
sys.modules.pop(name, None)


@contextmanager
def max_call_depth(n):
cur_depth = len(inspect.stack(0))
orig = sys.getrecursionlimit()
try:
# Our measure of the current stack depth can be off by a bit. Trying to
# set a recursionlimit < the current depth will raise a RecursionError.
# We just try again with a slightly higher limit, bailing after an
# unreasonable amount of adjustments.
#
# Note that python 3.8 also has a minimum recursion limit of 64, so
# there's some additional fiddliness there.
for i in range(64):
try:
sys.setrecursionlimit(cur_depth + i + n)
break
except RecursionError:
pass
else:
raise ValueError(
"Failed to set low recursion limit, something is wrong here"
)
yield
finally:
sys.setrecursionlimit(orig)

0 comments on commit ac3ea04

Please sign in to comment.