diff --git a/CHANGELOG.md b/CHANGELOG.md index e7a79a658..b4111a12c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ## [Unreleased 2.x](https://github.com/opensearch-project/flow-framework/compare/2.14...2.x) ### Features - Adds reprovision API to support updating search pipelines, ingest pipelines index settings ([#804](https://github.com/opensearch-project/flow-framework/pull/804)) +- Adds user level access control based on backend roles ([#838](https://github.com/opensearch-project/flow-framework/pull/838)) ### Enhancements ### Bug Fixes diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index f69534a77..92121dce5 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -56,6 +56,7 @@ import org.opensearch.flowframework.transport.SearchWorkflowStateAction; import org.opensearch.flowframework.transport.SearchWorkflowStateTransportAction; import org.opensearch.flowframework.transport.SearchWorkflowTransportAction; +import org.opensearch.flowframework.transport.handler.SearchHandler; import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; @@ -84,6 +85,7 @@ import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; @@ -135,7 +137,16 @@ public Collection createComponents( ); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool, flowFrameworkSettings); - return List.of(workflowStepFactory, workflowProcessSorter, encryptorUtils, flowFrameworkIndicesHandler, flowFrameworkSettings); + SearchHandler searchHandler = new SearchHandler(settings, clusterService, client, FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES); + + return List.of( + workflowStepFactory, + workflowProcessSorter, + encryptorUtils, + flowFrameworkIndicesHandler, + searchHandler, + flowFrameworkSettings + ); } @Override @@ -179,7 +190,14 @@ public List getRestHandlers( @Override public List> getSettings() { - return List.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, TASK_REQUEST_RETRY_DURATION); + return List.of( + FLOW_FRAMEWORK_ENABLED, + MAX_WORKFLOWS, + MAX_WORKFLOW_STEPS, + WORKFLOW_REQUEST_TIMEOUT, + TASK_REQUEST_RETRY_DURATION, + FILTER_BY_BACKEND_ROLES + ); } @Override diff --git a/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java b/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java index 3bbdfac28..922212a38 100644 --- a/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java +++ b/src/main/java/org/opensearch/flowframework/common/FlowFrameworkSettings.java @@ -75,6 +75,14 @@ public class FlowFrameworkSettings { Setting.Property.Dynamic ); + /** This setting sets the backend role filtering */ + public static final Setting FILTER_BY_BACKEND_ROLES = Setting.boolSetting( + "plugins.flow_framework.filter_by_backend_roles", + false, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + /** * Instantiate this class. * diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index ecb015ffc..72b2e8a12 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -16,12 +16,15 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.flowframework.common.CommonValue; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -49,7 +52,10 @@ import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD; import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES; +import static org.opensearch.flowframework.util.ParseUtils.checkFilterByBackendRoles; import static org.opensearch.flowframework.util.ParseUtils.getUserContext; +import static org.opensearch.flowframework.util.ParseUtils.getWorkflow; /** * Transport Action to index or update a use case template within the Global Context @@ -63,6 +69,9 @@ public class CreateWorkflowTransportAction extends HandledTransportAction filterByEnabled = it); + this.xContentRegistry = xContentRegistry; } @Override protected void doExecute(Task task, WorkflowRequest request, ActionListener listener) { - User user = getUserContext(client); + String workflowId = request.getWorkflowId(); + try { + resolveUserAndExecute(user, workflowId, listener, () -> createExecute(request, user, listener)); + } catch (Exception e) { + logger.error("Failed to create workflow", e); + listener.onFailure(e); + } + } + + /** + * Resolve user and execute the workflow function + * @param requestedUser the user making the request + * @param workflowId the workflow id + * @param listener the action listener + * @param function the workflow function to execute + */ + private void resolveUserAndExecute( + User requestedUser, + String workflowId, + ActionListener listener, + Runnable function + ) { + try { + // Check if user has backend roles + // When filter by is enabled, block users creating/updating workflows who do not have backend roles. + if (filterByEnabled == Boolean.TRUE) { + try { + checkFilterByBackendRoles(requestedUser); + } catch (FlowFrameworkException e) { + logger.error(e.getMessage(), e); + listener.onFailure(e); + return; + } + } + if (workflowId != null) { + // requestedUser == null means security is disabled or user is superadmin. In this case we don't need to + // check if request user have access to the workflow or not. But we still need to get current workflow for + // this case, so we can keep current workflow's user data. + boolean filterByBackendRole = requestedUser == null ? false : filterByEnabled; + // Update workflow request, check if user has permissions to update the workflow + // Get workflow and verify backend roles + getWorkflow(requestedUser, workflowId, filterByBackendRole, listener, function, client, clusterService, xContentRegistry); + } else { + // Create Workflow. No need to get current workflow. + function.run(); + } + } catch (Exception e) { + String errorMessage = "Failed to create or update workflow"; + if (e instanceof FlowFrameworkException) { + listener.onFailure(e); + } else { + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); + } + } + } + + /** + * Execute the create or update request + * 1. Validate workflows if requested + * 2. Create or update global context index + * 3. Create or update state index + * 4. Create or update provisioning progress index + * @param request the workflow request + * @param user the user making the request + * @param listener the action listener + */ + private void createExecute(WorkflowRequest request, User user, ActionListener listener) { Instant creationTime = Instant.now(); Template templateWithUser = new Template( request.getTemplate().name(), diff --git a/src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportAction.java index 151c3da7c..2974f522f 100644 --- a/src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportAction.java @@ -16,10 +16,14 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.tasks.Task; @@ -27,6 +31,9 @@ import static org.opensearch.flowframework.common.CommonValue.CLEAR_STATUS; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES; +import static org.opensearch.flowframework.util.ParseUtils.getUserContext; +import static org.opensearch.flowframework.util.ParseUtils.resolveUserAndExecute; /** * Transport action to retrieve a use case template within the Global Context @@ -37,6 +44,9 @@ public class DeleteWorkflowTransportAction extends HandledTransportAction filterByEnabled = it); } @Override protected void doExecute(Task task, WorkflowRequest request, ActionListener listener) { if (flowFrameworkIndicesHandler.doesIndexExist(GLOBAL_CONTEXT_INDEX)) { String workflowId = request.getWorkflowId(); - DeleteRequest deleteRequest = new DeleteRequest(GLOBAL_CONTEXT_INDEX, workflowId); + User user = getUserContext(client); ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext(); - logger.info("Deleting workflow doc: {}", workflowId); - client.delete(deleteRequest, ActionListener.runBefore(listener, context::restore)); - // Whether to force deletion of corresponding state - final boolean clearStatus = Booleans.parseBoolean(request.getParams().get(CLEAR_STATUS), false); - ActionListener stateListener = ActionListener.wrap(response -> { - logger.info("Deleted workflow state doc: {}", workflowId); - }, exception -> { logger.info("Failed to delete workflow state doc: {}", workflowId, exception); }); - flowFrameworkIndicesHandler.canDeleteWorkflowStateDoc(workflowId, clearStatus, canDelete -> { - if (Boolean.TRUE.equals(canDelete)) { - flowFrameworkIndicesHandler.deleteFlowFrameworkSystemIndexDoc(workflowId, stateListener); - } - }, stateListener); + resolveUserAndExecute( + user, + workflowId, + filterByEnabled, + listener, + () -> executeDeleteRequest(request, listener, context), + client, + clusterService, + xContentRegistry + ); + } else { String errorMessage = "There are no templates in the global context"; logger.error(errorMessage); listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND)); } } + + /** + * Executes the delete request + * @param request the workflow request + * @param listener the action listener + * @param context the thread context + */ + private void executeDeleteRequest( + WorkflowRequest request, + ActionListener listener, + ThreadContext.StoredContext context + ) { + String workflowId = request.getWorkflowId(); + DeleteRequest deleteRequest = new DeleteRequest(GLOBAL_CONTEXT_INDEX, workflowId); + logger.info("Deleting workflow doc: {}", workflowId); + client.delete(deleteRequest, ActionListener.runBefore(listener, context::restore)); + + // Whether to force deletion of corresponding state + final boolean clearStatus = Booleans.parseBoolean(request.getParams().get(CLEAR_STATUS), false); + ActionListener stateListener = ActionListener.wrap(response -> { + logger.info("Deleted workflow state doc: {}", workflowId); + }, exception -> { logger.info("Failed to delete workflow state doc: {}", workflowId, exception); }); + flowFrameworkIndicesHandler.canDeleteWorkflowStateDoc(workflowId, clearStatus, canDelete -> { + if (Boolean.TRUE.equals(canDelete)) { + flowFrameworkIndicesHandler.deleteFlowFrameworkSystemIndexDoc(workflowId, stateListener); + } + }, stateListener); + } } diff --git a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java index cf3f2361a..1b58e66db 100644 --- a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java @@ -16,11 +16,15 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; @@ -51,9 +55,11 @@ import static org.opensearch.flowframework.common.CommonValue.PROVISION_END_TIME_FIELD; import static org.opensearch.flowframework.common.CommonValue.RESOURCES_CREATED_FIELD; import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES; import static org.opensearch.flowframework.common.WorkflowResources.getDeprovisionStepByWorkflowStep; import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; import static org.opensearch.flowframework.util.ParseUtils.getUserContext; +import static org.opensearch.flowframework.util.ParseUtils.resolveUserAndExecute; /** * Transport Action to deprovision a workflow from a stored use case template @@ -67,6 +73,9 @@ public class DeprovisionWorkflowTransportAction extends HandledTransportAction filterByEnabled = it); } @Override protected void doExecute(Task task, WorkflowRequest request, ActionListener listener) { String workflowId = request.getWorkflowId(); - String allowDelete = request.getParams().get(ALLOW_DELETE); - GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true); + + User user = getUserContext(client); // Stash thread context to interact with system index try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - logger.info("Querying state for workflow: {}", workflowId); - client.execute(GetWorkflowStateAction.INSTANCE, getStateRequest, ActionListener.wrap(response -> { - context.restore(); + resolveUserAndExecute( + user, + workflowId, + filterByEnabled, + listener, + () -> executeDeprovisionRequest(request, listener, context), + client, + clusterService, + xContentRegistry + ); - Set deleteAllowedResources = Strings.tokenizeByCommaToSet(allowDelete); - // Retrieve resources from workflow state and deprovision - threadPool.executor(DEPROVISION_WORKFLOW_THREAD_POOL) - .execute( - () -> executeDeprovisionSequence( - workflowId, - response.getWorkflowState().resourcesCreated(), - deleteAllowedResources, - listener - ) - ); - }, exception -> { - String errorMessage = "Failed to get workflow state for workflow " + workflowId; - logger.error(errorMessage, exception); - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); - })); } catch (Exception e) { String errorMessage = "Failed to retrieve template from global context."; logger.error(errorMessage, e); @@ -131,6 +141,36 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener listener, + ThreadContext.StoredContext context + ) { + String workflowId = request.getWorkflowId(); + String allowDelete = request.getParams().get(ALLOW_DELETE); + GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true); + logger.info("Querying state for workflow: {}", workflowId); + client.execute(GetWorkflowStateAction.INSTANCE, getStateRequest, ActionListener.wrap(response -> { + context.restore(); + + Set deleteAllowedResources = Strings.tokenizeByCommaToSet(allowDelete); + // Retrieve resources from workflow state and deprovision + threadPool.executor(DEPROVISION_WORKFLOW_THREAD_POOL) + .execute( + () -> executeDeprovisionSequence( + workflowId, + response.getWorkflowState().resourcesCreated(), + deleteAllowedResources, + listener + ) + ); + }, exception -> { + String errorMessage = "Failed to get workflow state for workflow " + workflowId; + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + })); + } + private void executeDeprovisionSequence( String workflowId, List resourcesCreated, diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportAction.java index 9625ce731..b8e0eaf2a 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportAction.java @@ -15,7 +15,9 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -31,6 +33,8 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES; +import static org.opensearch.flowframework.util.ParseUtils.resolveUserAndExecute; //TODO: Currently we only get the workflow status but we should change to be able to get the // full template as well @@ -44,6 +48,8 @@ public class GetWorkflowStateTransportAction extends HandledTransportAction filterByEnabled = it); } @Override protected void doExecute(Task task, GetWorkflowStateRequest request, ActionListener listener) { String workflowId = request.getWorkflowId(); User user = ParseUtils.getUserContext(client); - GetRequest getRequest = new GetRequest(WORKFLOW_STATE_INDEX).id(workflowId); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - logger.info("Querying state workflow doc: {}", workflowId); - client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> { - if (r != null && r.isExists()) { - try (XContentParser parser = ParseUtils.createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - WorkflowState workflowState = WorkflowState.parse(parser); - listener.onResponse(new GetWorkflowStateResponse(workflowState, request.getAll())); - } catch (Exception e) { - String errorMessage = "Failed to parse workflowState: " + r.getId(); - logger.error(errorMessage, e); - listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); - } - } else { - listener.onFailure(new FlowFrameworkException("Fail to find workflow status of " + workflowId, RestStatus.NOT_FOUND)); - } - }, e -> { - if (e instanceof IndexNotFoundException) { - listener.onFailure(new FlowFrameworkException("Fail to find workflow status of " + workflowId, RestStatus.NOT_FOUND)); - } else { - String errorMessage = "Failed to get workflow status of: " + workflowId; - logger.error(errorMessage, e); - listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND)); - } - }), context::restore)); + + resolveUserAndExecute( + user, + workflowId, + filterByEnabled, + listener, + () -> executeGetWorkflowStateRequest(request, listener, context), + client, + clusterService, + xContentRegistry + ); + } catch (Exception e) { String errorMessage = "Failed to get workflow: " + workflowId; logger.error(errorMessage, e); listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); } } + + /** + * Execute the get workflow state request + * @param request the get workflow state request + * @param listener the action listener + * @param context the thread context + */ + private void executeGetWorkflowStateRequest( + GetWorkflowStateRequest request, + ActionListener listener, + ThreadContext.StoredContext context + ) { + String workflowId = request.getWorkflowId(); + GetRequest getRequest = new GetRequest(WORKFLOW_STATE_INDEX).id(workflowId); + logger.info("Querying state workflow doc: {}", workflowId); + client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> { + if (r != null && r.isExists()) { + try (XContentParser parser = ParseUtils.createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + WorkflowState workflowState = WorkflowState.parse(parser); + listener.onResponse(new GetWorkflowStateResponse(workflowState, request.getAll())); + } catch (Exception e) { + String errorMessage = "Failed to parse workflowState: " + r.getId(); + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); + } + } else { + listener.onFailure(new FlowFrameworkException("Fail to find workflow status of " + workflowId, RestStatus.NOT_FOUND)); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + listener.onFailure(new FlowFrameworkException("Fail to find workflow status of " + workflowId, RestStatus.NOT_FOUND)); + } else { + String errorMessage = "Failed to get workflow status of: " + workflowId; + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND)); + } + }), context::restore)); + } } diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java index 6ffbd191b..5994bfe4a 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java @@ -57,7 +57,7 @@ public GetWorkflowStepTransportAction( protected void doExecute(Task task, WorkflowRequest request, ActionListener listener) { try { logger.info("Getting workflow validator from the WorkflowStepFactory"); - List steps = request.getParams().size() > 0 + List steps = !request.getParams().isEmpty() ? Arrays.asList(Strings.splitStringByCommaToArray(request.getParams().get(WORKFLOW_STEP))) : Collections.emptyList(); WorkflowValidator workflowValidator; diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java index d713e8f48..59b129baf 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java @@ -15,11 +15,14 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.Template; @@ -29,6 +32,9 @@ import org.opensearch.transport.TransportService; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES; +import static org.opensearch.flowframework.util.ParseUtils.getUserContext; +import static org.opensearch.flowframework.util.ParseUtils.resolveUserAndExecute; /** * Transport action to retrieve a use case template within the Global Context @@ -40,6 +46,9 @@ public class GetWorkflowTransportAction extends HandledTransportAction filterByEnabled = it); } @Override @@ -68,29 +87,22 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { - context.restore(); - if (!response.isExists()) { - String errorMessage = "Failed to retrieve template (" + workflowId + ") from global context."; - logger.error(errorMessage); - listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND)); - } else { - // Remove any secured field from response - User user = ParseUtils.getUserContext(client); - Template template = encryptorUtils.redactTemplateSecuredFields(user, Template.parse(response.getSourceAsString())); - listener.onResponse(new GetWorkflowResponse(template)); - } - }, exception -> { - String errorMessage = "Failed to retrieve template (" + workflowId + ") from global context."; - logger.error(errorMessage, exception); - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); - })); + resolveUserAndExecute( + user, + workflowId, + filterByEnabled, + listener, + () -> executeGetRequest(request, listener, context), + client, + clusterService, + xContentRegistry + ); } catch (Exception e) { String errorMessage = "Failed to retrieve template (" + workflowId + ") from global context."; logger.error(errorMessage, e); @@ -103,4 +115,38 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener listener, + ThreadContext.StoredContext context + ) { + String workflowId = request.getWorkflowId(); + GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId); + logger.info("Querying workflow from global context: {}", workflowId); + client.get(getRequest, ActionListener.wrap(response -> { + context.restore(); + + if (!response.isExists()) { + String errorMessage = "Failed to retrieve template (" + workflowId + ") from global context."; + logger.error(errorMessage); + listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND)); + } else { + // Remove any secured field from response + User user = ParseUtils.getUserContext(client); + Template template = encryptorUtils.redactTemplateSecuredFields(user, Template.parse(response.getSourceAsString())); + listener.onResponse(new GetWorkflowResponse(template)); + } + }, exception -> { + String errorMessage = "Failed to retrieve template (" + workflowId + ") from global context."; + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + })); + } } diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 44eb07df4..1457cdb8e 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -16,10 +16,14 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.ProvisioningProgress; @@ -51,6 +55,9 @@ import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.RESOURCES_CREATED_FIELD; import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES; +import static org.opensearch.flowframework.util.ParseUtils.getUserContext; +import static org.opensearch.flowframework.util.ParseUtils.resolveUserAndExecute; /** * Transport Action to provision a workflow from a stored use case template @@ -65,6 +72,9 @@ public class ProvisionWorkflowTransportAction extends HandledTransportAction filterByEnabled = it); } @Override protected void doExecute(Task task, WorkflowRequest request, ActionListener listener) { // Retrieve use case template from global context String workflowId = request.getWorkflowId(); - GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId); + + User user = getUserContext(client); // Stash thread context to interact with system index try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - logger.info("Querying workflow from global context: {}", workflowId); - client.get(getRequest, ActionListener.wrap(response -> { - context.restore(); - if (!response.isExists()) { - String errorMessage = "Failed to retrieve template (" + workflowId + ") from global context."; - logger.error(errorMessage); - listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND)); - return; - } + resolveUserAndExecute( + user, + workflowId, + filterByEnabled, + listener, + () -> executeProvisionRequest(request, listener, context), + client, + clusterService, + xContentRegistry + ); + } catch (Exception e) { + String errorMessage = "Failed to retrieve template from global context for workflow " + workflowId; + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); + } + } - // Parse template from document source - Template parsedTemplate = Template.parse(response.getSourceAsString()); + /** + * Execute the provision request + * 1. Retrieve template from global context + * 2. Decrypt template + * 3. Sort and validate graph + * 4. Update state index + * 5. Execute workflow asynchronously + * 6. Update last provisioned field in template + * 7. Return response + * @param request the workflow request + * @param listener the action listener + * @param context the thread context + */ + private void executeProvisionRequest( + WorkflowRequest request, + ActionListener listener, + ThreadContext.StoredContext context + ) { + String workflowId = request.getWorkflowId(); + GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId); + logger.info("Querying workflow from global context: {}", workflowId); + client.get(getRequest, ActionListener.wrap(response -> { + context.restore(); + + if (!response.isExists()) { + String errorMessage = "Failed to retrieve template (" + workflowId + ") from global context."; + logger.error(errorMessage); + listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND)); + return; + } - // Decrypt template - final Template template = encryptorUtils.decryptTemplateCredentials(parsedTemplate); + // Parse template from document source + Template parsedTemplate = Template.parse(response.getSourceAsString()); - // Sort and validate graph - Workflow provisionWorkflow = template.workflows().get(PROVISION_WORKFLOW); - List provisionProcessSequence = workflowProcessSorter.sortProcessNodes( - provisionWorkflow, - workflowId, - request.getParams() - ); - workflowProcessSorter.validate(provisionProcessSequence, pluginsService); + // Decrypt template + final Template template = encryptorUtils.decryptTemplateCredentials(parsedTemplate); + + // Sort and validate graph + Workflow provisionWorkflow = template.workflows().get(PROVISION_WORKFLOW); + List provisionProcessSequence = workflowProcessSorter.sortProcessNodes( + provisionWorkflow, + workflowId, + request.getParams() + ); + workflowProcessSorter.validate(provisionProcessSequence, pluginsService); - flowFrameworkIndicesHandler.getProvisioningProgress(workflowId, progress -> { - if (ProvisioningProgress.NOT_STARTED.equals(progress.orElse(null))) { - // update state index - flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( - workflowId, - Map.ofEntries( - Map.entry(STATE_FIELD, State.PROVISIONING), - Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.IN_PROGRESS), - Map.entry(PROVISION_START_TIME_FIELD, Instant.now().toEpochMilli()), - Map.entry(RESOURCES_CREATED_FIELD, Collections.emptyList()) - ), - ActionListener.wrap(updateResponse -> { - logger.info("updated workflow {} state to {}", request.getWorkflowId(), State.PROVISIONING); - executeWorkflowAsync(workflowId, provisionProcessSequence, listener); - // update last provisioned field in template - Template newTemplate = Template.builder(template).lastProvisionedTime(Instant.now()).build(); - flowFrameworkIndicesHandler.updateTemplateInGlobalContext( - request.getWorkflowId(), - newTemplate, - ActionListener.wrap(templateResponse -> { - listener.onResponse(new WorkflowResponse(request.getWorkflowId())); - }, exception -> { - String errorMessage = "Failed to update use case template " + request.getWorkflowId(); - logger.error(errorMessage, exception); - if (exception instanceof FlowFrameworkException) { - listener.onFailure(exception); - } else { - listener.onFailure( - new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception)) - ); - } - }), - // We've already checked workflow is not started, ignore second check - true - ); - }, exception -> { - String errorMessage = "Failed to update workflow state: " + workflowId; - logger.error(errorMessage, exception); - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); - }) - ); - } else { - String errorMessage = "The workflow provisioning state is " - + (progress.isPresent() ? progress.get().toString() : "unknown") - + " and can not be provisioned unless its state is NOT_STARTED: " - + workflowId - + ". Deprovision the workflow to reset the state."; - logger.info(errorMessage); - listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); - } - }, listener); - }, exception -> { - if (exception instanceof FlowFrameworkException) { - logger.error("Workflow validation failed for workflow {}", workflowId); - listener.onFailure(exception); + flowFrameworkIndicesHandler.getProvisioningProgress(workflowId, progress -> { + if (ProvisioningProgress.NOT_STARTED.equals(progress.orElse(null))) { + // update state index + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( + workflowId, + Map.ofEntries( + Map.entry(STATE_FIELD, State.PROVISIONING), + Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.IN_PROGRESS), + Map.entry(PROVISION_START_TIME_FIELD, Instant.now().toEpochMilli()), + Map.entry(RESOURCES_CREATED_FIELD, Collections.emptyList()) + ), + ActionListener.wrap(updateResponse -> { + logger.info("updated workflow {} state to {}", request.getWorkflowId(), State.PROVISIONING); + executeWorkflowAsync(workflowId, provisionProcessSequence, listener); + // update last provisioned field in template + Template newTemplate = Template.builder(template).lastProvisionedTime(Instant.now()).build(); + flowFrameworkIndicesHandler.updateTemplateInGlobalContext( + request.getWorkflowId(), + newTemplate, + ActionListener.wrap(templateResponse -> { + listener.onResponse(new WorkflowResponse(request.getWorkflowId())); + }, exception -> { + String errorMessage = "Failed to update use case template " + request.getWorkflowId(); + logger.error(errorMessage, exception); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + }), + // We've already checked workflow is not started, ignore second check + true + ); + }, exception -> { + String errorMessage = "Failed to update workflow state: " + workflowId; + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + }) + ); } else { - String errorMessage = "Failed to retrieve template from global context for workflow " + workflowId; - logger.error(errorMessage, exception); - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + String errorMessage = "The workflow provisioning state is " + + (progress.isPresent() ? progress.get().toString() : "unknown") + + " and can not be provisioned unless its state is NOT_STARTED: " + + workflowId + + ". Deprovision the workflow to reset the state."; + logger.info(errorMessage); + listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); } - })); - } catch (Exception e) { - String errorMessage = "Failed to retrieve template from global context for workflow " + workflowId; - logger.error(errorMessage, e); - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); - } + }, listener); + }, exception -> { + if (exception instanceof FlowFrameworkException) { + logger.error("Workflow validation failed for workflow {}", workflowId); + listener.onFailure(exception); + } else { + String errorMessage = "Failed to retrieve template from global context for workflow " + workflowId; + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + })); } /** diff --git a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java index 90fe8066c..8d024d180 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java @@ -15,10 +15,14 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; @@ -55,6 +59,9 @@ import static org.opensearch.flowframework.common.CommonValue.RESOURCES_CREATED_FIELD; import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES; +import static org.opensearch.flowframework.util.ParseUtils.getUserContext; +import static org.opensearch.flowframework.util.ParseUtils.resolveUserAndExecute; /** * Transport Action to reprovision a provisioned template @@ -71,6 +78,9 @@ public class ReprovisionWorkflowTransportAction extends HandledTransportAction filterByEnabled = it); } @Override protected void doExecute(Task task, ReprovisionWorkflowRequest request, ActionListener listener) { String workflowId = request.getWorkflowId(); + User user = getUserContext(client); - // Retrieve state and resources created - GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - logger.info("Querying state for workflow: {}", workflowId); - client.execute(GetWorkflowStateAction.INSTANCE, getStateRequest, ActionListener.wrap(response -> { - context.restore(); - - State currentState = State.valueOf(response.getWorkflowState().getState()); - if (State.PROVISIONING.equals(currentState) || State.NOT_STARTED.equals(currentState)) { - String errorMessage = "The template can not be reprovisioned unless its provisioning state is DONE or FAILED: " - + workflowId; - throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); - } - - // Generate reprovision sequence - List resourceCreated = response.getWorkflowState().resourcesCreated(); - - // Original template is retrieved from index, attempt to decrypt any exisiting credentials before processing - Template originalTemplate = encryptorUtils.decryptTemplateCredentials(request.getOriginalTemplate()); - Template updatedTemplate = request.getUpdatedTemplate(); - - // Validate updated template prior to execution - Workflow provisionWorkflow = updatedTemplate.workflows().get(PROVISION_WORKFLOW); - List updatedProcessSequence = workflowProcessSorter.sortProcessNodes( - provisionWorkflow, - request.getWorkflowId(), - Collections.emptyMap() // TODO : Add suport to reprovision substitution templates - ); + resolveUserAndExecute( + user, + workflowId, + filterByEnabled, + listener, + () -> executeReprovisionRequest(request, listener, context), + client, + clusterService, + xContentRegistry + ); + } catch (Exception e) { + String errorMessage = "Failed to get workflow state for workflow " + workflowId; + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); + } - try { - workflowProcessSorter.validate(updatedProcessSequence, pluginsService); - } catch (Exception e) { - String errormessage = "Workflow validation failed for workflow " + request.getWorkflowId(); - logger.error(errormessage, e); - listener.onFailure(new FlowFrameworkException(errormessage, RestStatus.BAD_REQUEST)); - } - List reprovisionProcessSequence = workflowProcessSorter.createReprovisionSequence( - workflowId, - originalTemplate, - updatedTemplate, - resourceCreated - ); + } - // Remove error field if any prior to subsequent execution - if (response.getWorkflowState().getError() != null) { - Script script = new Script( - ScriptType.INLINE, - "painless", - "if(ctx._source.containsKey('error')){ctx._source.remove('error')}", - Collections.emptyMap() - ); - flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDocWithScript( - WORKFLOW_STATE_INDEX, - workflowId, - script, - ActionListener.wrap(updateResponse -> { + /** + * Execute the reprovision request + * @param request the reprovision request + * @param listener the action listener + * @param context the thread context + */ + private void executeReprovisionRequest( + ReprovisionWorkflowRequest request, + ActionListener listener, + ThreadContext.StoredContext context + ) { + String workflowId = request.getWorkflowId(); + logger.info("Querying state for workflow: {}", workflowId); + // Retrieve state and resources created + GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true); + client.execute(GetWorkflowStateAction.INSTANCE, getStateRequest, ActionListener.wrap(response -> { + context.restore(); - }, exception -> { - String errorMessage = "Failed to update workflow state: " + workflowId; - logger.error(errorMessage, exception); - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); - }) - ); - } + State currentState = State.valueOf(response.getWorkflowState().getState()); + if (State.PROVISIONING.equals(currentState) || State.NOT_STARTED.equals(currentState)) { + String errorMessage = "The template can not be reprovisioned unless its provisioning state is DONE or FAILED: " + + workflowId; + throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); + } - // Update State Index, maintain resources created for subsequent execution - flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( - workflowId, - Map.ofEntries( - Map.entry(STATE_FIELD, State.PROVISIONING), - Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.IN_PROGRESS), - Map.entry(PROVISION_START_TIME_FIELD, Instant.now().toEpochMilli()), - Map.entry(RESOURCES_CREATED_FIELD, resourceCreated) - ), - ActionListener.wrap(updateResponse -> { + // Generate reprovision sequence + List resourceCreated = response.getWorkflowState().resourcesCreated(); - logger.info("Updated workflow {} state to {}", request.getWorkflowId(), State.PROVISIONING); + // Original template is retrieved from index, attempt to decrypt any exisiting credentials before processing + Template originalTemplate = encryptorUtils.decryptTemplateCredentials(request.getOriginalTemplate()); + Template updatedTemplate = request.getUpdatedTemplate(); - // Attach last provisioned time to updated template and execute reprovisioning - Template updatedTemplateWithProvisionedTime = Template.builder(updatedTemplate) - .lastProvisionedTime(Instant.now()) - .build(); - executeWorkflowAsync(workflowId, updatedTemplateWithProvisionedTime, reprovisionProcessSequence, listener); + // Validate updated template prior to execution + Workflow provisionWorkflow = updatedTemplate.workflows().get(PROVISION_WORKFLOW); + List updatedProcessSequence = workflowProcessSorter.sortProcessNodes( + provisionWorkflow, + request.getWorkflowId(), + Collections.emptyMap() // TODO : Add suport to reprovision substitution templates + ); - listener.onResponse(new WorkflowResponse(workflowId)); + try { + workflowProcessSorter.validate(updatedProcessSequence, pluginsService); + } catch (Exception e) { + String errormessage = "Workflow validation failed for workflow " + request.getWorkflowId(); + logger.error(errormessage, e); + listener.onFailure(new FlowFrameworkException(errormessage, RestStatus.BAD_REQUEST)); + } + List reprovisionProcessSequence = workflowProcessSorter.createReprovisionSequence( + workflowId, + originalTemplate, + updatedTemplate, + resourceCreated + ); + + // Remove error field if any prior to subsequent execution + if (response.getWorkflowState().getError() != null) { + Script script = new Script( + ScriptType.INLINE, + "painless", + "if(ctx._source.containsKey('error')){ctx._source.remove('error')}", + Collections.emptyMap() + ); + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDocWithScript( + WORKFLOW_STATE_INDEX, + workflowId, + script, + ActionListener.wrap(updateResponse -> { }, exception -> { String errorMessage = "Failed to update workflow state: " + workflowId; @@ -206,21 +228,44 @@ protected void doExecute(Task task, ReprovisionWorkflowRequest request, ActionLi listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); }) ); - }, exception -> { - if (exception instanceof FlowFrameworkException) { - listener.onFailure(exception); - } else { - String errorMessage = "Failed to get workflow state for workflow " + workflowId; + } + + // Update State Index, maintain resources created for subsequent execution + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( + workflowId, + Map.ofEntries( + Map.entry(STATE_FIELD, State.PROVISIONING), + Map.entry(PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.IN_PROGRESS), + Map.entry(PROVISION_START_TIME_FIELD, Instant.now().toEpochMilli()), + Map.entry(RESOURCES_CREATED_FIELD, resourceCreated) + ), + ActionListener.wrap(updateResponse -> { + + logger.info("Updated workflow {} state to {}", request.getWorkflowId(), State.PROVISIONING); + + // Attach last provisioned time to updated template and execute reprovisioning + Template updatedTemplateWithProvisionedTime = Template.builder(updatedTemplate) + .lastProvisionedTime(Instant.now()) + .build(); + executeWorkflowAsync(workflowId, updatedTemplateWithProvisionedTime, reprovisionProcessSequence, listener); + + listener.onResponse(new WorkflowResponse(workflowId)); + + }, exception -> { + String errorMessage = "Failed to update workflow state: " + workflowId; logger.error(errorMessage, exception); listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); - } - })); - } catch (Exception e) { - String errorMessage = "Failed to get workflow state for workflow " + workflowId; - logger.error(errorMessage, e); - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); - } - + }) + ); + }, exception -> { + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + String errorMessage = "Failed to get workflow state for workflow " + workflowId; + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + } + })); } /** diff --git a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java index afe3b85d4..f20c57adb 100644 --- a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java @@ -10,22 +10,18 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; -import org.opensearch.flowframework.util.ParseUtils; -import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.transport.handler.SearchHandler; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; -import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext; - /** * Transport Action to search workflow states */ @@ -33,31 +29,29 @@ public class SearchWorkflowStateTransportAction extends HandledTransportAction actionListener) { - // AccessController should take care of letting the user with right permission to view the workflow - User user = ParseUtils.getUserContext(client); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - SearchSourceBuilder searchSourceBuilder = request.source(); - searchSourceBuilder.fetchSource(getSourceContext(user, searchSourceBuilder)); - client.search(request, ActionListener.runBefore(actionListener, context::restore)); + try { + searchHandler.search(request, actionListener); } catch (Exception e) { - logger.error("Failed to search workflow states in global context", e); - actionListener.onFailure(e); + String errorMessage = "Failed to search workflow states in global context"; + logger.error(errorMessage, e); + actionListener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); } + } } diff --git a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowTransportAction.java index 41a8b23f9..46f0afb10 100644 --- a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowTransportAction.java @@ -10,22 +10,18 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; -import org.opensearch.flowframework.util.ParseUtils; -import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.transport.handler.SearchHandler; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; -import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext; - /** * Transport Action to search workflows created */ @@ -33,32 +29,28 @@ public class SearchWorkflowTransportAction extends HandledTransportAction actionListener) { - // AccessController should take care of letting the user with right permission to view the workflow - User user = ParseUtils.getUserContext(client); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - logger.info("Searching workflows in global context"); - SearchSourceBuilder searchSourceBuilder = request.source(); - searchSourceBuilder.fetchSource(getSourceContext(user, searchSourceBuilder)); - client.search(request, ActionListener.runBefore(actionListener, context::restore)); + try { + searchHandler.search(request, actionListener); } catch (Exception e) { - logger.error("Failed to search workflows in global context", e); - actionListener.onFailure(e); + String errorMessage = "Failed to search workflows in global context"; + logger.error(errorMessage, e); + actionListener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); } } } diff --git a/src/main/java/org/opensearch/flowframework/transport/handler/SearchHandler.java b/src/main/java/org/opensearch/flowframework/transport/handler/SearchHandler.java new file mode 100644 index 000000000..512b0bea2 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/transport/handler/SearchHandler.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport.handler; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.search.builder.SearchSourceBuilder; + +import static org.opensearch.flowframework.util.ParseUtils.isAdmin; +import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext; + +/** + * Handle general search request, check user role and return search response. + */ +public class SearchHandler { + private final Logger logger = LogManager.getLogger(SearchHandler.class); + private final Client client; + private volatile Boolean filterByBackendRole; + + /** + * Instantiates a new SearchHandler + * @param settings settings + * @param clusterService cluster service + * @param client The node client to retrieve a stored use case template + * @param filterByBackendRoleSetting filter role backend settings + */ + public SearchHandler(Settings settings, ClusterService clusterService, Client client, Setting filterByBackendRoleSetting) { + this.client = client; + filterByBackendRole = filterByBackendRoleSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(filterByBackendRoleSetting, it -> filterByBackendRole = it); + } + + /** + * Search workflows in global context + * @param request SearchRequest + * @param actionListener ActionListener + */ + public void search(SearchRequest request, ActionListener actionListener) { + // AccessController should take care of letting the user with right permission to view the workflow + User user = ParseUtils.getUserContext(client); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + logger.info("Searching workflows in global context"); + SearchSourceBuilder searchSourceBuilder = request.source(); + searchSourceBuilder.fetchSource(getSourceContext(user, searchSourceBuilder)); + validateRole(request, user, actionListener, context); + } catch (Exception e) { + logger.error("Failed to search workflows in global context", e); + actionListener.onFailure(e); + } + } + + /** + * Validate user role and call search + * @param request SearchRequest + * @param user User + * @param listener ActionListener + * @param context ThreadContext + */ + public void validateRole( + SearchRequest request, + User user, + ActionListener listener, + ThreadContext.StoredContext context + ) { + if (user == null || !filterByBackendRole || isAdmin(user)) { + // Case 1: user == null when 1. Security is disabled. 2. When user is super-admin + // Case 2: If Security is enabled and filter is disabled, proceed with search as + // user is already authenticated to hit this API. + // case 3: user is admin which means we don't have to check backend role filtering + client.search(request, ActionListener.runBefore(listener, context::restore)); + } else { + // Security is enabled, filter is enabled and user isn't admin + try { + ParseUtils.addUserBackendRolesFilter(user, request.source()); + logger.debug("Filtering result by {}", user.getBackendRoles()); + client.search(request, ActionListener.runBefore(listener, context::restore)); + } catch (Exception e) { + listener.onFailure(e); + } + } + } +} diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index 16e8b25e1..cd598aa91 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -12,22 +12,38 @@ import com.google.gson.JsonParser; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Booleans; import org.opensearch.common.io.Streams; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.transport.WorkflowResponse; import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.NestedQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; +import org.opensearch.search.builder.SearchSourceBuilder; import java.io.FileNotFoundException; import java.io.IOException; @@ -49,6 +65,7 @@ import jakarta.json.bind.JsonbBuilder; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; /** * Utility methods for Template parsing @@ -225,10 +242,225 @@ public static Instant parseInstant(XContentParser parser) throws IOException { */ public static User getUserContext(Client client) { String userStr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - logger.debug("Filtering result by " + userStr); + logger.debug("Filtering result by {}", userStr); return User.parse(userStr); } + /** + * Add user backend roles filter to search source builder= + * @param user the user + * @param searchSourceBuilder search builder + * @return search builder with filter added + */ + public static SearchSourceBuilder addUserBackendRolesFilter(User user, SearchSourceBuilder searchSourceBuilder) { + if (user == null) { + return searchSourceBuilder; + } + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + String userFieldName = "user"; + String userBackendRoleFieldName = "user.backend_roles.keyword"; + List backendRoles = user.getBackendRoles() != null ? user.getBackendRoles() : ImmutableList.of(); + // For normal case, user should have backend roles. + TermsQueryBuilder userRolesFilterQuery = QueryBuilders.termsQuery(userBackendRoleFieldName, backendRoles); + NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(userFieldName, userRolesFilterQuery, ScoreMode.None); + boolQueryBuilder.must(nestedQueryBuilder); + QueryBuilder query = searchSourceBuilder.query(); + if (query == null) { + searchSourceBuilder.query(boolQueryBuilder); + } else if (query instanceof BoolQueryBuilder) { + ((BoolQueryBuilder) query).filter(boolQueryBuilder); + } else { + // Convert any other query to a BoolQueryBuilder + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery().must(query); + boolQuery.filter(boolQueryBuilder); + searchSourceBuilder.query(boolQuery); + } + return searchSourceBuilder; + } + + /** + * Resolve user and execute the function + * @param requestedUser the user to execute the request + * @param workflowId workflow id + * @param filterByEnabled filter by enabled setting + * @param listener action listener + * @param function workflow function + * @param client node client + * @param clusterService cluster service + * @param xContentRegistry contentRegister to parse get response + */ + public static void resolveUserAndExecute( + User requestedUser, + String workflowId, + Boolean filterByEnabled, + ActionListener listener, + Runnable function, + Client client, + ClusterService clusterService, + NamedXContentRegistry xContentRegistry + ) { + try { + if (requestedUser == null || filterByEnabled == Boolean.FALSE) { + // requestedUser == null means security is disabled or user is superadmin. In this case we don't need to + // check if request user have access to the workflow or not. + // !filterByEnabled means security is enabled and filterByEnabled is disabled + function.run(); + } else { + getWorkflow(requestedUser, workflowId, filterByEnabled, listener, function, client, clusterService, xContentRegistry); + } + } catch (Exception e) { + listener.onFailure(e); + } + } + + /** + * Check if requested user has backend role required to access the resource + * @param requestedUser the user to execute the request + * @param resourceUser user of the resource + * @param workflowId workflow id + * @return boolean if the requested user has backend role required to access the resource + * @throws Exception exception + */ + private static boolean checkUserPermissions(User requestedUser, User resourceUser, String workflowId) throws Exception { + if (resourceUser.getBackendRoles() == null || requestedUser.getBackendRoles() == null) { + return false; + } + // Check if requested user has backend role required to access the resource + for (String backendRole : requestedUser.getBackendRoles()) { + if (resourceUser.getBackendRoles().contains(backendRole)) { + logger.debug( + "User: " + + requestedUser.getName() + + " has backend role: " + + backendRole + + " permissions to access config: " + + workflowId + ); + return true; + } + } + return false; + } + + /** + * Check if filter by backend roles is enabled and validate the requested user + * @param requestedUser the user to execute the request + */ + public static void checkFilterByBackendRoles(User requestedUser) { + if (requestedUser == null) { + String errorMessage = "Filter by backend roles is enabled and User is null"; + logger.error(errorMessage); + throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); + } + if (requestedUser.getBackendRoles().isEmpty()) { + String userErrorMessage = "Filter by backend roles is enabled, but User " + + requestedUser.getName() + + " does not have any backend roles configured"; + + logger.error(userErrorMessage); + throw new FlowFrameworkException(userErrorMessage, RestStatus.FORBIDDEN); + } + } + + /** + * Get workflow + * @param requestUser the user to execute the request + * @param workflowId workflow id + * @param filterByEnabled filter by enabled setting + * @param listener action listener + * @param function workflow function + * @param client node client + * @param clusterService cluster service + * @param xContentRegistry contentRegister to parse get response + */ + public static void getWorkflow( + User requestUser, + String workflowId, + Boolean filterByEnabled, + ActionListener listener, + Runnable function, + Client client, + ClusterService clusterService, + NamedXContentRegistry xContentRegistry + ) { + if (clusterService.state().metadata().hasIndex(GLOBAL_CONTEXT_INDEX)) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + GetRequest request = new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId); + client.get( + request, + ActionListener.wrap( + response -> onGetWorkflowResponse( + response, + requestUser, + workflowId, + filterByEnabled, + listener, + function, + xContentRegistry + ), + exception -> { + logger.error("Failed to get workflow: {}", workflowId, exception); + listener.onFailure(exception); + } + ) + ); + } + } else { + String errorMessage = "Failed to retrieve template (" + workflowId + ") from global context."; + logger.error(errorMessage); + listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND)); + } + } + + /** + * Execute the function if user has permissions to access the resource + * @param requestUser the user to execute the request + * @param response get response + * @param workflowId workflow id + * @param filterByEnabled filter by enabled setting + * @param listener action listener + * @param function workflow function + * @param xContentRegistry contentRegister to parse get response + */ + public static void onGetWorkflowResponse( + GetResponse response, + User requestUser, + String workflowId, + Boolean filterByEnabled, + ActionListener listener, + Runnable function, + NamedXContentRegistry xContentRegistry + ) { + if (response.isExists()) { + try ( + XContentParser parser = RestHandlerUtils.createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Template template = Template.parse(parser); + User resourceUser = template.getUser(); + + if (!filterByEnabled || checkUserPermissions(requestUser, resourceUser, workflowId) || isAdmin(requestUser)) { + function.run(); + } else { + logger.debug("User: " + requestUser.getName() + " does not have permissions to access workflow: " + workflowId); + listener.onFailure( + new FlowFrameworkException( + "User does not have permissions to access workflow: " + workflowId, + RestStatus.BAD_REQUEST + ) + ); + } + } catch (Exception e) { + logger.error("Failed to parse workflow: {}", workflowId, e); + listener.onFailure(e); + } + } else { + String errorMessage = "Failed to retrieve template (" + workflowId + ") from global context."; + logger.error(errorMessage); + listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND)); + } + } + /** * Creates a XContentParser from a given Registry * diff --git a/src/main/java/org/opensearch/flowframework/util/RestHandlerUtils.java b/src/main/java/org/opensearch/flowframework/util/RestHandlerUtils.java index a119d6809..daa93a9e9 100644 --- a/src/main/java/org/opensearch/flowframework/util/RestHandlerUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/RestHandlerUtils.java @@ -9,12 +9,20 @@ package org.opensearch.flowframework.util; import org.apache.commons.lang3.ArrayUtils; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; import org.opensearch.core.common.Strings; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.common.CommonValue; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; +import java.io.IOException; + /** * Utility methods for Rest Handlers */ @@ -52,4 +60,16 @@ public static FetchSourceContext getSourceContext(User user, SearchSourceBuilder return new FetchSourceContext(true, Strings.EMPTY_ARRAY, EXCLUDES); } } + + /** + * Create an XContentParser from the provided NamedXContentRegistry and BytesReference + * @param xContentRegistry content registry + * @param bytesReference bytes reference + * @return XContentParser + * @throws IOException if error occurs while creating parser + */ + public static XContentParser createXContentParserFromRegistry(NamedXContentRegistry xContentRegistry, BytesReference bytesReference) + throws IOException { + return XContentHelper.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, bytesReference, XContentType.JSON); + } } diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index 8f540868d..65b0a9a75 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -28,6 +28,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOWS; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; @@ -65,7 +66,14 @@ public void setUp() throws Exception { final Set> settingsSet = Stream.concat( ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), - Stream.of(FLOW_FRAMEWORK_ENABLED, MAX_WORKFLOWS, MAX_WORKFLOW_STEPS, WORKFLOW_REQUEST_TIMEOUT, TASK_REQUEST_RETRY_DURATION) + Stream.of( + FLOW_FRAMEWORK_ENABLED, + MAX_WORKFLOWS, + MAX_WORKFLOW_STEPS, + WORKFLOW_REQUEST_TIMEOUT, + TASK_REQUEST_RETRY_DURATION, + FILTER_BY_BACKEND_ROLES + ) ).collect(Collectors.toSet()); clusterSettings = new ClusterSettings(settings, settingsSet); clusterService = mock(ClusterService.class); @@ -81,13 +89,13 @@ public void tearDown() throws Exception { public void testPlugin() throws IOException { try (FlowFrameworkPlugin ffp = new FlowFrameworkPlugin()) { assertEquals( - 5, + 6, ffp.createComponents(client, clusterService, threadPool, null, null, null, environment, null, null, null, null).size() ); assertEquals(9, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); assertEquals(10, ffp.getActions().size()); assertEquals(3, ffp.getExecutorBuilders(settings).size()); - assertEquals(5, ffp.getSettings().size()); + assertEquals(6, ffp.getSettings().size()); Collection systemIndexDescriptors = ffp.getSystemIndexDescriptors(Settings.EMPTY); assertEquals(3, systemIndexDescriptors.size()); diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java index eeb310be4..17c72b21f 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java @@ -8,6 +8,7 @@ */ package org.opensearch.flowframework; +import com.google.gson.JsonArray; import org.apache.commons.lang3.RandomStringUtils; import org.apache.http.Header; import org.apache.http.HttpHeaders; @@ -31,7 +32,6 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.json.JsonXContent; -import org.opensearch.commons.rest.SecureRestClientBuilder; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.MediaType; @@ -50,6 +50,7 @@ import org.junit.Before; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Locale; @@ -66,14 +67,6 @@ */ public abstract class FlowFrameworkRestTestCase extends OpenSearchRestTestCase { - private static String FLOW_FRAMEWORK_FULL_ACCESS_ROLE = "flow_framework_full_access"; - private static String ML_COMMONS_FULL_ACCESS_ROLE = "ml_full_access"; - private static String READ_ACCESS_ROLE = "flow_framework_read_access"; - public static String FULL_ACCESS_USER = "fullAccessUser"; - public static String READ_ACCESS_USER = "readAccessUser"; - private static RestClient readAccessClient; - private static RestClient fullAccessClient; - @Before protected void setUpSettings() throws Exception { @@ -132,42 +125,6 @@ protected void setUpSettings() throws Exception { ); assertEquals(200, response.getStatusLine().getStatusCode()); - // Set up clients if running in security enabled cluster - if (isHttps()) { - String fullAccessUserPassword = generatePassword(FULL_ACCESS_USER); - String readAccessUserPassword = generatePassword(READ_ACCESS_USER); - - // Configure full access user and client, needs ML Full Access role as well - response = createUser( - FULL_ACCESS_USER, - fullAccessUserPassword, - List.of(FLOW_FRAMEWORK_FULL_ACCESS_ROLE, ML_COMMONS_FULL_ACCESS_ROLE) - ); - fullAccessClient = new SecureRestClientBuilder( - getClusterHosts().toArray(new HttpHost[0]), - isHttps(), - FULL_ACCESS_USER, - fullAccessUserPassword - ).setSocketTimeout(60000).build(); - - // Configure read access user and client - response = createUser(READ_ACCESS_USER, readAccessUserPassword, List.of(READ_ACCESS_ROLE)); - readAccessClient = new SecureRestClientBuilder( - getClusterHosts().toArray(new HttpHost[0]), - isHttps(), - READ_ACCESS_USER, - readAccessUserPassword - ).setSocketTimeout(60000).build(); - } - - } - - protected static RestClient fullAccessClient() { - return fullAccessClient; - } - - protected static RestClient readAccessClient() { - return readAccessClient; } protected boolean isHttps() { @@ -338,7 +295,7 @@ protected Response createWorkflow(RestClient client, Template template) throws E * Helper method to invoke the Create Workflow Rest Action without validation * @param client the rest client * @param useCase the usecase to create - * @param the required params + * @param params the required params * @throws Exception if the request fails * @return a rest response */ @@ -362,6 +319,99 @@ protected Response createWorkflowWithUseCaseWithNoValidation(RestClient client, ); } + public Response createIndexRole(String role, String index) throws IOException { + return TestHelpers.makeRequest( + client(), + "PUT", + "/_plugins/_security/api/roles/" + role, + null, + TestHelpers.toHttpEntity( + "{\n" + + "\"cluster_permissions\": [\n" + + "\"cluster:admin/ingest/pipeline/delete\"\n" + + "],\n" + + "\"index_permissions\": [\n" + + "{\n" + + "\"index_patterns\": [\n" + + "\"" + + index + + "\"\n" + + "],\n" + + "\"dls\": \"\",\n" + + "\"fls\": [],\n" + + "\"masked_fields\": [],\n" + + "\"allowed_actions\": [\n" + + "\"crud\",\n" + + "\"indices:admin/create\",\n" + + "\"indices:admin/aliases\",\n" + + "\"indices:admin/delete\"\n" + + "]\n" + + "}\n" + + "],\n" + + "\"tenant_permissions\": []\n" + + "}" + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + + public Response createRoleMapping(String role, List users) throws IOException { + JsonArray usersString = new JsonArray(); + for (int i = 0; i < users.size(); i++) { + usersString.add(users.get(i)); + } + return TestHelpers.makeRequest( + client(), + "PUT", + "/_plugins/_security/api/rolesmapping/" + role, + null, + TestHelpers.toHttpEntity( + "{\n" + " \"backend_roles\" : [ ],\n" + " \"hosts\" : [ ],\n" + " \"users\" : " + usersString + "\n" + "}" + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + + public Response enableFilterBy() throws IOException { + return TestHelpers.makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + TestHelpers.toHttpEntity( + "{\n" + " \"persistent\": {\n" + " \"plugins.flow_framework.filter_by_backend_roles\" : \"true\"\n" + " }\n" + "}" + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + + public Response disableFilterBy() throws IOException { + return TestHelpers.makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + TestHelpers.toHttpEntity( + "{\n" + " \"persistent\": {\n" + " \"plugins.flow_framework.filter_by_backend_roles\" : \"false\"\n" + " }\n" + "}" + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + + public void confirmingClientIsAdmin() throws IOException { + Response resp = TestHelpers.makeRequest( + client(), + "GET", + "_plugins/_security/api/account", + null, + "", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "admin")) + ); + Map responseMap = entityAsMap(resp); + ArrayList roles = (ArrayList) responseMap.get("roles"); + assertTrue(roles.contains("all_access")); + } + /** * Helper method to invoke the create workflow API with a use case and also the provision param as true * @param client the rest client @@ -759,25 +809,40 @@ protected List getResourcesCreated(RestClient client, String wo } } - protected Response createUser(String name, String password, List backendRoles) throws IOException { - String backendRolesString = backendRoles.stream().map(item -> "\"" + item + "\"").collect(Collectors.joining(",")); - String json = "{\"password\": \"" - + password - + "\",\"opendistro_security_roles\": [" - + backendRolesString - + "],\"backend_roles\": [],\"attributes\": {}}"; + public Response createUser(String name, String password, List backendRoles) throws IOException { + JsonArray backendRolesString = new JsonArray(); + for (int i = 0; i < backendRoles.size(); i++) { + backendRolesString.add(backendRoles.get(i)); + } return TestHelpers.makeRequest( client(), "PUT", - "/_opendistro/_security/api/internalusers/" + name, + "/_plugins/_security/api/internalusers/" + name, null, - TestHelpers.toHttpEntity(json), - null + TestHelpers.toHttpEntity( + " {\n" + + "\"password\": \"" + + password + + "\",\n" + + "\"backend_roles\": " + + backendRolesString + + ",\n" + + "\"attributes\": {\n" + + "}} " + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) ); } protected Response deleteUser(String user) throws IOException { - return TestHelpers.makeRequest(client(), "DELETE", "/_opendistro/_security/api/internalusers/" + user, null, "", null); + return TestHelpers.makeRequest( + client(), + "DELETE", + "/_plugins/_security/api/internalusers/" + user, + null, + "", + List.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); } protected GetPipelineResponse getPipelines(String pipelineId) throws IOException { diff --git a/src/test/java/org/opensearch/flowframework/TestHelpers.java b/src/test/java/org/opensearch/flowframework/TestHelpers.java index 7dae5b09c..b65e7b599 100644 --- a/src/test/java/org/opensearch/flowframework/TestHelpers.java +++ b/src/test/java/org/opensearch/flowframework/TestHelpers.java @@ -11,6 +11,8 @@ import org.apache.http.Header; import org.apache.http.HttpEntity; import org.apache.http.nio.entity.NStringEntity; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.search.SearchRequest; import org.opensearch.client.Request; import org.opensearch.client.RequestOptions; import org.opensearch.client.Response; @@ -34,6 +36,10 @@ import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.index.get.GetResult; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; import java.io.BufferedReader; import java.io.IOException; @@ -47,6 +53,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; import static org.opensearch.test.OpenSearchTestCase.randomAlphaOfLength; import static org.apache.http.entity.ContentType.APPLICATION_JSON; @@ -177,4 +184,27 @@ public static Workflow createSampleWorkflow() { return workflow; } + public static SearchRequest matchAllRequest() { + BoolQueryBuilder query = new BoolQueryBuilder().filter(new MatchAllQueryBuilder()); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); + return new SearchRequest().source(searchSourceBuilder); + } + + public static GetResponse createGetResponse(ToXContentObject o, String id, String indexName) throws IOException { + XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + return new GetResponse( + new GetResult( + indexName, + id, + UNASSIGNED_SEQ_NO, + 0, + -1, + true, + BytesReference.bytes(content), + Collections.emptyMap(), + Collections.emptyMap() + ) + ); + } + } diff --git a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkSecureRestApiIT.java b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkSecureRestApiIT.java index 40efaff83..4ec64c26b 100644 --- a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkSecureRestApiIT.java +++ b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkSecureRestApiIT.java @@ -8,147 +8,553 @@ */ package org.opensearch.flowframework.rest; +import org.apache.http.HttpHost; +import org.opensearch.action.ingest.GetPipelineResponse; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; -import org.opensearch.common.util.io.IOUtils; +import org.opensearch.client.RestClient; +import org.opensearch.commons.rest.SecureRestClientBuilder; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.FlowFrameworkRestTestCase; import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.model.ProvisioningProgress; +import org.opensearch.flowframework.model.ResourceCreated; +import org.opensearch.flowframework.model.State; import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowNode; import org.junit.After; +import org.junit.Before; import java.io.IOException; +import java.security.SecureRandom; +import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; public class FlowFrameworkSecureRestApiIT extends FlowFrameworkRestTestCase { + String aliceUser = "alice"; + RestClient aliceClient; + String bobUser = "bob"; + RestClient bobClient; + String catUser = "cat"; + RestClient catClient; + String dogUser = "dog"; + RestClient dogClient; + String elkUser = "elk"; + RestClient elkClient; + String fishUser = "fish"; + RestClient fishClient; + String lionUser = "lion"; + RestClient lionClient; + private String indexAllAccessRole = "index_all_access"; + private static String FLOW_FRAMEWORK_FULL_ACCESS_ROLE = "flow_framework_full_access"; + private static String ML_COMMONS_FULL_ACCESS_ROLE = "ml_full_access"; + private static String FLOW_FRAMEWORK_READ_ACCESS_ROLE = "flow_framework_read_access"; + + @Before + public void setupSecureTests() throws IOException { + if (!isHttps()) throw new IllegalArgumentException("Secure Tests are running but HTTPS is not set"); + createIndexRole(indexAllAccessRole, "*"); + String alicePassword = generatePassword(aliceUser); + createUser(aliceUser, alicePassword, List.of("odfe")); + aliceClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), aliceUser, alicePassword) + .setSocketTimeout(60000) + .build(); + + String bobPassword = generatePassword(bobUser); + createUser(bobUser, bobPassword, List.of("odfe")); + bobClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), bobUser, bobPassword) + .setSocketTimeout(60000) + .build(); + + String catPassword = generatePassword(catUser); + createUser(catUser, catPassword, List.of("aes")); + catClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), catUser, catPassword) + .setSocketTimeout(60000) + .build(); + + String dogPassword = generatePassword(dogUser); + createUser(dogUser, dogPassword, List.of()); + dogClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), dogUser, dogPassword) + .setSocketTimeout(60000) + .build(); + + String elkPassword = generatePassword(elkUser); + createUser(elkUser, elkPassword, List.of("odfe")); + elkClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), elkUser, elkPassword) + .setSocketTimeout(60000) + .build(); + + String fishPassword = generatePassword(fishUser); + createUser(fishUser, fishPassword, List.of("odfe", "aes")); + fishClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), fishUser, fishPassword) + .setSocketTimeout(60000) + .build(); + + String lionPassword = generatePassword(lionUser); + createUser(lionUser, lionPassword, List.of("opensearch")); + lionClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), lionUser, lionPassword) + .setSocketTimeout(60000) + .build(); + + createRoleMapping(FLOW_FRAMEWORK_READ_ACCESS_ROLE, List.of(bobUser)); + createRoleMapping(ML_COMMONS_FULL_ACCESS_ROLE, List.of(aliceUser, catUser, dogUser, elkUser, fishUser)); + createRoleMapping(FLOW_FRAMEWORK_FULL_ACCESS_ROLE, List.of(aliceUser, catUser, dogUser, elkUser, fishUser)); + createRoleMapping(indexAllAccessRole, List.of(aliceUser)); + } + @After public void tearDownSecureTests() throws IOException { - IOUtils.close(fullAccessClient(), readAccessClient()); - deleteUser(FULL_ACCESS_USER); - deleteUser(READ_ACCESS_USER); + aliceClient.close(); + bobClient.close(); + catClient.close(); + dogClient.close(); + elkClient.close(); + fishClient.close(); + lionClient.close(); + deleteUser(aliceUser); + deleteUser(bobUser); + deleteUser(catUser); + deleteUser(dogUser); + deleteUser(elkUser); + deleteUser(fishUser); + deleteUser(lionUser); + } + + /** + * Create an unguessable password. Simple password are weak due to https://tinyurl.com/383em9zk + * @return a random password. + */ + public static String generatePassword(String username) { + String upperCase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + String lowerCase = "abcdefghijklmnopqrstuvwxyz"; + String digits = "0123456789"; + String special = "_"; + String characters = upperCase + lowerCase + digits + special; + + SecureRandom rng = new SecureRandom(); + + // Ensure password includes at least one character from each set + char[] password = new char[15]; + password[0] = upperCase.charAt(rng.nextInt(upperCase.length())); + password[1] = lowerCase.charAt(rng.nextInt(lowerCase.length())); + password[2] = digits.charAt(rng.nextInt(digits.length())); + password[3] = special.charAt(rng.nextInt(special.length())); + + for (int i = 4; i < 15; i++) { + char nextChar; + do { + nextChar = characters.charAt(rng.nextInt(characters.length())); + } while (username.indexOf(nextChar) > -1); + password[i] = nextChar; + } + + // Shuffle the array to ensure the first 4 characters are not always in the same position + for (int i = password.length - 1; i > 0; i--) { + int index = rng.nextInt(i + 1); + char temp = password[index]; + password[index] = password[i]; + password[i] = temp; + } + + return new String(password); } public void testCreateWorkflowWithReadAccess() throws Exception { Template template = TestHelpers.createTemplateFromFile("register-deploylocalsparseencodingmodel.json"); - ResponseException exception = expectThrows(ResponseException.class, () -> createWorkflow(readAccessClient(), template)); + ResponseException exception = expectThrows(ResponseException.class, () -> createWorkflow(bobClient, template)); + assertEquals(RestStatus.FORBIDDEN.getStatus(), exception.getResponse().getStatusLine().getStatusCode()); + } + + public void testCreateWorkflowWithWriteAccess() throws Exception { + // User Alice has FF full access, should be able to create a workflow + Template template = TestHelpers.createTemplateFromFile("register-deploylocalsparseencodingmodel.json"); + Response response = createWorkflow(aliceClient, template); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + } + + public void testCreateWorkflowWithNoFFAccess() throws Exception { + // User Lion has no FF access at all, should not be able to create a workflow + disableFilterBy(); + Template template = TestHelpers.createTemplateFromFile("createconnector-registerremotemodel-deploymodel.json"); + + ResponseException exception = expectThrows(ResponseException.class, () -> { createWorkflow(lionClient, template); }); assertEquals(RestStatus.FORBIDDEN.getStatus(), exception.getResponse().getStatusLine().getStatusCode()); } public void testProvisionWorkflowWithReadAccess() throws Exception { - ResponseException exception = expectThrows(ResponseException.class, () -> provisionWorkflow(readAccessClient(), "test")); + ResponseException exception = expectThrows(ResponseException.class, () -> provisionWorkflow(bobClient, "test")); + assertEquals(RestStatus.FORBIDDEN.getStatus(), exception.getResponse().getStatusLine().getStatusCode()); + } + + public void testReprovisionWorkflowWithReadAccess() throws Exception { + Template template = TestHelpers.createTemplateFromFile("createconnector-registerremotemodel-deploymodel.json"); + ResponseException exception = expectThrows(ResponseException.class, () -> reprovisionWorkflow(bobClient, "test", template)); assertEquals(RestStatus.FORBIDDEN.getStatus(), exception.getResponse().getStatusLine().getStatusCode()); } public void testDeleteWorkflowWithReadAccess() throws Exception { - ResponseException exception = expectThrows(ResponseException.class, () -> deleteWorkflow(readAccessClient(), "test")); + ResponseException exception = expectThrows(ResponseException.class, () -> deleteWorkflow(bobClient, "test")); assertEquals(RestStatus.FORBIDDEN.getStatus(), exception.getResponse().getStatusLine().getStatusCode()); } public void testDeprovisionWorkflowWithReadAcess() throws Exception { - ResponseException exception = expectThrows(ResponseException.class, () -> deprovisionWorkflow(readAccessClient(), "test")); + ResponseException exception = expectThrows(ResponseException.class, () -> deprovisionWorkflow(bobClient, "test")); assertEquals(RestStatus.FORBIDDEN.getStatus(), exception.getResponse().getStatusLine().getStatusCode()); } public void testGetWorkflowStepsWithReadAccess() throws Exception { - Response response = getWorkflowStep(readAccessClient()); + Response response = getWorkflowStep(bobClient); assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); } public void testGetWorkflowWithReadAccess() throws Exception { // No permissions to create, so we assert only that the response status isnt forbidden - ResponseException exception = expectThrows(ResponseException.class, () -> getWorkflow(readAccessClient(), "test")); + ResponseException exception = expectThrows(ResponseException.class, () -> getWorkflow(bobClient, "test")); assertEquals(RestStatus.NOT_FOUND, TestHelpers.restStatus(exception.getResponse())); } + public void testFilterByDisabled() throws Exception { + disableFilterBy(); + Template template = TestHelpers.createTemplateFromFile("register-deploylocalsparseencodingmodel.json"); + Response aliceWorkflow = createWorkflow(aliceClient, template); + Map responseMap = entityAsMap(aliceWorkflow); + String workflowId = (String) responseMap.get(WORKFLOW_ID); + Response response = getWorkflow(catClient, workflowId); + assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); + } + public void testSearchWorkflowWithReadAccess() throws Exception { // Use full access client to invoke create workflow to ensure the template/state indices are created Template template = TestHelpers.createTemplateFromFile("createconnector-registerremotemodel-deploymodel.json"); - Response response = createWorkflow(fullAccessClient(), template); + Response response = createWorkflow(aliceClient, template); assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); // No permissions to create, so we assert only that the response status isnt forbidden - String termIdQuery = "{\"query\":{\"ids\":{\"values\":[\"test\"]}}}"; - SearchResponse seachResponse = searchWorkflows(readAccessClient(), termIdQuery); + String termIdQuery = "{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"ids\":\"test\"}}]}}}"; + SearchResponse seachResponse = searchWorkflows(bobClient, termIdQuery); assertEquals(RestStatus.OK, seachResponse.status()); } public void testGetWorkflowStateWithReadAccess() throws Exception { // Use the full access client to invoke create workflow to ensure the template/state indices are created Template template = TestHelpers.createTemplateFromFile("createconnector-registerremotemodel-deploymodel.json"); - Response response = createWorkflow(fullAccessClient(), template); + Response response = createWorkflow(aliceClient, template); assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + Map responseMap = entityAsMap(response); + String workflowId = (String) responseMap.get(WORKFLOW_ID); // No permissions to create or provision, so we assert only that the response status isnt forbidden - ResponseException exception = expectThrows(ResponseException.class, () -> getWorkflowStatus(readAccessClient(), "test", false)); - assertTrue(exception.getMessage().contains("Fail to find workflow")); - assertEquals(RestStatus.NOT_FOUND, TestHelpers.restStatus(exception.getResponse())); + Response searchResponse = getWorkflowStatus(bobClient, workflowId, false); + assertEquals(RestStatus.OK, TestHelpers.restStatus(searchResponse)); } public void testSearchWorkflowStateWithReadAccess() throws Exception { // Use the full access client to invoke create workflow to ensure the template/state indices are created Template template = TestHelpers.createTemplateFromFile("createconnector-registerremotemodel-deploymodel.json"); - Response response = createWorkflow(fullAccessClient(), template); + Response response = createWorkflow(aliceClient, template); assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); // No permissions to create, so we assert only that the response status isnt forbidden - String termIdQuery = "{\"query\":{\"ids\":{\"values\":[\"test\"]}}}"; - SearchResponse searchResponse = searchWorkflowState(readAccessClient(), termIdQuery); + String termIdQuery = "{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"ids\":\"test\"}}]}}}"; + SearchResponse searchResponse = searchWorkflowState(bobClient, termIdQuery); assertEquals(RestStatus.OK, searchResponse.status()); } + public void testCreateWorkflowWithNoBackendRole() throws IOException { + enableFilterBy(); + // User Dog has FF full access, but has no backend role + // When filter by is enabled, we block creating workflows + Template template = TestHelpers.createTemplateFromFile("register-deploylocalsparseencodingmodel.json"); + Exception exception = expectThrows(IOException.class, () -> { createWorkflow(dogClient, template); }); + assertTrue( + exception.getMessage().contains("Filter by backend roles is enabled, but User dog does not have any backend roles configured") + ); + } + + public void testDeprovisionWorkflowWithWriteAccess() throws Exception { + // User Alice has FF full access, should be able to deprovision a workflow + Template template = TestHelpers.createTemplateFromFile("register-deploylocalsparseencodingmodel.json"); + Response aliceWorkflow = createWorkflow(aliceClient, template); + enableFilterBy(); + Map responseMap = entityAsMap(aliceWorkflow); + String workflowId = (String) responseMap.get(WORKFLOW_ID); + Response response = deprovisionWorkflow(aliceClient, workflowId); + assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); + } + + public void testGetWorkflowWithFilterEnabled() throws Exception { + Template template = TestHelpers.createTemplateFromFile("register-deploylocalsparseencodingmodel.json"); + Response aliceWorkflow = createWorkflow(aliceClient, template); + enableFilterBy(); + Map responseMap = entityAsMap(aliceWorkflow); + String workflowId = (String) responseMap.get(WORKFLOW_ID); + // User Cat has FF full access, but is part of different backend role so Cat should not be able to access alice workflow + ResponseException exception = expectThrows(ResponseException.class, () -> getWorkflow(catClient, workflowId)); + assertTrue(exception.getMessage().contains("User does not have permissions to access workflow: " + workflowId)); + } + + public void testGetWorkflowFilterbyEnabledForAdmin() throws Exception { + // User Alice has FF full access, should be able to create a workflow and has backend role "odfe" + Template template = TestHelpers.createTemplateFromFile("register-deploylocalsparseencodingmodel.json"); + Response aliceWorkflow = createWorkflow(aliceClient, template); + enableFilterBy(); + confirmingClientIsAdmin(); + Map responseMap = entityAsMap(aliceWorkflow); + String workflowId = (String) responseMap.get(WORKFLOW_ID); + Response response = getWorkflow(aliceClient, workflowId); + assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); + } + + public void testProvisionWorkflowWithWriteAccess() throws Exception { + // User Alice has FF full access, should be able to provision a workflow + Template template = TestHelpers.createTemplateFromFile("register-deploylocalsparseencodingmodel.json"); + Response aliceWorkflow = createWorkflow(aliceClient, template); + enableFilterBy(); + confirmingClientIsAdmin(); + Map responseMap = entityAsMap(aliceWorkflow); + String workflowId = (String) responseMap.get(WORKFLOW_ID); + Response response = provisionWorkflow(aliceClient, workflowId); + assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); + } + + public void testReprovisionWorkflowWithWriteAccess() throws Exception { + // User Alice has FF full access, should be able to reprovision a workflow + // Begin with a template to register a local pretrained model and create an index, no edges + Template template = TestHelpers.createTemplateFromFile("registerremotemodel-createindex.json"); + + enableFilterBy(); + Response response = createWorkflowWithProvision(aliceClient, template); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + + Map responseMap = entityAsMap(response); + String workflowId = (String) responseMap.get(WORKFLOW_ID); + // wait and ensure state is completed/done + assertBusy( + () -> { getAndAssertWorkflowStatus(aliceClient, workflowId, State.COMPLETED, ProvisioningProgress.DONE); }, + 120, + TimeUnit.SECONDS + ); + + // Wait until provisioning has completed successfully before attempting to retrieve created resources + List resourcesCreated = getResourcesCreated(aliceClient, workflowId, 30); + assertEquals(4, resourcesCreated.size()); + Map resourceMap = resourcesCreated.stream() + .collect(Collectors.toMap(ResourceCreated::workflowStepName, r -> r)); + assertTrue(resourceMap.containsKey("create_connector")); + assertTrue(resourceMap.containsKey("register_remote_model")); + assertTrue(resourceMap.containsKey("create_index")); + + // Reprovision template to add ingest pipeline which uses the model ID + template = TestHelpers.createTemplateFromFile("registerremotemodel-ingestpipeline-createindex.json"); + response = reprovisionWorkflow(aliceClient, workflowId, template); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + + resourcesCreated = getResourcesCreated(aliceClient, workflowId, 30); + assertEquals(5, resourcesCreated.size()); + resourceMap = resourcesCreated.stream().collect(Collectors.toMap(ResourceCreated::workflowStepName, r -> r)); + assertTrue(resourceMap.containsKey("create_connector")); + assertTrue(resourceMap.containsKey("register_remote_model")); + assertTrue(resourceMap.containsKey("create_ingest_pipeline")); + assertTrue(resourceMap.containsKey("create_index")); + + // Ensure ingest pipeline configuration contains the model id and index settings have the ingest pipeline as default + String modelId = resourceMap.get("register_remote_model").resourceId(); + String pipelineId = resourceMap.get("create_ingest_pipeline").resourceId(); + GetPipelineResponse getPipelineResponse = getPipelines(pipelineId); + assertEquals(1, getPipelineResponse.pipelines().size()); + assertTrue(getPipelineResponse.pipelines().get(0).getConfigAsMap().toString().contains(modelId)); + + String indexName = resourceMap.get("create_index").resourceId(); + Map indexSettings = getIndexSettingsAsMap(indexName); + assertEquals(pipelineId, indexSettings.get("index.default_pipeline")); + + // Deprovision and delete all resources + Response deprovisionResponse = deprovisionWorkflowWithAllowDelete(aliceClient, workflowId, pipelineId + "," + indexName); + assertBusy( + () -> { getAndAssertWorkflowStatus(aliceClient, workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); }, + 60, + TimeUnit.SECONDS + ); + assertEquals(RestStatus.OK, TestHelpers.restStatus(deprovisionResponse)); + + // Hit Delete API + Response deleteResponse = deleteWorkflow(aliceClient, workflowId); + assertEquals(RestStatus.OK, TestHelpers.restStatus(deleteResponse)); + } + + public void testDeleteWorkflowWithWriteAccess() throws Exception { + // User Alice has FF full access, should be able to delete a workflow + Template template = TestHelpers.createTemplateFromFile("register-deploylocalsparseencodingmodel.json"); + Response aliceWorkflow = createWorkflow(aliceClient, template); + enableFilterBy(); + Map responseMap = entityAsMap(aliceWorkflow); + String workflowId = (String) responseMap.get(WORKFLOW_ID); + Response response = deleteWorkflow(aliceClient, workflowId); + assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); + } + public void testCreateProvisionDeprovisionWorkflowWithFullAccess() throws Exception { // Invoke create workflow API Template template = TestHelpers.createTemplateFromFile("createconnector-registerremotemodel-deploymodel.json"); - Response response = createWorkflow(fullAccessClient(), template); + Response response = createWorkflow(aliceClient, template); assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + enableFilterBy(); + // Retrieve workflow ID Map responseMap = entityAsMap(response); String workflowId = (String) responseMap.get(WORKFLOW_ID); // Invoke search workflows API - String termIdQuery = "{\"query\":{\"ids\":{\"values\":[\"" + workflowId + "\"]}}}"; - SearchResponse searchResponse = searchWorkflows(fullAccessClient(), termIdQuery); + String termIdQuery = "{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"ids\":\"" + workflowId + "\"}}]}}}"; + SearchResponse searchResponse = searchWorkflows(aliceClient, termIdQuery); assertEquals(RestStatus.OK, searchResponse.status()); // Invoke provision API if (!indexExistsWithAdminClient(".plugins-ml-config")) { assertBusy(() -> assertTrue(indexExistsWithAdminClient(".plugins-ml-config")), 40, TimeUnit.SECONDS); - response = provisionWorkflow(fullAccessClient(), workflowId); + response = provisionWorkflow(aliceClient, workflowId); } else { - response = provisionWorkflow(fullAccessClient(), workflowId); + response = provisionWorkflow(aliceClient, workflowId); } assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); // Invoke status API - response = getWorkflowStatus(fullAccessClient(), workflowId, false); + response = getWorkflowStatus(aliceClient, workflowId, false); assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); - // Invoke delete API while state still exists - response = deleteWorkflow(fullAccessClient(), workflowId); - assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); - - // Invoke status API - response = getWorkflowStatus(fullAccessClient(), workflowId, false); + // Invoke deprovision API + response = deprovisionWorkflow(aliceClient, workflowId); assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); - // Invoke deprovision API - response = deprovisionWorkflow(fullAccessClient(), workflowId); + // Invoke delete API + response = deleteWorkflow(aliceClient, workflowId); assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); // Invoke status API with failure - ResponseException exception = expectThrows(ResponseException.class, () -> getWorkflowStatus(fullAccessClient(), workflowId, false)); + ResponseException exception = expectThrows(ResponseException.class, () -> getWorkflowStatus(aliceClient, workflowId, false)); assertEquals(RestStatus.NOT_FOUND.getStatus(), exception.getResponse().getStatusLine().getStatusCode()); } + public void testUpdateWorkflowEnabledForAdmin() throws Exception { + Template template = TestHelpers.createTemplateFromFile("register-deploylocalsparseencodingmodel.json"); + + // Remove register model input to test validation + Workflow originalWorkflow = template.workflows().get(PROVISION_WORKFLOW); + List modifiednodes = originalWorkflow.nodes() + .stream() + .map( + n -> "workflow_step_1".equals(n.id()) + ? new WorkflowNode( + "workflow_step_1", + "register_local_sparse_encoding_model", + Collections.emptyMap(), + Collections.emptyMap() + ) + : n + ) + .collect(Collectors.toList()); + Workflow missingInputs = new Workflow(originalWorkflow.userParams(), modifiednodes, originalWorkflow.edges()); + Template templateWithMissingInputs = Template.builder(template).workflows(Map.of(PROVISION_WORKFLOW, missingInputs)).build(); + + Response response = createWorkflow(aliceClient, templateWithMissingInputs); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + + Map responseMap = entityAsMap(response); + String workflowId = (String) responseMap.get(WORKFLOW_ID); + + ResponseException exception = expectThrows(ResponseException.class, () -> provisionWorkflow(aliceClient, workflowId)); + assertTrue(exception.getMessage().contains("Invalid workflow, node [workflow_step_1] missing the following required inputs")); + getAndAssertWorkflowStatus(aliceClient, workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); + + enableFilterBy(); + // User alice has admin all access, and has "odfe" backend role so client should be able to update workflow + Response updateResponse = updateWorkflow(aliceClient, workflowId, template); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(updateResponse)); + } + + public void testUpdateWorkflowWithFilterEnabled() throws Exception { + Template template = TestHelpers.createTemplateFromFile("register-deploylocalsparseencodingmodel.json"); + + // Remove register model input to test validation + Workflow originalWorkflow = template.workflows().get(PROVISION_WORKFLOW); + List modifiednodes = originalWorkflow.nodes() + .stream() + .map( + n -> "workflow_step_1".equals(n.id()) + ? new WorkflowNode( + "workflow_step_1", + "register_local_sparse_encoding_model", + Collections.emptyMap(), + Collections.emptyMap() + ) + : n + ) + .collect(Collectors.toList()); + Workflow missingInputs = new Workflow(originalWorkflow.userParams(), modifiednodes, originalWorkflow.edges()); + Template templateWithMissingInputs = Template.builder(template).workflows(Map.of(PROVISION_WORKFLOW, missingInputs)).build(); + + Response response = createWorkflow(aliceClient, templateWithMissingInputs); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + + Map responseMap = entityAsMap(response); + String workflowId = (String) responseMap.get(WORKFLOW_ID); + + enableFilterBy(); + // User Fish has FF full access, and has "odfe" backend role which is one of Alice's backend role, so + // Fish should be able to update workflows created by Alice. But the workflow's backend role should + // not be replaced as Fish's backend roles. + Response updateResponse = updateWorkflow(fishClient, workflowId, template); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(updateResponse)); + } + + public void testUpdateWorkflowWithNoFFAccess() throws Exception { + Template template = TestHelpers.createTemplateFromFile("register-deploylocalsparseencodingmodel.json"); + + // Remove register model input to test validation + Workflow originalWorkflow = template.workflows().get(PROVISION_WORKFLOW); + List modifiednodes = originalWorkflow.nodes() + .stream() + .map( + n -> "workflow_step_1".equals(n.id()) + ? new WorkflowNode( + "workflow_step_1", + "register_local_sparse_encoding_model", + Collections.emptyMap(), + Collections.emptyMap() + ) + : n + ) + .collect(Collectors.toList()); + Workflow missingInputs = new Workflow(originalWorkflow.userParams(), modifiednodes, originalWorkflow.edges()); + Template templateWithMissingInputs = Template.builder(template).workflows(Map.of(PROVISION_WORKFLOW, missingInputs)).build(); + + Response response = createWorkflow(aliceClient, templateWithMissingInputs); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + + Map responseMap = entityAsMap(response); + String workflowId = (String) responseMap.get(WORKFLOW_ID); + + enableFilterBy(); + + // User lion has no FF access and should not be able to update the workflow created by alice + ResponseException exception1 = expectThrows(ResponseException.class, () -> { updateWorkflow(lionClient, workflowId, template); }); + assertEquals(RestStatus.FORBIDDEN.getStatus(), exception1.getResponse().getStatusLine().getStatusCode()); + } + public void testGetWorkflowStepWithFullAccess() throws Exception { - Response response = getWorkflowStep(fullAccessClient()); + Response response = getWorkflowStep(aliceClient); + enableFilterBy(); + confirmingClientIsAdmin(); assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); } - } diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 90b60d1d3..86499e3d8 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -18,13 +18,21 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; import org.opensearch.flowframework.TestHelpers; import org.opensearch.flowframework.common.FlowFrameworkSettings; +import org.opensearch.flowframework.indices.FlowFrameworkIndex; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; @@ -39,9 +47,14 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; +import java.io.IOException; +import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; +import java.util.Set; import java.util.concurrent.TimeUnit; import org.mockito.ArgumentCaptor; @@ -79,6 +92,8 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { private ThreadPool threadPool; private FlowFrameworkSettings flowFrameworkSettings; private PluginsService pluginsService; + private ClusterService clusterService; + private ClusterSettings clusterSettings; @Override public void setUp() throws Exception { @@ -95,6 +110,26 @@ public void setUp() throws Exception { this.workflowProcessSorter = mock(WorkflowProcessSorter.class); this.pluginsService = mock(PluginsService.class); + clusterService = mock(ClusterService.class); + clusterSettings = new ClusterSettings(Settings.EMPTY, Set.copyOf(List.of(FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES))); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + FlowFrameworkIndex index = FlowFrameworkIndex.GLOBAL_CONTEXT; + ClusterName clusterName = new ClusterName("test"); + + Settings indexSettings = Settings.builder() + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .build(); + final Settings.Builder existingSettings = Settings.builder().put(indexSettings).put(IndexMetadata.SETTING_INDEX_UUID, "test2UUID"); + + IndexMetadata indexMetaData = IndexMetadata.builder(GLOBAL_CONTEXT_INDEX).settings(existingSettings).build(); + final Map indices = new HashMap<>(); + indices.put(GLOBAL_CONTEXT_INDEX, indexMetaData); + ClusterState clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().indices(indices).build()).build(); + when(clusterService.state()).thenReturn(clusterState); + // Spy this action to stub check max workflows this.createWorkflowTransportAction = spy( new CreateWorkflowTransportAction( @@ -104,7 +139,10 @@ public void setUp() throws Exception { flowFrameworkIndicesHandler, flowFrameworkSettings, client, - pluginsService + pluginsService, + clusterService, + xContentRegistry(), + Settings.EMPTY ) ); // client = mock(Client.class); @@ -338,7 +376,169 @@ public void testCreateNewWorkflow() { assertEquals("1", workflowResponseCaptor.getValue().getWorkflowId()); } - public void testUpdateWorkflowWithReprovision() { + public void testCreateWithUserAndFilterOn() { + Settings settings = Settings.builder().put(FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES.getKey(), true).build(); + ThreadContext threadContext = new ThreadContext(settings); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alice|odfe,aes|engineering,operations"); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + org.opensearch.threadpool.ThreadPool mockThreadPool = mock(ThreadPool.class); + when(client.threadPool()).thenReturn(mockThreadPool); + when(mockThreadPool.getThreadContext()).thenReturn(threadContext); + + CreateWorkflowTransportAction createWorkflowTransportAction1 = spy( + new CreateWorkflowTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + workflowProcessSorter, + flowFrameworkIndicesHandler, + flowFrameworkSettings, + client, + pluginsService, + clusterService, + xContentRegistry(), + settings + ) + ); + + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + null, + Collections.emptyMap(), + false + ); + + // Bypass checkMaxWorkflows and force onResponse + doAnswer(invocation -> { + ActionListener checkMaxWorkflowListener = invocation.getArgument(2); + checkMaxWorkflowListener.onResponse(true); + return null; + }).when(createWorkflowTransportAction1).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); + + // Bypass initializeConfigIndex and force onResponse + doAnswer(invocation -> { + ActionListener initalizeMasterKeyIndexListener = invocation.getArgument(0); + initalizeMasterKeyIndexListener.onResponse(true); + return null; + }).when(flowFrameworkIndicesHandler).initializeConfigIndex(any()); + + // Bypass putTemplateToGlobalContext and force onResponse + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onResponse(new IndexResponse(new ShardId(GLOBAL_CONTEXT_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(flowFrameworkIndicesHandler).putTemplateToGlobalContext(any(), any()); + + // Bypass putInitialStateToWorkflowState and force on response + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(new IndexResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "1", 1L, 1L, 1L, true)); + return null; + }).when(flowFrameworkIndicesHandler).putInitialStateToWorkflowState(any(), any(), any()); + + ArgumentCaptor workflowResponseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); + + createWorkflowTransportAction1.doExecute(mock(Task.class), workflowRequest, listener); + + verify(listener, times(1)).onResponse(workflowResponseCaptor.capture()); + } + + public void testFailedToCreateNewWorkflowWithNullUser() { + @SuppressWarnings("unchecked") + Settings settings = Settings.builder().put(FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES.getKey(), true).build(); + ThreadContext threadContext = new ThreadContext(settings); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, null); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + org.opensearch.threadpool.ThreadPool mockThreadPool = mock(ThreadPool.class); + when(client.threadPool()).thenReturn(mockThreadPool); + when(mockThreadPool.getThreadContext()).thenReturn(threadContext); + + CreateWorkflowTransportAction createWorkflowTransportAction1 = spy( + new CreateWorkflowTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + workflowProcessSorter, + flowFrameworkIndicesHandler, + flowFrameworkSettings, + client, + pluginsService, + clusterService, + xContentRegistry(), + settings + ) + ); + + ActionListener listener = mock(ActionListener.class); + + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + null, + Collections.emptyMap(), + false + ); + + createWorkflowTransportAction1.doExecute(mock(Task.class), workflowRequest, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Filter by backend roles is enabled and User is null", exceptionCaptor.getValue().getMessage()); + } + + public void testFailedToCreateNewWorkflowWithNoBackendRoleUser() { + @SuppressWarnings("unchecked") + Settings settings = Settings.builder().put(FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES.getKey(), true).build(); + ThreadContext threadContext = new ThreadContext(settings); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "test"); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + org.opensearch.threadpool.ThreadPool mockThreadPool = mock(ThreadPool.class); + when(client.threadPool()).thenReturn(mockThreadPool); + when(mockThreadPool.getThreadContext()).thenReturn(threadContext); + + CreateWorkflowTransportAction createWorkflowTransportAction1 = spy( + new CreateWorkflowTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + workflowProcessSorter, + flowFrameworkIndicesHandler, + flowFrameworkSettings, + client, + pluginsService, + clusterService, + xContentRegistry(), + settings + ) + ); + + ActionListener listener = mock(ActionListener.class); + + WorkflowRequest workflowRequest = new WorkflowRequest( + null, + template, + new String[] { "off" }, + false, + Collections.emptyMap(), + null, + Collections.emptyMap(), + false + ); + + createWorkflowTransportAction1.doExecute(mock(Task.class), workflowRequest, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals( + "Filter by backend roles is enabled, but User test does not have any backend roles configured", + exceptionCaptor.getValue().getMessage() + ); + } + + public void testUpdateWorkflowWithReprovision() throws IOException { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); WorkflowRequest workflowRequest = new WorkflowRequest( @@ -361,6 +561,23 @@ public void testUpdateWorkflowWithReprovision() { return null; }).when(client).get(any(GetRequest.class), any()); + GetResponse getWorkflowResponse = TestHelpers.createGetResponse(template, "123", GLOBAL_CONTEXT_INDEX); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertEquals( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + 2, + args.length + ); + + assertTrue(args[0] instanceof GetRequest); + assertTrue(args[1] instanceof ActionListener); + + ActionListener getListener = (ActionListener) args[1]; + getListener.onResponse(getWorkflowResponse); + return null; + }).when(client).get(any(GetRequest.class), any()); + doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(2); responseListener.onResponse(new WorkflowResponse("1")); @@ -374,7 +591,7 @@ public void testUpdateWorkflowWithReprovision() { assertEquals("1", responseCaptor.getValue().getWorkflowId()); } - public void testFailedToUpdateWorkflowWithReprovision() { + public void testFailedToUpdateWorkflowWithReprovision() throws IOException { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); WorkflowRequest workflowRequest = new WorkflowRequest( @@ -397,6 +614,23 @@ public void testFailedToUpdateWorkflowWithReprovision() { return null; }).when(client).get(any(GetRequest.class), any()); + GetResponse getWorkflowResponse = TestHelpers.createGetResponse(template, "123", GLOBAL_CONTEXT_INDEX); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertEquals( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + 2, + args.length + ); + + assertTrue(args[0] instanceof GetRequest); + assertTrue(args[1] instanceof ActionListener); + + ActionListener getListener = (ActionListener) args[1]; + getListener.onResponse(getWorkflowResponse); + return null; + }).when(client).get(any(GetRequest.class), any()); + doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(2); responseListener.onFailure(new Exception("failed")); @@ -410,7 +644,7 @@ public void testFailedToUpdateWorkflowWithReprovision() { assertEquals("Reprovisioning failed for workflow 1", responseCaptor.getValue().getMessage()); } - public void testFailedToUpdateWorkflow() { + public void testFailedToUpdateWorkflow() throws IOException { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); WorkflowRequest updateWorkflow = new WorkflowRequest("1", template); @@ -424,6 +658,23 @@ public void testFailedToUpdateWorkflow() { return null; }).when(client).get(any(GetRequest.class), any()); + GetResponse getWorkflowResponse = TestHelpers.createGetResponse(template, "123", GLOBAL_CONTEXT_INDEX); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertEquals( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + 2, + args.length + ); + + assertTrue(args[0] instanceof GetRequest); + assertTrue(args[1] instanceof ActionListener); + + ActionListener getListener = (ActionListener) args[1]; + getListener.onResponse(getWorkflowResponse); + return null; + }).when(client).get(any(GetRequest.class), any()); + doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(2); responseListener.onFailure(new Exception("failed")); @@ -436,7 +687,7 @@ public void testFailedToUpdateWorkflow() { assertEquals("Failed to update use case template 1", exceptionCaptor.getValue().getMessage()); } - public void testFailedToUpdateNonExistingWorkflow() { + public void testFailedToUpdateNonExistingWorkflow() throws IOException { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); WorkflowRequest updateWorkflow = new WorkflowRequest("2", template); @@ -449,6 +700,23 @@ public void testFailedToUpdateNonExistingWorkflow() { return null; }).when(client).get(any(GetRequest.class), any()); + GetResponse getWorkflowResponse = TestHelpers.createGetResponse(template, "123", GLOBAL_CONTEXT_INDEX); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertEquals( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + 2, + args.length + ); + + assertTrue(args[0] instanceof GetRequest); + assertTrue(args[1] instanceof ActionListener); + + ActionListener getListener = (ActionListener) args[1]; + getListener.onFailure(new Exception("Failed to retrieve template (2) from global context.")); + return null; + }).when(client).get(any(GetRequest.class), any()); + doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(2); responseListener.onFailure(new Exception("failed")); @@ -461,7 +729,7 @@ public void testFailedToUpdateNonExistingWorkflow() { assertEquals("Failed to retrieve template (2) from global context.", exceptionCaptor.getValue().getMessage()); } - public void testUpdateWorkflow() { + public void testUpdateWorkflow() throws IOException { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); WorkflowRequest updateWorkflow = new WorkflowRequest("1", template); @@ -481,6 +749,23 @@ public void testUpdateWorkflow() { return null; }).when(flowFrameworkIndicesHandler).updateTemplateInGlobalContext(anyString(), any(Template.class), any(), anyBoolean()); + GetResponse getWorkflowResponse = TestHelpers.createGetResponse(template, "123", GLOBAL_CONTEXT_INDEX); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertEquals( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + 2, + args.length + ); + + assertTrue(args[0] instanceof GetRequest); + assertTrue(args[1] instanceof ActionListener); + + ActionListener getListener = (ActionListener) args[1]; + getListener.onResponse(getWorkflowResponse); + return null; + }).when(client).get(any(GetRequest.class), any()); + doAnswer(invocation -> { ActionListener updateResponseListener = invocation.getArgument(2); updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); @@ -494,14 +779,13 @@ public void testUpdateWorkflow() { assertEquals("1", responseCaptor.getValue().getWorkflowId()); } - public void testUpdateWorkflowWithField() { + public void testUpdateWorkflowWithField() throws IOException { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest updateWorkflow = new WorkflowRequest( - "1", - Template.builder().name("new name").description("test").useCase(null).uiMetadata(Map.of("foo", "bar")).build(), - Map.of(UPDATE_WORKFLOW_FIELDS, "true") - ); + + Template template1 = Template.builder().name("new name").description("test").useCase(null).uiMetadata(Map.of("foo", "bar")).build(); + + WorkflowRequest updateWorkflow = new WorkflowRequest("1", template1, Map.of(UPDATE_WORKFLOW_FIELDS, "true")); doAnswer(invocation -> { ActionListener getListener = invocation.getArgument(1); @@ -512,6 +796,23 @@ public void testUpdateWorkflowWithField() { return null; }).when(client).get(any(GetRequest.class), any()); + GetResponse getWorkflowResponse = TestHelpers.createGetResponse(template, "123", GLOBAL_CONTEXT_INDEX); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertEquals( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + 2, + args.length + ); + + assertTrue(args[0] instanceof GetRequest); + assertTrue(args[1] instanceof ActionListener); + + ActionListener getListener = (ActionListener) args[1]; + getListener.onResponse(getWorkflowResponse); + return null; + }).when(client).get(any(GetRequest.class), any()); + doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(2); responseListener.onResponse(new IndexResponse(new ShardId(GLOBAL_CONTEXT_INDEX, "", 1), "1", 1L, 1L, 1L, true)); @@ -552,6 +853,24 @@ public void testUpdateWorkflowWithField() { getListener.onResponse(getResponse); return null; }).when(client).get(any(GetRequest.class), any()); + + GetResponse getWorkflowResponse1 = TestHelpers.createGetResponse(template1, "123", GLOBAL_CONTEXT_INDEX); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertEquals( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + 2, + args.length + ); + + assertTrue(args[0] instanceof GetRequest); + assertTrue(args[1] instanceof ActionListener); + + ActionListener getListener = (ActionListener) args[1]; + getListener.onResponse(getWorkflowResponse1); + return null; + }).when(client).get(any(GetRequest.class), any()); + createWorkflowTransportAction.doExecute(mock(Task.class), updateWorkflow, listener); verify(listener, times(2)).onResponse(any()); diff --git a/src/test/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportActionTests.java index 9cd392de1..ef4dfe097 100644 --- a/src/test/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/DeleteWorkflowTransportActionTests.java @@ -13,17 +13,24 @@ import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; + import org.mockito.ArgumentCaptor; import static org.mockito.ArgumentMatchers.any; @@ -45,11 +52,22 @@ public void setUp() throws Exception { super.setUp(); this.client = mock(Client.class); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + + ClusterService clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES))) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + this.deleteWorkflowTransportAction = new DeleteWorkflowTransportAction( mock(TransportService.class), mock(ActionFilters.class), flowFrameworkIndicesHandler, - client + client, + clusterService, + xContentRegistry(), + Settings.EMPTY ); ThreadPool clientThreadPool = mock(ThreadPool.class); diff --git a/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java index 51561c28e..203255361 100644 --- a/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java @@ -12,6 +12,8 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; @@ -37,6 +39,9 @@ import org.opensearch.transport.TransportService; import org.junit.AfterClass; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; @@ -109,6 +114,13 @@ public void setUp() throws Exception { flowFrameworkSettings = mock(FlowFrameworkSettings.class); when(flowFrameworkSettings.getRequestTimeout()).thenReturn(TimeValue.timeValueSeconds(10)); + ClusterService clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES))) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + this.deprovisionWorkflowTransportAction = new DeprovisionWorkflowTransportAction( mock(TransportService.class), mock(ActionFilters.class), @@ -116,7 +128,10 @@ public void setUp() throws Exception { client, workflowStepFactory, flowFrameworkIndicesHandler, - flowFrameworkSettings + flowFrameworkSettings, + clusterService, + xContentRegistry(), + Settings.EMPTY ); } diff --git a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportActionTests.java index 7aa0323b4..d9c188fc6 100644 --- a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStateTransportActionTests.java @@ -8,19 +8,29 @@ */ package org.opensearch.flowframework.transport; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.common.FlowFrameworkSettings; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.get.GetResult; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -29,12 +39,21 @@ import java.io.IOException; import java.time.Instant; +import java.util.Arrays; import java.util.Collections; +import java.util.HashSet; import java.util.Map; +import org.mockito.ArgumentCaptor; import org.mockito.Mockito; +import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class GetWorkflowStateTransportActionTests extends OpenSearchTestCase { @@ -52,11 +71,20 @@ public void setUp() throws Exception { super.setUp(); this.client = mock(Client.class); this.threadPool = mock(ThreadPool.class); + ClusterService clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES))) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + this.getWorkflowStateTransportAction = new GetWorkflowStateTransportAction( mock(TransportService.class), mock(ActionFilters.class), client, - xContentRegistry() + xContentRegistry(), + clusterService, + Settings.EMPTY ); task = Mockito.mock(Task.class); ThreadPool clientThreadPool = mock(ThreadPool.class); @@ -124,4 +152,71 @@ public void testGetWorkflowStateResponse() throws IOException { Assert.assertEquals(map.get("state"), workFlowState.getState()); Assert.assertEquals(map.get("workflow_id"), workFlowState.getWorkflowId()); } + + public void testExecuteGetWorkflowStateRequestFailure() throws IOException { + String workflowId = "test-workflow"; + GetWorkflowStateRequest request = new GetWorkflowStateRequest(workflowId, false); + ActionListener listener = mock(ActionListener.class); + + // Stub client.get to force on failure + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(new Exception("failed")); + return null; + }).when(client).get(any(GetRequest.class), any()); + + getWorkflowStateTransportAction.doExecute(null, request, listener); + + verify(listener, never()).onResponse(any(GetWorkflowStateResponse.class)); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(FlowFrameworkException.class); + verify(listener, times(1)).onFailure(responseCaptor.capture()); + + assertEquals("Failed to get workflow status of: " + workflowId, responseCaptor.getValue().getMessage()); + } + + public void testExecuteGetWorkflowStateRequestIndexNotFound() throws IOException { + String workflowId = "test-workflow"; + GetWorkflowStateRequest request = new GetWorkflowStateRequest(workflowId, false); + ActionListener listener = mock(ActionListener.class); + + // Stub client.get to force on failure + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(new IndexNotFoundException("index not found")); + return null; + }).when(client).get(any(GetRequest.class), any()); + + getWorkflowStateTransportAction.doExecute(null, request, listener); + + verify(listener, never()).onResponse(any(GetWorkflowStateResponse.class)); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(FlowFrameworkException.class); + verify(listener, times(1)).onFailure(responseCaptor.capture()); + + assertEquals("Fail to find workflow status of " + workflowId, responseCaptor.getValue().getMessage()); + } + + public void testExecuteGetWorkflowStateRequestParseFailure() throws IOException { + String workflowId = "test-workflow"; + GetWorkflowStateRequest request = new GetWorkflowStateRequest(workflowId, false); + ActionListener listener = mock(ActionListener.class); + + // Stub client.get to force on response + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + XContentBuilder builder = XContentFactory.jsonBuilder(); + BytesReference templateBytesRef = BytesReference.bytes(builder); + GetResult getResult = new GetResult(GLOBAL_CONTEXT_INDEX, workflowId, 1, 1, 1, true, templateBytesRef, null, null); + responseListener.onResponse(new GetResponse(getResult)); + return null; + }).when(client).get(any(GetRequest.class), any()); + + getWorkflowStateTransportAction.doExecute(null, request, listener); + + verify(listener, never()).onResponse(any(GetWorkflowStateResponse.class)); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(FlowFrameworkException.class); + verify(listener, times(1)).onFailure(responseCaptor.capture()); + + assertEquals("Failed to parse workflowState: " + workflowId, responseCaptor.getValue().getMessage()); + } + } diff --git a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowTransportActionTests.java index 7a88199fb..0ce92b3f0 100644 --- a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowTransportActionTests.java @@ -14,6 +14,7 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; @@ -22,6 +23,7 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; @@ -34,7 +36,9 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; +import java.util.Arrays; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; @@ -65,12 +69,22 @@ public void setUp() throws Exception { this.xContentRegistry = mock(NamedXContentRegistry.class); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); this.encryptorUtils = new EncryptorUtils(mock(ClusterService.class), client, xContentRegistry); + ClusterService clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES))) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + this.getTemplateTransportAction = new GetWorkflowTransportAction( mock(TransportService.class), mock(ActionFilters.class), flowFrameworkIndicesHandler, client, - encryptorUtils + encryptorUtils, + clusterService, + xContentRegistry, + Settings.EMPTY ); Version templateVersion = Version.fromString("1.0.0"); diff --git a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java index 5cc11a92d..a6eacc069 100644 --- a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java @@ -15,6 +15,8 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; @@ -23,6 +25,7 @@ import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.ProvisioningProgress; import org.opensearch.flowframework.model.Template; @@ -38,7 +41,9 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; +import java.util.Arrays; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -75,6 +80,12 @@ public void setUp() throws Exception { this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); this.encryptorUtils = mock(EncryptorUtils.class); this.pluginsService = mock(PluginsService.class); + ClusterService clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES))) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); this.provisionWorkflowTransportAction = new ProvisionWorkflowTransportAction( mock(TransportService.class), @@ -84,7 +95,10 @@ public void setUp() throws Exception { workflowProcessSorter, flowFrameworkIndicesHandler, encryptorUtils, - pluginsService + pluginsService, + clusterService, + xContentRegistry(), + Settings.EMPTY ); Version templateVersion = Version.fromString("1.0.0"); diff --git a/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java index ab2485be4..e654b0482 100644 --- a/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java @@ -11,6 +11,8 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; @@ -31,7 +33,10 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; +import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; @@ -75,6 +80,13 @@ public void setUp() throws Exception { this.encryptorUtils = mock(EncryptorUtils.class); this.pluginsService = mock(PluginsService.class); + ClusterService clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES))) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + this.reprovisionWorkflowTransportAction = new ReprovisionWorkflowTransportAction( transportService, actionFilters, @@ -85,7 +97,10 @@ public void setUp() throws Exception { flowFrameworkIndicesHandler, flowFrameworkSettings, encryptorUtils, - pluginsService + pluginsService, + clusterService, + xContentRegistry(), + Settings.EMPTY ); ThreadPool clientThreadPool = mock(ThreadPool.class); diff --git a/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportActionTests.java index f3f55c052..ce23e6289 100644 --- a/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportActionTests.java @@ -15,6 +15,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.transport.handler.SearchHandler; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; @@ -24,7 +25,6 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -34,10 +34,12 @@ public class SearchWorkflowStateTransportActionTests extends OpenSearchTestCase private Client client; private ThreadPool threadPool; private ThreadContext threadContext; + private SearchHandler searchHandler; @Override public void setUp() throws Exception { super.setUp(); + searchHandler = mock(SearchHandler.class); this.client = mock(Client.class); this.threadPool = mock(ThreadPool.class); this.threadContext = new ThreadContext(Settings.EMPTY); @@ -48,35 +50,35 @@ public void setUp() throws Exception { this.searchWorkflowStateTransportAction = new SearchWorkflowStateTransportAction( mock(TransportService.class), mock(ActionFilters.class), - client + searchHandler ); } - public void testFailedSearchWorkflow() { + public void testSearchWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); SearchRequest searchRequest = new SearchRequest(); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchRequest.source(searchSourceBuilder); doAnswer(invocation -> { + SearchRequest request = invocation.getArgument(0); ActionListener responseListener = invocation.getArgument(1); - responseListener.onFailure(new Exception("Search failed")); + ThreadContext.StoredContext storedContext = mock(ThreadContext.StoredContext.class); + searchHandler.validateRole(request, null, responseListener, storedContext); + responseListener.onResponse(mock(SearchResponse.class)); return null; - }).when(client).search(any(), any()); - - searchWorkflowStateTransportAction.doExecute(mock(Task.class), searchRequest, listener); - verify(listener, times(1)).onFailure(any()); - } + }).when(searchHandler).search(any(SearchRequest.class), any(ActionListener.class)); - public void testSearchWorkflow() { - @SuppressWarnings("unchecked") - ActionListener listener = mock(ActionListener.class); - SearchRequest searchRequest = new SearchRequest(); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchRequest.source(searchSourceBuilder); + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onResponse(mock(SearchResponse.class)); + return null; + }).when(client).search(any(SearchRequest.class), any(ActionListener.class)); searchWorkflowStateTransportAction.doExecute(mock(Task.class), searchRequest, listener); - verify(client, times(1)).search(any(SearchRequest.class), any()); + verify(searchHandler).search(any(SearchRequest.class), any(ActionListener.class)); } } diff --git a/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowTransportActionTests.java index 763ae73b5..001aca48d 100644 --- a/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowTransportActionTests.java @@ -15,18 +15,16 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.transport.handler.SearchHandler; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -import org.mockito.ArgumentCaptor; - import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -36,11 +34,13 @@ public class SearchWorkflowTransportActionTests extends OpenSearchTestCase { private Client client; private ThreadPool threadPool; ThreadContext threadContext; + private SearchHandler searchHandler; @Override public void setUp() throws Exception { super.setUp(); this.client = mock(Client.class); + searchHandler = mock(SearchHandler.class); this.threadPool = mock(ThreadPool.class); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -49,27 +49,11 @@ public void setUp() throws Exception { this.searchWorkflowTransportAction = new SearchWorkflowTransportAction( mock(TransportService.class), mock(ActionFilters.class), - client + searchHandler ); } - public void testFailedSearchWorkflow() { - @SuppressWarnings("unchecked") - ActionListener listener = mock(ActionListener.class); - SearchRequest searchRequest = new SearchRequest(); - - doAnswer(invocation -> { - ActionListener responseListener = invocation.getArgument(2); - responseListener.onFailure(new Exception("Search failed")); - return null; - }).when(client).search(searchRequest); - - searchWorkflowTransportAction.doExecute(mock(Task.class), searchRequest, listener); - ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); - verify(listener, times(1)).onFailure(exceptionCaptor.capture()); - } - public void testSearchWorkflow() { threadPool = mock(ThreadPool.class); ThreadContext threadContext = new ThreadContext(Settings.EMPTY); @@ -82,8 +66,23 @@ public void testSearchWorkflow() { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchRequest.source(searchSourceBuilder); + doAnswer(invocation -> { + SearchRequest request = invocation.getArgument(0); + ActionListener responseListener = invocation.getArgument(1); + ThreadContext.StoredContext storedContext = mock(ThreadContext.StoredContext.class); + searchHandler.validateRole(request, null, responseListener, storedContext); + responseListener.onResponse(mock(SearchResponse.class)); + return null; + }).when(searchHandler).search(any(SearchRequest.class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onResponse(mock(SearchResponse.class)); + return null; + }).when(client).search(any(SearchRequest.class), any(ActionListener.class)); + searchWorkflowTransportAction.doExecute(mock(Task.class), searchRequest, listener); - verify(client, times(1)).search(any(SearchRequest.class), any()); + verify(searchHandler).search(any(SearchRequest.class), any(ActionListener.class)); } } diff --git a/src/test/java/org/opensearch/flowframework/transport/handler/SearchHandlerTests.java b/src/test/java/org/opensearch/flowframework/transport/handler/SearchHandlerTests.java new file mode 100644 index 000000000..ca744481d --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/transport/handler/SearchHandlerTests.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.transport.handler; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.common.FlowFrameworkSettings; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.junit.Before; + +import static org.opensearch.flowframework.TestHelpers.clusterSetting; +import static org.opensearch.flowframework.TestHelpers.matchAllRequest; +import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class SearchHandlerTests extends OpenSearchTestCase { + + private Client client; + private Settings settings; + private ClusterService clusterService; + private SearchHandler searchHandler; + private ClusterSettings clusterSettings; + + private SearchRequest request; + + private ActionListener listener; + + @SuppressWarnings("unchecked") + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + settings = Settings.builder().put(FILTER_BY_BACKEND_ROLES.getKey(), false).build(); + clusterSettings = clusterSetting(settings, FILTER_BY_BACKEND_ROLES); + clusterService = new ClusterService(settings, clusterSettings, mock(ThreadPool.class), null); + client = mock(Client.class); + searchHandler = new SearchHandler(settings, clusterService, client, FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES); + + ThreadContext threadContext = new ThreadContext(settings); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alice|odfe,aes|engineering,operations"); + org.opensearch.threadpool.ThreadPool mockThreadPool = mock(ThreadPool.class); + when(client.threadPool()).thenReturn(mockThreadPool); + when(client.threadPool().getThreadContext()).thenReturn(threadContext); + when(mockThreadPool.getThreadContext()).thenReturn(threadContext); + + request = mock(SearchRequest.class); + listener = mock(ActionListener.class); + } + + public void testSearchException() { + doThrow(new RuntimeException("test")).when(client).search(any(), any()); + searchHandler.search(request, listener); + verify(listener, times(1)).onFailure(any()); + } + + public void testFilterEnabledWithWrongSearch() { + settings = Settings.builder().put(FILTER_BY_BACKEND_ROLES.getKey(), true).build(); + clusterService = new ClusterService(settings, clusterSettings, mock(ThreadPool.class), null); + + searchHandler = new SearchHandler(settings, clusterService, client, FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES); + searchHandler.search(request, listener); + verify(listener, times(1)).onFailure(any()); + } + + public void testFilterEnabled() { + settings = Settings.builder().put(FILTER_BY_BACKEND_ROLES.getKey(), true).build(); + clusterService = new ClusterService(settings, clusterSettings, mock(ThreadPool.class), null); + + searchHandler = new SearchHandler(settings, clusterService, client, FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES); + searchHandler.search(matchAllRequest(), listener); + verify(client, times(1)).search(any(), any()); + } + +} diff --git a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java index 8237a7a93..3644ab0ee 100644 --- a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java +++ b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java @@ -9,6 +9,7 @@ package org.opensearch.flowframework.util; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.commons.authuser.User; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -17,6 +18,8 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; @@ -27,6 +30,7 @@ import java.util.Map; import java.util.Set; +import static org.opensearch.flowframework.util.ParseUtils.isAdmin; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -111,6 +115,71 @@ public void testConditionallySubstituteWithNoPlaceholders() { assertEquals("This string has no placeholders", result); } + public void testAddUserRoleFilterWithNullUser() { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + ParseUtils.addUserBackendRolesFilter(null, searchSourceBuilder); + assertEquals("{}", searchSourceBuilder.toString()); + } + + public void testAddUserRoleFilterWithNullUserBackendRole() { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + ParseUtils.addUserBackendRolesFilter( + new User(randomAlphaOfLength(5), null, ImmutableList.of(randomAlphaOfLength(5)), ImmutableList.of(randomAlphaOfLength(5))), + searchSourceBuilder + ); + assertEquals( + "{\"query\":{\"bool\":{\"must\":[{\"nested\":{\"query\":{\"terms\":{\"user.backend_roles.keyword\":[]," + + "\"boost\":1.0}},\"path\":\"user\",\"ignore_unmapped\":false,\"score_mode\":\"none\",\"boost\":1.0}}]," + + "\"adjust_pure_negative\":true,\"boost\":1.0}}}", + searchSourceBuilder.toString() + ); + } + + public void testAddUserRoleFilterWithEmptyUserBackendRole() { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + ParseUtils.addUserBackendRolesFilter( + new User( + randomAlphaOfLength(5), + ImmutableList.of(), + ImmutableList.of(randomAlphaOfLength(5)), + ImmutableList.of(randomAlphaOfLength(5)) + ), + searchSourceBuilder + ); + assertEquals( + "{\"query\":{\"bool\":{\"must\":[{\"nested\":{\"query\":{\"terms\":{\"user.backend_roles.keyword\":[]," + + "\"boost\":1.0}},\"path\":\"user\",\"ignore_unmapped\":false,\"score_mode\":\"none\",\"boost\":1.0}}]," + + "\"adjust_pure_negative\":true,\"boost\":1.0}}}", + searchSourceBuilder.toString() + ); + } + + public void testAddUserRoleFilterWithUserBackendRole() { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + String backendRole1 = randomAlphaOfLength(5); + String backendRole2 = randomAlphaOfLength(5); + ParseUtils.addUserBackendRolesFilter( + new User( + randomAlphaOfLength(5), + ImmutableList.of(backendRole1, backendRole2), + ImmutableList.of(randomAlphaOfLength(5)), + ImmutableList.of(randomAlphaOfLength(5)) + ), + searchSourceBuilder + ); + assertEquals( + "{\"query\":{\"bool\":{\"must\":[{\"nested\":{\"query\":{\"terms\":{\"user.backend_roles.keyword\":" + + "[\"" + + backendRole1 + + "\",\"" + + backendRole2 + + "\"]," + + "\"boost\":1.0}},\"path\":\"user\",\"ignore_unmapped\":false,\"score_mode\":\"none\",\"boost\":1.0}}]," + + "\"adjust_pure_negative\":true,\"boost\":1.0}}}", + searchSourceBuilder.toString() + ); + } + public void testConditionallySubstituteWithUnmatchedPlaceholders() { String input = "This string has unmatched ${{placeholder}}"; Map outputs = new HashMap<>(); @@ -347,4 +416,29 @@ public void testPrependIndexToSettings() throws Exception { assertTrue(prependedSettings.entrySet().stream().allMatch(x -> x.getKey().startsWith("index."))); } + + public void testIsAdmin() { + User user1 = new User( + randomAlphaOfLength(5), + ImmutableList.of(), + ImmutableList.of("all_access"), + ImmutableList.of(randomAlphaOfLength(5)) + ); + assertTrue(isAdmin(user1)); + } + + public void testIsAdminBackendRoleIsAllAccess() { + String backendRole1 = "all_access"; + User user1 = new User( + randomAlphaOfLength(5), + ImmutableList.of(backendRole1), + ImmutableList.of(randomAlphaOfLength(5)), + ImmutableList.of(randomAlphaOfLength(5)) + ); + assertFalse(isAdmin(user1)); + } + + public void testIsAdminNull() { + assertFalse(isAdmin(null)); + } }