Skip to content

Commit

Permalink
Fix to better handle lambda responses when they are empty or null or …
Browse files Browse the repository at this point in the history
…not a valid json (opensearch-project#5211)

* Fix to better handle lambda responses when they are empty or null or not a valid json

Signed-off-by: Santhosh Gandhe <[email protected]>

* UTs for strict mode response comparison

Signed-off-by: Santhosh Gandhe <[email protected]>

* Additional UTs for strict mode and aggregate mode

Signed-off-by: Santhosh Gandhe <[email protected]>

* removed unused method and better method naming

Signed-off-by: Santhosh Gandhe <[email protected]>

* doExecute method testing

Signed-off-by: Santhosh Gandhe <[email protected]>

* better exception message

Signed-off-by: Santhosh Gandhe <[email protected]>

* Add IT for aggregate mode cases

Signed-off-by: Srikanth Govindarajan <[email protected]>

* Testing with presence of tags

Signed-off-by: Santhosh Gandhe <[email protected]>

* Add IT to test behaviour for different lambda responses

Signed-off-by: Srikanth Govindarajan <[email protected]>

* removed unused imports

Signed-off-by: Santhosh Gandhe <[email protected]>

* fix checkstyle

Signed-off-by: Srikanth Govindarajan <[email protected]>

* fix checkstyle

Signed-off-by: Srikanth Govindarajan <[email protected]>

* Address comments

Signed-off-by: Srikanth Govindarajan <[email protected]>

---------

Signed-off-by: Santhosh Gandhe <[email protected]>
Signed-off-by: Srikanth Govindarajan <[email protected]>
Co-authored-by: Srikanth Govindarajan <[email protected]>
  • Loading branch information
san81 and srikanthjg authored Nov 22, 2024
1 parent 6c9cdeb commit 539d599
Show file tree
Hide file tree
Showing 9 changed files with 288 additions and 110 deletions.
29 changes: 28 additions & 1 deletion data-prepper-plugins/aws-lambda/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,35 @@ The following command runs the integration tests:
-Dtests.lambda.processor.region="us-east-1" \
-Dtests.lambda.processor.functionName="test-lambda-processor" \
-Dtests.lambda.processor.sts_role_arn="arn:aws:iam::<>:role/lambda-role"
```

Lambda handler used to test:
```
def lambda_handler(event, context):
input_arr = event.get('osi_key', [])
output = []
if len(input_arr) == 1:
input = input_arr[0]
if "returnNone" in input:
return
if "returnString" in input:
return "RandomString"
if "returnObject" in input:
return input_arr[0]
if "returnEmptyArray" in input:
return output
if "returnNull" in input:
return "null"
if "returnEmptyMapinArray" in input:
return [{}]
for input in input_arr:
input["_out_"] = "transformed";
for k,v in input.items():
if type(v) is str:
input[k] = v.upper()
output.append(input)
return output
```


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,24 @@

import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Timer;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import org.mockito.Mock;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
Expand Down Expand Up @@ -50,16 +62,6 @@
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
@MockitoSettings(strictness = Strictness.LENIENT)
public class LambdaProcessorIT {
Expand Down Expand Up @@ -95,6 +97,7 @@ public void setup() {
lambdaRegion = System.getProperty("tests.lambda.processor.region");
functionName = System.getProperty("tests.lambda.processor.functionName");
role = System.getProperty("tests.lambda.processor.sts_role_arn");

pluginMetrics = mock(PluginMetrics.class);
pluginSetting = mock(PluginSetting.class);
when(pluginSetting.getPipelineName()).thenReturn("pipeline");
Expand Down Expand Up @@ -232,6 +235,87 @@ public void testWithFailureTags() throws Exception {
}
}

@ParameterizedTest
@ValueSource(strings = {"returnNull", "returnEmptyArray", "returnString", "returnEmptyMapinArray", "returnNone"})
public void testAggregateMode_WithVariousResponses(String input) {
when(lambdaProcessorConfig.getFunctionName()).thenReturn(functionName);
when(invocationType.getAwsLambdaValue()).thenReturn(InvocationType.REQUEST_RESPONSE.getAwsLambdaValue());
when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(false); // Aggregate mode
when(lambdaProcessorConfig.getTagsOnFailure()).thenReturn(Collections.singletonList("lambda_failure"));
lambdaProcessor = createObjectUnderTest(lambdaProcessorConfig);
List<Record<Event>> records = createRecord(input);

Collection<Record<Event>> results = lambdaProcessor.doExecute(records);

switch (input) {
case "returnNull":
case "returnEmptyArray":
case "returnString":
case "returnNone":
assertTrue(results.isEmpty(), "Events should be dropped for null, empty array, or string response");
break;
case "returnEmptyMapinArray":
assertEquals(1, results.size(), "Should have one event in result for empty map in array");
assertTrue(results.stream().allMatch(record -> record.getData().toMap().isEmpty()),
"Result should be an empty map");
break;
default:
fail("Unexpected input: " + input);
}
}

@ParameterizedTest
@ValueSource(strings = {"returnNone", "returnString", "returnObject", "returnEmptyArray", "returnNull", "returnEmptyMapinArray"})
public void testStrictMode_WithVariousResponses(String input) {
when(lambdaProcessorConfig.getFunctionName()).thenReturn(functionName);
when(invocationType.getAwsLambdaValue()).thenReturn(InvocationType.REQUEST_RESPONSE.getAwsLambdaValue());
when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(true); // Strict mode
when(lambdaProcessorConfig.getTagsOnFailure()).thenReturn(Collections.singletonList("lambda_failure"));
lambdaProcessor = createObjectUnderTest(lambdaProcessorConfig);
List<Record<Event>> records = createRecord(input);

Collection<Record<Event>> results = lambdaProcessor.doExecute(records);

switch (input) {
case "returnNone":
case "returnString":
case "returnEmptyArray":
case "returnNull":
assertEquals(1, results.size(), "Should return original record with failure tag");
assertTrue(results.iterator().next().getData().getMetadata().getTags().contains("lambda_failure"),
"Result should contain lambda_failure tag");
break;
case "returnObject":
assertEquals(1, results.size(), "Should return one record");
assertEquals(records.get(0).getData().toMap(), results.iterator().next().getData().toMap(),
"Returned record should match input record");
break;
case "returnEmptyMapinArray":
assertEquals(1, results.size(), "Should return one record");
assertTrue(results.iterator().next().getData().toMap().isEmpty(),
"Returned record should be an empty map");
break;
}
}

private List<Record<Event>> createRecord(String input) {
List<Record<Event>> records = new ArrayList<>();
Map<String, Object> map = new HashMap<>();
map.put(input, 42);
EventMetadata metadata = DefaultEventMetadata.builder()
.withEventType("event")
.build();
final Event event = JacksonEvent.builder()
.withData(map)
.withEventType("event")
.withEventMetadata(metadata)
.build();
records.add(new Record<>(event));

return records;
}


private void validateResultsForAggregateMode(Collection<Record<Event>> results) {
List<Record<Event>> resultRecords = new ArrayList<>(results);
for (int i = 0; i < resultRecords.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io.micrometer.core.instrument.Timer;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.expression.ExpressionEvaluator;
import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin;
import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor;
Expand All @@ -25,6 +26,7 @@
import org.opensearch.dataprepper.model.sink.OutputCodecContext;
import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodecConfig;
import org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler;
import static org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler.isSuccess;
import org.opensearch.dataprepper.plugins.lambda.common.ResponseEventHandlingStrategy;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer;
import org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory;
Expand All @@ -48,9 +50,6 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;

import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY;
import static org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler.isSuccess;

@DataPrepperPlugin(name = "aws_lambda", pluginType = Processor.class, pluginConfigurationType = LambdaProcessorConfig.class)
public class LambdaProcessor extends AbstractProcessor<Record<Event>, Record<Event>> {

Expand All @@ -61,6 +60,8 @@ public class LambdaProcessor extends AbstractProcessor<Record<Event>, Record<Eve
public static final String LAMBDA_LATENCY_METRIC = "lambdaFunctionLatency";
public static final String REQUEST_PAYLOAD_SIZE = "requestPayloadSize";
public static final String RESPONSE_PAYLOAD_SIZE = "responsePayloadSize";
public static final String LAMBDA_RESPONSE_RECORDS_COUNTER = "lambdaResponseRecordsCounter";
private static final String NO_RETURN_RESPONSE = "null";

private static final Logger LOG = LoggerFactory.getLogger(LambdaProcessor.class);
final PluginSetting codecPluginSetting;
Expand All @@ -72,6 +73,7 @@ public class LambdaProcessor extends AbstractProcessor<Record<Event>, Record<Eve
private final Counter numberOfRecordsFailedCounter;
private final Counter numberOfRequestsSuccessCounter;
private final Counter numberOfRequestsFailedCounter;
private final Counter lambdaResponseRecordsCounter;
private final Timer lambdaLatencyMetric;
private final List<String> tagsOnFailure;
private final LambdaAsyncClient lambdaAsyncClient;
Expand Down Expand Up @@ -102,6 +104,7 @@ public LambdaProcessor(final PluginFactory pluginFactory, final PluginSetting pl
this.lambdaLatencyMetric = pluginMetrics.timer(LAMBDA_LATENCY_METRIC);
this.requestPayloadMetric = pluginMetrics.summary(REQUEST_PAYLOAD_SIZE);
this.responsePayloadMetric = pluginMetrics.summary(RESPONSE_PAYLOAD_SIZE);
this.lambdaResponseRecordsCounter = pluginMetrics.counter(LAMBDA_RESPONSE_RECORDS_COUNTER);
this.whenCondition = lambdaProcessorConfig.getWhenCondition();
this.tagsOnFailure = lambdaProcessorConfig.getTagsOnFailure();

Expand Down Expand Up @@ -163,6 +166,8 @@ public Collection<Record<Event>> doExecute(Collection<Record<Event>> records) {
new OutputCodecContext());
} catch (Exception e) {
LOG.error(NOISY, "Error while sending records to Lambda", e);
numberOfRecordsFailedCounter.increment(recordsToLambda.size());
numberOfRequestsFailedCounter.increment();
resultRecords.addAll(addFailureTags(recordsToLambda));
}

Expand Down Expand Up @@ -211,28 +216,19 @@ List<Record<Event>> convertLambdaResponseToEvent(Buffer flushedBuffer,
List<Event> parsedEvents = new ArrayList<>();

SdkBytes payload = lambdaResponse.payload();
// Handle null or empty payload
if (payload == null || payload.asByteArray().length == 0) {
LOG.warn(NOISY,
"Lambda response payload is null or empty, dropping the original events");
return responseStrategy.handleEvents(parsedEvents, originalRecords);
// Considering "null" payload as empty response from lambda and not parsing it.
if (!(NO_RETURN_RESPONSE.equals(payload.asUtf8String()))) {
//Convert using response codec
InputStream inputStream = new ByteArrayInputStream(payload.asByteArray());
responseCodec.parse(inputStream, record -> {
Event event = record.getData();
parsedEvents.add(event);
});
}

//Convert using response codec
InputStream inputStream = new ByteArrayInputStream(payload.asByteArray());
responseCodec.parse(inputStream, record -> {
Event event = record.getData();
parsedEvents.add(event);
});

if (parsedEvents.isEmpty()) {
throw new RuntimeException(
"Lambda Response could not be parsed, returning original events");
}

LOG.debug("Parsed Event Size:{}, FlushedBuffer eventCount:{}, " +
"FlushedBuffer size:{}", parsedEvents.size(), flushedBuffer.getEventCount(),
flushedBuffer.getSize());
lambdaResponseRecordsCounter.increment(parsedEvents.size());
return responseStrategy.handleEvents(parsedEvents, originalRecords);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.plugins.lambda.common.ResponseEventHandlingStrategy;
import org.opensearch.dataprepper.plugins.lambda.processor.exception.StrictResponseModeNotRespectedException;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -19,8 +20,13 @@ public class StrictResponseEventHandlingStrategy implements ResponseEventHandlin
public List<Record<Event>> handleEvents(List<Event> parsedEvents,
List<Record<Event>> originalRecords) {
if (parsedEvents.size() != originalRecords.size()) {
throw new RuntimeException(
"Response Processing Mode is configured as Strict mode but behavior is aggregate mode. Event count mismatch.");
throw new StrictResponseModeNotRespectedException(
"Event count mismatch. The aws_lambda processor is configured with response_events_match set to true. " +
"The Lambda function responded with a different number of events. " +
"Either set response_events_match to false or investigate your " +
"Lambda function to ensure that it returns the same number of " +
"events and provided as input. parsedEvents size = " + parsedEvents.size() +
", Original events size = " + originalRecords.size());
}

List<Record<Event>> resultRecords = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package org.opensearch.dataprepper.plugins.lambda.processor.exception;

public class StrictResponseModeNotRespectedException extends RuntimeException {
public StrictResponseModeNotRespectedException(final String message) {
super(message);
}
}
Loading

0 comments on commit 539d599

Please sign in to comment.