diff --git a/rollbar/test/test_rollbar.py b/rollbar/test/test_rollbar.py index 319c1c5c..01b74722 100644 --- a/rollbar/test/test_rollbar.py +++ b/rollbar/test/test_rollbar.py @@ -140,6 +140,100 @@ def _raise(): self.assertNotIn('keywordspec', payload['data']['body']['trace']['frames'][-1]) self.assertNotIn('locals', payload['data']['body']['trace']['frames'][-1]) + @mock.patch('rollbar._post_api') + def test_lambda_function_good(self, _post_api): + rollbar.SETTINGS['handler'] = 'thread' + fake_event = {'a': 42} + fake_context = MockLambdaContext(99) + @rollbar.lambda_function + def my_lambda_func(event, context): + return [event['a'], context.x] + + result = my_lambda_func(fake_event, fake_context) + + self.assertEqual(len(result), 2) + self.assertEqual(result[0], 42) + self.assertEqual(result[1], 99) + self.assertEqual(_post_api.called, False) + + rollbar._CURRENT_LAMBDA_CONTEXT = None + rollbar.SETTINGS['handler'] = 'blocking' + + @mock.patch('rollbar._post_api') + def test_lambda_function_bad(self, _post_api): + rollbar.SETTINGS['handler'] = 'thread' + fake_event = {'a': 42} + fake_context = MockLambdaContext(99) + @rollbar.lambda_function + def my_lambda_func(event, context): + raise event['a'] + + result = None + try: + result = my_lambda_func(fake_event, fake_context) + except: + pass + + self.assertEqual(result, None) + self.assertEqual(_post_api.called, True) + + rollbar._CURRENT_LAMBDA_CONTEXT = None + rollbar.SETTINGS['handler'] = 'blocking' + + @mock.patch('rollbar._post_api') + def test_lambda_function_method_good(self, _post_api): + rollbar.SETTINGS['handler'] = 'thread' + fake_event = {'a': 42} + fake_context = MockLambdaContext(99) + + class LambdaClass(object): + def __init__(self): + self.a = 13 + + def my_lambda_func(self, event, context): + return [event['a'], context.x, self.a] + + app = LambdaClass() + app.my_lambda_func = rollbar.lambda_function(app.my_lambda_func) + result = app.my_lambda_func(fake_event, fake_context) + + self.assertEqual(len(result), 3) + self.assertEqual(result[0], 42) + self.assertEqual(result[1], 99) + self.assertEqual(result[2], 13) + self.assertEqual(_post_api.called, False) + + rollbar._CURRENT_LAMBDA_CONTEXT = None + rollbar.SETTINGS['handler'] = 'blocking' + + @mock.patch('rollbar._post_api') + def test_lambda_function_method_bad(self, _post_api): + rollbar.SETTINGS['handler'] = 'thread' + fake_event = {'a': 42} + fake_context = MockLambdaContext(99) + + class LambdaClass(object): + def __init__(self): + self.a = 13 + + def my_lambda_func(self, event, context): + raise self.a + + app = LambdaClass() + app.my_lambda_func = rollbar.lambda_function(app.my_lambda_func) + + result = None + try: + result = app.my_lambda_func(fake_event, fake_context) + except: + pass + + self.assertEqual(result, None) + self.assertEqual(_post_api.called, True) + + rollbar._CURRENT_LAMBDA_CONTEXT = None + rollbar.SETTINGS['handler'] = 'blocking' + @mock.patch('rollbar.send_payload') def test_report_exception_with_cause(self, send_payload): def _raise_cause(): @@ -1123,6 +1217,16 @@ def content(self): def json(self): return self.json_data +class MockLambdaContext(object): + def __init__(self, x): + self.function_name = 1 + self.function_version = 2 + self.invoked_function_arn = 3 + self.aws_request_id = 4 + self.x = x + + def get_remaining_time_in_millis(self): + 42 if __name__ == '__main__': unittest.main()