Skip to content

Commit

Permalink
Merge branch 'feature/multi_tenancy' into tenant-setting-static
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis authored Sep 22, 2024
2 parents 52c3a43 + d59fced commit 494820c
Show file tree
Hide file tree
Showing 21 changed files with 263 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.ml.common.FunctionName.REMOTE;
import static org.opensearch.ml.common.FunctionName.TEXT_EMBEDDING;
import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage;
import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -51,6 +52,7 @@
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand All @@ -68,6 +70,7 @@ public class CreateControllerTransportAction extends HandledTransportAction<Acti
ClusterService clusterService;
MLModelCacheHelper mlModelCacheHelper;
ModelAccessControlHelper modelAccessControlHelper;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public CreateControllerTransportAction(
Expand All @@ -78,7 +81,8 @@ public CreateControllerTransportAction(
ClusterService clusterService,
ModelAccessControlHelper modelAccessControlHelper,
MLModelCacheHelper mlModelCacheHelper,
MLModelManager mlModelManager
MLModelManager mlModelManager,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLCreateControllerAction.NAME, transportService, actionFilters, MLCreateControllerRequest::new);
this.mlIndicesHandler = mlIndicesHandler;
Expand All @@ -87,6 +91,7 @@ public CreateControllerTransportAction(
this.clusterService = clusterService;
this.mlModelCacheHelper = mlModelCacheHelper;
this.modelAccessControlHelper = modelAccessControlHelper;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
Expand All @@ -98,6 +103,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLCrea
String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD };

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
if (!mlFeatureEnabledSetting.isControllerEnabled()) {
throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG);
}
ActionListener<MLCreateControllerResponse> wrappedListener = ActionListener.runBefore(actionListener, context::restore);
// TODO: Add support for multi tenancy
mlModelManager.getModel(modelId, null, null, excludes, ActionListener.wrap(mlModel -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX;
import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage;
import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -40,6 +41,7 @@
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand All @@ -57,6 +59,7 @@ public class DeleteControllerTransportAction extends HandledTransportAction<Acti
MLModelManager mlModelManager;
MLModelCacheHelper mlModelCacheHelper;
ModelAccessControlHelper modelAccessControlHelper;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public DeleteControllerTransportAction(
Expand All @@ -67,7 +70,8 @@ public DeleteControllerTransportAction(
ClusterService clusterService,
MLModelManager mlModelManager,
MLModelCacheHelper mlModelCacheHelper,
ModelAccessControlHelper modelAccessControlHelper
ModelAccessControlHelper modelAccessControlHelper,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLControllerDeleteAction.NAME, transportService, actionFilters, MLControllerDeleteRequest::new);
this.client = client;
Expand All @@ -76,6 +80,7 @@ public DeleteControllerTransportAction(
this.mlModelManager = mlModelManager;
this.mlModelCacheHelper = mlModelCacheHelper;
this.modelAccessControlHelper = modelAccessControlHelper;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
Expand All @@ -85,6 +90,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
User user = RestActionUtils.getUserContext(client);
String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD };
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
if (!mlFeatureEnabledSetting.isControllerEnabled()) {
throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG);
}
ActionListener<DeleteResponse> wrappedListener = ActionListener.runBefore(actionListener, context::restore);
// TODO: Add support for multi tenancy
mlModelManager.getModel(modelId, null, null, excludes, ActionListener.wrap(mlModel -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX;
import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage;
import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;
import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext;

Expand All @@ -33,6 +34,7 @@
import org.opensearch.ml.common.transport.controller.MLControllerGetResponse;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.tasks.Task;
Expand All @@ -50,6 +52,7 @@ public class GetControllerTransportAction extends HandledTransportAction<ActionR
ClusterService clusterService;
MLModelManager mlModelManager;
ModelAccessControlHelper modelAccessControlHelper;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public GetControllerTransportAction(
Expand All @@ -59,14 +62,16 @@ public GetControllerTransportAction(
NamedXContentRegistry xContentRegistry,
ClusterService clusterService,
MLModelManager mlModelManager,
ModelAccessControlHelper modelAccessControlHelper
ModelAccessControlHelper modelAccessControlHelper,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLControllerGetAction.NAME, transportService, actionFilters, MLControllerGetRequest::new);
this.client = client;
this.xContentRegistry = xContentRegistry;
this.clusterService = clusterService;
this.mlModelManager = mlModelManager;
this.modelAccessControlHelper = modelAccessControlHelper;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
Expand All @@ -79,6 +84,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLCont
String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD };

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
if (!mlFeatureEnabledSetting.isControllerEnabled()) {
throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG);
}
ActionListener<MLControllerGetResponse> wrappedListener = ActionListener.runBefore(actionListener, context::restore);
client.get(getRequest, ActionListener.wrap(r -> {
if (r != null && r.isExists()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.ml.common.FunctionName.REMOTE;
import static org.opensearch.ml.common.FunctionName.TEXT_EMBEDDING;
import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage;
import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -46,6 +47,7 @@
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand All @@ -62,6 +64,7 @@ public class UpdateControllerTransportAction extends HandledTransportAction<Acti
MLModelCacheHelper mlModelCacheHelper;
ClusterService clusterService;
ModelAccessControlHelper modelAccessControlHelper;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public UpdateControllerTransportAction(
Expand All @@ -71,14 +74,16 @@ public UpdateControllerTransportAction(
ClusterService clusterService,
ModelAccessControlHelper modelAccessControlHelper,
MLModelCacheHelper mlModelCacheHelper,
MLModelManager mlModelManager
MLModelManager mlModelManager,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLUpdateControllerAction.NAME, transportService, actionFilters, MLUpdateControllerRequest::new);
this.client = client;
this.mlModelManager = mlModelManager;
this.clusterService = clusterService;
this.mlModelCacheHelper = mlModelCacheHelper;
this.modelAccessControlHelper = modelAccessControlHelper;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
Expand All @@ -90,6 +95,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update
String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD };

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
if (!mlFeatureEnabledSetting.isControllerEnabled()) {
throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG);
}
ActionListener<UpdateResponse> wrappedListener = ActionListener.runBefore(actionListener, context::restore);
// TODO: Add support for multi tenancy
mlModelManager.getModel(modelId, null, null, excludes, ActionListener.wrap(mlModel -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -748,10 +748,10 @@ public List<RestHandler> getRestHandlers(
RestMemorySearchInteractionsAction restSearchInteractionsAction = new RestMemorySearchInteractionsAction();
RestMemoryGetConversationAction restGetConversationAction = new RestMemoryGetConversationAction();
RestMemoryGetInteractionAction restGetInteractionAction = new RestMemoryGetInteractionAction();
RestMLCreateControllerAction restMLCreateControllerAction = new RestMLCreateControllerAction();
RestMLGetControllerAction restMLGetControllerAction = new RestMLGetControllerAction();
RestMLUpdateControllerAction restMLUpdateControllerAction = new RestMLUpdateControllerAction();
RestMLDeleteControllerAction restMLDeleteControllerAction = new RestMLDeleteControllerAction();
RestMLCreateControllerAction restMLCreateControllerAction = new RestMLCreateControllerAction(mlFeatureEnabledSetting);
RestMLGetControllerAction restMLGetControllerAction = new RestMLGetControllerAction(mlFeatureEnabledSetting);
RestMLUpdateControllerAction restMLUpdateControllerAction = new RestMLUpdateControllerAction(mlFeatureEnabledSetting);
RestMLDeleteControllerAction restMLDeleteControllerAction = new RestMLDeleteControllerAction(mlFeatureEnabledSetting);
RestMLGetAgentAction restMLGetAgentAction = new RestMLGetAgentAction(mlFeatureEnabledSetting);
RestMLDeleteAgentAction restMLDeleteAgentAction = new RestMLDeleteAgentAction(mlFeatureEnabledSetting);
RestMemoryUpdateConversationAction restMemoryUpdateConversationAction = new RestMemoryUpdateConversationAction();
Expand Down Expand Up @@ -944,6 +944,7 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED,
MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE,
MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED,
MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED,
// Settings for SdkClient
SdkClientSettings.REMOTE_METADATA_TYPE,
SdkClientSettings.REMOTE_METADATA_ENDPOINT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
import static org.opensearch.ml.utils.RestActionUtils.getParameterId;

Expand All @@ -20,6 +21,7 @@
import org.opensearch.ml.common.controller.MLController;
import org.opensearch.ml.common.transport.controller.MLCreateControllerAction;
import org.opensearch.ml.common.transport.controller.MLCreateControllerRequest;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;
Expand All @@ -29,11 +31,14 @@
public class RestMLCreateControllerAction extends BaseRestHandler {

public final static String ML_CREATE_CONTROLLER_ACTION = "ml_create_controller_action";
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

/**
* Constructor
*/
public RestMLCreateControllerAction() {}
public RestMLCreateControllerAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
public String getName() {
Expand Down Expand Up @@ -61,6 +66,10 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
* @return MLCreateControllerRequest
*/
private MLCreateControllerRequest getRequest(RestRequest request) throws IOException {
if (!mlFeatureEnabledSetting.isControllerEnabled()) {
throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG);
}

if (!request.hasContent()) {
throw new OpenSearchParseException("Create model controller request has empty body");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.rest;

import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;

import java.io.IOException;
Expand All @@ -15,6 +16,7 @@
import org.opensearch.client.node.NodeClient;
import org.opensearch.ml.common.transport.controller.MLControllerDeleteAction;
import org.opensearch.ml.common.transport.controller.MLControllerDeleteRequest;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;
Expand All @@ -25,9 +27,13 @@
* This class consists of the REST handler to delete ML Model.
*/
public class RestMLDeleteControllerAction extends BaseRestHandler {

private static final String ML_DELETE_CONTROLLER_ACTION = "ml_delete_controller_action";
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

public void RestMLDeleteControllerAction() {}
public RestMLDeleteControllerAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
public String getName() {
Expand All @@ -42,6 +48,9 @@ public List<Route> routes() {

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
if (!mlFeatureEnabledSetting.isControllerEnabled()) {
throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG);
}
String modelId = request.param(PARAMETER_MODEL_ID);

MLControllerDeleteRequest mlControllerDeleteRequest = new MLControllerDeleteRequest(modelId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.rest;

import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
import static org.opensearch.ml.utils.RestActionUtils.getParameterId;
import static org.opensearch.ml.utils.RestActionUtils.returnContent;
Expand All @@ -17,6 +18,7 @@
import org.opensearch.client.node.NodeClient;
import org.opensearch.ml.common.transport.controller.MLControllerGetAction;
import org.opensearch.ml.common.transport.controller.MLControllerGetRequest;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;
Expand All @@ -25,12 +27,16 @@
import com.google.common.collect.ImmutableList;

public class RestMLGetControllerAction extends BaseRestHandler {

private static final String ML_GET_CONTROLLER_ACTION = "ml_get_controller_action";
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

/**
* Constructor
*/
public RestMLGetControllerAction() {}
public RestMLGetControllerAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
public String getName() {
Expand All @@ -57,6 +63,10 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
*/
@VisibleForTesting
MLControllerGetRequest getRequest(RestRequest request) throws IOException {
if (!mlFeatureEnabledSetting.isControllerEnabled()) {
throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG);
}

String modelId = getParameterId(request, PARAMETER_MODEL_ID);
boolean returnContent = returnContent(request);

Expand Down
Loading

0 comments on commit 494820c

Please sign in to comment.