diff --git a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequest.java index cea0d484fe..6e27c2cb1d 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequest.java @@ -6,6 +6,8 @@ package org.opensearch.ml.common.transport.undeploy; import lombok.Getter; +import lombok.Setter; + import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; @@ -17,10 +19,15 @@ public class MLUndeployModelNodesRequest extends BaseNodesRequest CompletionStage executePrivilegedAsync(PrivilegedAction action, Executor executor) { + return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged(action), executor); + } +} diff --git a/common/src/main/java/org/opensearch/sdk/BulkDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/BulkDataObjectRequest.java new file mode 100644 index 0000000000..2452853087 --- /dev/null +++ b/common/src/main/java/org/opensearch/sdk/BulkDataObjectRequest.java @@ -0,0 +1,132 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.sdk; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.opensearch.action.support.WriteRequest.RefreshPolicy; +import org.opensearch.common.Nullable; +import org.opensearch.core.common.Strings; + +public class BulkDataObjectRequest { + + private final List requests = new ArrayList<>(); + private final Set indices = new HashSet<>(); + private RefreshPolicy refreshPolicy = RefreshPolicy.NONE; + private String globalIndex; + + /** + * Instantiate this request with a global index. + *

+ * For data storage implementations other than OpenSearch, an index may be referred to as a table and the id may be referred to as a primary key. + * @param globalIndex the index location for all the bulk requests as a default if not already specified + */ + public BulkDataObjectRequest(@Nullable String globalIndex) { + this.globalIndex = globalIndex; + } + + /** + * Returns the list of requests in this bulk request. + * @return the requests list + */ + public List requests() { + return List.copyOf(this.requests); + } + + /** + * Returns the indices being updated in this bulk request. + * @return the indices being updated + */ + public Set getIndices() { + return Collections.unmodifiableSet(indices); + } + + /** + * Add the given request to the {@link BulkDataObjectRequest} + * @param request The request to add + * @return the updated request object + */ + public BulkDataObjectRequest add(DataObjectRequest request) { + if (!request.isWriteRequest()) { + throw new IllegalArgumentException("No support for request [" + request.getClass().getName() + "]"); + } + if (Strings.isNullOrEmpty(request.index())) { + if (Strings.isNullOrEmpty(globalIndex)) { + throw new IllegalArgumentException( + "Either the request [" + request.getClass().getName() + "] or the bulk request must specify an index." + ); + } + indices.add(globalIndex); + request.index(globalIndex); + } else { + indices.add(request.index()); + } + requests.add(request); + return this; + } + + /** + * Should this request trigger a refresh ({@linkplain RefreshPolicy#IMMEDIATE}), wait for a refresh ( + * {@linkplain RefreshPolicy#WAIT_UNTIL}), or proceed ignore refreshes entirely ({@linkplain RefreshPolicy#NONE}, the default). + */ + public BulkDataObjectRequest setRefreshPolicy(RefreshPolicy refreshPolicy) { + this.refreshPolicy = refreshPolicy; + return this; + } + + /** + * Should this request trigger a refresh ({@linkplain RefreshPolicy#IMMEDIATE}), wait for a refresh ( + * {@linkplain RefreshPolicy#WAIT_UNTIL}), or proceed ignore refreshes entirely ({@linkplain RefreshPolicy#NONE}, the default). + */ + public RefreshPolicy getRefreshPolicy() { + return refreshPolicy; + } + + /** + * Instantiate a builder for this object + * @return a builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Class for constructing a Builder for this Request Object + */ + public static class Builder { + private String globalIndex = null; + + /** + * Empty constructor to initialize + */ + protected Builder() {} + + /** + * Add an index to this builder + * @param index the index to put the object + * @return the updated builder + */ + public Builder globalIndex(String index) { + this.globalIndex = index; + return this; + } + + /** + * Builds the request + * @return A {@link BulkDataObjectRequest} + */ + public BulkDataObjectRequest build() { + return new BulkDataObjectRequest(this.globalIndex); + } + } +} diff --git a/common/src/main/java/org/opensearch/sdk/BulkDataObjectResponse.java b/common/src/main/java/org/opensearch/sdk/BulkDataObjectResponse.java new file mode 100644 index 0000000000..2b387ce4f0 --- /dev/null +++ b/common/src/main/java/org/opensearch/sdk/BulkDataObjectResponse.java @@ -0,0 +1,76 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.sdk; + +import java.util.Arrays; + +import org.opensearch.core.xcontent.XContentParser; + +import static org.opensearch.action.bulk.BulkResponse.NO_INGEST_TOOK; + +public class BulkDataObjectResponse { + + private final DataObjectResponse[] responses; + private final long tookInMillis; + private final long ingestTookInMillis; + private final boolean failures; + private final XContentParser parser; + + public BulkDataObjectResponse(DataObjectResponse[] responses, long tookInMillis, boolean failures, XContentParser parser) { + this(responses, tookInMillis, NO_INGEST_TOOK, failures, parser); + } + + public BulkDataObjectResponse(DataObjectResponse[] responses, long tookInMillis, long ingestTookInMillis, boolean failures, XContentParser parser) { + this.responses = responses; + this.tookInMillis = tookInMillis; + this.ingestTookInMillis = ingestTookInMillis; + this.failures = failures; + this.parser = parser; + } + + /** + * The items representing each action performed in the bulk operation (in the same order!). + * @return the responses in the same order requested + */ + public DataObjectResponse[] getResponses() { + return responses; + } + + /** + * How long the bulk execution took. Excluding ingest preprocessing. + * @return the execution time in milliseconds + */ + public long getTookInMillis() { + return tookInMillis; + } + + /** + * If ingest is enabled returns the bulk ingest preprocessing time. in milliseconds, otherwise -1 is returned. + * @return the ingest execution time in milliseconds + */ + public long getIngestTookInMillis() { + return ingestTookInMillis; + } + + /** + * Has anything failed with the execution. + * @return true if any response failed, false otherwise + */ + public boolean hasFailures() { + return this.failures; + } + + /** + * Returns the parser + * @return the parser + */ + public XContentParser parser() { + return this.parser; + } +} diff --git a/common/src/main/java/org/opensearch/sdk/DataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/DataObjectRequest.java new file mode 100644 index 0000000000..ef2bb39f0f --- /dev/null +++ b/common/src/main/java/org/opensearch/sdk/DataObjectRequest.java @@ -0,0 +1,128 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.sdk; + +public abstract class DataObjectRequest { + + private String index; + private final String id; + private String tenantId; + + /** + * Instantiate this request with an index and id. + *

+ * For data storage implementations other than OpenSearch, an index may be referred to as a table and the id may be referred to as a primary key. + * @param index the index location to delete the object + * @param id the document id + * @param tenantId the tenant id + */ + protected DataObjectRequest(String index, String id, String tenantId) { + this.index = index; + this.id = id; + this.tenantId = tenantId; + } + + /** + * Returns the index + * @return the index + */ + public String index() { + return this.index; + } + + /** + * Sets the index + * @param index The new index to set + */ + public void index(String index) { + this.index = index; + } + + /** + * Returns the document id + * @return the id + */ + public String id() { + return this.id; + } + + /** + * Returns the tenant id + * @return the tenantId + */ + public String tenantId() { + return this.tenantId; + } + + /** + * Sets the tenant id + * @param tenantId The new tenant id to set + */ + public void tenantId(String tenantId) { + this.tenantId = tenantId; + } + + /** + * Returns whether the subclass can be used in a {@link BulkDataObjectRequest} + * @return whether the subclass is a write request + */ + public abstract boolean isWriteRequest(); + + /** + * Superclass for common fields in subclass builders + */ + public static class Builder> { + protected String index = null; + protected String id = null; + protected String tenantId = null; + + /** + * Empty constructor to initialize + */ + protected Builder() {} + + /** + * Add an index to this builder + * @param index the index to put the object + * @return the updated builder + */ + public T index(String index) { + this.index = index; + return self(); + } + + /** + * Add an id to this builder + * @param id the document id + * @return the updated builder + */ + public T id(String id) { + this.id = id; + return self(); + } + + /** + * Add a tenant id to this builder + * @param tenantId the tenant id + * @return the updated builder + */ + public T tenantId(String tenantId) { + this.tenantId = tenantId; + return self(); + } + + /** + * Returns this builder as the parameterized type. + */ + @SuppressWarnings("unchecked") + protected T self() { + return (T) this; + } + } +} diff --git a/common/src/main/java/org/opensearch/sdk/DataObjectResponse.java b/common/src/main/java/org/opensearch/sdk/DataObjectResponse.java new file mode 100644 index 0000000000..2d9671a862 --- /dev/null +++ b/common/src/main/java/org/opensearch/sdk/DataObjectResponse.java @@ -0,0 +1,174 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.sdk; + +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentParser; + +public abstract class DataObjectResponse { + private final String index; + private final String id; + private final XContentParser parser; + private final boolean failed; + private final Exception cause; + private final RestStatus status; + + /** + * Instantiate this request with an index, id, failure status, and parser representing a Response + *

+ * For data storage implementations other than OpenSearch, the id may be referred to as a primary key. + * @param index the index + * @param id the document id + * @param parser a parser that can be used to create a Response + * @param failed whether the request failed + * @param cause the Exception causing the failure + * @param status the RestStatus + */ + protected DataObjectResponse(String index, String id, XContentParser parser, boolean failed, Exception cause, RestStatus status) { + this.index = index; + this.id = id; + this.parser = parser; + this.failed = failed; + this.cause = cause; + this.status = status; + } + + /** + * Returns the index + * @return the index + */ + public String index() { + return this.index; + } + + /** + * Returns the document id + * @return the id + */ + public String id() { + return this.id; + } + + /** + * Returns the parser that can be used to create an IndexResponse + * @return the parser + */ + public XContentParser parser() { + return this.parser; + } + + /** + * Has anything failed with the execution. + * @return whether the corresponding bulk request failed + */ + public boolean isFailed() { + return this.failed; + } + + /** + * The actual cause of the failure. + * @return the Exception causing the failure + */ + public Exception cause() { + return this.cause; + } + + /** + * The rest status. + * @return the rest status. + */ + public RestStatus status() { + return this.status; + } + + /** + * Superclass for common fields in subclass builders + */ + public static class Builder> { + protected String index = null; + protected String id = null; + protected XContentParser parser; + protected boolean failed = false; + protected Exception cause = null; + protected RestStatus status = null; + + /** + * Empty constructor to initialize + */ + protected Builder() {} + + /** + * Add an index to this builder + * @param index the index to add + * @return the updated builder + */ + public T index(String index) { + this.index = index; + return self(); + } + + /** + * Add an id to this builder + * @param id the id to add + * @return the updated builder + */ + public T id(String id) { + this.id = id; + return self(); + } + + /** + * Add a parser to this builder + * @param parser a parser that can be used to create a Response for the subclass + * @return the updated builder + */ + public T parser(XContentParser parser) { + this.parser = parser; + return self(); + } + + /** + * Add a failed status to this builder + * @param failed whether the request failed + * @return the updated builder + */ + public T failed(boolean failed) { + this.failed = failed; + return self(); + } + + /** + * Add a cause to this builder + * @param cause the Exception + * @return the updated builder + */ + public T cause(Exception cause) { + this.cause = cause; + return self(); + } + + /** + * Add a rest status to this builder + * @param status the rest status + * @return the updated builder + */ + public T status(RestStatus status) { + this.status = status; + return self(); + } + + /** + * Returns this builder as the parameterized type. + */ + @SuppressWarnings("unchecked") + protected T self() { + return (T) this; + } + } +} diff --git a/common/src/main/java/org/opensearch/sdk/DeleteDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/DeleteDataObjectRequest.java index 6dfed07293..e22a360d77 100644 --- a/common/src/main/java/org/opensearch/sdk/DeleteDataObjectRequest.java +++ b/common/src/main/java/org/opensearch/sdk/DeleteDataObjectRequest.java @@ -8,11 +8,7 @@ */ package org.opensearch.sdk; -public class DeleteDataObjectRequest { - - private final String index; - private final String id; - private final String tenantId; +public class DeleteDataObjectRequest extends DataObjectRequest { /** * Instantiate this request with an index and id. @@ -23,35 +19,14 @@ public class DeleteDataObjectRequest { * @param tenantId the tenant id */ public DeleteDataObjectRequest(String index, String id, String tenantId) { - this.index = index; - this.id = id; - this.tenantId = tenantId; + super(index, id, tenantId); } - /** - * Returns the index - * @return the index - */ - public String index() { - return this.index; - } - - /** - * Returns the document id - * @return the id - */ - public String id() { - return this.id; + @Override + public boolean isWriteRequest() { + return true; } - /** - * Returns the tenant id - * @return the tenantId - */ - public String tenantId() { - return this.tenantId; - } - /** * Instantiate a builder for this object * @return a builder instance @@ -63,45 +38,7 @@ public static Builder builder() { /** * Class for constructing a Builder for this Request Object */ - public static class Builder { - private String index = null; - private String id = null; - private String tenantId = null; - - /** - * Empty Constructor for the Builder object - */ - private Builder() {} - - /** - * Add an index to this builder - * @param index the index to put the object - * @return the updated builder - */ - public Builder index(String index) { - this.index = index; - return this; - } - - /** - * Add an id to this builder - * @param id the document id - * @return the updated builder - */ - public Builder id(String id) { - this.id = id; - return this; - } - - /** - * Add a tenant id to this builder - * @param tenantId the tenant id - * @return the updated builder - */ - public Builder tenantId(String tenantId) { - this.tenantId = tenantId; - return this; - } + public static class Builder extends DataObjectRequest.Builder { /** * Builds the object diff --git a/common/src/main/java/org/opensearch/sdk/DeleteDataObjectResponse.java b/common/src/main/java/org/opensearch/sdk/DeleteDataObjectResponse.java index d7939d574b..4036a9cdaa 100644 --- a/common/src/main/java/org/opensearch/sdk/DeleteDataObjectResponse.java +++ b/common/src/main/java/org/opensearch/sdk/DeleteDataObjectResponse.java @@ -8,40 +8,26 @@ */ package org.opensearch.sdk; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; -public class DeleteDataObjectResponse { - private final String id; - private final XContentParser parser; +public class DeleteDataObjectResponse extends DataObjectResponse { /** * Instantiate this request with an id and parser representing a DeleteResponse *

* For data storage implementations other than OpenSearch, the id may be referred to as a primary key. + * @param index the index * @param id the document id * @param parser a parser that can be used to create a DeleteResponse + * @param failed whether the request failed + * @param cause the Exception causing the failure + * @param status the RestStatus */ - public DeleteDataObjectResponse(String id, XContentParser parser) { - this.id = id; - this.parser = parser; + public DeleteDataObjectResponse(String index, String id, XContentParser parser, boolean failed, Exception cause, RestStatus status) { + super(index, id, parser, failed, cause, status); } - /** - * Returns the document id - * @return the id - */ - public String id() { - return this.id; - } - - /** - * Returns the parser that can be used to create a DeleteResponse - * @return the parser - */ - public XContentParser parser() { - return this.parser; - } - /** * Instantiate a builder for this object * @return a builder instance @@ -53,41 +39,14 @@ public static Builder builder() { /** * Class for constructing a Builder for this Response Object */ - public static class Builder { - private String id = null; - private XContentParser parser = null; + public static class Builder extends DataObjectResponse.Builder { - /** - * Empty Constructor for the Builder object - */ - private Builder() {} - - /** - * Add an id to this builder - * @param id the id to add - * @return the updated builder - */ - public Builder id(String id) { - this.id = id; - return this; - } - - /** - * Add a parser to this builder - * @param parser a parser that can be used to create a DeleteResponse - * @return the updated builder - */ - public Builder parser(XContentParser parser) { - this.parser = parser; - return this; - } - /** * Builds the response * @return A {@link DeleteDataObjectResponse} */ public DeleteDataObjectResponse build() { - return new DeleteDataObjectResponse(this.id, this.parser); + return new DeleteDataObjectResponse(this.index, this.id, this.parser, this.failed, this.cause, this.status); } } } diff --git a/common/src/main/java/org/opensearch/sdk/GetDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/GetDataObjectRequest.java index 05d100d380..77165aa561 100644 --- a/common/src/main/java/org/opensearch/sdk/GetDataObjectRequest.java +++ b/common/src/main/java/org/opensearch/sdk/GetDataObjectRequest.java @@ -10,11 +10,8 @@ import org.opensearch.search.fetch.subphase.FetchSourceContext; -public class GetDataObjectRequest { +public class GetDataObjectRequest extends DataObjectRequest { - private final String index; - private final String id; - private final String tenantId; private final FetchSourceContext fetchSourceContext; /** @@ -27,36 +24,10 @@ public class GetDataObjectRequest { * @param fetchSourceContext the context to use when fetching _source */ public GetDataObjectRequest(String index, String id, String tenantId, FetchSourceContext fetchSourceContext) { - this.index = index; - this.id = id; - this.tenantId = tenantId; + super(index, id, tenantId); this.fetchSourceContext = fetchSourceContext; } - /** - * Returns the index - * @return the index - */ - public String index() { - return this.index; - } - - /** - * Returns the document id - * @return the id - */ - public String id() { - return this.id; - } - - /** - * Returns the tenant id - * @return the tenantId - */ - public String tenantId() { - return this.tenantId; - } - /** * Returns the context for fetching _source * @return the fetchSourceContext @@ -65,6 +36,11 @@ public FetchSourceContext fetchSourceContext() { return this.fetchSourceContext; } + @Override + public boolean isWriteRequest() { + return false; + } + /** * Instantiate a builder for this object * @return a builder instance @@ -76,47 +52,9 @@ public static Builder builder() { /** * Class for constructing a Builder for this Request Object */ - public static class Builder { - private String index = null; - private String id = null; - private String tenantId = null; + public static class Builder extends DataObjectRequest.Builder { private FetchSourceContext fetchSourceContext; - /** - * Empty Constructor for the Builder object - */ - private Builder() {} - - /** - * Add an index to this builder - * @param index the index to put the object - * @return the updated builder - */ - public Builder index(String index) { - this.index = index; - return this; - } - - /** - * Add an id to this builder - * @param id the document id - * @return the updated builder - */ - public Builder id(String id) { - this.id = id; - return this; - } - - /** - * Add a tenant id to this builder - * @param tenantId the tenant id - * @return the updated builder - */ - public Builder tenantId(String tenantId) { - this.tenantId = tenantId; - return this; - } - /** * Add a fetchSourceContext to this builder * @param fetchSourceContext the fetchSourceContext diff --git a/common/src/main/java/org/opensearch/sdk/GetDataObjectResponse.java b/common/src/main/java/org/opensearch/sdk/GetDataObjectResponse.java index 24c4c72a83..9802ae8266 100644 --- a/common/src/main/java/org/opensearch/sdk/GetDataObjectResponse.java +++ b/common/src/main/java/org/opensearch/sdk/GetDataObjectResponse.java @@ -8,46 +8,32 @@ */ package org.opensearch.sdk; -import org.opensearch.core.xcontent.XContentParser; - import java.util.Collections; import java.util.Map; -public class GetDataObjectResponse { - private final String id; - private final XContentParser parser; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentParser; + +public class GetDataObjectResponse extends DataObjectResponse { private final Map source; /** * Instantiate this request with an id and parser/map used to recreate the data object. *

* For data storage implementations other than OpenSearch, the id may be referred to as a primary key. + * @param index the index * @param id the document id * @param parser a parser that can be used to create a GetResponse + * @param failed whether the request failed + * @param cause the Exception causing the failure + * @param status the RestStatus * @param source the data object as a map */ - public GetDataObjectResponse(String id, XContentParser parser, Map source) { - this.id = id; - this.parser = parser; + public GetDataObjectResponse(String index, String id, XContentParser parser, boolean failed, Exception cause, RestStatus status, Map source) { + super(index, id, parser, failed, cause, status); this.source = source; } - /** - * Returns the document id - * @return the id - */ - public String id() { - return this.id; - } - - /** - * Returns the parser that can be used to create a GetResponse - * @return the parser - */ - public XContentParser parser() { - return this.parser; - } - /** * Returns the source map. This is a logical representation of the data object. * @return the source map @@ -67,36 +53,9 @@ public static Builder builder() { /** * Class for constructing a Builder for this Response Object */ - public static class Builder { - private String id = null; - private XContentParser parser = null; + public static class Builder extends DataObjectResponse.Builder { private Map source = Collections.emptyMap(); - /** - * Empty Constructor for the Builder object - */ - private Builder() {} - - /** - * Add an id to this builder - * @param id the id to add - * @return the updated builder - */ - public Builder id(String id) { - this.id = id; - return this; - } - - /** - * Add a parser to this builder - * @param parser a parser that can be used to create a GetResponse - * @return the updated builder - */ - public Builder parser(XContentParser parser) { - this.parser = parser; - return this; - } - /** * Add a source map to this builder * @param source the data object as a map @@ -106,13 +65,13 @@ public Builder source(Map source) { this.source = source == null ? Collections.emptyMap() : source; return this; } - + /** * Builds the response * @return A {@link GetDataObjectResponse} */ public GetDataObjectResponse build() { - return new GetDataObjectResponse(this.id, this.parser, this.source); + return new GetDataObjectResponse(this.index, this.id, this.parser, this.failed, this.cause, this.status, this.source); } } } diff --git a/common/src/main/java/org/opensearch/sdk/PutDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/PutDataObjectRequest.java index ff06080abf..6a0dfe5da2 100644 --- a/common/src/main/java/org/opensearch/sdk/PutDataObjectRequest.java +++ b/common/src/main/java/org/opensearch/sdk/PutDataObjectRequest.java @@ -8,17 +8,12 @@ */ package org.opensearch.sdk; -import org.opensearch.core.xcontent.ToXContentObject; -import org.opensearch.core.xcontent.XContentBuilder; - -import java.io.IOException; import java.util.Map; -public class PutDataObjectRequest { +import org.opensearch.core.xcontent.ToXContentObject; + +public class PutDataObjectRequest extends DataObjectRequest { - private final String index; - private final String id; - private final String tenantId; private final boolean overwriteIfExists; private final ToXContentObject dataObject; @@ -30,37 +25,11 @@ public class PutDataObjectRequest { * @param dataObject the data object */ public PutDataObjectRequest(String index, String id, String tenantId, boolean overwriteIfExists, ToXContentObject dataObject) { - this.index = index; - this.id = id; - this.tenantId = tenantId; + super(index, id, tenantId); this.overwriteIfExists = overwriteIfExists; this.dataObject = dataObject; } - /** - * Returns the index - * @return the index - */ - public String index() { - return this.index; - } - - /** - * Returns the document id - * @return the id - */ - public String id() { - return this.id; - } - - /** - * Returns the tenant id - * @return the tenantId - */ - public String tenantId() { - return this.tenantId; - } - /** * Returns whether to overwrite an existing document (upsert) * @return true if this request should overwrite @@ -77,6 +46,11 @@ public ToXContentObject dataObject() { return this.dataObject; } + @Override + public boolean isWriteRequest() { + return true; + } + /** * Instantiate a builder for this object * @return a builder instance @@ -88,48 +62,10 @@ public static Builder builder() { /** * Class for constructing a Builder for this Request Object */ - public static class Builder { - private String index = null; - private String id = null; - private String tenantId = null; + public static class Builder extends DataObjectRequest.Builder { private boolean overwriteIfExists = true; private ToXContentObject dataObject = null; - /** - * Empty Constructor for the Builder object - */ - private Builder() {} - - /** - * Add an index to this builder - * @param index the index to put the object - * @return the updated builder - */ - public Builder index(String index) { - this.index = index; - return this; - } - - /** - * Add an id to this builder - * @param id the documet id - * @return the updated builder - */ - public Builder id(String id) { - this.id = id; - return this; - } - - /** - * Add a tenant id to this builder - * @param tenantId the tenant id - * @return the updated builder - */ - public Builder tenantId(String tenantId) { - this.tenantId = tenantId; - return this; - } - /** * Specify whether to overwrite an existing document/item (upsert). True by default. * @param overwriteIfExists whether to overwrite an existing document/item @@ -139,6 +75,7 @@ public Builder overwriteIfExists(boolean overwriteIfExists) { this.overwriteIfExists = overwriteIfExists; return this; } + /** * Add a data object to this builder * @param dataObject the data object diff --git a/common/src/main/java/org/opensearch/sdk/PutDataObjectResponse.java b/common/src/main/java/org/opensearch/sdk/PutDataObjectResponse.java index 20b1db3ca0..fea527ba9d 100644 --- a/common/src/main/java/org/opensearch/sdk/PutDataObjectResponse.java +++ b/common/src/main/java/org/opensearch/sdk/PutDataObjectResponse.java @@ -8,40 +8,26 @@ */ package org.opensearch.sdk; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; -public class PutDataObjectResponse { - private final String id; - private final XContentParser parser; +public class PutDataObjectResponse extends DataObjectResponse { /** * Instantiate this request with an id and parser representing an IndexResponse *

* For data storage implementations other than OpenSearch, the id may be referred to as a primary key. + * @param index the index * @param id the document id * @param parser a parser that can be used to create an IndexResponse + * @param failed whether the request failed + * @param cause the Exception causing the failure + * @param status the RestStatus */ - public PutDataObjectResponse(String id, XContentParser parser) { - this.id = id; - this.parser = parser; + public PutDataObjectResponse(String index, String id, XContentParser parser, boolean failed, Exception cause, RestStatus status) { + super(index, id, parser, failed, cause, status); } - /** - * Returns the document id - * @return the id - */ - public String id() { - return this.id; - } - - /** - * Returns the parser that can be used to create an IndexResponse - * @return the parser - */ - public XContentParser parser() { - return this.parser; - } - /** * Instantiate a builder for this object * @return a builder instance @@ -53,41 +39,14 @@ public static Builder builder() { /** * Class for constructing a Builder for this Response Object */ - public static class Builder { - private String id = null; - private XContentParser parser = null; + public static class Builder extends DataObjectResponse.Builder { - /** - * Empty Constructor for the Builder object - */ - private Builder() {} - - /** - * Add an id to this builder - * @param id the id to add - * @return the updated builder - */ - public Builder id(String id) { - this.id = id; - return this; - } - - /** - * Add a parser to this builder - * @param parser a parser that can be used to create an IndexResponse - * @return the updated builder - */ - public Builder parser(XContentParser parser) { - this.parser = parser; - return this; - } - /** * Builds the response * @return A {@link PutDataObjectResponse} */ public PutDataObjectResponse build() { - return new PutDataObjectResponse(this.id, this.parser); + return new PutDataObjectResponse(this.index, this.id, this.parser, this.failed, this.cause, this.status); } } } diff --git a/common/src/main/java/org/opensearch/sdk/SdkClient.java b/common/src/main/java/org/opensearch/sdk/SdkClient.java index 97bfd93dd3..5c7313fbf8 100644 --- a/common/src/main/java/org/opensearch/sdk/SdkClient.java +++ b/common/src/main/java/org/opensearch/sdk/SdkClient.java @@ -8,6 +8,7 @@ */ package org.opensearch.sdk; +import java.util.List; import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; import java.util.concurrent.Executor; @@ -20,10 +21,10 @@ import static org.opensearch.sdk.SdkClientUtils.unwrapAndConvertToException; public class SdkClient { - + private final SdkClientDelegate delegate; private final Boolean isMultiTenancyEnabled; - + public SdkClient(SdkClientDelegate delegate, Boolean multiTenancy) { this.delegate = delegate; this.isMultiTenancyEnabled = multiTenancy; @@ -167,6 +168,42 @@ public DeleteDataObjectResponse deleteDataObject(DeleteDataObjectRequest request } } + /** + * Perform a bulk request for multiple data objects/documents in potentially multiple tables/indices. + * + * @param request A request identifying the bulk requests to execute + * @param executor the executor to use for asynchronous execution + * @return A completion stage encapsulating the response or exception + */ + public CompletionStage bulkDataObjectAsync(BulkDataObjectRequest request, Executor executor) { + validateTenantIds(request.requests()); + return delegate.bulkDataObjectAsync(request, executor, isMultiTenancyEnabled); + } + + /** + * Perform a bulk request for multiple data objects/documents in potentially multiple tables/indices. + * + * @param request A request identifying the bulk requests to execute + * @return A completion stage encapsulating the response or exception + */ + public CompletionStage bulkDataObjectAsync(BulkDataObjectRequest request) { + return bulkDataObjectAsync(request, ForkJoinPool.commonPool()); + } + + /** + * Perform a bulk request for multiple data objects/documents in potentially multiple tables/indices. + * + * @param request A request identifying the bulk requests to execute + * @return A response on success. Throws unchecked exceptions or {@link OpenSearchException} wrapping the cause on checked exception. + */ + public BulkDataObjectResponse bulkDataObject(BulkDataObjectRequest request) { + try { + return bulkDataObjectAsync(request).toCompletableFuture().join(); + } catch (CompletionException e) { + throw ExceptionsHelper.convertToRuntime(unwrapAndConvertToException(e)); + } + } + /** * Search for data objects/documents in a table/index. * @@ -219,4 +256,16 @@ private void validateTenantId(String tenantId) { throw new IllegalArgumentException("A tenant ID is required when multitenancy is enabled."); } } + + /** + * Throw exception if tenantId is null for any bulk request and multitenancy is enabled + * @param tenantId The tenantId from the request + */ + private void validateTenantIds(List requests) { + if (Boolean.TRUE.equals(isMultiTenancyEnabled) && requests.stream().map(DataObjectRequest::tenantId).anyMatch(Strings::isNullOrEmpty)) { + throw new IllegalArgumentException("A tenant ID is required for every bulk request when multitenancy is enabled."); + } + } + + } diff --git a/common/src/main/java/org/opensearch/sdk/SdkClientDelegate.java b/common/src/main/java/org/opensearch/sdk/SdkClientDelegate.java index 7e0ab0a3ee..5b036f08ba 100644 --- a/common/src/main/java/org/opensearch/sdk/SdkClientDelegate.java +++ b/common/src/main/java/org/opensearch/sdk/SdkClientDelegate.java @@ -20,7 +20,11 @@ public interface SdkClientDelegate { * @param isMultiTenancyEnabled whether multitenancy is enabled * @return A completion stage encapsulating the response or exception */ - CompletionStage putDataObjectAsync(PutDataObjectRequest request, Executor executor, Boolean isMultiTenancyEnabled); + CompletionStage putDataObjectAsync( + PutDataObjectRequest request, + Executor executor, + Boolean isMultiTenancyEnabled + ); /** * Read/Get a data object/document from a table/index. @@ -30,7 +34,11 @@ public interface SdkClientDelegate { * @param isMultiTenancyEnabled whether multitenancy is enabled * @return A completion stage encapsulating the response or exception */ - CompletionStage getDataObjectAsync(GetDataObjectRequest request, Executor executor, Boolean isMultiTenancyEnabled); + CompletionStage getDataObjectAsync( + GetDataObjectRequest request, + Executor executor, + Boolean isMultiTenancyEnabled + ); /** * Update a data object/document in a table/index. @@ -40,7 +48,11 @@ public interface SdkClientDelegate { * @param isMultiTenancyEnabled whether multitenancy is enabled * @return A completion stage encapsulating the response or exception */ - CompletionStage updateDataObjectAsync(UpdateDataObjectRequest request, Executor executor, Boolean isMultiTenancyEnabled); + CompletionStage updateDataObjectAsync( + UpdateDataObjectRequest request, + Executor executor, + Boolean isMultiTenancyEnabled + ); /** * Delete a data object/document from a table/index. @@ -50,7 +62,25 @@ public interface SdkClientDelegate { * @param isMultiTenancyEnabled whether multitenancy is enabled * @return A completion stage encapsulating the response or exception */ - CompletionStage deleteDataObjectAsync(DeleteDataObjectRequest request, Executor executor, Boolean isMultiTenancyEnabled); + CompletionStage deleteDataObjectAsync( + DeleteDataObjectRequest request, + Executor executor, + Boolean isMultiTenancyEnabled + ); + + /** + * Perform a bulk request for multiple data objects/documents in potentially multiple tables/indices. + * + * @param request A request identifying the requests to process in bulk + * @param executor the executor to use for asynchronous execution + * @param isMultiTenancyEnabled whether multitenancy is enabled + * @return A completion stage encapsulating the response or exception + */ + CompletionStage bulkDataObjectAsync( + BulkDataObjectRequest request, + Executor executor, + Boolean isMultiTenancyEnabled + ); /** * Search for data objects/documents in a table/index. @@ -60,5 +90,9 @@ public interface SdkClientDelegate { * @param isMultiTenancyEnabled whether multitenancy is enabled * @return A completion stage encapsulating the response or exception */ - CompletionStage searchDataObjectAsync(SearchDataObjectRequest request, Executor executor, Boolean isMultiTenancyEnabled); + CompletionStage searchDataObjectAsync( + SearchDataObjectRequest request, + Executor executor, + Boolean isMultiTenancyEnabled + ); } diff --git a/common/src/main/java/org/opensearch/sdk/SdkClientSettings.java b/common/src/main/java/org/opensearch/sdk/SdkClientSettings.java index a5dfa10e83..e49a4e18f3 100644 --- a/common/src/main/java/org/opensearch/sdk/SdkClientSettings.java +++ b/common/src/main/java/org/opensearch/sdk/SdkClientSettings.java @@ -8,11 +8,11 @@ */ package org.opensearch.sdk; +import java.util.Set; + import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Setting.Property; -import java.util.Set; - /** Settings applicable to the SdkClient */ public class SdkClientSettings { @@ -27,19 +27,23 @@ public class SdkClientSettings { private static final String AOSS_SERVICE_NAME = "aoss"; /** Service Names compatible with AWS SDK v2. */ public static final Set VALID_AWS_OPENSEARCH_SERVICE_NAMES = Set.of(AOS_SERVICE_NAME, AOSS_SERVICE_NAME); - + /** The value for remote metadata type for a remote cluster on AWS Dynamo DB and Zero-ETL replication to OpenSearch */ public static final String AWS_DYNAMO_DB = "AWSDynamoDB"; - + /** The key for remote metadata endpoint, applicable to remote clusters or Zero-ETL DynamoDB sinks */ public static final String REMOTE_METADATA_ENDPOINT_KEY = "plugins.ml_commons.remote_metadata_endpoint"; /** The key for remote metadata region, applicable for AWS SDK v2 connections */ public static final String REMOTE_METADATA_REGION_KEY = "plugins.ml_commons.remote_metadata_region"; /** The key for remote metadata service name used by service-specific SDKs */ public static final String REMOTE_METADATA_SERVICE_NAME_KEY = "plugins.ml_commons.remote_metadata_service_name"; - - public static final Setting REMOTE_METADATA_TYPE = Setting.simpleString(REMOTE_METADATA_TYPE_KEY, Property.NodeScope, Property.Final); - public static final Setting REMOTE_METADATA_ENDPOINT = Setting.simpleString(REMOTE_METADATA_ENDPOINT_KEY, Property.NodeScope, Property.Final); - public static final Setting REMOTE_METADATA_REGION = Setting.simpleString(REMOTE_METADATA_REGION_KEY, Property.NodeScope, Property.Final); - public static final Setting REMOTE_METADATA_SERVICE_NAME = Setting.simpleString(REMOTE_METADATA_SERVICE_NAME_KEY, Property.NodeScope, Property.Final); + + public static final Setting REMOTE_METADATA_TYPE = Setting + .simpleString(REMOTE_METADATA_TYPE_KEY, Property.NodeScope, Property.Final); + public static final Setting REMOTE_METADATA_ENDPOINT = Setting + .simpleString(REMOTE_METADATA_ENDPOINT_KEY, Property.NodeScope, Property.Final); + public static final Setting REMOTE_METADATA_REGION = Setting + .simpleString(REMOTE_METADATA_REGION_KEY, Property.NodeScope, Property.Final); + public static final Setting REMOTE_METADATA_SERVICE_NAME = Setting + .simpleString(REMOTE_METADATA_SERVICE_NAME_KEY, Property.NodeScope, Property.Final); } diff --git a/common/src/main/java/org/opensearch/sdk/SdkClientUtils.java b/common/src/main/java/org/opensearch/sdk/SdkClientUtils.java index c82bc0dd48..e1b658d82f 100644 --- a/common/src/main/java/org/opensearch/sdk/SdkClientUtils.java +++ b/common/src/main/java/org/opensearch/sdk/SdkClientUtils.java @@ -64,11 +64,11 @@ public static Throwable getRethrownExecutionExceptionRootCause(Throwable throwab public static String lowerCaseEnumValues(String field, String json) { // Use a matcher to find and replace the field value in lowercase Matcher matcher = Pattern.compile("(\"" + Pattern.quote(field) + "\"):(\"[A-Z_]+\")").matcher(json); - StringBuffer sb = new StringBuffer(); + StringBuffer sb = new StringBuffer(); while (matcher.find()) { matcher.appendReplacement(sb, matcher.group(1) + ":" + matcher.group(2).toLowerCase(Locale.ROOT)); } - matcher.appendTail(sb); + matcher.appendTail(sb); return sb.toString(); } } diff --git a/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java b/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java index 65518b2755..1a28c4bd2e 100644 --- a/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java +++ b/common/src/main/java/org/opensearch/sdk/UpdateDataObjectRequest.java @@ -8,18 +8,16 @@ */ package org.opensearch.sdk; +import java.io.IOException; +import java.util.Map; + import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; -import java.io.IOException; -import java.util.Map; +import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; -public class UpdateDataObjectRequest { +public class UpdateDataObjectRequest extends DataObjectRequest { - private final String index; - private final String id; - private final String tenantId; private final Long ifSeqNo; private final Long ifPrimaryTerm; private final int retryOnConflict; @@ -46,39 +44,13 @@ public UpdateDataObjectRequest( int retryOnConflict, ToXContentObject dataObject ) { - this.index = index; - this.id = id; - this.tenantId = tenantId; + super(index, id, tenantId); this.ifSeqNo = ifSeqNo; this.ifPrimaryTerm = ifPrimaryTerm; this.retryOnConflict = retryOnConflict; this.dataObject = dataObject; } - /** - * Returns the index - * @return the index - */ - public String index() { - return this.index; - } - - /** - * Returns the document id - * @return the id - */ - public String id() { - return this.id; - } - - /** - * Returns the tenant id - * @return the tenantId - */ - public String tenantId() { - return this.tenantId; - } - /** * Returns the sequence number to match, or null if no match required * @return the ifSeqNo @@ -102,7 +74,7 @@ public Long ifPrimaryTerm() { public int retryOnConflict() { return retryOnConflict; } - + /** * Returns the data object * @return the data object @@ -111,6 +83,11 @@ public ToXContentObject dataObject() { return this.dataObject; } + @Override + public boolean isWriteRequest() { + return true; + } + /** * Instantiate a builder for this object * @return a builder instance @@ -122,50 +99,12 @@ public static Builder builder() { /** * Class for constructing a Builder for this Request Object */ - public static class Builder { - private String index = null; - private String id = null; - private String tenantId = null; + public static class Builder extends DataObjectRequest.Builder { private Long ifSeqNo = null; private Long ifPrimaryTerm = null; private int retryOnConflict = 0; private ToXContentObject dataObject = null; - /** - * Empty Constructor for the Builder object - */ - private Builder() {} - - /** - * Add an index to this builder - * @param index the index to put the object - * @return the updated builder - */ - public Builder index(String index) { - this.index = index; - return this; - } - - /** - * Add an id to this builder - * @param id the document id - * @return the updated builder - */ - public Builder id(String id) { - this.id = id; - return this; - } - - /** - * Add a tenant ID to this builder - * @param tenantId the tenant id - * @return the updated builder - */ - public Builder tenantId(String tenantId) { - this.tenantId = tenantId; - return this; - } - /** * Only perform this update request if the document's modification was assigned the given * sequence number. Must be used in combination with {@link #ifPrimaryTerm(long)} diff --git a/common/src/main/java/org/opensearch/sdk/UpdateDataObjectResponse.java b/common/src/main/java/org/opensearch/sdk/UpdateDataObjectResponse.java index 6b3514e05f..c4179c5f18 100644 --- a/common/src/main/java/org/opensearch/sdk/UpdateDataObjectResponse.java +++ b/common/src/main/java/org/opensearch/sdk/UpdateDataObjectResponse.java @@ -8,40 +8,26 @@ */ package org.opensearch.sdk; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; -public class UpdateDataObjectResponse { - private final String id; - private final XContentParser parser; +public class UpdateDataObjectResponse extends DataObjectResponse { /** * Instantiate this request with an id and parser representing an UpdateResponse *

* For data storage implementations other than OpenSearch, the id may be referred to as a primary key. + * @param index the index * @param id the document id * @param parser a parser that can be used to create an UpdateResponse + * @param failed whether the request failed + * @param cause the Exception causing the failure + * @param status the RestStatus */ - public UpdateDataObjectResponse(String id, XContentParser parser) { - this.id = id; - this.parser = parser; + public UpdateDataObjectResponse(String index, String id, XContentParser parser, boolean failed, Exception cause, RestStatus status) { + super(index, id, parser, failed, cause, status); } - /** - * Returns the document id - * @return the id - */ - public String id() { - return this.id; - } - - /** - * Returns the parser that can be used to create an UpdateResponse - * @return the parser - */ - public XContentParser parser() { - return this.parser; - } - /** * Instantiate a builder for this object * @return a builder instance @@ -53,41 +39,14 @@ public static Builder builder() { /** * Class for constructing a Builder for this Response Object */ - public static class Builder { - private String id = null; - private XContentParser parser = null; + public static class Builder extends DataObjectResponse.Builder { - /** - * Empty Constructor for the Builder object - */ - private Builder() {} - - /** - * Add an id to this builder - * @param id the id to add - * @return the updated builder - */ - public Builder id(String id) { - this.id = id; - return this; - } - - /** - * Add a parser to this builder - * @param parser a parser that can be used to create an UpdateResponse - * @return the updated builder - */ - public Builder parser(XContentParser parser) { - this.parser = parser; - return this; - } - /** * Builds the response * @return A {@link UpdateDataObjectResponse} */ public UpdateDataObjectResponse build() { - return new UpdateDataObjectResponse(this.id, this.parser); + return new UpdateDataObjectResponse(this.index, this.id, this.parser, this.failed, this.cause, this.status); } } } diff --git a/common/src/main/java/org/opensearch/sdk/client/LocalClusterIndicesClient.java b/common/src/main/java/org/opensearch/sdk/client/LocalClusterIndicesClient.java index aae00a2dc2..625be8c553 100644 --- a/common/src/main/java/org/opensearch/sdk/client/LocalClusterIndicesClient.java +++ b/common/src/main/java/org/opensearch/sdk/client/LocalClusterIndicesClient.java @@ -8,21 +8,17 @@ */ package org.opensearch.sdk.client; -import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; -import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - import java.io.IOException; -import java.security.AccessController; -import java.security.PrivilegedAction; import java.util.Arrays; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; import java.util.concurrent.Executor; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.DocWriteRequest.OpType; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetRequest; @@ -30,12 +26,9 @@ import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; -import org.opensearch.common.action.ActionFuture; import org.opensearch.common.inject.Inject; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.common.Strings; @@ -48,11 +41,15 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.ml.common.CommonValue; +import org.opensearch.sdk.AbstractSdkClient; +import org.opensearch.sdk.BulkDataObjectRequest; +import org.opensearch.sdk.BulkDataObjectResponse; +import org.opensearch.sdk.DataObjectRequest; +import org.opensearch.sdk.DataObjectResponse; import org.opensearch.sdk.DeleteDataObjectRequest; import org.opensearch.sdk.DeleteDataObjectResponse; import org.opensearch.sdk.GetDataObjectRequest; @@ -60,27 +57,32 @@ import org.opensearch.sdk.PutDataObjectRequest; import org.opensearch.sdk.PutDataObjectResponse; import org.opensearch.sdk.SdkClient; -import org.opensearch.sdk.SdkClientDelegate; import org.opensearch.sdk.SearchDataObjectRequest; import org.opensearch.sdk.SearchDataObjectResponse; import org.opensearch.sdk.UpdateDataObjectRequest; import org.opensearch.sdk.UpdateDataObjectResponse; import org.opensearch.search.builder.SearchSourceBuilder; +import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + import lombok.extern.log4j.Log4j2; /** - * An implementation of {@link SdkClient} that stores data in a local OpenSearch cluster using the Node Client. + * An implementation of {@link SdkClient} that stores data in a local OpenSearch + * cluster using the Node Client. */ @Log4j2 -public class LocalClusterIndicesClient implements SdkClientDelegate { +public class LocalClusterIndicesClient extends AbstractSdkClient { private final Client client; private final NamedXContentRegistry xContentRegistry; /** * Instantiate this object with an OpenSearch client. - * @param client The client to wrap + * + * @param client The client to wrap * @param xContentRegistry the registry of XContent objects */ @Inject @@ -95,27 +97,32 @@ public CompletionStage putDataObjectAsync( Executor executor, Boolean isMultiTenancyEnabled ) { - return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { - try (XContentBuilder sourceBuilder = XContentFactory.jsonBuilder()) { + return executePrivilegedAsync(() -> { + try { log.info("Indexing data object in {}", request.index()); - IndexRequest indexRequest = new IndexRequest(request.index()) - .opType(request.overwriteIfExists() ? OpType.INDEX : OpType.CREATE) - .setRefreshPolicy(IMMEDIATE) - .source(request.dataObject().toXContent(sourceBuilder, EMPTY_PARAMS)); - if (!Strings.isNullOrEmpty(request.id())) { - indexRequest.id(request.id()); - } + IndexRequest indexRequest = createIndexRequest(request).setRefreshPolicy(IMMEDIATE); IndexResponse indexResponse = client.index(indexRequest).actionGet(); log.info("Creation status for id {}: {}", indexResponse.getId(), indexResponse.getResult()); return PutDataObjectResponse.builder().id(indexResponse.getId()).parser(createParser(indexResponse)).build(); } catch (IOException e) { - // Rethrow unchecked exception on XContent parsing error throw new OpenSearchStatusException( "Failed to parse data object to put in index " + request.index(), RestStatus.BAD_REQUEST ); } - }), executor); + }, executor); + } + + private IndexRequest createIndexRequest(PutDataObjectRequest putDataObjectRequest) throws IOException { + try (XContentBuilder sourceBuilder = XContentFactory.jsonBuilder()) { + IndexRequest indexRequest = new IndexRequest(putDataObjectRequest.index()) + .opType(putDataObjectRequest.overwriteIfExists() ? OpType.INDEX : OpType.CREATE) + .source(putDataObjectRequest.dataObject().toXContent(sourceBuilder, EMPTY_PARAMS)); + if (!Strings.isNullOrEmpty(putDataObjectRequest.id())) { + indexRequest.id(putDataObjectRequest.id()); + } + return indexRequest; + } } @Override @@ -124,7 +131,7 @@ public CompletionStage getDataObjectAsync( Executor executor, Boolean isMultiTenancyEnabled ) { - return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + return executePrivilegedAsync(() -> { try { GetResponse getResponse = client .get(new GetRequest(request.index(), request.id()).fetchSourceContext(request.fetchSourceContext())) @@ -145,7 +152,7 @@ public CompletionStage getDataObjectAsync( RestStatus.INTERNAL_SERVER_ERROR ); } - }), executor); + }, executor); } @Override @@ -154,20 +161,10 @@ public CompletionStage updateDataObjectAsync( Executor executor, Boolean isMultiTenancyEnabled ) { - return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { - try (XContentBuilder sourceBuilder = XContentFactory.jsonBuilder()) { + return executePrivilegedAsync(() -> { + try { log.info("Updating {} from {}", request.id(), request.index()); - UpdateRequest updateRequest = new UpdateRequest(request.index(), request.id()) - .doc(request.dataObject().toXContent(sourceBuilder, EMPTY_PARAMS)); - if (request.ifSeqNo() != null) { - updateRequest.setIfSeqNo(request.ifSeqNo()); - } - if (request.ifPrimaryTerm() != null) { - updateRequest.setIfPrimaryTerm(request.ifPrimaryTerm()); - } - if (request.retryOnConflict() > 0) { - updateRequest.retryOnConflict(request.retryOnConflict()); - } + UpdateRequest updateRequest = createUpdateRequest(request); UpdateResponse updateResponse = client.update(updateRequest).actionGet(); if (updateResponse == null) { log.info("Null UpdateResponse"); @@ -177,19 +174,34 @@ public CompletionStage updateDataObjectAsync( return UpdateDataObjectResponse.builder().id(updateResponse.getId()).parser(createParser(updateResponse)).build(); } catch (VersionConflictEngineException vcee) { log.error("Document version conflict updating {} in {}: {}", request.id(), request.index(), vcee.getMessage(), vcee); - // Rethrow throw new OpenSearchStatusException( "Document version conflict updating " + request.id() + " in index " + request.index(), RestStatus.CONFLICT ); } catch (IOException e) { - // Rethrow unchecked exception on XContent parsing error throw new OpenSearchStatusException( "Failed to parse data object to update in index " + request.index(), RestStatus.BAD_REQUEST ); } - }), executor); + }, executor); + } + + private UpdateRequest createUpdateRequest(UpdateDataObjectRequest updateDataObjectRequest) throws IOException { + try (XContentBuilder sourceBuilder = XContentFactory.jsonBuilder()) { + UpdateRequest updateRequest = new UpdateRequest(updateDataObjectRequest.index(), updateDataObjectRequest.id()) + .doc(updateDataObjectRequest.dataObject().toXContent(sourceBuilder, EMPTY_PARAMS)); + if (updateDataObjectRequest.ifSeqNo() != null) { + updateRequest.setIfSeqNo(updateDataObjectRequest.ifSeqNo()); + } + if (updateDataObjectRequest.ifPrimaryTerm() != null) { + updateRequest.setIfPrimaryTerm(updateDataObjectRequest.ifPrimaryTerm()); + } + if (updateDataObjectRequest.retryOnConflict() > 0) { + updateRequest.retryOnConflict(updateDataObjectRequest.retryOnConflict()); + } + return updateRequest; + } } @Override @@ -198,20 +210,100 @@ public CompletionStage deleteDataObjectAsync( Executor executor, Boolean isMultiTenancyEnabled ) { - return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + return executePrivilegedAsync(() -> { try { log.info("Deleting {} from {}", request.id(), request.index()); - DeleteResponse deleteResponse = client.delete(new DeleteRequest(request.index(), request.id()).setRefreshPolicy(IMMEDIATE)).actionGet(); + DeleteRequest deleteRequest = createDeleteRequest(request).setRefreshPolicy(IMMEDIATE); + DeleteResponse deleteResponse = client.delete(deleteRequest).actionGet(); log.info("Deletion status for id {}: {}", deleteResponse.getId(), deleteResponse.getResult()); return DeleteDataObjectResponse.builder().id(deleteResponse.getId()).parser(createParser(deleteResponse)).build(); } catch (IOException e) { - // Rethrow unchecked exception on XContent parsing error throw new OpenSearchStatusException( "Failed to parse data object to deletion response in index " + request.index(), RestStatus.INTERNAL_SERVER_ERROR ); } - }), executor); + }, executor); + } + + private DeleteRequest createDeleteRequest(DeleteDataObjectRequest deleteDataObjectRequest) { + return new DeleteRequest(deleteDataObjectRequest.index(), deleteDataObjectRequest.id()); + } + + @Override + public CompletionStage bulkDataObjectAsync( + BulkDataObjectRequest request, + Executor executor, + Boolean isMultiTenancyEnabled + ) { + return executePrivilegedAsync(() -> { + try { + log.info("Performing {} bulk actions on indices {}", request.requests().size(), request.getIndices()); + BulkRequest bulkRequest = new BulkRequest(); + + for (DataObjectRequest dataObjectRequest : request.requests()) { + if (dataObjectRequest instanceof PutDataObjectRequest) { + bulkRequest.add(createIndexRequest((PutDataObjectRequest) dataObjectRequest)); + } else if (dataObjectRequest instanceof UpdateDataObjectRequest) { + bulkRequest.add(createUpdateRequest((UpdateDataObjectRequest) dataObjectRequest)); + } else if (dataObjectRequest instanceof DeleteDataObjectRequest) { + bulkRequest.add(createDeleteRequest((DeleteDataObjectRequest) dataObjectRequest)); + } + } + + BulkResponse bulkResponse = client.bulk(bulkRequest.setRefreshPolicy(IMMEDIATE)).actionGet(); + return bulkResponseToDataObjectResponse(bulkResponse); + } catch (IOException e) { + // Rethrow unchecked exception on XContent parsing error + throw new OpenSearchStatusException("Failed to parse data object in a bulk response", RestStatus.INTERNAL_SERVER_ERROR); + } + }, executor); + } + + private BulkDataObjectResponse bulkResponseToDataObjectResponse(BulkResponse bulkResponse) throws IOException { + int responseCount = bulkResponse.getItems().length; + log.info("Bulk action complete for {} items: {}", responseCount, bulkResponse.hasFailures() ? "has failures" : "success"); + DataObjectResponse[] responses = new DataObjectResponse[responseCount]; + for (int i = 0; i < responseCount; i++) { + BulkItemResponse itemResponse = bulkResponse.getItems()[i]; + responses[i] = createDataObjectResponse(itemResponse); + } + return new BulkDataObjectResponse( + responses, + bulkResponse.getTook().millis(), + bulkResponse.getIngestTookInMillis(), + bulkResponse.hasFailures(), + createParser(bulkResponse) + ); + } + + private DataObjectResponse createDataObjectResponse(BulkItemResponse itemResponse) throws IOException { + switch (itemResponse.getOpType()) { + case INDEX: + case CREATE: + return PutDataObjectResponse + .builder() + .id(itemResponse.getId()) + .parser(createParser(itemResponse)) + .failed(itemResponse.isFailed()) + .build(); + case UPDATE: + return UpdateDataObjectResponse + .builder() + .id(itemResponse.getId()) + .parser(createParser(itemResponse)) + .failed(itemResponse.isFailed()) + .build(); + case DELETE: + return DeleteDataObjectResponse + .builder() + .id(itemResponse.getId()) + .parser(createParser(itemResponse)) + .failed(itemResponse.isFailed()) + .build(); + default: + throw new OpenSearchStatusException("Invalid operation type for bulk response", RestStatus.INTERNAL_SERVER_ERROR); + } } @Override @@ -242,22 +334,20 @@ public CompletionStage searchDataObjectAsync( log.debug("Adding tenant id to search query", Arrays.toString(request.indices())); } log.info("Searching {}", Arrays.toString(request.indices())); - ActionFuture searchResponseFuture = AccessController - .doPrivileged((PrivilegedAction>) () -> { - return client.search(new SearchRequest(request.indices(), searchSource)); + return executePrivilegedAsync(() -> client.search(new SearchRequest(request.indices(), searchSource)), executor) + .thenCompose(searchResponseFuture -> CompletableFuture.supplyAsync(searchResponseFuture::actionGet, executor)) + .thenApply(searchResponse -> { + log.info("Search returned {} hits", searchResponse.getHits().getTotalHits()); + try { + return SearchDataObjectResponse.builder().parser(createParser(searchResponse)).build(); + } catch (IOException e) { + // Rethrow unchecked exception on XContent parsing error + throw new OpenSearchStatusException( + "Failed to search indices " + Arrays.toString(request.indices()), + RestStatus.INTERNAL_SERVER_ERROR + ); + } }); - return CompletableFuture.supplyAsync(searchResponseFuture::actionGet, executor).thenApply(searchResponse -> { - log.info("Search returned {} hits", searchResponse.getHits().getTotalHits()); - try { - return SearchDataObjectResponse.builder().parser(createParser(searchResponse)).build(); - } catch (IOException e) { - // Rethrow unchecked exception on XContent parsing error - throw new OpenSearchStatusException( - "Failed to search indices " + Arrays.toString(request.indices()), - RestStatus.INTERNAL_SERVER_ERROR - ); - } - }); } private XContentParser createParser(ToXContent obj) throws IOException { diff --git a/common/src/test/java/org/opensearch/sdk/BulkDataObjectRequestTests.java b/common/src/test/java/org/opensearch/sdk/BulkDataObjectRequestTests.java new file mode 100644 index 0000000000..a6f8650c3f --- /dev/null +++ b/common/src/test/java/org/opensearch/sdk/BulkDataObjectRequestTests.java @@ -0,0 +1,97 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.sdk; + +import java.util.Set; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.support.WriteRequest.RefreshPolicy; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class BulkDataObjectRequestTests { + private String testIndex; + private String testGlobalIndex; + private String testTenantId; + + @Before + public void setUp() { + testIndex = "test-index"; + testGlobalIndex = "test-global-index"; + testTenantId = "test-tenant-id"; + } + + @Test + public void testBulkDataObjectRequest() { + BulkDataObjectRequest request = BulkDataObjectRequest + .builder() + .globalIndex(testGlobalIndex) + .build() + .add(PutDataObjectRequest.builder().index(testIndex).build()) + .add(UpdateDataObjectRequest.builder().build()) + .add(DeleteDataObjectRequest.builder().index(testIndex).tenantId(testTenantId).build()) + .setRefreshPolicy(RefreshPolicy.IMMEDIATE); + + assertEquals(Set.of(testIndex, testGlobalIndex), request.getIndices()); + assertEquals(3, request.requests().size()); + assertEquals(RefreshPolicy.IMMEDIATE, request.getRefreshPolicy()); + + DataObjectRequest r0 = request.requests().get(0); + assertTrue(r0 instanceof PutDataObjectRequest); + assertEquals(testIndex, r0.index()); + assertNull(r0.tenantId()); + + DataObjectRequest r1 = request.requests().get(1); + assertTrue(r1 instanceof UpdateDataObjectRequest); + assertEquals(testGlobalIndex, r1.index()); + assertNull(r1.tenantId()); + + DataObjectRequest r2 = request.requests().get(2); + assertTrue(r2 instanceof DeleteDataObjectRequest); + assertEquals(testIndex, r2.index()); + assertEquals(testTenantId, r2.tenantId()); + } + + @Test + public void testBulkDataObjectRequest_Tenant() { + BulkDataObjectRequest request = BulkDataObjectRequest + .builder() + .build() + .add(PutDataObjectRequest.builder().index(testIndex).tenantId(testTenantId).build()) + .add(DeleteDataObjectRequest.builder().index(testIndex).tenantId(testTenantId).build()); + + assertEquals(Set.of(testIndex), request.getIndices()); + assertEquals(2, request.requests().size()); + + DataObjectRequest r0 = request.requests().get(0); + assertTrue(r0 instanceof PutDataObjectRequest); + assertEquals(testIndex, r0.index()); + assertEquals(testTenantId, r0.tenantId()); + + DataObjectRequest r1 = request.requests().get(1); + assertTrue(r1 instanceof DeleteDataObjectRequest); + assertEquals(testIndex, r1.index()); + assertEquals(testTenantId, r1.tenantId()); + } + + @Test + public void testBulkDataObjectRequest_Exceptions() { + PutDataObjectRequest nullIndexRequest = PutDataObjectRequest.builder().build(); + GetDataObjectRequest badTypeRequest = GetDataObjectRequest.builder().index(testIndex).build(); + + BulkDataObjectRequest bulkRequest = BulkDataObjectRequest.builder().build(); + assertThrows(IllegalArgumentException.class, () -> bulkRequest.add(nullIndexRequest)); + assertThrows(IllegalArgumentException.class, () -> bulkRequest.add(badTypeRequest)); + } +} diff --git a/common/src/test/java/org/opensearch/sdk/BulkDataObjectResponseTests.java b/common/src/test/java/org/opensearch/sdk/BulkDataObjectResponseTests.java new file mode 100644 index 0000000000..fee4b360d7 --- /dev/null +++ b/common/src/test/java/org/opensearch/sdk/BulkDataObjectResponseTests.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.sdk; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + +public class BulkDataObjectResponseTests { + @Mock + XContentParser parser; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + } + + @Test + public void testBulkDataObjectResponse() { + DataObjectResponse[] responses = List + .of( + PutDataObjectResponse.builder().build(), + UpdateDataObjectResponse.builder().build(), + DeleteDataObjectResponse.builder().build() + ) + .toArray(new DataObjectResponse[0]); + + BulkDataObjectResponse response = new BulkDataObjectResponse(responses, 1L, false, parser); + + assertEquals(3, response.getResponses().length); + assertEquals(1L, response.getTookInMillis()); + assertEquals(-1L, response.getIngestTookInMillis()); + assertFalse(response.hasFailures()); + assertSame(parser, response.parser()); + } + + @Test + public void testBulkDataObjectRequest_Failures() { + DataObjectResponse[] responses = List + .of(PutDataObjectResponse.builder().build(), DeleteDataObjectResponse.builder().failed(true).build()) + .toArray(new DataObjectResponse[0]); + + BulkDataObjectResponse response = new BulkDataObjectResponse(responses, 1L, true, parser); + + assertEquals(2, response.getResponses().length); + assertEquals(1L, response.getTookInMillis()); + assertEquals(-1L, response.getIngestTookInMillis()); + assertTrue(response.hasFailures()); + assertSame(parser, response.parser()); + } +} diff --git a/common/src/test/java/org/opensearch/sdk/DeleteDataObjectResponseTests.java b/common/src/test/java/org/opensearch/sdk/DeleteDataObjectResponseTests.java index 5ca9a5784f..97aeb30283 100644 --- a/common/src/test/java/org/opensearch/sdk/DeleteDataObjectResponseTests.java +++ b/common/src/test/java/org/opensearch/sdk/DeleteDataObjectResponseTests.java @@ -10,27 +10,49 @@ import org.junit.Before; import org.junit.Test; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; import static org.mockito.Mockito.mock; public class DeleteDataObjectResponseTests { + private String testIndex; private String testId; private XContentParser testParser; + private boolean testFailed; + private Exception testCause; + private RestStatus testStatus; @Before public void setUp() { + testIndex = "test-index"; testId = "test-id"; testParser = mock(XContentParser.class); + testFailed = true; + testCause = mock(RuntimeException.class); + testStatus = RestStatus.BAD_REQUEST; } @Test public void testDeleteDataObjectResponse() { - DeleteDataObjectResponse response = DeleteDataObjectResponse.builder().id(testId).parser(testParser).build(); + DeleteDataObjectResponse response = DeleteDataObjectResponse + .builder() + .index(testIndex) + .id(testId) + .parser(testParser) + .failed(testFailed) + .cause(testCause) + .status(testStatus) + .build(); + assertEquals(testIndex, response.index()); assertEquals(testId, response.id()); - assertEquals(testParser, response.parser()); + assertSame(testParser, response.parser()); + assertEquals(testFailed, response.isFailed()); + assertSame(testCause, response.cause()); + assertEquals(testStatus, response.status()); } } diff --git a/common/src/test/java/org/opensearch/sdk/GetDataObjectResponseTests.java b/common/src/test/java/org/opensearch/sdk/GetDataObjectResponseTests.java index 8f25e68829..469409bf29 100644 --- a/common/src/test/java/org/opensearch/sdk/GetDataObjectResponseTests.java +++ b/common/src/test/java/org/opensearch/sdk/GetDataObjectResponseTests.java @@ -10,31 +10,54 @@ import org.junit.Before; import org.junit.Test; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; import static org.mockito.Mockito.mock; public class GetDataObjectResponseTests { + private String testIndex; private String testId; private XContentParser testParser; + private boolean testFailed; + private Exception testCause; + private RestStatus testStatus; private Map testSource; @Before public void setUp() { + testIndex = "test-index"; testId = "test-id"; testParser = mock(XContentParser.class); + testFailed = true; + testCause = mock(RuntimeException.class); + testStatus = RestStatus.BAD_REQUEST; testSource = Map.of("foo", "bar"); } @Test public void testGetDataObjectResponse() { - GetDataObjectResponse response = GetDataObjectResponse.builder().id(testId).parser(testParser).source(testSource).build(); + GetDataObjectResponse response = GetDataObjectResponse + .builder() + .index(testIndex) + .id(testId) + .parser(testParser) + .failed(testFailed) + .cause(testCause) + .status(testStatus) + .source(testSource) + .build(); + assertEquals(testIndex, response.index()); assertEquals(testId, response.id()); - assertEquals(testParser, response.parser()); + assertSame(testParser, response.parser()); + assertEquals(testFailed, response.isFailed()); + assertSame(testCause, response.cause()); + assertEquals(testStatus, response.status()); assertEquals(testSource, response.source()); } } diff --git a/common/src/test/java/org/opensearch/sdk/PutDataObjectRequestTests.java b/common/src/test/java/org/opensearch/sdk/PutDataObjectRequestTests.java index 1377542203..5dec3bbc7a 100644 --- a/common/src/test/java/org/opensearch/sdk/PutDataObjectRequestTests.java +++ b/common/src/test/java/org/opensearch/sdk/PutDataObjectRequestTests.java @@ -14,6 +14,7 @@ import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.MediaType; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; @@ -79,7 +80,7 @@ public void testPutDataObjectRequestWithMap() throws IOException { contentBuilder.flush(); BytesReference bytes = BytesReference.bytes(contentBuilder); - Map resultingMap = XContentHelper.convertToMap(bytes, false, XContentType.JSON).v2(); + Map resultingMap = XContentHelper.convertToMap(bytes, false, (MediaType) XContentType.JSON).v2(); assertEquals(dataObjectMap, resultingMap); } diff --git a/common/src/test/java/org/opensearch/sdk/PutDataObjectResponseTests.java b/common/src/test/java/org/opensearch/sdk/PutDataObjectResponseTests.java index b94a1cd588..0bd62667ce 100644 --- a/common/src/test/java/org/opensearch/sdk/PutDataObjectResponseTests.java +++ b/common/src/test/java/org/opensearch/sdk/PutDataObjectResponseTests.java @@ -10,27 +10,48 @@ import org.junit.Before; import org.junit.Test; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; import static org.mockito.Mockito.mock; public class PutDataObjectResponseTests { + private String testIndex; private String testId; private XContentParser testParser; + private boolean testFailed; + private Exception testCause; + private RestStatus testStatus; @Before public void setUp() { + testId = "test-index"; testId = "test-id"; testParser = mock(XContentParser.class); + testFailed = true; + testCause = mock(RuntimeException.class); + testStatus = RestStatus.BAD_REQUEST; } @Test public void testPutDataObjectResponse() { - PutDataObjectResponse response = PutDataObjectResponse.builder().id(testId).parser(testParser).build(); + PutDataObjectResponse response = PutDataObjectResponse.builder() + .index(testIndex) + .id(testId) + .parser(testParser) + .failed(testFailed) + .cause(testCause) + .status(testStatus) + .build(); + assertEquals(testIndex, response.index()); assertEquals(testId, response.id()); - assertEquals(testParser, response.parser()); + assertSame(testParser, response.parser()); + assertEquals(testFailed, response.isFailed()); + assertSame(testCause, response.cause()); + assertEquals(testStatus, response.status()); } } diff --git a/common/src/test/java/org/opensearch/sdk/SdkClientTests.java b/common/src/test/java/org/opensearch/sdk/SdkClientTests.java index 7af96cd370..3a7305df05 100644 --- a/common/src/test/java/org/opensearch/sdk/SdkClientTests.java +++ b/common/src/test/java/org/opensearch/sdk/SdkClientTests.java @@ -16,9 +16,14 @@ import org.opensearch.OpenSearchStatusException; import org.opensearch.core.rest.RestStatus; +import java.security.PrivilegedAction; +import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -26,7 +31,9 @@ import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -34,7 +41,7 @@ public class SdkClientTests { private static final String TENANT_ID = "test_id"; private SdkClient sdkClient; - private SdkClientDelegate sdkClientImpl; + private AbstractSdkClient sdkClientImpl; @Mock private PutDataObjectRequest putRequest; @@ -53,6 +60,10 @@ public class SdkClientTests { @Mock private DeleteDataObjectResponse deleteResponse; @Mock + private BulkDataObjectRequest bulkRequest; + @Mock + private BulkDataObjectResponse bulkResponse; + @Mock private SearchDataObjectRequest searchRequest; @Mock private SearchDataObjectResponse searchResponse; @@ -69,7 +80,7 @@ public void setUp() { when(deleteRequest.tenantId()).thenReturn(TENANT_ID); when(searchRequest.tenantId()).thenReturn(TENANT_ID); - sdkClientImpl = spy(new SdkClientDelegate() { + sdkClientImpl = spy(new AbstractSdkClient() { @Override public CompletionStage putDataObjectAsync( PutDataObjectRequest request, @@ -106,6 +117,15 @@ public CompletionStage deleteDataObjectAsync( return CompletableFuture.completedFuture(deleteResponse); } + @Override + public CompletionStage bulkDataObjectAsync( + BulkDataObjectRequest request, + Executor executor, + Boolean isMultiTenancyEnabled + ) { + return CompletableFuture.completedFuture(bulkResponse); + } + @Override public CompletionStage searchDataObjectAsync( SearchDataObjectRequest request, @@ -131,30 +151,26 @@ public void testPutDataObjectNullTenantId() { when(putRequest.tenantId()).thenReturn(null); assertThrows(IllegalArgumentException.class, () -> sdkClient.putDataObject(putRequest)); } - + @Test public void testPutDataObjectException() { when(sdkClientImpl.putDataObjectAsync(any(PutDataObjectRequest.class), any(Executor.class), anyBoolean())) - .thenReturn(CompletableFuture.failedFuture(testException)); + .thenReturn(CompletableFuture.failedFuture(testException)); - OpenSearchStatusException exception = assertThrows(OpenSearchStatusException.class, () -> { - sdkClient.putDataObject(putRequest); - }); + OpenSearchStatusException exception = assertThrows(OpenSearchStatusException.class, () -> { sdkClient.putDataObject(putRequest); }); assertEquals(testException, exception); - assertFalse(Thread.interrupted()); + assertFalse(Thread.interrupted()); verify(sdkClientImpl).putDataObjectAsync(any(PutDataObjectRequest.class), any(Executor.class), anyBoolean()); } @Test public void testPutDataObjectInterrupted() { when(sdkClientImpl.putDataObjectAsync(any(PutDataObjectRequest.class), any(Executor.class), anyBoolean())) - .thenReturn(CompletableFuture.failedFuture(interruptedException)); + .thenReturn(CompletableFuture.failedFuture(interruptedException)); - OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { - sdkClient.putDataObject(putRequest); - }); + OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { sdkClient.putDataObject(putRequest); }); assertEquals(interruptedException, exception.getCause()); - assertTrue(Thread.interrupted()); + assertTrue(Thread.interrupted()); verify(sdkClientImpl).putDataObjectAsync(any(PutDataObjectRequest.class), any(Executor.class), anyBoolean()); } @@ -169,30 +185,26 @@ public void testGetDataObjectNullTenantId() { when(getRequest.tenantId()).thenReturn(null); assertThrows(IllegalArgumentException.class, () -> sdkClient.getDataObject(getRequest)); } - + @Test public void testGetDataObjectException() { when(sdkClientImpl.getDataObjectAsync(any(GetDataObjectRequest.class), any(Executor.class), anyBoolean())) - .thenReturn(CompletableFuture.failedFuture(testException)); + .thenReturn(CompletableFuture.failedFuture(testException)); - OpenSearchStatusException exception = assertThrows(OpenSearchStatusException.class, () -> { - sdkClient.getDataObject(getRequest); - }); + OpenSearchStatusException exception = assertThrows(OpenSearchStatusException.class, () -> { sdkClient.getDataObject(getRequest); }); assertEquals(testException, exception); - assertFalse(Thread.interrupted()); + assertFalse(Thread.interrupted()); verify(sdkClientImpl).getDataObjectAsync(any(GetDataObjectRequest.class), any(Executor.class), anyBoolean()); } @Test public void testGetDataObjectInterrupted() { when(sdkClientImpl.getDataObjectAsync(any(GetDataObjectRequest.class), any(Executor.class), anyBoolean())) - .thenReturn(CompletableFuture.failedFuture(interruptedException)); + .thenReturn(CompletableFuture.failedFuture(interruptedException)); - OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { - sdkClient.getDataObject(getRequest); - }); + OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { sdkClient.getDataObject(getRequest); }); assertEquals(interruptedException, exception.getCause()); - assertTrue(Thread.interrupted()); + assertTrue(Thread.interrupted()); verify(sdkClientImpl).getDataObjectAsync(any(GetDataObjectRequest.class), any(Executor.class), anyBoolean()); } @@ -201,6 +213,7 @@ public void testUpdateDataObjectSuccess() { assertEquals(updateResponse, sdkClient.updateDataObject(updateRequest)); verify(sdkClientImpl).updateDataObjectAsync(any(UpdateDataObjectRequest.class), any(Executor.class), anyBoolean()); } + @Test public void testUpdateDataObjectNullTenantId() { when(updateRequest.tenantId()).thenReturn(null); @@ -210,24 +223,23 @@ public void testUpdateDataObjectNullTenantId() { @Test public void testUpdateDataObjectException() { when(sdkClientImpl.updateDataObjectAsync(any(UpdateDataObjectRequest.class), any(Executor.class), anyBoolean())) - .thenReturn(CompletableFuture.failedFuture(testException)); - OpenSearchStatusException exception = assertThrows(OpenSearchStatusException.class, () -> { - sdkClient.updateDataObject(updateRequest); - }); + .thenReturn(CompletableFuture.failedFuture(testException)); + OpenSearchStatusException exception = assertThrows( + OpenSearchStatusException.class, + () -> { sdkClient.updateDataObject(updateRequest); } + ); assertEquals(testException, exception); - assertFalse(Thread.interrupted()); + assertFalse(Thread.interrupted()); verify(sdkClientImpl).updateDataObjectAsync(any(UpdateDataObjectRequest.class), any(Executor.class), anyBoolean()); } @Test public void testUpdateDataObjectInterrupted() { when(sdkClientImpl.updateDataObjectAsync(any(UpdateDataObjectRequest.class), any(Executor.class), anyBoolean())) - .thenReturn(CompletableFuture.failedFuture(interruptedException)); - OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { - sdkClient.updateDataObject(updateRequest); - }); + .thenReturn(CompletableFuture.failedFuture(interruptedException)); + OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { sdkClient.updateDataObject(updateRequest); }); assertEquals(interruptedException, exception.getCause()); - assertTrue(Thread.interrupted()); + assertTrue(Thread.interrupted()); verify(sdkClientImpl).updateDataObjectAsync(any(UpdateDataObjectRequest.class), any(Executor.class), anyBoolean()); } @@ -246,60 +258,117 @@ public void testDeleteDataObjectNullTenantId() { @Test public void testDeleteDataObjectException() { when(sdkClientImpl.deleteDataObjectAsync(any(DeleteDataObjectRequest.class), any(Executor.class), anyBoolean())) - .thenReturn(CompletableFuture.failedFuture(testException)); - OpenSearchStatusException exception = assertThrows(OpenSearchStatusException.class, () -> { - sdkClient.deleteDataObject(deleteRequest); - }); + .thenReturn(CompletableFuture.failedFuture(testException)); + OpenSearchStatusException exception = assertThrows( + OpenSearchStatusException.class, + () -> { sdkClient.deleteDataObject(deleteRequest); } + ); assertEquals(testException, exception); - assertFalse(Thread.interrupted()); + assertFalse(Thread.interrupted()); verify(sdkClientImpl).deleteDataObjectAsync(any(DeleteDataObjectRequest.class), any(Executor.class), anyBoolean()); } @Test public void testDeleteDataObjectInterrupted() { when(sdkClientImpl.deleteDataObjectAsync(any(DeleteDataObjectRequest.class), any(Executor.class), anyBoolean())) - .thenReturn(CompletableFuture.failedFuture(interruptedException)); - OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { - sdkClient.deleteDataObject(deleteRequest); - }); + .thenReturn(CompletableFuture.failedFuture(interruptedException)); + OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { sdkClient.deleteDataObject(deleteRequest); }); assertEquals(interruptedException, exception.getCause()); - assertTrue(Thread.interrupted()); + assertTrue(Thread.interrupted()); verify(sdkClientImpl).deleteDataObjectAsync(any(DeleteDataObjectRequest.class), any(Executor.class), anyBoolean()); } + @Test + public void testBulkDataObjectSuccess() { + assertEquals(bulkResponse, sdkClient.bulkDataObject(bulkRequest)); + verify(sdkClientImpl).bulkDataObjectAsync(any(BulkDataObjectRequest.class), any(Executor.class), anyBoolean()); + } + + @Test + public void testBulkDataObjectNullTenantId() { + DeleteDataObjectRequest deleteRequest = mock(DeleteDataObjectRequest.class); + when(deleteRequest.tenantId()).thenReturn(null); + when(bulkRequest.requests()).thenReturn(List.of(deleteRequest)); + assertThrows(IllegalArgumentException.class, () -> sdkClient.bulkDataObject(bulkRequest)); + } + + @Test + public void testBulkDataObjectException() { + when(sdkClientImpl.bulkDataObjectAsync(any(BulkDataObjectRequest.class), any(Executor.class), anyBoolean())) + .thenReturn(CompletableFuture.failedFuture(testException)); + OpenSearchStatusException exception = assertThrows( + OpenSearchStatusException.class, + () -> { sdkClient.bulkDataObject(bulkRequest); } + ); + assertEquals(testException, exception); + assertFalse(Thread.interrupted()); + verify(sdkClientImpl).bulkDataObjectAsync(any(BulkDataObjectRequest.class), any(Executor.class), anyBoolean()); + } + + @Test + public void testBulkDataObjectInterrupted() { + when(sdkClientImpl.bulkDataObjectAsync(any(BulkDataObjectRequest.class), any(Executor.class), anyBoolean())) + .thenReturn(CompletableFuture.failedFuture(interruptedException)); + OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { sdkClient.bulkDataObject(bulkRequest); }); + assertEquals(interruptedException, exception.getCause()); + assertTrue(Thread.interrupted()); + verify(sdkClientImpl).bulkDataObjectAsync(any(BulkDataObjectRequest.class), any(Executor.class), anyBoolean()); + } + @Test public void testSearchDataObjectSuccess() { assertEquals(searchResponse, sdkClient.searchDataObject(searchRequest)); verify(sdkClientImpl).searchDataObjectAsync(any(SearchDataObjectRequest.class), any(Executor.class), anyBoolean()); } + @Test public void testSearchDataObjectNullTenantId() { when(searchRequest.tenantId()).thenReturn(null); assertThrows(IllegalArgumentException.class, () -> sdkClient.searchDataObject(searchRequest)); } - @Test public void testSearchDataObjectException() { when(sdkClientImpl.searchDataObjectAsync(any(SearchDataObjectRequest.class), any(Executor.class), anyBoolean())) - .thenReturn(CompletableFuture.failedFuture(testException)); - OpenSearchStatusException exception = assertThrows(OpenSearchStatusException.class, () -> { - sdkClient.searchDataObject(searchRequest); - }); + .thenReturn(CompletableFuture.failedFuture(testException)); + OpenSearchStatusException exception = assertThrows( + OpenSearchStatusException.class, + () -> { sdkClient.searchDataObject(searchRequest); } + ); assertEquals(testException, exception); - assertFalse(Thread.interrupted()); + assertFalse(Thread.interrupted()); verify(sdkClientImpl).searchDataObjectAsync(any(SearchDataObjectRequest.class), any(Executor.class), anyBoolean()); } @Test public void testSearchDataObjectInterrupted() { when(sdkClientImpl.searchDataObjectAsync(any(SearchDataObjectRequest.class), any(Executor.class), anyBoolean())) - .thenReturn(CompletableFuture.failedFuture(interruptedException)); - OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { - sdkClient.searchDataObject(searchRequest); - }); + .thenReturn(CompletableFuture.failedFuture(interruptedException)); + OpenSearchException exception = assertThrows(OpenSearchException.class, () -> { sdkClient.searchDataObject(searchRequest); }); assertEquals(interruptedException, exception.getCause()); - assertTrue(Thread.interrupted()); + assertTrue(Thread.interrupted()); verify(sdkClientImpl).searchDataObjectAsync(any(SearchDataObjectRequest.class), any(Executor.class), anyBoolean()); } + + @Test + public void testExecutePrivilegedAsync() throws Exception { + PrivilegedAction action = () -> "Test Result"; + Executor executor = Executors.newCachedThreadPool(); + CompletionStage result = sdkClientImpl.executePrivilegedAsync(action, executor); + CompletableFuture future = result.toCompletableFuture(); + verify(sdkClientImpl, timeout(1000)).executePrivilegedAsync(any(), any()); + assertEquals("Test Result", future.get(5, TimeUnit.SECONDS)); + assertFalse(future.isCompletedExceptionally()); + } + + @Test + public void testExecutePrivilegedAsyncWithException() throws Exception { + PrivilegedAction action = () -> { throw new RuntimeException("Test Exception"); }; + Executor executor = Executors.newCachedThreadPool(); + CompletionStage result = sdkClientImpl.executePrivilegedAsync(action, executor); + CompletableFuture future = result.toCompletableFuture(); + verify(sdkClientImpl, timeout(1000)).executePrivilegedAsync(any(), any()); + assertThrows(ExecutionException.class, () -> future.get(5, TimeUnit.SECONDS)); + assertTrue(future.isCompletedExceptionally()); + } } diff --git a/common/src/test/java/org/opensearch/sdk/UpdateDataObjectResponseTests.java b/common/src/test/java/org/opensearch/sdk/UpdateDataObjectResponseTests.java index 7762305f5f..1d03ed599c 100644 --- a/common/src/test/java/org/opensearch/sdk/UpdateDataObjectResponseTests.java +++ b/common/src/test/java/org/opensearch/sdk/UpdateDataObjectResponseTests.java @@ -10,27 +10,50 @@ import org.junit.Before; import org.junit.Test; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; import static org.mockito.Mockito.mock; public class UpdateDataObjectResponseTests { + private String testIndex; private String testId; private XContentParser testParser; + private boolean testFailed; + private Exception testCause; + private RestStatus testStatus; @Before public void setUp() { + testIndex = "test-index"; testId = "test-id"; testParser = mock(XContentParser.class); + testFailed = true; + testCause = mock(RuntimeException.class); + testStatus = RestStatus.BAD_REQUEST; + } @Test public void testUpdateDataObjectResponse() { - UpdateDataObjectResponse response = UpdateDataObjectResponse.builder().id(testId).parser(testParser).build(); + UpdateDataObjectResponse response = UpdateDataObjectResponse + .builder() + .index(testIndex) + .id(testId) + .parser(testParser) + .failed(testFailed) + .cause(testCause) + .status(testStatus) + .build(); + assertEquals(testIndex, response.index()); assertEquals(testId, response.id()); - assertEquals(testParser, response.parser()); + assertSame(testParser, response.parser()); + assertEquals(testFailed, response.isFailed()); + assertSame(testCause, response.cause()); + assertEquals(testStatus, response.status()); } } diff --git a/common/src/test/java/org/opensearch/sdk/client/LocalClusterIndicesClientTests.java b/common/src/test/java/org/opensearch/sdk/client/LocalClusterIndicesClientTests.java index d71264e5b0..a912067236 100644 --- a/common/src/test/java/org/opensearch/sdk/client/LocalClusterIndicesClientTests.java +++ b/common/src/test/java/org/opensearch/sdk/client/LocalClusterIndicesClientTests.java @@ -8,18 +8,6 @@ */ package org.opensearch.sdk.client; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertThrows; -import static org.junit.Assert.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; - import java.io.IOException; import java.util.EnumSet; import java.util.Map; @@ -38,6 +26,9 @@ import org.opensearch.action.DocWriteRequest.OpType; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.DocWriteResponse.Result; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetRequest; @@ -69,6 +60,8 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.index.get.GetResult; +import org.opensearch.sdk.BulkDataObjectRequest; +import org.opensearch.sdk.BulkDataObjectResponse; import org.opensearch.sdk.DeleteDataObjectRequest; import org.opensearch.sdk.DeleteDataObjectResponse; import org.opensearch.sdk.GetDataObjectRequest; @@ -86,6 +79,18 @@ import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + public class LocalClusterIndicesClientTests { // Copied constants from MachineLearningPlugin.java @@ -135,7 +140,8 @@ public void testPutDataObject() throws IOException { PutDataObjectRequest putRequest = PutDataObjectRequest .builder() .index(TEST_INDEX) - .id(TEST_ID).tenantId(TEST_TENANT_ID) + .id(TEST_ID) + .tenantId(TEST_TENANT_ID) .overwriteIfExists(false) .dataObject(testDataObject) .build(); @@ -166,7 +172,12 @@ public void testPutDataObject() throws IOException { @Test public void testPutDataObject_Exception() throws IOException { - PutDataObjectRequest putRequest = PutDataObjectRequest.builder().index(TEST_INDEX).tenantId(TEST_TENANT_ID).dataObject(testDataObject).build(); + PutDataObjectRequest putRequest = PutDataObjectRequest + .builder() + .index(TEST_INDEX) + .tenantId(TEST_TENANT_ID) + .dataObject(testDataObject) + .build(); when(mockedClient.index(any(IndexRequest.class))).thenThrow(new UnsupportedOperationException("test")); @@ -188,7 +199,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws throw new IOException("test"); } }; - PutDataObjectRequest putRequest = PutDataObjectRequest.builder().index(TEST_INDEX).tenantId(TEST_TENANT_ID).dataObject(badDataObject).build(); + PutDataObjectRequest putRequest = PutDataObjectRequest + .builder() + .index(TEST_INDEX) + .tenantId(TEST_TENANT_ID) + .dataObject(badDataObject) + .build(); CompletableFuture future = sdkClient .putDataObjectAsync(putRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) @@ -300,7 +316,8 @@ public void testUpdateDataObject() throws IOException { UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest .builder() .index(TEST_INDEX) - .id(TEST_ID).tenantId(TEST_TENANT_ID) + .id(TEST_ID) + .tenantId(TEST_TENANT_ID) .retryOnConflict(3) .dataObject(testDataObject) .build(); @@ -344,7 +361,8 @@ public void testUpdateDataObjectWithMap() throws IOException { UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest .builder() .index(TEST_INDEX) - .id(TEST_ID).tenantId(TEST_TENANT_ID) + .id(TEST_ID) + .tenantId(TEST_TENANT_ID) .dataObject(Map.of("foo", "bar")) .build(); @@ -377,7 +395,8 @@ public void testUpdateDataObject_NotFound() throws IOException { UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest .builder() .index(TEST_INDEX) - .id(TEST_ID).tenantId(TEST_TENANT_ID) + .id(TEST_ID) + .tenantId(TEST_TENANT_ID) .dataObject(testDataObject) .build(); @@ -419,7 +438,8 @@ public void testUpdateDataObject_Null() throws IOException { UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest .builder() .index(TEST_INDEX) - .id(TEST_ID).tenantId(TEST_TENANT_ID) + .id(TEST_ID) + .tenantId(TEST_TENANT_ID) .dataObject(testDataObject) .build(); @@ -445,7 +465,8 @@ public void testUpdateDataObject_Exception() throws IOException { UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest .builder() .index(TEST_INDEX) - .id(TEST_ID).tenantId(TEST_TENANT_ID) + .id(TEST_ID) + .tenantId(TEST_TENANT_ID) .dataObject(testDataObject) .build(); @@ -467,7 +488,8 @@ public void testUpdateDataObject_VersionCheck() throws IOException { UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest .builder() .index(TEST_INDEX) - .id(TEST_ID).tenantId(TEST_TENANT_ID) + .id(TEST_ID) + .tenantId(TEST_TENANT_ID) .dataObject(testDataObject) .ifSeqNo(5) .ifPrimaryTerm(2) @@ -493,7 +515,12 @@ public void testUpdateDataObject_VersionCheck() throws IOException { @Test public void testDeleteDataObject() throws IOException { - DeleteDataObjectRequest deleteRequest = DeleteDataObjectRequest.builder().index(TEST_INDEX).id(TEST_ID).tenantId(TEST_TENANT_ID).build(); + DeleteDataObjectRequest deleteRequest = DeleteDataObjectRequest + .builder() + .index(TEST_INDEX) + .id(TEST_ID) + .tenantId(TEST_TENANT_ID) + .build(); DeleteResponse deleteResponse = new DeleteResponse(new ShardId(TEST_INDEX, "_na_", 0), TEST_ID, 1, 0, 2, true); PlainActionFuture future = PlainActionFuture.newFuture(); @@ -517,7 +544,12 @@ public void testDeleteDataObject() throws IOException { @Test public void testDeleteDataObject_Exception() throws IOException { - DeleteDataObjectRequest deleteRequest = DeleteDataObjectRequest.builder().index(TEST_INDEX).id(TEST_ID).tenantId(TEST_TENANT_ID).build(); + DeleteDataObjectRequest deleteRequest = DeleteDataObjectRequest + .builder() + .index(TEST_INDEX) + .id(TEST_ID) + .tenantId(TEST_TENANT_ID) + .build(); ArgumentCaptor deleteRequestCaptor = ArgumentCaptor.forClass(DeleteRequest.class); when(mockedClient.delete(deleteRequestCaptor.capture())).thenThrow(new UnsupportedOperationException("test")); @@ -532,6 +564,140 @@ public void testDeleteDataObject_Exception() throws IOException { assertEquals("test", cause.getMessage()); } + @Test + public void testBulkDataObject() throws IOException { + PutDataObjectRequest putRequest = PutDataObjectRequest.builder().id(TEST_ID + "1").tenantId(TEST_TENANT_ID).dataObject(testDataObject).build(); + UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest.builder().id(TEST_ID + "2").tenantId(TEST_TENANT_ID).dataObject(testDataObject).build(); + DeleteDataObjectRequest deleteRequest = DeleteDataObjectRequest.builder().id(TEST_ID + "3").tenantId(TEST_TENANT_ID).build(); + + BulkDataObjectRequest bulkRequest = BulkDataObjectRequest + .builder() + .globalIndex(TEST_INDEX) + .build() + .add(putRequest) + .add(updateRequest) + .add(deleteRequest); + + ShardId shardId = new ShardId(TEST_INDEX, "_na_", 0); + ShardInfo shardInfo = new ShardInfo(1, 1); + + IndexResponse indexResponse = new IndexResponse(shardId, TEST_ID + "1", 1, 1, 1, true); + indexResponse.setShardInfo(shardInfo); + + UpdateResponse updateResponse = new UpdateResponse(shardId, TEST_ID + "2", 1, 1, 1, DocWriteResponse.Result.UPDATED); + updateResponse.setShardInfo(shardInfo); + + DeleteResponse deleteResponse = new DeleteResponse(shardId, TEST_ID + "3", 1, 1, 1, true); + deleteResponse.setShardInfo(shardInfo); + + BulkResponse bulkResponse = new BulkResponse( + new BulkItemResponse[] { + new BulkItemResponse(0, OpType.INDEX, indexResponse), + new BulkItemResponse(1, OpType.UPDATE, updateResponse), + new BulkItemResponse(2, OpType.DELETE, deleteResponse) }, + 100L + ); + + @SuppressWarnings("unchecked") + ActionFuture future = mock(ActionFuture.class); + when(mockedClient.bulk(any(BulkRequest.class))).thenReturn(future); + when(future.actionGet()).thenReturn(bulkResponse); + + BulkDataObjectResponse response = sdkClient + .bulkDataObjectAsync(bulkRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(BulkRequest.class); + verify(mockedClient, times(1)).bulk(requestCaptor.capture()); + assertEquals(3, requestCaptor.getValue().numberOfActions()); + + assertEquals(3, response.getResponses().length); + assertEquals(100L, response.getTookInMillis()); + + assertTrue(response.getResponses()[0] instanceof PutDataObjectResponse); + assertTrue(response.getResponses()[1] instanceof UpdateDataObjectResponse); + assertTrue(response.getResponses()[2] instanceof DeleteDataObjectResponse); + + assertEquals(TEST_ID + "1", response.getResponses()[0].id()); + assertEquals(TEST_ID + "2", response.getResponses()[1].id()); + assertEquals(TEST_ID + "3", response.getResponses()[2].id()); + } + + @Test + public void testBulkDataObject_WithFailures() throws IOException { + PutDataObjectRequest putRequest = PutDataObjectRequest + .builder() + .id(TEST_ID + "1").tenantId(TEST_TENANT_ID) + .dataObject(testDataObject) + .build(); + UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest + .builder() + .id(TEST_ID + "2").tenantId(TEST_TENANT_ID) + .dataObject(testDataObject) + .build(); + DeleteDataObjectRequest deleteRequest = DeleteDataObjectRequest.builder().id(TEST_ID + "3").tenantId(TEST_TENANT_ID).build(); + + BulkDataObjectRequest bulkRequest = BulkDataObjectRequest + .builder() + .globalIndex(TEST_INDEX) + .build() + .add(putRequest) + .add(updateRequest) + .add(deleteRequest); + + BulkResponse bulkResponse = new BulkResponse( + new BulkItemResponse[] { + new BulkItemResponse(0, OpType.INDEX, new IndexResponse(new ShardId(TEST_INDEX, "_na_", 0), TEST_ID + "1", 1, 1, 1, true)), + new BulkItemResponse( + 1, + OpType.UPDATE, + new BulkItemResponse.Failure(TEST_INDEX, TEST_ID + "2", new Exception("Update failed")) + ), + new BulkItemResponse(0, OpType.DELETE, new IndexResponse(new ShardId(TEST_INDEX, "_na_", 0), TEST_ID + "3", 1, 1, 1, true)) + }, + 100L + ); + + @SuppressWarnings("unchecked") + ActionFuture future = mock(ActionFuture.class); + when(mockedClient.bulk(any(BulkRequest.class))).thenReturn(future); + when(future.actionGet()).thenReturn(bulkResponse); + + BulkDataObjectResponse response = sdkClient + .bulkDataObjectAsync(bulkRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + + assertEquals(3, response.getResponses().length); + assertFalse(response.getResponses()[0].isFailed()); + assertTrue(response.getResponses()[0] instanceof PutDataObjectResponse); + assertTrue(response.getResponses()[1].isFailed()); + assertTrue(response.getResponses()[1] instanceof UpdateDataObjectResponse); + assertFalse(response.getResponses()[2].isFailed()); + assertTrue(response.getResponses()[2] instanceof DeleteDataObjectResponse); + } + + @Test + public void testBulkDataObject_Exception() { + PutDataObjectRequest putRequest = PutDataObjectRequest.builder().index(TEST_INDEX).id(TEST_ID).tenantId(TEST_TENANT_ID).dataObject(testDataObject).build(); + + BulkDataObjectRequest bulkRequest = BulkDataObjectRequest.builder().build().add(putRequest); + + when(mockedClient.bulk(any(BulkRequest.class))) + .thenThrow(new OpenSearchStatusException("Failed to parse data object in a bulk response", RestStatus.INTERNAL_SERVER_ERROR)); + + CompletableFuture future = sdkClient + .bulkDataObjectAsync(bulkRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture(); + + CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); + Throwable cause = ce.getCause(); + assertEquals(OpenSearchStatusException.class, cause.getClass()); + assertEquals(RestStatus.INTERNAL_SERVER_ERROR, ((OpenSearchStatusException) cause).status()); + assertEquals("Failed to parse data object in a bulk response", cause.getMessage()); + } + @Test public void testSearchDataObjectNotTenantAware() throws IOException { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); @@ -642,7 +808,7 @@ public void testSearchDataObject_Exception() throws IOException { PlainActionFuture exceptionalFuture = PlainActionFuture.newFuture(); exceptionalFuture.onFailure(new UnsupportedOperationException("test")); when(mockedClient.search(any(SearchRequest.class))).thenReturn(exceptionalFuture); - + CompletableFuture future = sdkClient .searchDataObjectAsync(searchRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) .toCompletableFuture(); @@ -652,7 +818,7 @@ public void testSearchDataObject_Exception() throws IOException { assertEquals(UnsupportedOperationException.class, cause.getClass()); assertEquals("test", cause.getMessage()); } - + @Test public void testSearchDataObject_NullTenantNoMultitenancy() throws IOException { // Tests no status exception if multitenancy not enabled diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java index 6456039774..9037118056 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java @@ -7,6 +7,7 @@ import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.CommonValue.UNDEPLOYED; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import java.io.IOException; import java.util.ArrayList; @@ -17,12 +18,10 @@ import java.util.stream.Collectors; import org.opensearch.action.FailedNodeException; -import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.WriteRequest; import org.opensearch.action.support.nodes.TransportNodesAction; -import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; @@ -44,6 +43,10 @@ import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStats; +import org.opensearch.sdk.BulkDataObjectRequest; +import org.opensearch.sdk.SdkClient; +import org.opensearch.sdk.SdkClientUtils; +import org.opensearch.sdk.UpdateDataObjectRequest; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -60,6 +63,7 @@ public class TransportUndeployModelAction extends private final Client client; private final DiscoveryNodeHelper nodeFilter; private final MLStats mlStats; + private final SdkClient sdkClient; @Inject public TransportUndeployModelAction( @@ -69,6 +73,7 @@ public TransportUndeployModelAction( ClusterService clusterService, ThreadPool threadPool, Client client, + SdkClient sdkClient, DiscoveryNodeHelper nodeFilter, MLStats mlStats ) { @@ -87,6 +92,7 @@ public TransportUndeployModelAction( this.clusterService = clusterService; this.client = client; + this.sdkClient = sdkClient; this.nodeFilter = nodeFilter; this.mlStats = mlStats; } @@ -94,12 +100,13 @@ public TransportUndeployModelAction( @Override protected void doExecute(Task task, MLUndeployModelNodesRequest request, ActionListener listener) { ActionListener wrappedListener = ActionListener.wrap(undeployModelNodesResponse -> { - processUndeployModelResponseAndUpdate(undeployModelNodesResponse, listener); + processUndeployModelResponseAndUpdate(request, undeployModelNodesResponse, listener); }, listener::onFailure); super.doExecute(task, request, wrappedListener); } void processUndeployModelResponseAndUpdate( + MLUndeployModelNodesRequest undeployModelNodesRequest, MLUndeployModelNodesResponse undeployModelNodesResponse, ActionListener listener ) { @@ -145,11 +152,11 @@ void processUndeployModelResponseAndUpdate( MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(nodeFilter.getAllNodes(), syncUpInput); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - if (actualRemovedNodesMap.size() > 0) { - BulkRequest bulkRequest = new BulkRequest(); + if (!actualRemovedNodesMap.isEmpty()) { + BulkDataObjectRequest bulkRequest = BulkDataObjectRequest.builder().globalIndex(ML_MODEL_INDEX).build(); + String tenantId = undeployModelNodesRequest.getTenantId(); Map deployToAllNodes = new HashMap<>(); for (String modelId : actualRemovedNodesMap.keySet()) { - UpdateRequest updateRequest = new UpdateRequest(); List removedNodes = actualRemovedNodesMap.get(modelId); int removedNodeCount = removedNodes.size(); /** @@ -178,7 +185,12 @@ void processUndeployModelResponseAndUpdate( updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, newPlanningWorkerNodes.size()); deployToAllNodes.put(modelId, false); } - updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(updateDocument); + UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest + .builder() + .id(modelId) + .tenantId(tenantId) + .dataObject(updateDocument) + .build(); bulkRequest.add(updateRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); } syncUpInput.setDeployToAllNodes(deployToAllNodes); @@ -189,10 +201,35 @@ void processUndeployModelResponseAndUpdate( Arrays.toString(actualRemovedNodesMap.keySet().toArray(new String[0])) ); }, e -> { log.error("Failed to update model state as undeployed", e); }); - client.bulk(bulkRequest, ActionListener.runAfter(actionListener, () -> { + ActionListener wrappedListener = ActionListener.runAfter(actionListener, () -> { syncUpUndeployedModels(syncUpRequest); listener.onResponse(undeployModelNodesResponse); - })); + }); + sdkClient + .bulkDataObjectAsync(bulkRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) + .whenComplete((r, throwable) -> { + if (throwable != null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + log.error("Failed to execute BulkDataObject request", cause); + wrappedListener.onFailure(cause); + } else { + try { + BulkResponse bulkResponse = BulkResponse.fromXContent(r.parser()); + log + .info( + "Executed {} bulk operations with {} failures, Took: {}", + bulkResponse.getItems().length, + bulkResponse.hasFailures() + ? Arrays.stream(bulkResponse.getItems()).filter(i -> i.isFailed()).count() + : 0, + bulkResponse.getTook() + ); + wrappedListener.onResponse(bulkResponse); + } catch (Exception e) { + wrappedListener.onFailure(e); + } + } + }); } else { syncUpUndeployedModels(syncUpRequest); listener.onResponse(undeployModelNodesResponse); diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java index ba82d91e00..9092b6fd0d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java @@ -131,7 +131,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (hasPermissionToUndeploy) { - undeployModels(targetNodeIds, modelIds, listener); + undeployModels(targetNodeIds, modelIds, tenantId, listener); } else { listener.onFailure(new IllegalArgumentException("No permission to undeploy model " + modelId)); } @@ -157,9 +157,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener !hiddenModelIds.contains(modelId)) .toArray(String[]::new); - undeployModels(targetNodeIds, modelsIDsToUndeploy, listener); + undeployModels(targetNodeIds, modelsIDsToUndeploy, tenantId, listener); } else { - undeployModels(targetNodeIds, modelIds, listener); + undeployModels(targetNodeIds, modelIds, tenantId, listener); } }, e -> { log.error("Failed to search model index", e); @@ -169,8 +169,14 @@ protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + private void undeployModels( + String[] targetNodeIds, + String[] modelIds, + String tenantId, + ActionListener listener + ) { MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest(targetNodeIds, modelIds); + mlUndeployModelNodesRequest.setTenantId(tenantId); client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> { listener.onResponse(new MLUndeployModelsResponse(r)); diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java index 9e704b4866..dfddb9367f 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java @@ -102,7 +102,7 @@ private void startSyncModelRoutingCron() { log.info("Starting ML sync up job..."); syncModelRoutingCron = threadPool .scheduleWithFixedDelay( - new MLSyncUpCron(client, clusterService, nodeHelper, mlIndicesHandler, encryptor, mlFeatureEnabledSetting), + new MLSyncUpCron(client, sdkClient, clusterService, nodeHelper, mlIndicesHandler, encryptor, mlFeatureEnabledSetting), TimeValue.timeValueSeconds(jobInterval), GENERAL_THREAD_POOL ); diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java index 0c232364a3..5bbea3cb9c 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.MLConfig.CREATE_TIME_FIELD; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import static org.opensearch.ml.utils.RestActionUtils.getAllNodes; import java.time.Instant; @@ -23,12 +24,10 @@ import java.util.stream.Collectors; import org.opensearch.action.DocWriteRequest; -import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.get.GetRequest; import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.WriteRequest; -import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; @@ -36,6 +35,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.model.MLModelState; @@ -49,6 +49,11 @@ import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.sdk.BulkDataObjectRequest; +import org.opensearch.sdk.SdkClient; +import org.opensearch.sdk.SdkClientUtils; +import org.opensearch.sdk.SearchDataObjectRequest; +import org.opensearch.sdk.UpdateDataObjectRequest; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; @@ -62,6 +67,7 @@ public class MLSyncUpCron implements Runnable { public static final int DEPLOY_MODEL_TASK_GRACE_TIME_IN_MS = 20_000; private final Client client; + private final SdkClient sdkClient; private final ClusterService clusterService; private final DiscoveryNodeHelper nodeHelper; private final MLIndicesHandler mlIndicesHandler; @@ -73,6 +79,7 @@ public class MLSyncUpCron implements Runnable { public MLSyncUpCron( Client client, + SdkClient sdkClient, ClusterService clusterService, DiscoveryNodeHelper nodeHelper, MLIndicesHandler mlIndicesHandler, @@ -80,6 +87,7 @@ public MLSyncUpCron( MLFeatureEnabledSetting mlFeatureEnabledSetting ) { this.client = client; + this.sdkClient = sdkClient; this.clusterService = clusterService; this.nodeHelper = nodeHelper; this.mlIndicesHandler = mlIndicesHandler; @@ -262,7 +270,6 @@ void refreshModelState(Map> modelWorkerNodes, Map> modelWorkerNodes, Map> modelWorkerNodes, Map { - SearchHit[] hits = res.getHits().getHits(); - Map newModelStates = new HashMap<>(); - Map> newPlanningWorkerNodes = new HashMap<>(); - for (SearchHit hit : hits) { - String modelId = hit.getId(); - Map sourceAsMap = hit.getSourceAsMap(); - FunctionName functionName = FunctionName.from((String) sourceAsMap.get(MLModel.ALGORITHM_FIELD)); - MLModelState state = MLModelState.from((String) sourceAsMap.get(MLModel.MODEL_STATE_FIELD)); - Long lastUpdateTime = sourceAsMap.containsKey(MLModel.LAST_UPDATED_TIME_FIELD) - ? (Long) sourceAsMap.get(MLModel.LAST_UPDATED_TIME_FIELD) - : null; - int planningWorkerNodeCount = sourceAsMap.containsKey(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD) - ? (int) sourceAsMap.get(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD) - : 0; - int currentWorkerNodeCountInIndex = sourceAsMap.containsKey(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD) - ? (int) sourceAsMap.get(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD) - : 0; - boolean deployToAllNodes = sourceAsMap.containsKey(MLModel.DEPLOY_TO_ALL_NODES_FIELD) - && (boolean) sourceAsMap.get(MLModel.DEPLOY_TO_ALL_NODES_FIELD); - List planningWorkNodes = sourceAsMap.containsKey(MLModel.PLANNING_WORKER_NODES_FIELD) - ? (List) sourceAsMap.get(MLModel.PLANNING_WORKER_NODES_FIELD) - : new ArrayList<>(); - if (deployToAllNodes) { - DiscoveryNode[] eligibleNodes = nodeHelper.getEligibleNodes(functionName); - planningWorkerNodeCount = eligibleNodes.length; - List eligibleNodeIds = Arrays.stream(eligibleNodes).map(DiscoveryNode::getId).collect(Collectors.toList()); - if (eligibleNodeIds.size() != planningWorkNodes.size() || !eligibleNodeIds.containsAll(planningWorkNodes)) { - newPlanningWorkerNodes.put(modelId, eligibleNodeIds); + SearchDataObjectRequest searchRequest = SearchDataObjectRequest + .builder() + .indices(ML_MODEL_INDEX) + .searchSourceBuilder(sourceBuilder) + .build(); + sdkClient + .searchDataObjectAsync(searchRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) + .whenComplete((r, throwable) -> { + if (throwable == null) { + try { + SearchResponse res = SearchResponse.fromXContent(r.parser()); + SearchHit[] hits = res.getHits().getHits(); + Map tenantIds = new HashMap<>(); + Map newModelStates = new HashMap<>(); + Map> newPlanningWorkerNodes = new HashMap<>(); + for (SearchHit hit : hits) { + String modelId = hit.getId(); + Map sourceAsMap = hit.getSourceAsMap(); + if (sourceAsMap.containsKey(CommonValue.TENANT_ID)) { + tenantIds.put(modelId, (String) sourceAsMap.get(CommonValue.TENANT_ID)); + } + FunctionName functionName = FunctionName.from((String) sourceAsMap.get(MLModel.ALGORITHM_FIELD)); + MLModelState state = MLModelState.from((String) sourceAsMap.get(MLModel.MODEL_STATE_FIELD)); + Long lastUpdateTime = sourceAsMap.containsKey(MLModel.LAST_UPDATED_TIME_FIELD) + ? (Long) sourceAsMap.get(MLModel.LAST_UPDATED_TIME_FIELD) + : null; + int planningWorkerNodeCount = sourceAsMap.containsKey(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD) + ? (int) sourceAsMap.get(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD) + : 0; + int currentWorkerNodeCountInIndex = sourceAsMap.containsKey(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD) + ? (int) sourceAsMap.get(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD) + : 0; + boolean deployToAllNodes = sourceAsMap.containsKey(MLModel.DEPLOY_TO_ALL_NODES_FIELD) + && (boolean) sourceAsMap.get(MLModel.DEPLOY_TO_ALL_NODES_FIELD); + List planningWorkNodes = sourceAsMap.containsKey(MLModel.PLANNING_WORKER_NODES_FIELD) + ? (List) sourceAsMap.get(MLModel.PLANNING_WORKER_NODES_FIELD) + : new ArrayList<>(); + if (deployToAllNodes) { + DiscoveryNode[] eligibleNodes = nodeHelper.getEligibleNodes(functionName); + planningWorkerNodeCount = eligibleNodes.length; + List eligibleNodeIds = Arrays + .stream(eligibleNodes) + .map(DiscoveryNode::getId) + .collect(Collectors.toList()); + if (eligibleNodeIds.size() != planningWorkNodes.size() + || !eligibleNodeIds.containsAll(planningWorkNodes)) { + newPlanningWorkerNodes.put(modelId, eligibleNodeIds); + } + } + MLModelState mlModelState = getNewModelState( + deployingModels, + modelWorkerNodes, + modelId, + state, + lastUpdateTime, + planningWorkerNodeCount, + currentWorkerNodeCountInIndex + ); + if (mlModelState != null) { + newModelStates.put(modelId, mlModelState); + } + } + bulkUpdateModelState(modelWorkerNodes, newModelStates, newPlanningWorkerNodes, tenantIds); + } catch (Exception e) { + log.error("Failed to parse model search response", e); + updateModelStateSemaphore.release(); } + } else { + Exception e = SdkClientUtils.unwrapAndConvertToException(throwable); + updateModelStateSemaphore.release(); + log.error("Failed to search models", e); } - MLModelState mlModelState = getNewModelState( - deployingModels, - modelWorkerNodes, - modelId, - state, - lastUpdateTime, - planningWorkerNodeCount, - currentWorkerNodeCountInIndex - ); - if (mlModelState != null) { - newModelStates.put(modelId, mlModelState); - } - } - bulkUpdateModelState(modelWorkerNodes, newModelStates, newPlanningWorkerNodes); - }, e -> { - updateModelStateSemaphore.release(); - log.error("Failed to search models", e); - })); + }); } catch (Exception e) { updateModelStateSemaphore.release(); log.error("Failed to refresh model state", e); @@ -404,40 +435,49 @@ private MLModelState getNewModelState( private void bulkUpdateModelState( Map> modelWorkerNodes, Map newModelStates, - Map> newPlanningWorkNodes + Map> newPlanningWorkNodes, + Map tenantIds ) { Set updatedModelIds = new HashSet<>(); updatedModelIds.addAll(newModelStates.keySet()); updatedModelIds.addAll(newPlanningWorkNodes.keySet()); if (!updatedModelIds.isEmpty()) { - BulkRequest bulkUpdateRequest = new BulkRequest(); + BulkDataObjectRequest bulkUpdateRequest = BulkDataObjectRequest.builder().globalIndex(ML_MODEL_INDEX).build(); for (String modelId : updatedModelIds) { - UpdateRequest updateRequest = new UpdateRequest(); Instant now = Instant.now(); - ImmutableMap.Builder builder = ImmutableMap.builder(); + Map updateDocument = new HashMap<>(); if (newModelStates.containsKey(modelId)) { - builder.put(MLModel.MODEL_STATE_FIELD, newModelStates.get(modelId).name()); + updateDocument.put(MLModel.MODEL_STATE_FIELD, newModelStates.get(modelId).name()); } if (newPlanningWorkNodes.containsKey(modelId)) { - builder.put(MLModel.PLANNING_WORKER_NODES_FIELD, newPlanningWorkNodes.get(modelId)); - builder.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, newPlanningWorkNodes.get(modelId).size()); + updateDocument.put(MLModel.PLANNING_WORKER_NODES_FIELD, newPlanningWorkNodes.get(modelId)); + updateDocument.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, newPlanningWorkNodes.get(modelId).size()); } - builder.put(MLModel.LAST_UPDATED_TIME_FIELD, now.toEpochMilli()); + updateDocument.put(MLModel.LAST_UPDATED_TIME_FIELD, now.toEpochMilli()); Set workerNodes = modelWorkerNodes.get(modelId); int currentWorkNodeCount = workerNodes == null ? 0 : workerNodes.size(); - builder.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, currentWorkNodeCount); - updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(builder.build()); + updateDocument.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, currentWorkNodeCount); + UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest + .builder() + .tenantId(tenantIds.get(modelId)) + .id(modelId) + .dataObject(updateDocument) + .build(); bulkUpdateRequest.add(updateRequest); } log.info("Refresh model state: {}", newModelStates); - client.bulk(bulkUpdateRequest, ActionListener.wrap(br -> { - updateModelStateSemaphore.release(); - log.debug("Refresh model state successfully"); - }, e -> { - updateModelStateSemaphore.release(); - log.error("Failed to bulk update model state", e); - })); + sdkClient + .bulkDataObjectAsync(bulkUpdateRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) + .whenComplete((r, throwable) -> { + updateModelStateSemaphore.release(); + if (throwable != null) { + Exception e = SdkClientUtils.unwrapAndConvertToException(throwable); + log.error("Failed to bulk update model state", e); + } else { + log.debug("Refresh model state successfully"); + } + }); } else { updateModelStateSemaphore.release(); } diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java index c2621606f9..844adb5fc0 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java @@ -14,30 +14,46 @@ import static org.opensearch.ml.common.CommonValue.TENANT_ID; import java.io.IOException; -import java.security.AccessController; -import java.security.PrivilegedAction; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; -import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; +import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.DocWriteRequest.OpType; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.replication.ReplicationResponse.ShardInfo; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.sdkclient.util.JsonTransformer; +import org.opensearch.sdk.AbstractSdkClient; +import org.opensearch.sdk.BulkDataObjectRequest; +import org.opensearch.sdk.BulkDataObjectResponse; +import org.opensearch.sdk.DataObjectRequest; +import org.opensearch.sdk.DataObjectResponse; import org.opensearch.sdk.DeleteDataObjectRequest; import org.opensearch.sdk.DeleteDataObjectResponse; import org.opensearch.sdk.GetDataObjectRequest; @@ -45,7 +61,7 @@ import org.opensearch.sdk.PutDataObjectRequest; import org.opensearch.sdk.PutDataObjectResponse; import org.opensearch.sdk.SdkClient; -import org.opensearch.sdk.SdkClientDelegate; +import org.opensearch.sdk.SdkClientUtils; import org.opensearch.sdk.SearchDataObjectRequest; import org.opensearch.sdk.SearchDataObjectResponse; import org.opensearch.sdk.UpdateDataObjectRequest; @@ -73,7 +89,7 @@ * */ @Log4j2 -public class DDBOpenSearchClient implements SdkClientDelegate { +public class DDBOpenSearchClient extends AbstractSdkClient { private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); @@ -120,7 +136,7 @@ public CompletionStage putDataObjectAsync( final String tenantId = request.tenantId() != null ? request.tenantId() : DEFAULT_TENANT; final String tableName = request.index(); final GetItemRequest getItemRequest = buildGetItemRequest(tenantId, id, request.index()); - return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + return executePrivilegedAsync(() -> { try { GetItemResponse getItemResponse = dynamoDbClient.getItem(getItemRequest); Long sequenceNumber = initOrIncrementSeqNo(getItemResponse); @@ -157,7 +173,7 @@ public CompletionStage putDataObjectAsync( // Rethrow unchecked exception on XContent parsing error throw new OpenSearchStatusException("Failed to parse data object " + request.id(), RestStatus.BAD_REQUEST); } - }), executor); + }, executor); } /** @@ -172,7 +188,7 @@ public CompletionStage getDataObjectAsync( Boolean isMultiTenancyEnabled ) { final GetItemRequest getItemRequest = buildGetItemRequest(request.tenantId(), request.id(), request.index()); - return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + return executePrivilegedAsync(() -> { try { final GetItemResponse getItemResponse = dynamoDbClient.getItem(getItemRequest); ObjectNode sourceObject; @@ -213,7 +229,7 @@ public CompletionStage getDataObjectAsync( // Rethrow unchecked exception on XContent parsing error throw new OpenSearchStatusException("Failed to parse response", RestStatus.INTERNAL_SERVER_ERROR); } - }), executor); + }, executor); } /** @@ -228,7 +244,7 @@ public CompletionStage updateDataObjectAsync( Boolean isMultiTenancyEnabled ) { final String tenantId = request.tenantId() != null ? request.tenantId() : DEFAULT_TENANT; - return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + return executePrivilegedAsync(() -> { try { String source = Strings.toString(MediaTypeRegistry.JSON, request.dataObject()); JsonNode jsonNode = OBJECT_MAPPER.readTree(source); @@ -250,7 +266,7 @@ public CompletionStage updateDataObjectAsync( RestStatus.BAD_REQUEST ); } - }), executor); + }, executor); } private Long updateItemWithRetryOnConflict(String tenantId, JsonNode jsonNode, UpdateDataObjectRequest request) { @@ -330,7 +346,7 @@ public CompletionStage deleteDataObjectAsync( ) ) .build(); - return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + return executePrivilegedAsync(() -> { try { DeleteItemResponse deleteItemResponse = dynamoDbClient.deleteItem(deleteItemRequest); Long sequenceNumber = null; @@ -349,7 +365,140 @@ public CompletionStage deleteDataObjectAsync( // Rethrow unchecked exception on XContent parsing error throw new OpenSearchStatusException("Failed to parse response", RestStatus.INTERNAL_SERVER_ERROR); } - }), executor); + }, executor); + } + + @Override + public CompletionStage bulkDataObjectAsync( + BulkDataObjectRequest request, + Executor executor, + Boolean isMultiTenancyEnabled + ) { + return executePrivilegedAsync(() -> { + log.info("Performing {} bulk actions on table {}", request.requests().size(), request.getIndices()); + + List responses = new ArrayList<>(); + + // TODO: Ideally if we only have put and delete requests we can use DynamoDB BatchWriteRequest. + long startNanos = System.nanoTime(); + for (DataObjectRequest dataObjectRequest : request.requests()) { + try { + if (dataObjectRequest instanceof PutDataObjectRequest) { + responses + .add( + putDataObjectAsync((PutDataObjectRequest) dataObjectRequest, executor, isMultiTenancyEnabled) + .toCompletableFuture() + .join() + ); + } else if (dataObjectRequest instanceof UpdateDataObjectRequest) { + responses + .add( + updateDataObjectAsync((UpdateDataObjectRequest) dataObjectRequest, executor, isMultiTenancyEnabled) + .toCompletableFuture() + .join() + ); + } else if (dataObjectRequest instanceof DeleteDataObjectRequest) { + responses + .add( + deleteDataObjectAsync((DeleteDataObjectRequest) dataObjectRequest, executor, isMultiTenancyEnabled) + .toCompletableFuture() + .join() + ); + } + } catch (CompletionException e) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(e); + RestStatus status = ExceptionsHelper.status(cause); + if (dataObjectRequest instanceof PutDataObjectRequest) { + responses + .add( + new PutDataObjectResponse.Builder() + .index(dataObjectRequest.index()) + .id(dataObjectRequest.id()) + .failed(true) + .cause(cause) + .status(status) + .build() + ); + } else if (dataObjectRequest instanceof UpdateDataObjectRequest) { + responses + .add( + new UpdateDataObjectResponse.Builder() + .index(dataObjectRequest.index()) + .id(dataObjectRequest.id()) + .failed(true) + .cause(cause) + .status(status) + .build() + ); + } else if (dataObjectRequest instanceof DeleteDataObjectRequest) { + responses + .add( + new DeleteDataObjectResponse.Builder() + .index(dataObjectRequest.index()) + .id(dataObjectRequest.id()) + .failed(true) + .cause(cause) + .status(status) + .build() + ); + } + log.error("Error in bulk operation for id {}: {}", dataObjectRequest.id(), e.getCause().getMessage(), e.getCause()); + } + } + long endNanos = System.nanoTime(); + long tookMillis = TimeUnit.NANOSECONDS.toMillis(endNanos - startNanos); + + log.info("Bulk action complete for {} items, took {} ms", responses.size(), tookMillis); + return buildBulkDataObjectResponse(responses, tookMillis); + }, executor); + } + + private BulkDataObjectResponse buildBulkDataObjectResponse(List responses, long tookMillis) { + // Reconstruct BulkResponse to leverage its parser and hasFailed methods + BulkItemResponse[] responseArray = new BulkItemResponse[responses.size()]; + try { + for (int id = 0; id < responses.size(); id++) { + responseArray[id] = buildBulkItemResponse(responses, id); + } + BulkResponse br = new BulkResponse(responseArray, tookMillis); + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + br.toXContent(builder, ToXContent.EMPTY_PARAMS); + return new BulkDataObjectResponse( + responses.toArray(new DataObjectResponse[0]), + tookMillis, + br.hasFailures(), + createParser(builder.toString()) + ); + } + } catch (IOException e) { + // Rethrow unchecked exception on XContent parsing error + throw new OpenSearchStatusException("Failed to parse bulk response", RestStatus.INTERNAL_SERVER_ERROR); + } + } + + private BulkItemResponse buildBulkItemResponse(List responses, int bulkId) throws IOException { + DataObjectResponse response = responses.get(bulkId); + OpType opType = null; + if (response instanceof PutDataObjectResponse) { + opType = OpType.INDEX; + } else if (response instanceof UpdateDataObjectResponse) { + opType = OpType.UPDATE; + } else if (response instanceof DeleteDataObjectResponse) { + opType = OpType.DELETE; + } + // If failed, parser is null, so shortcut response here + if (response.isFailed()) { + return new BulkItemResponse(bulkId, opType, new BulkItemResponse.Failure(response.index(), response.id(), response.cause())); + } + DocWriteResponse writeResponse = null; + if (response instanceof PutDataObjectResponse) { + writeResponse = IndexResponse.fromXContent(response.parser()); + } else if (response instanceof UpdateDataObjectResponse) { + writeResponse = UpdateResponse.fromXContent(response.parser()); + } else if (response instanceof DeleteDataObjectResponse) { + writeResponse = DeleteResponse.fromXContent(response.parser()); + } + return new BulkItemResponse(bulkId, opType, writeResponse); } /** diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java index 73ce589249..9dd2959043 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClient.java @@ -13,11 +13,10 @@ import java.io.IOException; import java.io.StringReader; import java.io.StringWriter; -import java.security.AccessController; -import java.security.PrivilegedAction; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.concurrent.Executor; @@ -28,10 +27,13 @@ import org.opensearch.client.opensearch._types.FieldValue; import org.opensearch.client.opensearch._types.OpType; import org.opensearch.client.opensearch._types.OpenSearchException; +import org.opensearch.client.opensearch._types.Refresh; import org.opensearch.client.opensearch._types.query_dsl.BoolQuery; import org.opensearch.client.opensearch._types.query_dsl.MatchAllQuery; import org.opensearch.client.opensearch._types.query_dsl.Query; import org.opensearch.client.opensearch._types.query_dsl.TermQuery; +import org.opensearch.client.opensearch.core.BulkRequest; +import org.opensearch.client.opensearch.core.BulkResponse; import org.opensearch.client.opensearch.core.DeleteRequest; import org.opensearch.client.opensearch.core.DeleteResponse; import org.opensearch.client.opensearch.core.GetRequest; @@ -43,6 +45,8 @@ import org.opensearch.client.opensearch.core.UpdateRequest; import org.opensearch.client.opensearch.core.UpdateRequest.Builder; import org.opensearch.client.opensearch.core.UpdateResponse; +import org.opensearch.client.opensearch.core.bulk.BulkOperation; +import org.opensearch.client.opensearch.core.bulk.BulkResponseItem; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.json.JsonXContent; @@ -56,6 +60,11 @@ import org.opensearch.index.query.MatchPhraseQueryBuilder; import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.sdkclient.util.JsonTransformer; +import org.opensearch.sdk.AbstractSdkClient; +import org.opensearch.sdk.BulkDataObjectRequest; +import org.opensearch.sdk.BulkDataObjectResponse; +import org.opensearch.sdk.DataObjectRequest; +import org.opensearch.sdk.DataObjectResponse; import org.opensearch.sdk.DeleteDataObjectRequest; import org.opensearch.sdk.DeleteDataObjectResponse; import org.opensearch.sdk.GetDataObjectRequest; @@ -63,7 +72,6 @@ import org.opensearch.sdk.PutDataObjectRequest; import org.opensearch.sdk.PutDataObjectResponse; import org.opensearch.sdk.SdkClient; -import org.opensearch.sdk.SdkClientDelegate; import org.opensearch.sdk.SdkClientUtils; import org.opensearch.sdk.SearchDataObjectRequest; import org.opensearch.sdk.SearchDataObjectResponse; @@ -75,10 +83,11 @@ import lombok.extern.log4j.Log4j2; /** - * An implementation of {@link SdkClient} that stores data in a remote OpenSearch cluster using the OpenSearch Java Client. + * An implementation of {@link SdkClient} that stores data in a remote + * OpenSearch cluster using the OpenSearch Java Client. */ @Log4j2 -public class RemoteClusterIndicesClient implements SdkClientDelegate { +public class RemoteClusterIndicesClient extends AbstractSdkClient { @SuppressWarnings("unchecked") private static final Class> MAP_DOCTYPE = (Class>) (Class) Map.class; @@ -88,6 +97,7 @@ public class RemoteClusterIndicesClient implements SdkClientDelegate { /** * Instantiate this object with an OpenSearch Java client. + * * @param openSearchClient The client to wrap */ public RemoteClusterIndicesClient(OpenSearchClient openSearchClient) { @@ -101,7 +111,7 @@ public CompletionStage putDataObjectAsync( Executor executor, Boolean isMultiTenancyEnabled ) { - return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + return executePrivilegedAsync(() -> { try { IndexRequest.Builder builder = new IndexRequest.Builder<>() .index(request.index()) @@ -124,7 +134,7 @@ public CompletionStage putDataObjectAsync( RestStatus.BAD_REQUEST ); } - }), executor); + }, executor); } @Override @@ -133,7 +143,7 @@ public CompletionStage getDataObjectAsync( Executor executor, Boolean isMultiTenancyEnabled ) { - return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + return executePrivilegedAsync(() -> { try { GetRequest getRequest = new GetRequest.Builder().index(request.index()).id(request.id()).build(); log.info("Getting {} from {}", request.id(), request.index()); @@ -149,7 +159,7 @@ public CompletionStage getDataObjectAsync( RestStatus.INTERNAL_SERVER_ERROR ); } - }), executor); + }, executor); } @Override @@ -158,7 +168,7 @@ public CompletionStage updateDataObjectAsync( Executor executor, Boolean isMultiTenancyEnabled ) { - return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + return executePrivilegedAsync(() -> { try (XContentBuilder builder = XContentFactory.jsonBuilder()) { request.dataObject().toXContent(builder, ToXContent.EMPTY_PARAMS); Map docMap = JsonXContent.jsonXContent @@ -199,7 +209,7 @@ public CompletionStage updateDataObjectAsync( RestStatus.BAD_REQUEST ); } - }), executor); + }, executor); } @Override @@ -208,7 +218,7 @@ public CompletionStage deleteDataObjectAsync( Executor executor, Boolean isMultiTenancyEnabled ) { - return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + return executePrivilegedAsync(() -> { try { DeleteRequest deleteRequest = new DeleteRequest.Builder().index(request.index()).id(request.id()).build(); log.info("Deleting {} from {}", request.id(), request.index()); @@ -223,7 +233,151 @@ public CompletionStage deleteDataObjectAsync( RestStatus.INTERNAL_SERVER_ERROR ); } - }), executor); + }, executor); + } + + @Override + public CompletionStage bulkDataObjectAsync( + BulkDataObjectRequest request, + Executor executor, + Boolean isMultiTenancyEnabled + ) { + return executePrivilegedAsync(() -> { + try { + log.info("Performing {} bulk actions on indices {}", request.requests().size(), request.getIndices()); + List operations = new ArrayList<>(); + for (DataObjectRequest dataObjectRequest : request.requests()) { + addBulkOperation(dataObjectRequest, operations); + } + BulkRequest bulkRequest = new BulkRequest.Builder().operations(operations).refresh(Refresh.True).build(); + BulkResponse bulkResponse = openSearchClient.bulk(bulkRequest); + log + .info( + "Bulk action complete for {} items: {}", + bulkResponse.items().size(), + bulkResponse.errors() ? "has failures" : "success" + ); + DataObjectResponse[] responses = bulkResponseItemsToArray(bulkResponse.items()); + return bulkResponse.ingestTook() == null + ? new BulkDataObjectResponse(responses, bulkResponse.took(), bulkResponse.errors(), createParser(bulkResponse)) + : new BulkDataObjectResponse( + responses, + bulkResponse.took(), + bulkResponse.ingestTook().longValue(), + bulkResponse.errors(), + createParser(bulkResponse) + ); + } catch (IOException e) { + // Rethrow unchecked exception on XContent parsing error + throw new OpenSearchStatusException("Failed to parse data object in a bulk response", RestStatus.INTERNAL_SERVER_ERROR); + } + }, executor); + } + + private void addBulkOperation(DataObjectRequest dataObjectRequest, List operations) { + if (dataObjectRequest instanceof PutDataObjectRequest) { + addBulkPutOperation((PutDataObjectRequest) dataObjectRequest, operations); + } else if (dataObjectRequest instanceof UpdateDataObjectRequest) { + addBulkUpdateOperation((UpdateDataObjectRequest) dataObjectRequest, operations); + } else if (dataObjectRequest instanceof DeleteDataObjectRequest) { + addBulkDeleteOperation((DeleteDataObjectRequest) dataObjectRequest, operations); + } else { + throw new IllegalArgumentException("Invalid type for bulk request"); + } + } + + private void addBulkPutOperation(PutDataObjectRequest putRequest, List operations) { + if (putRequest.overwriteIfExists()) { + // Use index operation + operations.add(BulkOperation.of(op -> op.index(i -> { + i + .index(putRequest.index()) + .document(putRequest.dataObject()) + .tDocumentSerializer(new JsonTransformer.XContentObjectJsonpSerializer()); + if (!Strings.isNullOrEmpty(putRequest.id())) { + i.id(putRequest.id()); + } + return i; + }))); + } else { + // Use create operation + operations.add(BulkOperation.of(op -> op.create(c -> { + c + .index(putRequest.index()) + .document(putRequest.dataObject()) + .tDocumentSerializer(new JsonTransformer.XContentObjectJsonpSerializer()); + if (!Strings.isNullOrEmpty(putRequest.id())) { + c.id(putRequest.id()); + } + return c; + }))); + } + } + + private void addBulkUpdateOperation(UpdateDataObjectRequest updateRequest, List operations) { + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + updateRequest.dataObject().toXContent(builder, ToXContent.EMPTY_PARAMS); + Map docMap = JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, builder.toString()) + .map(); + operations.add(BulkOperation.of(op -> op.update(u -> { + u.index(updateRequest.index()).id(updateRequest.id()).document(docMap); + if (updateRequest.ifSeqNo() != null) { + u.ifSeqNo(updateRequest.ifSeqNo()); + } + if (updateRequest.ifPrimaryTerm() != null) { + u.ifPrimaryTerm(updateRequest.ifPrimaryTerm()); + } + if (updateRequest.retryOnConflict() > 0) { + u.retryOnConflict(updateRequest.retryOnConflict()); + } + return u; + }))); + } catch (IOException e) { + // Rethrow unchecked exception on XContent parsing error + throw new OpenSearchStatusException("Failed to parse data object in a bulk update request", RestStatus.INTERNAL_SERVER_ERROR); + } + } + + private void addBulkDeleteOperation(DeleteDataObjectRequest deleteRequest, List operations) { + operations.add(BulkOperation.of(op -> op.delete(d -> d.index(deleteRequest.index()).id(deleteRequest.id())))); + } + + private DataObjectResponse[] bulkResponseItemsToArray(List items) throws IOException { + DataObjectResponse[] responses = new DataObjectResponse[items.size()]; + int i = 0; + for (BulkResponseItem itemResponse : items) { + switch (itemResponse.operationType()) { + case Index: + case Create: + responses[i++] = PutDataObjectResponse + .builder() + .id(itemResponse.id()) + .parser(createParser(itemResponse)) + .failed(itemResponse.error() != null) + .build(); + break; + case Update: + responses[i++] = UpdateDataObjectResponse + .builder() + .id(itemResponse.id()) + .parser(createParser(itemResponse)) + .failed(itemResponse.error() != null) + .build(); + break; + case Delete: + responses[i++] = DeleteDataObjectResponse + .builder() + .id(itemResponse.id()) + .parser(createParser(itemResponse)) + .failed(itemResponse.error() != null) + .build(); + break; + default: + throw new OpenSearchStatusException("Invalid operation type for bulk response", RestStatus.INTERNAL_SERVER_ERROR); + } + } + return responses; } @Override @@ -232,7 +386,7 @@ public CompletionStage searchDataObjectAsync( Executor executor, Boolean isMultiTenancyEnabled ) { - return CompletableFuture.supplyAsync(() -> AccessController.doPrivileged((PrivilegedAction) () -> { + return executePrivilegedAsync(() -> { try { log.info("Searching {}", Arrays.toString(request.indices())); // work around https://github.com/opensearch-project/opensearch-java/issues/1150 @@ -271,7 +425,7 @@ public CompletionStage searchDataObjectAsync( RestStatus.INTERNAL_SERVER_ERROR ); } - }), executor); + }, executor); } private XContentParser createParser(JsonpSerializable obj) throws IOException { diff --git a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelActionTests.java index 87d42f3847..81e4bc37a8 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelActionTests.java @@ -28,6 +28,7 @@ import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.mockito.Spy; import org.opensearch.Version; @@ -56,8 +57,10 @@ import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest; import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.sdkclient.SdkClientFactory; import org.opensearch.ml.stats.MLStat; import org.opensearch.ml.stats.MLStats; +import org.opensearch.sdk.SdkClient; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -82,6 +85,7 @@ public class TransportUndeployModelActionTests extends OpenSearchTestCase { @Mock private Client client; + private SdkClient sdkClient; @Mock ClusterState clusterState; @@ -132,6 +136,7 @@ public class TransportUndeployModelActionTests extends OpenSearchTestCase { public void setup() throws IOException { MockitoAnnotations.openMocks(this); Settings settings = Settings.builder().build(); + sdkClient = Mockito.spy(SdkClientFactory.createSdkClient(client, xContentRegistry, settings)); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); @@ -150,6 +155,7 @@ public void setup() throws IOException { clusterService, threadPool, client, + sdkClient, nodeFilter, mlStats ) @@ -242,7 +248,7 @@ public void testDoExecuteTransportUndeployedModelAction() { public void testProcessUndeployModelResponseAndUpdateNullResponse() { when(undeployModelNodesResponse.getNodes()).thenReturn(null); - action.processUndeployModelResponseAndUpdate(undeployModelNodesResponse, actionListener); + action.processUndeployModelResponseAndUpdate(mock(), undeployModelNodesResponse, actionListener); } public void testProcessUndeployModelResponseAndUpdateResponse() { @@ -276,7 +282,7 @@ public void testProcessUndeployModelResponseAndUpdateResponse() { return null; }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); - action.processUndeployModelResponseAndUpdate(response, actionListener); + action.processUndeployModelResponseAndUpdate(nodesRequest, response, actionListener); verify(actionListener).onResponse(response); } @@ -310,7 +316,7 @@ public void testProcessUndeployModelResponseAndUpdateBulkException() { return null; }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); - action.processUndeployModelResponseAndUpdate(response, actionListener); + action.processUndeployModelResponseAndUpdate(nodesRequest, response, actionListener); verify(actionListener).onResponse(response); } @@ -344,7 +350,7 @@ public void testProcessUndeployModelResponseAndUpdateSyncUpException() { return null; }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); - action.processUndeployModelResponseAndUpdate(response, actionListener); + action.processUndeployModelResponseAndUpdate(nodesRequest, response, actionListener); verify(actionListener).onResponse(response); } @@ -379,7 +385,7 @@ public void testProcessUndeployModelResponseAndUpdateResponseDeployStatusWrong() return null; }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); - action.processUndeployModelResponseAndUpdate(response, actionListener); + action.processUndeployModelResponseAndUpdate(nodesRequest, response, actionListener); verify(actionListener).onResponse(response); } @@ -416,7 +422,7 @@ public void testProcessUndeployModelResponseAndUpdateResponseUndeployPartialNode return null; }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); - action.processUndeployModelResponseAndUpdate(response, actionListener); + action.processUndeployModelResponseAndUpdate(nodesRequest, response, actionListener); verify(actionListener).onResponse(response); } @@ -451,7 +457,7 @@ public void testProcessUndeployModelResponseAndUpdateResponseUndeployEmptyNodes( return null; }).when(client).execute(any(), any(MLSyncUpNodesRequest.class), any()); - action.processUndeployModelResponseAndUpdate(response, actionListener); + action.processUndeployModelResponseAndUpdate(nodesRequest, response, actionListener); verify(actionListener).onResponse(response); } @@ -474,7 +480,7 @@ public void testProcessUndeployModelResponseAndUpdateResponseUndeployNodeEntrySe final List failures = new ArrayList<>(); final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); - action.processUndeployModelResponseAndUpdate(response, actionListener); + action.processUndeployModelResponseAndUpdate(nodesRequest, response, actionListener); } public void testProcessUndeployModelResponseAndUpdateResponseUndeployModelWorkerNodeBeforeRemovalNull() { @@ -494,7 +500,7 @@ public void testProcessUndeployModelResponseAndUpdateResponseUndeployModelWorker final List failures = new ArrayList<>(); final MLUndeployModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); - action.processUndeployModelResponseAndUpdate(response, actionListener); + action.processUndeployModelResponseAndUpdate(nodesRequest, response, actionListener); } public void testNewResponseWithNotFoundModelStatus() { diff --git a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java index 46cc5b8877..1b50e33fcd 100644 --- a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java +++ b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java @@ -7,18 +7,21 @@ import static java.util.Collections.emptyMap; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD; import static org.opensearch.ml.common.CommonValue.MASTER_KEY; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import static org.opensearch.ml.utils.TestHelper.ML_ROLE; import static org.opensearch.ml.utils.TestHelper.setupTestClusterState; @@ -31,21 +34,26 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import org.apache.lucene.search.TotalHits; +import org.junit.AfterClass; import org.junit.Assert; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.Version; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; @@ -56,12 +64,17 @@ import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.action.connector.TransportCreateConnectorActionTests; +import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.model.MLModelState; @@ -71,8 +84,10 @@ import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.sdkclient.SdkClientFactory; import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.TestHelper; +import org.opensearch.sdk.SdkClient; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; @@ -80,6 +95,8 @@ import org.opensearch.search.profile.SearchProfileShardResults; import org.opensearch.search.suggest.Suggest; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ScalingExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import com.google.common.collect.ImmutableMap; @@ -87,8 +104,22 @@ public class MLSyncUpCronTests extends OpenSearchTestCase { + private static TestThreadPool testThreadPool = new TestThreadPool( + TransportCreateConnectorActionTests.class.getName(), + new ScalingExecutorBuilder( + GENERAL_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), + ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL + ) + ); + @Mock private Client client; + private SdkClient sdkClient; + @Mock + NamedXContentRegistry xContentRegistry; @Mock private ClusterService clusterService; @Mock @@ -120,7 +151,6 @@ public void setup() throws IOException { mlNode1 = new DiscoveryNode(mlNode1Id, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT); mlNode2 = new DiscoveryNode(mlNode2Id, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT); encryptor = spy(new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=")); - syncUpCron = new MLSyncUpCron(client, clusterService, nodeHelper, mlIndicesHandler, encryptor, mlFeatureEnabledSetting); testState = setupTestClusterState(); when(clusterService.state()).thenReturn(testState); @@ -132,10 +162,18 @@ public void setup() throws IOException { }).when(mlIndicesHandler).initMLConfigIndex(any()); Settings settings = Settings.builder().build(); + sdkClient = Mockito.spy(SdkClientFactory.createSdkClient(client, xContentRegistry, settings)); threadContext = new ThreadContext(settings); threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.executor(anyString())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL)); + syncUpCron = new MLSyncUpCron(client, sdkClient, clusterService, nodeHelper, mlIndicesHandler, encryptor, mlFeatureEnabledSetting); + } + + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); } public void testInitMlConfig_MasterKeyNotExist() { @@ -236,75 +274,68 @@ public void testRun_Failure() { public void testRefreshModelState_NoSemaphore() throws InterruptedException { syncUpCron.updateModelStateSemaphore.acquire(); syncUpCron.refreshModelState(null, null); - verify(client, never()).search(any(), any()); + verify(client, Mockito.after(1000).never()).search(any()); syncUpCron.updateModelStateSemaphore.release(); } - public void testRefreshModelState_SearchException() { - doThrow(new RuntimeException("test exception")).when(client).search(any(), any()); + public void testRefreshModelState_SearchException() throws Exception { + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onFailure(new RuntimeException("test exception")); + when(client.search(any(SearchRequest.class))).thenReturn(future); syncUpCron.refreshModelState(null, null); - verify(client, times(1)).search(any(), any()); - assertTrue(syncUpCron.updateModelStateSemaphore.tryAcquire()); + // Need a small delay due to multithreading + verify(client, timeout(1000).times(1)).search(any(SearchRequest.class)); + assertBusy(() -> { assertTrue(syncUpCron.updateModelStateSemaphore.tryAcquire()); }, 5, TimeUnit.SECONDS); syncUpCron.updateModelStateSemaphore.release(); } - public void testRefreshModelState_SearchFailed() { - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onFailure(new RuntimeException("search error")); - return null; - }).when(client).search(any(), any()); + public void testRefreshModelState_SearchFailed() throws Exception { + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onFailure(new RuntimeException("search error")); + when(client.search(any(SearchRequest.class))).thenReturn(future); syncUpCron.refreshModelState(null, null); - verify(client, times(1)).search(any(), any()); - assertTrue(syncUpCron.updateModelStateSemaphore.tryAcquire()); + // Need a small delay due to multithreading + verify(client, timeout(1000).times(1)).search(any(SearchRequest.class)); + assertBusy(() -> { assertTrue(syncUpCron.updateModelStateSemaphore.tryAcquire()); }, 5, TimeUnit.SECONDS); syncUpCron.updateModelStateSemaphore.release(); } public void testRefreshModelState_EmptySearchResponse() { - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - SearchHits hits = new SearchHits(new SearchHit[0], null, Float.NaN); - SearchResponseSections searchSections = new SearchResponseSections( - hits, - InternalAggregations.EMPTY, - null, - true, - false, - null, - 1 - ); - SearchResponse searchResponse = new SearchResponse( - searchSections, - null, - 1, - 1, - 0, - 11, - ShardSearchFailure.EMPTY_ARRAY, - SearchResponse.Clusters.EMPTY - ); - actionListener.onResponse(searchResponse); - return null; - }).when(client).search(any(), any()); + SearchHits hits = new SearchHits(new SearchHit[0], null, Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections(hits, InternalAggregations.EMPTY, null, true, false, null, 1); + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(searchResponse); + when(client.search(any(SearchRequest.class))).thenReturn(future); syncUpCron.refreshModelState(new HashMap<>(), new HashMap<>()); - verify(client, times(1)).search(any(), any()); - verify(client, never()).bulk(any(), any()); + // Need a small delay due to multithreading + verify(client, timeout(1000).times(1)).search(any(SearchRequest.class)); + verify(client, Mockito.after(1000).never()).bulk(any()); assertTrue(syncUpCron.updateModelStateSemaphore.tryAcquire()); syncUpCron.updateModelStateSemaphore.release(); } - public void testRefreshModelState_ResetAsDeployFailed() { + public void testRefreshModelState_ResetAsDeployFailed() throws IOException { Map> modelWorkerNodes = new HashMap<>(); Map> deployingModels = new HashMap<>(); - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(createSearchModelResponse("modelId", MLModelState.DEPLOYED, 2, null, Instant.now().toEpochMilli())); - return null; - }).when(client).search(any(), any()); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(createSearchModelResponse("modelId", "tenantId", MLModelState.DEPLOYED, 2, null, Instant.now().toEpochMilli())); + when(client.search(any(SearchRequest.class))).thenReturn(future); + syncUpCron.refreshModelState(modelWorkerNodes, deployingModels); - verify(client, times(1)).search(any(), any()); + // Need a small delay due to multithreading + verify(client, timeout(1000).times(1)).search(any(SearchRequest.class)); ArgumentCaptor bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); - verify(client, times(1)).bulk(bulkRequestCaptor.capture(), any()); + verify(client, timeout(1000).times(1)).bulk(bulkRequestCaptor.capture()); BulkRequest bulkRequest = bulkRequestCaptor.getValue(); assertEquals(1, bulkRequest.numberOfActions()); assertEquals(1, bulkRequest.requests().size()); @@ -315,19 +346,18 @@ public void testRefreshModelState_ResetAsDeployFailed() { assertEquals(ML_MODEL_INDEX, updateRequest.index()); } - public void testRefreshModelState_ResetAsPartiallyDeployed() { + public void testRefreshModelState_ResetAsPartiallyDeployed() throws IOException { Map> modelWorkerNodes = new HashMap<>(); modelWorkerNodes.put("modelId", ImmutableSet.of("node1")); Map> deployingModels = new HashMap<>(); - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(createSearchModelResponse("modelId", MLModelState.DEPLOYED, 2, 0, Instant.now().toEpochMilli())); - return null; - }).when(client).search(any(), any()); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(createSearchModelResponse("modelId", "tenantId", MLModelState.DEPLOYED, 2, 0, Instant.now().toEpochMilli())); + when(client.search(any(SearchRequest.class))).thenReturn(future); syncUpCron.refreshModelState(modelWorkerNodes, deployingModels); - verify(client, times(1)).search(any(), any()); + // Need a small delay due to multithreading + verify(client, timeout(1000).times(1)).search(any(SearchRequest.class)); ArgumentCaptor bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); - verify(client, times(1)).bulk(bulkRequestCaptor.capture(), any()); + verify(client, timeout(1000).times(1)).bulk(bulkRequestCaptor.capture()); BulkRequest bulkRequest = bulkRequestCaptor.getValue(); assertEquals(1, bulkRequest.numberOfActions()); assertEquals(1, bulkRequest.requests().size()); @@ -338,20 +368,21 @@ public void testRefreshModelState_ResetAsPartiallyDeployed() { assertEquals(ML_MODEL_INDEX, updateRequest.index()); } - public void testRefreshModelState_ResetCurrentWorkerNodeCountForPartiallyDeployed() { + public void testRefreshModelState_ResetCurrentWorkerNodeCountForPartiallyDeployed() throws IOException { Map> modelWorkerNodes = new HashMap<>(); modelWorkerNodes.put("modelId", ImmutableSet.of("node1")); Map> deployingModels = new HashMap<>(); - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener - .onResponse(createSearchModelResponse("modelId", MLModelState.PARTIALLY_DEPLOYED, 3, 2, Instant.now().toEpochMilli())); - return null; - }).when(client).search(any(), any()); + PlainActionFuture future = PlainActionFuture.newFuture(); + future + .onResponse( + createSearchModelResponse("modelId", "tenantId", MLModelState.PARTIALLY_DEPLOYED, 3, 2, Instant.now().toEpochMilli()) + ); + when(client.search(any(SearchRequest.class))).thenReturn(future); syncUpCron.refreshModelState(modelWorkerNodes, deployingModels); - verify(client, times(1)).search(any(), any()); + // Need a small delay due to multithreading + verify(client, timeout(1000).times(1)).search(any(SearchRequest.class)); ArgumentCaptor bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); - verify(client, times(1)).bulk(bulkRequestCaptor.capture(), any()); + verify(client, timeout(1000).times(1)).bulk(bulkRequestCaptor.capture()); BulkRequest bulkRequest = bulkRequestCaptor.getValue(); assertEquals(1, bulkRequest.numberOfActions()); assertEquals(1, bulkRequest.requests().size()); @@ -362,20 +393,19 @@ public void testRefreshModelState_ResetCurrentWorkerNodeCountForPartiallyDeploye assertEquals(ML_MODEL_INDEX, updateRequest.index()); } - public void testRefreshModelState_ResetAsDeploying() { + public void testRefreshModelState_ResetAsDeploying() throws IOException { Map> modelWorkerNodes = new HashMap<>(); modelWorkerNodes.put("modelId", ImmutableSet.of("node1")); Map> deployingModels = new HashMap<>(); deployingModels.put("modelId", ImmutableSet.of("node2")); - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(createSearchModelResponse("modelId", MLModelState.DEPLOY_FAILED, 2, 0, Instant.now().toEpochMilli())); - return null; - }).when(client).search(any(), any()); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(createSearchModelResponse("modelId", "tenantId", MLModelState.DEPLOY_FAILED, 2, 0, Instant.now().toEpochMilli())); + when(client.search(any(SearchRequest.class))).thenReturn(future); syncUpCron.refreshModelState(modelWorkerNodes, deployingModels); - verify(client, times(1)).search(any(), any()); + // Need a small delay due to multithreading + verify(client, timeout(1000).times(1)).search(any(SearchRequest.class)); ArgumentCaptor bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); - verify(client, times(1)).bulk(bulkRequestCaptor.capture(), any()); + verify(client, timeout(1000).times(1)).bulk(bulkRequestCaptor.capture()); BulkRequest bulkRequest = bulkRequestCaptor.getValue(); assertEquals(1, bulkRequest.numberOfActions()); assertEquals(1, bulkRequest.requests().size()); @@ -386,31 +416,29 @@ public void testRefreshModelState_ResetAsDeploying() { assertEquals(ML_MODEL_INDEX, updateRequest.index()); } - public void testRefreshModelState_NotResetState_DeployingModelTaskRunning() { + public void testRefreshModelState_NotResetState_DeployingModelTaskRunning() throws IOException { Map> modelWorkerNodes = new HashMap<>(); Map> deployingModels = new HashMap<>(); deployingModels.put("modelId", ImmutableSet.of("node2")); - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(createSearchModelResponse("modelId", MLModelState.DEPLOYING, 2, null, Instant.now().toEpochMilli())); - return null; - }).when(client).search(any(), any()); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(createSearchModelResponse("modelId", "tenantId", MLModelState.DEPLOYING, 2, null, Instant.now().toEpochMilli())); + when(client.search(any(SearchRequest.class))).thenReturn(future); syncUpCron.refreshModelState(modelWorkerNodes, deployingModels); - verify(client, times(1)).search(any(), any()); - verify(client, never()).bulk(any(), any()); + // Need a small delay due to multithreading + verify(client, timeout(1000).times(1)).search(any(SearchRequest.class)); + verify(client, Mockito.after(1000).never()).bulk(any()); } - public void testRefreshModelState_NotResetState_DeployingInGraceTime() { + public void testRefreshModelState_NotResetState_DeployingInGraceTime() throws IOException { Map> modelWorkerNodes = new HashMap<>(); Map> deployingModels = new HashMap<>(); - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(createSearchModelResponse("modelId", MLModelState.DEPLOYING, 2, null, Instant.now().toEpochMilli())); - return null; - }).when(client).search(any(), any()); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(createSearchModelResponse("modelId", "tenantId", MLModelState.DEPLOYING, 2, null, Instant.now().toEpochMilli())); + when(client.search(any(SearchRequest.class))).thenReturn(future); syncUpCron.refreshModelState(modelWorkerNodes, deployingModels); - verify(client, times(1)).search(any(), any()); - verify(client, never()).bulk(any(), any()); + // Need a small delay due to multithreading + verify(client, timeout(1000).times(1)).search(any(SearchRequest.class)); + verify(client, Mockito.after(1000).never()).bulk(any()); } private void mockSyncUp_GatherRunningTasks() { @@ -448,6 +476,7 @@ private void mockSyncUp_GatherRunningTasks_Failure() { private SearchResponse createSearchModelResponse( String modelId, + String tenantId, MLModelState state, Integer planningWorkerNodeCount, Integer currentWorkerNodeCount, @@ -455,6 +484,7 @@ private SearchResponse createSearchModelResponse( ) throws IOException { XContentBuilder content = TestHelper.builder(); content.startObject(); + content.field(CommonValue.TENANT_ID, tenantId); content.field(MLModel.MODEL_STATE_FIELD, state); content.field(MLModel.ALGORITHM_FIELD, FunctionName.KMEANS); content.field(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, planningWorkerNodeCount); diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java index 0964bc50ac..8edfc89436 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java @@ -9,6 +9,7 @@ package org.opensearch.ml.sdkclient; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; @@ -50,6 +51,8 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.CommonValue; +import org.opensearch.sdk.BulkDataObjectRequest; +import org.opensearch.sdk.BulkDataObjectResponse; import org.opensearch.sdk.DeleteDataObjectRequest; import org.opensearch.sdk.DeleteDataObjectResponse; import org.opensearch.sdk.GetDataObjectRequest; @@ -670,6 +673,119 @@ public void updateDataObjectAsync_VersionCheckRetryFailure() { assertEquals(RestStatus.CONFLICT, ((OpenSearchStatusException) cause).status()); } + @Test + public void testBulkDataObject_HappyCase() { + PutDataObjectRequest putRequest = PutDataObjectRequest + .builder() + .id(TEST_ID + "1") + .tenantId(TENANT_ID) + .dataObject(testDataObject) + .build(); + UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest + .builder() + .id(TEST_ID + "2") + .tenantId(TENANT_ID) + .dataObject(testDataObject) + .build(); + DeleteDataObjectRequest deleteRequest = DeleteDataObjectRequest.builder().id(TEST_ID + "3").tenantId(TENANT_ID).build(); + BulkDataObjectRequest bulkRequest = BulkDataObjectRequest + .builder() + .globalIndex(TEST_INDEX) + .build() + .add(putRequest) + .add(updateRequest) + .add(deleteRequest); + + when(dynamoDbClient.putItem(any(PutItemRequest.class))).thenReturn(PutItemResponse.builder().build()); + GetItemResponse getItemResponse = GetItemResponse + .builder() + .item( + Map + .ofEntries( + Map.entry(SOURCE, AttributeValue.builder().m(Map.of("data", AttributeValue.builder().s("foo").build())).build()), + Map.entry(SEQ_NUM, AttributeValue.builder().n("0").build()) + ) + ) + .build(); + when(dynamoDbClient.getItem(any(GetItemRequest.class))).thenReturn(getItemResponse); + when(dynamoDbClient.updateItem(any(UpdateItemRequest.class))).thenReturn(UpdateItemResponse.builder().build()); + when(dynamoDbClient.deleteItem(any(DeleteItemRequest.class))).thenReturn(DeleteItemResponse.builder().build()); + + BulkDataObjectResponse response = sdkClient + .bulkDataObjectAsync(bulkRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + + assertEquals(3, response.getResponses().length); + assertTrue(response.getResponses()[0] instanceof PutDataObjectResponse); + assertTrue(response.getResponses()[1] instanceof UpdateDataObjectResponse); + assertTrue(response.getResponses()[2] instanceof DeleteDataObjectResponse); + + assertEquals(TEST_ID + "1", response.getResponses()[0].id()); + assertEquals(TEST_ID + "2", response.getResponses()[1].id()); + assertEquals(TEST_ID + "3", response.getResponses()[2].id()); + + assertTrue(response.getTookInMillis() >= 0); + } + + @Test + public void testBulkDataObject_WithFailures() { + PutDataObjectRequest putRequest = PutDataObjectRequest + .builder() + .id(TEST_ID + "1") + .tenantId(TENANT_ID) + .dataObject(testDataObject) + .build(); + UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest + .builder() + .id(TEST_ID + "2") + .tenantId(TENANT_ID) + .dataObject(testDataObject) + .build(); + DeleteDataObjectRequest deleteRequest = DeleteDataObjectRequest.builder().id(TEST_ID + "3").tenantId(TENANT_ID).build(); + BulkDataObjectRequest bulkRequest = BulkDataObjectRequest + .builder() + .globalIndex(TEST_INDEX) + .build() + .add(putRequest) + .add(updateRequest) + .add(deleteRequest); + + when(dynamoDbClient.putItem(any(PutItemRequest.class))).thenReturn(PutItemResponse.builder().build()); + GetItemResponse getItemResponse = GetItemResponse + .builder() + .item( + Map + .ofEntries( + Map.entry(SOURCE, AttributeValue.builder().m(Map.of("data", AttributeValue.builder().s("foo").build())).build()), + Map.entry(SEQ_NUM, AttributeValue.builder().n("0").build()) + ) + ) + .build(); + when(dynamoDbClient.getItem(any(GetItemRequest.class))).thenReturn(getItemResponse); + Exception cause = new OpenSearchStatusException("Update failed with conflict", RestStatus.CONFLICT); + when(dynamoDbClient.updateItem(any(UpdateItemRequest.class))).thenThrow(cause); + when(dynamoDbClient.deleteItem(any(DeleteItemRequest.class))).thenReturn(DeleteItemResponse.builder().build()); + + BulkDataObjectResponse response = sdkClient + .bulkDataObjectAsync(bulkRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + + assertEquals(3, response.getResponses().length); + assertFalse(response.getResponses()[0].isFailed()); + assertNull(response.getResponses()[0].cause()); + assertTrue(response.getResponses()[0] instanceof PutDataObjectResponse); + assertTrue(response.getResponses()[1].isFailed()); + assertTrue(response.getResponses()[1].cause() instanceof OpenSearchStatusException); + assertEquals("Update failed with conflict", response.getResponses()[1].cause().getMessage()); + assertEquals(RestStatus.CONFLICT, response.getResponses()[1].status()); + assertTrue(response.getResponses()[1] instanceof UpdateDataObjectResponse); + assertFalse(response.getResponses()[2].isFailed()); + assertNull(response.getResponses()[0].cause()); + assertTrue(response.getResponses()[2] instanceof DeleteDataObjectResponse); + } + @Test public void searchDataObjectAsync_HappyCase() { SearchSourceBuilder searchSourceBuilder = SearchSourceBuilder.searchSource(); @@ -729,5 +845,4 @@ private Map getComplexDataSource() { Map.entry("testObject", AttributeValue.builder().m(Map.of("data", AttributeValue.builder().s("foo").build())).build()) ); } - } diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java index 9e214c6b36..812376cb24 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/RemoteClusterIndicesClientTests.java @@ -9,6 +9,8 @@ package org.opensearch.ml.sdkclient; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -18,6 +20,7 @@ import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import java.io.IOException; +import java.util.Arrays; import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -40,6 +43,8 @@ import org.opensearch.client.opensearch._types.Result; import org.opensearch.client.opensearch._types.ShardStatistics; import org.opensearch.client.opensearch._types.query_dsl.Query; +import org.opensearch.client.opensearch.core.BulkRequest; +import org.opensearch.client.opensearch.core.BulkResponse; import org.opensearch.client.opensearch.core.DeleteRequest; import org.opensearch.client.opensearch.core.DeleteResponse; import org.opensearch.client.opensearch.core.GetRequest; @@ -50,6 +55,8 @@ import org.opensearch.client.opensearch.core.SearchResponse; import org.opensearch.client.opensearch.core.UpdateRequest; import org.opensearch.client.opensearch.core.UpdateResponse; +import org.opensearch.client.opensearch.core.bulk.BulkResponseItem; +import org.opensearch.client.opensearch.core.bulk.OperationType; import org.opensearch.client.opensearch.core.search.HitsMetadata; import org.opensearch.client.opensearch.core.search.TotalHits; import org.opensearch.client.opensearch.core.search.TotalHitsRelation; @@ -63,6 +70,8 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.sdk.BulkDataObjectRequest; +import org.opensearch.sdk.BulkDataObjectResponse; import org.opensearch.sdk.DeleteDataObjectRequest; import org.opensearch.sdk.DeleteDataObjectResponse; import org.opensearch.sdk.GetDataObjectRequest; @@ -558,6 +567,187 @@ public void testDeleteDataObject_Exception() throws IOException { assertEquals(OpenSearchStatusException.class, ce.getCause().getClass()); } + public void testBulkDataObject() throws IOException { + PutDataObjectRequest putRequest = PutDataObjectRequest + .builder() + .id(TEST_ID + "1") + .tenantId(TEST_TENANT_ID) + .dataObject(testDataObject) + .build(); + UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest + .builder() + .id(TEST_ID + "2") + .tenantId(TEST_TENANT_ID) + .dataObject(testDataObject) + .build(); + DeleteDataObjectRequest deleteRequest = DeleteDataObjectRequest.builder().id(TEST_ID + "3").tenantId(TEST_TENANT_ID).build(); + + BulkDataObjectRequest bulkRequest = BulkDataObjectRequest + .builder() + .globalIndex(TEST_INDEX) + .build() + .add(putRequest) + .add(updateRequest) + .add(deleteRequest); + + BulkResponse bulkResponse = new BulkResponse.Builder() + .took(100L) + .items( + Arrays + .asList( + new BulkResponseItem.Builder() + .id(TEST_ID + "1") + .index(TEST_INDEX) + .operationType(OperationType.Index) + .result(Result.Created.jsonValue()) + .status(RestStatus.OK.getStatus()) + .build(), + new BulkResponseItem.Builder() + .id(TEST_ID + "2") + .index(TEST_INDEX) + .operationType(OperationType.Update) + .result(Result.Updated.jsonValue()) + .status(RestStatus.OK.getStatus()) + .build(), + new BulkResponseItem.Builder() + .id(TEST_ID + "3") + .index(TEST_INDEX) + .operationType(OperationType.Delete) + .result(Result.Deleted.jsonValue()) + .status(RestStatus.OK.getStatus()) + .build() + ) + ) + .errors(false) + .build(); + + ArgumentCaptor bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); + when(mockedOpenSearchClient.bulk(bulkRequestCaptor.capture())).thenReturn(bulkResponse); + + BulkDataObjectResponse response = sdkClient + .bulkDataObjectAsync(bulkRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + + assertEquals(3, bulkRequestCaptor.getValue().operations().size()); + assertEquals(3, response.getResponses().length); + assertEquals(100L, response.getTookInMillis()); + + assertTrue(response.getResponses()[0] instanceof PutDataObjectResponse); + assertTrue(response.getResponses()[1] instanceof UpdateDataObjectResponse); + assertTrue(response.getResponses()[2] instanceof DeleteDataObjectResponse); + + assertEquals(TEST_ID + "1", response.getResponses()[0].id()); + assertEquals(TEST_ID + "2", response.getResponses()[1].id()); + assertEquals(TEST_ID + "3", response.getResponses()[2].id()); + } + + public void testBulkDataObject_WithFailures() throws IOException { + PutDataObjectRequest putRequest = PutDataObjectRequest + .builder() + .id(TEST_ID + "1") + .tenantId(TEST_TENANT_ID) + .dataObject(testDataObject) + .build(); + UpdateDataObjectRequest updateRequest = UpdateDataObjectRequest + .builder() + .id(TEST_ID + "2") + .tenantId(TEST_TENANT_ID) + .dataObject(testDataObject) + .build(); + DeleteDataObjectRequest deleteRequest = DeleteDataObjectRequest.builder().id(TEST_ID + "3").build(); + + BulkDataObjectRequest bulkRequest = BulkDataObjectRequest + .builder() + .globalIndex(TEST_INDEX) + .build() + .add(putRequest) + .add(updateRequest); + + BulkResponse bulkResponse = new BulkResponse.Builder() + .took(100L) + .items( + Arrays + .asList( + new BulkResponseItem.Builder() + .id(TEST_ID + "1") + .index(TEST_INDEX) + .operationType(OperationType.Index) + .result(Result.Created.jsonValue()) + .status(RestStatus.OK.getStatus()) + .build(), + new BulkResponseItem.Builder() + .id(TEST_ID + "2") + .index(TEST_INDEX) + .operationType(OperationType.Update) + .error(new ErrorCause.Builder().type("update_error").reason("Update failed").build()) + .status(RestStatus.INTERNAL_SERVER_ERROR.getStatus()) + .build(), + new BulkResponseItem.Builder() + .id(TEST_ID + "3") + .index(TEST_INDEX) + .operationType(OperationType.Delete) + .result(Result.Deleted.jsonValue()) + .status(RestStatus.OK.getStatus()) + .build() + ) + ) + .errors(true) + .build(); + + when(mockedOpenSearchClient.bulk(any(BulkRequest.class))).thenReturn(bulkResponse); + + BulkDataObjectResponse response = sdkClient + .bulkDataObjectAsync(bulkRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture() + .join(); + + assertEquals(3, response.getResponses().length); + assertFalse(response.getResponses()[0].isFailed()); + assertTrue(response.getResponses()[0] instanceof PutDataObjectResponse); + assertTrue(response.getResponses()[1].isFailed()); + assertTrue(response.getResponses()[1] instanceof UpdateDataObjectResponse); + assertFalse(response.getResponses()[2].isFailed()); + assertTrue(response.getResponses()[2] instanceof DeleteDataObjectResponse); + } + + public void testBulkDataObject_Exception() throws OpenSearchException, IOException { + PutDataObjectRequest putRequest = PutDataObjectRequest + .builder() + .index(TEST_INDEX) + .id(TEST_ID) + .tenantId(TEST_TENANT_ID) + .dataObject(testDataObject) + .build(); + + BulkDataObjectRequest bulkRequest = BulkDataObjectRequest.builder().build().add(putRequest); + + when(mockedOpenSearchClient.bulk(any(BulkRequest.class))) + .thenThrow( + new OpenSearchException( + new ErrorResponse.Builder() + .error( + new ErrorCause.Builder() + .type("parse_exception") + .reason("Failed to parse data object in a bulk response") + .build() + ) + .status(RestStatus.INTERNAL_SERVER_ERROR.getStatus()) + .build() + ) + ); + + CompletableFuture future = sdkClient + .bulkDataObjectAsync(bulkRequest, testThreadPool.executor(GENERAL_THREAD_POOL)) + .toCompletableFuture(); + + CompletionException ce = assertThrows(CompletionException.class, () -> future.join()); + Throwable cause = ce.getCause(); + assertEquals(OpenSearchException.class, cause.getClass()); + assertEquals(RestStatus.INTERNAL_SERVER_ERROR.getStatus(), ((OpenSearchException) cause).status()); + assertTrue(cause.getMessage().contains("Failed to parse data object in a bulk response")); + } + @SuppressWarnings({ "unchecked", "rawtypes" }) public void testSearchDataObject() throws IOException { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();