diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java index bb126c01c2..5441c3c27b 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java @@ -11,6 +11,10 @@ import java.io.IOException; import java.security.AccessController; import java.security.PrivilegedAction; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.UUID; @@ -38,7 +42,12 @@ import org.opensearch.sdk.UpdateDataObjectRequest; import org.opensearch.sdk.UpdateDataObjectResponse; -import lombok.AllArgsConstructor; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.common.annotations.VisibleForTesting; + import lombok.extern.log4j.Log4j2; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; @@ -46,22 +55,28 @@ import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; import software.amazon.awssdk.services.dynamodb.model.GetItemResponse; import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; +import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; /** * DDB implementation of {@link SdkClient}. DDB table name will be mapped to index name. * */ -@AllArgsConstructor @Log4j2 public class DDBOpenSearchClient implements SdkClient { + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); private static final String DEFAULT_TENANT = "DEFAULT_TENANT"; private static final String HASH_KEY = "tenant_id"; private static final String RANGE_KEY = "id"; - private static final String SOURCE = "source"; private DynamoDbClient dynamoDbClient; + private RemoteClusterIndicesClient remoteClusterIndicesClient; + + public DDBOpenSearchClient(DynamoDbClient dynamoDbClient, RemoteClusterIndicesClient remoteClusterIndicesClient) { + this.dynamoDbClient = dynamoDbClient; + this.remoteClusterIndicesClient = remoteClusterIndicesClient; + } /** * DDB implementation to write data objects to DDB table. Tenant ID will be used as hash key and document ID will @@ -76,16 +91,18 @@ public CompletionStage putDataObjectAsync(PutDataObjectRe final String tableName = getTableName(request.index()); return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { String source = Strings.toString(MediaTypeRegistry.JSON, request.dataObject()); - final Map item = Map - .ofEntries( - Map.entry(HASH_KEY, AttributeValue.builder().s(tenantId).build()), - Map.entry(RANGE_KEY, AttributeValue.builder().s(id).build()), - Map.entry(SOURCE, AttributeValue.builder().s(source).build()) - ); - final PutItemRequest putItemRequest = PutItemRequest.builder().tableName(tableName).item(item).build(); + try { + JsonNode jsonNode = OBJECT_MAPPER.readTree(source); + Map item = convertJsonObjectToItem(jsonNode); + item.put(HASH_KEY, AttributeValue.builder().s(tenantId).build()); + item.put(RANGE_KEY, AttributeValue.builder().s(id).build()); + final PutItemRequest putItemRequest = PutItemRequest.builder().tableName(tableName).item(item).build(); - dynamoDbClient.putItem(putItemRequest); - return new PutDataObjectResponse.Builder().id(id).created(true).build(); + dynamoDbClient.putItem(putItemRequest); + return new PutDataObjectResponse.Builder().id(id).created(true).build(); + } catch (IOException e) { + throw new OpenSearchStatusException("Failed to parse data object " + request.id(), RestStatus.BAD_REQUEST); + } }), executor); } @@ -114,7 +131,8 @@ public CompletionStage getDataObjectAsync(GetDataObjectRe return new GetDataObjectResponse.Builder().id(request.id()).parser(Optional.empty()).build(); } - String source = getItemResponse.item().get(SOURCE).s(); + final ObjectNode sourceObject = convertToObjectNode((getItemResponse.item())); + final String source = OBJECT_MAPPER.writeValueAsString(sourceObject); XContentParser parser = JsonXContent.jsonXContent .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, source); return new GetDataObjectResponse.Builder().id(request.id()).parser(Optional.of(parser)).build(); @@ -125,10 +143,38 @@ public CompletionStage getDataObjectAsync(GetDataObjectRe }), executor); } + + /** + * Makes use of DDB update request to update data object. + * + */ @Override public CompletionStage updateDataObjectAsync(UpdateDataObjectRequest request, Executor executor) { - // TODO: Implement update - return null; + final String tenantId = request.tenantId() != null ? request.tenantId() : DEFAULT_TENANT; + return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + try { + String source = Strings.toString(MediaTypeRegistry.JSON, request.dataObject()); + JsonNode jsonNode = OBJECT_MAPPER.readTree(source); + Map updateItem = convertJsonObjectToItem(jsonNode); + updateItem.put(HASH_KEY, AttributeValue.builder().s(tenantId).build()); + updateItem.put(RANGE_KEY, AttributeValue.builder().s(request.id()).build()); + UpdateItemRequest updateItemRequest = UpdateItemRequest + .builder() + .tableName(getTableName(request.index())) + .key(updateItem) + .build(); + dynamoDbClient.updateItem(updateItemRequest); + + return new UpdateDataObjectResponse.Builder().id(request.id()).shardId(request.index()).updated(true).build(); + } catch (IOException e) { + log.error("Error updating {} in {}: {}", request.id(), request.index(), e.getMessage(), e); + // Rethrow unchecked exception on update IOException + throw new OpenSearchStatusException( + "Parsing error updating data object " + request.id() + " in index " + request.index(), + RestStatus.BAD_REQUEST + ); + } + }), executor); } /** @@ -155,11 +201,17 @@ public CompletionStage deleteDataObjectAsync(DeleteDat }), executor); } + /** + * DDB data needs to be synced with opensearch cluster. {@link RemoteClusterIndicesClient} will then be used to + * search data in opensearch cluster. + * + * @param request + * @param executor + * @return Search data object response + */ @Override public CompletionStage searchDataObjectAsync(SearchDataObjectRequest request, Executor executor) { - // TODO will implement this later. - - return null; + return this.remoteClusterIndicesClient.searchDataObjectAsync(request, executor); } private String getTableName(String index) { @@ -167,4 +219,122 @@ private String getTableName(String index) { // it will be removed from name. return index.replaceAll("\\.", ""); } + + @VisibleForTesting + static Map convertJsonObjectToItem(JsonNode jsonNode) { + Map item = new HashMap<>(); + Iterator> fields = jsonNode.fields(); + + while (fields.hasNext()) { + Map.Entry field = fields.next(); + + if (field.getValue().isTextual()) { + item.put(field.getKey(), AttributeValue.builder().s(field.getValue().asText()).build()); + } else if (field.getValue().isNumber()) { + item.put(field.getKey(), AttributeValue.builder().n(field.getValue().asText()).build()); + } else if (field.getValue().isBoolean()) { + item.put(field.getKey(), AttributeValue.builder().bool(field.getValue().asBoolean()).build()); + } else if (field.getValue().isNull()) { + item.put(field.getKey(), AttributeValue.builder().nul(true).build()); + } else if (field.getValue().isObject()) { + item.put(field.getKey(), AttributeValue.builder().m(convertJsonObjectToItem(field.getValue())).build()); + } else if (field.getValue().isArray()) { + item.put(field.getKey(), AttributeValue.builder().l(convertJsonArrayToAttributeValueList(field.getValue())).build()); + } else { + throw new IllegalArgumentException("Unsupported field type: " + field.getValue()); + } + } + + return item; + } + + @VisibleForTesting + static List convertJsonArrayToAttributeValueList(JsonNode jsonArray) { + List attributeValues = new ArrayList<>(); + + for (JsonNode element : jsonArray) { + if (element.isTextual()) { + attributeValues.add(AttributeValue.builder().s(element.asText()).build()); + } else if (element.isNumber()) { + attributeValues.add(AttributeValue.builder().n(element.asText()).build()); + } else if (element.isBoolean()) { + attributeValues.add(AttributeValue.builder().bool(element.asBoolean()).build()); + } else if (element.isNull()) { + attributeValues.add(AttributeValue.builder().nul(true).build()); + } else if (element.isObject()) { + attributeValues.add(AttributeValue.builder().m(convertJsonObjectToItem(element)).build()); + } else if (element.isArray()) { + attributeValues.add(AttributeValue.builder().l(convertJsonArrayToAttributeValueList(element)).build()); + } else { + throw new IllegalArgumentException("Unsupported field type: " + element); + } + + } + + return attributeValues; + } + + @VisibleForTesting + static ObjectNode convertToObjectNode(Map item) { + ObjectNode objectNode = OBJECT_MAPPER.createObjectNode(); + + item.forEach((key, value) -> { + switch (value.type()) { + case S: + objectNode.put(key, value.s()); + break; + case N: + objectNode.put(key, value.n()); + break; + case BOOL: + objectNode.put(key, value.bool()); + break; + case L: + objectNode.put(key, convertToArrayNode(value.l())); + break; + case M: + objectNode.set(key, convertToObjectNode(value.m())); + break; + case NUL: + objectNode.putNull(key); + break; + default: + throw new IllegalArgumentException("Unsupported AttributeValue type: " + value.type()); + } + }); + + return objectNode; + + } + + @VisibleForTesting + static ArrayNode convertToArrayNode(final List attributeValueList) { + ArrayNode arrayNode = OBJECT_MAPPER.createArrayNode(); + attributeValueList.forEach(attribute -> { + switch (attribute.type()) { + case S: + arrayNode.add(attribute.s()); + break; + case N: + arrayNode.add(attribute.n()); + break; + case BOOL: + arrayNode.add(attribute.bool()); + break; + case L: + arrayNode.add(convertToArrayNode(attribute.l())); + break; + case M: + arrayNode.add(convertToObjectNode(attribute.m())); + break; + case NUL: + arrayNode.add((JsonNode) null); + break; + default: + throw new IllegalArgumentException("Unsupported AttributeValue type: " + attribute.type()); + } + }); + return arrayNode; + + } } diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java index 6be851aa39..8f46f969bd 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/SdkClientModule.java @@ -8,6 +8,10 @@ */ package org.opensearch.ml.sdkclient; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import lombok.extern.log4j.Log4j2; import org.apache.http.HttpHost; import org.apache.http.conn.ssl.NoopHostnameVerifier; import org.opensearch.OpenSearchException; @@ -17,12 +21,6 @@ import org.opensearch.client.transport.rest_client.RestClientTransport; import org.opensearch.common.inject.AbstractModule; import org.opensearch.sdk.SdkClient; - -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.PropertyNamingStrategies; - -import lombok.extern.log4j.Log4j2; import software.amazon.awssdk.auth.credentials.AwsCredentialsProviderChain; import software.amazon.awssdk.auth.credentials.ContainerCredentialsProvider; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; @@ -62,7 +60,7 @@ public SdkClientModule() { SdkClientModule(String remoteMetadataType, String remoteMetadataEndpoint, String region) { this.remoteMetadataType = remoteMetadataType; this.remoteMetadataEndpoint = remoteMetadataEndpoint; - this.region = region; + this.region = region == null ? "us-west-2" : region; } @Override @@ -80,7 +78,8 @@ protected void configure() { return; case AWS_DYNAMO_DB: log.info("Using dynamo DB as metadata store"); - bind(SdkClient.class).toInstance(new DDBOpenSearchClient(createDynamoDbClient())); + bind(SdkClient.class) + .toInstance(new DDBOpenSearchClient(createDynamoDbClient(), new RemoteClusterIndicesClient(createOpenSearchClient()))); return; default: log.info("Using local opensearch cluster as metadata store"); diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/ComplexDataObject.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/ComplexDataObject.java new file mode 100644 index 0000000000..b8e2bae2cb --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/ComplexDataObject.java @@ -0,0 +1,72 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.ml.sdkclient; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; + +@AllArgsConstructor +@Builder +@Getter +public class ComplexDataObject implements ToXContentObject { + private String testString; + private long testNumber; + private boolean testBool; + private List testList; + private TestDataObject testObject; + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + xContentBuilder.field("testString", this.testString); + xContentBuilder.field("testNumber", this.testNumber); + xContentBuilder.field("testBool", this.testBool); + xContentBuilder.field("testList", this.testList); + xContentBuilder.field("testObject", this.testObject); + return xContentBuilder.endObject(); + } + + public static ComplexDataObject parse(XContentParser parser) throws IOException { + ComplexDataObjectBuilder builder = ComplexDataObject.builder(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + if ("testString".equals(fieldName)) { + builder.testString(parser.text()); + } else if ("testNumber".equals(fieldName)) { + builder.testNumber(parser.longValue()); + } else if ("testBool".equals(fieldName)) { + builder.testBool(parser.booleanValue()); + } else if ("testList".equals(fieldName)) { + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + List list = new ArrayList<>(); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + list.add(parser.text()); + } + builder.testList(list); + } else if ("testObject".equals(fieldName)) { + builder.testObject(TestDataObject.parse(parser)); + } + } + return builder.build(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java index aed028e42c..57fb5496b9 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java @@ -9,13 +9,16 @@ package org.opensearch.ml.sdkclient; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import java.io.IOException; +import java.util.Arrays; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; +import java.util.concurrent.CompletionStage; import java.util.concurrent.TimeUnit; import org.junit.AfterClass; @@ -30,9 +33,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.sdk.DeleteDataObjectRequest; import org.opensearch.sdk.DeleteDataObjectResponse; import org.opensearch.sdk.GetDataObjectRequest; @@ -40,11 +41,21 @@ import org.opensearch.sdk.PutDataObjectRequest; import org.opensearch.sdk.PutDataObjectResponse; import org.opensearch.sdk.SdkClient; +import org.opensearch.sdk.SearchDataObjectRequest; +import org.opensearch.sdk.SearchDataObjectResponse; +import org.opensearch.sdk.UpdateDataObjectRequest; +import org.opensearch.sdk.UpdateDataObjectResponse; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ScalingExecutorBuilder; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.common.collect.ImmutableMap; + import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; @@ -53,6 +64,8 @@ import software.amazon.awssdk.services.dynamodb.model.GetItemResponse; import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; import software.amazon.awssdk.services.dynamodb.model.PutItemResponse; +import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; +import software.amazon.awssdk.services.dynamodb.model.UpdateItemResponse; public class DDBOpenSearchClientTests extends OpenSearchTestCase { @@ -63,12 +76,16 @@ public class DDBOpenSearchClientTests extends OpenSearchTestCase { @Mock private DynamoDbClient dynamoDbClient; + @Mock + private RemoteClusterIndicesClient remoteClusterIndicesClient; @Captor private ArgumentCaptor putItemRequestArgumentCaptor; @Captor private ArgumentCaptor getItemRequestArgumentCaptor; @Captor private ArgumentCaptor deleteItemRequestArgumentCaptor; + @Captor + private ArgumentCaptor updateItemRequestArgumentCaptor; private TestDataObject testDataObject; private static TestThreadPool testThreadPool = new TestThreadPool( @@ -91,7 +108,7 @@ public static void cleanup() { public void setup() { MockitoAnnotations.openMocks(this); - sdkClient = new DDBOpenSearchClient(dynamoDbClient); + sdkClient = new DDBOpenSearchClient(dynamoDbClient, remoteClusterIndicesClient); testDataObject = new TestDataObject("foo"); } @@ -116,9 +133,39 @@ public void testPutDataObject_HappyCase() throws IOException { Assert.assertEquals(TEST_INDEX, putItemRequest.tableName()); Assert.assertEquals(TEST_ID, putItemRequest.item().get("id").s()); Assert.assertEquals(TENANT_ID, putItemRequest.item().get("tenant_id").s()); - XContentBuilder sourceBuilder = XContentFactory.jsonBuilder(); - XContentBuilder builder = testDataObject.toXContent(sourceBuilder, ToXContent.EMPTY_PARAMS); - Assert.assertEquals(builder.toString(), putItemRequest.item().get("source").s()); + Assert.assertEquals("foo", putItemRequest.item().get("data").s()); + } + + @Test + public void testPutDataObject_WithComplexData() throws IOException { + ComplexDataObject complexDataObject = ComplexDataObject + .builder() + .testString("testString") + .testNumber(123) + .testBool(true) + .testList(Arrays.asList("123", "hello", null)) + .testObject(testDataObject) + .build(); + PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .tenantId(TENANT_ID) + .dataObject(complexDataObject) + .build(); + Mockito.when(dynamoDbClient.putItem(Mockito.any(PutItemRequest.class))).thenReturn(PutItemResponse.builder().build()); + PutDataObjectResponse response = sdkClient + .putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + Mockito.verify(dynamoDbClient).putItem(putItemRequestArgumentCaptor.capture()); + PutItemRequest putItemRequest = putItemRequestArgumentCaptor.getValue(); + Assert.assertEquals("testString", putItemRequest.item().get("testString").s()); + Assert.assertEquals("123", putItemRequest.item().get("testNumber").n()); + Assert.assertEquals(true, putItemRequest.item().get("testBool").bool()); + Assert.assertEquals("123", putItemRequest.item().get("testList").l().get(0).s()); + Assert.assertEquals("hello", putItemRequest.item().get("testList").l().get(1).s()); + Assert.assertEquals(null, putItemRequest.item().get("testList").l().get(2).s()); + Assert.assertEquals("foo", putItemRequest.item().get("testObject").m().get("data").s()); } @Test @@ -170,11 +217,9 @@ public void testPutDataObject_DDBException_ThrowsException() { @Test public void testGetDataObject_HappyCase() throws IOException { GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder().index(TEST_INDEX).id(TEST_ID).tenantId(TENANT_ID).build(); - XContentBuilder sourceBuilder = XContentFactory.jsonBuilder(); - XContentBuilder builder = testDataObject.toXContent(sourceBuilder, ToXContent.EMPTY_PARAMS); GetItemResponse getItemResponse = GetItemResponse .builder() - .item(Map.ofEntries(Map.entry("source", AttributeValue.builder().s(builder.toString()).build()))) + .item(Map.ofEntries(Map.entry("data", AttributeValue.builder().s("foo").build()))) .build(); Mockito.when(dynamoDbClient.getItem(Mockito.any(GetItemRequest.class))).thenReturn(getItemResponse); GetDataObjectResponse response = sdkClient @@ -188,7 +233,51 @@ public void testGetDataObject_HappyCase() throws IOException { Assert.assertEquals(TEST_ID, getItemRequest.key().get("id").s()); Assert.assertEquals(TEST_ID, response.id()); Assert.assertTrue(response.parser().isPresent()); - Assert.assertEquals("foo", response.parser().get().map().get("data")); + XContentParser parser = response.parser().get(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + assertEquals("foo", TestDataObject.parse(parser).data()); + } + + @Test + public void testGetDataObject_ComplexDataObject() throws IOException { + GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder().index(TEST_INDEX).id(TEST_ID).tenantId(TENANT_ID).build(); + GetItemResponse getItemResponse = GetItemResponse + .builder() + .item( + Map + .ofEntries( + Map.entry("testString", AttributeValue.builder().s("testString").build()), + Map.entry("testNumber", AttributeValue.builder().n("123").build()), + Map.entry("testBool", AttributeValue.builder().bool(true).build()), + Map + .entry( + "testList", + AttributeValue.builder().l(Arrays.asList(AttributeValue.builder().s("testString").build())).build() + ), + Map + .entry( + "testObject", + AttributeValue.builder().m(ImmutableMap.of("data", AttributeValue.builder().s("foo").build())).build() + ) + ) + ) + .build(); + Mockito.when(dynamoDbClient.getItem(Mockito.any(GetItemRequest.class))).thenReturn(getItemResponse); + GetDataObjectResponse response = sdkClient + .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + Mockito.verify(dynamoDbClient).getItem(getItemRequestArgumentCaptor.capture()); + GetItemRequest getItemRequest = getItemRequestArgumentCaptor.getValue(); + Assert.assertTrue(response.parser().isPresent()); + XContentParser parser = response.parser().get(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + ComplexDataObject complexDataObject = ComplexDataObject.parse(parser); + assertEquals("testString", complexDataObject.getTestString()); + assertEquals(123, complexDataObject.getTestNumber()); + assertEquals("testString", complexDataObject.getTestList().get(0)); + assertEquals("foo", complexDataObject.getTestObject().data()); + assertEquals(true, complexDataObject.isTestBool()); } @Test @@ -260,4 +349,98 @@ public void testDeleteDataObject_NullTenantId_UsesDefaultTenantId() { DeleteItemRequest deleteItemRequest = deleteItemRequestArgumentCaptor.getValue(); Assert.assertEquals("DEFAULT_TENANT", deleteItemRequest.key().get("tenant_id").s()); } + + @Test + public void updateDataObjectAsync_HappyCase() { + UpdateDataObjectRequest updateRequest = new UpdateDataObjectRequest.Builder() + .id(TEST_ID) + .index(TEST_INDEX) + .tenantId(TENANT_ID) + .dataObject(testDataObject) + .build(); + Mockito.when(dynamoDbClient.updateItem(updateItemRequestArgumentCaptor.capture())).thenReturn(UpdateItemResponse.builder().build()); + UpdateDataObjectResponse updateResponse = sdkClient + .updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + assertEquals(TEST_ID, updateResponse.id()); + UpdateItemRequest updateItemRequest = updateItemRequestArgumentCaptor.getValue(); + assertEquals(TEST_ID, updateRequest.id()); + assertEquals(TEST_INDEX, updateItemRequest.tableName()); + assertEquals(TEST_ID, updateItemRequest.key().get("id").s()); + assertEquals(TENANT_ID, updateItemRequest.key().get("tenant_id").s()); + assertEquals("foo", updateItemRequest.key().get("data").s()); + + } + + @Test + public void updateDataObjectAsync_NullTenantId_UsesDefaultTenantId() { + UpdateDataObjectRequest updateRequest = new UpdateDataObjectRequest.Builder() + .id(TEST_ID) + .index(TEST_INDEX) + .tenantId(TENANT_ID) + .dataObject(testDataObject) + .build(); + Mockito.when(dynamoDbClient.updateItem(updateItemRequestArgumentCaptor.capture())).thenReturn(UpdateItemResponse.builder().build()); + sdkClient.updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)).toCompletableFuture().join(); + UpdateItemRequest updateItemRequest = updateItemRequestArgumentCaptor.getValue(); + assertEquals(TENANT_ID, updateItemRequest.key().get("tenant_id").s()); + } + + @Test + public void searchDataObjectAsync_HappyCase() { + SearchDataObjectRequest searchDataObjectRequest = new SearchDataObjectRequest.Builder() + .indices(TEST_INDEX) + .tenantId(TENANT_ID) + .build(); + CompletionStage searchDataObjectResponse = Mockito.mock(CompletionStage.class); + Mockito + .when(remoteClusterIndicesClient.searchDataObjectAsync(Mockito.eq(searchDataObjectRequest), Mockito.any())) + .thenReturn(searchDataObjectResponse); + CompletionStage searchResponse = sdkClient.searchDataObjectAsync(searchDataObjectRequest); + + assertEquals(searchDataObjectResponse, searchResponse); + } + + @Test + public void convertJsonArrayToAttributeValueList_TestMultipleJsonType() throws Exception { + ObjectMapper objectMapper = new ObjectMapper(); + + JsonNode jsonNode = objectMapper.readTree("[\"testString\", 123, true, [\"test1\", \"test2\"], {\"hello\": \"all\"}]"); + DDBOpenSearchClient.convertJsonArrayToAttributeValueList(jsonNode); + } + + @Test + public void convertToObjectNode_TestNullInput() { + AttributeValue nullAttribute = AttributeValue.builder().nul(true).build(); + ObjectNode response = DDBOpenSearchClient.convertToObjectNode(ImmutableMap.of("test", nullAttribute)); + Assert.assertTrue(response.get("test").isNull()); + } + + @Test + public void convertToObjectNode_TestInvalidInput() { + AttributeValue nsAttribute = AttributeValue.builder().ns("123").build(); + Assert + .assertThrows( + IllegalArgumentException.class, + () -> DDBOpenSearchClient.convertToObjectNode(ImmutableMap.of("test", nsAttribute)) + ); + } + + @Test + public void convertToArrayNode_MultipleDataTypes() { + ArrayNode arrayNode = DDBOpenSearchClient + .convertToArrayNode( + Arrays + .asList( + AttributeValue.builder().s("test").build(), + AttributeValue.builder().n("123").build(), + AttributeValue.builder().l(AttributeValue.builder().s("testList").build()).build(), + AttributeValue.builder().nul(true).build(), + AttributeValue.builder().m(ImmutableMap.of("key", AttributeValue.builder().s("testMap").build())).build() + ) + ); + assertEquals(5, arrayNode.size()); + + } } diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java index 4c1b3e71ff..8667450d9c 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/SdkClientModuleTests.java @@ -44,7 +44,7 @@ public void testRemoteOpenSearchBinding() { } public void testDDBBinding() { - Injector injector = Guice.createInjector(new SdkClientModule(SdkClientModule.AWS_DYNAMO_DB, null, "eu-west-3")); + Injector injector = Guice.createInjector(new SdkClientModule(SdkClientModule.AWS_DYNAMO_DB, "http://example.org", "eu-west-3")); SdkClient sdkClient = injector.getInstance(SdkClient.class); assertTrue(sdkClient instanceof DDBOpenSearchClient);