Skip to content

Commit

Permalink
Merge branch 'main' into improve-index-validation
Browse files Browse the repository at this point in the history
  • Loading branch information
dai-chen committed Apr 17, 2024
2 parents 7731309 + b69f38b commit 196ee3e
Show file tree
Hide file tree
Showing 12 changed files with 263 additions and 27 deletions.
1 change: 1 addition & 0 deletions MAINTAINERS.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ This document contains a list of maintainers in this repo. See [opensearch-proje
| Sean Kao | [seankao-az](https://github.com/seankao-az) | Amazon |
| Anirudha Jadhav | [anirudha](https://github.com/anirudha) | Amazon |
| Kaituo Li | [kaituo](https://github.com/kaituo) | Amazon |
| Louis Chu | [noCharger](https://github.com/noCharger) | Amazon |
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,10 @@ <T> OptimisticTransaction<T> startTransaction(String indexName, String dataSourc
* @return {@link FlintWriter}
*/
FlintWriter createWriter(String indexName);

/**
* Create {@link IRestHighLevelClient}.
* @return {@link IRestHighLevelClient}
*/
public IRestHighLevelClient createClient();
IRestHighLevelClient createClient();
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package org.opensearch.flint.core;

import dev.failsafe.RetryPolicy;
import java.io.Serializable;
import java.util.Map;
import org.opensearch.flint.core.http.FlintRetryOptions;
Expand Down Expand Up @@ -46,10 +45,14 @@ public class FlintOptions implements Serializable {

public static final String CUSTOM_AWS_CREDENTIALS_PROVIDER = "customAWSCredentialsProvider";

public static final String METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER = "spark.metadata.accessAWSCredentialsProvider";

/**
* By default, customAWSCredentialsProvider is empty. use DefaultAWSCredentialsProviderChain.
* By default, customAWSCredentialsProvider and accessAWSCredentialsProvider are empty. use DefaultAWSCredentialsProviderChain.
*/
public static final String DEFAULT_CUSTOM_AWS_CREDENTIALS_PROVIDER = "";
public static final String DEFAULT_AWS_CREDENTIALS_PROVIDER = "";

public static final String SYSTEM_INDEX_KEY_NAME = "spark.flint.job.requestIndex";

/**
* Used by {@link org.opensearch.flint.core.storage.OpenSearchScrollReader}
Expand Down Expand Up @@ -121,7 +124,11 @@ public String getAuth() {
}

public String getCustomAwsCredentialsProvider() {
return options.getOrDefault(CUSTOM_AWS_CREDENTIALS_PROVIDER, "");
return options.getOrDefault(CUSTOM_AWS_CREDENTIALS_PROVIDER, DEFAULT_AWS_CREDENTIALS_PROVIDER);
}

public String getMetadataAccessAwsCredentialsProvider() {
return options.getOrDefault(METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER, DEFAULT_AWS_CREDENTIALS_PROVIDER);
}

public String getUsername() {
Expand All @@ -139,4 +146,8 @@ public int getSocketTimeoutMillis() {
public String getDataSourceName() {
return options.getOrDefault(DATA_SOURCE_NAME, "");
}

public String getSystemIndexName() {
return options.getOrDefault(SYSTEM_INDEX_KEY_NAME, "");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package org.opensearch.flint.core.auth;

import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.Signer;
import org.apache.http.HttpException;
import org.apache.http.HttpRequest;
import org.apache.http.HttpRequestInterceptor;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.protocol.HttpContext;

import java.io.IOException;
import java.net.URISyntaxException;

/**
* Intercepts HTTP requests to sign them for AWS authentication, adjusting the signing process
* based on whether the request accesses metadata or not.
*/
public class ResourceBasedAWSRequestSigningApacheInterceptor implements HttpRequestInterceptor {

private final String service;
private final String metadataAccessIdentifier;
final AWSRequestSigningApacheInterceptor primaryInterceptor;
final AWSRequestSigningApacheInterceptor metadataAccessInterceptor;

/**
* Constructs an interceptor for AWS request signing with optional metadata access.
*
* @param service The AWS service name.
* @param signer The AWS request signer.
* @param primaryCredentialsProvider The credentials provider for general access.
* @param metadataAccessCredentialsProvider The credentials provider for metadata access.
* @param metadataAccessIdentifier Identifier for operations requiring metadata access.
*/
public ResourceBasedAWSRequestSigningApacheInterceptor(final String service,
final Signer signer,
final AWSCredentialsProvider primaryCredentialsProvider,
final AWSCredentialsProvider metadataAccessCredentialsProvider,
final String metadataAccessIdentifier) {
this(service,
new AWSRequestSigningApacheInterceptor(service, signer, primaryCredentialsProvider),
new AWSRequestSigningApacheInterceptor(service, signer, metadataAccessCredentialsProvider),
metadataAccessIdentifier);
}

// Test constructor allowing injection of mock interceptors
ResourceBasedAWSRequestSigningApacheInterceptor(final String service,
final AWSRequestSigningApacheInterceptor primaryInterceptor,
final AWSRequestSigningApacheInterceptor metadataAccessInterceptor,
final String metadataAccessIdentifier) {
this.service = service == null ? "unknown" : service;
this.primaryInterceptor = primaryInterceptor;
this.metadataAccessInterceptor = metadataAccessInterceptor;
this.metadataAccessIdentifier = metadataAccessIdentifier;
}

/**
* Processes an HTTP request, signing it according to whether it requires metadata access.
*
* @param request The HTTP request to process.
* @param context The context in which the HTTP request is being processed.
* @throws HttpException If processing the HTTP request results in an exception.
* @throws IOException If an I/O error occurs.
*/
@Override
public void process(HttpRequest request, HttpContext context) throws HttpException, IOException {
String resourcePath = parseUriToPath(request);
if ("es".equals(this.service) && isMetadataAccess(resourcePath)) {
metadataAccessInterceptor.process(request, context);
} else {
primaryInterceptor.process(request, context);
}
}

/**
* Extracts and returns the path component of a URI from an HTTP request.
*
* @param request The HTTP request from which to extract the URI path.
* @return The path component of the URI.
* @throws IOException If an error occurs parsing the URI.
*/
private String parseUriToPath(HttpRequest request) throws IOException {
try {
URIBuilder uriBuilder = new URIBuilder(request.getRequestLine().getUri());
return uriBuilder.build().getRawPath();
} catch (URISyntaxException e) {
throw new IOException("Invalid URI", e);
}
}

/**
* Determines whether the accessed resource requires metadata credentials.
*
* @param resourcePath The path of the resource being accessed.
* @return true if the operation requires metadata access credentials, false otherwise.
*/
private boolean isMetadataAccess(String resourcePath) {
return resourcePath.contains(metadataAccessIdentifier);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
import org.opensearch.flint.core.FlintClient;
import org.opensearch.flint.core.FlintOptions;
import org.opensearch.flint.core.IRestHighLevelClient;
import org.opensearch.flint.core.auth.AWSRequestSigningApacheInterceptor;
import org.opensearch.flint.core.auth.ResourceBasedAWSRequestSigningApacheInterceptor;
import org.opensearch.flint.core.http.RetryableHttpAsyncClient;
import org.opensearch.flint.core.metadata.FlintMetadata;
import org.opensearch.flint.core.metadata.log.DefaultOptimisticTransaction;
Expand Down Expand Up @@ -262,26 +262,30 @@ public IRestHighLevelClient createClient() {
signer.setRegionName(options.getRegion());

// Use DefaultAWSCredentialsProviderChain by default.
final AtomicReference<AWSCredentialsProvider> awsCredentialsProvider =
new AtomicReference<>(new DefaultAWSCredentialsProviderChain());
String providerClass = options.getCustomAwsCredentialsProvider();
if (!Strings.isNullOrEmpty(providerClass)) {
try {
Class<?> awsCredentialsProviderClass = Class.forName(providerClass);
Constructor<?> ctor = awsCredentialsProviderClass.getDeclaredConstructor();
ctor.setAccessible(true);
awsCredentialsProvider.set((AWSCredentialsProvider) ctor.newInstance());
} catch (Exception e) {
throw new RuntimeException(e);
}
final AtomicReference<AWSCredentialsProvider> customAWSCredentialsProvider =
new AtomicReference<>(new DefaultAWSCredentialsProviderChain());
String customProviderClass = options.getCustomAwsCredentialsProvider();
if (!Strings.isNullOrEmpty(customProviderClass)) {
instantiateProvider(customProviderClass, customAWSCredentialsProvider);
}

// Set metadataAccessAWSCredentialsProvider to customAWSCredentialsProvider by default for backwards compatibility
// unless a specific metadata access provider class name is provided
String metadataAccessProviderClass = options.getMetadataAccessAwsCredentialsProvider();
final AtomicReference<AWSCredentialsProvider> metadataAccessAWSCredentialsProvider =
new AtomicReference<>(new DefaultAWSCredentialsProviderChain());
if (Strings.isNullOrEmpty(metadataAccessProviderClass)) {
metadataAccessAWSCredentialsProvider.set(customAWSCredentialsProvider.get());
} else {
instantiateProvider(metadataAccessProviderClass, metadataAccessAWSCredentialsProvider);
}

restClientBuilder.setHttpClientConfigCallback(builder -> {
HttpAsyncClientBuilder delegate =
builder.addInterceptorLast(
new AWSRequestSigningApacheInterceptor(
signer.getServiceName(), signer, awsCredentialsProvider.get()));
return RetryableHttpAsyncClient.builder(delegate, options);
}
HttpAsyncClientBuilder delegate = builder.addInterceptorLast(
new ResourceBasedAWSRequestSigningApacheInterceptor(
signer.getServiceName(), signer, customAWSCredentialsProvider.get(), metadataAccessAWSCredentialsProvider.get(), options.getSystemIndexName()));
return RetryableHttpAsyncClient.builder(delegate, options);
}
);
} else if (options.getAuth().equals(FlintOptions.BASIC_AUTH)) {
CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
Expand All @@ -303,6 +307,20 @@ public IRestHighLevelClient createClient() {
return new RestHighLevelClientWrapper(new RestHighLevelClient(restClientBuilder));
}

/**
* Attempts to instantiate the AWS credential provider using reflection.
*/
private void instantiateProvider(String providerClass, AtomicReference<AWSCredentialsProvider> provider) {
try {
Class<?> awsCredentialsProviderClass = Class.forName(providerClass);
Constructor<?> ctor = awsCredentialsProviderClass.getDeclaredConstructor();
ctor.setAccessible(true);
provider.set((AWSCredentialsProvider) ctor.newInstance());
} catch (Exception e) {
throw new RuntimeException("Failed to instantiate AWSCredentialsProvider: " + providerClass, e);
}
}

/*
* Constructs Flint metadata with latest metadata log entry attached if it's available.
* It relies on FlintOptions to provide data source name.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package org.opensearch.flint.core.auth;

import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.Signer;
import org.apache.http.HttpRequest;
import org.apache.http.message.BasicHttpRequest;
import org.apache.http.protocol.HttpContext;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import static org.mockito.Mockito.*;

public class ResourceBasedAWSRequestSigningApacheInterceptorTest {

@Mock
private Signer mockSigner;
@Mock
private AWSCredentialsProvider mockPrimaryCredentialsProvider;
@Mock
private AWSCredentialsProvider mockMetadataAccessCredentialsProvider;
@Mock
private HttpContext mockContext;
@Captor
private ArgumentCaptor<HttpRequest> httpRequestCaptor;

private ResourceBasedAWSRequestSigningApacheInterceptor interceptor;

@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
AWSRequestSigningApacheInterceptor primaryInterceptorSpy = spy(new AWSRequestSigningApacheInterceptor("es", mockSigner, mockPrimaryCredentialsProvider));
AWSRequestSigningApacheInterceptor metadataInterceptorSpy = spy(new AWSRequestSigningApacheInterceptor("es", mockSigner, mockMetadataAccessCredentialsProvider));

interceptor = new ResourceBasedAWSRequestSigningApacheInterceptor(
"es",
primaryInterceptorSpy,
metadataInterceptorSpy,
"/metadata");
}

@Test
public void testProcessWithMetadataAccess() throws Exception {
HttpRequest request = new BasicHttpRequest("GET", "/es/metadata/resource");

interceptor.process(request, mockContext);

verify(interceptor.metadataAccessInterceptor).process(httpRequestCaptor.capture(), eq(mockContext));
verify(interceptor.primaryInterceptor, never()).process(any(HttpRequest.class), any(HttpContext.class));
assert httpRequestCaptor.getValue().getRequestLine().getUri().contains("/metadata");
}

@Test
public void testProcessWithoutMetadataAccess() throws Exception {
HttpRequest request = new BasicHttpRequest("GET", "/es/regular/resource");

interceptor.process(request, mockContext);

verify(interceptor.primaryInterceptor).process(httpRequestCaptor.capture(), eq(mockContext));
verify(interceptor.metadataAccessInterceptor, never()).process(any(HttpRequest.class), any(HttpContext.class));
assert !httpRequestCaptor.getValue().getRequestLine().getUri().contains("/metadata");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ object FlintSparkConf {
FlintConfig("spark.datasource.flint.customAWSCredentialsProvider")
.datasourceOption()
.doc("AWS customAWSCredentialsProvider")
.createWithDefault(FlintOptions.DEFAULT_CUSTOM_AWS_CREDENTIALS_PROVIDER)
.createWithDefault(FlintOptions.DEFAULT_AWS_CREDENTIALS_PROVIDER)

val DOC_ID_COLUMN_NAME = FlintConfig("spark.datasource.flint.write.id_name")
.datasourceOption()
Expand Down Expand Up @@ -174,6 +174,10 @@ object FlintSparkConf {
FlintConfig(s"spark.flint.job.inactivityLimitMillis")
.doc("inactivity timeout")
.createWithDefault(String.valueOf(FlintOptions.DEFAULT_INACTIVITY_LIMIT_MILLIS))
val METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER =
FlintConfig("spark.metadata.accessAWSCredentialsProvider")
.doc("AWS credentials provider for metadata access permission")
.createOptional()
}

/**
Expand Down Expand Up @@ -234,6 +238,7 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable
DATA_SOURCE_NAME,
SESSION_ID,
REQUEST_INDEX,
METADATA_ACCESS_AWS_CREDENTIALS_PROVIDER,
EXCLUDE_JOB_IDS)
.map(conf => (conf.optionKey, conf.readFrom(reader)))
.flatMap {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import java.util.Optional

import scala.collection.JavaConverters._

import org.opensearch.flint.core.FlintOptions
import org.opensearch.flint.core.http.FlintRetryOptions._
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

Expand Down Expand Up @@ -62,6 +63,18 @@ class FlintSparkConfSuite extends FlintSuite {
retryOptions.getRetryableExceptionClassNames.get() shouldBe "java.net.ConnectException"
}

test("test metadata access AWS credentials provider option") {
withSparkConf("spark.metadata.accessAWSCredentialsProvider") {
spark.conf.set(
"spark.metadata.accessAWSCredentialsProvider",
"com.example.MetadataAccessCredentialsProvider")
val flintOptions = FlintSparkConf().flintOptions()
assert(flintOptions.getCustomAwsCredentialsProvider == "")
assert(
flintOptions.getMetadataAccessAwsCredentialsProvider == "com.example.MetadataAccessCredentialsProvider")
}
}

/**
* Delete index `indexNames` after calling `f`.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ object FlintJob extends Logging with FlintJobExecutor {
conf.set(FlintSparkConf.JOB_TYPE.key, jobType)

val dataSource = conf.get("spark.flint.datasource.name", "")
val query = queryOption.getOrElse(conf.get(FlintSparkConf.QUERY.key, ""))
val query = queryOption.getOrElse(unescapeQuery(conf.get(FlintSparkConf.QUERY.key, "")))
if (query.isEmpty) {
throw new IllegalArgumentException(s"Query undefined for the ${jobType} job.")
}
Expand Down
Loading

0 comments on commit 196ee3e

Please sign in to comment.