Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Better Lambda web request input parameter validation. #2653

Merged
merged 1 commit into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ private class FunctionDetails
public AwsLambdaEventType EventType { get; private set; } = AwsLambdaEventType.Unknown;

public bool HasContext() => ContextIdx != -1;
public bool HasInputObject() => InputIdx != -1;
private bool HasInputObject() => InputIdx != -1;

public void SetContext(object lambdaContext, int contextIdx)
{
Expand Down Expand Up @@ -112,6 +112,41 @@ public object GetInputObject(InstrumentedMethodCall instrumentedMethodCall)
}

public bool IsWebRequest => EventType is AwsLambdaEventType.APIGatewayProxyRequest or AwsLambdaEventType.APIGatewayHttpApiV2ProxyRequest or AwsLambdaEventType.ApplicationLoadBalancerRequest;

public bool ValidateWebRequestParameters(InstrumentedMethodCall instrumentedMethodCall)
{
if (HasInputObject() && IsWebRequest)
{
dynamic input = GetInputObject(instrumentedMethodCall);

// make sure the request includes Http Method and Path
switch (EventType)
{
case AwsLambdaEventType.APIGatewayHttpApiV2ProxyRequest:
{
if (input.RequestContext != null)
{
dynamic requestContext = input.RequestContext;

return !string.IsNullOrEmpty(requestContext.Http.Method) && !string.IsNullOrEmpty(requestContext.Http.Path);
}

return false;
}
case AwsLambdaEventType.APIGatewayProxyRequest:
case AwsLambdaEventType.ApplicationLoadBalancerRequest:
{
dynamic webReq = input;
return !string.IsNullOrEmpty(webReq.HttpMethod) && !string.IsNullOrEmpty(webReq.Path);
}
default:
return true;
}

}

return false;
}
}

private List<string> _webResponseHeaders = ["content-type", "content-length"];
Expand Down Expand Up @@ -194,8 +229,19 @@ public AfterWrappedMethodDelegate BeforeWrappedMethod(InstrumentedMethodCall ins
}
}

if (_functionDetails!.IsWebRequest)
{
if (!_functionDetails.ValidateWebRequestParameters(instrumentedMethodCall))
{
agent.Logger.Debug($"Invalid or missing web request parameters. HttpMethod and Path are required for {_functionDetails.EventType}. Not instrumenting this function invocation.");
return Delegates.NoOp;
}
}



var isAsync = instrumentedMethodCall.IsAsync;
string requestId = _functionDetails!.GetRequestId(instrumentedMethodCall);
string requestId = _functionDetails.GetRequestId(instrumentedMethodCall);
var inputObject = _functionDetails.GetInputObject(instrumentedMethodCall);

transaction = agent.CreateTransaction(
Expand Down Expand Up @@ -314,7 +360,7 @@ private void CaptureResponseData(ITransaction transaction, object response, IAge
{
// copy and lowercase the headers
Dictionary<string, string> copiedHeaders = new Dictionary<string, string>();
foreach(var kvp in responseHeaders)
foreach (var kvp in responseHeaders)
copiedHeaders.Add(kvp.Key.ToLower(), kvp.Value);

foreach (var header in _webResponseHeaders) // only capture specific headers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ public abstract class AgentLogBase
// Serverless payloads
public const string ServerlessPayloadLogLineRegex = FinestLogLinePrefixRegex + @"Serverless payload: (.*)";

// Invalid serverless web request
public const string InvalidServerlessWebRequestLogLineRegex = DebugLogLinePrefixRegex + @"Invalid or missing web request parameters. (.*)";

public AgentLogBase(ITestOutputHelper testLogger)
{
_testLogger = testLogger;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ protected AwsLambdaAPIGatewayHttpApiV2ProxyRequestTest(T fixture, ITestOutputHel
_fixture.EnqueueAPIGatewayHttpApiV2ProxyRequest();
_fixture.EnqueueAPIGatewayHttpApiV2ProxyRequestWithDTHeaders(TestTraceId, TestParentSpanId);
_fixture.EnqueueMinimalAPIGatewayHttpApiV2ProxyRequest();
_fixture.EnqueueInvalidAPIGatewayHttpApiV2ProxyRequest();
_fixture.AgentLog.WaitForLogLines(AgentLogBase.ServerlessPayloadLogLineRegex, TimeSpan.FromMinutes(1), 3);
}
);
Expand All @@ -47,13 +48,18 @@ public void Test()
var serverlessPayloads = _fixture.AgentLog.GetServerlessPayloads().ToList();

Assert.Multiple(
// the fourth exerciser invocation should result in a NoOpDelegate, so there will only be 3 payloads
() => Assert.Equal(3, serverlessPayloads.Count),
// validate the first 2 payloads separately from the 3rd
() => Assert.All(serverlessPayloads.GetRange(0, 2), ValidateServerlessPayload),
() => ValidateMinimalRequestPayload(serverlessPayloads[2]),
() => ValidateTraceHasNoParent(serverlessPayloads[0]),
() => ValidateTraceHasParent(serverlessPayloads[1])
);

// verify that the invalid request payload generated the expected log line
var logLines = _fixture.AgentLog.TryGetLogLines(AgentLogBase.InvalidServerlessWebRequestLogLineRegex);
Assert.Single(logLines);
}

private void ValidateServerlessPayload(ServerlessPayload serverlessPayload)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ protected AwsLambdaAPIGatewayProxyRequestTest(T fixture, ITestOutputHelper outpu
_fixture.EnqueueAPIGatewayProxyRequest();
_fixture.EnqueueAPIGatewayProxyRequestWithDTHeaders(TestTraceId, TestParentSpanId);
_fixture.EnqueueMinimalAPIGatewayProxyRequest();
_fixture.EnqueueInvalidAPIGatewayProxyRequest();
_fixture.AgentLog.WaitForLogLines(AgentLogBase.ServerlessPayloadLogLineRegex, TimeSpan.FromMinutes(1), 3);
}
);
Expand All @@ -47,13 +48,18 @@ public void Test()
var serverlessPayloads = _fixture.AgentLog.GetServerlessPayloads().ToList();

Assert.Multiple(
// the fourth exerciser invocation should result in a NoOpDelegate, so there will only be 3 payloads
() => Assert.Equal(3, serverlessPayloads.Count),
// validate the first 2 payloads separately from the 3rd
() => Assert.All(serverlessPayloads.GetRange(0, 2), ValidateServerlessPayload),
() => ValidateMinimalRequestPayload(serverlessPayloads[2]),
() => ValidateTraceHasNoParent(serverlessPayloads[0]),
() => ValidateTraceHasParent(serverlessPayloads[1])
);

// verify that the invalid request payload generated the expected log line
var logLines = _fixture.AgentLog.TryGetLogLines(AgentLogBase.InvalidServerlessWebRequestLogLineRegex);
Assert.Single(logLines);
}

private void ValidateServerlessPayload(ServerlessPayload serverlessPayload)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ protected AwsLambdaApplicationLoadBalancerRequestTest(T fixture, ITestOutputHelp
{
_fixture.EnqueueApplicationLoadBalancerRequest();
_fixture.EnqueueApplicationLoadBalancerRequestWithDTHeaders(TestTraceId, TestParentSpanId);
_fixture.EnqueueInvalidLoadBalancerRequestyRequest();
_fixture.AgentLog.WaitForLogLines(AgentLogBase.ServerlessPayloadLogLineRegex, TimeSpan.FromMinutes(1), 2);
}
);
Expand All @@ -46,11 +47,16 @@ public void Test()
var serverlessPayloads = _fixture.AgentLog.GetServerlessPayloads().ToList();

Assert.Multiple(
// the third exerciser invocation should result in a NoOpDelegate, so there will only be 2 payloads
() => Assert.Equal(2, serverlessPayloads.Count),
() => Assert.All(serverlessPayloads, ValidateServerlessPayload),
() => ValidateTraceHasNoParent(serverlessPayloads[0]),
() => ValidateTraceHasParent(serverlessPayloads[1])
);

// verify that the invalid request payload generated the expected log line
var logLines = _fixture.AgentLog.TryGetLogLines(AgentLogBase.InvalidServerlessWebRequestLogLineRegex);
Assert.Single(logLines);
}

private void ValidateServerlessPayload(ServerlessPayload serverlessPayload)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,19 @@ public void EnqueueMinimalAPIGatewayHttpApiV2ProxyRequest()
""";
EnqueueLambdaEvent(apiGatewayProxyRequestJson);
}

/// <summary>
/// An invalid payload to validate the fix for https://github.com/newrelic/newrelic-dotnet-agent/issues/2652
/// </summary>
public void EnqueueInvalidAPIGatewayHttpApiV2ProxyRequest()
{
var invalidApiGatewayHttpApiV2ProxyRequestJson = $$"""
{
"foo": "bar"
}
""";
EnqueueLambdaEvent(invalidApiGatewayHttpApiV2ProxyRequestJson);
}
}

public class LambdaAPIGatewayHttpApiV2ProxyRequestTriggerFixtureNet6 : LambdaAPIGatewayHttpApiV2ProxyRequestTriggerFixtureBase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,19 @@ public void EnqueueMinimalAPIGatewayProxyRequest()
""";
EnqueueLambdaEvent(apiGatewayProxyRequestJson);
}

/// <summary>
/// An invalid payload to validate the fix for https://github.com/newrelic/newrelic-dotnet-agent/issues/2652
/// </summary>
public void EnqueueInvalidAPIGatewayProxyRequest()
{
var invalidApiGatewayProxyRequestJson = $$"""
{
"foo": "bar"
}
""";
EnqueueLambdaEvent(invalidApiGatewayProxyRequestJson);
}
}

public class LambdaAPIGatewayProxyRequestTriggerFixtureNet6 : LambdaAPIGatewayProxyRequestTriggerFixtureBase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,19 @@ public void EnqueueApplicationLoadBalancerRequestWithDTHeaders(string traceId, s
""";
EnqueueLambdaEvent(ApplicationLoadBalancerRequestJson);
}

/// <summary>
/// An invalid payload to validate the fix for https://github.com/newrelic/newrelic-dotnet-agent/issues/2652
/// </summary>
public void EnqueueInvalidLoadBalancerRequestyRequest()
{
var invalidLoadBalancerRequestJson = $$"""
{
"foo": "bar"
}
""";
EnqueueLambdaEvent(invalidLoadBalancerRequestJson);
}
}

public class LambdaApplicationLoadBalancerRequestTriggerFixtureNet6 : LambdaApplicationLoadBalancerRequestTriggerFixtureBase
Expand Down
Loading