diff --git a/common/src/main/java/org/opensearch/sdk/SdkClient.java b/common/src/main/java/org/opensearch/sdk/SdkClient.java index 9fb195e13f..78f3d8b9a5 100644 --- a/common/src/main/java/org/opensearch/sdk/SdkClient.java +++ b/common/src/main/java/org/opensearch/sdk/SdkClient.java @@ -42,11 +42,7 @@ default PutDataObjectResponse putDataObject(PutDataObjectRequest request) { try { return putDataObjectAsync(request).toCompletableFuture().join(); } catch (CompletionException e) { - Throwable cause = e.getCause(); - if (cause instanceof InterruptedException) { - Thread.currentThread().interrupt(); - } - throw new OpenSearchException(cause); + throw unwrapAndConvertToRuntime(e); } } @@ -76,11 +72,37 @@ default GetDataObjectResponse getDataObject(GetDataObjectRequest request) { try { return getDataObjectAsync(request).toCompletableFuture().join(); } catch (CompletionException e) { - Throwable cause = e.getCause(); - if (cause instanceof InterruptedException) { - Thread.currentThread().interrupt(); - } - throw new OpenSearchException(cause); + throw unwrapAndConvertToRuntime(e); + } + } + + /** + * Update a data object/document in a table/index. + * @param request A request identifying the data object to update + * @param executor the executor to use for asynchronous execution + * @return A completion stage encapsulating the response or exception + */ + public CompletionStage updateDataObjectAsync(UpdateDataObjectRequest request, Executor executor); + + /** + * Update a data object/document in a table/index. + * @param request A request identifying the data object to update + * @return A completion stage encapsulating the response or exception + */ + default CompletionStage updateDataObjectAsync(UpdateDataObjectRequest request) { + return updateDataObjectAsync(request, ForkJoinPool.commonPool()); + } + + /** + * Update a data object/document in a table/index. + * @param request A request identifying the data object to update + * @return A response on success. Throws {@link OpenSearchException} wrapping the cause on exception. + */ + default UpdateDataObjectResponse updateDataObject(UpdateDataObjectRequest request) { + try { + return updateDataObjectAsync(request).toCompletableFuture().join(); + } catch (CompletionException e) { + throw unwrapAndConvertToRuntime(e); } } @@ -110,11 +132,18 @@ default DeleteDataObjectResponse deleteDataObject(DeleteDataObjectRequest reques try { return deleteDataObjectAsync(request).toCompletableFuture().join(); } catch (CompletionException e) { - Throwable cause = e.getCause(); - if (cause instanceof InterruptedException) { - Thread.currentThread().interrupt(); - } - throw new OpenSearchException(cause); + throw unwrapAndConvertToRuntime(e); + } + } + + private static RuntimeException unwrapAndConvertToRuntime(CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + if (cause instanceof RuntimeException) { + return (RuntimeException) cause; } + return new OpenSearchException(cause); } } diff --git a/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java new file mode 100644 index 0000000000..25891a1167 --- /dev/null +++ b/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.sdk; + +import org.opensearch.core.xcontent.ToXContentObject; + +public class UpdateDataObjectRequest { + + private final String index; + private final String id; + private final ToXContentObject dataObject; + + /** + * Instantiate this request with an index and data object. + *

+ * For data storage implementations other than OpenSearch, an index may be referred to as a table and the data object may be referred to as an item. + * @param index the index location to update the object + * @param id the document id + * @param dataObject the data object + */ + public UpdateDataObjectRequest(String index, String id, ToXContentObject dataObject) { + this.index = index; + this.id = id; + this.dataObject = dataObject; + } + + /** + * Returns the index + * @return the index + */ + public String index() { + return this.index; + } + + /** + * Returns the document id + * @return the id + */ + public String id() { + return this.id; + } + + /** + * Returns the data object + * @return the data object + */ + public ToXContentObject dataObject() { + return this.dataObject; + } + + /** + * Class for constructing a Builder for this Request Object + */ + public static class Builder { + private String index = null; + private String id = null; + private ToXContentObject dataObject = null; + + /** + * Empty Constructor for the Builder object + */ + public Builder() {} + + /** + * Add an index to this builder + * @param index the index to put the object + * @return the updated builder + */ + public Builder index(String index) { + this.index = index; + return this; + } + + /** + * Add an id to this builder + * @param id the document id + * @return the updated builder + */ + public Builder id(String id) { + this.id = id; + return this; + } + + /** + * Add a data object to this builder + * @param dataObject the data object + * @return the updated builder + */ + public Builder dataObject(ToXContentObject dataObject) { + this.dataObject = dataObject; + return this; + } + + /** + * Builds the request + * @return A {@link UpdateDataObjectRequest} + */ + public UpdateDataObjectRequest build() { + return new UpdateDataObjectRequest(this.index, this.id, this.dataObject); + } + } +} diff --git a/common/src/main/java/org/opensearch/sdk/UpdateDataObjectResponse.java b/common/src/main/java/org/opensearch/sdk/UpdateDataObjectResponse.java new file mode 100644 index 0000000000..56711c60d6 --- /dev/null +++ b/common/src/main/java/org/opensearch/sdk/UpdateDataObjectResponse.java @@ -0,0 +1,140 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.sdk; + +import org.opensearch.action.support.replication.ReplicationResponse.ShardInfo; +import org.opensearch.core.common.Strings; +import org.opensearch.core.index.shard.ShardId; + +public class UpdateDataObjectResponse { + private final String id; + private final ShardId shardId; + private final ShardInfo shardInfo; + private final boolean updated; + + /** + * Instantiate this request with an id and update status. + *

+ * For data storage implementations other than OpenSearch, the id may be referred to as a primary key. + * @param id the document id + * @param shardId the shard id + * @param shardInfo the shard info + * @param updated Whether the object was updated. + */ + public UpdateDataObjectResponse(String id, ShardId shardId, ShardInfo shardInfo, boolean updated) { + this.id = id; + this.shardId = shardId; + this.shardInfo = shardInfo; + this.updated = updated; + } + + /** + * Returns the document id + * @return the id + */ + public String id() { + return id; + } + + /** + * Returns the shard id. + * @return the shard id, or a generated id if shards are not applicable + */ + public ShardId shardId() { + return shardId; + } + + /** + * Returns the shard info. + * @return the shard info, or generated info if shards are not applicable + */ + public ShardInfo shardInfo() { + return shardInfo; + } + + /** + * Returns whether update was successful + * @return true if update was successful + */ + public boolean updated() { + return updated; + } + + /** + * Class for constructing a Builder for this Response Object + */ + public static class Builder { + private String id = null; + private ShardId shardId = null; + private ShardInfo shardInfo = null; + private boolean updated = false; + + /** + * Empty Constructor for the Builder object + */ + public Builder() {} + + /** + * Add an id to this builder + * @param id the id to add + * @return the updated builder + */ + public Builder id(String id) { + this.id = id; + return this; + } + + /** + * Adds a shard id to this builder + * @param shardId the shard id to add + * @return the updated builder + */ + public Builder shardId(ShardId shardId) { + this.shardId = shardId; + return this; + } + + /** + * Adds a generated shard id to this builder + * @param indexName the index name to generate a shard id + * @return the updated builder + */ + public Builder shardId(String indexName) { + this.shardId = new ShardId(indexName, Strings.UNKNOWN_UUID_VALUE, 0); + return this; + } + + /** + * Adds shard information (statistics) to this builder + * @param shardInfo the shard info to add + * @return the updated builder + */ + public Builder shardInfo(ShardInfo shardInfo) { + this.shardInfo = shardInfo; + return this; + } + /** + * Add a updated status to this builder + * @param updated the updated status to add + * @return the updated builder + */ + public Builder updated(boolean updated) { + this.updated = updated; + return this; + } + + /** + * Builds the object + * @return A {@link UpdateDataObjectResponse} + */ + public UpdateDataObjectResponse build() { + return new UpdateDataObjectResponse(this.id, this.shardId, this.shardInfo, this.updated); + } + } +} diff --git a/common/src/test/java/org/opensearch/sdk/SdkClientTests.java b/common/src/test/java/org/opensearch/sdk/SdkClientTests.java index 08b3732a42..141c63775f 100644 --- a/common/src/test/java/org/opensearch/sdk/SdkClientTests.java +++ b/common/src/test/java/org/opensearch/sdk/SdkClientTests.java @@ -13,6 +13,8 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchException; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.core.rest.RestStatus; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; @@ -40,11 +42,15 @@ public class SdkClientTests { @Mock private GetDataObjectResponse getResponse; @Mock + private UpdateDataObjectRequest updateRequest; + @Mock + private UpdateDataObjectResponse updateResponse; + @Mock private DeleteDataObjectRequest deleteRequest; @Mock private DeleteDataObjectResponse deleteResponse; - private RuntimeException testException; + private OpenSearchStatusException testException; private InterruptedException interruptedException; @Before @@ -61,12 +67,17 @@ public CompletionStage getDataObjectAsync(GetDataObjectRe return CompletableFuture.completedFuture(getResponse); } + @Override + public CompletionStage updateDataObjectAsync(UpdateDataObjectRequest request, Executor executor) { + return CompletableFuture.completedFuture(updateResponse); + } + @Override public CompletionStage deleteDataObjectAsync(DeleteDataObjectRequest request, Executor executor) { return CompletableFuture.completedFuture(deleteResponse); } }); - testException = new RuntimeException(); + testException = new OpenSearchStatusException("Test", RestStatus.BAD_REQUEST); interruptedException = new InterruptedException(); } @@ -81,10 +92,10 @@ public void testPutDataObjectException() { when(sdkClient.putDataObjectAsync(any(PutDataObjectRequest.class), any(Executor.class))) .thenReturn(CompletableFuture.failedFuture(testException)); - OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { + OpenSearchStatusException exception = assertThrows(OpenSearchStatusException.class, () -> { sdkClient.putDataObject(putRequest); }); - assertEquals(testException, exception.getCause()); + assertEquals(testException, exception); assertFalse(Thread.interrupted()); verify(sdkClient).putDataObjectAsync(any(PutDataObjectRequest.class), any(Executor.class)); } @@ -113,10 +124,10 @@ public void testGetDataObjectException() { when(sdkClient.getDataObjectAsync(any(GetDataObjectRequest.class), any(Executor.class))) .thenReturn(CompletableFuture.failedFuture(testException)); - OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { + OpenSearchStatusException exception = assertThrows(OpenSearchStatusException.class, () -> { sdkClient.getDataObject(getRequest); }); - assertEquals(testException, exception.getCause()); + assertEquals(testException, exception); assertFalse(Thread.interrupted()); verify(sdkClient).getDataObjectAsync(any(GetDataObjectRequest.class), any(Executor.class)); } @@ -134,6 +145,37 @@ public void testGetDataObjectInterrupted() { verify(sdkClient).getDataObjectAsync(any(GetDataObjectRequest.class), any(Executor.class)); } + + @Test + public void testUpdateDataObjectSuccess() { + assertEquals(updateResponse, sdkClient.updateDataObject(updateRequest)); + verify(sdkClient).updateDataObjectAsync(any(UpdateDataObjectRequest.class), any(Executor.class)); + } + + @Test + public void testUpdateDataObjectException() { + when(sdkClient.updateDataObjectAsync(any(UpdateDataObjectRequest.class), any(Executor.class))) + .thenReturn(CompletableFuture.failedFuture(testException)); + OpenSearchStatusException exception = assertThrows(OpenSearchStatusException.class, () -> { + sdkClient.updateDataObject(updateRequest); + }); + assertEquals(testException, exception); + assertFalse(Thread.interrupted()); + verify(sdkClient).updateDataObjectAsync(any(UpdateDataObjectRequest.class), any(Executor.class)); + } + + @Test + public void testUpdateDataObjectInterrupted() { + when(sdkClient.updateDataObjectAsync(any(UpdateDataObjectRequest.class), any(Executor.class))) + .thenReturn(CompletableFuture.failedFuture(interruptedException)); + OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { + sdkClient.updateDataObject(updateRequest); + }); + assertEquals(interruptedException, exception.getCause()); + assertTrue(Thread.interrupted()); + verify(sdkClient).updateDataObjectAsync(any(UpdateDataObjectRequest.class), any(Executor.class)); + } + @Test public void testDeleteDataObjectSuccess() { assertEquals(deleteResponse, sdkClient.deleteDataObject(deleteRequest)); @@ -144,10 +186,10 @@ public void testDeleteDataObjectSuccess() { public void testDeleteDataObjectException() { when(sdkClient.deleteDataObjectAsync(any(DeleteDataObjectRequest.class), any(Executor.class))) .thenReturn(CompletableFuture.failedFuture(testException)); - OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { + OpenSearchStatusException exception = assertThrows(OpenSearchStatusException.class, () -> { sdkClient.deleteDataObject(deleteRequest); }); - assertEquals(testException, exception.getCause()); + assertEquals(testException, exception); assertFalse(Thread.interrupted()); verify(sdkClient).deleteDataObjectAsync(any(DeleteDataObjectRequest.class), any(Executor.class)); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java index b4220fef95..b5e6fb81e0 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java @@ -14,6 +14,7 @@ import java.util.Arrays; import java.util.List; +import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.delete.DeleteRequest; @@ -123,7 +124,7 @@ private void checkForModelsUsingConnector(String connectorId, String tenantId, A sourceBuilder.query(QueryBuilders.matchQuery(TENANT_ID, tenantId)); } searchRequest.source(sourceBuilder); - // TODO: User SDK client not client. + // TODO: Use SDK client not client. client.search(searchRequest, ActionListener.runBefore(ActionListener.wrap(searchResponse -> { SearchHit[] searchHits = searchResponse.getHits().getHits(); if (searchHits.length == 0) { @@ -186,8 +187,13 @@ private void handleDeleteResponse( ActionListener actionListener ) { if (throwable != null) { - log.error("Failed to delete ML connector: {}", connectorId, throwable); - actionListener.onFailure(new RuntimeException(throwable)); + Throwable cause = throwable.getCause() == null ? throwable : throwable.getCause(); + log.error("Failed to delete ML connector: {}", connectorId, cause); + if (cause instanceof Exception) { + actionListener.onFailure((Exception) cause); + } else { + actionListener.onFailure(new OpenSearchException(cause)); + } } else { log.info("Connector deletion result: {}, connector id: {}", response.deleted(), response.id()); DeleteResponse deleteResponse = new DeleteResponse(response.shardId(), response.id(), 0, 0, 0, response.deleted()); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java index 8221b33886..9d09bfffbc 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java @@ -93,7 +93,6 @@ protected void doExecute(Task task, ActionRequest request, ActionListener handleConnectorAccessValidationFailure(connectorId, e, actionListener) ) ); - } catch (Exception e) { log.error("Failed to get ML connector {}", connectorId, e); actionListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java index e7cb3c40ac..7f4253e59d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java @@ -12,6 +12,7 @@ import java.util.HashSet; import java.util.List; +import org.opensearch.OpenSearchException; import org.opensearch.action.ActionRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; @@ -147,7 +148,13 @@ private void indexConnector(Connector connector, ActionListener { context.restore(); if (throwable != null) { - listener.onFailure(new RuntimeException(throwable)); + Throwable cause = throwable.getCause() == null ? throwable : throwable.getCause(); + log.error("Failed to create ML connector", cause); + if (cause instanceof Exception) { + listener.onFailure((Exception) cause); + } else { + listener.onFailure(new OpenSearchException(cause)); + } } else { log.info("Connector creation result: {}, connector id: {}", r.created(), r.id()); MLCreateConnectorResponse response = new MLCreateConnectorResponse(r.id()); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java index f227b22e2a..7b4fa7e787 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java @@ -7,31 +7,30 @@ import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.DocWriteResponse.Result; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -46,6 +45,8 @@ import org.opensearch.ml.utils.TenantAwareHelper; import org.opensearch.sdk.GetDataObjectRequest; import org.opensearch.sdk.SdkClient; +import org.opensearch.sdk.UpdateDataObjectRequest; +import org.opensearch.sdk.UpdateDataObjectResponse; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; @@ -60,7 +61,8 @@ @FieldDefaults(level = AccessLevel.PRIVATE) public class UpdateConnectorTransportAction extends HandledTransportAction { Client client; - private final SdkClient sdkClient; + SdkClient sdkClient; + ConnectorAccessControlHelper connectorAccessControlHelper; private final MLFeatureEnabledSetting mlFeatureEnabledSetting; MLModelManager mlModelManager; @@ -122,10 +124,12 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - log.error("Unable to find the connector with ID {}. Details: {}", connectorId, exception); + log.error("Permission denied: Unable to update the connector with ID {}. Details: {}", connectorId, exception); listener.onFailure(exception); })); } catch (Exception e) { @@ -147,7 +151,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener listener, ThreadContext.StoredContext context ) { @@ -159,10 +163,15 @@ private void updateUndeployedConnector( sourceBuilder.query(boolQueryBuilder); searchRequest.source(sourceBuilder); + // TODO: Use SDK client not client. client.search(searchRequest, ActionListener.wrap(searchResponse -> { SearchHit[] searchHits = searchResponse.getHits().getHits(); if (searchHits.length == 0) { - client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context)); + sdkClient + .updateDataObjectAsync(updateDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) + .whenComplete((r, throwable) -> { + handleUpdateDataObjectCompletionStage(r, throwable, getUpdateResponseListener(connectorId, listener, context)); + }); } else { log.error(searchHits.length + " models are still using this connector, please undeploy the models first!"); List modelIds = new ArrayList<>(); @@ -181,15 +190,36 @@ private void updateUndeployedConnector( } }, e -> { if (e instanceof IndexNotFoundException) { - client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context)); + sdkClient + .updateDataObjectAsync(updateDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) + .whenComplete((r, throwable) -> { + handleUpdateDataObjectCompletionStage(r, throwable, getUpdateResponseListener(connectorId, listener, context)); + }); return; } log.error("Failed to update ML connector: " + connectorId, e); listener.onFailure(e); - })); } + private void handleUpdateDataObjectCompletionStage( + UpdateDataObjectResponse r, + Throwable throwable, + ActionListener updateListener + ) { + if (throwable != null) { + Throwable cause = throwable.getCause() == null ? throwable : throwable.getCause(); + if (cause instanceof Exception) { + updateListener.onFailure((Exception) cause); + } else { + updateListener.onFailure(new OpenSearchException(cause)); + } + } else { + log.info("Connector update result: {}, connector id: {}", r.updated(), r.id()); + updateListener.onResponse(new UpdateResponse(r.shardId(), r.id(), 0, 0, 0, r.updated() ? Result.UPDATED : Result.CREATED)); + } + } + private ActionListener getUpdateResponseListener( String connectorId, ActionListener actionListener, diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java index 2b9935f82e..76b58e7695 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java @@ -14,6 +14,7 @@ import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetRequest; import org.opensearch.client.Client; @@ -176,8 +177,12 @@ public void getConnector( log.error("Failed to get connector index", cause); listener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND)); } else { - log.error("Failed to find connector {}", connectorId, cause); - listener.onFailure(new RuntimeException(cause)); + log.error("Failed to get ML connector " + connectorId, cause); + if (cause instanceof Exception) { + listener.onFailure((Exception) cause); + } else { + listener.onFailure(new OpenSearchException(cause)); + } } } else { if (r != null && r.parser().isPresent()) { diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java index 2af15f6bb3..dca67bd15b 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClient.java @@ -10,10 +10,12 @@ import static org.opensearch.action.DocWriteResponse.Result.CREATED; import static org.opensearch.action.DocWriteResponse.Result.DELETED; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; +import java.io.IOException; import java.security.AccessController; import java.security.PrivilegedAction; import java.util.Optional; @@ -21,7 +23,6 @@ import java.util.concurrent.CompletionStage; import java.util.concurrent.Executor; -import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.delete.DeleteResponse; @@ -29,13 +30,15 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.IndexNotFoundException; import org.opensearch.sdk.DeleteDataObjectRequest; import org.opensearch.sdk.DeleteDataObjectResponse; import org.opensearch.sdk.GetDataObjectRequest; @@ -43,6 +46,8 @@ import org.opensearch.sdk.PutDataObjectRequest; import org.opensearch.sdk.PutDataObjectResponse; import org.opensearch.sdk.SdkClient; +import org.opensearch.sdk.UpdateDataObjectRequest; +import org.opensearch.sdk.UpdateDataObjectResponse; import lombok.extern.log4j.Log4j2; @@ -79,8 +84,12 @@ public CompletionStage putDataObjectAsync(PutDataObjectRe .actionGet(); log.info("Creation status for id {}: {}", indexResponse.getId(), indexResponse.getResult()); return new PutDataObjectResponse.Builder().id(indexResponse.getId()).created(indexResponse.getResult() == CREATED).build(); - } catch (Exception e) { - throw new OpenSearchException(e); + } catch (IOException e) { + // Rethrow unchecked exception on XContent parsing error + throw new OpenSearchStatusException( + "Failed to parse data object to put in index " + request.index(), + RestStatus.BAD_REQUEST + ); } }), executor); } @@ -90,7 +99,9 @@ public CompletionStage getDataObjectAsync(GetDataObjectRe return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { try { log.info("Getting {} from {}", request.id(), request.index()); - GetResponse getResponse = client.get(new GetRequest(request.index(), request.id())).actionGet(); + GetResponse getResponse = client + .get(new GetRequest(request.index(), request.id()).fetchSourceContext(request.fetchSourceContext())) + .actionGet(); if (getResponse == null || !getResponse.isExists()) { return new GetDataObjectResponse.Builder().id(request.id()).build(); } @@ -102,30 +113,55 @@ public CompletionStage getDataObjectAsync(GetDataObjectRe .parser(Optional.of(parser)) .source(getResponse.getSource()) .build(); - } catch (OpenSearchStatusException | IndexNotFoundException notFound) { - throw notFound; - } catch (Exception e) { - throw new OpenSearchException(e); + } catch (IOException e) { + // Rethrow unchecked exception on XContent parser creation error + throw new OpenSearchStatusException( + "Failed to create parser for data object retrieved from index " + request.index(), + RestStatus.INTERNAL_SERVER_ERROR + ); } }), executor); } @Override - public CompletionStage deleteDataObjectAsync(DeleteDataObjectRequest request, Executor executor) { - return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { - try { - log.info("Deleting {} from {}", request.id(), request.index()); - DeleteResponse deleteResponse = client.delete(new DeleteRequest(request.index(), request.id())).actionGet(); - log.info("Deletion status for id {}: {}", deleteResponse.getId(), deleteResponse.getResult()); - return new DeleteDataObjectResponse.Builder() - .id(deleteResponse.getId()) - .shardId(deleteResponse.getShardId()) - .shardInfo(deleteResponse.getShardInfo()) - .deleted(deleteResponse.getResult() == DELETED) + public CompletionStage updateDataObjectAsync(UpdateDataObjectRequest request, Executor executor) { + return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + try (XContentBuilder sourceBuilder = XContentFactory.jsonBuilder()) { + log.info("Updating {} from {}", request.id(), request.index()); + UpdateResponse updateResponse = client + .update( + new UpdateRequest(request.index(), request.id()).doc(request.dataObject().toXContent(sourceBuilder, EMPTY_PARAMS)) + ) + .actionGet(); + log.info("Update status for id {}: {}", updateResponse.getId(), updateResponse.getResult()); + return new UpdateDataObjectResponse.Builder() + .id(updateResponse.getId()) + .shardId(updateResponse.getShardId()) + .shardInfo(updateResponse.getShardInfo()) + .updated(updateResponse.getResult() == UPDATED) .build(); - } catch (Exception e) { - throw new OpenSearchException(e); + } catch (IOException e) { + // Rethrow unchecked exception on XContent parsing error + throw new OpenSearchStatusException( + "Failed to parse data object to update in index " + request.index(), + RestStatus.BAD_REQUEST + ); } }), executor); } + + @Override + public CompletionStage deleteDataObjectAsync(DeleteDataObjectRequest request, Executor executor) { + return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + log.info("Deleting {} from {}", request.id(), request.index()); + DeleteResponse deleteResponse = client.delete(new DeleteRequest(request.index(), request.id())).actionGet(); + log.info("Deletion status for id {}: {}", deleteResponse.getId(), deleteResponse.getResult()); + return new DeleteDataObjectResponse.Builder() + .id(deleteResponse.getId()) + .shardId(deleteResponse.getShardId()) + .shardInfo(deleteResponse.getShardInfo()) + .deleted(deleteResponse.getResult() == DELETED) + .build(); + }), executor); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java index 8284b95f4e..35c5fb5743 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java @@ -10,7 +10,9 @@ import static org.opensearch.client.opensearch._types.Result.Created; import static org.opensearch.client.opensearch._types.Result.Deleted; +import static org.opensearch.client.opensearch._types.Result.Updated; +import java.io.IOException; import java.security.AccessController; import java.security.PrivilegedAction; import java.util.Map; @@ -19,7 +21,7 @@ import java.util.concurrent.CompletionStage; import java.util.concurrent.Executor; -import org.opensearch.OpenSearchException; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.support.replication.ReplicationResponse.ShardInfo; import org.opensearch.client.opensearch.OpenSearchClient; import org.opensearch.client.opensearch.core.DeleteRequest; @@ -28,9 +30,15 @@ import org.opensearch.client.opensearch.core.GetResponse; import org.opensearch.client.opensearch.core.IndexRequest; import org.opensearch.client.opensearch.core.IndexResponse; +import org.opensearch.client.opensearch.core.UpdateRequest; +import org.opensearch.client.opensearch.core.UpdateResponse; import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.sdk.DeleteDataObjectRequest; import org.opensearch.sdk.DeleteDataObjectResponse; @@ -39,6 +47,8 @@ import org.opensearch.sdk.PutDataObjectRequest; import org.opensearch.sdk.PutDataObjectResponse; import org.opensearch.sdk.SdkClient; +import org.opensearch.sdk.UpdateDataObjectRequest; +import org.opensearch.sdk.UpdateDataObjectResponse; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.databind.ObjectMapper; @@ -70,8 +80,12 @@ public CompletionStage putDataObjectAsync(PutDataObjectRe IndexResponse indexResponse = openSearchClient.index(indexRequest); log.info("Creation status for id {}: {}", indexResponse.id(), indexResponse.result()); return new PutDataObjectResponse.Builder().id(indexResponse.id()).created(indexResponse.result() == Created).build(); - } catch (Exception e) { - throw new OpenSearchException("Error occurred while indexing data object", e); + } catch (IOException e) { + // Rethrow unchecked exception on XContent parsing error + throw new OpenSearchStatusException( + "Failed to parse data object to put in index " + request.index(), + RestStatus.BAD_REQUEST + ); } }), executor); } @@ -94,8 +108,50 @@ public CompletionStage getDataObjectAsync(GetDataObjectRe XContentParser parser = JsonXContent.jsonXContent .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, json); return new GetDataObjectResponse.Builder().id(getResponse.id()).parser(Optional.of(parser)).source(source).build(); - } catch (Exception e) { - throw new OpenSearchException(e); + } catch (IOException e) { + // Rethrow unchecked exception on XContent parser creation error + throw new OpenSearchStatusException( + "Failed to create parser for data object retrieved from index " + request.index(), + RestStatus.INTERNAL_SERVER_ERROR + ); + } + }), executor); + } + + @Override + public CompletionStage updateDataObjectAsync(UpdateDataObjectRequest request, Executor executor) { + return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + @SuppressWarnings("unchecked") + Class> documentType = (Class>) (Class) Map.class; + request.dataObject().toXContent(builder, ToXContent.EMPTY_PARAMS); + Map docMap = JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, builder.toString()) + .map(); + UpdateRequest, ?> updateRequest = new UpdateRequest.Builder, Map>() + .index(request.index()) + .id(request.id()) + .doc(docMap) + .build(); + log.info("Updating {} in {}", request.id(), request.index()); + UpdateResponse> updateResponse = openSearchClient.update(updateRequest, documentType); + log.info("Update status for id {}: {}", updateResponse.id(), updateResponse.result()); + ShardInfo shardInfo = new ShardInfo( + updateResponse.shards().total().intValue(), + updateResponse.shards().successful().intValue() + ); + return new UpdateDataObjectResponse.Builder() + .id(updateResponse.id()) + .shardId(updateResponse.index()) + .shardInfo(shardInfo) + .updated(updateResponse.result() == Updated) + .build(); + } catch (IOException e) { + // Rethrow unchecked exception on update IOException + throw new OpenSearchStatusException( + "Parsing error updating data object " + request.id() + " in index " + request.index(), + RestStatus.BAD_REQUEST + ); } }), executor); } @@ -118,8 +174,12 @@ public CompletionStage deleteDataObjectAsync(DeleteDat .shardInfo(shardInfo) .deleted(deleteResponse.result() == Deleted) .build(); - } catch (Exception e) { - throw new OpenSearchException("Error occurred while deleting data object", e); + } catch (IOException e) { + // Rethrow unchecked exception on deletion IOException + throw new OpenSearchStatusException( + "IOException occurred while deleting data object " + request.id() + " from index " + request.index(), + RestStatus.INTERNAL_SERVER_ERROR + ); } }), executor); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java index 6f10e1d715..3cf3595e2a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java @@ -354,9 +354,7 @@ public void testDeleteConnector_ResourceNotFoundException() throws IOException, ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); verify(actionListener).onFailure(argumentCaptor.capture()); - // TODO: fix all this exception nesting - // java.util.concurrent.CompletionException: OpenSearchException[ResourceNotFoundException[errorMessage]]; nested: ResourceNotFoundException[errorMessage]; - assertEquals("errorMessage", argumentCaptor.getValue().getCause().getCause().getCause().getMessage()); + assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); } public void test_ValidationFailedException() throws IOException { diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java index ba7189df13..f125321d15 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java @@ -186,7 +186,6 @@ public void testGetConnector_NullResponse() throws InterruptedException { assertEquals("Failed to find connector with the provided connector id: connector_id", argumentCaptor.getValue().getMessage()); } - @Test public void testGetConnector_MultiTenancyEnabled_Success() throws IOException, InterruptedException { when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); when(connectorAccessControlHelper.hasPermission(any(), any())).thenReturn(true); diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java index d4e929ea25..fe89661967 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java @@ -6,6 +6,7 @@ package org.opensearch.ml.action.connector; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.*; import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; @@ -20,19 +21,25 @@ import java.util.List; import java.util.Map; import java.util.UUID; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import org.apache.lucene.search.TotalHits; +import org.junit.AfterClass; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.DocWriteResponse.Result; +import org.opensearch.action.LatchedActionListener; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; @@ -75,8 +82,6 @@ public class UpdateConnectorTransportActionTests extends OpenSearchTestCase { - private UpdateConnectorTransportAction updateConnectorTransportAction; - private static TestThreadPool testThreadPool = new TestThreadPool( UpdateConnectorTransportActionTests.class.getName(), new ScalingExecutorBuilder( @@ -88,6 +93,8 @@ public class UpdateConnectorTransportActionTests extends OpenSearchTestCase { ) ); + private UpdateConnectorTransportAction updateConnectorTransportAction; + @Mock private ConnectorAccessControlHelper connectorAccessControlHelper; @@ -98,6 +105,9 @@ public class UpdateConnectorTransportActionTests extends OpenSearchTestCase { private Client client; private SdkClient sdkClient; + @Mock + private NamedXContentRegistry xContentRegistry; + @Mock private ThreadPool threadPool; @@ -113,13 +123,9 @@ public class UpdateConnectorTransportActionTests extends OpenSearchTestCase { @Mock private ActionFilters actionFilters; - @Mock - NamedXContentRegistry xContentRegistry; - @Mock private MLUpdateConnectorRequest updateRequest; - @Mock private UpdateResponse updateResponse; @Mock @@ -138,13 +144,16 @@ public class UpdateConnectorTransportActionTests extends OpenSearchTestCase { private MLEngine mlEngine; + private static final String TEST_CONNECTOR_ID = "test_connector_id"; private static final List TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList .of("^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$"); @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); + sdkClient = new LocalClusterIndicesClient(client, xContentRegistry); + settings = Settings .builder() .putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES) @@ -163,14 +172,13 @@ public void setup() throws IOException { when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); when(threadPool.executor(anyString())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL)); - String connector_id = "test_connector_id"; MLCreateConnectorInput updateContent = MLCreateConnectorInput .builder() .updateConnector(true) .version("2") .description("updated description") .build(); - when(updateRequest.getConnectorId()).thenReturn(connector_id); + when(updateRequest.getConnectorId()).thenReturn(TEST_CONNECTOR_ID); when(updateRequest.getUpdateContent()).thenReturn(updateContent); SearchHits hits = new SearchHits(new SearchHit[] {}, new TotalHits(0, TotalHits.Relation.EQUAL_TO), Float.NaN); @@ -236,8 +244,13 @@ public void setup() throws IOException { }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), any(), any(), any()); } + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); + } + @Test - public void testExecuteConnectorAccessControlSuccess() { + public void testExecuteConnectorAccessControlSuccess() throws InterruptedException { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { @@ -246,14 +259,16 @@ public void testExecuteConnectorAccessControlSuccess() { return null; }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(updateResponse); - return null; - }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(updateResponse); + when(client.update(any(UpdateRequest.class))).thenReturn(future); - updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + CountDownLatch latch = new CountDownLatch(1); + ActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + updateConnectorTransportAction.doExecute(task, updateRequest, latchedActionListener); + latch.await(); + + verify(actionListener).onResponse(any(UpdateResponse.class)); } @Test @@ -294,7 +309,7 @@ public void testExecuteConnectorAccessControlException() { } @Test - public void testExecuteUpdateWrongStatus() { + public void testExecuteUpdateWrongStatus() throws InterruptedException { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { @@ -303,19 +318,23 @@ public void testExecuteUpdateWrongStatus() { return null; }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); - UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(updateResponse); - return null; - }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + PlainActionFuture future = PlainActionFuture.newFuture(); + UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, Result.CREATED); + future.onResponse(updateResponse); + when(client.update(any(UpdateRequest.class))).thenReturn(future); - updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + CountDownLatch latch = new CountDownLatch(1); + ActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + updateConnectorTransportAction.doExecute(task, updateRequest, latchedActionListener); + latch.await(); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(Result.CREATED, argumentCaptor.getValue().getResult()); } @Test - public void testExecuteUpdateException() { + public void testExecuteUpdateException() throws InterruptedException { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { @@ -324,13 +343,13 @@ public void testExecuteUpdateException() { return null; }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new RuntimeException("update document failure")); - return null; - }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + when(client.update(any(UpdateRequest.class))).thenThrow(new RuntimeException("update document failure")); + + CountDownLatch latch = new CountDownLatch(1); + ActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + updateConnectorTransportAction.doExecute(task, updateRequest, latchedActionListener); + latch.await(); - updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("update document failure", argumentCaptor.getValue().getMessage()); @@ -371,7 +390,7 @@ public void testExecuteSearchResponseError() { } @Test - public void testExecuteSearchIndexNotFoundError() { + public void testExecuteSearchIndexNotFoundError() throws InterruptedException { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { @@ -409,14 +428,18 @@ public void testExecuteSearchIndexNotFoundError() { return null; }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(updateResponse); - return null; - }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(updateResponse); + when(client.update(any(UpdateRequest.class))).thenReturn(future); - updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + CountDownLatch latch = new CountDownLatch(1); + ActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + updateConnectorTransportAction.doExecute(task, updateRequest, latchedActionListener); + latch.await(); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(Result.UPDATED, argumentCaptor.getValue().getResult()); } private SearchResponse noneEmptySearchResponse() throws IOException { diff --git a/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java b/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java index 68bef15120..696fd634cf 100644 --- a/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java @@ -52,6 +52,7 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.get.GetResult; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; @@ -443,11 +444,7 @@ public void testGetConnectorHappyCase() throws IOException, InterruptedException .id("connectorId") .build(); GetResponse getResponse = prepareConnector(); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(getResponse); - return null; - }).when(client).get(any(), any()); + PlainActionFuture future = PlainActionFuture.newFuture(); future.onResponse(getResponse); when(client.get(any(GetRequest.class))).thenReturn(future); @@ -470,6 +467,65 @@ public void testGetConnectorHappyCase() throws IOException, InterruptedException assertEquals(CommonValue.ML_CONNECTOR_INDEX, requestCaptor.getValue().index()); } + @Test + public void testGetConnectorException() throws IOException, InterruptedException { + GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder() + .index(CommonValue.ML_CONNECTOR_INDEX) + .id("connectorId") + .build(); + + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onFailure(new RuntimeException("Failed to get connector")); + when(client.get(any(GetRequest.class))).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(getConnectorActionListener, latch); + connectorAccessControlHelper + .getConnector( + sdkClient, + client, + client.threadPool().getThreadContext().newStoredContext(true), + getRequest, + "connectorId", + latchedActionListener + ); + latch.await(); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getConnectorActionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to get connector", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testGetConnectorIndexNotFound() throws IOException, InterruptedException { + GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder() + .index(CommonValue.ML_CONNECTOR_INDEX) + .id("connectorId") + .build(); + + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onFailure(new IndexNotFoundException("Index not found")); + when(client.get(any(GetRequest.class))).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(getConnectorActionListener, latch); + connectorAccessControlHelper + .getConnector( + sdkClient, + client, + client.threadPool().getThreadContext().newStoredContext(true), + getRequest, + "connectorId", + latchedActionListener + ); + latch.await(); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(getConnectorActionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to find connector", argumentCaptor.getValue().getMessage()); + assertEquals(RestStatus.NOT_FOUND, argumentCaptor.getValue().status()); + } + private GetResponse createGetResponse(List backendRoles) { HttpConnector httpConnector = HttpConnector .builder() diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java index ffd0e0072b..8b13419ac0 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/LocalClusterIndicesClientTests.java @@ -9,7 +9,6 @@ package org.opensearch.ml.sdkclient; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -28,7 +27,6 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.opensearch.OpenSearchException; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.delete.DeleteResponse; @@ -38,6 +36,8 @@ import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.support.replication.ReplicationResponse.ShardInfo; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.common.action.ActionFuture; import org.opensearch.common.settings.Settings; @@ -45,8 +45,6 @@ import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.json.JsonXContent; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.action.ActionResponse; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.sdk.DeleteDataObjectRequest; @@ -56,6 +54,8 @@ import org.opensearch.sdk.PutDataObjectRequest; import org.opensearch.sdk.PutDataObjectResponse; import org.opensearch.sdk.SdkClient; +import org.opensearch.sdk.UpdateDataObjectRequest; +import org.opensearch.sdk.UpdateDataObjectResponse; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ScalingExecutorBuilder; import org.opensearch.threadpool.TestThreadPool; @@ -125,18 +125,17 @@ public void testPutDataObject() throws IOException { public void testPutDataObject_Exception() throws IOException { PutDataObjectRequest putRequest = new PutDataObjectRequest.Builder().index(TEST_INDEX).dataObject(testDataObject).build(); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new IOException("test")); - return null; - }).when(mockedClient).index(any(IndexRequest.class), any()); + ArgumentCaptor indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class); + when(mockedClient.index(indexRequestCaptor.capture())).thenThrow(new UnsupportedOperationException("test")); CompletableFuture future = sdkClient .putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) .toCompletableFuture(); CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); - assertEquals(OpenSearchException.class, ce.getCause().getClass()); + Throwable cause = ce.getCause(); + assertEquals(UnsupportedOperationException.class, cause.getClass()); + assertEquals("test", cause.getMessage()); } public void testGetDataObject() throws IOException { @@ -194,18 +193,91 @@ public void testGetDataObject_NotFound() throws IOException { public void testGetDataObject_Exception() throws IOException { GetDataObjectRequest getRequest = new GetDataObjectRequest.Builder().index(TEST_INDEX).id(TEST_ID).build(); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new IOException("test")); - return null; - }).when(mockedClient).get(any(GetRequest.class), any()); + ArgumentCaptor getRequestCaptor = ArgumentCaptor.forClass(GetRequest.class); + when(mockedClient.get(getRequestCaptor.capture())).thenThrow(new UnsupportedOperationException("test")); CompletableFuture future = sdkClient .getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) .toCompletableFuture(); CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); - assertEquals(OpenSearchException.class, ce.getCause().getClass()); + Throwable cause = ce.getCause(); + assertEquals(UnsupportedOperationException.class, cause.getClass()); + assertEquals("test", cause.getMessage()); + } + + public void testUpdateDataObject() throws IOException { + UpdateDataObjectRequest updateRequest = new UpdateDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(testDataObject) + .build(); + + UpdateResponse updateResponse = mock(UpdateResponse.class); + when(updateResponse.getId()).thenReturn(TEST_ID); + when(updateResponse.getResult()).thenReturn(DocWriteResponse.Result.UPDATED); + @SuppressWarnings("unchecked") + ActionFuture future = mock(ActionFuture.class); + when(mockedClient.update(any(UpdateRequest.class))).thenReturn(future); + when(future.actionGet()).thenReturn(updateResponse); + + UpdateDataObjectResponse response = sdkClient + .updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(UpdateRequest.class); + verify(mockedClient, times(1)).update(requestCaptor.capture()); + assertEquals(TEST_INDEX, requestCaptor.getValue().index()); + assertEquals(TEST_ID, response.id()); + assertTrue(response.updated()); + } + + public void testUpdateDataObject_NotFound() throws IOException { + UpdateDataObjectRequest updateRequest = new UpdateDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(testDataObject) + .build(); + + UpdateResponse updateResponse = mock(UpdateResponse.class); + when(updateResponse.getId()).thenReturn(TEST_ID); + when(updateResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); + @SuppressWarnings("unchecked") + ActionFuture future = mock(ActionFuture.class); + when(mockedClient.update(any(UpdateRequest.class))).thenReturn(future); + when(future.actionGet()).thenReturn(updateResponse); + + UpdateDataObjectResponse response = sdkClient + .updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(UpdateRequest.class); + verify(mockedClient, times(1)).update(requestCaptor.capture()); + assertEquals(TEST_INDEX, requestCaptor.getValue().index()); + assertEquals(TEST_ID, response.id()); + assertFalse(response.updated()); + } + + public void testUpdateDataObject_Exception() throws IOException { + UpdateDataObjectRequest updateRequest = new UpdateDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(testDataObject) + .build(); + + ArgumentCaptor updateRequestCaptor = ArgumentCaptor.forClass(UpdateRequest.class); + when(mockedClient.update(updateRequestCaptor.capture())).thenThrow(new UnsupportedOperationException("test")); + + CompletableFuture future = sdkClient + .updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture(); + + CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); + Throwable cause = ce.getCause(); + assertEquals(UnsupportedOperationException.class, cause.getClass()); + assertEquals("test", cause.getMessage()); } public void testDeleteDataObject() throws IOException { @@ -236,17 +308,16 @@ public void testDeleteDataObject() throws IOException { public void testDeleteDataObject_Exception() throws IOException { DeleteDataObjectRequest deleteRequest = new DeleteDataObjectRequest.Builder().index(TEST_INDEX).id(TEST_ID).build(); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new IOException("test")); - return null; - }).when(mockedClient).delete(any(DeleteRequest.class), any()); + ArgumentCaptor deleteRequestCaptor = ArgumentCaptor.forClass(DeleteRequest.class); + when(mockedClient.delete(deleteRequestCaptor.capture())).thenThrow(new UnsupportedOperationException("test")); CompletableFuture future = sdkClient .deleteDataObjectAsync(deleteRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) .toCompletableFuture(); CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); - assertEquals(OpenSearchException.class, ce.getCause().getClass()); + Throwable cause = ce.getCause(); + assertEquals(UnsupportedOperationException.class, cause.getClass()); + assertEquals("test", cause.getMessage()); } } diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java index ee7ee53f98..2018a69a0b 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java @@ -8,6 +8,7 @@ */ package org.opensearch.ml.sdkclient; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; @@ -24,7 +25,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.opensearch.OpenSearchException; +import org.opensearch.OpenSearchStatusException; import org.opensearch.client.opensearch.OpenSearchClient; import org.opensearch.client.opensearch._types.Result; import org.opensearch.client.opensearch._types.ShardStatistics; @@ -34,6 +35,8 @@ import org.opensearch.client.opensearch.core.GetResponse; import org.opensearch.client.opensearch.core.IndexRequest; import org.opensearch.client.opensearch.core.IndexResponse; +import org.opensearch.client.opensearch.core.UpdateRequest; +import org.opensearch.client.opensearch.core.UpdateResponse; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; @@ -45,6 +48,8 @@ import org.opensearch.sdk.PutDataObjectRequest; import org.opensearch.sdk.PutDataObjectResponse; import org.opensearch.sdk.SdkClient; +import org.opensearch.sdk.UpdateDataObjectRequest; +import org.opensearch.sdk.UpdateDataObjectResponse; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ScalingExecutorBuilder; import org.opensearch.threadpool.TestThreadPool; @@ -148,7 +153,7 @@ public void testPutDataObject_Exception() throws IOException { .toCompletableFuture(); CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); - assertEquals(OpenSearchException.class, ce.getCause().getClass()); + assertEquals(OpenSearchStatusException.class, ce.getCause().getClass()); } @SuppressWarnings({ "unchecked", "rawtypes" }) @@ -213,7 +218,87 @@ public void testGetDataObject_Exception() throws IOException { .toCompletableFuture(); CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); - assertEquals(OpenSearchException.class, ce.getCause().getClass()); + assertEquals(OpenSearchStatusException.class, ce.getCause().getClass()); + } + + public void testUpdateDataObject() throws IOException { + UpdateDataObjectRequest updateRequest = new UpdateDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(testDataObject) + .build(); + + UpdateResponse> updateResponse = new UpdateResponse.Builder>() + .id(TEST_ID) + .index(TEST_INDEX) + .primaryTerm(0) + .result(Result.Updated) + .seqNo(0) + .shards(new ShardStatistics.Builder().failed(0).successful(1).total(1).build()) + .version(0) + .build(); + + @SuppressWarnings("unchecked") + ArgumentCaptor, ?>> updateRequestCaptor = ArgumentCaptor.forClass(UpdateRequest.class); + when(mockedOpenSearchClient.update(updateRequestCaptor.capture(), any())).thenReturn(updateResponse); + + UpdateDataObjectResponse response = sdkClient + .updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + + assertEquals(TEST_INDEX, updateRequestCaptor.getValue().index()); + assertEquals(TEST_ID, response.id()); + assertTrue(response.updated()); + } + + public void testUpdateDataObject_NotFound() throws IOException { + UpdateDataObjectRequest updateRequest = new UpdateDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(testDataObject) + .build(); + + UpdateResponse> updateResponse = new UpdateResponse.Builder>() + .id(TEST_ID) + .index(TEST_INDEX) + .primaryTerm(0) + .result(Result.Created) + .seqNo(0) + .shards(new ShardStatistics.Builder().failed(0).successful(1).total(1).build()) + .version(0) + .build(); + + @SuppressWarnings("unchecked") + ArgumentCaptor, ?>> updateRequestCaptor = ArgumentCaptor.forClass(UpdateRequest.class); + when(mockedOpenSearchClient.update(updateRequestCaptor.capture(), any())).thenReturn(updateResponse); + + UpdateDataObjectResponse response = sdkClient + .updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + + assertEquals(TEST_INDEX, updateRequestCaptor.getValue().index()); + assertEquals(TEST_ID, response.id()); + assertFalse(response.updated()); + } + + public void testtUpdateDataObject_Exception() throws IOException { + UpdateDataObjectRequest updateRequest = new UpdateDataObjectRequest.Builder() + .index(TEST_INDEX) + .id(TEST_ID) + .dataObject(testDataObject) + .build(); + + ArgumentCaptor> updateRequestCaptor = ArgumentCaptor.forClass(UpdateRequest.class); + when(mockedOpenSearchClient.update(updateRequestCaptor.capture(), any())).thenThrow(new IOException("test")); + + CompletableFuture future = sdkClient + .updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture(); + + CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); + assertEquals(OpenSearchStatusException.class, ce.getCause().getClass()); } public void testDeleteDataObject() throws IOException { @@ -285,6 +370,6 @@ public void testDeleteDataObject_Exception() throws IOException { .toCompletableFuture(); CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); - assertEquals(OpenSearchException.class, ce.getCause().getClass()); + assertEquals(OpenSearchStatusException.class, ce.getCause().getClass()); } }