Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature/multi_tenancy] Make tenant awareness setting static #2968

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions common/src/main/java/org/opensearch/sdk/SdkClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,14 @@

import static org.opensearch.sdk.SdkClientUtils.unwrapAndConvertToException;

public class SdkClient implements SettingsChangeListener {
public class SdkClient {

private final SdkClientDelegate delegate;
private volatile Boolean isMultiTenancyEnabled;
private final Boolean isMultiTenancyEnabled;

public SdkClient(SdkClientDelegate delegate) {
public SdkClient(SdkClientDelegate delegate, Boolean multiTenancy) {
this.delegate = delegate;
}

@Override
public void onMultiTenancyEnabledChanged(boolean isEnabled) {
this.isMultiTenancyEnabled = isEnabled;
this.isMultiTenancyEnabled = multiTenancy;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ public CompletionStage<SearchDataObjectResponse> searchDataObjectAsync(
return CompletableFuture.completedFuture(searchResponse);
}
});
sdkClient = new SdkClient(sdkClientImpl);
sdkClient.onMultiTenancyEnabledChanged(true);
sdkClient = new SdkClient(sdkClientImpl, true);
testException = new OpenSearchStatusException("Test", RestStatus.BAD_REQUEST);
interruptedException = new InterruptedException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import java.util.EnumSet;
import java.util.Map;
Expand Down Expand Up @@ -119,8 +120,7 @@ public class LocalClusterIndicesClientTests {
public void setup() {
MockitoAnnotations.openMocks(this);

sdkClient = new SdkClient(new LocalClusterIndicesClient(mockedClient, xContentRegistry));
sdkClient.onMultiTenancyEnabledChanged(false);
sdkClient = new SdkClient(new LocalClusterIndicesClient(mockedClient, xContentRegistry), true);

testDataObject = new TestDataObject("foo");
}
Expand Down Expand Up @@ -559,8 +559,8 @@ public void testSearchDataObjectNotTenantAware() throws IOException {
when(mockedClient.search(any(SearchRequest.class))).thenReturn(future);
when(future.actionGet()).thenReturn(searchResponse);

sdkClient.onMultiTenancyEnabledChanged(false);
SearchDataObjectResponse response = sdkClient
SdkClient sdkClientNoTenant = new SdkClient(new LocalClusterIndicesClient(mockedClient, xContentRegistry), false);
SearchDataObjectResponse response = sdkClientNoTenant
.searchDataObjectAsync(searchRequest, testThreadPool.executor(GENERAL_THREAD_POOL))
.toCompletableFuture()
.join();
Expand Down Expand Up @@ -608,7 +608,6 @@ public void testSearchDataObjectTenantAware() throws IOException {
when(mockedClient.search(any(SearchRequest.class))).thenReturn(future);
when(future.actionGet()).thenReturn(searchResponse);

sdkClient.onMultiTenancyEnabledChanged(true);
SearchDataObjectResponse response = sdkClient
.searchDataObjectAsync(searchRequest, testThreadPool.executor(GENERAL_THREAD_POOL))
.toCompletableFuture()
Expand Down Expand Up @@ -655,9 +654,7 @@ public void testSearchDataObject_Exception() throws IOException {

@Test
public void testSearchDataObject_NullTenantId() throws IOException {
// Tests exception if multitenancy enabled
sdkClient.onMultiTenancyEnabledChanged(true);

// Tests exception if multitenancy enabled
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
SearchDataObjectRequest searchRequest = SearchDataObjectRequest
.builder()
Expand All @@ -675,4 +672,26 @@ public void testSearchDataObject_NullTenantId() throws IOException {
assertEquals(OpenSearchStatusException.class, cause.getClass());
assertEquals("Tenant ID is required when multitenancy is enabled.", cause.getMessage());
}

public void testSearchDataObject_NullTenantNoMultitenancy() throws IOException {
// Tests no status exception if multitenancy not enabled
SdkClient sdkClientNoTenant = new SdkClient(new LocalClusterIndicesClient(mockedClient, xContentRegistry), false);

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
SearchDataObjectRequest searchRequest = SearchDataObjectRequest
.builder()
.indices(TEST_INDEX)
// null tenant Id
.searchSourceBuilder(searchSourceBuilder)
.build();

CompletableFuture<SearchDataObjectResponse> future = sdkClientNoTenant
.searchDataObjectAsync(searchRequest, testThreadPool.executor(GENERAL_THREAD_POOL))
.toCompletableFuture();

CompletionException ce = assertThrows(CompletionException.class, () -> future.join());
Throwable cause = ce.getCause();
assertEquals(UnsupportedOperationException.class, cause.getClass());
assertEquals("test", cause.getMessage());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ public void setup() {
when(threadPool.executor(anyString())).thenReturn(testThreadPool.executor("opensearch_ml_general"));

settings = Settings.builder().build();
sdkClient = new SdkClient(new LocalClusterIndicesClient(client, xContentRegistry));
sdkClient = new SdkClient(new LocalClusterIndicesClient(client, xContentRegistry), true);
mlAgentExecutor = Mockito
.spy(new MLAgentExecutor(client, sdkClient, settings, clusterService, xContentRegistry, toolFactories, memoryMap, false));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public void setUp() {
MockitoAnnotations.openMocks(this);
masterKey = new ConcurrentHashMap<>();
masterKey.put(DEFAULT_TENANT_ID, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
sdkClient = new SdkClient(new LocalClusterIndicesClient(client, xContentRegistry));
sdkClient = new SdkClient(new LocalClusterIndicesClient(client, xContentRegistry), true);

doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
Expand Down
13 changes: 10 additions & 3 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ dependencies {
implementation("software.amazon.awssdk:utils:2.25.40")
// AWS OpenSearch Service dependency
implementation("software.amazon.awssdk:apache-client:2.25.40")

configurations.all {
resolutionStrategy.force 'org.apache.httpcomponents.core5:httpcore5:5.2.4'
resolutionStrategy.force 'org.apache.httpcomponents.core5:httpcore5-h2:5.2.4'
Expand All @@ -104,7 +104,8 @@ dependencies {

publishing {
publications {
pluginZip(MavenPublication) { publication ->
pluginZip(MavenPublication) {
publication ->
pom {
name = opensearchplugin.name
description = opensearchplugin.description
Expand Down Expand Up @@ -173,7 +174,9 @@ task integTest(type: RestIntegTestTask) {
testClassesDirs = sourceSets.test.output.classesDirs
classpath = sourceSets.test.runtimeClasspath
}
tasks.named("check").configure { dependsOn(integTest) }
tasks.named("check").configure {
dependsOn(integTest)
}

integTest {
dependsOn "bundlePlugin"
Expand Down Expand Up @@ -246,6 +249,10 @@ testClusters.integTest {
environment "AWS_SECRET_ACCESS_KEY", System.getenv("AWS_SECRET_ACCESS_KEY");
environment "AWS_SESSION_TOKEN", System.getenv("AWS_SESSION_TOKEN");

if (System.getProperty("tests.rest.tenantaware") != null) {
environment "plugins.ml_commons.multi_tenancy_enabled", "true"
}

testDistribution = "ARCHIVE"
// Cluster shrink exception thrown if we try to set numberOfNodes to 1, so only apply if > 1
if (_numNodes > 1) numberOfNodes = _numNodes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,8 +613,6 @@ public Collection<Object> createComponents(
memoryFactoryMap,
mlFeatureEnabledSetting.isMultiTenancyEnabled()
);
// Register the sdkClient as a listener
mlFeatureEnabledSetting.addListener(sdkClient);
// Register the agentExecutor as a listener
mlFeatureEnabledSetting.addListener(agentExecutor);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ public CompletionStage<SearchDataObjectResponse> searchDataObjectAsync(

private String getIndexName(String index) {
// System index is not supported in remote index. Replacing '.' from index name.
return index.replaceAll("\\.", "");
return (index.length() > 1 && index.charAt(0) == '.') ? index.substring(1) : index;
}

private XContentParser createParser(String json) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ public CompletionStage<SearchDataObjectResponse> searchDataObjectAsync(
) {
return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction<SearchDataObjectResponse>) () -> {
try {
log.info("Searching {}", Arrays.toString(request.indices()), null);
log.info("Searching {}", Arrays.toString(request.indices()));
// work around https://github.com/opensearch-project/opensearch-java/issues/1150
String json = SdkClientUtils
.lowerCaseEnumValues(
Expand All @@ -254,6 +254,8 @@ public CompletionStage<SearchDataObjectResponse> searchDataObjectAsync(
.filter(tenantIdFilterQuery.toQuery())
.build();
searchRequest = searchRequest.toBuilder().index(Arrays.asList(request.indices())).query(boolQuery.toQuery()).build();
} else {
searchRequest = searchRequest.toBuilder().index(Arrays.asList(request.indices())).build();
}
SearchResponse<?> searchResponse = openSearchClient.search(searchRequest, MAP_DOCTYPE);
log.info("Search returned {} hits", searchResponse.hits().total().value());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/
package org.opensearch.ml.sdkclient;

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED;
import static org.opensearch.sdk.SdkClientSettings.AWS_DYNAMO_DB;
import static org.opensearch.sdk.SdkClientSettings.AWS_OPENSEARCH_SERVICE;
import static org.opensearch.sdk.SdkClientSettings.REMOTE_METADATA_ENDPOINT;
Expand Down Expand Up @@ -84,19 +85,21 @@ public static SdkClient createSdkClient(Client client, NamedXContentRegistry xCo
String remoteMetadataEndpoint = REMOTE_METADATA_ENDPOINT.get(settings);
String region = REMOTE_METADATA_REGION.get(settings);
String serviceName = REMOTE_METADATA_SERVICE_NAME.get(settings);
Boolean multiTenancy = ML_COMMONS_MULTI_TENANCY_ENABLED.get(settings);

switch (remoteMetadataType) {
case REMOTE_OPENSEARCH:
if (Strings.isBlank(remoteMetadataEndpoint)) {
throw new OpenSearchException("Remote Opensearch client requires a metadata endpoint.");
}
log.info("Using remote opensearch cluster as metadata store");
return new SdkClient(new RemoteClusterIndicesClient(createOpenSearchClient(remoteMetadataEndpoint)));
return new SdkClient(new RemoteClusterIndicesClient(createOpenSearchClient(remoteMetadataEndpoint)), multiTenancy);
case AWS_OPENSEARCH_SERVICE:
validateAwsParams(remoteMetadataType, remoteMetadataEndpoint, region, serviceName);
log.info("Using remote AWS Opensearch Service cluster as metadata store");
return new SdkClient(
new RemoteClusterIndicesClient(createAwsOpenSearchServiceClient(remoteMetadataEndpoint, region, serviceName))
new RemoteClusterIndicesClient(createAwsOpenSearchServiceClient(remoteMetadataEndpoint, region, serviceName)),
multiTenancy
);
case AWS_DYNAMO_DB:
validateAwsParams(remoteMetadataType, remoteMetadataEndpoint, region, serviceName);
Expand All @@ -105,11 +108,12 @@ public static SdkClient createSdkClient(Client client, NamedXContentRegistry xCo
new DDBOpenSearchClient(
createDynamoDbClient(region),
new RemoteClusterIndicesClient(createAwsOpenSearchServiceClient(remoteMetadataEndpoint, region, serviceName))
)
),
multiTenancy
);
default:
log.info("Using local opensearch cluster as metadata store");
return new SdkClient(new LocalClusterIndicesClient(client, xContentRegistry));
return new SdkClient(new LocalClusterIndicesClient(client, xContentRegistry), multiTenancy);
}
}

Expand All @@ -123,8 +127,8 @@ private static void validateAwsParams(String clientType, String remoteMetadataEn
}

// Package private for testing
static SdkClient wrapSdkClientDelegate(SdkClientDelegate delegate) {
return new SdkClient(delegate);
static SdkClient wrapSdkClientDelegate(SdkClientDelegate delegate, Boolean multiTenancy) {
return new SdkClient(delegate, multiTenancy);
}

private static DynamoDbClient createDynamoDbClient(String region) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,20 @@ private MLCommonsSettings() {}
public static final Setting<Boolean> ML_COMMONS_AGENT_FRAMEWORK_ENABLED = Setting
.boolSetting("plugins.ml_commons.agent_framework_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);

// Whether multi-tenancy is enabled in ML Commons.
// This is a static setting which must be set before starting OpenSearch by (in priority order):
// 1. As a command-line argument using the -E flag (overrides other options):
// ./bin/opensearch -Eplugins.ml_commons.multi_tenancy_enabled=true
// 2. As a system property using OPENSEARCH_JAVA_OPTS (overrides opensearch.yml):
// export OPENSEARCH_JAVA_OPTS="-Dplugins.ml_commons.multi_tenancy_enabled=true"
// ./bin/opensearch
// Or inline when starting OpenSearch:
// OPENSEARCH_JAVA_OPTS="-Dplugins.ml_commons.multi_tenancy_enabled=true" ./bin/opensearch
// 3. In the opensearch.yml configuration file:
// plugins.ml_commons.multi_tenancy_enabled: true
// After setting it, a full cluster restart is required for the changes to take effect.
public static final Setting<Boolean> ML_COMMONS_MULTI_TENANCY_ENABLED = Setting
.boolSetting("plugins.ml_commons.multi_tenancy_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
.boolSetting("plugins.ml_commons.multi_tenancy_enabled", false, Setting.Property.NodeScope);

public static final Setting<Boolean> ML_COMMONS_CONTROLLER_ENABLED = Setting
.boolSetting("plugins.ml_commons.controller_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_AGENT_FRAMEWORK_ENABLED, it -> isAgentFrameworkEnabled = it);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_LOCAL_MODEL_ENABLED, it -> isLocalModelEnabled = it);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MULTI_TENANCY_ENABLED, it -> {
isMultiTenancyEnabled = it;
notifyMultiTenancyListeners(it);
});
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_CONTROLLER_ENABLED, it -> isControllerEnabled = it);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,6 @@ public void setupSettings() throws IOException {
response = TestHelper
.makeRequest(client(), "PUT", "_cluster/settings", ImmutableMap.of(), TestHelper.toHttpEntity(jsonEntity), null);
assertEquals(200, response.getStatusLine().getStatusCode());

String multiTenancyEntity = "{\n"
+ " \"persistent\" : {\n"
+ " \"plugins.ml_commons.multi_tenancy_enabled\" : false \n"
+ " }\n"
+ "}";

response = TestHelper
.makeRequest(client(), "PUT", "_cluster/settings", ImmutableMap.of(), TestHelper.toHttpEntity(multiTenancyEntity), null);
assertEquals(200, response.getStatusLine().getStatusCode());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ public static void cleanup() {
public void setup() {
MockitoAnnotations.openMocks(this);

sdkClient = SdkClientFactory.wrapSdkClientDelegate(new DDBOpenSearchClient(dynamoDbClient, remoteClusterIndicesClient));
sdkClient.onMultiTenancyEnabledChanged(true);
sdkClient = SdkClientFactory.wrapSdkClientDelegate(new DDBOpenSearchClient(dynamoDbClient, remoteClusterIndicesClient), true);
testDataObject = new TestDataObject("foo");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ public void setup() {
.setSerializationInclusion(JsonInclude.Include.NON_NULL)
)
);
sdkClient = SdkClientFactory.wrapSdkClientDelegate(new RemoteClusterIndicesClient(mockedOpenSearchClient));
sdkClient.onMultiTenancyEnabledChanged(true);
sdkClient = SdkClientFactory.wrapSdkClientDelegate(new RemoteClusterIndicesClient(mockedOpenSearchClient), true);
testDataObject = new TestDataObject("foo");
}

Expand Down Expand Up @@ -592,8 +591,6 @@ public void testSearchDataObject_Exception() throws IOException {

public void testSearchDataObject_NullTenant() throws IOException {
// Tests exception if multitenancy enabled
sdkClient.onMultiTenancyEnabledChanged(true);

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
SearchDataObjectRequest searchRequest = SearchDataObjectRequest
.builder()
Expand All @@ -612,4 +609,27 @@ public void testSearchDataObject_NullTenant() throws IOException {
assertEquals(OpenSearchStatusException.class, cause.getClass());
assertEquals("Tenant ID is required when multitenancy is enabled.", cause.getMessage());
}

public void testSearchDataObject_NullTenantNoMultitenancy() throws IOException {
// Tests no status exception if multitenancy not enabled
SdkClient sdkClientNoTenant = SdkClientFactory.wrapSdkClientDelegate(new RemoteClusterIndicesClient(mockedOpenSearchClient), false);

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
SearchDataObjectRequest searchRequest = SearchDataObjectRequest
.builder()
.indices(TEST_INDEX)
// null tenant Id
.searchSourceBuilder(searchSourceBuilder)
.build();

when(mockedOpenSearchClient.search(any(SearchRequest.class), any())).thenThrow(new UnsupportedOperationException("test"));
CompletableFuture<SearchDataObjectResponse> future = sdkClientNoTenant
.searchDataObjectAsync(searchRequest, testThreadPool.executor(GENERAL_THREAD_POOL))
.toCompletableFuture();

CompletionException ce = assertThrows(CompletionException.class, () -> future.join());
Throwable cause = ce.getCause();
assertEquals(UnsupportedOperationException.class, cause.getClass());
assertEquals("test", cause.getMessage());
}
}
Loading