Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change to model group access for batch job task APIs #3098

Merged
merged 2 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
Expand All @@ -42,6 +43,7 @@
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
Expand All @@ -54,9 +56,11 @@
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.script.ScriptService;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand All @@ -73,6 +77,7 @@ public class CancelBatchJobTransportAction extends HandledTransportAction<Action
ScriptService scriptService;

ConnectorAccessControlHelper connectorAccessControlHelper;
ModelAccessControlHelper modelAccessControlHelper;
EncryptorImpl encryptor;
MLModelManager mlModelManager;

Expand All @@ -88,6 +93,7 @@ public CancelBatchJobTransportAction(
ClusterService clusterService,
ScriptService scriptService,
ConnectorAccessControlHelper connectorAccessControlHelper,
ModelAccessControlHelper modelAccessControlHelper,
EncryptorImpl encryptor,
MLTaskManager mlTaskManager,
MLModelManager mlModelManager,
Expand All @@ -99,6 +105,7 @@ public CancelBatchJobTransportAction(
this.clusterService = clusterService;
this.scriptService = scriptService;
this.connectorAccessControlHelper = connectorAccessControlHelper;
this.modelAccessControlHelper = modelAccessControlHelper;
this.encryptor = encryptor;
this.mlTaskManager = mlTaskManager;
this.mlModelManager = mlModelManager;
Expand Down Expand Up @@ -177,25 +184,39 @@ private void processRemoteBatchPrediction(MLTask mlTask, ActionListener<MLCancel
RemoteInferenceInputDataSet inferenceInputDataSet = new RemoteInferenceInputDataSet(parameters, ActionType.BATCH_PREDICT_STATUS);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inferenceInputDataSet).build();
String modelId = mlTask.getModelId();
User user = RestActionUtils.getUserContext(client);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLModel> getModelListener = ActionListener.wrap(model -> {
if (model.getConnector() != null) {
Connector connector = model.getConnector();
executeConnector(connector, mlInput, actionListener);
} else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) {
ActionListener<Connector> listener = ActionListener
.wrap(connector -> { executeConnector(connector, mlInput, actionListener); }, e -> {
log.error("Failed to get connector " + model.getConnectorId(), e);
actionListener.onFailure(e);
});
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
connectorAccessControlHelper
.getConnector(client, model.getConnectorId(), ActionListener.runBefore(listener, threadContext::restore));
modelAccessControlHelper.validateModelGroupAccess(user, model.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
actionListener.onFailure(new MLValidationException("You don't have permission to cancel this batch job"));
} else {
if (model.getConnector() != null) {
Connector connector = model.getConnector();
executeConnector(connector, mlInput, actionListener);
} else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) {
ActionListener<Connector> listener = ActionListener
.wrap(connector -> { executeConnector(connector, mlInput, actionListener); }, e -> {
log.error("Failed to get connector " + model.getConnectorId(), e);
actionListener.onFailure(e);
});
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
connectorAccessControlHelper
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check connector access control again here? Are model/connector all access controlled through the model group so ideally any role would have the same permission to access both models and connectors?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its fine. batch_predict also only checks for only model access for now. Since predict API is controlled by model access, it makes sense to have it for batch job APIs too. If we are planning to have connector access too, then we should also include it into predict APIs for consistency.

And yes, I don't think connector access is controlled by model group access control. They were designed separately. I think we have to either unify them or remove one of them for remote models. Does not make sense to have both at connector and model level.

.getConnector(
client,
model.getConnectorId(),
ActionListener.runBefore(listener, threadContext::restore)
);
}
} else {
actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId()));
}
}
} else {
actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId()));
}
}, e -> {
log.error("Failed to validate Access for Model Group " + model.getModelGroupId(), e);
actionListener.onFailure(e);
}));
}, e -> {
log.error("Failed to retrieve the ML model with the given ID", e);
actionListener
Expand All @@ -211,26 +232,20 @@ private void processRemoteBatchPrediction(MLTask mlTask, ActionListener<MLCancel
}

private void executeConnector(Connector connector, MLInput mlInput, ActionListener<MLCancelBatchJobResponse> actionListener) {
if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) {
Optional<ConnectorAction> cancelBatchPredictAction = connector.findAction(CANCEL_BATCH_PREDICT.name());
if (!cancelBatchPredictAction.isPresent() || cancelBatchPredictAction.get().getRequestBody() == null) {
ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, CANCEL_BATCH_PREDICT);
connector.addAction(connectorAction);
}
connector.decrypt(CANCEL_BATCH_PREDICT.name(), (credential) -> encryptor.decrypt(credential));
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader
.initInstance(connector.getProtocol(), connector, Connector.class);
connectorExecutor.setScriptService(scriptService);
connectorExecutor.setClusterService(clusterService);
connectorExecutor.setClient(client);
connectorExecutor.setXContentRegistry(xContentRegistry);
connectorExecutor.executeAction(CANCEL_BATCH_PREDICT.name(), mlInput, ActionListener.wrap(taskResponse -> {
processTaskResponse(taskResponse, actionListener);
}, e -> { actionListener.onFailure(e); }));
} else {
actionListener
.onFailure(new OpenSearchStatusException("You don't have permission to access this connector", RestStatus.FORBIDDEN));
Optional<ConnectorAction> cancelBatchPredictAction = connector.findAction(CANCEL_BATCH_PREDICT.name());
if (!cancelBatchPredictAction.isPresent() || cancelBatchPredictAction.get().getRequestBody() == null) {
ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, CANCEL_BATCH_PREDICT);
connector.addAction(connectorAction);
}
connector.decrypt(CANCEL_BATCH_PREDICT.name(), (credential) -> encryptor.decrypt(credential));
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class);
connectorExecutor.setScriptService(scriptService);
connectorExecutor.setClusterService(clusterService);
connectorExecutor.setClient(client);
connectorExecutor.setXContentRegistry(xContentRegistry);
connectorExecutor.executeAction(CANCEL_BATCH_PREDICT.name(), mlInput, ActionListener.wrap(taskResponse -> {
processTaskResponse(taskResponse, actionListener);
}, e -> { actionListener.onFailure(e); }));
}

private void processTaskResponse(MLTaskResponse taskResponse, ActionListener<MLCancelBatchJobResponse> actionListener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
Expand All @@ -59,6 +60,7 @@
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
Expand All @@ -71,9 +73,11 @@
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.script.ScriptService;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand All @@ -90,6 +94,7 @@ public class GetTaskTransportAction extends HandledTransportAction<ActionRequest
ScriptService scriptService;

ConnectorAccessControlHelper connectorAccessControlHelper;
ModelAccessControlHelper modelAccessControlHelper;
EncryptorImpl encryptor;
MLModelManager mlModelManager;

Expand All @@ -111,6 +116,7 @@ public GetTaskTransportAction(
ClusterService clusterService,
ScriptService scriptService,
ConnectorAccessControlHelper connectorAccessControlHelper,
ModelAccessControlHelper modelAccessControlHelper,
EncryptorImpl encryptor,
MLTaskManager mlTaskManager,
MLModelManager mlModelManager,
Expand All @@ -123,6 +129,7 @@ public GetTaskTransportAction(
this.clusterService = clusterService;
this.scriptService = scriptService;
this.connectorAccessControlHelper = connectorAccessControlHelper;
this.modelAccessControlHelper = modelAccessControlHelper;
this.encryptor = encryptor;
this.mlTaskManager = mlTaskManager;
this.mlModelManager = mlModelManager;
Expand Down Expand Up @@ -238,26 +245,40 @@ private void processRemoteBatchPrediction(MLTask mlTask, String taskId, ActionLi
RemoteInferenceInputDataSet inferenceInputDataSet = new RemoteInferenceInputDataSet(parameters, ActionType.BATCH_PREDICT_STATUS);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inferenceInputDataSet).build();
String modelId = mlTask.getModelId();
User user = RestActionUtils.getUserContext(client);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLModel> getModelListener = ActionListener.wrap(model -> {
if (model.getConnector() != null) {
Connector connector = model.getConnector();
executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener);
} else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) {
ActionListener<Connector> listener = ActionListener.wrap(connector -> {
executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener);
}, e -> {
log.error("Failed to get connector " + model.getConnectorId(), e);
actionListener.onFailure(e);
});
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
connectorAccessControlHelper
.getConnector(client, model.getConnectorId(), ActionListener.runBefore(listener, threadContext::restore));
modelAccessControlHelper.validateModelGroupAccess(user, model.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
actionListener.onFailure(new MLValidationException("You don't have permission to access this batch job"));
} else {
if (model.getConnector() != null) {
Connector connector = model.getConnector();
executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener);
} else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) {
ActionListener<Connector> listener = ActionListener.wrap(connector -> {
executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener);
}, e -> {
log.error("Failed to get connector " + model.getConnectorId(), e);
actionListener.onFailure(e);
});
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
connectorAccessControlHelper
.getConnector(
client,
model.getConnectorId(),
ActionListener.runBefore(listener, threadContext::restore)
);
}
} else {
actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId()));
}
}
} else {
actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId()));
}
}, e -> {
log.error("Failed to validate Access for Model Group " + model.getModelGroupId(), e);
actionListener.onFailure(e);
}));
}, e -> {
log.error("Failed to retrieve the ML model for the given task ID", e);
actionListener
Expand All @@ -280,26 +301,20 @@ private void executeConnector(
Map<String, Object> remoteJob,
ActionListener<MLTaskGetResponse> actionListener
) {
if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) {
Optional<ConnectorAction> batchPredictStatusAction = connector.findAction(BATCH_PREDICT_STATUS.name());
if (!batchPredictStatusAction.isPresent() || batchPredictStatusAction.get().getRequestBody() == null) {
ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, BATCH_PREDICT_STATUS);
connector.addAction(connectorAction);
}
connector.decrypt(BATCH_PREDICT_STATUS.name(), (credential) -> encryptor.decrypt(credential));
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader
.initInstance(connector.getProtocol(), connector, Connector.class);
connectorExecutor.setScriptService(scriptService);
connectorExecutor.setClusterService(clusterService);
connectorExecutor.setClient(client);
connectorExecutor.setXContentRegistry(xContentRegistry);
connectorExecutor.executeAction(BATCH_PREDICT_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> {
processTaskResponse(mlTask, taskId, taskResponse, remoteJob, actionListener);
}, e -> { actionListener.onFailure(e); }));
} else {
actionListener
.onFailure(new OpenSearchStatusException("You don't have permission to access this connector", RestStatus.FORBIDDEN));
Optional<ConnectorAction> batchPredictStatusAction = connector.findAction(BATCH_PREDICT_STATUS.name());
if (!batchPredictStatusAction.isPresent() || batchPredictStatusAction.get().getRequestBody() == null) {
ConnectorAction connectorAction = ConnectorUtils.createConnectorAction(connector, BATCH_PREDICT_STATUS);
connector.addAction(connectorAction);
}
connector.decrypt(BATCH_PREDICT_STATUS.name(), (credential) -> encryptor.decrypt(credential));
RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class);
connectorExecutor.setScriptService(scriptService);
connectorExecutor.setClusterService(clusterService);
connectorExecutor.setClient(client);
connectorExecutor.setXContentRegistry(xContentRegistry);
connectorExecutor.executeAction(BATCH_PREDICT_STATUS.name(), mlInput, ActionListener.wrap(taskResponse -> {
processTaskResponse(mlTask, taskId, taskResponse, remoteJob, actionListener);
}, e -> { actionListener.onFailure(e); }));
}

protected void processTaskResponse(
Expand Down
Loading
Loading