Skip to content

Commit

Permalink
Adding guardrails to default use case params (#658)
Browse files Browse the repository at this point in the history
* Adding guardrails to default use case params

Signed-off-by: Joshua Palis <[email protected]>

* Updating changelog and adding javadocs

Signed-off-by: Joshua Palis <[email protected]>

* Fixing test

Signed-off-by: Joshua Palis <[email protected]>

* Fixing integration tests, covering case in which no content is passed at all for default cases

Signed-off-by: Joshua Palis <[email protected]>

* Fixing tests

Signed-off-by: Joshua Palis <[email protected]>

* Fixing rest create workflow action

Signed-off-by: Joshua Palis <[email protected]>

* addressing PR comments

Signed-off-by: Joshua Palis <[email protected]>

* fixing test

Signed-off-by: Joshua Palis <[email protected]>

---------

Signed-off-by: Joshua Palis <[email protected]>
  • Loading branch information
joshpalis authored Apr 11, 2024
1 parent 3c67cff commit 9d28045
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 32 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)
## [Unreleased 2.x](https://github.com/opensearch-project/flow-framework/compare/2.13...2.x)
### Features
### Enhancements
- Add guardrails to default use case params ([#658](https://github.com/opensearch-project/flow-framework/pull/658))
### Bug Fixes
- Reset workflow state to initial state after successful deprovision ([#635](https://github.com/opensearch-project/flow-framework/pull/635))
- Silently ignore content on APIs that don't require it ([#639](https://github.com/opensearch-project/flow-framework/pull/639))
Expand Down
14 changes: 14 additions & 0 deletions src/main/java/org/opensearch/flowframework/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -202,4 +202,18 @@ private CommonValue() {}
public static final String RESOURCE_ID = "resource_id";
/** The field name for the opensearch-ml plugin */
public static final String OPENSEARCH_ML = "opensearch-ml";

/*
* Constants assoicated with substitution / default templates
*/
/** The field name for connector credential key substitution */
public static final String CREATE_CONNECTOR_CREDENTIAL_KEY = "create_connector.credential.key";
/** The field name for connector credential access key substitution */
public static final String CREATE_CONNECTOR_CREDENTIAL_ACCESS_KEY = "create_connector.credential.access_key";
/** The field name for connector credential secret key substitution */
public static final String CREATE_CONNECTOR_CREDENTIAL_SECRET_KEY = "create_connector.credential.secret_key";
/** The field name for connector credential session token substitution */
public static final String CREATE_CONNECTOR_CREDENTIAL_SESSION_TOKEN = "create_connector.credential.session_token";
/** The field name for ingest pipeline model ID substitution */
public static final String CREATE_INGEST_PIPELINE_MODEL_ID = "create_ingest_pipeline.model_id";
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import static org.opensearch.flowframework.common.CommonValue.CREATE_CONNECTOR_CREDENTIAL_ACCESS_KEY;
import static org.opensearch.flowframework.common.CommonValue.CREATE_CONNECTOR_CREDENTIAL_KEY;
import static org.opensearch.flowframework.common.CommonValue.CREATE_CONNECTOR_CREDENTIAL_SECRET_KEY;
import static org.opensearch.flowframework.common.CommonValue.CREATE_CONNECTOR_CREDENTIAL_SESSION_TOKEN;
import static org.opensearch.flowframework.common.CommonValue.CREATE_INGEST_PIPELINE_MODEL_ID;

/**
* Enum encapsulating the different default use cases and templates we have stored
*/
Expand All @@ -22,94 +32,119 @@ public enum DefaultUseCases {
OPEN_AI_EMBEDDING_MODEL_DEPLOY(
"open_ai_embedding_model_deploy",
"defaults/openai-embedding-defaults.json",
"substitutionTemplates/deploy-remote-model-template.json"
"substitutionTemplates/deploy-remote-model-template.json",
List.of(CREATE_CONNECTOR_CREDENTIAL_KEY)
),
/** defaults file and substitution ready template for Cohere embedding model */
COHERE_EMBEDDING_MODEL_DEPLOY(
"cohere_embedding_model_deploy",
"defaults/cohere-embedding-defaults.json",
"substitutionTemplates/deploy-remote-model-extra-params-template.json"
"substitutionTemplates/deploy-remote-model-extra-params-template.json",
List.of(CREATE_CONNECTOR_CREDENTIAL_KEY)
),
/** defaults file and substitution ready template for Bedrock Titan embedding model */
BEDROCK_TITAN_EMBEDDING_MODEL_DEPLOY(
"bedrock_titan_embedding_model_deploy",
"defaults/bedrock-titan-embedding-defaults.json",
"substitutionTemplates/deploy-remote-bedrock-model-template.json"
"substitutionTemplates/deploy-remote-bedrock-model-template.json",
List.of(CREATE_CONNECTOR_CREDENTIAL_ACCESS_KEY, CREATE_CONNECTOR_CREDENTIAL_SECRET_KEY, CREATE_CONNECTOR_CREDENTIAL_SESSION_TOKEN)
),
/** defaults file and substitution ready template for Bedrock Titan multimodal embedding model */
BEDROCK_TITAN_MULTIMODAL_MODEL_DEPLOY(
"bedrock_titan_multimodal_model_deploy",
"defaults/bedrock-titan-multimodal-defaults.json",
"substitutionTemplates/deploy-remote-bedrock-model-template.json"
"substitutionTemplates/deploy-remote-bedrock-model-template.json",
List.of(CREATE_CONNECTOR_CREDENTIAL_ACCESS_KEY, CREATE_CONNECTOR_CREDENTIAL_SECRET_KEY, CREATE_CONNECTOR_CREDENTIAL_SESSION_TOKEN)
),
/** defaults file and substitution ready template for Cohere chat model */
COHERE_CHAT_MODEL_DEPLOY(
"cohere_chat_model_deploy",
"defaults/cohere-chat-defaults.json",
"substitutionTemplates/deploy-remote-model-chat-template.json"
"substitutionTemplates/deploy-remote-model-chat-template.json",
List.of(CREATE_CONNECTOR_CREDENTIAL_KEY)
),
/** defaults file and substitution ready template for OpenAI chat model */
OPENAI_CHAT_MODEL_DEPLOY(
"openai_chat_model_deploy",
"defaults/openai-chat-defaults.json",
"substitutionTemplates/deploy-remote-model-chat-template.json"
"substitutionTemplates/deploy-remote-model-chat-template.json",
List.of(CREATE_CONNECTOR_CREDENTIAL_KEY)
),
/** defaults file and substitution ready template for local neural sparse model and ingest pipeline*/
LOCAL_NEURAL_SPARSE_SEARCH_BI_ENCODER(
"local_neural_sparse_search_bi_encoder",
"defaults/local-sparse-search-biencoder-defaults.json",
"substitutionTemplates/neural-sparse-local-biencoder-template.json"
"substitutionTemplates/neural-sparse-local-biencoder-template.json",
Collections.emptyList()
),
/** defaults file and substitution ready template for semantic search, no model creation*/
SEMANTIC_SEARCH("semantic_search", "defaults/semantic-search-defaults.json", "substitutionTemplates/semantic-search-template.json"),
SEMANTIC_SEARCH(
"semantic_search",
"defaults/semantic-search-defaults.json",
"substitutionTemplates/semantic-search-template.json",
List.of(CREATE_INGEST_PIPELINE_MODEL_ID)
),
/** defaults file and substitution ready template for multimodal search, no model creation*/
MULTI_MODAL_SEARCH(
"multimodal_search",
"defaults/multi-modal-search-defaults.json",
"substitutionTemplates/multi-modal-search-template.json"
"substitutionTemplates/multi-modal-search-template.json",
List.of(CREATE_INGEST_PIPELINE_MODEL_ID)
),
/** defaults file and substitution ready template for multimodal search, no model creation*/
MULTI_MODAL_SEARCH_WITH_BEDROCK_TITAN(
"multimodal_search_with_bedrock_titan",
"defaults/multimodal-search-bedrock-titan-defaults.json",
"substitutionTemplates/multi-modal-search-with-bedrock-titan-template.json"
"substitutionTemplates/multi-modal-search-with-bedrock-titan-template.json",
List.of(CREATE_CONNECTOR_CREDENTIAL_ACCESS_KEY, CREATE_CONNECTOR_CREDENTIAL_SECRET_KEY, CREATE_CONNECTOR_CREDENTIAL_SESSION_TOKEN)
),
/** defaults file and substitution ready template for semantic search with query enricher processor attached, no model creation*/
SEMANTIC_SEARCH_WITH_QUERY_ENRICHER(
"semantic_search_with_query_enricher",
"defaults/semantic-search-query-enricher-defaults.json",
"substitutionTemplates/semantic-search-with-query-enricher-template.json"
"substitutionTemplates/semantic-search-with-query-enricher-template.json",
List.of(CREATE_INGEST_PIPELINE_MODEL_ID)
),
/** defaults file and substitution ready template for semantic search with cohere embedding model*/
SEMANTIC_SEARCH_WITH_COHERE_EMBEDDING(
"semantic_search_with_cohere_embedding",
"defaults/cohere-embedding-semantic-search-defaults.json",
"substitutionTemplates/semantic-search-with-model-template.json"
"substitutionTemplates/semantic-search-with-model-template.json",
List.of(CREATE_CONNECTOR_CREDENTIAL_KEY)
),
/** defaults file and substitution ready template for semantic search with query enricher processor attached and cohere embedding model*/
SEMANTIC_SEARCH_WITH_COHERE_EMBEDDING_AND_QUERY_ENRICHER(
"semantic_search_with_cohere_embedding_query_enricher",
"defaults/cohere-embedding-semantic-search-with-query-enricher-defaults.json",
"substitutionTemplates/semantic-search-with-model-and-query-enricher-template.json"
"substitutionTemplates/semantic-search-with-model-and-query-enricher-template.json",
List.of(CREATE_CONNECTOR_CREDENTIAL_KEY)
),
/** defaults file and substitution ready template for hybrid search, no model creation*/
HYBRID_SEARCH("hybrid_search", "defaults/hybrid-search-defaults.json", "substitutionTemplates/hybrid-search-template.json"),
HYBRID_SEARCH(
"hybrid_search",
"defaults/hybrid-search-defaults.json",
"substitutionTemplates/hybrid-search-template.json",
List.of(CREATE_INGEST_PIPELINE_MODEL_ID)
),
/** defaults file and substitution ready template for conversational search with cohere chat model*/
CONVERSATIONAL_SEARCH_WITH_COHERE_DEPLOY(
"conversational_search_with_llm_deploy",
"defaults/conversational-search-defaults.json",
"substitutionTemplates/conversational-search-with-cohere-model-template.json"
"substitutionTemplates/conversational-search-with-cohere-model-template.json",
List.of(CREATE_CONNECTOR_CREDENTIAL_KEY)
);

private final String useCaseName;
private final String defaultsFile;
private final String substitutionReadyFile;
private final List<String> requiredParams;
private static final Logger logger = LogManager.getLogger(DefaultUseCases.class);

DefaultUseCases(String useCaseName, String defaultsFile, String substitutionReadyFile) {
DefaultUseCases(String useCaseName, String defaultsFile, String substitutionReadyFile, List<String> requiredParams) {
this.useCaseName = useCaseName;
this.defaultsFile = defaultsFile;
this.substitutionReadyFile = substitutionReadyFile;
this.requiredParams = requiredParams;
}

/**
Expand All @@ -136,6 +171,14 @@ public String getSubstitutionReadyFile() {
return substitutionReadyFile;
}

/**
* Returns the required params for the given enum Constant
* @return the required params of the given useCase
*/
public List<String> getRequiredParams() {
return requiredParams;
}

/**
* Gets the defaultsFile based on the given use case.
* @param useCaseName name of the given use case
Expand Down Expand Up @@ -171,4 +214,21 @@ public static String getSubstitutionReadyFileByUseCaseName(String useCaseName) t
logger.error("Unable to find substitution ready file for use case: {}", useCaseName);
throw new FlowFrameworkException("Unable to find substitution ready file for use case: " + useCaseName, RestStatus.BAD_REQUEST);
}

/**
* Gets the required parameters based on the given use case
* @param useCaseName name of the given use case
* @return the list of required params
*/
public static List<String> getRequiredParamsByUseCaseName(String useCaseName) {
if (useCaseName != null && !useCaseName.isEmpty()) {
for (DefaultUseCases useCase : values()) {
if (useCase.getUseCaseName().equals(useCaseName)) {
return new ArrayList<String>(useCase.getRequiredParams());
}
}
}
logger.error("Default use case [" + useCaseName + "] does not exist");
throw new FlowFrameworkException("Default use case [" + useCaseName + "] does not exist", RestStatus.BAD_REQUEST);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -126,12 +127,31 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
);
String defaultsFilePath = DefaultUseCases.getDefaultsFileByUseCaseName(useCase);
useCaseDefaultsMap = ParseUtils.parseJsonFileToStringToStringMap("/" + defaultsFilePath);

if (request.hasContent()) {
List<String> requiredParams = DefaultUseCases.getRequiredParamsByUseCaseName(useCase);

if (!request.hasContent()) {
if (!requiredParams.isEmpty()) {
throw new FlowFrameworkException(
"Missing the following required parameters for use case [" + useCase + "] : " + requiredParams.toString(),
RestStatus.BAD_REQUEST
);
}
} else {
try {
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
Map<String, Object> userDefaults = ParseUtils.parseStringToObjectMap(parser);

// Validate user defaults key set
Set<String> userDefaultKeys = userDefaults.keySet();
if (!userDefaultKeys.containsAll(requiredParams)) {
requiredParams.removeAll(userDefaultKeys);
throw new FlowFrameworkException(
"Missing the following required parameters for use case [" + useCase + "] : " + requiredParams.toString(),
RestStatus.BAD_REQUEST
);
}

// updates the default params with anything user has given that matches
for (Map.Entry<String, Object> userDefaultsEntry : userDefaults.entrySet()) {
String key = userDefaultsEntry.getKey();
Expand All @@ -141,13 +161,16 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
}
}
} catch (Exception ex) {
RestStatus status = ex instanceof IOException ? RestStatus.BAD_REQUEST : ExceptionsHelper.status(ex);
String errorMessage =
"failure parsing request body when a use case is given, make sure to provide a map with values that are either Strings, Arrays, or Map of Strings to Strings";
logger.error(errorMessage, ex);
throw new FlowFrameworkException(errorMessage, status);
if (ex instanceof FlowFrameworkException) {
throw ex;
} else {
RestStatus status = ex instanceof IOException ? RestStatus.BAD_REQUEST : ExceptionsHelper.status(ex);
String errorMessage =
"failure parsing request body when a use case is given, make sure to provide a map with values that are either Strings, Arrays, or Map of Strings to Strings";
logger.error(errorMessage, ex);
throw new FlowFrameworkException(errorMessage, status);
}
}

}

useCaseTemplateFileInStringFormat = (String) ParseUtils.conditionallySubstitute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,16 +345,26 @@ protected Response createWorkflow(RestClient client, Template template) throws E
* Helper method to invoke the Create Workflow Rest Action without validation
* @param client the rest client
* @param useCase the usecase to create
* @param the required params
* @throws Exception if the request fails
* @return a rest response
*/
protected Response createWorkflowWithUseCase(RestClient client, String useCase) throws Exception {
protected Response createWorkflowWithUseCase(RestClient client, String useCase, List<String> params) throws Exception {

StringBuilder sb = new StringBuilder();
for (String param : params) {
sb.append('"').append(param).append("\" : \"\",");
}
if (!params.isEmpty()) {
sb.deleteCharAt(sb.length() - 1);
}

return TestHelpers.makeRequest(
client,
"POST",
WORKFLOW_URI + "?validation=off&use_case=" + useCase,
Collections.emptyMap(),
"{}",
"{" + sb.toString() + "}",
null
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;

import static org.opensearch.flowframework.common.CommonValue.CREATE_CONNECTOR_CREDENTIAL_KEY;
import static org.opensearch.flowframework.common.CommonValue.CREATE_INGEST_PIPELINE_MODEL_ID;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID;

Expand Down Expand Up @@ -406,7 +408,7 @@ public void testCreateAndProvisionIngestAndSearchPipeline() throws Exception {
public void testDefaultCohereUseCase() throws Exception {

// Hit Create Workflow API with original template
Response response = createWorkflowWithUseCase(client(), "cohere_embedding_model_deploy");
Response response = createWorkflowWithUseCase(client(), "cohere_embedding_model_deploy", List.of(CREATE_CONNECTOR_CREDENTIAL_KEY));
assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response));

Map<String, Object> responseMap = entityAsMap(response);
Expand Down Expand Up @@ -442,8 +444,18 @@ public void testDefaultCohereUseCase() throws Exception {
}

public void testDefaultSemanticSearchUseCaseWithFailureExpected() throws Exception {
// Hit Create Workflow API with original template
Response response = createWorkflowWithUseCase(client(), "semantic_search");
// Hit Create Workflow API with original template without required params
ResponseException exception = expectThrows(
ResponseException.class,
() -> createWorkflowWithUseCase(client(), "semantic_search", Collections.emptyList())
);
assertTrue(
exception.getMessage()
.contains("Missing the following required parameters for use case [semantic_search] : [create_ingest_pipeline.model_id]")
);

// Pass in required params
Response response = createWorkflowWithUseCase(client(), "semantic_search", List.of(CREATE_INGEST_PIPELINE_MODEL_ID));
assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response));

Map<String, Object> responseMap = entityAsMap(response);
Expand Down Expand Up @@ -483,7 +495,11 @@ public void testAllDefaultUseCasesCreation() throws Exception {
.collect(Collectors.toSet());

for (String useCaseName : allUseCaseNames) {
Response response = createWorkflowWithUseCase(client(), useCaseName);
Response response = createWorkflowWithUseCase(
client(),
useCaseName,
DefaultUseCases.getRequiredParamsByUseCaseName(useCaseName)
);
assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response));

Map<String, Object> responseMap = entityAsMap(response);
Expand Down
Loading

0 comments on commit 9d28045

Please sign in to comment.