From 9044be13467ee8a0fde8c2b4430581b6e590870c Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Fri, 21 Jun 2024 13:44:47 -0700 Subject: [PATCH] Update GetDataObjectResponse to include full GetResponse parser Signed-off-by: Daniel Widdis --- .../opensearch/sdk/GetDataObjectResponse.java | 21 +++-- .../sdk/GetDataObjectResponseTests.java | 6 +- .../GetModelGroupTransportAction.java | 82 ++++++++++++----- .../TransportUpdateModelGroupAction.java | 66 ++++++++------ .../tasks/DeleteTaskTransportAction.java | 87 +++++++++++-------- .../action/tasks/GetTaskTransportAction.java | 38 +++++--- .../helper/ConnectorAccessControlHelper.java | 46 ++++++---- .../ml/sdkclient/DDBOpenSearchClient.java | 32 +++++-- .../sdkclient/LocalClusterIndicesClient.java | 18 ++-- .../sdkclient/RemoteClusterIndicesClient.java | 34 +++----- .../sdkclient/DDBOpenSearchClientTests.java | 23 ++++- .../LocalClusterIndicesClientTests.java | 31 +++++-- .../RemoteClusterIndicesClientTests.java | 21 +++-- 13 files changed, 325 insertions(+), 180 deletions(-) diff --git a/common/src/main/java/org/opensearch/sdk/GetDataObjectResponse.java b/common/src/main/java/org/opensearch/sdk/GetDataObjectResponse.java index b884cc9eb4..3fa5f8e398 100644 --- a/common/src/main/java/org/opensearch/sdk/GetDataObjectResponse.java +++ b/common/src/main/java/org/opensearch/sdk/GetDataObjectResponse.java @@ -12,11 +12,10 @@ import java.util.Collections; import java.util.Map; -import java.util.Optional; public class GetDataObjectResponse { private final String id; - private final Optional parser; + private final XContentParser parser; private final Map source; /** @@ -24,10 +23,10 @@ public class GetDataObjectResponse { *

* For data storage implementations other than OpenSearch, the id may be referred to as a primary key. * @param id the document id - * @param parser an optional XContentParser that can be used to create the data object if present. + * @param parser a parser that can be used to create a GetResponse * @param source the data object as a map */ - public GetDataObjectResponse(String id, Optional parser, Map source) { + public GetDataObjectResponse(String id, XContentParser parser, Map source) { this.id = id; this.parser = parser; this.source = source; @@ -42,10 +41,10 @@ public String id() { } /** - * Returns the parser optional. If present, is a representation of the data object that may be parsed. - * @return the parser optional + * Returns the parser that can be used to create a GetResponse + * @return the parser */ - public Optional parser() { + public XContentParser parser() { return this.parser; } @@ -62,7 +61,7 @@ public Map source() { */ public static class Builder { private String id = null; - private Optional parser = Optional.empty(); + private XContentParser parser = null; private Map source = Collections.emptyMap(); /** @@ -81,11 +80,11 @@ public Builder id(String id) { } /** - * Add an optional parser to this builder - * @param parser an {@link Optional} which may contain the data object parser + * Add a parser to this builder + * @param parser a parser that can be used to create a GetResponse * @return the updated builder */ - public Builder parser(Optional parser) { + public Builder parser(XContentParser parser) { this.parser = parser; return this; } diff --git a/common/src/test/java/org/opensearch/sdk/GetDataObjectResponseTests.java b/common/src/test/java/org/opensearch/sdk/GetDataObjectResponseTests.java index 9e79593dd8..6b318e4f58 100644 --- a/common/src/test/java/org/opensearch/sdk/GetDataObjectResponseTests.java +++ b/common/src/test/java/org/opensearch/sdk/GetDataObjectResponseTests.java @@ -13,8 +13,6 @@ import org.opensearch.core.xcontent.XContentParser; import java.util.Map; -import java.util.Optional; - import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.mock; @@ -33,10 +31,10 @@ public void setUp() { @Test public void testGetDataObjectResponse() { - GetDataObjectResponse response = new GetDataObjectResponse.Builder().id(testId).parser(Optional.of(testParser)).source(testSource).build(); + GetDataObjectResponse response = new GetDataObjectResponse.Builder().id(testId).parser(testParser).source(testSource).build(); assertEquals(testId, response.id()); - assertEquals(testParser, response.parser().get()); + assertEquals(testParser, response.parser()); assertEquals(testSource, response.source()); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java index 8a46a2a423..17def2dc78 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java @@ -5,18 +5,21 @@ package org.opensearch.ml.action.model_group; +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; @@ -134,6 +137,38 @@ private void handleThrowable(Throwable throwable, String modelGroupId, ActionLis } } + /* + + if (r != null && r.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + + MLModelGroup mlModelGroup = MLModelGroup.parse(parser); + modelAccessControlHelper.validateModelGroupAccess(user, modelGroupId, client, ActionListener.wrap(access -> { + if (!access) { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model group", + RestStatus.FORBIDDEN + ) + ); + } else { + wrappedListener.onResponse(MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build()); + } + }, e -> { + log.error("Failed to validate access for Model Group " + modelGroupId, e); + wrappedListener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Failed to parse ml model group" + r.getId(), e); + wrappedListener.onFailure(e); + } + } else { + + */ + private void processResponse( GetDataObjectResponse getDataObjectResponse, String modelGroupId, @@ -141,28 +176,35 @@ private void processResponse( User user, ActionListener wrappedListener ) { - if (getDataObjectResponse != null && getDataObjectResponse.parser().isPresent()) { - try { - XContentParser parser = getDataObjectResponse.parser().get(); - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - MLModelGroup mlModelGroup = MLModelGroup.parse(parser); - - if (TenantAwareHelper - .validateTenantResource(mlFeatureEnabledSetting, tenantId, mlModelGroup.getTenantId(), wrappedListener)) { - validateModelGroupAccess(user, modelGroupId, mlModelGroup, wrappedListener); + try { + GetResponse r = GetResponse.fromXContent(getDataObjectResponse.parser()); + if (r != null && r.isExists()) { + try ( + XContentParser parser = jsonXContent + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, r.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLModelGroup mlModelGroup = MLModelGroup.parse(parser); + + if (TenantAwareHelper + .validateTenantResource(mlFeatureEnabledSetting, tenantId, mlModelGroup.getTenantId(), wrappedListener)) { + validateModelGroupAccess(user, modelGroupId, mlModelGroup, wrappedListener); + } + } catch (Exception e) { + log.error("Failed to parse ml connector {}", getDataObjectResponse.id(), e); + wrappedListener.onFailure(e); } - } catch (Exception e) { - log.error("Failed to parse ml connector {}", getDataObjectResponse.id(), e); - wrappedListener.onFailure(e); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "Failed to find model group with the provided model group id: " + modelGroupId, + RestStatus.NOT_FOUND + ) + ); } - } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "Failed to find model group with the provided model group id: " + modelGroupId, - RestStatus.NOT_FOUND - ) - ); + } catch (Exception e) { + wrappedListener.onFailure(e); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java index 5eb4c5ed74..566f40325b 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java @@ -5,6 +5,7 @@ package org.opensearch.ml.action.model_group; +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; @@ -16,6 +17,7 @@ import org.apache.commons.lang3.StringUtils; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.action.update.UpdateRequest; @@ -23,6 +25,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; @@ -128,37 +131,44 @@ protected void doExecute(Task task, ActionRequest request, ActionListener handleDeleteResponse( - response, - delThrowable, - tenantId, - actionListener + try { + GetResponse gr = GetResponse.fromXContent(r.parser()); + if (gr != null && gr.isExists()) { + try ( + XContentParser parser = jsonXContent + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, gr.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLTask mlTask = MLTask.parse(parser); + if (!TenantAwareHelper + .validateTenantResource(mlFeatureEnabledSetting, tenantId, mlTask.getTenantId(), actionListener)) { + return; + } + MLTaskState mlTaskState = mlTask.getState(); + if (mlTaskState.equals(MLTaskState.RUNNING)) { + actionListener + .onFailure(new Exception("Task cannot be deleted in running state. Try after sometime")); + } else { + DeleteRequest deleteRequest = new DeleteRequest(ML_TASK_INDEX, taskId); + try { + sdkClient + .deleteDataObjectAsync( + new DeleteDataObjectRequest.Builder() + .index(deleteRequest.index()) + .id(deleteRequest.id()) + .build(), + client.threadPool().executor(GENERAL_THREAD_POOL) ) - ); - } catch (Exception e) { - log.error("Failed to delete ML task: {}", taskId, e); - actionListener.onFailure(e); + .whenComplete( + (response, delThrowable) -> handleDeleteResponse( + response, + delThrowable, + tenantId, + actionListener + ) + ); + } catch (Exception e) { + log.error("Failed to delete ML task: {}", taskId, e); + actionListener.onFailure(e); + } } + } catch (Exception e) { + log.error("Failed to parse ml task {}", r.id(), e); + wrappedListener.onFailure(e); } - } catch (Exception e) { - log.error("Failed to parse ml task {}", r.id(), e); - wrappedListener.onFailure(e); + } else { + wrappedListener.onFailure(new OpenSearchStatusException("Fail to find task", RestStatus.NOT_FOUND)); } - } else { - wrappedListener.onFailure(new OpenSearchStatusException("Fail to find task", RestStatus.NOT_FOUND)); + } catch (Exception e) { + wrappedListener.onFailure(e); } } }); diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java index 9a19997eaa..bda65762bd 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java @@ -5,17 +5,20 @@ package org.opensearch.ml.action.tasks; +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; @@ -95,22 +98,29 @@ protected void doExecute(Task task, ActionRequest request, ActionListener getDataObjectAsync(GetDataObjectRe return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { try { final GetItemResponse getItemResponse = dynamoDbClient.getItem(getItemRequest); + String source; + boolean found; if (getItemResponse == null || getItemResponse.item() == null || getItemResponse.item().isEmpty()) { - return new GetDataObjectResponse.Builder().id(request.id()).parser(Optional.empty()).build(); + found = false; + source = null; + } else { + found = true; + source = getItemResponse.item().get(SOURCE).s(); } - - String source = getItemResponse.item().get(SOURCE).s(); + String simulatedGetResponse = "{\"_index\":\"" + + request.index() + + "\",\"_id\":\"" + + request.id() + + "\",\"_version\":1,\"_seq_no\":-2,\"_primary_term\":0,\"found\":" + + found + + ",\"_source\":" + + source + + "}"; XContentParser parser = JsonXContent.jsonXContent - .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, source); - return new GetDataObjectResponse.Builder().id(request.id()).parser(Optional.of(parser)).build(); + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, simulatedGetResponse); + // This would consume parser content so we need to create a new parser for the map + Map sourceAsMap = GetResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, simulatedGetResponse) + ) + .getSourceAsMap(); + return new GetDataObjectResponse.Builder().id(request.id()).parser(parser).source(sourceAsMap).build(); } catch (IOException e) { // Rethrow unchecked exception on XContent parsing error throw new OpenSearchStatusException("Failed to parse data object " + request.id(), RestStatus.BAD_REQUEST); diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java index 3d822ecb17..8a29b709fe 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java @@ -14,12 +14,13 @@ import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; +import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; import java.io.IOException; import java.security.AccessController; import java.security.PrivilegedAction; import java.util.Arrays; -import java.util.Optional; +import java.util.Collections; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.concurrent.Executor; @@ -41,7 +42,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.get.GetResult; import org.opensearch.sdk.DeleteDataObjectRequest; import org.opensearch.sdk.DeleteDataObjectResponse; import org.opensearch.sdk.GetDataObjectRequest; @@ -108,14 +109,19 @@ public CompletionStage getDataObjectAsync(GetDataObjectRe .get(new GetRequest(request.index(), request.id()).fetchSourceContext(request.fetchSourceContext())) .actionGet(); if (getResponse == null || !getResponse.isExists()) { - return new GetDataObjectResponse.Builder().id(request.id()).build(); + getResponse = new GetResponse( + new GetResult(request.index(), request.id(), UNASSIGNED_SEQ_NO, 0, 1, false, null, null, null) + ); + return new GetDataObjectResponse.Builder() + .id(request.id()) + .parser(jsonXContent.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getResponse.toString())) + .source(Collections.emptyMap()) + .build(); } - XContentParser parser = jsonXContent - .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getResponse.getSourceAsString()); log.info("Retrieved data object"); return new GetDataObjectResponse.Builder() .id(getResponse.getId()) - .parser(Optional.of(parser)) + .parser(jsonXContent.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getResponse.toString())) .source(getResponse.getSource()) .build(); } catch (IOException e) { diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java index 3ed85cb42b..f2e6585f82 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java @@ -20,7 +20,6 @@ import java.security.PrivilegedAction; import java.util.Arrays; import java.util.Map; -import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.concurrent.Executor; @@ -47,7 +46,6 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; import org.opensearch.sdk.DeleteDataObjectRequest; import org.opensearch.sdk.DeleteDataObjectResponse; import org.opensearch.sdk.GetDataObjectRequest; @@ -60,9 +58,6 @@ import org.opensearch.sdk.UpdateDataObjectRequest; import org.opensearch.sdk.UpdateDataObjectResponse; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.databind.ObjectMapper; - import jakarta.json.stream.JsonGenerator; import jakarta.json.stream.JsonParser; import lombok.extern.log4j.Log4j2; @@ -73,17 +68,19 @@ @Log4j2 public class RemoteClusterIndicesClient implements SdkClient { - private OpenSearchClient openSearchClient; - @SuppressWarnings("unchecked") private static final Class> MAP_DOCTYPE = (Class>) (Class) Map.class; + private OpenSearchClient openSearchClient; + private JsonpMapper mapper; + /** * Instantiate this object with an OpenSearch Java client. * @param openSearchClient The client to wrap */ public RemoteClusterIndicesClient(OpenSearchClient openSearchClient) { this.openSearchClient = openSearchClient; + this.mapper = openSearchClient._transport().jsonpMapper(); } @Override @@ -112,18 +109,13 @@ public CompletionStage getDataObjectAsync(GetDataObjectRe try { GetRequest getRequest = new GetRequest.Builder().index(request.index()).id(request.id()).build(); log.info("Getting {} from {}", request.id(), request.index()); - @SuppressWarnings("rawtypes") - GetResponse getResponse = openSearchClient.get(getRequest, Map.class); - if (!getResponse.found()) { - return new GetDataObjectResponse.Builder().id(getResponse.id()).build(); - } - // Since we use the JacksonJsonBMapper we know this is String-Object map - @SuppressWarnings("unchecked") + GetResponse> getResponse = openSearchClient.get(getRequest, MAP_DOCTYPE); Map source = getResponse.source(); - String json = new ObjectMapper().setSerializationInclusion(Include.NON_NULL).writeValueAsString(source); - XContentParser parser = JsonXContent.jsonXContent - .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, json); - return new GetDataObjectResponse.Builder().id(getResponse.id()).parser(Optional.of(parser)).source(source).build(); + return new GetDataObjectResponse.Builder() + .id(getResponse.id()) + .parser(jsonXContent.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, toJson(getResponse))) + .source(source) + .build(); } catch (IOException e) { log.error("Error getting data object {} from {}: {}", request.id(), request.index(), e.getMessage(), e); // Rethrow unchecked exception on XContent parser creation error @@ -206,7 +198,6 @@ public CompletionStage searchDataObjectAsync(SearchDat return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { try { log.info("Searching {}", Arrays.toString(request.indices()), null); - JsonpMapper mapper = openSearchClient._transport().jsonpMapper(); JsonParser parser = mapper.jsonProvider().createParser(new StringReader(request.searchSourceBuilder().toString())); SearchRequest searchRequest = SearchRequest._DESERIALIZER.deserialize(parser, mapper); searchRequest = searchRequest.toBuilder().index(Arrays.asList(request.indices())).build(); @@ -215,8 +206,7 @@ public CompletionStage searchDataObjectAsync(SearchDat log.info("Search returned {} hits", searchResponse.hits().total().value()); return new SearchDataObjectResponse.Builder() .parser( - jsonXContent - .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, toJson(searchResponse, mapper)) + jsonXContent.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, toJson(searchResponse)) ) .build(); } catch (IOException e) { @@ -230,7 +220,7 @@ public CompletionStage searchDataObjectAsync(SearchDat }), executor); } - private String toJson(JsonpSerializable obj, JsonpMapper mapper) { + private String toJson(JsonpSerializable obj) { StringWriter stringWriter = new StringWriter(); try (JsonGenerator generator = mapper.jsonProvider().createGenerator(stringWriter)) { mapper.serialize(obj, generator); diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java index aed028e42c..6413b0bc04 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java @@ -9,6 +9,7 @@ package org.opensearch.ml.sdkclient; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; @@ -27,12 +28,18 @@ import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; +import org.opensearch.action.get.GetResponse; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.sdk.DeleteDataObjectRequest; import org.opensearch.sdk.DeleteDataObjectResponse; import org.opensearch.sdk.GetDataObjectRequest; @@ -187,8 +194,17 @@ public void testGetDataObject_HappyCase() throws IOException { Assert.assertEquals(TENANT_ID, getItemRequest.key().get("tenant_id").s()); Assert.assertEquals(TEST_ID, getItemRequest.key().get("id").s()); Assert.assertEquals(TEST_ID, response.id()); - Assert.assertTrue(response.parser().isPresent()); - Assert.assertEquals("foo", response.parser().get().map().get("data")); + Assert.assertEquals("foo", response.source().get("data")); + XContentParser parser = response.parser(); + XContentParser dataParser = XContentHelper + .createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + GetResponse.fromXContent(parser).getSourceAsBytesRef(), + XContentType.JSON + ); + ensureExpectedToken(XContentParser.Token.START_OBJECT, dataParser.nextToken(), dataParser); + Assert.assertEquals("foo", TestDataObject.parse(dataParser).data()); } @Test @@ -201,7 +217,8 @@ public void testGetDataObject_NoExistingDoc() throws IOException { .toCompletableFuture() .join(); Assert.assertEquals(TEST_ID, response.id()); - Assert.assertFalse(response.parser().isPresent()); + assertNull(response.source()); + assertFalse(GetResponse.fromXContent(response.parser()).isExists()); } @Test diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java index e9a98e5d48..33bf76d7f3 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java @@ -18,6 +18,7 @@ import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import java.io.IOException; +import java.util.Collections; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.TimeUnit; @@ -47,7 +48,9 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; @@ -155,6 +158,16 @@ public void testGetDataObject() throws IOException { when(getResponse.getId()).thenReturn(TEST_ID); String json = testDataObject.toJson(); when(getResponse.getSourceAsString()).thenReturn(json); + when(getResponse.toString()) + .thenReturn( + "{\"_index\":\"" + + TEST_INDEX + + "\",\"_id\":\"" + + TEST_ID + + "\",\"_version\":1,\"_seq_no\":-2,\"_primary_term\":0,\"found\":true,\"_source\":" + + json + + "}" + ); when(getResponse.getSource()).thenReturn(XContentHelper.convertToMap(JsonXContent.jsonXContent, json, false)); @SuppressWarnings("unchecked") ActionFuture future = mock(ActionFuture.class); @@ -171,10 +184,16 @@ public void testGetDataObject() throws IOException { assertEquals(TEST_INDEX, requestCaptor.getValue().index()); assertEquals(TEST_ID, response.id()); assertEquals("foo", response.source().get("data")); - assertTrue(response.parser().isPresent()); - XContentParser parser = response.parser().get(); - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - assertEquals("foo", TestDataObject.parse(parser).data()); + XContentParser parser = response.parser(); + XContentParser dataParser = XContentHelper + .createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + GetResponse.fromXContent(parser).getSourceAsBytesRef(), + XContentType.JSON + ); + ensureExpectedToken(XContentParser.Token.START_OBJECT, dataParser.nextToken(), dataParser); + assertEquals("foo", TestDataObject.parse(dataParser).data()); } public void testGetDataObject_NotFound() throws IOException { @@ -182,6 +201,7 @@ public void testGetDataObject_NotFound() throws IOException { GetResponse getResponse = mock(GetResponse.class); when(getResponse.isExists()).thenReturn(false); + when(getResponse.toString()).thenReturn("{\"found\":false,\"_source\":{}}"); @SuppressWarnings("unchecked") ActionFuture future = mock(ActionFuture.class); when(mockedClient.get(any(GetRequest.class))).thenReturn(future); @@ -196,7 +216,8 @@ public void testGetDataObject_NotFound() throws IOException { verify(mockedClient, times(1)).get(requestCaptor.capture()); assertEquals(TEST_INDEX, requestCaptor.getValue().index()); assertEquals(TEST_ID, response.id()); - assertFalse(response.parser().isPresent()); + assertEquals(Collections.emptyMap(), response.source()); + assertFalse(GetResponse.fromXContent(response.parser()).isExists()); } public void testGetDataObject_Exception() throws IOException { diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java index 03d26a2a8f..1604c0ef3e 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java @@ -50,6 +50,10 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.sdk.DeleteDataObjectRequest; import org.opensearch.sdk.DeleteDataObjectResponse; @@ -209,10 +213,16 @@ public void testGetDataObject() throws IOException { assertEquals(TEST_INDEX, getRequestCaptor.getValue().index()); assertEquals(TEST_ID, response.id()); assertEquals("foo", response.source().get("data")); - assertTrue(response.parser().isPresent()); - XContentParser parser = response.parser().get(); - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - assertEquals("foo", TestDataObject.parse(parser).data()); + XContentParser parser = response.parser(); + XContentParser dataParser = XContentHelper + .createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + org.opensearch.action.get.GetResponse.fromXContent(parser).getSourceAsBytesRef(), + XContentType.JSON + ); + ensureExpectedToken(XContentParser.Token.START_OBJECT, dataParser.nextToken(), dataParser); + assertEquals("foo", TestDataObject.parse(dataParser).data()); } @SuppressWarnings({ "unchecked", "rawtypes" }) @@ -232,7 +242,8 @@ public void testGetDataObject_NotFound() throws IOException { assertEquals(TEST_INDEX, getRequestCaptor.getValue().index()); assertEquals(TEST_ID, response.id()); - assertFalse(response.parser().isPresent()); + assertNull(response.source()); + assertFalse(org.opensearch.action.get.GetResponse.fromXContent(response.parser()).isExists()); } @SuppressWarnings({ "unchecked", "rawtypes" })