forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Bhavana Ramaram <[email protected]>
- Loading branch information
Showing
7 changed files
with
369 additions
and
1 deletion.
There are no files selected for viewing
15 changes: 15 additions & 0 deletions
15
...n/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.model_group; | ||
|
||
import org.opensearch.action.ActionType; | ||
|
||
public class MLModelGroupGetAction extends ActionType<MLModelGroupGetResponse> { | ||
public static final MLModelGroupGetAction INSTANCE = new MLModelGroupGetAction(); | ||
public static final String NAME = "cluster:admin/opensearch/ml/model_groups/get"; | ||
|
||
private MLModelGroupGetAction() { super(NAME, MLModelGroupGetResponse::new);} | ||
} |
80 changes: 80 additions & 0 deletions
80
.../src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.model_group; | ||
|
||
import lombok.AccessLevel; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.ToString; | ||
import lombok.experimental.FieldDefaults; | ||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.ActionRequestValidationException; | ||
import org.opensearch.core.common.io.stream.InputStreamStreamInput; | ||
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
|
||
import java.io.ByteArrayInputStream; | ||
import java.io.ByteArrayOutputStream; | ||
import java.io.IOException; | ||
import java.io.UncheckedIOException; | ||
|
||
import static org.opensearch.action.ValidateActions.addValidationError; | ||
|
||
@Getter | ||
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) | ||
@ToString | ||
public class MLModelGroupGetRequest extends ActionRequest { | ||
|
||
String modelGroupId; | ||
boolean returnContent; | ||
|
||
@Builder | ||
public MLModelGroupGetRequest(String modelGroupId, boolean returnContent) { | ||
this.modelGroupId = modelGroupId; | ||
this.returnContent = returnContent; | ||
} | ||
|
||
public MLModelGroupGetRequest(StreamInput in) throws IOException { | ||
super(in); | ||
this.modelGroupId = in.readString(); | ||
this.returnContent = in.readBoolean(); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
out.writeString(this.modelGroupId); | ||
out.writeBoolean(returnContent); | ||
} | ||
|
||
@Override | ||
public ActionRequestValidationException validate() { | ||
ActionRequestValidationException exception = null; | ||
|
||
if (this.modelGroupId == null) { | ||
exception = addValidationError("Model group id can't be null", exception); | ||
} | ||
|
||
return exception; | ||
} | ||
|
||
public static MLModelGroupGetRequest fromActionRequest(ActionRequest actionRequest) { | ||
if (actionRequest instanceof MLModelGroupGetRequest) { | ||
return (MLModelGroupGetRequest)actionRequest; | ||
} | ||
|
||
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); | ||
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { | ||
actionRequest.writeTo(osso); | ||
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { | ||
return new MLModelGroupGetRequest(input); | ||
} | ||
} catch (IOException e) { | ||
throw new UncheckedIOException("failed to parse ActionRequest into MLModelGroupGetRequest", e); | ||
} | ||
} | ||
} |
67 changes: 67 additions & 0 deletions
67
...src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetResponse.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.model_group; | ||
|
||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.ToString; | ||
import org.opensearch.core.action.ActionResponse; | ||
import org.opensearch.core.common.io.stream.InputStreamStreamInput; | ||
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.xcontent.ToXContentObject; | ||
import org.opensearch.core.xcontent.XContentBuilder; | ||
import org.opensearch.ml.common.MLModelGroup; | ||
|
||
import java.io.ByteArrayInputStream; | ||
import java.io.ByteArrayOutputStream; | ||
import java.io.IOException; | ||
import java.io.UncheckedIOException; | ||
|
||
@Getter | ||
@ToString | ||
public class MLModelGroupGetResponse extends ActionResponse implements ToXContentObject { | ||
|
||
MLModelGroup mlModelGroup; | ||
|
||
@Builder | ||
public MLModelGroupGetResponse(MLModelGroup mlModelGroup) { | ||
this.mlModelGroup = mlModelGroup; | ||
} | ||
|
||
|
||
public MLModelGroupGetResponse(StreamInput in) throws IOException { | ||
super(in); | ||
mlModelGroup = mlModelGroup.fromStream(in); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException{ | ||
mlModelGroup.writeTo(out); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { | ||
return mlModelGroup.toXContent(xContentBuilder, params); | ||
} | ||
|
||
public static MLModelGroupGetResponse fromActionResponse(ActionResponse actionResponse) { | ||
if (actionResponse instanceof MLModelGroupGetResponse) { | ||
return (MLModelGroupGetResponse) actionResponse; | ||
} | ||
|
||
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); | ||
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { | ||
actionResponse.writeTo(osso); | ||
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { | ||
return new MLModelGroupGetResponse(input); | ||
} | ||
} catch (IOException e) { | ||
throw new UncheckedIOException("failed to parse ActionResponse into MLModelGroupGetResponse", e); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
131 changes: 131 additions & 0 deletions
131
plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.action.model_group; | ||
|
||
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; | ||
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; | ||
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; | ||
import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; | ||
|
||
import org.opensearch.OpenSearchStatusException; | ||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.get.GetRequest; | ||
import org.opensearch.action.get.GetResponse; | ||
import org.opensearch.action.support.ActionFilters; | ||
import org.opensearch.action.support.HandledTransportAction; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.cluster.service.ClusterService; | ||
import org.opensearch.common.inject.Inject; | ||
import org.opensearch.common.util.concurrent.ThreadContext; | ||
import org.opensearch.commons.authuser.User; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.rest.RestStatus; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.index.IndexNotFoundException; | ||
import org.opensearch.ml.common.MLModelGroup; | ||
import org.opensearch.ml.common.exception.MLResourceNotFoundException; | ||
import org.opensearch.ml.common.exception.MLValidationException; | ||
import org.opensearch.ml.common.transport.model_group.MLModelGroupGetAction; | ||
import org.opensearch.ml.common.transport.model_group.MLModelGroupGetRequest; | ||
import org.opensearch.ml.common.transport.model_group.MLModelGroupGetResponse; | ||
import org.opensearch.ml.helper.ModelAccessControlHelper; | ||
import org.opensearch.ml.utils.RestActionUtils; | ||
import org.opensearch.search.fetch.subphase.FetchSourceContext; | ||
import org.opensearch.tasks.Task; | ||
import org.opensearch.transport.TransportService; | ||
|
||
import lombok.AccessLevel; | ||
import lombok.experimental.FieldDefaults; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
@Log4j2 | ||
@FieldDefaults(level = AccessLevel.PRIVATE) | ||
public class GetModelGroupTransportAction extends HandledTransportAction<ActionRequest, MLModelGroupGetResponse> { | ||
|
||
Client client; | ||
NamedXContentRegistry xContentRegistry; | ||
ClusterService clusterService; | ||
|
||
ModelAccessControlHelper modelAccessControlHelper; | ||
|
||
@Inject | ||
public GetModelGroupTransportAction( | ||
TransportService transportService, | ||
ActionFilters actionFilters, | ||
Client client, | ||
NamedXContentRegistry xContentRegistry, | ||
ClusterService clusterService, | ||
ModelAccessControlHelper modelAccessControlHelper | ||
) { | ||
super(MLModelGroupGetAction.NAME, transportService, actionFilters, MLModelGroupGetRequest::new); | ||
this.client = client; | ||
this.xContentRegistry = xContentRegistry; | ||
this.clusterService = clusterService; | ||
this.modelAccessControlHelper = modelAccessControlHelper; | ||
} | ||
|
||
@Override | ||
protected void doExecute(Task task, ActionRequest request, ActionListener<MLModelGroupGetResponse> actionListener) { | ||
MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.fromActionRequest(request); | ||
String modelGroupId = mlModelGroupGetRequest.getModelGroupId(); | ||
FetchSourceContext fetchSourceContext = getFetchSourceContext(mlModelGroupGetRequest.isReturnContent()); | ||
GetRequest getRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId).fetchSourceContext(fetchSourceContext); | ||
User user = RestActionUtils.getUserContext(client); | ||
|
||
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { | ||
ActionListener<MLModelGroupGetResponse> wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); | ||
client.get(getRequest, ActionListener.wrap(r -> { | ||
if (r != null && r.isExists()) { | ||
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { | ||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); | ||
GetResponse getResponse = r; | ||
|
||
MLModelGroup mlModelGroup = MLModelGroup.parse(parser); | ||
modelAccessControlHelper.validateModelGroupAccess(user, modelGroupId, client, ActionListener.wrap(access -> { | ||
if (!access) { | ||
wrappedListener | ||
.onFailure( | ||
new MLValidationException( | ||
"User doesn't have privilege to perform this operation on this model group" | ||
) | ||
); | ||
} else { | ||
wrappedListener.onResponse(MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build()); | ||
} | ||
}, e -> { | ||
log.error("Failed to validate access for Model Group " + modelGroupId, e); | ||
wrappedListener.onFailure(e); | ||
})); | ||
|
||
} catch (Exception e) { | ||
log.error("Failed to parse ml model group" + r.getId(), e); | ||
wrappedListener.onFailure(e); | ||
} | ||
} else { | ||
wrappedListener | ||
.onFailure( | ||
new OpenSearchStatusException( | ||
"Failed to find model group with the provided model group id: " + modelGroupId, | ||
RestStatus.NOT_FOUND | ||
) | ||
); | ||
} | ||
}, e -> { | ||
if (e instanceof IndexNotFoundException) { | ||
wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group index")); | ||
} else { | ||
log.error("Failed to get ML model group" + modelGroupId, e); | ||
wrappedListener.onFailure(e); | ||
} | ||
})); | ||
} catch (Exception e) { | ||
log.error("Failed to get ML model group " + modelGroupId, e); | ||
actionListener.onFailure(e); | ||
} | ||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.