From e231a97d991f5f5ed358641e9a640122f5871b05 Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Tue, 21 Nov 2023 16:44:26 +0530 Subject: [PATCH] get model group API Signed-off-by: Bhavana Ramaram --- .../model_group/MLModelGroupGetAction.java | 15 ++ .../model_group/MLModelGroupGetRequest.java | 80 +++++++++++ .../model_group/MLModelGroupGetResponse.java | 67 +++++++++ plugin/build.gradle | 4 +- .../GetModelGroupTransportAction.java | 131 ++++++++++++++++++ .../ml/plugin/MachineLearningPlugin.java | 6 + .../ml/rest/RestMLGetModelGroupAction.java | 67 +++++++++ 7 files changed, 369 insertions(+), 1 deletion(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetAction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequest.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetResponse.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelGroupAction.java diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetAction.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetAction.java new file mode 100644 index 0000000000..2a8177eda5 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetAction.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model_group; + +import org.opensearch.action.ActionType; + +public class MLModelGroupGetAction extends ActionType { + public static final MLModelGroupGetAction INSTANCE = new MLModelGroupGetAction(); + public static final String NAME = "cluster:admin/opensearch/ml/model_groups/get"; + + private MLModelGroupGetAction() { super(NAME, MLModelGroupGetResponse::new);} +} \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequest.java new file mode 100644 index 0000000000..65fad3b78d --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequest.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model_group; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.opensearch.action.ValidateActions.addValidationError; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLModelGroupGetRequest extends ActionRequest { + + String modelGroupId; + boolean returnContent; + + @Builder + public MLModelGroupGetRequest(String modelGroupId, boolean returnContent) { + this.modelGroupId = modelGroupId; + this.returnContent = returnContent; + } + + public MLModelGroupGetRequest(StreamInput in) throws IOException { + super(in); + this.modelGroupId = in.readString(); + this.returnContent = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.modelGroupId); + out.writeBoolean(returnContent); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + + if (this.modelGroupId == null) { + exception = addValidationError("Model group id can't be null", exception); + } + + return exception; + } + + public static MLModelGroupGetRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLModelGroupGetRequest) { + return (MLModelGroupGetRequest)actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLModelGroupGetRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLModelGroupGetRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetResponse.java new file mode 100644 index 0000000000..90775e09c4 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetResponse.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model_group; + +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.MLModelGroup; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +@Getter +@ToString +public class MLModelGroupGetResponse extends ActionResponse implements ToXContentObject { + + MLModelGroup mlModelGroup; + + @Builder + public MLModelGroupGetResponse(MLModelGroup mlModelGroup) { + this.mlModelGroup = mlModelGroup; + } + + + public MLModelGroupGetResponse(StreamInput in) throws IOException { + super(in); + mlModelGroup = mlModelGroup.fromStream(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException{ + mlModelGroup.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + return mlModelGroup.toXContent(xContentBuilder, params); + } + + public static MLModelGroupGetResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLModelGroupGetResponse) { + return (MLModelGroupGetResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLModelGroupGetResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLModelGroupGetResponse", e); + } + } +} diff --git a/plugin/build.gradle b/plugin/build.gradle index b8a4d47d22..993bb64699 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -297,7 +297,9 @@ List jacocoExclusions = [ 'org.opensearch.ml.cluster.MLSyncUpCron', 'org.opensearch.ml.model.MLModelGroupManager', 'org.opensearch.ml.helper.ModelAccessControlHelper', - 'org.opensearch.ml.action.models.DeleteModelTransportAction.2' + 'org.opensearch.ml.action.models.DeleteModelTransportAction.2', + 'org.opensearch.ml.action.model_group.GetModelGroupTransportAction', + 'org.opensearch.ml.rest.RestMLGetModelGroupAction' ] jacocoTestCoverageVerification { diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java new file mode 100644 index 0000000000..e6578a7296 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java @@ -0,0 +1,131 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.model_group; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; +import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.common.exception.MLValidationException; +import org.opensearch.ml.common.transport.model_group.MLModelGroupGetAction; +import org.opensearch.ml.common.transport.model_group.MLModelGroupGetRequest; +import org.opensearch.ml.common.transport.model_group.MLModelGroupGetResponse; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@FieldDefaults(level = AccessLevel.PRIVATE) +public class GetModelGroupTransportAction extends HandledTransportAction { + + Client client; + NamedXContentRegistry xContentRegistry; + ClusterService clusterService; + + ModelAccessControlHelper modelAccessControlHelper; + + @Inject + public GetModelGroupTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + NamedXContentRegistry xContentRegistry, + ClusterService clusterService, + ModelAccessControlHelper modelAccessControlHelper + ) { + super(MLModelGroupGetAction.NAME, transportService, actionFilters, MLModelGroupGetRequest::new); + this.client = client; + this.xContentRegistry = xContentRegistry; + this.clusterService = clusterService; + this.modelAccessControlHelper = modelAccessControlHelper; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.fromActionRequest(request); + String modelGroupId = mlModelGroupGetRequest.getModelGroupId(); + FetchSourceContext fetchSourceContext = getFetchSourceContext(mlModelGroupGetRequest.isReturnContent()); + GetRequest getRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId).fetchSourceContext(fetchSourceContext); + User user = RestActionUtils.getUserContext(client); + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); + client.get(getRequest, ActionListener.wrap(r -> { + if (r != null && r.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + GetResponse getResponse = r; + + MLModelGroup mlModelGroup = MLModelGroup.parse(parser); + modelAccessControlHelper.validateModelGroupAccess(user, modelGroupId, client, ActionListener.wrap(access -> { + if (!access) { + wrappedListener + .onFailure( + new MLValidationException( + "User doesn't have privilege to perform this operation on this model group" + ) + ); + } else { + wrappedListener.onResponse(MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build()); + } + }, e -> { + log.error("Failed to validate access for Model Group " + modelGroupId, e); + wrappedListener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Failed to parse ml model group" + r.getId(), e); + wrappedListener.onFailure(e); + } + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "Failed to find model group with the provided model group id: " + modelGroupId, + RestStatus.NOT_FOUND + ) + ); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group index")); + } else { + log.error("Failed to get ML model group" + modelGroupId, e); + wrappedListener.onFailure(e); + } + })); + } catch (Exception e) { + log.error("Failed to get ML model group " + modelGroupId, e); + actionListener.onFailure(e); + } + + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index ad3b4dfc44..f5ba454c4d 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -47,6 +47,7 @@ import org.opensearch.ml.action.forward.TransportForwardAction; import org.opensearch.ml.action.handler.MLSearchHandler; import org.opensearch.ml.action.model_group.DeleteModelGroupTransportAction; +import org.opensearch.ml.action.model_group.GetModelGroupTransportAction; import org.opensearch.ml.action.model_group.SearchModelGroupTransportAction; import org.opensearch.ml.action.model_group.TransportRegisterModelGroupAction; import org.opensearch.ml.action.model_group.TransportUpdateModelGroupAction; @@ -103,6 +104,7 @@ import org.opensearch.ml.common.transport.model.MLModelSearchAction; import org.opensearch.ml.common.transport.model.MLUpdateModelAction; import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteAction; +import org.opensearch.ml.common.transport.model_group.MLModelGroupGetAction; import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupAction; import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupAction; @@ -153,6 +155,7 @@ import org.opensearch.ml.rest.RestMLExecuteAction; import org.opensearch.ml.rest.RestMLGetConnectorAction; import org.opensearch.ml.rest.RestMLGetModelAction; +import org.opensearch.ml.rest.RestMLGetModelGroupAction; import org.opensearch.ml.rest.RestMLGetTaskAction; import org.opensearch.ml.rest.RestMLPredictionAction; import org.opensearch.ml.rest.RestMLProfileAction; @@ -290,6 +293,7 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc new ActionHandler<>(MLSyncUpAction.INSTANCE, TransportSyncUpOnNodeAction.class), new ActionHandler<>(MLRegisterModelGroupAction.INSTANCE, TransportRegisterModelGroupAction.class), new ActionHandler<>(MLUpdateModelGroupAction.INSTANCE, TransportUpdateModelGroupAction.class), + new ActionHandler<>(MLModelGroupGetAction.INSTANCE, GetModelGroupTransportAction.class), new ActionHandler<>(MLModelGroupSearchAction.INSTANCE, SearchModelGroupTransportAction.class), new ActionHandler<>(MLModelGroupDeleteAction.INSTANCE, DeleteModelGroupTransportAction.class), new ActionHandler<>(MLCreateConnectorAction.INSTANCE, TransportCreateConnectorAction.class), @@ -539,6 +543,7 @@ public List getRestHandlers( RestMLUploadModelChunkAction restMLUploadModelChunkAction = new RestMLUploadModelChunkAction(clusterService, settings); RestMLRegisterModelGroupAction restMLCreateModelGroupAction = new RestMLRegisterModelGroupAction(); RestMLUpdateModelGroupAction restMLUpdateModelGroupAction = new RestMLUpdateModelGroupAction(); + RestMLGetModelGroupAction restMLGetModelGroupAction = new RestMLGetModelGroupAction(); RestMLSearchModelGroupAction restMLSearchModelGroupAction = new RestMLSearchModelGroupAction(); RestMLUpdateModelAction restMLUpdateModelAction = new RestMLUpdateModelAction(); RestMLDeleteModelGroupAction restMLDeleteModelGroupAction = new RestMLDeleteModelGroupAction(); @@ -574,6 +579,7 @@ public List getRestHandlers( restMLUploadModelChunkAction, restMLCreateModelGroupAction, restMLUpdateModelGroupAction, + restMLGetModelGroupAction, restMLSearchModelGroupAction, restMLDeleteModelGroupAction, restMLCreateConnectorAction, diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelGroupAction.java new file mode 100644 index 0000000000..c6c51e5bc5 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelGroupAction.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_GROUP_ID; +import static org.opensearch.ml.utils.RestActionUtils.getParameterId; +import static org.opensearch.ml.utils.RestActionUtils.returnContent; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.transport.model_group.MLModelGroupGetAction; +import org.opensearch.ml.common.transport.model_group.MLModelGroupGetRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLGetModelGroupAction extends BaseRestHandler { + private static final String ML_GET_MODEL_GROUP_ACTION = "ml_get_model_group_action"; + + /** + * Constructor + */ + public RestMLGetModelGroupAction() {} + + @Override + public String getName() { + return ML_GET_MODEL_GROUP_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/model_groups/{%s}", ML_BASE_URI, PARAMETER_MODEL_GROUP_ID)) + ); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLModelGroupGetRequest mlModelGroupGetRequest = getRequest(request); + return channel -> client.execute(MLModelGroupGetAction.INSTANCE, mlModelGroupGetRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLModelGroupGetRequest from a RestRequest + * + * @param request RestRequest + * @return MLModelGroupGetRequest + */ + @VisibleForTesting + MLModelGroupGetRequest getRequest(RestRequest request) throws IOException { + String modelGroupId = getParameterId(request, PARAMETER_MODEL_GROUP_ID); + boolean returnContent = returnContent(request); + + return new MLModelGroupGetRequest(modelGroupId, returnContent); + } +}