Skip to content

Commit

Permalink
get model group API
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
rbhavna committed Nov 21, 2023
1 parent 9f438d3 commit e231a97
Show file tree
Hide file tree
Showing 7 changed files with 369 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model_group;

import org.opensearch.action.ActionType;

public class MLModelGroupGetAction extends ActionType<MLModelGroupGetResponse> {
public static final MLModelGroupGetAction INSTANCE = new MLModelGroupGetAction();
public static final String NAME = "cluster:admin/opensearch/ml/model_groups/get";

private MLModelGroupGetAction() { super(NAME, MLModelGroupGetResponse::new);}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model_group;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import static org.opensearch.action.ValidateActions.addValidationError;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@ToString
public class MLModelGroupGetRequest extends ActionRequest {

String modelGroupId;
boolean returnContent;

@Builder
public MLModelGroupGetRequest(String modelGroupId, boolean returnContent) {
this.modelGroupId = modelGroupId;
this.returnContent = returnContent;
}

public MLModelGroupGetRequest(StreamInput in) throws IOException {
super(in);
this.modelGroupId = in.readString();
this.returnContent = in.readBoolean();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.modelGroupId);
out.writeBoolean(returnContent);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;

if (this.modelGroupId == null) {
exception = addValidationError("Model group id can't be null", exception);
}

return exception;
}

public static MLModelGroupGetRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLModelGroupGetRequest) {
return (MLModelGroupGetRequest)actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLModelGroupGetRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionRequest into MLModelGroupGetRequest", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model_group;

import lombok.Builder;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.MLModelGroup;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

@Getter
@ToString
public class MLModelGroupGetResponse extends ActionResponse implements ToXContentObject {

MLModelGroup mlModelGroup;

@Builder
public MLModelGroupGetResponse(MLModelGroup mlModelGroup) {
this.mlModelGroup = mlModelGroup;
}


public MLModelGroupGetResponse(StreamInput in) throws IOException {
super(in);
mlModelGroup = mlModelGroup.fromStream(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException{
mlModelGroup.writeTo(out);
}

@Override
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
return mlModelGroup.toXContent(xContentBuilder, params);
}

public static MLModelGroupGetResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLModelGroupGetResponse) {
return (MLModelGroupGetResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLModelGroupGetResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into MLModelGroupGetResponse", e);
}
}
}
4 changes: 3 additions & 1 deletion plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,9 @@ List<String> jacocoExclusions = [
'org.opensearch.ml.cluster.MLSyncUpCron',
'org.opensearch.ml.model.MLModelGroupManager',
'org.opensearch.ml.helper.ModelAccessControlHelper',
'org.opensearch.ml.action.models.DeleteModelTransportAction.2'
'org.opensearch.ml.action.models.DeleteModelTransportAction.2',
'org.opensearch.ml.action.model_group.GetModelGroupTransportAction',
'org.opensearch.ml.rest.RestMLGetModelGroupAction'
]

jacocoTestCoverageVerification {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.model_group;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;
import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.ml.common.MLModelGroup;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.transport.model_group.MLModelGroupGetAction;
import org.opensearch.ml.common.transport.model_group.MLModelGroupGetRequest;
import org.opensearch.ml.common.transport.model_group.MLModelGroupGetResponse;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import lombok.AccessLevel;
import lombok.experimental.FieldDefaults;
import lombok.extern.log4j.Log4j2;

@Log4j2
@FieldDefaults(level = AccessLevel.PRIVATE)
public class GetModelGroupTransportAction extends HandledTransportAction<ActionRequest, MLModelGroupGetResponse> {

Client client;
NamedXContentRegistry xContentRegistry;
ClusterService clusterService;

ModelAccessControlHelper modelAccessControlHelper;

@Inject
public GetModelGroupTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
NamedXContentRegistry xContentRegistry,
ClusterService clusterService,
ModelAccessControlHelper modelAccessControlHelper
) {
super(MLModelGroupGetAction.NAME, transportService, actionFilters, MLModelGroupGetRequest::new);
this.client = client;
this.xContentRegistry = xContentRegistry;
this.clusterService = clusterService;
this.modelAccessControlHelper = modelAccessControlHelper;
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<MLModelGroupGetResponse> actionListener) {
MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.fromActionRequest(request);
String modelGroupId = mlModelGroupGetRequest.getModelGroupId();
FetchSourceContext fetchSourceContext = getFetchSourceContext(mlModelGroupGetRequest.isReturnContent());
GetRequest getRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId).fetchSourceContext(fetchSourceContext);
User user = RestActionUtils.getUserContext(client);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLModelGroupGetResponse> wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore());
client.get(getRequest, ActionListener.wrap(r -> {
if (r != null && r.isExists()) {
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
GetResponse getResponse = r;

MLModelGroup mlModelGroup = MLModelGroup.parse(parser);
modelAccessControlHelper.validateModelGroupAccess(user, modelGroupId, client, ActionListener.wrap(access -> {
if (!access) {
wrappedListener
.onFailure(
new MLValidationException(
"User doesn't have privilege to perform this operation on this model group"
)
);
} else {
wrappedListener.onResponse(MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build());
}
}, e -> {
log.error("Failed to validate access for Model Group " + modelGroupId, e);
wrappedListener.onFailure(e);
}));

} catch (Exception e) {
log.error("Failed to parse ml model group" + r.getId(), e);
wrappedListener.onFailure(e);
}
} else {
wrappedListener
.onFailure(
new OpenSearchStatusException(
"Failed to find model group with the provided model group id: " + modelGroupId,
RestStatus.NOT_FOUND
)
);
}
}, e -> {
if (e instanceof IndexNotFoundException) {
wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group index"));
} else {
log.error("Failed to get ML model group" + modelGroupId, e);
wrappedListener.onFailure(e);
}
}));
} catch (Exception e) {
log.error("Failed to get ML model group " + modelGroupId, e);
actionListener.onFailure(e);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.opensearch.ml.action.forward.TransportForwardAction;
import org.opensearch.ml.action.handler.MLSearchHandler;
import org.opensearch.ml.action.model_group.DeleteModelGroupTransportAction;
import org.opensearch.ml.action.model_group.GetModelGroupTransportAction;
import org.opensearch.ml.action.model_group.SearchModelGroupTransportAction;
import org.opensearch.ml.action.model_group.TransportRegisterModelGroupAction;
import org.opensearch.ml.action.model_group.TransportUpdateModelGroupAction;
Expand Down Expand Up @@ -103,6 +104,7 @@
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.ml.common.transport.model.MLUpdateModelAction;
import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteAction;
import org.opensearch.ml.common.transport.model_group.MLModelGroupGetAction;
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupAction;
import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupAction;
Expand Down Expand Up @@ -153,6 +155,7 @@
import org.opensearch.ml.rest.RestMLExecuteAction;
import org.opensearch.ml.rest.RestMLGetConnectorAction;
import org.opensearch.ml.rest.RestMLGetModelAction;
import org.opensearch.ml.rest.RestMLGetModelGroupAction;
import org.opensearch.ml.rest.RestMLGetTaskAction;
import org.opensearch.ml.rest.RestMLPredictionAction;
import org.opensearch.ml.rest.RestMLProfileAction;
Expand Down Expand Up @@ -290,6 +293,7 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc
new ActionHandler<>(MLSyncUpAction.INSTANCE, TransportSyncUpOnNodeAction.class),
new ActionHandler<>(MLRegisterModelGroupAction.INSTANCE, TransportRegisterModelGroupAction.class),
new ActionHandler<>(MLUpdateModelGroupAction.INSTANCE, TransportUpdateModelGroupAction.class),
new ActionHandler<>(MLModelGroupGetAction.INSTANCE, GetModelGroupTransportAction.class),
new ActionHandler<>(MLModelGroupSearchAction.INSTANCE, SearchModelGroupTransportAction.class),
new ActionHandler<>(MLModelGroupDeleteAction.INSTANCE, DeleteModelGroupTransportAction.class),
new ActionHandler<>(MLCreateConnectorAction.INSTANCE, TransportCreateConnectorAction.class),
Expand Down Expand Up @@ -539,6 +543,7 @@ public List<RestHandler> getRestHandlers(
RestMLUploadModelChunkAction restMLUploadModelChunkAction = new RestMLUploadModelChunkAction(clusterService, settings);
RestMLRegisterModelGroupAction restMLCreateModelGroupAction = new RestMLRegisterModelGroupAction();
RestMLUpdateModelGroupAction restMLUpdateModelGroupAction = new RestMLUpdateModelGroupAction();
RestMLGetModelGroupAction restMLGetModelGroupAction = new RestMLGetModelGroupAction();
RestMLSearchModelGroupAction restMLSearchModelGroupAction = new RestMLSearchModelGroupAction();
RestMLUpdateModelAction restMLUpdateModelAction = new RestMLUpdateModelAction();
RestMLDeleteModelGroupAction restMLDeleteModelGroupAction = new RestMLDeleteModelGroupAction();
Expand Down Expand Up @@ -574,6 +579,7 @@ public List<RestHandler> getRestHandlers(
restMLUploadModelChunkAction,
restMLCreateModelGroupAction,
restMLUpdateModelGroupAction,
restMLGetModelGroupAction,
restMLSearchModelGroupAction,
restMLDeleteModelGroupAction,
restMLCreateConnectorAction,
Expand Down
Loading

0 comments on commit e231a97

Please sign in to comment.