Skip to content

Commit

Permalink
Add super admin auth provider
Browse files Browse the repository at this point in the history
Signed-off-by: Louis Chu <[email protected]>
  • Loading branch information
noCharger committed Mar 28, 2024
1 parent a38747f commit f0d72df
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,5 @@ <T> OptimisticTransaction<T> startTransaction(String indexName, String dataSourc
* Create {@link IRestHighLevelClient}.
* @return {@link IRestHighLevelClient}
*/
public IRestHighLevelClient createClient();
IRestHighLevelClient createClient();
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package org.opensearch.flint.core;

import dev.failsafe.RetryPolicy;
import java.io.Serializable;
import java.util.Map;
import org.opensearch.flint.core.http.FlintRetryOptions;
Expand Down Expand Up @@ -46,10 +45,14 @@ public class FlintOptions implements Serializable {

public static final String CUSTOM_AWS_CREDENTIALS_PROVIDER = "customAWSCredentialsProvider";

public static final String SUPER_ADMIN_AWS_CREDENTIALS_PROVIDER = "superAdminAWSCredentialsProvider";

/**
* By default, customAWSCredentialsProvider is empty. use DefaultAWSCredentialsProviderChain.
* By default, customAWSCredentialsProvider and superAdminAWSCredentialsProvider are empty. use DefaultAWSCredentialsProviderChain.
*/
public static final String DEFAULT_CUSTOM_AWS_CREDENTIALS_PROVIDER = "";
public static final String DEFAULT_AWS_CREDENTIALS_PROVIDER = "";

public static final String SYSTEM_INDEX_KEY_NAME = "spark.flint.job.requestIndex";

/**
* Used by {@link org.opensearch.flint.core.storage.OpenSearchScrollReader}
Expand Down Expand Up @@ -121,7 +124,11 @@ public String getAuth() {
}

public String getCustomAwsCredentialsProvider() {
return options.getOrDefault(CUSTOM_AWS_CREDENTIALS_PROVIDER, "");
return options.getOrDefault(CUSTOM_AWS_CREDENTIALS_PROVIDER, DEFAULT_AWS_CREDENTIALS_PROVIDER);
}

public String getSuperAdminAwsCredentialsProvider() {
return options.getOrDefault(SUPER_ADMIN_AWS_CREDENTIALS_PROVIDER, DEFAULT_AWS_CREDENTIALS_PROVIDER);
}

public String getUsername() {
Expand All @@ -139,4 +146,8 @@ public int getSocketTimeoutMillis() {
public String getDataSourceName() {
return options.getOrDefault(DATA_SOURCE_NAME, "");
}

public String getSystemIndexName() {
return options.getOrDefault(SYSTEM_INDEX_KEY_NAME, "");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

import org.apache.http.Header;
import org.apache.http.HttpEntityEnclosingRequest;
import org.apache.http.HttpException;
Expand All @@ -32,10 +33,11 @@

/**
* From https://github.com/opensearch-project/sql-jdbc/blob/main/src/main/java/org/opensearch/jdbc/transport/http/auth/aws/AWSRequestSigningApacheInterceptor.java
* An {@link HttpRequestInterceptor} that signs requests using any AWS {@link Signer}
* An {@link HttpRequestInterceptor} that signs requests using any AWS {@link Signer} for SIGV4_AUTH
* and {@link AWSCredentialsProvider}.
*/
public class AWSRequestSigningApacheInterceptor implements HttpRequestInterceptor {

/**
* The service that we're connecting to. Technically not necessary.
* Could be used by a future Signer, though.
Expand All @@ -48,22 +50,43 @@ public class AWSRequestSigningApacheInterceptor implements HttpRequestIntercepto
private final Signer signer;

/**
* The source of AWS credentials for signing.
* Provides the primary source of AWS credentials used for signing requests. These credentials are used
* for the majority of requests, except in cases where elevated permissions are required.
*/
private final AWSCredentialsProvider awsCredentialsProvider;
private final AWSCredentialsProvider primaryCredentialsProvider;

/**
* Provides a source of AWS credentials that are used for signing requests requiring elevated permissions.
* This is particularly useful for accessing resources that are restricted to super-administrative operations,
* such as certain system indices or administrative APIs. These credentials are expected to have permissions
* beyond those of the regular {@link #primaryCredentialsProvider}.
*/
private final AWSCredentialsProvider superAdminAWSCredentialsProvider;

/**
* Identifies data access operations that require super-admin credentials. This identifier can be used to
* distinguish between regular and elevated data access needs, facilitating the decision to use
* {@link #superAdminAWSCredentialsProvider} over {@link #primaryCredentialsProvider} when accessing sensitive
* or restricted resources.
*/
private final String superAdminDataAccessIdentifier;

/**
*
* @param service service that we're connecting to
* @param signer particular signer implementation
* @param awsCredentialsProvider source of AWS credentials for signing
* @param primaryCredentialsProvider source of AWS credentials for signing
*/
public AWSRequestSigningApacheInterceptor(final String service,
final Signer signer,
final AWSCredentialsProvider awsCredentialsProvider) {
final Signer signer,
final AWSCredentialsProvider primaryCredentialsProvider,
final AWSCredentialsProvider superAdminAWSCredentialsProvider,
final String superAdminDataAccessIdentifier) {
this.service = service;
this.signer = signer;
this.awsCredentialsProvider = awsCredentialsProvider;
this.primaryCredentialsProvider = primaryCredentialsProvider;
this.superAdminAWSCredentialsProvider = superAdminAWSCredentialsProvider;
this.superAdminDataAccessIdentifier = superAdminDataAccessIdentifier;
}

/**
Expand Down Expand Up @@ -106,7 +129,11 @@ public void process(final HttpRequest request, final HttpContext context)
signableRequest.setHeaders(headerArrayToMap(request.getAllHeaders()));

// Sign it
signer.sign(signableRequest, awsCredentialsProvider.getCredentials());
if (this.service.equals("es") && isSuperAdminDataAccess(signableRequest.getResourcePath())) {
signer.sign(signableRequest, superAdminAWSCredentialsProvider.getCredentials());
} else {
signer.sign(signableRequest, primaryCredentialsProvider.getCredentials());
}

// Now copy everything back
request.setHeaders(mapToHeaderArray(signableRequest.getHeaders()));
Expand Down Expand Up @@ -136,6 +163,15 @@ private static Map<String, List<String>> nvpToMapParams(final List<NameValuePair
return parameterMap;
}

/**
* @param resourcePath The path of the resource being accessed.
* @return true if the resource path contains the super-admin data access identifier, indicating that
* the operation requires super-admin credentials; false otherwise.
*/
private boolean isSuperAdminDataAccess(String resourcePath) {
return resourcePath.contains(superAdminDataAccessIdentifier);
}

/**
* @param headers modeled Header objects
* @return a Map of header entries
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,27 +261,18 @@ public IRestHighLevelClient createClient() {
signer.setServiceName("es");
signer.setRegionName(options.getRegion());

// Use DefaultAWSCredentialsProviderChain by default.
final AtomicReference<AWSCredentialsProvider> awsCredentialsProvider =
new AtomicReference<>(new DefaultAWSCredentialsProviderChain());
String providerClass = options.getCustomAwsCredentialsProvider();
if (!Strings.isNullOrEmpty(providerClass)) {
try {
Class<?> awsCredentialsProviderClass = Class.forName(providerClass);
Constructor<?> ctor = awsCredentialsProviderClass.getDeclaredConstructor();
ctor.setAccessible(true);
awsCredentialsProvider.set((AWSCredentialsProvider) ctor.newInstance());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
// Initialize and attempt to instantiate custom AWSCredentialsProviders.
final AtomicReference<AWSCredentialsProvider> customAWSCredentialsProvider =
initializeAndInstantiateProvider(options.getCustomAwsCredentialsProvider());
final AtomicReference<AWSCredentialsProvider> superAdminAWSCredentialsProvider =
initializeAndInstantiateProvider(options.getSuperAdminAwsCredentialsProvider());

restClientBuilder.setHttpClientConfigCallback(builder -> {
HttpAsyncClientBuilder delegate =
builder.addInterceptorLast(
new AWSRequestSigningApacheInterceptor(
signer.getServiceName(), signer, awsCredentialsProvider.get()));
return RetryableHttpAsyncClient.builder(delegate, options);
}
HttpAsyncClientBuilder delegate = builder.addInterceptorLast(
new AWSRequestSigningApacheInterceptor(
signer.getServiceName(), signer, customAWSCredentialsProvider.get(), superAdminAWSCredentialsProvider.get(), options.getSystemIndexName()));
return RetryableHttpAsyncClient.builder(delegate, options);
}
);
} else if (options.getAuth().equals(FlintOptions.BASIC_AUTH)) {
CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
Expand All @@ -303,6 +294,26 @@ public IRestHighLevelClient createClient() {
return new RestHighLevelClientWrapper(new RestHighLevelClient(restClientBuilder));
}

/**
* Initializes and possibly instantiates an AWS credentials provider. If a provider class name is provided,
* this method attempts to instantiate the provider using reflection. Otherwise, it defaults to using the
* {@link DefaultAWSCredentialsProviderChain}.
*/
private AtomicReference<AWSCredentialsProvider> initializeAndInstantiateProvider(String providerClass) {
AWSCredentialsProvider provider = new DefaultAWSCredentialsProviderChain();
if (!Strings.isNullOrEmpty(providerClass)) {
try {
Class<?> clazz = Class.forName(providerClass);
Constructor<?> constructor = clazz.getDeclaredConstructor();
constructor.setAccessible(true);
provider = (AWSCredentialsProvider) constructor.newInstance();
} catch (Exception e) {
throw new RuntimeException("Failed to instantiate AWSCredentialsProvider: " + providerClass, e);
}
}
return new AtomicReference<>(provider);
}

/*
* Constructs Flint metadata with latest metadata log entry attached if it's available.
* It relies on FlintOptions to provide data source name.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,13 @@ object FlintSparkConf {
FlintConfig("spark.datasource.flint.customAWSCredentialsProvider")
.datasourceOption()
.doc("AWS customAWSCredentialsProvider")
.createWithDefault(FlintOptions.DEFAULT_CUSTOM_AWS_CREDENTIALS_PROVIDER)
.createWithDefault(FlintOptions.DEFAULT_AWS_CREDENTIALS_PROVIDER)

val SUPER_ADMIN_AWS_CREDENTIALS_PROVIDER =
FlintConfig("spark.datasource.flint.superAdminAWSCredentialsProvider")
.datasourceOption()
.doc("AWS credentials provider for super admin permission")
.createWithDefault(FlintOptions.DEFAULT_AWS_CREDENTIALS_PROVIDER)

val DOC_ID_COLUMN_NAME = FlintConfig("spark.datasource.flint.write.id_name")
.datasourceOption()
Expand Down Expand Up @@ -221,6 +227,7 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable
RETRYABLE_HTTP_STATUS_CODES,
REGION,
CUSTOM_AWS_CREDENTIALS_PROVIDER,
SUPER_ADMIN_AWS_CREDENTIALS_PROVIDER,
USERNAME,
PASSWORD,
SOCKET_TIMEOUT_MILLIS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import java.util.Optional

import scala.collection.JavaConverters._

import org.opensearch.flint.core.FlintOptions
import org.opensearch.flint.core.http.FlintRetryOptions._
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

Expand Down Expand Up @@ -62,6 +63,18 @@ class FlintSparkConfSuite extends FlintSuite {
retryOptions.getRetryableExceptionClassNames.get() shouldBe "java.net.ConnectException"
}

test("test super admin AWS credentials provider option") {
withSparkConf("spark.datasource.flint.superAdminAWSCredentialsProvider") {
spark.conf.set(
"spark.datasource.flint.superAdminAWSCredentialsProvider",
"com.example.superAdminAWSCredentialsProvider")
val flintOptions = FlintSparkConf().flintOptions()
assert(flintOptions.getCustomAwsCredentialsProvider == "")
assert(
flintOptions.getSuperAdminAwsCredentialsProvider == "com.example.superAdminAWSCredentialsProvider")
}
}

/**
* Delete index `indexNames` after calling `f`.
*/
Expand Down

0 comments on commit f0d72df

Please sign in to comment.