From 0e350f85ec3b5f7ffd820d2778e67cb3cbf29c9d Mon Sep 17 00:00:00 2001 From: Arjun kumar Giri Date: Wed, 12 Jun 2024 17:24:02 -0700 Subject: [PATCH] AWS DDB SDK client support for remote data store Signed-off-by: Arjun kumar Giri --- .../sdk/DeleteDataObjectRequest.java | 17 +- .../opensearch/sdk/GetDataObjectRequest.java | 16 +- .../opensearch/sdk/PutDataObjectRequest.java | 28 +- plugin/build.gradle | 22 ++ .../ml/sdkclient/DDBOpenSearchClient.java | 125 ++++++++ .../ml/sdkclient/SdkClientModule.java | 86 +++++- .../sdkclient/DDBOpenSearchClientTests.java | 274 ++++++++++++++++++ .../ml/sdkclient/SdkClientModuleTests.java | 19 +- 8 files changed, 569 insertions(+), 18 deletions(-) create mode 100644 plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java create mode 100644 plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java diff --git a/common/src/main/java/org/opensearch/sdk/DeleteDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/DeleteDataObjectRequest.java index 31d560815d..4cbe587f75 100644 --- a/common/src/main/java/org/opensearch/sdk/DeleteDataObjectRequest.java +++ b/common/src/main/java/org/opensearch/sdk/DeleteDataObjectRequest.java @@ -13,6 +13,8 @@ public class DeleteDataObjectRequest { private final String index; private final String id; + private final String tenantId; + /** * Instantiate this request with an index and id. *

@@ -20,9 +22,10 @@ public class DeleteDataObjectRequest { * @param index the index location to delete the object * @param id the document id */ - public DeleteDataObjectRequest(String index, String id) { + public DeleteDataObjectRequest(String index, String id, String tenantId) { this.index = index; this.id = id; + this.tenantId = tenantId; } /** @@ -41,12 +44,17 @@ public String id() { return this.id; } + public String tenantId() { + return this.tenantId; + } + /** * Class for constructing a Builder for this Request Object */ public static class Builder { private String index = null; private String id = null; + private String tenantId = null; /** * Empty Constructor for the Builder object @@ -73,12 +81,17 @@ public Builder id(String id) { return this; } + public Builder tenantId(String tenantId) { + this.tenantId = tenantId; + return this; + } + /** * Builds the object * @return A {@link DeleteDataObjectRequest} */ public DeleteDataObjectRequest build() { - return new DeleteDataObjectRequest(this.index, this.id); + return new DeleteDataObjectRequest(this.index, this.id, this.tenantId); } } } diff --git a/common/src/main/java/org/opensearch/sdk/GetDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/GetDataObjectRequest.java index 8edbb99f39..3d282dbf04 100644 --- a/common/src/main/java/org/opensearch/sdk/GetDataObjectRequest.java +++ b/common/src/main/java/org/opensearch/sdk/GetDataObjectRequest.java @@ -14,6 +14,7 @@ public class GetDataObjectRequest { private final String index; private final String id; + private final String tenantId; private final FetchSourceContext fetchSourceContext; /** @@ -24,9 +25,10 @@ public class GetDataObjectRequest { * @param id the document id * @param fetchSourceContext the context to use when fetching _source */ - public GetDataObjectRequest(String index, String id, FetchSourceContext fetchSourceContext) { + public GetDataObjectRequest(String index, String id, String tenantId, FetchSourceContext fetchSourceContext) { this.index = index; this.id = id; + this.tenantId = tenantId; this.fetchSourceContext = fetchSourceContext; } @@ -46,6 +48,10 @@ public String id() { return this.id; } + public String tenantId() { + return this.tenantId; + } + /** * Returns the context for fetching _source * @return the fetchSourceContext @@ -60,6 +66,7 @@ public FetchSourceContext fetchSourceContext() { public static class Builder { private String index = null; private String id = null; + private String tenantId = null; private FetchSourceContext fetchSourceContext; /** @@ -87,6 +94,11 @@ public Builder id(String id) { return this; } + public Builder tenantId(String tenantId) { + this.tenantId = tenantId; + return this; + } + /** * Add a fetchSourceContext to this builder * @param fetchSourceContext the fetchSourceContext @@ -102,7 +114,7 @@ public Builder fetchSourceContext(FetchSourceContext fetchSourceContext) { * @return A {@link GetDataObjectRequest} */ public GetDataObjectRequest build() { - return new GetDataObjectRequest(this.index, this.id, this.fetchSourceContext); + return new GetDataObjectRequest(this.index, this.id, this.tenantId, this.fetchSourceContext); } } } diff --git a/common/src/main/java/org/opensearch/sdk/PutDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/PutDataObjectRequest.java index 2d6d0a5d07..bb36150de0 100644 --- a/common/src/main/java/org/opensearch/sdk/PutDataObjectRequest.java +++ b/common/src/main/java/org/opensearch/sdk/PutDataObjectRequest.java @@ -13,6 +13,8 @@ public class PutDataObjectRequest { private final String index; + private final String id; + private final String tenantId; private final ToXContentObject dataObject; /** @@ -22,8 +24,10 @@ public class PutDataObjectRequest { * @param index the index location to put the object * @param dataObject the data object */ - public PutDataObjectRequest(String index, ToXContentObject dataObject) { + public PutDataObjectRequest(String index, String id, String tenantId, ToXContentObject dataObject) { this.index = index; + this.id = id; + this.tenantId = tenantId; this.dataObject = dataObject; } @@ -35,6 +39,14 @@ public String index() { return this.index; } + public String id() { + return this.id; + } + + public String tenantId() { + return this.tenantId; + } + /** * Returns the data object * @return the data object @@ -48,6 +60,8 @@ public ToXContentObject dataObject() { */ public static class Builder { private String index = null; + private String id = null; + private String tenantId = null; private ToXContentObject dataObject = null; /** @@ -65,6 +79,16 @@ public Builder index(String index) { return this; } + public Builder id(String id) { + this.id = id; + return this; + } + + public Builder tenantId(String tenantId) { + this.tenantId = tenantId; + return this; + } + /** * Add a data object to this builder * @param dataObject the data object @@ -80,7 +104,7 @@ public Builder dataObject(ToXContentObject dataObject) { * @return A {@link PutDataObjectRequest} */ public PutDataObjectRequest build() { - return new PutDataObjectRequest(this.index, this.dataObject); + return new PutDataObjectRequest(this.index, this.id, this.tenantId, this.dataObject); } } } diff --git a/plugin/build.gradle b/plugin/build.gradle index 947ed22c5a..5b4f1ee3fb 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -63,6 +63,28 @@ dependencies { implementation 'com.jayway.jsonpath:json-path:2.9.0' implementation "org.opensearch.client:opensearch-java:2.10.2" + // Dynamo dependencies + implementation("software.amazon.awssdk:sdk-core:2.25.40") + implementation("software.amazon.awssdk:aws-core:2.25.40") + implementation "software.amazon.awssdk:aws-json-protocol:2.25.40" + implementation("software.amazon.awssdk:auth:2.25.40") + implementation("software.amazon.awssdk:checksums:2.25.40") + implementation("software.amazon.awssdk:checksums-spi:2.25.40") + implementation("software.amazon.awssdk:dynamodb:2.25.40") + implementation("software.amazon.awssdk:endpoints-spi:2.25.40") + implementation("software.amazon.awssdk:http-auth-aws:2.25.40") + implementation("software.amazon.awssdk:http-auth-spi:2.25.40") + implementation("software.amazon.awssdk:http-client-spi:2.25.40") + implementation("software.amazon.awssdk:identity-spi:2.25.40") + implementation "software.amazon.awssdk:json-utils:2.25.40" + implementation "software.amazon.awssdk:metrics-spi:2.25.40" + implementation("software.amazon.awssdk:profiles:2.25.40") + implementation "software.amazon.awssdk:protocol-core:2.25.40" + implementation("software.amazon.awssdk:regions:2.25.40") + implementation "software.amazon.awssdk:third-party-jackson-core:2.25.40" + implementation("software.amazon.awssdk:url-connection-client:2.25.40") + implementation("software.amazon.awssdk:utils:2.25.40") + configurations.all { resolutionStrategy.force 'org.apache.httpcomponents.core5:httpcore5:5.2.4' diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java new file mode 100644 index 0000000000..2c3039c2bb --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java @@ -0,0 +1,125 @@ +package org.opensearch.ml.sdkclient; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.opensearch.OpenSearchException; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.sdk.DeleteDataObjectRequest; +import org.opensearch.sdk.DeleteDataObjectResponse; +import org.opensearch.sdk.GetDataObjectRequest; +import org.opensearch.sdk.GetDataObjectResponse; +import org.opensearch.sdk.PutDataObjectRequest; +import org.opensearch.sdk.PutDataObjectResponse; +import org.opensearch.sdk.SdkClient; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; +import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; +import software.amazon.awssdk.services.dynamodb.model.GetItemResponse; +import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.Executor; + +@AllArgsConstructor +@Log4j2 +public class DDBOpenSearchClient implements SdkClient { + + private static final String DEFAULT_TENANT = "DEFAULT_TENANT"; + + private static final String HASH_KEY = "tenant_id"; + private static final String RANGE_KEY = "id"; + private static final String SOURCE = "source"; + + private DynamoDbClient dynamoDbClient; + @Override + public CompletionStage putDataObjectAsync(PutDataObjectRequest request, Executor executor) { + final String id = request.id() != null ? request.id() : UUID.randomUUID().toString(); + final String tenantId = request.tenantId() != null ? request.tenantId() : DEFAULT_TENANT; + final String tableName = getTableName(request.index()); + return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + try (XContentBuilder sourceBuilder = XContentFactory.jsonBuilder()) { + XContentBuilder builder = request.dataObject().toXContent(sourceBuilder, ToXContent.EMPTY_PARAMS); + String source = builder.toString(); + + final Map item = Map.ofEntries( + Map.entry(HASH_KEY, AttributeValue.builder().s(tenantId).build()), + Map.entry(RANGE_KEY, AttributeValue.builder().s(id).build()), + Map.entry(SOURCE, AttributeValue.builder().s(source).build()) + ); + final PutItemRequest putItemRequest = PutItemRequest.builder() + .tableName(tableName) + .item(item) + .build(); + + dynamoDbClient.putItem(putItemRequest); + return new PutDataObjectResponse.Builder().id(id).created(true).build(); + } catch (Exception e){ + log.error("Exception while inserting data into DDB: " + e.getMessage(), e); + throw new OpenSearchException(e); + } + }), executor); + } + + @Override + public CompletionStage getDataObjectAsync(GetDataObjectRequest request, Executor executor) { + final String tenantId = request.tenantId() != null ? request.tenantId() : DEFAULT_TENANT; + final GetItemRequest getItemRequest = GetItemRequest.builder() + .tableName(getTableName(request.index())) + .key(Map.ofEntries( + Map.entry(HASH_KEY, AttributeValue.builder().s(tenantId).build()), + Map.entry(RANGE_KEY, AttributeValue.builder().s(request.id()).build()) + )) + .build(); + return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + try { + final GetItemResponse getItemResponse = dynamoDbClient.getItem(getItemRequest); + if (getItemResponse == null || getItemResponse.item() == null || getItemResponse.item().isEmpty()) { + return new GetDataObjectResponse.Builder().id(request.id()).parser(Optional.empty()).build(); + } + + String source = getItemResponse.item().get(SOURCE).s(); + XContentParser parser = JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, source); + return new GetDataObjectResponse.Builder().id(request.id()).parser(Optional.of(parser)).build(); + } catch (Exception e) { + log.error("Exception while fetching data from DDB: " + e.getMessage(), e); + throw new OpenSearchException(e); + } + }), executor); + } + + @Override + public CompletionStage deleteDataObjectAsync(DeleteDataObjectRequest request, Executor executor) { + final String tenantId = request.tenantId() != null ? request.tenantId() : DEFAULT_TENANT; + final DeleteItemRequest deleteItemRequest = DeleteItemRequest.builder() + .tableName(getTableName(request.index())) + .key(Map.ofEntries( + Map.entry(HASH_KEY, AttributeValue.builder().s(tenantId).build()), + Map.entry(RANGE_KEY, AttributeValue.builder().s(request.id()).build()) + )).build(); + return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + dynamoDbClient.deleteItem(deleteItemRequest); + return new DeleteDataObjectResponse.Builder().id(request.id()).deleted(true).build(); + }), executor); + } + + private String getTableName(String index) { + // Table name will be same as index name. As DDB table name does not support dot(.) + // it will be removed form name. + return index.replaceAll("\\.", ""); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java index fb7d1d3119..504397cb30 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java @@ -8,36 +8,79 @@ */ package org.opensearch.ml.sdkclient; +import lombok.extern.log4j.Log4j2; +import lombok.extern.slf4j.Slf4j; import org.apache.http.HttpHost; import org.apache.http.conn.ssl.NoopHostnameVerifier; import org.opensearch.OpenSearchException; +import org.opensearch.SpecialPermission; import org.opensearch.client.RestClient; import org.opensearch.client.json.jackson.JacksonJsonpMapper; import org.opensearch.client.opensearch.OpenSearchClient; import org.opensearch.client.transport.rest_client.RestClientTransport; import org.opensearch.common.inject.AbstractModule; -import org.opensearch.core.common.Strings; +import org.opensearch.sdk.PutDataObjectResponse; import org.opensearch.sdk.SdkClient; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.profiles.ProfileFileSystemSetting; + +import java.security.AccessController; +import java.security.PrivilegedAction; /** * A module for binding this plugin's desired implementation of {@link SdkClient}. */ +@Log4j2 public class SdkClientModule extends AbstractModule { + public static final String REMOTE_METADATA_TYPE = "REMOTE_METADATA_TYPE"; public static final String REMOTE_METADATA_ENDPOINT = "REMOTE_METADATA_ENDPOINT"; public static final String REGION = "REGION"; + public static final String REMOTE_OPENSEARCH = "RemoteOpenSearch"; + public static final String AWS_DYNAMO_DB = "AWSDynamoDB"; + private final String remoteStoreType; private final String remoteMetadataEndpoint; private final String region; // not using with RestClient + static { + // Aws v2 sdk tries to load a default profile from home path which is restricted. Hence, setting these to random valid paths. + // @SuppressForbidden(reason = "Need to provide this override to v2 SDK so that path does not default to home path") + if (ProfileFileSystemSetting.AWS_SHARED_CREDENTIALS_FILE.getStringValue().isEmpty()) { + SocketAccess.doPrivileged( + () -> System.setProperty( + ProfileFileSystemSetting.AWS_SHARED_CREDENTIALS_FILE.property(), + System.getProperty("opensearch.path.conf") + ) + ); + } + if (ProfileFileSystemSetting.AWS_CONFIG_FILE.getStringValue().isEmpty()) { + SocketAccess.doPrivileged( + () -> System.setProperty(ProfileFileSystemSetting.AWS_CONFIG_FILE.property(), System.getProperty("opensearch.path.conf")) + ); + } + } + + private static final class SocketAccess { + private SocketAccess() {} + + public static T doPrivileged(PrivilegedAction operation) { + SpecialPermission.check(); + return AccessController.doPrivileged(operation); + } + } + /** * Instantiate this module using environment variables */ public SdkClientModule() { - this(System.getenv(REMOTE_METADATA_ENDPOINT), System.getenv(REGION)); + this(System.getenv(REMOTE_METADATA_TYPE), System.getenv(REMOTE_METADATA_ENDPOINT), System.getenv(REGION)); } /** @@ -45,19 +88,44 @@ public SdkClientModule() { * @param remoteMetadataEndpoint The remote endpoint * @param region The region */ - SdkClientModule(String remoteMetadataEndpoint, String region) { + SdkClientModule(String remoteStoreType, String remoteMetadataEndpoint, String region) { + this.remoteStoreType = remoteStoreType; this.remoteMetadataEndpoint = remoteMetadataEndpoint; - this.region = region; + this.region = region == null ? "us-west-2" : region; } @Override - protected void configure() { - boolean local = Strings.isNullOrEmpty(remoteMetadataEndpoint); - if (local) { + protected void configure() {/* + if (this.remoteStoreType == null) { + log.info("Using local opensearch cluster as metadata store"); bind(SdkClient.class).to(LocalClusterIndicesClient.class); - } else { - bind(SdkClient.class).toInstance(new RemoteClusterIndicesClient(createOpenSearchClient())); + return; } + + switch (this.remoteStoreType) { + case REMOTE_OPENSEARCH: + log.info("Using remote opensearch cluster as metadata store"); + bind(SdkClient.class).toInstance(new RemoteClusterIndicesClient(createOpenSearchClient())); + return; + case AWS_DYNAMO_DB: + log.info("Using dynamo DB as metadata store"); + bind(SdkClient.class).toInstance(new DDBOpenSearchClient(createDynamoDbClient())); + return; + default: + log.info("Using local opensearch cluster as metadata store"); + bind(SdkClient.class).to(LocalClusterIndicesClient.class); + }*/ + bind(SdkClient.class).toInstance(new DDBOpenSearchClient(createDynamoDbClient())); + } + + private DynamoDbClient createDynamoDbClient() { + if (this.region == null) { + throw new IllegalStateException("REGION environment variable needs to be set!"); + } + + return DynamoDbClient.builder() + .region(Region.of(this.region)) + .build(); } private OpenSearchClient createOpenSearchClient() { diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java new file mode 100644 index 0000000000..0da4bc21a5 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java @@ -0,0 +1,274 @@ +package org.opensearch.ml.sdkclient; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchException; +import org.opensearch.client.opensearch.core.IndexRequest; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.sdk.DeleteDataObjectRequest; +import org.opensearch.sdk.DeleteDataObjectResponse; +import org.opensearch.sdk.GetDataObjectRequest; +import org.opensearch.sdk.GetDataObjectResponse; +import org.opensearch.sdk.PutDataObjectRequest; +import org.opensearch.sdk.PutDataObjectResponse; +import org.opensearch.sdk.SdkClient; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ScalingExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; +import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse; +import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; +import software.amazon.awssdk.services.dynamodb.model.GetItemResponse; +import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; +import software.amazon.awssdk.services.dynamodb.model.PutItemResponse; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.TimeUnit; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; + + +public class DDBOpenSearchClientTests extends OpenSearchTestCase { + + private static final String TEST_ID = "123"; + private static final String TENANT_ID = "TEST_TENANT_ID"; + private static final String TEST_INDEX = "test_index"; + private SdkClient sdkClient; + + @Mock + private DynamoDbClient dynamoDbClient; + @Captor + private ArgumentCaptor putItemRequestArgumentCaptor; + @Captor + private ArgumentCaptor getItemRequestArgumentCaptor; + @Captor + private ArgumentCaptor deleteItemRequestArgumentCaptor; + private TestDataObject testDataObject; + + + private static TestThreadPool testThreadPool = new TestThreadPool( + LocalClusterIndicesClientTests.class.getName(), + new ScalingExecutorBuilder( + GENERAL_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), + ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL + ) + ); + + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); + } + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + sdkClient = new DDBOpenSearchClient(dynamoDbClient); + testDataObject = new TestDataObject("foo"); + } + + @Test + public void testPutDataObject_HappyCase() throws IOException { + PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .tenantId(TENANT_ID) + .dataObject(testDataObject).build(); + Mockito.when(dynamoDbClient.putItem(Mockito.any(PutItemRequest.class))) + .thenReturn(PutItemResponse.builder().build()); + PutDataObjectResponse response = sdkClient + .putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + Mockito.verify(dynamoDbClient).putItem(putItemRequestArgumentCaptor.capture()); + Assert.assertEquals(TEST_ID, response.id()); + Assert.assertEquals(true, response.created()); + + PutItemRequest putItemRequest = putItemRequestArgumentCaptor.getValue(); + Assert.assertEquals(TEST_INDEX, putItemRequest.tableName()); + Assert.assertEquals(TEST_ID, putItemRequest.item().get("id").s()); + Assert.assertEquals(TENANT_ID, putItemRequest.item().get("tenant_id").s()); + XContentBuilder sourceBuilder = XContentFactory.jsonBuilder(); + XContentBuilder builder = testDataObject.toXContent(sourceBuilder, ToXContent.EMPTY_PARAMS); + Assert.assertEquals(builder.toString(), putItemRequest.item().get("source").s()); + } + + @Test + public void testPutDataObject_NullTenantId_SetsDefaultTenantId() throws IOException { + PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(testDataObject).build(); + Mockito.when(dynamoDbClient.putItem(Mockito.any(PutItemRequest.class))) + .thenReturn(PutItemResponse.builder().build()); + sdkClient.putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + Mockito.verify(dynamoDbClient).putItem(putItemRequestArgumentCaptor.capture()); + + PutItemRequest putItemRequest = putItemRequestArgumentCaptor.getValue(); + Assert.assertEquals("DEFAULT_TENANT", putItemRequest.item().get("tenant_id").s()); + } + + @Test + public void testPutDataObject_NullId_SetsDefaultTenantId() throws IOException { + PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder() + .index(TEST_INDEX) + .dataObject(testDataObject).build(); + Mockito.when(dynamoDbClient.putItem(Mockito.any(PutItemRequest.class))) + .thenReturn(PutItemResponse.builder().build()); + PutDataObjectResponse response = sdkClient.putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + Mockito.verify(dynamoDbClient).putItem(putItemRequestArgumentCaptor.capture()); + + PutItemRequest putItemRequest = putItemRequestArgumentCaptor.getValue(); + Assert.assertNotNull(putItemRequest.item().get("id").s()); + Assert.assertNotNull(response.id()); + } + + @Test + public void testPutDataObject_DDBException_ThrowsException() { + PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(testDataObject).build(); + Mockito.when(dynamoDbClient.putItem(Mockito.any(PutItemRequest.class))) + .thenThrow(new RuntimeException("Test exception")); + CompletableFuture future = sdkClient + .putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture(); + + CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); + assertEquals(OpenSearchException.class, ce.getCause().getClass()); + } + + @Test + public void testGetDataObject_HappyCase() throws IOException { + GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .tenantId(TENANT_ID).build(); + XContentBuilder sourceBuilder = XContentFactory.jsonBuilder(); + XContentBuilder builder = testDataObject.toXContent(sourceBuilder, ToXContent.EMPTY_PARAMS); + GetItemResponse getItemResponse = GetItemResponse.builder().item(Map.ofEntries( + Map.entry("source", AttributeValue.builder().s(builder.toString()).build()) + )).build(); + Mockito.when(dynamoDbClient.getItem(Mockito.any(GetItemRequest.class))) + .thenReturn(getItemResponse); + GetDataObjectResponse response = sdkClient + .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + Mockito.verify(dynamoDbClient).getItem(getItemRequestArgumentCaptor.capture()); + GetItemRequest getItemRequest = getItemRequestArgumentCaptor.getValue(); + Assert.assertEquals(TEST_INDEX, getItemRequest.tableName()); + Assert.assertEquals(TENANT_ID, getItemRequest.key().get("tenant_id").s()); + Assert.assertEquals(TEST_ID, getItemRequest.key().get("id").s()); + Assert.assertEquals(TEST_ID, response.id()); + Assert.assertTrue(response.parser().isPresent()); + Assert.assertEquals("foo", response.parser().get().map().get("data")); + } + + @Test + public void testGetDataObject_NoExistingDoc() throws IOException { + GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .tenantId(TENANT_ID).build(); + GetItemResponse getItemResponse = GetItemResponse.builder().build(); + Mockito.when(dynamoDbClient.getItem(Mockito.any(GetItemRequest.class))) + .thenReturn(getItemResponse); + GetDataObjectResponse response = sdkClient + .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + Assert.assertEquals(TEST_ID, response.id()); + Assert.assertFalse(response.parser().isPresent()); + } + + @Test + public void testGetDataObject_UseDefaultTenantIdIfNull() throws IOException { + GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID).build(); + GetItemResponse getItemResponse = GetItemResponse.builder().build(); + Mockito.when(dynamoDbClient.getItem(Mockito.any(GetItemRequest.class))) + .thenReturn(getItemResponse); + GetDataObjectResponse response = sdkClient + .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + Mockito.verify(dynamoDbClient).getItem(getItemRequestArgumentCaptor.capture()); + GetItemRequest getItemRequest = getItemRequestArgumentCaptor.getValue(); + Assert.assertEquals("DEFAULT_TENANT", getItemRequest.key().get("tenant_id").s()); + } + + @Test + public void testGetDataObject_DDBException_ThrowsOSException() throws IOException { + GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .tenantId(TENANT_ID).build(); + Mockito.when(dynamoDbClient.getItem(Mockito.any(GetItemRequest.class))) + .thenThrow(new RuntimeException("Test exception")); + CompletableFuture future = sdkClient + .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture(); + CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); + assertEquals(OpenSearchException.class, ce.getCause().getClass()); + } + + @Test + public void testDeleteDataObject_HappyCase() { + DeleteDataObjectRequest deleteRequest = new DeleteDataObjectRequest.Builder() + .id(TEST_ID).index(TEST_INDEX).tenantId(TENANT_ID).build(); + Mockito.when(dynamoDbClient.deleteItem(deleteItemRequestArgumentCaptor.capture())) + .thenReturn(DeleteItemResponse.builder().build()); + DeleteDataObjectResponse deleteResponse = sdkClient + .deleteDataObjectAsync(deleteRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture().join(); + DeleteItemRequest deleteItemRequest = deleteItemRequestArgumentCaptor.getValue(); + Assert.assertEquals(TEST_INDEX, deleteItemRequest.tableName()); + Assert.assertEquals(TENANT_ID, deleteItemRequest.key().get("tenant_id").s()); + Assert.assertEquals(TEST_ID, deleteItemRequest.key().get("id").s()); + Assert.assertEquals(TEST_ID, deleteResponse.id()); + Assert.assertTrue(deleteResponse.deleted()); + } + + @Test + public void testDeleteDataObject_NullTenantId_UsesDefaultTenantId() { + DeleteDataObjectRequest deleteRequest = new DeleteDataObjectRequest.Builder() + .id(TEST_ID).index(TEST_INDEX).build(); + Mockito.when(dynamoDbClient.deleteItem(deleteItemRequestArgumentCaptor.capture())) + .thenReturn(DeleteItemResponse.builder().build()); + DeleteDataObjectResponse deleteResponse = sdkClient + .deleteDataObjectAsync(deleteRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture().join(); + DeleteItemRequest deleteItemRequest = deleteItemRequestArgumentCaptor.getValue(); + Assert.assertEquals("DEFAULT_TENANT", deleteItemRequest.key().get("tenant_id").s()); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java index 707ddd46f6..b1cd9d7db6 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java @@ -10,6 +10,7 @@ import static org.mockito.Mockito.mock; +import org.junit.Before; import org.opensearch.common.inject.AbstractModule; import org.opensearch.common.inject.Guice; import org.opensearch.common.inject.Injector; @@ -22,6 +23,11 @@ @ThreadLeakScope(ThreadLeakScope.Scope.NONE) // remote http client is never closed public class SdkClientModuleTests extends OpenSearchTestCase { + @Before + public void setup() { + System.setProperty("opensearch.path.conf", "/tmp"); + } + private Module localClientModule = new AbstractModule() { @Override protected void configure() { @@ -30,16 +36,23 @@ protected void configure() { }; public void testLocalBinding() { - Injector injector = Guice.createInjector(new SdkClientModule(null, null), localClientModule); + Injector injector = Guice.createInjector(new SdkClientModule(null, null, null), localClientModule); SdkClient sdkClient = injector.getInstance(SdkClient.class); assertTrue(sdkClient instanceof LocalClusterIndicesClient); } - public void testRemoteBinding() { - Injector injector = Guice.createInjector(new SdkClientModule("http://example.org", "eu-west-3")); + public void testRemoteOpenSearchBinding() { + Injector injector = Guice.createInjector(new SdkClientModule(SdkClientModule.REMOTE_OPENSEARCH, "http://example.org", "eu-west-3")); SdkClient sdkClient = injector.getInstance(SdkClient.class); assertTrue(sdkClient instanceof RemoteClusterIndicesClient); } + + public void testDDBBinding() { + Injector injector = Guice.createInjector(new SdkClientModule(SdkClientModule.AWS_DYNAMO_DB, null, "eu-west-3")); + + SdkClient sdkClient = injector.getInstance(SdkClient.class); + assertTrue(sdkClient instanceof DDBOpenSearchClient); + } }