Skip to content

Commit

Permalink
Fix SigV4 signature when connecting to OpenSearchServerless
Browse files Browse the repository at this point in the history
Signed-off-by: Tomoyuki Morita <[email protected]>
  • Loading branch information
ykmr1224 committed Jul 30, 2024
1 parent c6ab291 commit 41e89e5
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,15 @@ public class FlintOptions implements Serializable {

public static final String SCHEME = "scheme";

public static final String AUTH = "auth";
public static final String INDEX_TYPE = "index.type";
public static final String INDEX_TYPE_AOS = "aos";
public static final String INDEX_TYPE_AOSS = "aoss";

public static final String AUTH = "auth";
public static final String NONE_AUTH = "noauth";

public static final String SIGV4_AUTH = "sigv4";

public static final String BASIC_AUTH = "basic";

public static final String USERNAME = "auth.username";

public static final String PASSWORD = "auth.password";

public static final String CUSTOM_AWS_CREDENTIALS_PROVIDER = "customAWSCredentialsProvider";
Expand Down Expand Up @@ -131,6 +130,10 @@ public String getAuth() {
return options.getOrDefault(AUTH, NONE_AUTH);
}

public String getIndexType() {
return options.getOrDefault(INDEX_TYPE, INDEX_TYPE_AOS);
}

public String getCustomAwsCredentialsProvider() {
return options.getOrDefault(CUSTOM_AWS_CREDENTIALS_PROVIDER, DEFAULT_AWS_CREDENTIALS_PROVIDER);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.flint.core.auth;

import static com.amazonaws.auth.internal.SignerConstants.X_AMZ_CONTENT_SHA256;
import static org.apache.http.protocol.HttpCoreContext.HTTP_TARGET_HOST;

import com.amazonaws.DefaultRequest;
Expand All @@ -31,6 +32,7 @@
import org.apache.http.entity.BasicHttpEntity;
import org.apache.http.message.BasicHeader;
import org.apache.http.protocol.HttpContext;
import org.opensearch.flint.core.storage.OpenSearchClientUtils;

/**
* From https://github.com/opensearch-project/sql-jdbc/blob/main/src/main/java/org/opensearch/jdbc/transport/http/auth/aws/AWSRequestSigningApacheInterceptor.java
Expand Down Expand Up @@ -74,13 +76,6 @@ public AWSRequestSigningApacheInterceptor(final String service,
@Override
public void process(final HttpRequest request, final HttpContext context)
throws HttpException, IOException {
URIBuilder uriBuilder;
try {
uriBuilder = new URIBuilder(request.getRequestLine().getUri());
} catch (URISyntaxException e) {
throw new IOException("Invalid URI" , e);
}

// Copy Apache HttpRequest to AWS DefaultRequest
DefaultRequest<?> signableRequest = new DefaultRequest<>(service);

Expand All @@ -91,7 +86,10 @@ public void process(final HttpRequest request, final HttpContext context)
final HttpMethodName httpMethod =
HttpMethodName.fromValue(request.getRequestLine().getMethod());
signableRequest.setHttpMethod(httpMethod);

URIBuilder uriBuilder;
try {
uriBuilder = new URIBuilder(request.getRequestLine().getUri());
signableRequest.setResourcePath(uriBuilder.build().getRawPath());
} catch (URISyntaxException e) {
throw new IOException("Invalid URI" , e);
Expand All @@ -110,6 +108,10 @@ public void process(final HttpRequest request, final HttpContext context)
signableRequest.setParameters(nvpToMapParams(uriBuilder.getQueryParams()));
signableRequest.setHeaders(headerArrayToMap(request.getAllHeaders()));

if (OpenSearchClientUtils.AOSS_SIGV4_SERVICE_NAME.equals(service)) {
enableContentBodySignature(signableRequest);
}

// Sign it
signer.sign(signableRequest, awsCredentialsProvider.getCredentials());

Expand All @@ -126,6 +128,11 @@ public void process(final HttpRequest request, final HttpContext context)
}
}

private void enableContentBodySignature(DefaultRequest<?> signableRequest) {
// AWS4Signer will add `x-amz-content-sha256` header when this header is set
signableRequest.addHeader(X_AMZ_CONTENT_SHA256, "required");
}

/**
*
* @param params list of HTTP query params as NameValuePairs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.apache.http.protocol.HttpContext;
import org.jetbrains.annotations.TestOnly;
import org.opensearch.common.Strings;
import org.opensearch.flint.core.storage.OpenSearchClientUtils;
import software.amazon.awssdk.authcrt.signer.AwsCrtV4aSigner;

import java.io.IOException;
Expand Down Expand Up @@ -84,7 +85,7 @@ public ResourceBasedAWSRequestSigningApacheInterceptor(final String service,
@Override
public void process(HttpRequest request, HttpContext context) throws HttpException, IOException {
String resourcePath = parseUriToPath(request);
if ("es".equals(this.service) && isMetadataAccess(resourcePath)) {
if (OpenSearchClientUtils.AOS_SIGV4_SERVICE_NAME.equals(this.service) && isMetadataAccess(resourcePath)) {
metadataAccessInterceptor.process(request, context);
} else {
primaryInterceptor.process(request, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import java.lang.reflect.Constructor;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.http.HttpHost;
import org.apache.http.auth.AuthScope;
Expand All @@ -30,13 +34,19 @@
*/
public class OpenSearchClientUtils {

private static final String SERVICE_NAME = "es";

/**
* Metadata log index name prefix
*/
public final static String META_LOG_NAME_PREFIX = ".query_execution_request";

public static final String AOS_SIGV4_SERVICE_NAME = "es";
public static final String AOSS_SIGV4_SERVICE_NAME = "aoss";

private static final Map<String, String> INDEX_TYPE_SERVICE_NAME_MAPPING = ImmutableMap.<String, String>builder()
.put(FlintOptions.INDEX_TYPE_AOS, AOS_SIGV4_SERVICE_NAME)
.put(FlintOptions.INDEX_TYPE_AOSS, AOSS_SIGV4_SERVICE_NAME)
.build();

/**
* Used in IT.
*/
Expand Down Expand Up @@ -87,17 +97,24 @@ private static RestClientBuilder configureSigV4Auth(RestClientBuilder restClient
instantiateProvider(metadataAccessProviderClass, metadataAccessAWSCredentialsProvider);
}

String serviceName = getServiceNameForSigV4(options);
restClientBuilder.setHttpClientConfigCallback(builder -> {
HttpAsyncClientBuilder delegate = builder.addInterceptorLast(
new ResourceBasedAWSRequestSigningApacheInterceptor(
SERVICE_NAME, options.getRegion(), customAWSCredentialsProvider.get(), metadataAccessAWSCredentialsProvider.get(), systemIndexName));
serviceName, options.getRegion(), customAWSCredentialsProvider.get(), metadataAccessAWSCredentialsProvider.get(), systemIndexName));
return RetryableHttpAsyncClient.builder(delegate, options);
}
);

return restClientBuilder;
}

@VisibleForTesting
static String getServiceNameForSigV4(FlintOptions options) {
String indexType = options.getIndexType();
return Objects.requireNonNull(INDEX_TYPE_SERVICE_NAME_MAPPING.get(indexType), "Unknown index type was specified: " + indexType);
}

private static RestClientBuilder configureBasicAuth(RestClientBuilder restClientBuilder, FlintOptions options) {
CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
credentialsProvider.setCredentials(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package org.opensearch.flint.core.auth;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static software.amazon.awssdk.auth.signer.internal.SignerConstant.X_AMZ_CONTENT_SHA256;

import com.amazonaws.DefaultRequest;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.Signer;
import com.amazonaws.http.HttpMethodName;
import com.amazonaws.util.IOUtils;
import java.io.IOException;
import java.net.URI;
import org.apache.http.HttpHost;
import org.apache.http.entity.BasicHttpEntity;
import org.apache.http.message.BasicHttpEntityEnclosingRequest;
import org.apache.http.protocol.BasicHttpContext;
import org.apache.http.protocol.HttpCoreContext;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.utils.StringInputStream;

@ExtendWith(MockitoExtension.class)
class AWSRequestSigningApacheInterceptorTest {

@Mock
AWSCredentialsProvider awsCredentialsProvider;
@Mock Signer signer;
@Mock
AWSCredentials awsCredentials;

@Captor
ArgumentCaptor<DefaultRequest<?>> signableRequestCaptor;

@Test
public void testProcessWithServiceIsEs() throws Exception {
AWSRequestSigningApacheInterceptor awsRequestSigningApacheInterceptor = new AWSRequestSigningApacheInterceptor("es", signer, awsCredentialsProvider);
final BasicHttpEntityEnclosingRequest request = getRequestWithEntity();
final BasicHttpContext context = getContext();
when(awsCredentialsProvider.getCredentials()).thenReturn(awsCredentials);

awsRequestSigningApacheInterceptor.process(request, context);

verify(signer).sign(signableRequestCaptor.capture(), eq(awsCredentials));
DefaultRequest<?> signableRequest = signableRequestCaptor.getValue();
assertEquals(new URI("http://hello.world"), signableRequest.getEndpoint());
assertEquals(HttpMethodName.POST, signableRequest.getHttpMethod());
assertEquals("/path", signableRequest.getResourcePath());
assertEquals("ENTITY", IOUtils.toString(signableRequest.getContent()));
assertEquals("HeaderValue", signableRequest.getHeaders().get("Test-Header"));
assertEquals("value0", signableRequest.getParameters().get("param0").get(0));
}

@Test
public void testProcessWithoutEntity() throws Exception {
AWSRequestSigningApacheInterceptor awsRequestSigningApacheInterceptor = new AWSRequestSigningApacheInterceptor("es", signer, awsCredentialsProvider);
final BasicHttpEntityEnclosingRequest request = getRequest();
final BasicHttpContext context = getContext();
when(awsCredentialsProvider.getCredentials()).thenReturn(awsCredentials);

awsRequestSigningApacheInterceptor.process(request, context);

verify(signer).sign(signableRequestCaptor.capture(), eq(awsCredentials));
DefaultRequest<?> signableRequest = signableRequestCaptor.getValue();
assertEquals("", IOUtils.toString(signableRequest.getContent()));
}

@NotNull
private static BasicHttpContext getContext() {
BasicHttpContext context = new BasicHttpContext();
context.setAttribute(HttpCoreContext.HTTP_TARGET_HOST, new HttpHost("hello.world"));
return context;
}

@Test
public void testProcessWithServiceIsAoss() throws Exception {
AWSRequestSigningApacheInterceptor awsRequestSigningApacheInterceptor = new AWSRequestSigningApacheInterceptor("aoss", signer, awsCredentialsProvider);
final BasicHttpEntityEnclosingRequest request = getRequest();
final BasicHttpContext context = getContext();
when(awsCredentialsProvider.getCredentials()).thenReturn(awsCredentials);

awsRequestSigningApacheInterceptor.process(request, context);

verify(signer).sign(signableRequestCaptor.capture(), eq(awsCredentials));
DefaultRequest<?> signableRequest = signableRequestCaptor.getValue();
assertEquals("required", signableRequest.getHeaders().get(X_AMZ_CONTENT_SHA256));
}

@Test
public void testInvalidURI() throws Exception {
AWSRequestSigningApacheInterceptor awsRequestSigningApacheInterceptor = new AWSRequestSigningApacheInterceptor("aoss", signer, awsCredentialsProvider);
final BasicHttpEntityEnclosingRequest request = new BasicHttpEntityEnclosingRequest("POST", "::INVALID_URI::");
final BasicHttpContext context = getContext();

assertThrows(IOException.class, () -> {
awsRequestSigningApacheInterceptor.process(request, context);
});
}

@NotNull
private static BasicHttpEntityEnclosingRequest getRequestWithEntity() {
BasicHttpEntityEnclosingRequest request = getRequest();
BasicHttpEntity basicHttpEntity = new BasicHttpEntity();
basicHttpEntity.setContent(new StringInputStream("ENTITY"));
request.setEntity(basicHttpEntity);
request.setHeader("content-length", "6");
return request;
}

@NotNull
private static BasicHttpEntityEnclosingRequest getRequest() {
BasicHttpEntityEnclosingRequest request = new BasicHttpEntityEnclosingRequest("POST", "https://hello.world/path?param0=value0");
request.setHeader("Test-Header", "HeaderValue");
request.setHeader("content-length", "0");
return request;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package org.opensearch.flint.core.storage;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import com.google.common.collect.ImmutableMap;
import org.junit.jupiter.api.Test;
import org.opensearch.flint.core.FlintOptions;

class OpenSearchClientUtilsTest {

@Test
public void testGetServiceNameForSigV4() {
assertEquals("es",
OpenSearchClientUtils.getServiceNameForSigV4(getFlintOptions(FlintOptions.INDEX_TYPE_AOS)));

assertEquals("aoss", OpenSearchClientUtils.getServiceNameForSigV4(
getFlintOptions(FlintOptions.INDEX_TYPE_AOSS)));

assertThrows(NullPointerException.class, () -> OpenSearchClientUtils.getServiceNameForSigV4(
getFlintOptions("INVALID_INDEX_TYPE")));
}

private FlintOptions getFlintOptions(String indexType) {
return new FlintOptions(ImmutableMap.of(FlintOptions.INDEX_TYPE, indexType));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ object FlintSparkConf {
"noauth(no auth), sigv4(sigv4 auth), basic(basic auth)")
.createWithDefault(FlintOptions.NONE_AUTH)

val INDEX_TYPE = FlintConfig("spark.datasource.flint.index.type")
.datasourceOption()
.doc("type of index storage. supported value: " +
"aos (AWS OpenSearch Service), aoss (Amazon OpenSearch Serverless)")
.createWithDefault(FlintOptions.INDEX_TYPE_AOS)

val USERNAME = FlintConfig("spark.datasource.flint.auth.username")
.datasourceOption()
.doc("basic auth username")
Expand Down Expand Up @@ -267,6 +273,7 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable
RETRYABLE_HTTP_STATUS_CODES,
REGION,
CUSTOM_AWS_CREDENTIALS_PROVIDER,
INDEX_TYPE,
USERNAME,
PASSWORD,
SOCKET_TIMEOUT_MILLIS,
Expand Down

0 comments on commit 41e89e5

Please sign in to comment.