Skip to content

Commit

Permalink
Extract tenant id from REST header into RestAction
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Dec 13, 2024
1 parent c67a807 commit 03104cf
Show file tree
Hide file tree
Showing 14 changed files with 200 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX;
import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL;
import static org.opensearch.flowframework.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES;
Expand Down Expand Up @@ -144,7 +145,7 @@ public Collection<Object> createComponents(
Map.entry(REMOTE_METADATA_ENDPOINT_KEY, REMOTE_METADATA_ENDPOINT.get(settings)),
Map.entry(REMOTE_METADATA_REGION_KEY, REMOTE_METADATA_REGION.get(settings)),
Map.entry(REMOTE_METADATA_SERVICE_NAME_KEY, REMOTE_METADATA_SERVICE_NAME.get(settings)),
Map.entry(TENANT_ID_FIELD_KEY, "tenant_id")
Map.entry(TENANT_ID_FIELD_KEY, TENANT_ID_FIELD)
)
: Collections.emptyMap()
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ private CommonValue() {}
public static final String USE_CASE = "use_case";
/** The param name for reprovisioning, used by the create workflow API */
public static final String REPROVISION_WORKFLOW = "reprovision";
/** The Rest header containing the tenant id */
public static final String TENANT_ID_HEADER = "x-tenant-id";
/** The field name containing the tenant id */
public static final String TENANT_ID_FIELD = "tenant_id";

/*
* Constants associated with plugin configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ public class FlowFrameworkSettings {
protected volatile Integer maxWorkflows;
/** Timeout for internal requests*/
protected volatile TimeValue requestTimeout;
/** Whether multitenancy is enabled */
private final Boolean isMultiTenancyEnabled;

/** The upper limit of max workflows that can be created */
public static final int MAX_WORKFLOWS_LIMIT = 10000;
Expand Down Expand Up @@ -150,6 +152,7 @@ public FlowFrameworkSettings(ClusterService clusterService, Settings settings) {
this.maxWorkflowSteps = MAX_WORKFLOW_STEPS.get(settings);
this.maxWorkflows = MAX_WORKFLOWS.get(settings);
this.requestTimeout = WORKFLOW_REQUEST_TIMEOUT.get(settings);
this.isMultiTenancyEnabled = FLOW_FRAMEWORK_MULTI_TENANCY_ENABLED.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(FLOW_FRAMEWORK_ENABLED, it -> isFlowFrameworkEnabled = it);
clusterService.getClusterSettings().addSettingsUpdateConsumer(TASK_REQUEST_RETRY_DURATION, it -> retryDuration = it);
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_WORKFLOW_STEPS, it -> maxWorkflowSteps = it);
Expand Down Expand Up @@ -196,4 +199,12 @@ public Integer getMaxWorkflows() {
public TimeValue getRequestTimeout() {
return requestTimeout;
}

/**
* Whether multitenancy is enabled.
* @return whether Flow Framework multitenancy is enabled.
*/
public boolean isMultiTenancyEnabled() {
return isMultiTenancyEnabled;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.util.RestActionUtils;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
Expand Down Expand Up @@ -82,6 +83,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
new BytesRestResponse(ffe.getRestStatus(), ffe.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS))
);
}
String tenantId = RestActionUtils.getTenantID(flowFrameworkSettings.isMultiTenancyEnabled(), request);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.parseXContent(request.contentOrSourceParamParser());
searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.opensearch.flowframework.transport.CreateWorkflowAction;
import org.opensearch.flowframework.transport.WorkflowRequest;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.flowframework.util.RestActionUtils;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestRequest;
Expand Down Expand Up @@ -145,6 +146,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
);
return processError(ffe, params, request);
}
String tenantId = RestActionUtils.getTenantID(flowFrameworkSettings.isMultiTenancyEnabled(), request);
try {
Template template;
Map<String, String> useCaseDefaultsMap = Collections.emptyMap();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.transport.DeleteWorkflowAction;
import org.opensearch.flowframework.transport.WorkflowRequest;
import org.opensearch.flowframework.util.RestActionUtils;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestRequest;
Expand Down Expand Up @@ -71,6 +72,7 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request
RestStatus.FORBIDDEN
);
}
String tenantId = RestActionUtils.getTenantID(flowFrameworkFeatureEnabledSetting.isMultiTenancyEnabled(), request);

// Always consume content to silently ignore it
// https://github.com/opensearch-project/flow-framework/issues/578
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.transport.DeprovisionWorkflowAction;
import org.opensearch.flowframework.transport.WorkflowRequest;
import org.opensearch.flowframework.util.RestActionUtils;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestRequest;
Expand Down Expand Up @@ -68,6 +69,7 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request
RestStatus.FORBIDDEN
);
}
String tenantId = RestActionUtils.getTenantID(flowFrameworkFeatureEnabledSetting.isMultiTenancyEnabled(), request);

// Always consume content to silently ignore it
// https://github.com/opensearch-project/flow-framework/issues/578
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.transport.GetWorkflowAction;
import org.opensearch.flowframework.transport.WorkflowRequest;
import org.opensearch.flowframework.util.RestActionUtils;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestRequest;
Expand Down Expand Up @@ -69,6 +70,7 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request
RestStatus.FORBIDDEN
);
}
String tenantId = RestActionUtils.getTenantID(flowFrameworkFeatureEnabledSetting.isMultiTenancyEnabled(), request);

// Always consume content to silently ignore it
// https://github.com/opensearch-project/flow-framework/issues/578
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.transport.GetWorkflowStateAction;
import org.opensearch.flowframework.transport.GetWorkflowStateRequest;
import org.opensearch.flowframework.util.RestActionUtils;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestRequest;
Expand Down Expand Up @@ -65,6 +66,7 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request
RestStatus.FORBIDDEN
);
}
String tenantId = RestActionUtils.getTenantID(flowFrameworkFeatureEnabledSetting.isMultiTenancyEnabled(), request);

// Always consume content to silently ignore it
// https://github.com/opensearch-project/flow-framework/issues/578
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.transport.GetWorkflowStepAction;
import org.opensearch.flowframework.transport.WorkflowRequest;
import org.opensearch.flowframework.util.RestActionUtils;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestRequest;
Expand Down Expand Up @@ -70,6 +71,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
RestStatus.FORBIDDEN
);
}
String tenantId = RestActionUtils.getTenantID(flowFrameworkSettings.isMultiTenancyEnabled(), request);

// Always consume content to silently ignore it
// https://github.com/opensearch-project/flow-framework/issues/578
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.transport.ProvisionWorkflowAction;
import org.opensearch.flowframework.transport.WorkflowRequest;
import org.opensearch.flowframework.util.RestActionUtils;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestRequest;
Expand Down Expand Up @@ -85,6 +86,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
if (workflowId == null) {
throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST);
}
String tenantId = RestActionUtils.getTenantID(flowFrameworkFeatureEnabledSetting.isMultiTenancyEnabled(), request);
// Create request and provision
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null, params);
return channel -> client.execute(ProvisionWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.util;

import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.common.CommonValue;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.rest.RestRequest;

import java.util.List;
import java.util.Map;

/**
* Utility methods for Rest Handlers
*/
public class RestActionUtils {

private RestActionUtils() {}

/**
* Finds the tenant id in the REST Headers
* @param isMultiTenancyEnabled whether multitenancy is enabled
* @param restRequest the RestRequest
* @return The tenant ID from the headers or null if multitenancy is not enabled
*/
public static String getTenantID(Boolean isMultiTenancyEnabled, RestRequest restRequest) {
if (isMultiTenancyEnabled) {
Map<String, List<String>> headers = restRequest.getHeaders();
if (headers.containsKey(CommonValue.TENANT_ID_HEADER)) {
List<String> tenantIdList = headers.get(CommonValue.TENANT_ID_HEADER);
if (tenantIdList != null && !tenantIdList.isEmpty()) {
String tenantId = tenantIdList.get(0);
if (tenantId != null) {
return tenantId;
} else {
throw new FlowFrameworkException("Tenant ID can't be null", RestStatus.FORBIDDEN);
}
} else {
throw new FlowFrameworkException("Tenant ID header is present but has no value", RestStatus.FORBIDDEN);
}
} else {
throw new FlowFrameworkException("Tenant ID header is missing", RestStatus.FORBIDDEN);
}
} else {
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ public void setUp() throws Exception {
FlowFrameworkSettings.TASK_REQUEST_RETRY_DURATION,
FlowFrameworkSettings.MAX_WORKFLOW_STEPS,
FlowFrameworkSettings.MAX_WORKFLOWS,
FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT
FlowFrameworkSettings.WORKFLOW_REQUEST_TIMEOUT,
FlowFrameworkSettings.FLOW_FRAMEWORK_MULTI_TENANCY_ENABLED
)
).collect(Collectors.toSet());
clusterSettings = new ClusterSettings(settings, settingsSet);
Expand All @@ -63,5 +64,6 @@ public void testSettings() throws IOException {
assertEquals(Optional.of(50), Optional.ofNullable(flowFrameworkSettings.getMaxWorkflowSteps()));
assertEquals(Optional.of(1000), Optional.ofNullable(flowFrameworkSettings.getMaxWorkflows()));
assertEquals(Optional.of(TimeValue.timeValueSeconds(10)), Optional.ofNullable(flowFrameworkSettings.getRequestTimeout()));
assertFalse(flowFrameworkSettings.isMultiTenancyEnabled());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.util;

import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.common.CommonValue;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.rest.RestRequest;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.rest.FakeRestRequest;
import org.junit.Assert;
import org.junit.Test;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class RestActionUtilsTests extends OpenSearchTestCase {
@Test
public void testGetTenantID() {
String tenantId = "test-tenant";
Map<String, List<String>> headers = new HashMap<>();
headers.put(CommonValue.TENANT_ID_HEADER, Collections.singletonList(tenantId));
RestRequest restRequest = new FakeRestRequest.Builder(xContentRegistry()).withHeaders(headers).build();

String actualTenantID = RestActionUtils.getTenantID(Boolean.TRUE, restRequest);
Assert.assertEquals(tenantId, actualTenantID);
}

@Test
public void testGetTenantID_NullTenantID() {
Map<String, List<String>> headers = new HashMap<>();
headers.put(CommonValue.TENANT_ID_HEADER, Collections.singletonList(null));
RestRequest restRequest = new FakeRestRequest.Builder(xContentRegistry()).withHeaders(headers).build();

try {
RestActionUtils.getTenantID(Boolean.TRUE, restRequest);
Assert.fail("Expected FlowFrameworkException");
} catch (Exception e) {
Assert.assertTrue(e instanceof FlowFrameworkException);
Assert.assertEquals(RestStatus.FORBIDDEN, ((FlowFrameworkException) e).status());
Assert.assertEquals("Tenant ID can't be null", e.getMessage());
}
}

@Test
public void testGetTenantID_NoMultiTenancy() {
String tenantId = "test-tenant";
Map<String, List<String>> headers = new HashMap<>();
headers.put(CommonValue.TENANT_ID_HEADER, Collections.singletonList(tenantId));
RestRequest restRequest = new FakeRestRequest.Builder(xContentRegistry()).withHeaders(headers).build();

String tenantID = RestActionUtils.getTenantID(Boolean.FALSE, restRequest);
Assert.assertNull(tenantID);
}

@Test
public void testGetTenantID_EmptyTenantIDList() {
Map<String, List<String>> headers = new HashMap<>();
headers.put(CommonValue.TENANT_ID_HEADER, Collections.emptyList());
RestRequest restRequest = new FakeRestRequest.Builder(xContentRegistry()).withHeaders(headers).build();

FlowFrameworkException exception = expectThrows(
FlowFrameworkException.class,
() -> RestActionUtils.getTenantID(Boolean.TRUE, restRequest)
);
assertEquals(RestStatus.FORBIDDEN, exception.status());
assertEquals("Tenant ID header is present but has no value", exception.getMessage());
}

@Test
public void testGetTenantID_MissingTenantIDHeader() {
Map<String, List<String>> headers = new HashMap<>();
RestRequest restRequest = new FakeRestRequest.Builder(xContentRegistry()).withHeaders(headers).build();

FlowFrameworkException exception = expectThrows(
FlowFrameworkException.class,
() -> RestActionUtils.getTenantID(Boolean.TRUE, restRequest)
);
assertEquals(RestStatus.FORBIDDEN, exception.status());
assertEquals("Tenant ID header is missing", exception.getMessage());
}

@Test
public void testGetTenantID_MultipleValues() {
Map<String, List<String>> headers = new HashMap<>();
headers.put(CommonValue.TENANT_ID_HEADER, List.of("tenant1", "tenant2"));
RestRequest restRequest = new FakeRestRequest.Builder(xContentRegistry()).withHeaders(headers).build();

String actualTenantID = RestActionUtils.getTenantID(Boolean.TRUE, restRequest);
assertEquals("tenant1", actualTenantID);
}

@Test
public void testGetTenantID_EmptyStringTenantID() {
Map<String, List<String>> headers = new HashMap<>();
headers.put(CommonValue.TENANT_ID_HEADER, Collections.singletonList(""));
RestRequest restRequest = new FakeRestRequest.Builder(xContentRegistry()).withHeaders(headers).build();

String actualTenantID = RestActionUtils.getTenantID(Boolean.TRUE, restRequest);
assertEquals("", actualTenantID);
}
}

0 comments on commit 03104cf

Please sign in to comment.