diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 4f0f42ef2..f02a8f538 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -192,4 +192,4 @@ jobs: distribution: temurin - name: Build and Run Tests run: | - ./gradlew integTest -Dtests.rest.tenantaware=true + ./gradlew integTest "-Dtests.rest.tenantaware=true" diff --git a/CHANGELOG.md b/CHANGELOG.md index c8f99f0bb..6ab0612de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ## [Unreleased 2.x](https://github.com/opensearch-project/flow-framework/compare/2.18...2.x) ### Features +- Implemented multitenant remote metadata client ([#980](https://github.com/opensearch-project/flow-framework/pull/980)) + ### Enhancements ### Bug Fixes - Remove useCase and defaultParams field in WorkflowRequest ([#758](https://github.com/opensearch-project/flow-framework/pull/758)) diff --git a/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java index 94e168920..52c0f7754 100644 --- a/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java @@ -74,23 +74,32 @@ public AbstractSearchWorkflowAction( @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - if (!flowFrameworkSettings.isFlowFrameworkEnabled()) { - FlowFrameworkException ffe = new FlowFrameworkException( - "This API is disabled. To enable it, update the setting [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", - RestStatus.FORBIDDEN - ); + try { + if (!flowFrameworkSettings.isFlowFrameworkEnabled()) { + FlowFrameworkException ffe = new FlowFrameworkException( + "This API is disabled. To enable it, update the setting [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.", + RestStatus.FORBIDDEN + ); + return channel -> channel.sendResponse( + new BytesRestResponse(ffe.getRestStatus(), ffe.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) + ); + } + String tenantId = RestActionUtils.getTenantID(flowFrameworkSettings.isMultiTenancyEnabled(), request); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.parseXContent(request.contentOrSourceParamParser()); + searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true); + searchSourceBuilder.timeout(flowFrameworkSettings.getRequestTimeout()); + + // The transport action needs the tenant id but also only takes a SearchRequest. + // The tenant filtering will be handled by the metadata client. + // We'll use the preference field to communicate the tenant ID and strip it on the other end + SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index).preference(tenantId); + return channel -> client.execute(actionType, searchRequest, search(channel)); + } catch (FlowFrameworkException ex) { return channel -> channel.sendResponse( - new BytesRestResponse(ffe.getRestStatus(), ffe.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) + new BytesRestResponse(ex.getRestStatus(), ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) ); } - String tenantId = RestActionUtils.getTenantID(flowFrameworkSettings.isMultiTenancyEnabled(), request); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.parseXContent(request.contentOrSourceParamParser()); - searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true); - searchSourceBuilder.timeout(flowFrameworkSettings.getRequestTimeout()); - - SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index); - return channel -> client.execute(actionType, searchRequest, search(channel)); } /** diff --git a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java index f20c57adb..138c55ea2 100644 --- a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java @@ -46,7 +46,13 @@ public SearchWorkflowStateTransportAction(TransportService transportService, Act @Override protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { try { - searchHandler.search(request, actionListener); + // We used the SearchRequest preference field to convey a tenant id if any + String tenantId = null; + if (request.preference() != null) { + tenantId = request.preference(); + request.preference(null); + } + searchHandler.search(request, tenantId, actionListener); } catch (Exception e) { String errorMessage = "Failed to search workflow states in global context"; logger.error(errorMessage, e); diff --git a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowTransportAction.java index 46f0afb10..40c0a72e2 100644 --- a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowTransportAction.java @@ -46,7 +46,13 @@ public SearchWorkflowTransportAction(TransportService transportService, ActionFi @Override protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { try { - searchHandler.search(request, actionListener); + // We used the SearchRequest preference field to convey a tenant id if any + String tenantId = null; + if (request.preference() != null) { + tenantId = request.preference(); + request.preference(null); + } + searchHandler.search(request, tenantId, actionListener); } catch (Exception e) { String errorMessage = "Failed to search workflows in global context"; logger.error(errorMessage, e); diff --git a/src/main/java/org/opensearch/flowframework/transport/handler/SearchHandler.java b/src/main/java/org/opensearch/flowframework/transport/handler/SearchHandler.java index 512b0bea2..c58d68988 100644 --- a/src/main/java/org/opensearch/flowframework/transport/handler/SearchHandler.java +++ b/src/main/java/org/opensearch/flowframework/transport/handler/SearchHandler.java @@ -19,9 +19,17 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.SearchDataObjectRequest; +import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.builder.SearchSourceBuilder; +import java.util.Arrays; + +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.util.ParseUtils.isAdmin; import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext; @@ -31,6 +39,7 @@ public class SearchHandler { private final Logger logger = LogManager.getLogger(SearchHandler.class); private final Client client; + private final SdkClient sdkClient; private volatile Boolean filterByBackendRole; /** @@ -40,8 +49,15 @@ public class SearchHandler { * @param client The node client to retrieve a stored use case template * @param filterByBackendRoleSetting filter role backend settings */ - public SearchHandler(Settings settings, ClusterService clusterService, Client client, Setting filterByBackendRoleSetting) { + public SearchHandler( + Settings settings, + ClusterService clusterService, + Client client, + SdkClient sdkClient, + Setting filterByBackendRoleSetting + ) { this.client = client; + this.sdkClient = sdkClient; filterByBackendRole = filterByBackendRoleSetting.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(filterByBackendRoleSetting, it -> filterByBackendRole = it); } @@ -49,16 +65,17 @@ public SearchHandler(Settings settings, ClusterService clusterService, Client cl /** * Search workflows in global context * @param request SearchRequest + * @param tenantId the tenant ID * @param actionListener ActionListener */ - public void search(SearchRequest request, ActionListener actionListener) { + public void search(SearchRequest request, String tenantId, ActionListener actionListener) { // AccessController should take care of letting the user with right permission to view the workflow User user = ParseUtils.getUserContext(client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { logger.info("Searching workflows in global context"); SearchSourceBuilder searchSourceBuilder = request.source(); searchSourceBuilder.fetchSource(getSourceContext(user, searchSourceBuilder)); - validateRole(request, user, actionListener, context); + validateRole(request, tenantId, user, actionListener, context); } catch (Exception e) { logger.error("Failed to search workflows in global context", e); actionListener.onFailure(e); @@ -68,12 +85,14 @@ public void search(SearchRequest request, ActionListener actionL /** * Validate user role and call search * @param request SearchRequest + * @param tenantId the tenant id * @param user User * @param listener ActionListener * @param context ThreadContext */ public void validateRole( SearchRequest request, + String tenantId, User user, ActionListener listener, ThreadContext.StoredContext context @@ -83,16 +102,40 @@ public void validateRole( // Case 2: If Security is enabled and filter is disabled, proceed with search as // user is already authenticated to hit this API. // case 3: user is admin which means we don't have to check backend role filtering - client.search(request, ActionListener.runBefore(listener, context::restore)); + doSearch(request, tenantId, ActionListener.runBefore(listener, context::restore)); } else { // Security is enabled, filter is enabled and user isn't admin try { ParseUtils.addUserBackendRolesFilter(user, request.source()); logger.debug("Filtering result by {}", user.getBackendRoles()); - client.search(request, ActionListener.runBefore(listener, context::restore)); + doSearch(request, tenantId, ActionListener.runBefore(listener, context::restore)); } catch (Exception e) { listener.onFailure(e); } } } + + private void doSearch(SearchRequest request, String tenantId, ActionListener listener) { + SearchDataObjectRequest searchRequest = SearchDataObjectRequest.builder() + .indices(request.indices()) + .tenantId(tenantId) + .searchSourceBuilder(request.source()) + .build(); + sdkClient.searchDataObjectAsync(searchRequest, client.threadPool().executor(WORKFLOW_THREAD_POOL)).whenComplete((r, throwable) -> { + if (throwable == null) { + try { + SearchResponse searchResponse = SearchResponse.fromXContent(r.parser()); + logger.info(Arrays.toString(request.indices()) + " search complete: {}", searchResponse.getHits().getTotalHits()); + listener.onResponse(searchResponse); + } catch (Exception e) { + logger.error("Failed to parse search response", e); + listener.onFailure(new FlowFrameworkException("Failed to parse search response", INTERNAL_SERVER_ERROR)); + } + } else { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + logger.error(Arrays.toString(request.indices()) + " search failed", cause); + listener.onFailure(cause); + } + }); + } } diff --git a/src/test/java/org/opensearch/flowframework/rest/RestWorkflowTenantAwareIT.java b/src/test/java/org/opensearch/flowframework/rest/RestWorkflowTenantAwareIT.java index ba9f3383e..0120f65f2 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestWorkflowTenantAwareIT.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestWorkflowTenantAwareIT.java @@ -8,6 +8,7 @@ */ package org.opensearch.flowframework.rest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; import org.opensearch.flowframework.FlowFrameworkTenantAwareRestTestCase; @@ -16,6 +17,7 @@ import java.io.IOException; import java.util.Map; +import java.util.concurrent.TimeUnit; import static org.opensearch.flowframework.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; @@ -183,7 +185,6 @@ public void testWorkflowCRUD() throws Exception { // Retry these tests until they pass. Search requires refresh, can take 15s on DDB refreshAllIndices(); - /* Search not yet implemented TODO assertBusy(() -> { // Search should show only the workflow for tenant Response restResponse = makeRequest(tenantMatchAllRequest, GET, WORKFLOW_PATH + "_search"); diff --git a/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportActionTests.java index ce23e6289..898840e59 100644 --- a/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportActionTests.java @@ -23,6 +23,7 @@ import org.opensearch.transport.TransportService; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -66,10 +67,10 @@ public void testSearchWorkflow() { SearchRequest request = invocation.getArgument(0); ActionListener responseListener = invocation.getArgument(1); ThreadContext.StoredContext storedContext = mock(ThreadContext.StoredContext.class); - searchHandler.validateRole(request, null, responseListener, storedContext); + searchHandler.validateRole(request, null, null, responseListener, storedContext); responseListener.onResponse(mock(SearchResponse.class)); return null; - }).when(searchHandler).search(any(SearchRequest.class), any(ActionListener.class)); + }).when(searchHandler).search(any(SearchRequest.class), nullable(String.class), any(ActionListener.class)); doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(1); @@ -78,7 +79,7 @@ public void testSearchWorkflow() { }).when(client).search(any(SearchRequest.class), any(ActionListener.class)); searchWorkflowStateTransportAction.doExecute(mock(Task.class), searchRequest, listener); - verify(searchHandler).search(any(SearchRequest.class), any(ActionListener.class)); + verify(searchHandler).search(any(SearchRequest.class), nullable(String.class), any(ActionListener.class)); } } diff --git a/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowTransportActionTests.java index 001aca48d..8d33630dc 100644 --- a/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowTransportActionTests.java @@ -23,6 +23,7 @@ import org.opensearch.transport.TransportService; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -70,10 +71,10 @@ public void testSearchWorkflow() { SearchRequest request = invocation.getArgument(0); ActionListener responseListener = invocation.getArgument(1); ThreadContext.StoredContext storedContext = mock(ThreadContext.StoredContext.class); - searchHandler.validateRole(request, null, responseListener, storedContext); + searchHandler.validateRole(request, null, null, responseListener, storedContext); responseListener.onResponse(mock(SearchResponse.class)); return null; - }).when(searchHandler).search(any(SearchRequest.class), any(ActionListener.class)); + }).when(searchHandler).search(any(SearchRequest.class), nullable(String.class), any(ActionListener.class)); doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(1); @@ -82,7 +83,7 @@ public void testSearchWorkflow() { }).when(client).search(any(SearchRequest.class), any(ActionListener.class)); searchWorkflowTransportAction.doExecute(mock(Task.class), searchRequest, listener); - verify(searchHandler).search(any(SearchRequest.class), any(ActionListener.class)); + verify(searchHandler).search(any(SearchRequest.class), nullable(String.class), any(ActionListener.class)); } } diff --git a/src/test/java/org/opensearch/flowframework/transport/handler/SearchHandlerTests.java b/src/test/java/org/opensearch/flowframework/transport/handler/SearchHandlerTests.java index ca744481d..b55574c31 100644 --- a/src/test/java/org/opensearch/flowframework/transport/handler/SearchHandlerTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/handler/SearchHandlerTests.java @@ -8,33 +8,63 @@ */ package org.opensearch.flowframework.transport.handler; +import org.opensearch.action.LatchedActionListener; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.flowframework.common.FlowFrameworkSettings; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ScalingExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; +import org.junit.AfterClass; import org.junit.Before; +import java.util.Collections; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + import static org.opensearch.flowframework.TestHelpers.clusterSetting; import static org.opensearch.flowframework.TestHelpers.matchAllRequest; +import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class SearchHandlerTests extends OpenSearchTestCase { + private static final TestThreadPool testThreadPool = spy( + new TestThreadPool( + SearchHandlerTests.class.getName(), + new ScalingExecutorBuilder( + WORKFLOW_THREAD_POOL, + 1, + Math.max(2, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), + FLOW_FRAMEWORK_THREAD_POOL_PREFIX + WORKFLOW_THREAD_POOL + ) + ) + ); + private Client client; + private SdkClient sdkClient; private Settings settings; private ClusterService clusterService; private SearchHandler searchHandler; @@ -53,41 +83,61 @@ public void setUp() throws Exception { clusterSettings = clusterSetting(settings, FILTER_BY_BACKEND_ROLES); clusterService = new ClusterService(settings, clusterSettings, mock(ThreadPool.class), null); client = mock(Client.class); - searchHandler = new SearchHandler(settings, clusterService, client, FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); + searchHandler = new SearchHandler(settings, clusterService, client, sdkClient, FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES); + + when(client.threadPool()).thenReturn(testThreadPool); ThreadContext threadContext = new ThreadContext(settings); threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alice|odfe,aes|engineering,operations"); - org.opensearch.threadpool.ThreadPool mockThreadPool = mock(ThreadPool.class); - when(client.threadPool()).thenReturn(mockThreadPool); - when(client.threadPool().getThreadContext()).thenReturn(threadContext); - when(mockThreadPool.getThreadContext()).thenReturn(threadContext); + when(testThreadPool.getThreadContext()).thenReturn(threadContext); request = mock(SearchRequest.class); listener = mock(ActionListener.class); } - public void testSearchException() { - doThrow(new RuntimeException("test")).when(client).search(any(), any()); - searchHandler.search(request, listener); + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); + } + + public void testSearchException() throws InterruptedException { + doThrow(new RuntimeException("test")).when(client).search(any()); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(listener, latch); + searchHandler.search(request, null, latchedActionListener); + latch.await(1, TimeUnit.SECONDS); + verify(listener, times(1)).onFailure(any()); } - public void testFilterEnabledWithWrongSearch() { + public void testFilterEnabledWithWrongSearch() throws InterruptedException { settings = Settings.builder().put(FILTER_BY_BACKEND_ROLES.getKey(), true).build(); clusterService = new ClusterService(settings, clusterSettings, mock(ThreadPool.class), null); - searchHandler = new SearchHandler(settings, clusterService, client, FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES); - searchHandler.search(request, listener); + searchHandler = new SearchHandler(settings, clusterService, client, sdkClient, FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(listener, latch); + searchHandler.search(request, null, latchedActionListener); + latch.await(1, TimeUnit.SECONDS); + verify(listener, times(1)).onFailure(any()); } - public void testFilterEnabled() { + public void testFilterEnabled() throws InterruptedException { settings = Settings.builder().put(FILTER_BY_BACKEND_ROLES.getKey(), true).build(); clusterService = new ClusterService(settings, clusterSettings, mock(ThreadPool.class), null); - searchHandler = new SearchHandler(settings, clusterService, client, FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES); - searchHandler.search(matchAllRequest(), listener); - verify(client, times(1)).search(any(), any()); + searchHandler = new SearchHandler(settings, clusterService, client, sdkClient, FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(listener, latch); + searchHandler.search(matchAllRequest(), null, latchedActionListener); + latch.await(1, TimeUnit.SECONDS); + + verify(client, times(1)).search(any()); } }