diff --git a/src/Voxel.MiddyNet.HttpCorsMiddleware/HttpCorsMiddleware.cs b/src/Voxel.MiddyNet.HttpCorsMiddleware/HttpCorsMiddleware.cs index 3a0ac3e..0280898 100644 --- a/src/Voxel.MiddyNet.HttpCorsMiddleware/HttpCorsMiddleware.cs +++ b/src/Voxel.MiddyNet.HttpCorsMiddleware/HttpCorsMiddleware.cs @@ -21,6 +21,8 @@ public HttpCorsMiddleware(CorsOptions corsOptions) this.corsOptions = corsOptions; } + public bool InterruptsExecution => false; + private const string DefaultAccessControlAllowOrigin = "*"; private const string AllowOriginHeader = "Access-Control-Allow-Origin"; private const string AllowHeadersHeader = "Access-Control-Allow-Headers"; diff --git a/src/Voxel.MiddyNet.HttpCorsMiddleware/HttpV2CorsMiddleware.cs b/src/Voxel.MiddyNet.HttpCorsMiddleware/HttpV2CorsMiddleware.cs index 0a7f93e..bf0746c 100644 --- a/src/Voxel.MiddyNet.HttpCorsMiddleware/HttpV2CorsMiddleware.cs +++ b/src/Voxel.MiddyNet.HttpCorsMiddleware/HttpV2CorsMiddleware.cs @@ -22,6 +22,8 @@ public HttpV2CorsMiddleware(CorsOptions corsOptions) this.corsOptions = corsOptions; } + public bool InterruptsExecution => false; + private const string DefaultAccessControlAllowOrigin = "*"; private const string AllowOriginHeader = "Access-Control-Allow-Origin"; private const string AllowHeadersHeader = "Access-Control-Allow-Headers"; diff --git a/src/Voxel.MiddyNet.HttpJsonMiddleware/HttpJsonBodyParserMiddleware.cs b/src/Voxel.MiddyNet.HttpJsonMiddleware/HttpJsonBodyParserMiddleware.cs index 0570970..6d89ff3 100644 --- a/src/Voxel.MiddyNet.HttpJsonMiddleware/HttpJsonBodyParserMiddleware.cs +++ b/src/Voxel.MiddyNet.HttpJsonMiddleware/HttpJsonBodyParserMiddleware.cs @@ -12,6 +12,16 @@ public abstract class HttpJsonBodyParserMiddleware } public class HttpJsonBodyParserMiddleware : HttpJsonBodyParserMiddleware, ILambdaMiddleware where T : new() { + + public HttpJsonBodyParserMiddleware(bool interruptsExecution) + { + InterruptsExecution = interruptsExecution; + } + + public HttpJsonBodyParserMiddleware() : this(false){} + + public bool InterruptsExecution { get; } + public Task Before(APIGatewayProxyRequest lambdaEvent, MiddyNetContext context) { if (!HasJsonContentHeaders(lambdaEvent)) diff --git a/src/Voxel.MiddyNet.HttpJsonMiddleware/HttpV2JsonBodyParserMiddleware.cs b/src/Voxel.MiddyNet.HttpJsonMiddleware/HttpV2JsonBodyParserMiddleware.cs index 12593fe..207baaa 100644 --- a/src/Voxel.MiddyNet.HttpJsonMiddleware/HttpV2JsonBodyParserMiddleware.cs +++ b/src/Voxel.MiddyNet.HttpJsonMiddleware/HttpV2JsonBodyParserMiddleware.cs @@ -8,6 +8,15 @@ namespace Voxel.MiddyNet.HttpJsonMiddleware { public class HttpV2JsonBodyParserMiddleware : HttpJsonBodyParserMiddleware, ILambdaMiddleware { + public HttpV2JsonBodyParserMiddleware(bool interruptsExecution) + { + InterruptsExecution = interruptsExecution; + } + + public HttpV2JsonBodyParserMiddleware() : this(false) { } + + public bool InterruptsExecution { get; } + public Task Before(APIGatewayHttpApiV2ProxyRequest lambdaEvent, MiddyNetContext context) { if (!HasJsonContentHeaders(lambdaEvent)) diff --git a/src/Voxel.MiddyNet.ProblemDetailsMiddleware/ProblemDetailsMiddleware.cs b/src/Voxel.MiddyNet.ProblemDetailsMiddleware/ProblemDetailsMiddleware.cs index 3a74766..376cc50 100644 --- a/src/Voxel.MiddyNet.ProblemDetailsMiddleware/ProblemDetailsMiddleware.cs +++ b/src/Voxel.MiddyNet.ProblemDetailsMiddleware/ProblemDetailsMiddleware.cs @@ -10,6 +10,8 @@ public class ProblemDetailsMiddleware : ILambdaMiddleware this.options = options ?? new ProblemDetailsMiddlewareOptions(); + public bool InterruptsExecution => false; + public Task Before(APIGatewayProxyRequest lambdaEvent, MiddyNetContext context) => Task.CompletedTask; public Task After(APIGatewayProxyResponse lambdaResponse, MiddyNetContext context) diff --git a/src/Voxel.MiddyNet.ProblemDetailsMiddleware/ProblemDetailsMiddlewareV2.cs b/src/Voxel.MiddyNet.ProblemDetailsMiddleware/ProblemDetailsMiddlewareV2.cs index c036cbf..066f0ce 100644 --- a/src/Voxel.MiddyNet.ProblemDetailsMiddleware/ProblemDetailsMiddlewareV2.cs +++ b/src/Voxel.MiddyNet.ProblemDetailsMiddleware/ProblemDetailsMiddlewareV2.cs @@ -10,6 +10,8 @@ public class ProblemDetailsMiddlewareV2 : ILambdaMiddleware this.options = options ?? new ProblemDetailsMiddlewareOptions(); + public bool InterruptsExecution => false; + public Task Before(APIGatewayHttpApiV2ProxyRequest lambdaEvent, MiddyNetContext context) => Task.CompletedTask; public Task After(APIGatewayHttpApiV2ProxyResponse lambdaResponse, MiddyNetContext context) diff --git a/src/Voxel.MiddyNet.SSMMiddleware/SSMMiddleware.cs b/src/Voxel.MiddyNet.SSMMiddleware/SSMMiddleware.cs index be540e7..12c2dbe 100644 --- a/src/Voxel.MiddyNet.SSMMiddleware/SSMMiddleware.cs +++ b/src/Voxel.MiddyNet.SSMMiddleware/SSMMiddleware.cs @@ -25,6 +25,8 @@ public SSMMiddleware(SSMOptions ssmOptions, Func this.timeProvider = timeProvider; } + public bool InterruptsExecution => false; + public async Task Before(TReq lambdaEvent, MiddyNetContext context) { foreach (var parameter in ssmOptions.ParametersToGet) diff --git a/src/Voxel.MiddyNet.Tracing.ApiGatewayMiddleware/ApiGatewayHttpApiV2TracingMiddleware.cs b/src/Voxel.MiddyNet.Tracing.ApiGatewayMiddleware/ApiGatewayHttpApiV2TracingMiddleware.cs index 27f8f9d..7738e60 100644 --- a/src/Voxel.MiddyNet.Tracing.ApiGatewayMiddleware/ApiGatewayHttpApiV2TracingMiddleware.cs +++ b/src/Voxel.MiddyNet.Tracing.ApiGatewayMiddleware/ApiGatewayHttpApiV2TracingMiddleware.cs @@ -12,6 +12,8 @@ public class ApiGatewayHttpApiV2TracingMiddleware : ILambdaMiddleware false; + public Task Before(APIGatewayHttpApiV2ProxyRequest apiGatewayEvent, MiddyNetContext context) { var traceParentHeaderValue = string.Empty; diff --git a/src/Voxel.MiddyNet.Tracing.ApiGatewayMiddleware/ApiGatewayTracingMiddleware.cs b/src/Voxel.MiddyNet.Tracing.ApiGatewayMiddleware/ApiGatewayTracingMiddleware.cs index cb8db9b..6438f59 100644 --- a/src/Voxel.MiddyNet.Tracing.ApiGatewayMiddleware/ApiGatewayTracingMiddleware.cs +++ b/src/Voxel.MiddyNet.Tracing.ApiGatewayMiddleware/ApiGatewayTracingMiddleware.cs @@ -12,6 +12,8 @@ public class ApiGatewayTracingMiddleware : ILambdaMiddleware false; + public Task Before(APIGatewayProxyRequest apiGatewayEvent, MiddyNetContext context) { var traceParentHeaderValue = string.Empty; diff --git a/src/Voxel.MiddyNet.Tracing.SNSMiddleware/SNSTracingMiddleware.cs b/src/Voxel.MiddyNet.Tracing.SNSMiddleware/SNSTracingMiddleware.cs index c5df7e5..779ccde 100644 --- a/src/Voxel.MiddyNet.Tracing.SNSMiddleware/SNSTracingMiddleware.cs +++ b/src/Voxel.MiddyNet.Tracing.SNSMiddleware/SNSTracingMiddleware.cs @@ -11,6 +11,8 @@ public class SNSTracingMiddleware : ILambdaMiddleware private const string TraceStateHeaderName = "tracestate"; private const string TraceIdHeaderName = "trace-id"; + public bool InterruptsExecution => false; + public Task Before(SNSEvent snsEvent, MiddyNetContext context) { var snsMessage = snsEvent.Records.First().Sns; diff --git a/src/Voxel.MiddyNet.Tracing.SQSMiddleware/SQSTracingMiddleware.cs b/src/Voxel.MiddyNet.Tracing.SQSMiddleware/SQSTracingMiddleware.cs index 037e1c5..159490d 100644 --- a/src/Voxel.MiddyNet.Tracing.SQSMiddleware/SQSTracingMiddleware.cs +++ b/src/Voxel.MiddyNet.Tracing.SQSMiddleware/SQSTracingMiddleware.cs @@ -11,6 +11,8 @@ public class SQSTracingMiddleware : ILambdaMiddleware private const string TraceStateHeaderName = "tracestate"; private const string TraceIdHeaderName = "trace-id"; + public bool InterruptsExecution => false; + public Task Before(SQSEvent sqsEvent, MiddyNetContext context) { var sqsMessage = sqsEvent.Records.First(); diff --git a/src/Voxel.MiddyNet/ILambdaMiddleware.cs b/src/Voxel.MiddyNet/ILambdaMiddleware.cs index 354770e..d601f1d 100644 --- a/src/Voxel.MiddyNet/ILambdaMiddleware.cs +++ b/src/Voxel.MiddyNet/ILambdaMiddleware.cs @@ -7,5 +7,7 @@ public interface ILambdaMiddleware Task Before(TReq lambdaEvent, MiddyNetContext context); Task After(TRes lambdaResponse, MiddyNetContext context); + + bool InterruptsExecution { get; } } } diff --git a/src/Voxel.MiddyNet/MiddyNet.cs b/src/Voxel.MiddyNet/MiddyNet.cs index 62b7077..10ba2aa 100644 --- a/src/Voxel.MiddyNet/MiddyNet.cs +++ b/src/Voxel.MiddyNet/MiddyNet.cs @@ -15,9 +15,14 @@ public async Task Handler(TReq lambdaEvent, ILambdaContext context) { InitialiseMiddyContext(context); - await ExecuteBeforeMiddlewares(lambdaEvent); + var beforeMiddlewaresExecutedWithoutErrors = await ExecuteBeforeMiddlewares(lambdaEvent); - var response = await SafeHandleLambdaEvent(lambdaEvent).ConfigureAwait(false); + var response = default(TRes); + + if (beforeMiddlewaresExecutedWithoutErrors) + { + response = await SafeHandleLambdaEvent(lambdaEvent).ConfigureAwait(false); + } response = await ExecuteAfterMiddlewares(response); @@ -66,7 +71,7 @@ private async Task ExecuteAfterMiddlewares(TRes response) return response; } - private async Task ExecuteBeforeMiddlewares(TReq lambdaEvent) + private async Task ExecuteBeforeMiddlewares(TReq lambdaEvent) { foreach (var middleware in middlewares) { @@ -77,8 +82,11 @@ private async Task ExecuteBeforeMiddlewares(TReq lambdaEvent) catch (Exception ex) { MiddyContext.MiddlewareBeforeExceptions.Add(ex); + if (middleware.InterruptsExecution) return false; } } + + return true; } private void InitialiseMiddyContext(ILambdaContext context) diff --git a/test/Voxel.MiddyNet.Tests/MiddyNetShould.cs b/test/Voxel.MiddyNet.Tests/MiddyNetShould.cs index dedea25..1e7ed49 100644 --- a/test/Voxel.MiddyNet.Tests/MiddyNetShould.cs +++ b/test/Voxel.MiddyNet.Tests/MiddyNetShould.cs @@ -23,17 +23,17 @@ public class TestLambdaFunction : MiddyNet { private readonly bool withFailingHandler; - public TestLambdaFunction(List logLines, List contextLogLines, int numberOfMiddlewares, bool withFailingMiddleware = false, bool withFailingHandler = false) - : this(logLines, contextLogLines, numberOfMiddlewares, withFailingMiddleware, withFailingMiddleware, withFailingHandler) { } + public TestLambdaFunction(List logLines, List contextLogLines, int numberOfMiddlewares, bool withFailingMiddleware = false, bool withFailingHandler = false, bool withInterruptingMiddleware = false) + : this(logLines, contextLogLines, numberOfMiddlewares, withFailingMiddleware, withFailingMiddleware, withFailingHandler, withInterruptingMiddleware) { } - public TestLambdaFunction(List logLines, List contextLogLines, int numberOfMiddlewares, bool withFailingBeforeMiddleware, bool withFailingAfterMiddleware, bool withFailingHandler) + public TestLambdaFunction(List logLines, List contextLogLines, int numberOfMiddlewares, bool withFailingBeforeMiddleware, bool withFailingAfterMiddleware, bool withFailingHandler, bool withInterruptingMiddleware) { LogLines = logLines; ContextLogLines = contextLogLines; for (var i = 0; i < numberOfMiddlewares; i++) { - Use(new TestBeforeMiddleware(logLines, i + 1, withFailingBeforeMiddleware)); + Use(new TestBeforeMiddleware(logLines, i + 1, withFailingBeforeMiddleware, withInterruptingMiddleware)); Use(new TestAfterMiddleware(logLines, i + 1, withFailingAfterMiddleware)); } @@ -62,12 +62,16 @@ public class MiddlewareException : Exception { } public class TestBeforeMiddleware : ILambdaMiddleware { private readonly int position; + private readonly bool interrupts = false; public List LogLines { get; } public bool Failing { get; } - public TestBeforeMiddleware(List logLines, int position, bool failing) + public bool InterruptsExecution => interrupts; + + public TestBeforeMiddleware(List logLines, int position, bool failing, bool interrupts) { this.position = position; + this.interrupts = interrupts; LogLines = logLines; Failing = failing; } @@ -94,6 +98,8 @@ public class TestAfterMiddleware : ILambdaMiddleware public List LogLines { get; } public bool Failing { get; } + public bool InterruptsExecution => false; + public TestAfterMiddleware(List logLines, int position, bool failing) { this.position = position; @@ -205,7 +211,7 @@ public void IncludeHandlerExceptionOnAfterErrorNotifications(int numberOfMiddlew [InlineData(false, false, true)] public void ThrowSpecificExceptionWhenOnlyOnePresent(bool throwBeforeException, bool throwAfterException, bool throwHandlerException) { - var lambdaFunction = new TestLambdaFunction(logLines, contextLines, 1, throwBeforeException, throwAfterException, throwHandlerException); + var lambdaFunction = new TestLambdaFunction(logLines, contextLines, 1, throwBeforeException, throwAfterException, throwHandlerException, false); Func act = async () => await lambdaFunction.Handler(0, new FakeLambdaContext()); act.Should().NotThrow(); @@ -223,6 +229,30 @@ public async Task LetMiddlewaresChangeTheFunctionResult() result.Should().Be(9); } + [Fact] + public void StopEvaluatingBeforeMiddlewaresIfInterruptExecutionSetToTrueAndExceptionHappens() + { + var lambdaFunction = new TestLambdaFunction(logLines, contextLines, 2, true, false, true); + Func act = async () => await lambdaFunction.Handler(0, new FakeLambdaContext()); + + act.Should().Throw(); + + logLines.Should().Contain($"{MiddlewareBeforeLog}-1"); + logLines.Should().NotContain($"{MiddlewareBeforeLog}-2"); + + } + + [Fact] + public void NotExecuteLambdaIfInterruptExecutionSetToTrueAndExceptionHappens() + { + var lambdaFunction = new TestLambdaFunction(logLines, contextLines, 2, true, false, true); + Func act = async () => await lambdaFunction.Handler(0, new FakeLambdaContext()); + + act.Should().Throw(); + logLines.Should().NotContain(FunctionLog); + + } + private class AddsTwo : MiddyNet { protected override Task Handle(int lambdaEvent, MiddyNetContext context) @@ -233,6 +263,8 @@ protected override Task Handle(int lambdaEvent, MiddyNetContext context) private class SquareIt : ILambdaMiddleware { + public bool InterruptsExecution => false; + public Task After(int lambdaResponse, MiddyNetContext context) { return Task.FromResult(lambdaResponse * lambdaResponse);