Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SigV4 signature when connecting to OpenSearchServerless #473

Merged
merged 3 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,20 @@ public class FlintOptions implements Serializable {

public static final String SCHEME = "scheme";

public static final String AUTH = "auth";
/**
* Service name used for SigV4 signature.
* `es`: Amazon OpenSearch Service
* `aoss`: Amazon OpenSearch Serverless
*/
public static final String SERVICE_NAME = "auth.servicename";
public static final String SERVICE_NAME_ES = "es";
public static final String SERVICE_NAME_AOSS = "aoss";

public static final String AUTH = "auth";
public static final String NONE_AUTH = "noauth";

public static final String SIGV4_AUTH = "sigv4";

public static final String BASIC_AUTH = "basic";

public static final String USERNAME = "auth.username";

public static final String PASSWORD = "auth.password";

public static final String CUSTOM_AWS_CREDENTIALS_PROVIDER = "customAWSCredentialsProvider";
Expand Down Expand Up @@ -131,6 +135,10 @@ public String getAuth() {
return options.getOrDefault(AUTH, NONE_AUTH);
}

public String getServiceName() {
return options.getOrDefault(SERVICE_NAME, SERVICE_NAME_ES);
}

public String getCustomAwsCredentialsProvider() {
return options.getOrDefault(CUSTOM_AWS_CREDENTIALS_PROVIDER, DEFAULT_AWS_CREDENTIALS_PROVIDER);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

package org.opensearch.flint.core.auth;

import static com.amazonaws.auth.internal.SignerConstants.X_AMZ_CONTENT_SHA256;
import static org.apache.http.protocol.HttpCoreContext.HTTP_TARGET_HOST;
import static org.opensearch.flint.core.FlintOptions.SERVICE_NAME_AOSS;

import com.amazonaws.DefaultRequest;
import com.amazonaws.auth.AWSCredentialsProvider;
Expand All @@ -31,6 +33,7 @@
import org.apache.http.entity.BasicHttpEntity;
import org.apache.http.message.BasicHeader;
import org.apache.http.protocol.HttpContext;
import org.opensearch.flint.core.storage.OpenSearchClientUtils;

/**
* From https://github.com/opensearch-project/sql-jdbc/blob/main/src/main/java/org/opensearch/jdbc/transport/http/auth/aws/AWSRequestSigningApacheInterceptor.java
Expand Down Expand Up @@ -74,13 +77,6 @@ public AWSRequestSigningApacheInterceptor(final String service,
@Override
public void process(final HttpRequest request, final HttpContext context)
throws HttpException, IOException {
URIBuilder uriBuilder;
try {
uriBuilder = new URIBuilder(request.getRequestLine().getUri());
} catch (URISyntaxException e) {
throw new IOException("Invalid URI" , e);
}

// Copy Apache HttpRequest to AWS DefaultRequest
DefaultRequest<?> signableRequest = new DefaultRequest<>(service);

Expand All @@ -91,7 +87,10 @@ public void process(final HttpRequest request, final HttpContext context)
final HttpMethodName httpMethod =
HttpMethodName.fromValue(request.getRequestLine().getMethod());
signableRequest.setHttpMethod(httpMethod);

URIBuilder uriBuilder;
try {
uriBuilder = new URIBuilder(request.getRequestLine().getUri());
signableRequest.setResourcePath(uriBuilder.build().getRawPath());
} catch (URISyntaxException e) {
throw new IOException("Invalid URI" , e);
Expand All @@ -110,6 +109,10 @@ public void process(final HttpRequest request, final HttpContext context)
signableRequest.setParameters(nvpToMapParams(uriBuilder.getQueryParams()));
signableRequest.setHeaders(headerArrayToMap(request.getAllHeaders()));

if (SERVICE_NAME_AOSS.equals(service)) {
enableContentBodySignature(signableRequest);
}

// Sign it
signer.sign(signableRequest, awsCredentialsProvider.getCredentials());

Expand All @@ -126,6 +129,11 @@ public void process(final HttpRequest request, final HttpContext context)
}
}

private void enableContentBodySignature(DefaultRequest<?> signableRequest) {
// AWS4Signer will add `x-amz-content-sha256` header when this header is set
signableRequest.addHeader(X_AMZ_CONTENT_SHA256, "required");
}

/**
*
* @param params list of HTTP query params as NameValuePairs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.flint.core.auth;

import static org.opensearch.flint.core.FlintOptions.SERVICE_NAME_ES;

import com.amazonaws.auth.AWS4Signer;
import com.amazonaws.auth.AWSCredentialsProvider;
import org.apache.http.HttpException;
Expand All @@ -14,6 +16,7 @@
import org.apache.http.protocol.HttpContext;
import org.jetbrains.annotations.TestOnly;
import org.opensearch.common.Strings;
import org.opensearch.flint.core.storage.OpenSearchClientUtils;
import software.amazon.awssdk.authcrt.signer.AwsCrtV4aSigner;

import java.io.IOException;
Expand Down Expand Up @@ -84,7 +87,7 @@ public ResourceBasedAWSRequestSigningApacheInterceptor(final String service,
@Override
public void process(HttpRequest request, HttpContext context) throws HttpException, IOException {
String resourcePath = parseUriToPath(request);
if ("es".equals(this.service) && isMetadataAccess(resourcePath)) {
if (SERVICE_NAME_ES.equals(this.service) && isMetadataAccess(resourcePath)) {
metadataAccessInterceptor.process(request, context);
} else {
primaryInterceptor.process(request, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
*/
public class OpenSearchClientUtils {

private static final String SERVICE_NAME = "es";

/**
* Metadata log index name prefix
*/
Expand Down Expand Up @@ -90,7 +88,7 @@ private static RestClientBuilder configureSigV4Auth(RestClientBuilder restClient
restClientBuilder.setHttpClientConfigCallback(builder -> {
HttpAsyncClientBuilder delegate = builder.addInterceptorLast(
new ResourceBasedAWSRequestSigningApacheInterceptor(
SERVICE_NAME, options.getRegion(), customAWSCredentialsProvider.get(), metadataAccessAWSCredentialsProvider.get(), systemIndexName));
options.getServiceName(), options.getRegion(), customAWSCredentialsProvider.get(), metadataAccessAWSCredentialsProvider.get(), systemIndexName));
return RetryableHttpAsyncClient.builder(delegate, options);
}
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package org.opensearch.flint.core.auth;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static software.amazon.awssdk.auth.signer.internal.SignerConstant.X_AMZ_CONTENT_SHA256;

import com.amazonaws.DefaultRequest;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.Signer;
import com.amazonaws.http.HttpMethodName;
import com.amazonaws.util.IOUtils;
import java.io.IOException;
import java.net.URI;
import org.apache.http.HttpHost;
import org.apache.http.entity.BasicHttpEntity;
import org.apache.http.message.BasicHttpEntityEnclosingRequest;
import org.apache.http.protocol.BasicHttpContext;
import org.apache.http.protocol.HttpCoreContext;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.utils.StringInputStream;

@ExtendWith(MockitoExtension.class)
class AWSRequestSigningApacheInterceptorTest {

@Mock
AWSCredentialsProvider awsCredentialsProvider;
@Mock Signer signer;
@Mock
AWSCredentials awsCredentials;

@Captor
ArgumentCaptor<DefaultRequest<?>> signableRequestCaptor;

@Test
public void testProcessWithServiceIsEs() throws Exception {
AWSRequestSigningApacheInterceptor awsRequestSigningApacheInterceptor = new AWSRequestSigningApacheInterceptor("es", signer, awsCredentialsProvider);
final BasicHttpEntityEnclosingRequest request = getRequestWithEntity();
final BasicHttpContext context = getContext();
when(awsCredentialsProvider.getCredentials()).thenReturn(awsCredentials);

awsRequestSigningApacheInterceptor.process(request, context);

verify(signer).sign(signableRequestCaptor.capture(), eq(awsCredentials));
DefaultRequest<?> signableRequest = signableRequestCaptor.getValue();
assertEquals(new URI("http://hello.world"), signableRequest.getEndpoint());
assertEquals(HttpMethodName.POST, signableRequest.getHttpMethod());
assertEquals("/path", signableRequest.getResourcePath());
assertEquals("ENTITY", IOUtils.toString(signableRequest.getContent()));
assertEquals("HeaderValue", signableRequest.getHeaders().get("Test-Header"));
assertEquals("value0", signableRequest.getParameters().get("param0").get(0));
}

@Test
public void testProcessWithoutEntity() throws Exception {
AWSRequestSigningApacheInterceptor awsRequestSigningApacheInterceptor = new AWSRequestSigningApacheInterceptor("es", signer, awsCredentialsProvider);
final BasicHttpEntityEnclosingRequest request = getRequest();
final BasicHttpContext context = getContext();
when(awsCredentialsProvider.getCredentials()).thenReturn(awsCredentials);

awsRequestSigningApacheInterceptor.process(request, context);

verify(signer).sign(signableRequestCaptor.capture(), eq(awsCredentials));
DefaultRequest<?> signableRequest = signableRequestCaptor.getValue();
assertEquals("", IOUtils.toString(signableRequest.getContent()));
}

@NotNull
private static BasicHttpContext getContext() {
BasicHttpContext context = new BasicHttpContext();
context.setAttribute(HttpCoreContext.HTTP_TARGET_HOST, new HttpHost("hello.world"));
return context;
}

@Test
public void testProcessWithServiceIsAoss() throws Exception {
AWSRequestSigningApacheInterceptor awsRequestSigningApacheInterceptor = new AWSRequestSigningApacheInterceptor("aoss", signer, awsCredentialsProvider);
final BasicHttpEntityEnclosingRequest request = getRequest();
final BasicHttpContext context = getContext();
when(awsCredentialsProvider.getCredentials()).thenReturn(awsCredentials);

awsRequestSigningApacheInterceptor.process(request, context);

verify(signer).sign(signableRequestCaptor.capture(), eq(awsCredentials));
DefaultRequest<?> signableRequest = signableRequestCaptor.getValue();
assertEquals("required", signableRequest.getHeaders().get(X_AMZ_CONTENT_SHA256));
}

@Test
public void testInvalidURI() throws Exception {
AWSRequestSigningApacheInterceptor awsRequestSigningApacheInterceptor = new AWSRequestSigningApacheInterceptor("aoss", signer, awsCredentialsProvider);
final BasicHttpEntityEnclosingRequest request = new BasicHttpEntityEnclosingRequest("POST", "::INVALID_URI::");
final BasicHttpContext context = getContext();

assertThrows(IOException.class, () -> {
awsRequestSigningApacheInterceptor.process(request, context);
});
}

@NotNull
private static BasicHttpEntityEnclosingRequest getRequestWithEntity() {
BasicHttpEntityEnclosingRequest request = getRequest();
BasicHttpEntity basicHttpEntity = new BasicHttpEntity();
basicHttpEntity.setContent(new StringInputStream("ENTITY"));
request.setEntity(basicHttpEntity);
request.setHeader("content-length", "6");
return request;
}

@NotNull
private static BasicHttpEntityEnclosingRequest getRequest() {
BasicHttpEntityEnclosingRequest request = new BasicHttpEntityEnclosingRequest("POST", "https://hello.world/path?param0=value0");
request.setHeader("Test-Header", "HeaderValue");
request.setHeader("content-length", "0");
return request;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ object FlintSparkConf {
"noauth(no auth), sigv4(sigv4 auth), basic(basic auth)")
.createWithDefault(FlintOptions.NONE_AUTH)

val SERVICE_NAME = FlintConfig("spark.datasource.flint.auth.servicename")
.datasourceOption()
.doc("service name used for SigV4 signature. " +
"es (AWS OpenSearch Service), aoss (Amazon OpenSearch Serverless)")
.createWithDefault(FlintOptions.SERVICE_NAME_ES)

val USERNAME = FlintConfig("spark.datasource.flint.auth.username")
.datasourceOption()
.doc("basic auth username")
Expand Down Expand Up @@ -267,6 +273,7 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable
RETRYABLE_HTTP_STATUS_CODES,
REGION,
CUSTOM_AWS_CREDENTIALS_PROVIDER,
SERVICE_NAME,
USERNAME,
PASSWORD,
SOCKET_TIMEOUT_MILLIS,
Expand Down
Loading