Skip to content

Commit

Permalink
Introduce aws sigv4a request signer
Browse files Browse the repository at this point in the history
Signed-off-by: Louis Chu <[email protected]>
  • Loading branch information
noCharger committed Apr 4, 2024
1 parent 77d0078 commit b8a70b7
Show file tree
Hide file tree
Showing 7 changed files with 318 additions and 21 deletions.
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ lazy val flintCore = (project in file("flint-core"))
exclude ("com.fasterxml.jackson.core", "jackson-databind"),
"com.amazonaws" % "aws-java-sdk-cloudwatch" % "1.12.593"
exclude("com.fasterxml.jackson.core", "jackson-databind"),
"software.amazon.awssdk" % "auth-crt" % "2.25.23",
"org.scalactic" %% "scalactic" % "3.2.15" % "test",
"org.scalatest" %% "scalatest" % "3.2.15" % "test",
"org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.core.auth;

import com.amazonaws.auth.AWSSessionCredentials;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.services.glue.model.InvalidStateException;
import org.apache.http.Header;
import org.apache.http.HttpHost;
import org.apache.http.HttpRequest;
import org.apache.http.HttpRequestInterceptor;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.message.BasicHeader;
import org.apache.http.protocol.HttpContext;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.signer.Signer;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.regions.Region;

import java.io.IOException;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

import static org.apache.http.protocol.HttpCoreContext.HTTP_TARGET_HOST;
import static org.opensearch.flint.core.auth.AWSRequestSigningApacheInterceptor.nvpToMapParams;
import static org.opensearch.flint.core.auth.AWSRequestSigningApacheInterceptor.skipHeader;

/**
* Interceptor for signing AWS requests according to Signature Version 4A.
* This interceptor processes HTTP requests, signs them with AWS credentials,
* and updates the request headers to include the signature.
*/
public class AWSRequestSigV4ASigningApacheInterceptor implements HttpRequestInterceptor {
private static final String HTTPS_PROTOCOL = "https";
private static final int HTTPS_PORT = 443;

private final String service;
private final String region;
private final Signer signer;
private final AWSCredentialsProvider awsCredentialsProvider;

/**
* Constructs an interceptor for AWS request signing with metadata access.
*
* @param service The AWS service name.
* @param region The AWS region for signing.
* @param signer The signer implementation.
* @param awsCredentialsProvider The credentials provider for metadata access.
*/
public AWSRequestSigV4ASigningApacheInterceptor(String service, String region, Signer signer, AWSCredentialsProvider awsCredentialsProvider) {
this.service = service;
this.region = region;
this.signer = signer;
this.awsCredentialsProvider = awsCredentialsProvider;
}

/**
* Processes the HTTP request, signs it, and updates its headers.
*
* @param request The HTTP request to process and sign.
* @param context The HTTP context.
*/
@Override
public void process(HttpRequest request, HttpContext context) throws IOException {
SdkHttpFullRequest requestToSign = buildSdkHttpRequest(request, context);
SdkHttpFullRequest signedRequest = signRequest(requestToSign);
updateRequestHeaders(request, signedRequest.headers());
}

private SdkHttpFullRequest buildSdkHttpRequest(HttpRequest request, HttpContext context) throws IOException {
URIBuilder uriBuilder = parseUri(request);
SdkHttpFullRequest.Builder builder = SdkHttpFullRequest.builder()
.method(SdkHttpMethod.fromValue(request.getRequestLine().getMethod()))
.port(HTTPS_PORT)
.protocol(HTTPS_PROTOCOL)
.headers(headerArrayToMap(request.getAllHeaders()))
.rawQueryParameters(nvpToMapParams(uriBuilder.getQueryParams()));

HttpHost host = (HttpHost) context.getAttribute(HTTP_TARGET_HOST);
if (host == null) {
throw new InvalidStateException("host must not be null");
}
builder.host(host.toURI());
try {
builder.encodedPath(uriBuilder.build().getRawPath());
} catch (URISyntaxException e) {
throw new IOException("Invalid URI" , e);
}
return builder.build();
}

private URIBuilder parseUri(HttpRequest request) throws IOException {
try {
return new URIBuilder(request.getRequestLine().getUri());
} catch (URISyntaxException e) {
throw new IOException("Invalid URI", e);
}
}

private SdkHttpFullRequest signRequest(SdkHttpFullRequest request) {
AWSSessionCredentials sessionCredentials = (AWSSessionCredentials) awsCredentialsProvider.getCredentials();
AwsSessionCredentials awsCredentials = AwsSessionCredentials.create(
sessionCredentials.getAWSAccessKeyId(),
sessionCredentials.getAWSSecretKey(),
sessionCredentials.getSessionToken()
);

ExecutionAttributes executionAttributes = new ExecutionAttributes()
.putAttribute(AwsSignerExecutionAttribute.AWS_CREDENTIALS, awsCredentials)
.putAttribute(AwsSignerExecutionAttribute.SERVICE_SIGNING_NAME, service)
.putAttribute(AwsSignerExecutionAttribute.SIGNING_REGION, Region.of(region));

return signer.sign(request, executionAttributes);
}

private void updateRequestHeaders(HttpRequest request, Map<String, List<String>> signedHeaders) {
Header[] headers = convertHeaderMapToArray(signedHeaders);
request.setHeaders(headers);
}

/**
* Converts an array of Headers into a map, accumulating multiple values for the same header name into a single list.
*
* @param headers The array of Headers to convert.
* @return A map of header names to their corresponding list of values.
*/
private static Map<String, List<String>> headerArrayToMap(final Header[] headers) {
Map<String, List<String>> headersMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
for (Header header : headers) {
if (!skipHeader(header)) {
headersMap.computeIfAbsent(header.getName(), k -> new ArrayList<>()).add(header.getValue());
}
}
return headersMap;
}

/**
* Converts a map of header names to lists of values back into an array of Headers.
*
* @param mapHeaders The map of headers to convert.
* @return An array of Headers.
*/
private Header[] convertHeaderMapToArray(final Map<String, List<String>> mapHeaders) {
return mapHeaders.entrySet().stream()
.map(entry -> new BasicHeader(entry.getKey(), String.join(",", entry.getValue())))
.toArray(Header[]::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ public void process(final HttpRequest request, final HttpContext context)
* @param params list of HTTP query params as NameValuePairs
* @return a multimap of HTTP query params
*/
private static Map<String, List<String>> nvpToMapParams(final List<NameValuePair> params) {
static Map<String, List<String>> nvpToMapParams(final List<NameValuePair> params) {
Map<String, List<String>> parameterMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
for (NameValuePair nvp : params) {
List<String> argsList =
Expand Down Expand Up @@ -154,7 +154,7 @@ private static Map<String, String> headerArrayToMap(final Header[] headers) {
* @param header header line to check
* @return true if the given header should be excluded when signing
*/
private static boolean skipHeader(final Header header) {
static boolean skipHeader(final Header header) {
return ("content-length".equalsIgnoreCase(header.getName())
&& "0".equals(header.getValue())) // Strip Content-Length: 0
|| "host".equalsIgnoreCase(header.getName()); // Host comes from endpoint
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.core.auth;

import com.amazonaws.auth.AWS4Signer;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.Signer;
import org.apache.http.HttpException;
import org.apache.http.HttpRequest;
import org.apache.http.HttpRequestInterceptor;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.protocol.HttpContext;
import software.amazon.awssdk.authcrt.signer.AwsCrtV4aSigner;

import java.io.IOException;
import java.net.URISyntaxException;
Expand All @@ -19,33 +25,36 @@ public class ResourceBasedAWSRequestSigningApacheInterceptor implements HttpRequ

private final String service;
private final String metadataAccessIdentifier;
final AWSRequestSigningApacheInterceptor primaryInterceptor;
final AWSRequestSigningApacheInterceptor metadataAccessInterceptor;
final HttpRequestInterceptor primaryInterceptor;
final HttpRequestInterceptor metadataAccessInterceptor;

/**
* Constructs an interceptor for AWS request signing with optional metadata access.
*
* @param service The AWS service name.
* @param signer The AWS request signer.
* @param region The AWS region for signing.
* @param primaryCredentialsProvider The credentials provider for general access.
* @param metadataAccessCredentialsProvider The credentials provider for metadata access.
* @param metadataAccessIdentifier Identifier for operations requiring metadata access.
*/
public ResourceBasedAWSRequestSigningApacheInterceptor(final String service,
final Signer signer,
final String region,
final AWSCredentialsProvider primaryCredentialsProvider,
final AWSCredentialsProvider metadataAccessCredentialsProvider,
final String metadataAccessIdentifier) {
this(service,
new AWSRequestSigningApacheInterceptor(service, signer, primaryCredentialsProvider),
new AWSRequestSigningApacheInterceptor(service, signer, metadataAccessCredentialsProvider),
metadataAccessIdentifier);
this.service = service == null ? "unknown" : service;
AWS4Signer signer = new AWS4Signer();
signer.setServiceName(this.service);
signer.setRegionName(region);
this.primaryInterceptor = new AWSRequestSigningApacheInterceptor(this.service, signer, primaryCredentialsProvider);
this.metadataAccessInterceptor = new AWSRequestSigV4ASigningApacheInterceptor(this.service, region, AwsCrtV4aSigner.builder().build(), metadataAccessCredentialsProvider);
this.metadataAccessIdentifier = metadataAccessIdentifier;
}

// Test constructor allowing injection of mock interceptors
ResourceBasedAWSRequestSigningApacheInterceptor(final String service,
final AWSRequestSigningApacheInterceptor primaryInterceptor,
final AWSRequestSigningApacheInterceptor metadataAccessInterceptor,
final HttpRequestInterceptor primaryInterceptor,
final HttpRequestInterceptor metadataAccessInterceptor,
final String metadataAccessIdentifier) {
this.service = service == null ? "unknown" : service;
this.primaryInterceptor = primaryInterceptor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,6 @@ public IRestHighLevelClient createClient() {

// SigV4 support
if (options.getAuth().equals(FlintOptions.SIGV4_AUTH)) {
AWS4Signer signer = new AWS4Signer();
signer.setServiceName("es");
signer.setRegionName(options.getRegion());

// Use DefaultAWSCredentialsProviderChain by default.
final AtomicReference<AWSCredentialsProvider> customAWSCredentialsProvider =
new AtomicReference<>(new DefaultAWSCredentialsProviderChain());
Expand All @@ -283,7 +279,7 @@ public IRestHighLevelClient createClient() {
restClientBuilder.setHttpClientConfigCallback(builder -> {
HttpAsyncClientBuilder delegate = builder.addInterceptorLast(
new ResourceBasedAWSRequestSigningApacheInterceptor(
signer.getServiceName(), signer, customAWSCredentialsProvider.get(), metadataAccessAWSCredentialsProvider.get(), options.getSystemIndexName()));
"es", options.getRegion(), customAWSCredentialsProvider.get(), metadataAccessAWSCredentialsProvider.get(), options.getSystemIndexName()));
return RetryableHttpAsyncClient.builder(delegate, options);
}
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.core.auth;

import org.apache.http.HttpHost;
import org.apache.http.HttpRequest;
import org.apache.http.message.BasicHttpRequest;
import org.apache.http.protocol.HttpContext;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.AWSSessionCredentials;

import software.amazon.awssdk.core.signer.Signer;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.SdkHttpMethod;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

import static org.apache.http.protocol.HttpCoreContext.HTTP_TARGET_HOST;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;

public class AWSRequestSigV4ASigningApacheInterceptorTest {
@Mock
private Signer mockSigner;
@Mock
private AWSCredentialsProvider mockCredentialsProvider;
@Mock
private HttpContext mockContext;
@Mock
private AWSSessionCredentials mockSessionCredentials;
@Mock
private SdkHttpFullRequest mockSdkHttpFullRequest;
@Captor
private ArgumentCaptor<SdkHttpFullRequest> sdkHttpFullRequestCaptor;

private AWSRequestSigV4ASigningApacheInterceptor interceptor;

@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
when(mockCredentialsProvider.getCredentials()).thenReturn(mockSessionCredentials);
when(mockSessionCredentials.getAWSAccessKeyId()).thenReturn("ACCESS_KEY_ID");
when(mockSessionCredentials.getAWSSecretKey()).thenReturn("SECRET_ACCESS_KEY");
when(mockSessionCredentials.getSessionToken()).thenReturn("SESSION_TOKEN");
interceptor = new AWSRequestSigV4ASigningApacheInterceptor("s3", "us-west-2", mockSigner, mockCredentialsProvider);
when(mockContext.getAttribute(HTTP_TARGET_HOST)).thenReturn(new HttpHost("localhost", 443, "https"));
when(mockSigner.sign(any(), any())).thenReturn(mockSdkHttpFullRequest);
}

@Test
public void testSigningProcess() throws Exception {
HttpRequest request = new BasicHttpRequest("GET", "/path/to/resource");
interceptor.process(request, mockContext);

verify(mockSigner).sign(sdkHttpFullRequestCaptor.capture(), any());
SdkHttpFullRequest signedRequest = sdkHttpFullRequestCaptor.getValue();

assertEquals(SdkHttpMethod.GET, signedRequest.method());
assertEquals("/path/to/resource", signedRequest.encodedPath());
}

@Test(expected = IOException.class)
public void testInvalidUriHandling() throws Exception {
HttpRequest request = new BasicHttpRequest("GET", ":///this/is/not/a/valid/uri");
interceptor.process(request, mockContext);
}

@Test
public void testHeaderUpdateAfterSigning() throws Exception {
// Setup mock signer to return a new SdkHttpFullRequest with an "Authorization" header
when(mockSigner.sign(any(SdkHttpFullRequest.class), any())).thenAnswer(invocation -> {
SdkHttpFullRequest originalRequest = invocation.getArgument(0);
Map<String, List<String>> modifiedHeaders = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
modifiedHeaders.putAll(originalRequest.headers());
modifiedHeaders.put("Authorization", List.of("AWS4-HMAC-SHA256 Credential=..."));

// Build a new SdkHttpFullRequest with the modified headers
return SdkHttpFullRequest.builder()
.method(originalRequest.method())
.uri(originalRequest.getUri())
.headers(modifiedHeaders)
.build();
});

HttpRequest request = new BasicHttpRequest("GET", "/path/to/resource");
interceptor.process(request, mockContext);

// Now verify that the HttpRequest has been updated with the new headers from the signed request
assertTrue("The request does not contain the expected 'Authorization' header",
request.containsHeader("Authorization"));
assertEquals("AWS4-HMAC-SHA256 Credential=...",
request.getFirstHeader("Authorization").getValue());
}
}
Loading

0 comments on commit b8a70b7

Please sign in to comment.