Skip to content

Commit

Permalink
Use aws-lambda-java-serialization library, which is available by de…
Browse files Browse the repository at this point in the history
…fault, while deserializing input and serializing output (#11868)
  • Loading branch information
serkan-ozal authored Aug 12, 2024
1 parent d480f15 commit e2cfe37
Show file tree
Hide file tree
Showing 13 changed files with 1,008 additions and 256 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.RequestStreamHandler;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.context.Scope;
import io.opentelemetry.instrumentation.awslambdacore.v1_0.internal.ApiGatewayProxyRequest;
import io.opentelemetry.instrumentation.awslambdacore.v1_0.internal.AwsLambdaFunctionInstrumenter;
Expand Down Expand Up @@ -67,49 +68,55 @@ protected TracingRequestStreamHandler(
@Override
public void handleRequest(InputStream input, OutputStream output, Context context)
throws IOException {

ApiGatewayProxyRequest proxyRequest = ApiGatewayProxyRequest.forStream(input);
AwsLambdaRequest request =
AwsLambdaRequest.create(context, proxyRequest, proxyRequest.getHeaders());
AwsLambdaRequest request = createRequest(input, context, proxyRequest);
io.opentelemetry.context.Context parentContext = instrumenter.extract(request);

if (!instrumenter.shouldStart(parentContext, request)) {
doHandleRequest(proxyRequest.freshStream(), output, context);
doHandleRequest(proxyRequest.freshStream(), output, context, request);
return;
}

io.opentelemetry.context.Context otelContext = instrumenter.start(parentContext, request);
Throwable error = null;
try (Scope ignored = otelContext.makeCurrent()) {
doHandleRequest(
proxyRequest.freshStream(),
new OutputStreamWrapper(output, otelContext, request, openTelemetrySdk),
context);
new OutputStreamWrapper(output, otelContext),
context,
request);
} catch (Throwable t) {
instrumenter.end(otelContext, request, null, t);
LambdaUtils.forceFlush(openTelemetrySdk, flushTimeoutNanos, TimeUnit.NANOSECONDS);
error = t;
throw t;
} finally {
instrumenter.end(otelContext, request, null, error);
LambdaUtils.forceFlush(openTelemetrySdk, flushTimeoutNanos, TimeUnit.NANOSECONDS);
}
}

protected AwsLambdaRequest createRequest(
InputStream input, Context context, ApiGatewayProxyRequest proxyRequest) throws IOException {
return AwsLambdaRequest.create(context, proxyRequest, proxyRequest.getHeaders());
}

protected void doHandleRequest(
InputStream input, OutputStream output, Context context, AwsLambdaRequest request)
throws IOException {
doHandleRequest(input, output, context);
}

protected abstract void doHandleRequest(InputStream input, OutputStream output, Context context)
throws IOException;

private class OutputStreamWrapper extends OutputStream {
private static class OutputStreamWrapper extends OutputStream {

private final OutputStream delegate;
private final io.opentelemetry.context.Context otelContext;
private final AwsLambdaRequest request;
private final OpenTelemetrySdk openTelemetrySdk;

private OutputStreamWrapper(
OutputStream delegate,
io.opentelemetry.context.Context otelContext,
AwsLambdaRequest request,
OpenTelemetrySdk openTelemetrySdk) {
OutputStream delegate, io.opentelemetry.context.Context otelContext) {
this.delegate = delegate;
this.otelContext = otelContext;
this.request = request;
this.openTelemetrySdk = openTelemetrySdk;
}

@Override
Expand All @@ -135,8 +142,8 @@ public void flush() throws IOException {
@Override
public void close() throws IOException {
delegate.close();
instrumenter.end(otelContext, request, null, null);
LambdaUtils.forceFlush(openTelemetrySdk, flushTimeoutNanos, TimeUnit.NANOSECONDS);
Span span = Span.fromContext(otelContext);
span.addEvent("Output stream closed");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
*/
public class TracingRequestStreamWrapper extends TracingRequestStreamHandler {

private final WrappedLambda wrappedLambda;
protected final WrappedLambda wrappedLambda;

public TracingRequestStreamWrapper() {
this(
Expand All @@ -32,7 +32,8 @@ public TracingRequestStreamWrapper() {
}

// Visible for testing
TracingRequestStreamWrapper(OpenTelemetrySdk openTelemetrySdk, WrappedLambda wrappedLambda) {
protected TracingRequestStreamWrapper(
OpenTelemetrySdk openTelemetrySdk, WrappedLambda wrappedLambda) {
super(openTelemetrySdk, WrapperConfiguration.flushTimeout());
this.wrappedLambda = wrappedLambda;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ dependencies {
// in public API.
library("com.amazonaws:aws-lambda-java-events:2.2.1")

// By default, "aws-lambda-java-serialization" library is enabled in the classpath
// at the AWS Lambda environment except "java8" runtime which is deprecated.
// But it is available at "java8.al2" runtime, so it is still can be used
// by Java 8 based Lambda functions.
// So that is the reason that why we add it as compile only dependency.
compileOnly("com.amazonaws:aws-lambda-java-serialization:1.1.5")

// We need Jackson for wrappers to reproduce the serialization does when Lambda invokes a RequestHandler with event
// since Lambda will only be able to invoke the wrapper itself with a generic Object.
// Note that Lambda itself uses Jackson, but does not expose it to the function so we need to include it here.
Expand All @@ -33,6 +40,7 @@ dependencies {
testImplementation("io.opentelemetry:opentelemetry-sdk-extension-autoconfigure")
testImplementation("io.opentelemetry:opentelemetry-extension-trace-propagators")
testImplementation("com.google.guava:guava")
testImplementation("com.amazonaws:aws-lambda-java-serialization:1.1.5")

testImplementation(project(":instrumentation:aws-lambda:aws-lambda-events-2.2:testing"))
testImplementation("uk.org.webcompere:system-stubs-jupiter")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package io.opentelemetry.instrumentation.awslambdaevents.v2_2;

import com.amazonaws.services.lambda.runtime.Context;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.util.function.BiFunction;

Expand All @@ -27,5 +28,35 @@ static <T> Object[] toArray(
return parameters;
}

static <T> Object[] toParameters(Method targetMethod, T input, Context context) {
Class<?>[] parameterTypes = targetMethod.getParameterTypes();
Object[] parameters = new Object[parameterTypes.length];
for (int i = 0; i < parameterTypes.length; i++) {
Class<?> clazz = parameterTypes[i];
boolean isContext = clazz.equals(Context.class);
if (isContext) {
parameters[i] = context;
} else if (i == 0) {
parameters[0] = input;
}
}
return parameters;
}

static Object toInput(
Method targetMethod,
InputStream inputStream,
BiFunction<InputStream, Class<?>, Object> mapper) {
Class<?>[] parameterTypes = targetMethod.getParameterTypes();
for (int i = 0; i < parameterTypes.length; i++) {
Class<?> clazz = parameterTypes[i];
boolean isContext = clazz.equals(Context.class);
if (i == 0 && !isContext) {
return mapper.apply(inputStream, clazz);
}
}
return null;
}

private LambdaParameters() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyResponseEvent;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.opentelemetry.instrumentation.awslambdacore.v1_0.internal.WrappedLambda;
import io.opentelemetry.instrumentation.awslambdaevents.v2_2.internal.SerializationUtil;
import io.opentelemetry.sdk.OpenTelemetrySdk;
import java.util.function.BiFunction;

Expand All @@ -35,12 +35,7 @@ public TracingRequestApiGatewayWrapper() {

// Visible for testing
static <T> T map(APIGatewayProxyRequestEvent event, Class<T> clazz) {
try {
return OBJECT_MAPPER.readValue(event.getBody(), clazz);
} catch (JsonProcessingException e) {
throw new IllegalStateException(
"Could not map API Gateway event body to requested parameter type: " + clazz, e);
}
return SerializationUtil.fromJson(event.getBody(), clazz);
}

@Override
Expand All @@ -52,12 +47,8 @@ protected APIGatewayProxyResponseEvent doHandleRequest(
if (result instanceof APIGatewayProxyResponseEvent) {
event = (APIGatewayProxyResponseEvent) result;
} else {
try {
event = new APIGatewayProxyResponseEvent();
event.setBody(OBJECT_MAPPER.writeValueAsString(result));
} catch (JsonProcessingException e) {
throw new IllegalStateException("Could not serialize return value.", e);
}
event = new APIGatewayProxyResponseEvent();
event.setBody(SerializationUtil.toJson(result));
}
return event;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,90 @@

package io.opentelemetry.instrumentation.awslambdaevents.v2_2;

import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent;
import io.opentelemetry.instrumentation.awslambdacore.v1_0.AwsLambdaRequest;
import io.opentelemetry.instrumentation.awslambdacore.v1_0.TracingRequestStreamWrapper;
import io.opentelemetry.instrumentation.awslambdacore.v1_0.internal.ApiGatewayProxyRequest;
import io.opentelemetry.instrumentation.awslambdacore.v1_0.internal.MapUtils;
import io.opentelemetry.instrumentation.awslambdacore.v1_0.internal.WrappedLambda;
import io.opentelemetry.instrumentation.awslambdaevents.v2_2.internal.SerializationUtil;
import io.opentelemetry.sdk.OpenTelemetrySdk;
import java.util.function.BiFunction;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.Map;

/**
* Wrapper for {@link io.opentelemetry.instrumentation.awslambdacore.v1_0.TracingRequestHandler}.
* Allows for wrapping a regular lambda, not proxied through API Gateway. Therefore, HTTP headers
* propagation is not supported.
* Wrapper for {@link com.amazonaws.services.lambda.runtime.RequestHandler} based Lambda handlers.
*/
public class TracingRequestWrapper extends TracingRequestWrapperBase<Object, Object> {
public class TracingRequestWrapper extends TracingRequestStreamWrapper {
public TracingRequestWrapper() {
super(TracingRequestWrapper::map);
super();
}

// Visible for testing
TracingRequestWrapper(
OpenTelemetrySdk openTelemetrySdk,
WrappedLambda wrappedLambda,
BiFunction<Object, Class<?>, Object> mapper) {
super(openTelemetrySdk, wrappedLambda, mapper);
TracingRequestWrapper(OpenTelemetrySdk openTelemetrySdk, WrappedLambda wrappedLambda) {
super(openTelemetrySdk, wrappedLambda);
}

@Override
protected final AwsLambdaRequest createRequest(
InputStream inputStream, Context context, ApiGatewayProxyRequest proxyRequest) {
Method targetMethod = wrappedLambda.getRequestTargetMethod();
Object input = LambdaParameters.toInput(targetMethod, inputStream, TracingRequestWrapper::map);
return AwsLambdaRequest.create(context, input, extractHeaders(input));
}

protected Map<String, String> extractHeaders(Object input) {
if (input instanceof APIGatewayProxyRequestEvent) {
return MapUtils.emptyIfNull(((APIGatewayProxyRequestEvent) input).getHeaders());
}
return Collections.emptyMap();
}

@Override
protected final void doHandleRequest(
InputStream input, OutputStream output, Context context, AwsLambdaRequest request) {
Method targetMethod = wrappedLambda.getRequestTargetMethod();
Object[] parameters = LambdaParameters.toParameters(targetMethod, request.getInput(), context);
try {
Object result = targetMethod.invoke(wrappedLambda.getTargetObject(), parameters);
SerializationUtil.toJson(output, result);
} catch (IllegalAccessException e) {
throw new IllegalStateException("Method is inaccessible", e);
} catch (InvocationTargetException e) {
throw (e.getCause() instanceof RuntimeException
? (RuntimeException) e.getCause()
: new IllegalStateException(e.getTargetException()));
}
}

@SuppressWarnings({"unchecked", "TypeParameterUnusedInFormals"})
// Used for testing
<INPUT, OUTPUT> OUTPUT handleRequest(INPUT input, Context context) throws IOException {
byte[] inputJsonData = SerializationUtil.toJsonData(input);
ByteArrayInputStream inputStream = new ByteArrayInputStream(inputJsonData);
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();

super.handleRequest(inputStream, outputStream, context);

byte[] outputJsonData = outputStream.toByteArray();
return (OUTPUT)
SerializationUtil.fromJson(
new ByteArrayInputStream(outputJsonData),
wrappedLambda.getRequestTargetMethod().getReturnType());
}

// Visible for testing
static <T> T map(Object jsonMap, Class<T> clazz) {
static <T> T map(InputStream inputStream, Class<T> clazz) {
try {
return OBJECT_MAPPER.convertValue(jsonMap, clazz);
return SerializationUtil.fromJson(inputStream, clazz);
} catch (IllegalArgumentException e) {
throw new IllegalStateException(
"Could not map input to requested parameter type: " + clazz, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.opentelemetry.instrumentation.api.internal.HttpConstants;
import io.opentelemetry.instrumentation.awslambdacore.v1_0.TracingRequestHandler;
import io.opentelemetry.instrumentation.awslambdacore.v1_0.internal.MapUtils;
Expand All @@ -29,10 +27,6 @@
*/
abstract class TracingRequestWrapperBase<I, O> extends TracingRequestHandler<I, O> {

protected static final ObjectMapper OBJECT_MAPPER =
new ObjectMapper()
.registerModule(new CustomJodaModule())
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
private final WrappedLambda wrappedLambda;
private final Method targetMethod;
private final BiFunction<I, Class<?>, Object> parameterMapper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

package io.opentelemetry.instrumentation.awslambdaevents.v2_2;
package io.opentelemetry.instrumentation.awslambdaevents.v2_2.internal;

import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
Expand Down
Loading

0 comments on commit e2cfe37

Please sign in to comment.