Skip to content

Commit

Permalink
Test exception handling preserves the call stack.
Browse files Browse the repository at this point in the history
This way, using a once decorator will not swallow all exception traces.
  • Loading branch information
aebrahim committed Nov 8, 2023
1 parent fc159ff commit dfba7db
Showing 1 changed file with 78 additions and 20 deletions.
98 changes: 78 additions & 20 deletions once_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
# pylint: disable=missing-function-docstring
import asyncio
import concurrent.futures
import contextlib
import functools
import gc
import inspect
import math
import sys
import threading
import traceback
import unittest
import weakref

Expand Down Expand Up @@ -120,6 +122,41 @@ def counting_fn(*args) -> int:
return counting_fn, counter


class LineCapture:
def __init__(self):
self.line = None

def record_next_line(self):
"""Record the next line in the parent frame"""
self.line = inspect.currentframe().f_back.f_lineno + 1


class ExceptionContextManager:
exception: Exception


@contextlib.contextmanager
def assertRaisesWithLineInStackTrace(test: unittest.TestCase, exception_type, line: LineCapture):
try:
container = ExceptionContextManager()
yield container
except exception_type as exception:
container.exception = exception
traceback_exception = traceback.TracebackException.from_exception(exception)
if not len(traceback_exception.stack):
test.fail("Exception stack not preserved. Did you use the raw assertRaises by mistake?")
locations = [(frame.filename, frame.lineno) for frame in traceback_exception.stack]
line_number = line.line
error_message = [
f"Traceback for exception {repr(exception)} did not have frame on line {line_number}. Exception below\n"
]
error_message.extend(traceback_exception.format())
test.assertIn((__file__, line_number), locations, msg="".join(error_message))

else:
test.fail("expected exception not called")


class TestFunctionInspection(unittest.TestCase):
"""Unit tests for function inspection"""

Expand Down Expand Up @@ -317,33 +354,42 @@ def test_partial(self):

def test_failing_function(self):
counter = Counter()
failing_line = LineCapture()

@once.once
def sample_failing_fn():
nonlocal failing_line
if counter.get_incremented() < 4:
failing_line.record_next_line()
raise ValueError("expected failure")
return 1

with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
sample_failing_fn()
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line) as cm:
sample_failing_fn()
self.assertEqual(cm.exception.args[0], "expected failure")
self.assertEqual(counter.get_incremented(), 2)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
sample_failing_fn()
self.assertEqual(counter.get_incremented(), 3, "Function call incremented the counter")

def test_failing_function_retry_exceptions(self):
counter = Counter()
failing_line = LineCapture()

@once.once(retry_exceptions=True)
def sample_failing_fn():
nonlocal failing_line
if counter.get_incremented() < 4:
failing_line.record_next_line()
raise ValueError("expected failure")
return 1

with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
sample_failing_fn()
self.assertEqual(counter.get_incremented(), 2)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
sample_failing_fn()
# This ensures that this was a new function call, not a cached result.
self.assertEqual(counter.get_incremented(), 4)
Expand All @@ -363,13 +409,15 @@ def yielding_iterator():

def test_failing_generator(self):
counter = Counter()
failing_line = LineCapture()

@once.once
def sample_failing_fn():
yield counter.get_incremented()
result = counter.get_incremented()
yield result
if result == 2:
failing_line.record_next_line()
raise ValueError("expected failure after 2.")

# Both of these calls should return the same results.
Expand All @@ -379,9 +427,9 @@ def sample_failing_fn():
self.assertEqual(next(call2), 1)
self.assertEqual(next(call1), 2)
self.assertEqual(next(call2), 2)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
next(call1)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
next(call2)
# These next 2 calls should also fail.
call3 = sample_failing_fn()
Expand All @@ -390,20 +438,22 @@ def sample_failing_fn():
self.assertEqual(next(call4), 1)
self.assertEqual(next(call3), 2)
self.assertEqual(next(call4), 2)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
next(call3)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
next(call4)

def test_failing_generator_retry_exceptions(self):
counter = Counter()
failing_line = LineCapture()

@once.once(retry_exceptions=True)
def sample_failing_fn():
yield counter.get_incremented()
result = counter.get_incremented()
yield result
if result == 2:
failing_line.record_next_line()
raise ValueError("expected failure after 2.")

# Both of these calls should return the same results.
Expand All @@ -413,9 +463,9 @@ def sample_failing_fn():
self.assertEqual(next(call2), 1)
self.assertEqual(next(call1), 2)
self.assertEqual(next(call2), 2)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
next(call1)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
next(call2)
# These next 2 calls should succeed.
call3 = sample_failing_fn()
Expand Down Expand Up @@ -906,33 +956,37 @@ def execute(*args):

async def test_failing_function(self):
counter = Counter()
failing_line = LineCapture()

@once.once
async def sample_failing_fn():
if counter.get_incremented() < 4:
failing_line.record_next_line()
raise ValueError("expected failure")
return 1

with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
await sample_failing_fn()
self.assertEqual(counter.get_incremented(), 2)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
await sample_failing_fn()
self.assertEqual(counter.get_incremented(), 3, "Function call incremented the counter")

async def test_failing_function_retry_exceptions(self):
counter = Counter()
failing_line = LineCapture()

@once.once(retry_exceptions=True)
async def sample_failing_fn():
if counter.get_incremented() < 4:
failing_line.record_next_line()
raise ValueError("expected failure")
return 1

with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
await sample_failing_fn()
self.assertEqual(counter.get_incremented(), 2)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
await sample_failing_fn()
# This ensures that this was a new function call, not a cached result.
self.assertEqual(counter.get_incremented(), 4)
Expand Down Expand Up @@ -985,13 +1039,15 @@ async def async_yielding_iterator():

async def test_failing_generator(self):
counter = Counter()
failing_line = LineCapture()

@once.once
async def sample_failing_fn():
yield counter.get_incremented()
result = counter.get_incremented()
yield result
if result == 2:
failing_line.record_next_line()
raise ValueError("we raise an error when result is exactly 2")

# Both of these calls should return the same results.
Expand All @@ -1001,9 +1057,9 @@ async def sample_failing_fn():
self.assertEqual(await anext(call2), 1)
self.assertEqual(await anext(call1), 2)
self.assertEqual(await anext(call2), 2)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
await anext(call1)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
await anext(call2)
# These next 2 calls should also fail.
call3 = sample_failing_fn()
Expand All @@ -1012,20 +1068,22 @@ async def sample_failing_fn():
self.assertEqual(await anext(call4), 1)
self.assertEqual(await anext(call3), 2)
self.assertEqual(await anext(call4), 2)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
await anext(call3)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
await anext(call4)

async def test_failing_generator_retry_exceptions(self):
counter = Counter()
failing_line = LineCapture()

@once.once(retry_exceptions=True)
async def sample_failing_fn():
yield counter.get_incremented()
result = counter.get_incremented()
yield result
if result == 2:
failing_line.record_next_line()
raise ValueError("we raise an error when result is exactly 2")

# Both of these calls should return the same results.
Expand All @@ -1035,9 +1093,9 @@ async def sample_failing_fn():
self.assertEqual(await anext(call2), 1)
self.assertEqual(await anext(call1), 2)
self.assertEqual(await anext(call2), 2)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
await anext(call1)
with self.assertRaises(ValueError):
with assertRaisesWithLineInStackTrace(self, ValueError, failing_line):
await anext(call2)
# These next 2 calls should succeed.
call3 = sample_failing_fn()
Expand Down

0 comments on commit dfba7db

Please sign in to comment.