Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.17] change to model group access for batch job task APIs #3103

Merged
merged 1 commit into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
Expand All @@ -42,6 +43,7 @@
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
Expand All @@ -54,9 +56,11 @@
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.script.ScriptService;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand All @@ -73,6 +77,7 @@ public class CancelBatchJobTransportAction extends HandledTransportAction<Action
ScriptService scriptService;

ConnectorAccessControlHelper connectorAccessControlHelper;
ModelAccessControlHelper modelAccessControlHelper;
EncryptorImpl encryptor;
MLModelManager mlModelManager;

Expand All @@ -88,6 +93,7 @@ public CancelBatchJobTransportAction(
ClusterService clusterService,
ScriptService scriptService,
ConnectorAccessControlHelper connectorAccessControlHelper,
ModelAccessControlHelper modelAccessControlHelper,
EncryptorImpl encryptor,
MLTaskManager mlTaskManager,
MLModelManager mlModelManager,
Expand All @@ -99,6 +105,7 @@ public CancelBatchJobTransportAction(
this.clusterService = clusterService;
this.scriptService = scriptService;
this.connectorAccessControlHelper = connectorAccessControlHelper;
this.modelAccessControlHelper = modelAccessControlHelper;
this.encryptor = encryptor;
this.mlTaskManager = mlTaskManager;
this.mlModelManager = mlModelManager;
Expand Down Expand Up @@ -177,25 +184,39 @@ private void processRemoteBatchPrediction(MLTask mlTask, ActionListener<MLCancel
RemoteInferenceInputDataSet inferenceInputDataSet = new RemoteInferenceInputDataSet(parameters, ActionType.BATCH_PREDICT_STATUS);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inferenceInputDataSet).build();
String modelId = mlTask.getModelId();
User user = RestActionUtils.getUserContext(client);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLModel> getModelListener = ActionListener.wrap(model -> {
if (model.getConnector() != null) {
Connector connector = model.getConnector();
executeConnector(connector, mlInput, actionListener);
} else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) {
ActionListener<Connector> listener = ActionListener
.wrap(connector -> { executeConnector(connector, mlInput, actionListener); }, e -> {
log.error("Failed to get connector " + model.getConnectorId(), e);
actionListener.onFailure(e);
});
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
connectorAccessControlHelper
.getConnector(client, model.getConnectorId(), ActionListener.runBefore(listener, threadContext::restore));
modelAccessControlHelper.validateModelGroupAccess(user, model.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
actionListener.onFailure(new MLValidationException("You don't have permission to cancel this batch job"));
} else {
if (model.getConnector() != null) {
Connector connector = model.getConnector();
executeConnector(connector, mlInput, actionListener);
} else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) {
ActionListener<Connector> listener = ActionListener
.wrap(connector -> { executeConnector(connector, mlInput, actionListener); }, e -> {
log.error("Failed to get connector " + model.getConnectorId(), e);
actionListener.onFailure(e);
});
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
connectorAccessControlHelper
.getConnector(
client,
model.getConnectorId(),
ActionListener.runBefore(listener, threadContext::restore)
);
}
} else {
actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId()));
}
}
} else {
actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId()));
}
}, e -> {
log.error("Failed to validate Access for Model Group " + model.getModelGroupId(), e);
actionListener.onFailure(e);
}));
}, e -> {
log.error("Failed to retrieve the ML model with the given ID", e);
actionListener
Expand All @@ -211,26 +232,20 @@ private void processRemoteBatchPrediction(MLTask mlTask, ActionListener<MLCancel
}

private void executeConnector(Connector connector, MLInput mlInput, ActionListener<MLCancelBatchJobResponse> actionListener) {
if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) {
Optional<ConnectorAction> cancelBatchPredictAction = connector.findAction(CANCEL_BATCH_PREDICT.name());
if (!cancelBatchPredictAction.isPresent() || cancelBatchPredictAction.get().getRequestBody() == null) {
ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, CANCEL_BATCH_PREDICT);
connector.addAction(connectorAction);
}
connector.decrypt(CANCEL_BATCH_PREDICT.name(), (credential) -> encryptor.decrypt(credential));
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader
.initInstance(connector.getProtocol(), connector, Connector.class);
connectorExecutor.setScriptService(scriptService);
connectorExecutor.setClusterService(clusterService);
connectorExecutor.setClient(client);
connectorExecutor.setXContentRegistry(xContentRegistry);
connectorExecutor.executeAction(CANCEL_BATCH_PREDICT.name(), mlInput, ActionListener.wrap(taskResponse -> {
processTaskResponse(taskResponse, actionListener);
}, e -> { actionListener.onFailure(e); }));
} else {
actionListener
.onFailure(new OpenSearchStatusException("You don't have permission to access this connector", RestStatus.FORBIDDEN));
Optional<ConnectorAction> cancelBatchPredictAction = connector.findAction(CANCEL_BATCH_PREDICT.name());
if (!cancelBatchPredictAction.isPresent() || cancelBatchPredictAction.get().getRequestBody() == null) {
ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, CANCEL_BATCH_PREDICT);
connector.addAction(connectorAction);
}
connector.decrypt(CANCEL_BATCH_PREDICT.name(), (credential) -> encryptor.decrypt(credential));
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class);
connectorExecutor.setScriptService(scriptService);
connectorExecutor.setClusterService(clusterService);
connectorExecutor.setClient(client);
connectorExecutor.setXContentRegistry(xContentRegistry);
connectorExecutor.executeAction(CANCEL_BATCH_PREDICT.name(), mlInput, ActionListener.wrap(taskResponse -> {
processTaskResponse(taskResponse, actionListener);
}, e -> { actionListener.onFailure(e); }));
}

private void processTaskResponse(MLTaskResponse taskResponse, ActionListener<MLCancelBatchJobResponse> actionListener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
Expand All @@ -59,6 +60,7 @@
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
Expand All @@ -71,9 +73,11 @@
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.script.ScriptService;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand All @@ -90,6 +94,7 @@ public class GetTaskTransportAction extends HandledTransportAction<ActionRequest
ScriptService scriptService;

ConnectorAccessControlHelper connectorAccessControlHelper;
ModelAccessControlHelper modelAccessControlHelper;
EncryptorImpl encryptor;
MLModelManager mlModelManager;

Expand All @@ -111,6 +116,7 @@ public GetTaskTransportAction(
ClusterService clusterService,
ScriptService scriptService,
ConnectorAccessControlHelper connectorAccessControlHelper,
ModelAccessControlHelper modelAccessControlHelper,
EncryptorImpl encryptor,
MLTaskManager mlTaskManager,
MLModelManager mlModelManager,
Expand All @@ -123,6 +129,7 @@ public GetTaskTransportAction(
this.clusterService = clusterService;
this.scriptService = scriptService;
this.connectorAccessControlHelper = connectorAccessControlHelper;
this.modelAccessControlHelper = modelAccessControlHelper;
this.encryptor = encryptor;
this.mlTaskManager = mlTaskManager;
this.mlModelManager = mlModelManager;
Expand Down Expand Up @@ -238,26 +245,40 @@ private void processRemoteBatchPrediction(MLTask mlTask, String taskId, ActionLi
RemoteInferenceInputDataSet inferenceInputDataSet = new RemoteInferenceInputDataSet(parameters, ActionType.BATCH_PREDICT_STATUS);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inferenceInputDataSet).build();
String modelId = mlTask.getModelId();
User user = RestActionUtils.getUserContext(client);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLModel> getModelListener = ActionListener.wrap(model -> {
if (model.getConnector() != null) {
Connector connector = model.getConnector();
executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener);
} else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) {
ActionListener<Connector> listener = ActionListener.wrap(connector -> {
executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener);
}, e -> {
log.error("Failed to get connector " + model.getConnectorId(), e);
actionListener.onFailure(e);
});
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
connectorAccessControlHelper
.getConnector(client, model.getConnectorId(), ActionListener.runBefore(listener, threadContext::restore));
modelAccessControlHelper.validateModelGroupAccess(user, model.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
actionListener.onFailure(new MLValidationException("You don't have permission to access this batch job"));
} else {
if (model.getConnector() != null) {
Connector connector = model.getConnector();
executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener);
} else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) {
ActionListener<Connector> listener = ActionListener.wrap(connector -> {
executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener);
}, e -> {
log.error("Failed to get connector " + model.getConnectorId(), e);
actionListener.onFailure(e);
});
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
connectorAccessControlHelper
.getConnector(
client,
model.getConnectorId(),
ActionListener.runBefore(listener, threadContext::restore)
);
}
} else {
actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId()));
}
}
} else {
actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId()));
}
}, e -> {
log.error("Failed to validate Access for Model Group " + model.getModelGroupId(), e);
actionListener.onFailure(e);
}));
}, e -> {
log.error("Failed to retrieve the ML model for the given task ID", e);
actionListener
Expand All @@ -280,26 +301,20 @@ private void executeConnector(
Map<String, Object> remoteJob,
ActionListener<MLTaskGetResponse> actionListener
) {
if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) {
Optional<ConnectorAction> batchPredictStatusAction = connector.findAction(BATCH_PREDICT_STATUS.name());
if (!batchPredictStatusAction.isPresent() || batchPredictStatusAction.get().getRequestBody() == null) {
ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, BATCH_PREDICT_STATUS);
connector.addAction(connectorAction);
}
connector.decrypt(BATCH_PREDICT_STATUS.name(), (credential) -> encryptor.decrypt(credential));
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader
.initInstance(connector.getProtocol(), connector, Connector.class);
connectorExecutor.setScriptService(scriptService);
connectorExecutor.setClusterService(clusterService);
connectorExecutor.setClient(client);
connectorExecutor.setXContentRegistry(xContentRegistry);
connectorExecutor.executeAction(BATCH_PREDICT_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> {
processTaskResponse(mlTask, taskId, taskResponse, remoteJob, actionListener);
}, e -> { actionListener.onFailure(e); }));
} else {
actionListener
.onFailure(new OpenSearchStatusException("You don't have permission to access this connector", RestStatus.FORBIDDEN));
Optional<ConnectorAction> batchPredictStatusAction = connector.findAction(BATCH_PREDICT_STATUS.name());
if (!batchPredictStatusAction.isPresent() || batchPredictStatusAction.get().getRequestBody() == null) {
ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, BATCH_PREDICT_STATUS);
connector.addAction(connectorAction);
}
connector.decrypt(BATCH_PREDICT_STATUS.name(), (credential) -> encryptor.decrypt(credential));
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class);
connectorExecutor.setScriptService(scriptService);
connectorExecutor.setClusterService(clusterService);
connectorExecutor.setClient(client);
connectorExecutor.setXContentRegistry(xContentRegistry);
connectorExecutor.executeAction(BATCH_PREDICT_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> {
processTaskResponse(mlTask, taskId, taskResponse, remoteJob, actionListener);
}, e -> { actionListener.onFailure(e); }));
}

protected void processTaskResponse(
Expand Down
Loading
Loading