Skip to content

Commit

Permalink
Enable tenant aware search
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Dec 28, 2024
1 parent c46d904 commit 7c10eac
Show file tree
Hide file tree
Showing 10 changed files with 163 additions and 44 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,4 @@ jobs:
distribution: temurin
- name: Build and Run Tests
run: |
./gradlew integTest -Dtests.rest.tenantaware=true
./gradlew integTest "-Dtests.rest.tenantaware=true"
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)

## [Unreleased 2.x](https://github.com/opensearch-project/flow-framework/compare/2.18...2.x)
### Features
- Implemented multitenant remote metadata client ([#980](https://github.com/opensearch-project/flow-framework/pull/980))

### Enhancements
### Bug Fixes
- Remove useCase and defaultParams field in WorkflowRequest ([#758](https://github.com/opensearch-project/flow-framework/pull/758))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,23 +74,32 @@ public AbstractSearchWorkflowAction(

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
if (!flowFrameworkSettings.isFlowFrameworkEnabled()) {
FlowFrameworkException ffe = new FlowFrameworkException(
"This API is disabled. To enable it, update the setting [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.",
RestStatus.FORBIDDEN
);
try {
if (!flowFrameworkSettings.isFlowFrameworkEnabled()) {
FlowFrameworkException ffe = new FlowFrameworkException(
"This API is disabled. To enable it, update the setting [" + FLOW_FRAMEWORK_ENABLED.getKey() + "] to true.",
RestStatus.FORBIDDEN
);
return channel -> channel.sendResponse(
new BytesRestResponse(ffe.getRestStatus(), ffe.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS))
);
}
String tenantId = RestActionUtils.getTenantID(flowFrameworkSettings.isMultiTenancyEnabled(), request);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.parseXContent(request.contentOrSourceParamParser());
searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true);
searchSourceBuilder.timeout(flowFrameworkSettings.getRequestTimeout());

// The transport action needs the tenant id but also only takes a SearchRequest.
// The tenant filtering will be handled by the metadata client.
// We'll use the preference field to communicate the tenant ID and strip it on the other end
SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index).preference(tenantId);
return channel -> client.execute(actionType, searchRequest, search(channel));
} catch (FlowFrameworkException ex) {
return channel -> channel.sendResponse(
new BytesRestResponse(ffe.getRestStatus(), ffe.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS))
new BytesRestResponse(ex.getRestStatus(), ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS))
);
}
String tenantId = RestActionUtils.getTenantID(flowFrameworkSettings.isMultiTenancyEnabled(), request);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.parseXContent(request.contentOrSourceParamParser());
searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true);
searchSourceBuilder.timeout(flowFrameworkSettings.getRequestTimeout());

SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index);
return channel -> client.execute(actionType, searchRequest, search(channel));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ public SearchWorkflowStateTransportAction(TransportService transportService, Act
@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
try {
searchHandler.search(request, actionListener);
// We used the SearchRequest preference field to convey a tenant id if any
String tenantId = null;
if (request.preference() != null) {
tenantId = request.preference();
request.preference(null);
}
searchHandler.search(request, tenantId, actionListener);
} catch (Exception e) {
String errorMessage = "Failed to search workflow states in global context";
logger.error(errorMessage, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ public SearchWorkflowTransportAction(TransportService transportService, ActionFi
@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
try {
searchHandler.search(request, actionListener);
// We used the SearchRequest preference field to convey a tenant id if any
String tenantId = null;
if (request.preference() != null) {
tenantId = request.preference();
request.preference(null);
}
searchHandler.search(request, tenantId, actionListener);
} catch (Exception e) {
String errorMessage = "Failed to search workflows in global context";
logger.error(errorMessage, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,17 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
import org.opensearch.remote.metadata.common.SdkClientUtils;
import org.opensearch.search.builder.SearchSourceBuilder;

import java.util.Arrays;

import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL;
import static org.opensearch.flowframework.util.ParseUtils.isAdmin;
import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext;

Expand All @@ -31,6 +39,7 @@
public class SearchHandler {
private final Logger logger = LogManager.getLogger(SearchHandler.class);
private final Client client;
private final SdkClient sdkClient;
private volatile Boolean filterByBackendRole;

/**
Expand All @@ -40,25 +49,33 @@ public class SearchHandler {
* @param client The node client to retrieve a stored use case template
* @param filterByBackendRoleSetting filter role backend settings
*/
public SearchHandler(Settings settings, ClusterService clusterService, Client client, Setting<Boolean> filterByBackendRoleSetting) {
public SearchHandler(
Settings settings,
ClusterService clusterService,
Client client,
SdkClient sdkClient,
Setting<Boolean> filterByBackendRoleSetting
) {
this.client = client;
this.sdkClient = sdkClient;
filterByBackendRole = filterByBackendRoleSetting.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(filterByBackendRoleSetting, it -> filterByBackendRole = it);
}

/**
* Search workflows in global context
* @param request SearchRequest
* @param tenantId the tenant ID
* @param actionListener ActionListener
*/
public void search(SearchRequest request, ActionListener<SearchResponse> actionListener) {
public void search(SearchRequest request, String tenantId, ActionListener<SearchResponse> actionListener) {
// AccessController should take care of letting the user with right permission to view the workflow
User user = ParseUtils.getUserContext(client);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
logger.info("Searching workflows in global context");
SearchSourceBuilder searchSourceBuilder = request.source();
searchSourceBuilder.fetchSource(getSourceContext(user, searchSourceBuilder));
validateRole(request, user, actionListener, context);
validateRole(request, tenantId, user, actionListener, context);
} catch (Exception e) {
logger.error("Failed to search workflows in global context", e);
actionListener.onFailure(e);
Expand All @@ -68,12 +85,14 @@ public void search(SearchRequest request, ActionListener<SearchResponse> actionL
/**
* Validate user role and call search
* @param request SearchRequest
* @param tenantId the tenant id
* @param user User
* @param listener ActionListener
* @param context ThreadContext
*/
public void validateRole(
SearchRequest request,
String tenantId,
User user,
ActionListener<SearchResponse> listener,
ThreadContext.StoredContext context
Expand All @@ -83,16 +102,40 @@ public void validateRole(
// Case 2: If Security is enabled and filter is disabled, proceed with search as
// user is already authenticated to hit this API.
// case 3: user is admin which means we don't have to check backend role filtering
client.search(request, ActionListener.runBefore(listener, context::restore));
doSearch(request, tenantId, ActionListener.runBefore(listener, context::restore));
} else {
// Security is enabled, filter is enabled and user isn't admin
try {
ParseUtils.addUserBackendRolesFilter(user, request.source());
logger.debug("Filtering result by {}", user.getBackendRoles());
client.search(request, ActionListener.runBefore(listener, context::restore));
doSearch(request, tenantId, ActionListener.runBefore(listener, context::restore));
} catch (Exception e) {
listener.onFailure(e);
}
}
}

private void doSearch(SearchRequest request, String tenantId, ActionListener<SearchResponse> listener) {
SearchDataObjectRequest searchRequest = SearchDataObjectRequest.builder()
.indices(request.indices())
.tenantId(tenantId)
.searchSourceBuilder(request.source())
.build();
sdkClient.searchDataObjectAsync(searchRequest, client.threadPool().executor(WORKFLOW_THREAD_POOL)).whenComplete((r, throwable) -> {
if (throwable == null) {
try {
SearchResponse searchResponse = SearchResponse.fromXContent(r.parser());
logger.info(Arrays.toString(request.indices()) + " search complete: {}", searchResponse.getHits().getTotalHits());
listener.onResponse(searchResponse);
} catch (Exception e) {
logger.error("Failed to parse search response", e);
listener.onFailure(new FlowFrameworkException("Failed to parse search response", INTERNAL_SERVER_ERROR));
}
} else {
Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable);
logger.error(Arrays.toString(request.indices()) + " search failed", cause);
listener.onFailure(cause);
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/
package org.opensearch.flowframework.rest;

import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.flowframework.FlowFrameworkTenantAwareRestTestCase;
Expand All @@ -16,6 +17,7 @@

import java.io.IOException;
import java.util.Map;
import java.util.concurrent.TimeUnit;

import static org.opensearch.flowframework.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID;
Expand Down Expand Up @@ -183,7 +185,6 @@ public void testWorkflowCRUD() throws Exception {
// Retry these tests until they pass. Search requires refresh, can take 15s on DDB
refreshAllIndices();

/* Search not yet implemented TODO
assertBusy(() -> {
// Search should show only the workflow for tenant
Response restResponse = makeRequest(tenantMatchAllRequest, GET, WORKFLOW_PATH + "_search");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.opensearch.transport.TransportService;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
Expand Down Expand Up @@ -66,10 +67,10 @@ public void testSearchWorkflow() {
SearchRequest request = invocation.getArgument(0);
ActionListener<SearchResponse> responseListener = invocation.getArgument(1);
ThreadContext.StoredContext storedContext = mock(ThreadContext.StoredContext.class);
searchHandler.validateRole(request, null, responseListener, storedContext);
searchHandler.validateRole(request, null, null, responseListener, storedContext);
responseListener.onResponse(mock(SearchResponse.class));
return null;
}).when(searchHandler).search(any(SearchRequest.class), any(ActionListener.class));
}).when(searchHandler).search(any(SearchRequest.class), nullable(String.class), any(ActionListener.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> responseListener = invocation.getArgument(1);
Expand All @@ -78,7 +79,7 @@ public void testSearchWorkflow() {
}).when(client).search(any(SearchRequest.class), any(ActionListener.class));

searchWorkflowStateTransportAction.doExecute(mock(Task.class), searchRequest, listener);
verify(searchHandler).search(any(SearchRequest.class), any(ActionListener.class));
verify(searchHandler).search(any(SearchRequest.class), nullable(String.class), any(ActionListener.class));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.opensearch.transport.TransportService;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
Expand Down Expand Up @@ -70,10 +71,10 @@ public void testSearchWorkflow() {
SearchRequest request = invocation.getArgument(0);
ActionListener<SearchResponse> responseListener = invocation.getArgument(1);
ThreadContext.StoredContext storedContext = mock(ThreadContext.StoredContext.class);
searchHandler.validateRole(request, null, responseListener, storedContext);
searchHandler.validateRole(request, null, null, responseListener, storedContext);
responseListener.onResponse(mock(SearchResponse.class));
return null;
}).when(searchHandler).search(any(SearchRequest.class), any(ActionListener.class));
}).when(searchHandler).search(any(SearchRequest.class), nullable(String.class), any(ActionListener.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> responseListener = invocation.getArgument(1);
Expand All @@ -82,7 +83,7 @@ public void testSearchWorkflow() {
}).when(client).search(any(SearchRequest.class), any(ActionListener.class));

searchWorkflowTransportAction.doExecute(mock(Task.class), searchRequest, listener);
verify(searchHandler).search(any(SearchRequest.class), any(ActionListener.class));
verify(searchHandler).search(any(SearchRequest.class), nullable(String.class), any(ActionListener.class));
}

}
Loading

0 comments on commit 7c10eac

Please sign in to comment.