diff --git a/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java index 47f5aa40a..6885dc0c2 100644 --- a/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java @@ -13,6 +13,7 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.client.node.NodeClient; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -23,19 +24,19 @@ import org.opensearch.rest.RestResponse; import org.opensearch.rest.action.RestResponseListener; import org.opensearch.search.builder.SearchSourceBuilder; -import static org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting.FLOW_FRAMEWORK_ENABLED; import java.io.IOException; import java.util.ArrayList; import java.util.List; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; +import static org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting.FLOW_FRAMEWORK_ENABLED; import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext; /** * Abstract class to handle search request. */ -public class AbstractSearchWorkflowAction extends BaseRestHandler { +public abstract class AbstractSearchWorkflowAction extends BaseRestHandler { protected final List urlPaths; protected final String index; @@ -49,8 +50,15 @@ public class AbstractSearchWorkflowAction extends Ba * @param index index the search should be done on * @param clazz model class * @param actionType from which action abstract class is called + * @param flowFrameworkFeatureEnabledSetting Whether this API is enabled */ - public AbstractSearchWorkflowAction(List urlPaths, String index, Class clazz, ActionType actionType, FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting) { + public AbstractSearchWorkflowAction( + List urlPaths, + String index, + Class clazz, + ActionType actionType, + FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting + ) { this.urlPaths = urlPaths; this.index = index; this.clazz = clazz; @@ -58,17 +66,15 @@ public AbstractSearchWorkflowAction(List urlPaths, String index, Class channel.sendResponse( + new BytesRestResponse(ffe.getRestStatus(), ffe.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) ); } SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); diff --git a/src/main/java/org/opensearch/flowframework/rest/RestSearchWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestSearchWorkflowAction.java index a7334a61a..6843199b4 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestSearchWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestSearchWorkflowAction.java @@ -25,8 +25,14 @@ public class RestSearchWorkflowAction extends AbstractSearchWorkflowAction>>>>>> 57869ff (Added javadoc) assertEquals(1, ffp.getExecutorBuilders(settings).size()); assertEquals(1, ffp.getSettings().size()); } diff --git a/src/test/java/org/opensearch/flowframework/rest/RestSearchWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestSearchWorkflowActionTests.java index ce6f5e597..a92950534 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestSearchWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestSearchWorkflowActionTests.java @@ -10,11 +10,14 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.XContentParseException; +import org.opensearch.flowframework.common.FlowFrameworkFeatureEnabledSetting; import org.opensearch.rest.RestHandler.Route; import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestChannel; import org.opensearch.test.rest.FakeRestRequest; import java.util.List; @@ -22,23 +25,27 @@ import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class RestSearchWorkflowActionTests extends OpenSearchTestCase { private RestSearchWorkflowAction restSearchWorkflowAction; private String searchPath; private NodeClient nodeClient; + private FlowFrameworkFeatureEnabledSetting flowFrameworkFeatureEnabledSetting; @Override public void setUp() throws Exception { super.setUp(); this.searchPath = String.format(Locale.ROOT, "%s/%s", WORKFLOW_URI, "_search"); - this.restSearchWorkflowAction = new RestSearchWorkflowAction(); + flowFrameworkFeatureEnabledSetting = mock(FlowFrameworkFeatureEnabledSetting.class); + when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(true); + this.restSearchWorkflowAction = new RestSearchWorkflowAction(flowFrameworkFeatureEnabledSetting); this.nodeClient = mock(NodeClient.class); } public void testConstructor() { - RestSearchWorkflowAction searchWorkflowAction = new RestSearchWorkflowAction(); + RestSearchWorkflowAction searchWorkflowAction = new RestSearchWorkflowAction(flowFrameworkFeatureEnabledSetting); assertNotNull(searchWorkflowAction); } @@ -69,4 +76,15 @@ public void testInvalidSearchRequest() { }); assertEquals("unknown named object category [org.opensearch.index.query.QueryBuilder]", ex.getMessage()); } + + public void testFeatureFlagNotEnabled() throws Exception { + when(flowFrameworkFeatureEnabledSetting.isFlowFrameworkEnabled()).thenReturn(false); + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath(this.searchPath) + .build(); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + restSearchWorkflowAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.FORBIDDEN, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("This API is disabled.")); + } }