From dfba7db787644260d52f779ac4ed6b7bd9c422ad Mon Sep 17 00:00:00 2001 From: Ali Ebrahim Date: Wed, 8 Nov 2023 05:25:34 +0000 Subject: [PATCH] Test exception handling preserves the call stack. This way, using a once decorator will not swallow all exception traces. --- once_test.py | 98 +++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 78 insertions(+), 20 deletions(-) diff --git a/once_test.py b/once_test.py index 6c10d10..bd46873 100644 --- a/once_test.py +++ b/once_test.py @@ -2,12 +2,14 @@ # pylint: disable=missing-function-docstring import asyncio import concurrent.futures +import contextlib import functools import gc import inspect import math import sys import threading +import traceback import unittest import weakref @@ -120,6 +122,41 @@ def counting_fn(*args) -> int: return counting_fn, counter +class LineCapture: + def __init__(self): + self.line = None + + def record_next_line(self): + """Record the next line in the parent frame""" + self.line = inspect.currentframe().f_back.f_lineno + 1 + + +class ExceptionContextManager: + exception: Exception + + +@contextlib.contextmanager +def assertRaisesWithLineInStackTrace(test: unittest.TestCase, exception_type, line: LineCapture): + try: + container = ExceptionContextManager() + yield container + except exception_type as exception: + container.exception = exception + traceback_exception = traceback.TracebackException.from_exception(exception) + if not len(traceback_exception.stack): + test.fail("Exception stack not preserved. Did you use the raw assertRaises by mistake?") + locations = [(frame.filename, frame.lineno) for frame in traceback_exception.stack] + line_number = line.line + error_message = [ + f"Traceback for exception {repr(exception)} did not have frame on line {line_number}. Exception below\n" + ] + error_message.extend(traceback_exception.format()) + test.assertIn((__file__, line_number), locations, msg="".join(error_message)) + + else: + test.fail("expected exception not called") + + class TestFunctionInspection(unittest.TestCase): """Unit tests for function inspection""" @@ -317,33 +354,42 @@ def test_partial(self): def test_failing_function(self): counter = Counter() + failing_line = LineCapture() @once.once def sample_failing_fn(): + nonlocal failing_line if counter.get_incremented() < 4: + failing_line.record_next_line() raise ValueError("expected failure") return 1 - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): + sample_failing_fn() + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line) as cm: sample_failing_fn() + self.assertEqual(cm.exception.args[0], "expected failure") self.assertEqual(counter.get_incremented(), 2) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): sample_failing_fn() self.assertEqual(counter.get_incremented(), 3, "Function call incremented the counter") def test_failing_function_retry_exceptions(self): counter = Counter() + failing_line = LineCapture() @once.once(retry_exceptions=True) def sample_failing_fn(): + nonlocal failing_line if counter.get_incremented() < 4: + failing_line.record_next_line() raise ValueError("expected failure") return 1 - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): sample_failing_fn() self.assertEqual(counter.get_incremented(), 2) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): sample_failing_fn() # This ensures that this was a new function call, not a cached result. self.assertEqual(counter.get_incremented(), 4) @@ -363,6 +409,7 @@ def yielding_iterator(): def test_failing_generator(self): counter = Counter() + failing_line = LineCapture() @once.once def sample_failing_fn(): @@ -370,6 +417,7 @@ def sample_failing_fn(): result = counter.get_incremented() yield result if result == 2: + failing_line.record_next_line() raise ValueError("expected failure after 2.") # Both of these calls should return the same results. @@ -379,9 +427,9 @@ def sample_failing_fn(): self.assertEqual(next(call2), 1) self.assertEqual(next(call1), 2) self.assertEqual(next(call2), 2) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): next(call1) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): next(call2) # These next 2 calls should also fail. call3 = sample_failing_fn() @@ -390,13 +438,14 @@ def sample_failing_fn(): self.assertEqual(next(call4), 1) self.assertEqual(next(call3), 2) self.assertEqual(next(call4), 2) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): next(call3) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): next(call4) def test_failing_generator_retry_exceptions(self): counter = Counter() + failing_line = LineCapture() @once.once(retry_exceptions=True) def sample_failing_fn(): @@ -404,6 +453,7 @@ def sample_failing_fn(): result = counter.get_incremented() yield result if result == 2: + failing_line.record_next_line() raise ValueError("expected failure after 2.") # Both of these calls should return the same results. @@ -413,9 +463,9 @@ def sample_failing_fn(): self.assertEqual(next(call2), 1) self.assertEqual(next(call1), 2) self.assertEqual(next(call2), 2) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): next(call1) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): next(call2) # These next 2 calls should succeed. call3 = sample_failing_fn() @@ -906,33 +956,37 @@ def execute(*args): async def test_failing_function(self): counter = Counter() + failing_line = LineCapture() @once.once async def sample_failing_fn(): if counter.get_incremented() < 4: + failing_line.record_next_line() raise ValueError("expected failure") return 1 - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): await sample_failing_fn() self.assertEqual(counter.get_incremented(), 2) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): await sample_failing_fn() self.assertEqual(counter.get_incremented(), 3, "Function call incremented the counter") async def test_failing_function_retry_exceptions(self): counter = Counter() + failing_line = LineCapture() @once.once(retry_exceptions=True) async def sample_failing_fn(): if counter.get_incremented() < 4: + failing_line.record_next_line() raise ValueError("expected failure") return 1 - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): await sample_failing_fn() self.assertEqual(counter.get_incremented(), 2) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): await sample_failing_fn() # This ensures that this was a new function call, not a cached result. self.assertEqual(counter.get_incremented(), 4) @@ -985,6 +1039,7 @@ async def async_yielding_iterator(): async def test_failing_generator(self): counter = Counter() + failing_line = LineCapture() @once.once async def sample_failing_fn(): @@ -992,6 +1047,7 @@ async def sample_failing_fn(): result = counter.get_incremented() yield result if result == 2: + failing_line.record_next_line() raise ValueError("we raise an error when result is exactly 2") # Both of these calls should return the same results. @@ -1001,9 +1057,9 @@ async def sample_failing_fn(): self.assertEqual(await anext(call2), 1) self.assertEqual(await anext(call1), 2) self.assertEqual(await anext(call2), 2) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): await anext(call1) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): await anext(call2) # These next 2 calls should also fail. call3 = sample_failing_fn() @@ -1012,13 +1068,14 @@ async def sample_failing_fn(): self.assertEqual(await anext(call4), 1) self.assertEqual(await anext(call3), 2) self.assertEqual(await anext(call4), 2) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): await anext(call3) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): await anext(call4) async def test_failing_generator_retry_exceptions(self): counter = Counter() + failing_line = LineCapture() @once.once(retry_exceptions=True) async def sample_failing_fn(): @@ -1026,6 +1083,7 @@ async def sample_failing_fn(): result = counter.get_incremented() yield result if result == 2: + failing_line.record_next_line() raise ValueError("we raise an error when result is exactly 2") # Both of these calls should return the same results. @@ -1035,9 +1093,9 @@ async def sample_failing_fn(): self.assertEqual(await anext(call2), 1) self.assertEqual(await anext(call1), 2) self.assertEqual(await anext(call2), 2) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): await anext(call1) - with self.assertRaises(ValueError): + with assertRaisesWithLineInStackTrace(self, ValueError, failing_line): await anext(call2) # These next 2 calls should succeed. call3 = sample_failing_fn()