Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AWS credentials provider for metadata access #285

Merged
merged 5 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,10 @@ <T> OptimisticTransaction<T> startTransaction(String indexName, String dataSourc
* @return {@link FlintWriter}
*/
FlintWriter createWriter(String indexName);

/**
* Create {@link IRestHighLevelClient}.
* @return {@link IRestHighLevelClient}
*/
public IRestHighLevelClient createClient();
IRestHighLevelClient createClient();
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package org.opensearch.flint.core;

import dev.failsafe.RetryPolicy;
import java.io.Serializable;
import java.util.Map;
import org.opensearch.flint.core.http.FlintRetryOptions;
Expand Down Expand Up @@ -46,10 +45,14 @@ public class FlintOptions implements Serializable {

public static final String CUSTOM_AWS_CREDENTIALS_PROVIDER = "customAWSCredentialsProvider";

public static final String METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER = "spark.metadata.accessAWSCredentialsProvider";

/**
* By default, customAWSCredentialsProvider is empty. use DefaultAWSCredentialsProviderChain.
* By default, customAWSCredentialsProvider and accessAWSCredentialsProvider are empty. use DefaultAWSCredentialsProviderChain.
*/
public static final String DEFAULT_CUSTOM_AWS_CREDENTIALS_PROVIDER = "";
public static final String DEFAULT_AWS_CREDENTIALS_PROVIDER = "";

public static final String SYSTEM_INDEX_KEY_NAME = "spark.flint.job.requestIndex";

/**
* Used by {@link org.opensearch.flint.core.storage.OpenSearchScrollReader}
Expand Down Expand Up @@ -121,7 +124,11 @@ public String getAuth() {
}

public String getCustomAwsCredentialsProvider() {
return options.getOrDefault(CUSTOM_AWS_CREDENTIALS_PROVIDER, "");
return options.getOrDefault(CUSTOM_AWS_CREDENTIALS_PROVIDER, DEFAULT_AWS_CREDENTIALS_PROVIDER);
}

public String getMetadataAccessAwsCredentialsProvider() {
return options.getOrDefault(METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER, DEFAULT_AWS_CREDENTIALS_PROVIDER);
}

public String getUsername() {
Expand All @@ -139,4 +146,8 @@ public int getSocketTimeoutMillis() {
public String getDataSourceName() {
return options.getOrDefault(DATA_SOURCE_NAME, "");
}

public String getSystemIndexName() {
return options.getOrDefault(SYSTEM_INDEX_KEY_NAME, "");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package org.opensearch.flint.core.auth;

import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.Signer;
import org.apache.http.HttpException;
import org.apache.http.HttpRequest;
import org.apache.http.HttpRequestInterceptor;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.protocol.HttpContext;

import java.io.IOException;
import java.net.URISyntaxException;

/**
* Intercepts HTTP requests to sign them for AWS authentication, adjusting the signing process
* based on whether the request accesses metadata or not.
*/
public class ResourceBasedAWSRequestSigningApacheInterceptor implements HttpRequestInterceptor {

private final String service;
private final String metadataAccessIdentifier;
final AWSRequestSigningApacheInterceptor primaryInterceptor;
final AWSRequestSigningApacheInterceptor metadataAccessInterceptor;

/**
* Constructs an interceptor for AWS request signing with optional metadata access.
*
* @param service The AWS service name.
* @param signer The AWS request signer.
* @param primaryCredentialsProvider The credentials provider for general access.
* @param metadataAccessCredentialsProvider The credentials provider for metadata access.
* @param metadataAccessIdentifier Identifier for operations requiring metadata access.
*/
public ResourceBasedAWSRequestSigningApacheInterceptor(final String service,
final Signer signer,
final AWSCredentialsProvider primaryCredentialsProvider,
final AWSCredentialsProvider metadataAccessCredentialsProvider,
final String metadataAccessIdentifier) {
this(service,
new AWSRequestSigningApacheInterceptor(service, signer, primaryCredentialsProvider),
new AWSRequestSigningApacheInterceptor(service, signer, metadataAccessCredentialsProvider),
metadataAccessIdentifier);
}

// Test constructor allowing injection of mock interceptors
ResourceBasedAWSRequestSigningApacheInterceptor(final String service,
final AWSRequestSigningApacheInterceptor primaryInterceptor,
final AWSRequestSigningApacheInterceptor metadataAccessInterceptor,
final String metadataAccessIdentifier) {
this.service = service == null ? "unknown" : service;
this.primaryInterceptor = primaryInterceptor;
this.metadataAccessInterceptor = metadataAccessInterceptor;
this.metadataAccessIdentifier = metadataAccessIdentifier;
}

/**
* Processes an HTTP request, signing it according to whether it requires metadata access.
*
* @param request The HTTP request to process.
* @param context The context in which the HTTP request is being processed.
* @throws HttpException If processing the HTTP request results in an exception.
* @throws IOException If an I/O error occurs.
*/
@Override
public void process(HttpRequest request, HttpContext context) throws HttpException, IOException {
String resourcePath = parseUriToPath(request);
if ("es".equals(this.service) && isMetadataAccess(resourcePath)) {
metadataAccessInterceptor.process(request, context);
} else {
primaryInterceptor.process(request, context);
}
}

/**
* Extracts and returns the path component of a URI from an HTTP request.
*
* @param request The HTTP request from which to extract the URI path.
* @return The path component of the URI.
* @throws IOException If an error occurs parsing the URI.
*/
private String parseUriToPath(HttpRequest request) throws IOException {
try {
URIBuilder uriBuilder = new URIBuilder(request.getRequestLine().getUri());
return uriBuilder.build().getRawPath();
} catch (URISyntaxException e) {
throw new IOException("Invalid URI", e);
}
}

/**
* Determines whether the accessed resource requires metadata credentials.
*
* @param resourcePath The path of the resource being accessed.
* @return true if the operation requires metadata access credentials, false otherwise.
*/
private boolean isMetadataAccess(String resourcePath) {
return resourcePath.contains(metadataAccessIdentifier);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
import org.opensearch.flint.core.FlintClient;
import org.opensearch.flint.core.FlintOptions;
import org.opensearch.flint.core.IRestHighLevelClient;
import org.opensearch.flint.core.auth.AWSRequestSigningApacheInterceptor;
import org.opensearch.flint.core.auth.ResourceBasedAWSRequestSigningApacheInterceptor;
import org.opensearch.flint.core.http.RetryableHttpAsyncClient;
import org.opensearch.flint.core.metadata.FlintMetadata;
import org.opensearch.flint.core.metadata.log.DefaultOptimisticTransaction;
Expand Down Expand Up @@ -262,26 +262,30 @@ public IRestHighLevelClient createClient() {
signer.setRegionName(options.getRegion());

// Use DefaultAWSCredentialsProviderChain by default.
final AtomicReference<AWSCredentialsProvider> awsCredentialsProvider =
new AtomicReference<>(new DefaultAWSCredentialsProviderChain());
String providerClass = options.getCustomAwsCredentialsProvider();
if (!Strings.isNullOrEmpty(providerClass)) {
try {
Class<?> awsCredentialsProviderClass = Class.forName(providerClass);
Constructor<?> ctor = awsCredentialsProviderClass.getDeclaredConstructor();
ctor.setAccessible(true);
awsCredentialsProvider.set((AWSCredentialsProvider) ctor.newInstance());
} catch (Exception e) {
throw new RuntimeException(e);
}
final AtomicReference<AWSCredentialsProvider> customAWSCredentialsProvider =
noCharger marked this conversation as resolved.
Show resolved Hide resolved
new AtomicReference<>(new DefaultAWSCredentialsProviderChain());
String customProviderClass = options.getCustomAwsCredentialsProvider();
if (!Strings.isNullOrEmpty(customProviderClass)) {
instantiateProvider(customProviderClass, customAWSCredentialsProvider);
}

// Set metadataAccessAWSCredentialsProvider to customAWSCredentialsProvider by default for backwards compatibility
// unless a specific metadata access provider class name is provided
String metadataAccessProviderClass = options.getMetadataAccessAwsCredentialsProvider();
final AtomicReference<AWSCredentialsProvider> metadataAccessAWSCredentialsProvider =
new AtomicReference<>(new DefaultAWSCredentialsProviderChain());
if (Strings.isNullOrEmpty(metadataAccessProviderClass)) {
metadataAccessAWSCredentialsProvider.set(customAWSCredentialsProvider.get());
} else {
instantiateProvider(metadataAccessProviderClass, metadataAccessAWSCredentialsProvider);
}

restClientBuilder.setHttpClientConfigCallback(builder -> {
HttpAsyncClientBuilder delegate =
builder.addInterceptorLast(
new AWSRequestSigningApacheInterceptor(
signer.getServiceName(), signer, awsCredentialsProvider.get()));
return RetryableHttpAsyncClient.builder(delegate, options);
}
HttpAsyncClientBuilder delegate = builder.addInterceptorLast(
new ResourceBasedAWSRequestSigningApacheInterceptor(
signer.getServiceName(), signer, customAWSCredentialsProvider.get(), metadataAccessAWSCredentialsProvider.get(), options.getSystemIndexName()));
return RetryableHttpAsyncClient.builder(delegate, options);
}
);
} else if (options.getAuth().equals(FlintOptions.BASIC_AUTH)) {
CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
Expand All @@ -303,6 +307,20 @@ public IRestHighLevelClient createClient() {
return new RestHighLevelClientWrapper(new RestHighLevelClient(restClientBuilder));
}

/**
* Attempts to instantiate the AWS credential provider using reflection.
*/
private void instantiateProvider(String providerClass, AtomicReference<AWSCredentialsProvider> provider) {
try {
Class<?> awsCredentialsProviderClass = Class.forName(providerClass);
Constructor<?> ctor = awsCredentialsProviderClass.getDeclaredConstructor();
ctor.setAccessible(true);
provider.set((AWSCredentialsProvider) ctor.newInstance());
} catch (Exception e) {
throw new RuntimeException("Failed to instantiate AWSCredentialsProvider: " + providerClass, e);
}
}

/*
* Constructs Flint metadata with latest metadata log entry attached if it's available.
* It relies on FlintOptions to provide data source name.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package org.opensearch.flint.core.auth;

import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.Signer;
import org.apache.http.HttpRequest;
import org.apache.http.message.BasicHttpRequest;
import org.apache.http.protocol.HttpContext;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import static org.mockito.Mockito.*;

public class ResourceBasedAWSRequestSigningApacheInterceptorTest {

@Mock
private Signer mockSigner;
@Mock
private AWSCredentialsProvider mockPrimaryCredentialsProvider;
@Mock
private AWSCredentialsProvider mockMetadataAccessCredentialsProvider;
@Mock
private HttpContext mockContext;
@Captor
private ArgumentCaptor<HttpRequest> httpRequestCaptor;

private ResourceBasedAWSRequestSigningApacheInterceptor interceptor;

@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
AWSRequestSigningApacheInterceptor primaryInterceptorSpy = spy(new AWSRequestSigningApacheInterceptor("es", mockSigner, mockPrimaryCredentialsProvider));
AWSRequestSigningApacheInterceptor metadataInterceptorSpy = spy(new AWSRequestSigningApacheInterceptor("es", mockSigner, mockMetadataAccessCredentialsProvider));

interceptor = new ResourceBasedAWSRequestSigningApacheInterceptor(
"es",
primaryInterceptorSpy,
metadataInterceptorSpy,
"/metadata");
}

@Test
public void testProcessWithMetadataAccess() throws Exception {
HttpRequest request = new BasicHttpRequest("GET", "/es/metadata/resource");

interceptor.process(request, mockContext);

verify(interceptor.metadataAccessInterceptor).process(httpRequestCaptor.capture(), eq(mockContext));
verify(interceptor.primaryInterceptor, never()).process(any(HttpRequest.class), any(HttpContext.class));
assert httpRequestCaptor.getValue().getRequestLine().getUri().contains("/metadata");
}

@Test
public void testProcessWithoutMetadataAccess() throws Exception {
HttpRequest request = new BasicHttpRequest("GET", "/es/regular/resource");

interceptor.process(request, mockContext);

verify(interceptor.primaryInterceptor).process(httpRequestCaptor.capture(), eq(mockContext));
verify(interceptor.metadataAccessInterceptor, never()).process(any(HttpRequest.class), any(HttpContext.class));
assert !httpRequestCaptor.getValue().getRequestLine().getUri().contains("/metadata");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ object FlintSparkConf {
FlintConfig("spark.datasource.flint.customAWSCredentialsProvider")
.datasourceOption()
.doc("AWS customAWSCredentialsProvider")
.createWithDefault(FlintOptions.DEFAULT_CUSTOM_AWS_CREDENTIALS_PROVIDER)
.createWithDefault(FlintOptions.DEFAULT_AWS_CREDENTIALS_PROVIDER)

val DOC_ID_COLUMN_NAME = FlintConfig("spark.datasource.flint.write.id_name")
.datasourceOption()
Expand Down Expand Up @@ -174,6 +174,10 @@ object FlintSparkConf {
FlintConfig(s"spark.flint.job.inactivityLimitMillis")
.doc("inactivity timeout")
.createWithDefault(String.valueOf(FlintOptions.DEFAULT_INACTIVITY_LIMIT_MILLIS))
val METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER =
FlintConfig("spark.metadata.accessAWSCredentialsProvider")
.doc("AWS credentials provider for metadata access permission")
.createOptional()
}

/**
Expand Down Expand Up @@ -234,6 +238,7 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable
DATA_SOURCE_NAME,
SESSION_ID,
REQUEST_INDEX,
METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER,
EXCLUDE_JOB_IDS)
.map(conf => (conf.optionKey, conf.readFrom(reader)))
.flatMap {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import java.util.Optional

import scala.collection.JavaConverters._

import org.opensearch.flint.core.FlintOptions
import org.opensearch.flint.core.http.FlintRetryOptions._
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

Expand Down Expand Up @@ -62,6 +63,18 @@ class FlintSparkConfSuite extends FlintSuite {
retryOptions.getRetryableExceptionClassNames.get() shouldBe "java.net.ConnectException"
}

test("test metadata access AWS credentials provider option") {
withSparkConf("spark.metadata.accessAWSCredentialsProvider") {
spark.conf.set(
"spark.metadata.accessAWSCredentialsProvider",
"com.example.MetadataAccessCredentialsProvider")
val flintOptions = FlintSparkConf().flintOptions()
assert(flintOptions.getCustomAwsCredentialsProvider == "")
assert(
flintOptions.getMetadataAccessAwsCredentialsProvider == "com.example.MetadataAccessCredentialsProvider")
}
}

/**
* Delete index `indexNames` after calling `f`.
*/
Expand Down
Loading