From 56111fcce8554e8868f34daa2c1c2688c49d0d8b Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Wed, 29 Nov 2023 15:06:16 -0800 Subject: [PATCH 1/5] [Backport feature/agent_framework] Encrypt/Decrypt template credentials (#217) Encrypt/Decrypt template credentials (#197) * added RegisterRemoteModelStep and tests * Adding RegisterLocalModelStep, fixing tests, adding input/ouput definitions to workflow step json * Fixing javadoc warnings, fixing log message * Addressing PR comments,making description field optional for RegisterRemoteModelStep and RegisterLocalModelStep * moving modelConfig builder before adding allConfig * initial implementation * Fixing create workflow transport action * Removing duplicate register_remote_model validator * Adding bouncy castle dependency to resolve encryption issue * Fixing CreateWorkflowTransportActionTests * Adding initial unit tests for encryptor utils * Implemented encryption/decryption for workflow node user inputs with credential * Addressing PR comments * Suppressing unchecked warning, making credential strings constants * Removing setMasterKey from initializeMasterKey method * Adding final template encryption decryption test * Addressing PR comments, changing master key index name to config, fixes error messages as well. Added create time field to config index, ensured that updates are also encrypted * Added TODO * changing getMasterKeyIndexMapping method name * Removing unnecessary aws sdk dependency --------- (cherry picked from commit 772fbbd9eb58aeda52ea57a36307ad7206b1cf0b) Signed-off-by: Joshua Palis Signed-off-by: github-actions[bot] Co-authored-by: github-actions[bot] --- build.gradle | 2 + .../flowframework/FlowFrameworkPlugin.java | 6 +- .../flowframework/common/CommonValue.java | 14 +- .../indices/FlowFrameworkIndex.java | 10 + .../indices/FlowFrameworkIndicesHandler.java | 47 ++- .../CreateWorkflowTransportAction.java | 60 ++-- .../ProvisionWorkflowTransportAction.java | 11 +- .../flowframework/util/EncryptorUtils.java | 286 ++++++++++++++++++ .../workflow/CreateConnectorStep.java | 6 +- src/main/resources/mappings/config.json | 15 + .../FlowFrameworkPluginTests.java | 2 +- .../FlowFrameworkIndicesHandlerTests.java | 5 +- .../CreateWorkflowTransportActionTests.java | 14 + ...ProvisionWorkflowTransportActionTests.java | 8 +- .../util/EncryptorUtilsTests.java | 193 ++++++++++++ .../workflow/CreateConnectorStepTests.java | 2 +- 16 files changed, 648 insertions(+), 33 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java create mode 100644 src/main/resources/mappings/config.json create mode 100644 src/test/java/org/opensearch/flowframework/util/EncryptorUtilsTests.java diff --git a/build.gradle b/build.gradle index a7e6046d4..6b66247a6 100644 --- a/build.gradle +++ b/build.gradle @@ -142,6 +142,8 @@ dependencies { api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.14.0' implementation "org.opensearch:common-utils:${common_utils_version}" + implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1' + implementation 'org.bouncycastle:bcprov-jdk18on:1.77' configurations.all { resolutionStrategy { diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 4b2e4a9cb..14df7e17e 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -39,6 +39,7 @@ import org.opensearch.flowframework.transport.ProvisionWorkflowTransportAction; import org.opensearch.flowframework.transport.SearchWorkflowAction; import org.opensearch.flowframework.transport.SearchWorkflowTransportAction; +import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.ml.client.MachineLearningNodeClient; @@ -96,7 +97,8 @@ public Collection createComponents( this.clusterService = clusterService; flowFrameworkFeatureEnabledSetting = new FlowFrameworkFeatureEnabledSetting(clusterService, settings); MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); - FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService); + EncryptorUtils encryptorUtils = new EncryptorUtils(clusterService, client); + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService, encryptorUtils); WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory( settings, clusterService, @@ -106,7 +108,7 @@ public Collection createComponents( ); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); - return ImmutableList.of(workflowStepFactory, workflowProcessSorter, flowFrameworkIndicesHandler); + return ImmutableList.of(workflowStepFactory, workflowProcessSorter, encryptorUtils, flowFrameworkIndicesHandler); } @Override diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 1f9f33f4b..90f208c8d 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -33,6 +33,16 @@ private CommonValue() {} public static final String WORKFLOW_STATE_INDEX_MAPPING = "mappings/workflow-state.json"; /** Workflow State index mapping version */ public static final Integer WORKFLOW_STATE_INDEX_VERSION = 1; + /** Config Index Name */ + public static final String CONFIG_INDEX = ".plugins-flow-framework-config"; + /** Config index mapping file path */ + public static final String CONFIG_INDEX_MAPPING = "mappings/config.json"; + /** Config index mapping version */ + public static final Integer CONFIG_INDEX_VERSION = 1; + /** Master key field name */ + public static final String MASTER_KEY = "master_key"; + /** Create Time field name */ + public static final String CREATE_TIME = "create_time"; /** The template field name for template use case */ public static final String USE_CASE_FIELD = "use_case"; @@ -119,8 +129,8 @@ private CommonValue() {} public static final String PROTOCOL_FIELD = "protocol"; /** Connector parameters field */ public static final String PARAMETERS_FIELD = "parameters"; - /** Connector credentials field */ - public static final String CREDENTIALS_FIELD = "credentials"; + /** Connector credential field */ + public static final String CREDENTIAL_FIELD = "credential"; /** Connector actions field */ public static final String ACTIONS_FIELD = "actions"; /** Backend roles for the model */ diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java index 4b005e45d..b74c1d89a 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java @@ -12,6 +12,8 @@ import java.util.function.Supplier; +import static org.opensearch.flowframework.common.CommonValue.CONFIG_INDEX; +import static org.opensearch.flowframework.common.CommonValue.CONFIG_INDEX_VERSION; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX_VERSION; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; @@ -36,6 +38,14 @@ public enum FlowFrameworkIndex { WORKFLOW_STATE_INDEX, ThrowingSupplierWrapper.throwingSupplierWrapper(FlowFrameworkIndicesHandler::getWorkflowStateMappings), WORKFLOW_STATE_INDEX_VERSION + ), + /** + * Config Index + */ + CONFIG( + CONFIG_INDEX, + ThrowingSupplierWrapper.throwingSupplierWrapper(FlowFrameworkIndicesHandler::getConfigIndexMappings), + CONFIG_INDEX_VERSION ); private final String indexName; diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index 7dcc89de6..cce4ba839 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -38,6 +38,7 @@ import org.opensearch.flowframework.model.State; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.script.Script; import java.io.IOException; @@ -48,6 +49,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; +import static org.opensearch.flowframework.common.CommonValue.CONFIG_INDEX_MAPPING; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX_MAPPING; import static org.opensearch.flowframework.common.CommonValue.META; @@ -64,6 +66,7 @@ public class FlowFrameworkIndicesHandler { private static final Logger logger = LogManager.getLogger(FlowFrameworkIndicesHandler.class); private final Client client; private final ClusterService clusterService; + private final EncryptorUtils encryptorUtils; private static final Map indexMappingUpdated = new HashMap<>(); private static final Map indexSettings = Map.of("index.auto_expand_replicas", "0-1"); @@ -71,10 +74,12 @@ public class FlowFrameworkIndicesHandler { * constructor * @param client the open search client * @param clusterService ClusterService + * @param encryptorUtils encryption utility */ - public FlowFrameworkIndicesHandler(Client client, ClusterService clusterService) { + public FlowFrameworkIndicesHandler(Client client, ClusterService clusterService, EncryptorUtils encryptorUtils) { this.client = client; this.clusterService = clusterService; + this.encryptorUtils = encryptorUtils; for (FlowFrameworkIndex mlIndex : FlowFrameworkIndex.values()) { indexMappingUpdated.put(mlIndex.getIndexName(), new AtomicBoolean(false)); } @@ -104,6 +109,15 @@ public static String getWorkflowStateMappings() throws IOException { return getIndexMappings(WORKFLOW_STATE_INDEX_MAPPING); } + /** + * Get config index mapping + * @return config index mapping + * @throws IOException if mapping file cannot be read correctly + */ + public static String getConfigIndexMappings() throws IOException { + return getIndexMappings(CONFIG_INDEX_MAPPING); + } + /** * Create global context index if it's absent * @param listener The action listener @@ -120,6 +134,14 @@ public void initWorkflowStateIndexIfAbsent(ActionListener listener) { initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex.WORKFLOW_STATE, listener); } + /** + * Create config index if it's absent + * @param listener The action listener + */ + public void initConfigIndexIfAbsent(ActionListener listener) { + initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex.CONFIG, listener); + } + /** * Checks if the given index exists * @param indexName the name of the index @@ -287,7 +309,8 @@ public void putTemplateToGlobalContext(Template template, ActionListener context.restore())); } catch (Exception e) { @@ -301,6 +324,23 @@ public void putTemplateToGlobalContext(Template template, ActionListener listener) { + initConfigIndexIfAbsent(ActionListener.wrap(indexCreated -> { + if (!indexCreated) { + listener.onFailure(new FlowFrameworkException("No response to create config index", INTERNAL_SERVER_ERROR)); + return; + } + encryptorUtils.initializeMasterKey(listener); + }, createIndexException -> { + logger.error("Failed to create config index", createIndexException); + listener.onFailure(createIndexException); + })); + } + /** * add document insert into global context index * @param workflowId the workflowId, corresponds to document ID of @@ -361,7 +401,8 @@ public void updateTemplateInGlobalContext(String documentId, Template template, XContentBuilder builder = XContentFactory.jsonBuilder(); ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext() ) { - request.source(template.toXContent(builder, ToXContent.EMPTY_PARAMS)) + Template encryptedTemplate = encryptorUtils.encryptTemplateCredentials(template); + request.source(encryptedTemplate.toXContent(builder, ToXContent.EMPTY_PARAMS)) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); client.index(request, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index a77e7dd79..6ca1c4661 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -119,25 +119,49 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { - flowFrameworkIndicesHandler.putInitialStateToWorkflowState( - globalContextResponse.getId(), - user, - ActionListener.wrap(stateResponse -> { - logger.info("create state workflow doc"); - listener.onResponse(new WorkflowResponse(globalContextResponse.getId())); - }, exception -> { - logger.error("Failed to save workflow state : {}", exception.getMessage()); - if (exception instanceof FlowFrameworkException) { - listener.onFailure(exception); - } else { - listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST)); - } - }) - ); + // Initialize config index and create new global context and state index entries + flowFrameworkIndicesHandler.initializeConfigIndex(ActionListener.wrap(isInitialized -> { + if (!isInitialized) { + listener.onFailure( + new FlowFrameworkException("Failed to initalize config index", RestStatus.INTERNAL_SERVER_ERROR) + ); + } else { + // Create new global context and state index entries + flowFrameworkIndicesHandler.putTemplateToGlobalContext( + templateWithUser, + ActionListener.wrap(globalContextResponse -> { + flowFrameworkIndicesHandler.putInitialStateToWorkflowState( + globalContextResponse.getId(), + user, + ActionListener.wrap(stateResponse -> { + logger.info("create state workflow doc"); + listener.onResponse(new WorkflowResponse(globalContextResponse.getId())); + }, exception -> { + logger.error("Failed to save workflow state : {}", exception.getMessage()); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure( + new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST) + ); + } + }) + ); + }, exception -> { + logger.error("Failed to save use case template : {}", exception.getMessage()); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + } + + }) + ); + } }, exception -> { - logger.error("Failed to save use case template : {}", exception.getMessage()); + logger.error("Failed to initialize config index : {}", exception.getMessage()); if (exception instanceof FlowFrameworkException) { listener.onFailure(exception); } else { diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 1f19a4d04..b381b41ec 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -26,6 +26,7 @@ import org.opensearch.flowframework.model.State; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.tasks.Task; @@ -61,6 +62,7 @@ public class ProvisionWorkflowTransportAction extends HandledTransportAction provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId); diff --git a/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java b/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java new file mode 100644 index 000000000..70b30b5cc --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java @@ -0,0 +1,286 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.util; + +import com.google.common.collect.ImmutableMap; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowNode; + +import javax.crypto.spec.SecretKeySpec; + +import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Base64; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import com.amazonaws.encryptionsdk.AwsCrypto; +import com.amazonaws.encryptionsdk.CommitmentPolicy; +import com.amazonaws.encryptionsdk.CryptoResult; +import com.amazonaws.encryptionsdk.jce.JceMasterKey; + +import static org.opensearch.flowframework.common.CommonValue.CONFIG_INDEX; +import static org.opensearch.flowframework.common.CommonValue.CREATE_TIME; +import static org.opensearch.flowframework.common.CommonValue.CREDENTIAL_FIELD; +import static org.opensearch.flowframework.common.CommonValue.MASTER_KEY; + +/** + * Encryption utility class + */ +public class EncryptorUtils { + + private static final Logger logger = LogManager.getLogger(EncryptorUtils.class); + + private static final String ALGORITHM = "AES"; + private static final String PROVIDER = "Custom"; + private static final String WRAPPING_ALGORITHM = "AES/GCM/NoPadding"; + + private ClusterService clusterService; + private Client client; + private String masterKey; + + /** + * Instantiates a new EncryptorUtils object + * @param clusterService the cluster service + * @param client the node client + */ + public EncryptorUtils(ClusterService clusterService, Client client) { + this.masterKey = null; + this.clusterService = clusterService; + this.client = client; + } + + /** + * Sets the master key + * @param masterKey the master key + */ + void setMasterKey(String masterKey) { + this.masterKey = masterKey; + } + + /** + * Returns the master key + * @return the master key + */ + String getMasterKey() { + return this.masterKey; + } + + /** + * Randomly generate a master key + * @return the master key string + */ + String generateMasterKey() { + byte[] masterKeyBytes = new byte[32]; + new SecureRandom().nextBytes(masterKeyBytes); + return Base64.getEncoder().encodeToString(masterKeyBytes); + } + + /** + * Encrypts template credentials + * @param template the template to encrypt + * @return template with encrypted credentials + */ + public Template encryptTemplateCredentials(Template template) { + return processTemplateCredentials(template, this::encrypt); + } + + /** + * Decrypts template credentials + * @param template the template to decrypt + * @return template with decrypted credentials + */ + public Template decryptTemplateCredentials(Template template) { + return processTemplateCredentials(template, this::decrypt); + } + + // TODO : Improve processTemplateCredentials to encrypt different fields based on the WorkflowStep type + /** + * Applies the given cipher function on template credentials + * @param template the template to process + * @param cipher the encryption/decryption function to apply on credential values + * @return template with encrypted credentials + */ + private Template processTemplateCredentials(Template template, Function cipherFunction) { + Template.Builder processedTemplateBuilder = new Template.Builder(); + + Map processedWorkflows = new HashMap<>(); + for (Map.Entry entry : template.workflows().entrySet()) { + + List processedNodes = new ArrayList<>(); + for (WorkflowNode node : entry.getValue().nodes()) { + if (node.userInputs().containsKey(CREDENTIAL_FIELD)) { + // Apply the cipher funcion on all values within credential field + @SuppressWarnings("unchecked") + Map credentials = new HashMap<>((Map) node.userInputs().get(CREDENTIAL_FIELD)); + credentials.replaceAll((key, cred) -> cipherFunction.apply(cred)); + + // Replace credentials field in node user inputs + Map processedUserInputs = new HashMap<>(); + processedUserInputs.putAll(node.userInputs()); + processedUserInputs.replace(CREDENTIAL_FIELD, credentials); + + // build new node to add to processed nodes + WorkflowNode processedWorkflowNode = new WorkflowNode( + node.id(), + node.type(), + node.previousNodeInputs(), + processedUserInputs + ); + processedNodes.add(processedWorkflowNode); + } else { + processedNodes.add(node); + } + } + + // Add processed workflow nodes to processed workflows + processedWorkflows.put(entry.getKey(), new Workflow(entry.getValue().userParams(), processedNodes, entry.getValue().edges())); + } + + Template processedTemplate = processedTemplateBuilder.name(template.name()) + .description(template.description()) + .useCase(template.useCase()) + .templateVersion(template.templateVersion()) + .compatibilityVersion(template.compatibilityVersion()) + .workflows(processedWorkflows) + .uiMetadata(template.getUiMetadata()) + .user(template.getUser()) + .build(); + + return processedTemplate; + } + + /** + * Encrypts the given credential + * @param credential the credential to encrypt + * @return the encrypted credential + */ + String encrypt(final String credential) { + initializeMasterKeyIfAbsent(); + final AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt).build(); + byte[] bytes = Base64.getDecoder().decode(masterKey); + JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, ALGORITHM), PROVIDER, "", WRAPPING_ALGORITHM); + final CryptoResult encryptResult = crypto.encryptData( + jceMasterKey, + credential.getBytes(StandardCharsets.UTF_8) + ); + return Base64.getEncoder().encodeToString(encryptResult.getResult()); + } + + /** + * Decrypts the given credential + * @param encryptedCredential the credential to decrypt + * @return the decrypted credential + */ + String decrypt(final String encryptedCredential) { + initializeMasterKeyIfAbsent(); + final AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt).build(); + + byte[] bytes = Base64.getDecoder().decode(masterKey); + JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, ALGORITHM), PROVIDER, "", WRAPPING_ALGORITHM); + + final CryptoResult decryptedResult = crypto.decryptData( + jceMasterKey, + Base64.getDecoder().decode(encryptedCredential) + ); + return new String(decryptedResult.getResult(), StandardCharsets.UTF_8); + } + + /** + * Retrieves an existing master key or generates a new key to index + * @param listener the action listener + */ + public void initializeMasterKey(ActionListener listener) { + // Index has either been created or it already exists, need to check if master key has been initalized already, if not then + // generate + // This is necessary in case of global context index restoration from snapshot, will need to use the same master key to decrypt + // stored credentials + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + + GetRequest getRequest = new GetRequest(CONFIG_INDEX).id(MASTER_KEY); + client.get(getRequest, ActionListener.wrap(getResponse -> { + + if (!getResponse.isExists()) { + + // Generate new key and index + final String masterKey = generateMasterKey(); + IndexRequest masterKeyIndexRequest = new IndexRequest(CONFIG_INDEX).id(MASTER_KEY) + .source(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME, Instant.now().toEpochMilli())) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + client.index(masterKeyIndexRequest, ActionListener.wrap(indexResponse -> { + // Set generated key to master + logger.info("Config has been initialized successfully"); + this.masterKey = masterKey; + listener.onResponse(true); + }, indexException -> { + logger.error("Failed to index config", indexException); + listener.onFailure(indexException); + })); + + } else { + // Set existing key to master + logger.info("Config has already been initialized"); + final String masterKey = (String) getResponse.getSourceAsMap().get(MASTER_KEY); + this.masterKey = masterKey; + listener.onResponse(true); + } + }, getRequestException -> { + logger.error("Failed to search for config from config index", getRequestException); + listener.onFailure(getRequestException); + })); + + } catch (Exception e) { + logger.error("Failed to retrieve config from config index", e); + listener.onFailure(e); + } + } + + /** + * Retrieves master key from system index if not yet set + */ + void initializeMasterKeyIfAbsent() { + if (masterKey != null) { + return; + } + + if (!clusterService.state().metadata().hasIndex(CONFIG_INDEX)) { + throw new FlowFrameworkException("Config Index has not been initialized", RestStatus.INTERNAL_SERVER_ERROR); + } else { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + GetRequest getRequest = new GetRequest(CONFIG_INDEX).id(MASTER_KEY); + client.get(getRequest, ActionListener.wrap(response -> { + if (response.isExists()) { + this.masterKey = (String) response.getSourceAsMap().get(MASTER_KEY); + } else { + throw new FlowFrameworkException("Config has not been initialized", RestStatus.NOT_FOUND); + } + }, exception -> { throw new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)); })); + } + } + } + +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index 41bd71489..d685cbeaa 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -42,7 +42,7 @@ import java.util.stream.Stream; import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD; -import static org.opensearch.flowframework.common.CommonValue.CREDENTIALS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.CREDENTIAL_FIELD; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; @@ -155,8 +155,8 @@ public void onFailure(Exception e) { case PARAMETERS_FIELD: parameters = getParameterMap(entry.getValue()); break; - case CREDENTIALS_FIELD: - credentials = getStringToStringMap(entry.getValue(), CREDENTIALS_FIELD); + case CREDENTIAL_FIELD: + credentials = getStringToStringMap(entry.getValue(), CREDENTIAL_FIELD); break; case ACTIONS_FIELD: actions = getConnectorActionList(entry.getValue()); diff --git a/src/main/resources/mappings/config.json b/src/main/resources/mappings/config.json new file mode 100644 index 000000000..6871f541c --- /dev/null +++ b/src/main/resources/mappings/config.json @@ -0,0 +1,15 @@ +{ + "dynamic": false, + "_meta": { + "schema_version": 1 + }, + "properties": { + "master_key": { + "type": "keyword" + }, + "create_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + } + } +} diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index c89baba13..e3827e0b3 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -78,7 +78,7 @@ public void tearDown() throws Exception { public void testPlugin() throws IOException { try (FlowFrameworkPlugin ffp = new FlowFrameworkPlugin()) { assertEquals( - 3, + 4, ffp.createComponents(client, clusterService, threadPool, null, null, null, environment, null, null, null, null).size() ); assertEquals(4, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); diff --git a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java index 2f0fc256f..fd7d8e0d4 100644 --- a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java +++ b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java @@ -25,6 +25,7 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.workflow.CreateIndexStep; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -53,6 +54,8 @@ public class FlowFrameworkIndicesHandlerTests extends OpenSearchTestCase { private CreateIndexStep createIndexStep; @Mock private ThreadPool threadPool; + @Mock + private EncryptorUtils encryptorUtils; private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; private AdminClient adminClient; private IndicesAdminClient indicesAdminClient; @@ -77,7 +80,7 @@ public void setUp() throws Exception { threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService); + flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService, encryptorUtils); adminClient = mock(AdminClient.class); indicesAdminClient = mock(IndicesAdminClient.class); metadata = mock(Metadata.class); diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 22d831f2b..70c066c0e 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -231,6 +231,13 @@ public void testFailedToCreateNewWorkflow() { return null; }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); + // Bypass initializeConfigIndex and force onResponse + doAnswer(invocation -> { + ActionListener initalizeMasterKeyIndexListener = invocation.getArgument(0); + initalizeMasterKeyIndexListener.onResponse(true); + return null; + }).when(flowFrameworkIndicesHandler).initializeConfigIndex(any()); + doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(1); responseListener.onFailure(new Exception("Failed to create global_context index")); @@ -261,6 +268,13 @@ public void testCreateNewWorkflow() { return null; }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); + // Bypass initializeConfigIndex and force onResponse + doAnswer(invocation -> { + ActionListener initalizeMasterKeyIndexListener = invocation.getArgument(0); + initalizeMasterKeyIndexListener.onResponse(true); + return null; + }).when(flowFrameworkIndicesHandler).initializeConfigIndex(any()); + // Bypass putTemplateToGlobalContext and force onResponse doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(1); diff --git a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java index 7d061e572..8bdcaa2c7 100644 --- a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java @@ -25,6 +25,7 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.index.get.GetResult; import org.opensearch.tasks.Task; @@ -53,6 +54,7 @@ public class ProvisionWorkflowTransportActionTests extends OpenSearchTestCase { private ProvisionWorkflowTransportAction provisionWorkflowTransportAction; private Template template; private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private EncryptorUtils encryptorUtils; @Override public void setUp() throws Exception { @@ -61,6 +63,7 @@ public void setUp() throws Exception { this.client = mock(Client.class); this.workflowProcessSorter = mock(WorkflowProcessSorter.class); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + this.encryptorUtils = mock(EncryptorUtils.class); this.provisionWorkflowTransportAction = new ProvisionWorkflowTransportAction( mock(TransportService.class), @@ -68,7 +71,8 @@ public void setUp() throws Exception { threadPool, client, workflowProcessSorter, - flowFrameworkIndicesHandler + flowFrameworkIndicesHandler, + encryptorUtils ); Version templateVersion = Version.fromString("1.0.0"); @@ -116,6 +120,8 @@ public void testProvisionWorkflow() { return null; }).when(client).get(any(GetRequest.class), any()); + when(encryptorUtils.decryptTemplateCredentials(any())).thenReturn(template); + provisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); verify(listener, times(1)).onResponse(responseCaptor.capture()); diff --git a/src/test/java/org/opensearch/flowframework/util/EncryptorUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/EncryptorUtilsTests.java new file mode 100644 index 000000000..1101fa1a2 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/util/EncryptorUtilsTests.java @@ -0,0 +1,193 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.util; + +import com.google.common.collect.ImmutableMap; +import org.opensearch.Version; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; + +import java.util.List; +import java.util.Map; + +import static org.opensearch.flowframework.common.CommonValue.CREDENTIAL_FIELD; +import static org.opensearch.flowframework.common.CommonValue.MASTER_KEY; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class EncryptorUtilsTests extends OpenSearchTestCase { + + private ClusterService clusterService; + private Client client; + private EncryptorUtils encryptorUtils; + private String testMasterKey; + private Template testTemplate; + private String testCredentialKey; + private String testCredentialValue; + + @Override + public void setUp() throws Exception { + super.setUp(); + this.clusterService = mock(ClusterService.class); + this.client = mock(Client.class); + this.encryptorUtils = new EncryptorUtils(clusterService, client); + this.testMasterKey = encryptorUtils.generateMasterKey(); + this.testCredentialKey = "credential_key"; + this.testCredentialValue = "12345"; + + Version templateVersion = Version.fromString("1.0.0"); + List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); + WorkflowNode nodeA = new WorkflowNode( + "A", + "a-type", + Map.of(), + Map.of(CREDENTIAL_FIELD, Map.of(testCredentialKey, testCredentialValue)) + ); + List nodes = List.of(nodeA); + Workflow workflow = new Workflow(Map.of("key", "value"), nodes, List.of()); + + this.testTemplate = new Template( + "test", + "description", + "use case", + templateVersion, + compatibilityVersions, + Map.of("provision", workflow), + Map.of(), + TestHelpers.randomUser() + ); + + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex(anyString())).thenReturn(true); + + ThreadPool threadPool = mock(ThreadPool.class); + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + } + + public void testGenerateMasterKey() { + String generatedMasterKey = encryptorUtils.generateMasterKey(); + encryptorUtils.setMasterKey(generatedMasterKey); + assertEquals(generatedMasterKey, encryptorUtils.getMasterKey()); + } + + public void testEncryptDecrypt() { + encryptorUtils.setMasterKey(testMasterKey); + String testString = "test"; + String encrypted = encryptorUtils.encrypt(testString); + assertNotNull(encrypted); + + String decrypted = encryptorUtils.decrypt(encrypted); + assertEquals(testString, decrypted); + } + + public void testEncryptWithDifferentMasterKey() { + encryptorUtils.setMasterKey(testMasterKey); + String testString = "test"; + String encrypted1 = encryptorUtils.encrypt(testString); + assertNotNull(encrypted1); + + // Change the master key prior to encryption + String differentMasterKey = encryptorUtils.generateMasterKey(); + encryptorUtils.setMasterKey(differentMasterKey); + String encrypted2 = encryptorUtils.encrypt(testString); + + assertNotEquals(encrypted1, encrypted2); + } + + public void testInitializeMasterKeySuccess() { + encryptorUtils.setMasterKey(null); + + String masterKey = encryptorUtils.generateMasterKey(); + doAnswer(invocation -> { + ActionListener getRequestActionListener = invocation.getArgument(1); + + // Stub get response for success case + GetResponse getResponse = mock(GetResponse.class); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSourceAsMap()).thenReturn(ImmutableMap.of(MASTER_KEY, masterKey)); + + getRequestActionListener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), any()); + + encryptorUtils.initializeMasterKeyIfAbsent(); + assertEquals(masterKey, encryptorUtils.getMasterKey()); + } + + public void testInitializeMasterKeyFailure() { + encryptorUtils.setMasterKey(null); + + doAnswer(invocation -> { + ActionListener getRequestActionListener = invocation.getArgument(1); + + // Stub get response for failure case + GetResponse getResponse = mock(GetResponse.class); + when(getResponse.isExists()).thenReturn(false); + getRequestActionListener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), any()); + + FlowFrameworkException ex = expectThrows(FlowFrameworkException.class, () -> encryptorUtils.initializeMasterKeyIfAbsent()); + assertEquals("Config has not been initialized", ex.getMessage()); + } + + public void testEncryptDecryptTemplateCredential() { + encryptorUtils.setMasterKey(testMasterKey); + + // Ecnrypt template with credential field + Template processedtemplate = encryptorUtils.encryptTemplateCredentials(testTemplate); + + // Validate the encrytped field + WorkflowNode node = processedtemplate.workflows().get("provision").nodes().get(0); + + @SuppressWarnings("unchecked") + Map encryptedCredentialMap = (Map) node.userInputs().get(CREDENTIAL_FIELD); + assertEquals(1, encryptedCredentialMap.size()); + + String encryptedCredential = encryptedCredentialMap.get(testCredentialKey); + assertNotNull(encryptedCredential); + assertNotEquals(testCredentialValue, encryptedCredential); + + // Decrypt credential field + processedtemplate = encryptorUtils.decryptTemplateCredentials(processedtemplate); + + // Validate the decrypted field + node = processedtemplate.workflows().get("provision").nodes().get(0); + + @SuppressWarnings("unchecked") + Map decryptedCredentialMap = (Map) node.userInputs().get(CREDENTIAL_FIELD); + assertEquals(1, decryptedCredentialMap.size()); + + String decryptedCredential = decryptedCredentialMap.get(testCredentialKey); + assertNotNull(decryptedCredential); + assertEquals(testCredentialValue, decryptedCredential); + } +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index a05a3927e..e26fdf0c4 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -68,7 +68,7 @@ public void setUp() throws Exception { Map.entry(CommonValue.VERSION_FIELD, "1"), Map.entry(CommonValue.PROTOCOL_FIELD, "test"), Map.entry(CommonValue.PARAMETERS_FIELD, params), - Map.entry(CommonValue.CREDENTIALS_FIELD, credentials), + Map.entry(CommonValue.CREDENTIAL_FIELD, credentials), Map.entry(CommonValue.ACTIONS_FIELD, actions) ), "test-id" From cac2e0b26ed11043013770f61e35bf1ac16b06c7 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Wed, 29 Nov 2023 17:14:37 -0800 Subject: [PATCH 2/5] Improve workflow step execute API (#215) * Improve workflow step execute API Signed-off-by: Daniel Widdis * Fix typo Signed-off-by: Daniel Widdis --------- Signed-off-by: Daniel Widdis --- .../workflow/CreateConnectorStep.java | 18 ++++++++-- .../workflow/CreateIndexStep.java | 22 +++++++++++-- .../workflow/CreateIngestPipelineStep.java | 20 +++++++++-- .../workflow/DeployModelStep.java | 17 ++++++++-- .../flowframework/workflow/GetMLTaskStep.java | 30 ++++++++++++++--- .../workflow/ModelGroupStep.java | 16 +++++++-- .../flowframework/workflow/NoOpStep.java | 9 +++-- .../flowframework/workflow/ProcessNode.java | 15 ++++++--- .../workflow/RegisterLocalModelStep.java | 17 ++++++++-- .../workflow/RegisterRemoteModelStep.java | 17 ++++++++-- .../flowframework/workflow/WorkflowData.java | 24 +++++++++++--- .../workflow/WorkflowProcessSorter.java | 2 +- .../flowframework/workflow/WorkflowStep.java | 14 ++++++-- .../workflow/CreateConnectorStepTests.java | 19 ++++++++--- .../workflow/CreateIndexStepTests.java | 18 +++++++--- .../CreateIngestPipelineStepTests.java | 31 +++++++++++++---- .../workflow/DeployModelStepTests.java | 18 +++++++--- .../workflow/GetMLTaskStepTests.java | 25 +++++++++++--- .../workflow/ModelGroupStepTests.java | 27 +++++++++++---- .../flowframework/workflow/NoOpStepTests.java | 7 +++- .../workflow/ProcessNodeTests.java | 33 +++++++++++++++---- .../workflow/RegisterLocalModelStepTests.java | 26 ++++++++++++--- .../RegisterRemoteModelStepTests.java | 26 ++++++++++++--- .../workflow/WorkflowDataTests.java | 7 ++-- 24 files changed, 373 insertions(+), 85 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index d685cbeaa..b00857ff6 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -74,22 +74,28 @@ public CreateConnectorStep(MachineLearningNodeClient mlClient, FlowFrameworkIndi // TODO: need to add retry conflicts here @Override - public CompletableFuture execute(List data) throws IOException { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { CompletableFuture createConnectorFuture = new CompletableFuture<>(); ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { + String workflowId = currentNodeInputs.getWorkflowId(); createConnectorFuture.complete( new WorkflowData( Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId())), - data.get(0).getWorkflowId() + workflowId, + currentNodeInputs.getNodeId() ) ); try { logger.info("Created connector successfully"); - String workflowId = data.get(0).getWorkflowId(); String workflowStepName = getName(); ResourceCreated newResource = new ResourceCreated(workflowStepName, mlCreateConnectorResponse.getConnectorId()); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); @@ -136,6 +142,12 @@ public void onFailure(Exception e) { Map credentials = Collections.emptyMap(); List actions = Collections.emptyList(); + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + try { for (WorkflowData workflowData : data) { for (Entry entry : workflowData.getContent().entrySet()) { diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index f3a82b26c..f443e9c2c 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -19,6 +19,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -54,14 +55,25 @@ public CreateIndexStep(ClusterService clusterService, Client client) { } @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture future = new CompletableFuture<>(); ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(CreateIndexResponse createIndexResponse) { logger.info("created index: {}", createIndexResponse.index()); - future.complete(new WorkflowData(Map.of(INDEX_NAME, createIndexResponse.index()), data.get(0).getWorkflowId())); + future.complete( + new WorkflowData( + Map.of(INDEX_NAME, createIndexResponse.index()), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); } @Override @@ -75,6 +87,12 @@ public void onFailure(Exception e) { String type = null; Settings settings = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + for (WorkflowData workflowData : data) { Map content = workflowData.getContent(); index = (String) content.get(INDEX_NAME); diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java index a63a800fd..77dae29eb 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java @@ -20,6 +20,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -59,7 +60,12 @@ public CreateIngestPipelineStep(Client client) { } @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture createIngestPipelineFuture = new CompletableFuture<>(); @@ -71,6 +77,12 @@ public CompletableFuture execute(List data) { String outputFieldName = null; BytesReference configuration = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + // Extract required content from workflow data and generate the ingest pipeline configuration for (WorkflowData workflowData : data) { @@ -126,7 +138,11 @@ public CompletableFuture execute(List data) { // PutPipelineRequest returns only an AcknowledgeResponse, returning pipelineId instead createIngestPipelineFuture.complete( - new WorkflowData(Map.of(PIPELINE_ID, putPipelineRequest.getId()), data.get(0).getWorkflowId()) + new WorkflowData( + Map.of(PIPELINE_ID, putPipelineRequest.getId()), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) ); // TODO : Use node client to index response data to global context (pending global context index implementation) diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index 8ce89176c..aa6768605 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -17,6 +17,7 @@ import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -41,7 +42,12 @@ public DeployModelStep(MachineLearningNodeClient mlClient) { } @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture deployModelFuture = new CompletableFuture<>(); @@ -52,7 +58,8 @@ public void onResponse(MLDeployModelResponse mlDeployModelResponse) { deployModelFuture.complete( new WorkflowData( Map.ofEntries(Map.entry("deploy_model_status", mlDeployModelResponse.getStatus())), - data.get(0).getWorkflowId() + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() ) ); } @@ -66,6 +73,12 @@ public void onFailure(Exception e) { String modelId = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + for (WorkflowData workflowData : data) { if (workflowData.getContent().containsKey(MODEL_ID)) { modelId = (String) workflowData.getContent().get(MODEL_ID); diff --git a/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java b/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java index bb57adbae..018783b19 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java @@ -20,6 +20,7 @@ import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTaskState; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -50,12 +51,23 @@ public GetMLTaskStep(Settings settings, ClusterService clusterService, MachineLe } @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture getMLTaskFuture = new CompletableFuture<>(); String taskId = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + for (WorkflowData workflowData : data) { Map content = workflowData.getContent(); for (Entry entry : content.entrySet()) { @@ -73,7 +85,7 @@ public CompletableFuture execute(List data) { logger.error("Failed to retrieve ML Task"); getMLTaskFuture.completeExceptionally(new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST)); } else { - retryableGetMlTask(data.get(0).getWorkflowId(), getMLTaskFuture, taskId, 0); + retryableGetMlTask(currentNodeInputs.getWorkflowId(), currentNodeInputs.getNodeId(), getMLTaskFuture, taskId, 0); } return getMLTaskFuture; @@ -87,11 +99,18 @@ public String getName() { /** * Retryable GetMLTask * @param workflowId the workflow id + * @param nodeId the node id * @param getMLTaskFuture the workflow step future * @param taskId the ml task id * @param retries the current number of request retries */ - protected void retryableGetMlTask(String workflowId, CompletableFuture getMLTaskFuture, String taskId, int retries) { + protected void retryableGetMlTask( + String workflowId, + String nodeId, + CompletableFuture getMLTaskFuture, + String taskId, + int retries + ) { mlClient.getTask(taskId, ActionListener.wrap(response -> { if (response.getState() != MLTaskState.COMPLETED) { throw new IllegalStateException("MLTask is not yet completed"); @@ -103,7 +122,8 @@ protected void retryableGetMlTask(String workflowId, CompletableFuture execute(List data) throws IOException { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { CompletableFuture registerModelGroupFuture = new CompletableFuture<>(); @@ -67,7 +72,8 @@ public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse Map.entry("model_group_id", mlRegisterModelGroupResponse.getModelGroupId()), Map.entry("model_group_status", mlRegisterModelGroupResponse.getStatus()) ), - data.get(0).getWorkflowId() + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() ) ); } @@ -85,6 +91,12 @@ public void onFailure(Exception e) { AccessMode modelAccessMode = null; Boolean isAddAllBackendRoles = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + for (WorkflowData workflowData : data) { Map content = workflowData.getContent(); diff --git a/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java b/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java index 098c5626c..bbf325e46 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java @@ -9,7 +9,7 @@ package org.opensearch.flowframework.workflow; import java.io.IOException; -import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; /** @@ -21,7 +21,12 @@ public class NoOpStep implements WorkflowStep { public static final String NAME = "noop"; @Override - public CompletableFuture execute(List data) throws IOException { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { return CompletableFuture.completedFuture(WorkflowData.EMPTY); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java index 729043074..c6bdcc6a5 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java @@ -14,7 +14,7 @@ import org.opensearch.threadpool.Scheduler.ScheduledCancellable; import org.opensearch.threadpool.ThreadPool; -import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -152,10 +152,10 @@ public CompletableFuture execute() { logger.info("Starting {}.", this.id); // get the input data from predecessor(s) - List input = new ArrayList(); - input.add(this.input); + Map inputMap = new HashMap<>(); for (CompletableFuture cf : predFutures) { - input.add(cf.get()); + WorkflowData wd = cf.get(); + inputMap.put(wd.getNodeId(), wd); } ScheduledCancellable delayExec = null; @@ -167,7 +167,12 @@ public CompletableFuture execute() { }, this.nodeTimeout, ThreadPool.Names.SAME); } // record start time for this step. - CompletableFuture stepFuture = this.workflowStep.execute(input); + CompletableFuture stepFuture = this.workflowStep.execute( + this.id, + this.input, + inputMap, + this.previousNodeInputs + ); // If completed exceptionally, this is a no-op future.complete(stepFuture.get()); // record end time passing workflow steps diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java index ad6cbff8f..27aa5e537 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java @@ -24,6 +24,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -64,7 +65,12 @@ public RegisterLocalModelStep(MachineLearningNodeClient mlClient) { } @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture registerLocalModelFuture = new CompletableFuture<>(); @@ -78,7 +84,8 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { Map.entry(TASK_ID, mlRegisterModelResponse.getTaskId()), Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) ), - data.get(0).getWorkflowId() + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() ) ); } @@ -102,6 +109,12 @@ public void onFailure(Exception e) { String allConfig = null; String url = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + for (WorkflowData workflowData : data) { Map content = workflowData.getContent(); diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index de889b720..e41323a14 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -20,6 +20,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; +import java.util.ArrayList; import java.util.List; import java.util.Locale; import java.util.Map; @@ -55,7 +56,12 @@ public RegisterRemoteModelStep(MachineLearningNodeClient mlClient) { } @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture registerRemoteModelFuture = new CompletableFuture<>(); @@ -69,7 +75,8 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { Map.entry(MODEL_ID, mlRegisterModelResponse.getModelId()), Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) ), - data.get(0).getWorkflowId() + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() ) ); } @@ -87,6 +94,12 @@ public void onFailure(Exception e) { String description = null; String connectorId = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + // TODO : Handle inline connector configuration : https://github.com/opensearch-project/flow-framework/issues/149 for (WorkflowData workflowData : data) { diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java index 4f62885e9..a0d901f74 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java @@ -28,18 +28,21 @@ public class WorkflowData { @Nullable private String workflowId; + @Nullable + private String nodeId; private WorkflowData() { - this(Collections.emptyMap(), Collections.emptyMap(), ""); + this(Collections.emptyMap(), Collections.emptyMap(), null, null); } /** * Instantiate this object with content and empty params. * @param content The content map * @param workflowId The workflow ID associated with this step + * @param nodeId The node ID associated with this step */ - public WorkflowData(Map content, @Nullable String workflowId) { - this(content, Collections.emptyMap(), workflowId); + public WorkflowData(Map content, @Nullable String workflowId, @Nullable String nodeId) { + this(content, Collections.emptyMap(), workflowId, nodeId); } /** @@ -47,11 +50,13 @@ public WorkflowData(Map content, @Nullable String workflowId) { * @param content The content map * @param params The params map * @param workflowId The workflow ID associated with this step + * @param nodeId The node ID associated with this step */ - public WorkflowData(Map content, Map params, @Nullable String workflowId) { + public WorkflowData(Map content, Map params, @Nullable String workflowId, @Nullable String nodeId) { this.content = Map.copyOf(content); this.params = Map.copyOf(params); this.workflowId = workflowId; + this.nodeId = nodeId; } /** @@ -72,11 +77,20 @@ public Map getParams() { }; /** - * Returns the workflowId associated with this workflow. + * Returns the workflowId associated with this data. * @return the workflowId of this data. */ @Nullable public String getWorkflowId() { return this.workflowId; }; + + /** + * Returns the nodeId associated with this data. + * @return the nodeId of this data. + */ + @Nullable + public String getNodeId() { + return this.nodeId; + }; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index da0705eea..3e8b77f9d 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -70,7 +70,7 @@ public List sortProcessNodes(Workflow workflow, String workflowId) Map idToNodeMap = new HashMap<>(); for (WorkflowNode node : sortedNodes) { WorkflowStep step = workflowStepFactory.createStep(node.type()); - WorkflowData data = new WorkflowData(node.userInputs(), workflow.userParams(), workflowId); + WorkflowData data = new WorkflowData(node.userInputs(), workflow.userParams(), workflowId, node.id()); List predecessorNodes = workflow.edges() .stream() .filter(e -> e.destination().equals(node.id())) diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java index 1f5545cdf..f106ee652 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java @@ -9,7 +9,7 @@ package org.opensearch.flowframework.workflow; import java.io.IOException; -import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; /** @@ -19,11 +19,19 @@ public interface WorkflowStep { /** * Triggers the actual processing of the building block. - * @param data representing input params and content, or output content of previous steps. The first element of the list is data (if any) provided from parsing the template, and may be {@link WorkflowData#EMPTY}. + * @param currentNodeId The id of the node executing this step + * @param currentNodeInputs Input params and content for this node, from workflow parsing + * @param previousNodeInputs Input params for this node that come from previous steps + * @param outputs WorkflowData content of previous steps. * @return A CompletableFuture of the building block. This block should return immediately, but not be completed until the step executes, containing either the step's output data or {@link WorkflowData#EMPTY} which may be passed to follow-on steps. * @throws IOException on a failure. */ - CompletableFuture execute(List data) throws IOException; + CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException; /** * Gets the name of the workflow step. diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index e26fdf0c4..de3add996 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -20,7 +20,7 @@ import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; -import java.util.List; +import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -71,7 +71,8 @@ public void setUp() throws Exception { Map.entry(CommonValue.CREDENTIAL_FIELD, credentials), Map.entry(CommonValue.ACTIONS_FIELD, actions) ), - "test-id" + "test-id", + "test-node-id" ); } @@ -90,7 +91,12 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr return null; }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); - CompletableFuture future = createConnectorStep.execute(List.of(inputData)); + CompletableFuture future = createConnectorStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); @@ -111,7 +117,12 @@ public void testCreateConnectorFailure() throws IOException { return null; }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); - CompletableFuture future = createConnectorStep.execute(List.of(inputData)); + CompletableFuture future = createConnectorStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java index 67cb6cb9b..8be5c5787 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java @@ -24,8 +24,8 @@ import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; +import java.util.Collections; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -69,7 +69,7 @@ public class CreateIndexStepTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); MockitoAnnotations.openMocks(this); - inputData = new WorkflowData(Map.ofEntries(Map.entry("index_name", "demo"), Map.entry("type", "knn")), "test-id"); + inputData = new WorkflowData(Map.ofEntries(Map.entry("index_name", "demo"), Map.entry("type", "knn")), "test-id", "test-node-id"); clusterService = mock(ClusterService.class); client = mock(Client.class); adminClient = mock(AdminClient.class); @@ -91,7 +91,12 @@ public void setUp() throws Exception { public void testCreateIndexStep() throws ExecutionException, InterruptedException { @SuppressWarnings({ "unchecked" }) ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - CompletableFuture future = createIndexStep.execute(List.of(inputData)); + CompletableFuture future = createIndexStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertFalse(future.isDone()); verify(indicesAdminClient, times(1)).create(any(CreateIndexRequest.class), actionListenerCaptor.capture()); actionListenerCaptor.getValue().onResponse(new CreateIndexResponse(true, true, "demo")); @@ -106,7 +111,12 @@ public void testCreateIndexStep() throws ExecutionException, InterruptedExceptio public void testCreateIndexStepFailure() throws ExecutionException, InterruptedException { @SuppressWarnings({ "unchecked" }) ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - CompletableFuture future = createIndexStep.execute(List.of(inputData)); + CompletableFuture future = createIndexStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertFalse(future.isDone()); verify(indicesAdminClient, times(1)).create(any(CreateIndexRequest.class), actionListenerCaptor.capture()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java index 194c80eb0..f0a970758 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java @@ -16,7 +16,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.test.OpenSearchTestCase; -import java.util.List; +import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -51,11 +51,12 @@ public void setUp() throws Exception { Map.entry("input_field_name", "inputField"), Map.entry("output_field_name", "outputField") ), - "test-id" + "test-id", + "test-node-id" ); // Set output data to returned pipelineId - outpuData = new WorkflowData(Map.ofEntries(Map.entry("pipeline_id", "pipelineId")), "test-id"); + outpuData = new WorkflowData(Map.ofEntries(Map.entry("pipeline_id", "pipelineId")), "test-id", "test-node-id"); client = mock(Client.class); adminClient = mock(AdminClient.class); @@ -71,7 +72,12 @@ public void testCreateIngestPipelineStep() throws InterruptedException, Executio @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - CompletableFuture future = createIngestPipelineStep.execute(List.of(inputData)); + CompletableFuture future = createIngestPipelineStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertFalse(future.isDone()); @@ -89,7 +95,12 @@ public void testCreateIngestPipelineStepFailure() throws InterruptedException { @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - CompletableFuture future = createIngestPipelineStep.execute(List.of(inputData)); + CompletableFuture future = createIngestPipelineStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertFalse(future.isDone()); @@ -115,10 +126,16 @@ public void testMissingData() throws InterruptedException { Map.entry("type", "text_embedding"), Map.entry("model_id", "model_id") ), - "test-id" + "test-id", + "test-node-id" ); - CompletableFuture future = createIngestPipelineStep.execute(List.of(incorrectData)); + CompletableFuture future = createIngestPipelineStep.execute( + incorrectData.getNodeId(), + incorrectData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone() && future.isCompletedExceptionally()); ExecutionException exception = assertThrows(ExecutionException.class, () -> future.get()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index fd856b945..670933373 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -19,7 +19,7 @@ import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; import org.opensearch.test.OpenSearchTestCase; -import java.util.List; +import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -44,7 +44,7 @@ public class DeployModelStepTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); - inputData = new WorkflowData(Map.ofEntries(Map.entry("model_id", "modelId")), "test-id"); + inputData = new WorkflowData(Map.ofEntries(Map.entry("model_id", "modelId")), "test-id", "test-node-id"); MockitoAnnotations.openMocks(this); @@ -67,7 +67,12 @@ public void testDeployModel() throws ExecutionException, InterruptedException { return null; }).when(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); - CompletableFuture future = deployModel.execute(List.of(inputData)); + CompletableFuture future = deployModel.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); @@ -87,7 +92,12 @@ public void testDeployModelFailure() { return null; }).when(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); - CompletableFuture future = deployModel.execute(List.of(inputData)); + CompletableFuture future = deployModel.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java index efb59d42a..bd62ddfc7 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java @@ -21,7 +21,7 @@ import org.opensearch.ml.common.MLTaskState; import org.opensearch.test.OpenSearchTestCase; -import java.util.List; +import java.util.Collections; import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; @@ -70,7 +70,7 @@ public void setUp() throws Exception { when(clusterService.getClusterSettings()).thenReturn(clusterSettings); this.getMLTaskStep = spy(new GetMLTaskStep(testMaxRetrySetting, clusterService, mlNodeClient)); - this.workflowData = new WorkflowData(Map.ofEntries(Map.entry(TASK_ID, "test")), "test-id"); + this.workflowData = new WorkflowData(Map.ofEntries(Map.entry(TASK_ID, "test")), "test-id", "test-node-id"); } public void testGetMLTaskSuccess() throws Exception { @@ -85,7 +85,12 @@ public void testGetMLTaskSuccess() throws Exception { return null; }).when(mlNodeClient).getTask(any(), any()); - CompletableFuture future = this.getMLTaskStep.execute(List.of(workflowData)); + CompletableFuture future = this.getMLTaskStep.execute( + workflowData.getNodeId(), + workflowData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(mlNodeClient, times(1)).getTask(any(), any()); @@ -102,7 +107,12 @@ public void testGetMLTaskFailure() { return null; }).when(mlNodeClient).getTask(any(), any()); - CompletableFuture future = this.getMLTaskStep.execute(List.of(workflowData)); + CompletableFuture future = this.getMLTaskStep.execute( + workflowData.getNodeId(), + workflowData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone()); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); @@ -111,7 +121,12 @@ public void testGetMLTaskFailure() { } public void testMissingInputs() { - CompletableFuture future = this.getMLTaskStep.execute(List.of(WorkflowData.EMPTY)); + CompletableFuture future = this.getMLTaskStep.execute( + "nodeID", + WorkflowData.EMPTY, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone()); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java index f763c8005..bc914baa7 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java @@ -20,7 +20,7 @@ import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; -import java.util.List; +import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -29,7 +29,6 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.verify; @@ -54,7 +53,8 @@ public void setUp() throws Exception { Map.entry("access_mode", AccessMode.PUBLIC), Map.entry("add_all_backend_roles", false) ), - "test-id" + "test-id", + "test-node-id" ); } @@ -75,7 +75,12 @@ public void testRegisterModelGroup() throws ExecutionException, InterruptedExcep return null; }).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); - CompletableFuture future = modelGroupStep.execute(List.of(inputData)); + CompletableFuture future = modelGroupStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); @@ -97,7 +102,12 @@ public void testRegisterModelGroupFailure() throws ExecutionException, Interrupt return null; }).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); - CompletableFuture future = modelGroupStep.execute(List.of(inputData)); + CompletableFuture future = modelGroupStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); @@ -111,7 +121,12 @@ public void testRegisterModelGroupFailure() throws ExecutionException, Interrupt public void testRegisterModelGroupWithNoName() throws IOException { ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient); - CompletableFuture future = modelGroupStep.execute(List.of(inputDataWithNoName)); + CompletableFuture future = modelGroupStep.execute( + inputDataWithNoName.getNodeId(), + inputDataWithNoName, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java index 6c03cd87d..1782375cc 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java @@ -19,7 +19,12 @@ public class NoOpStepTests extends OpenSearchTestCase { public void testNoOpStep() throws IOException { NoOpStep noopStep = new NoOpStep(); assertEquals(NoOpStep.NAME, noopStep.getName()); - CompletableFuture future = noopStep.execute(Collections.emptyList()); + CompletableFuture future = noopStep.execute( + "nodeId", + WorkflowData.EMPTY, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone()); assertFalse(future.isCompletedExceptionally()); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java index 6aae139e4..f50250ea5 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java @@ -56,9 +56,14 @@ public void testNode() throws InterruptedException, ExecutionException { // Tests where execute nas no timeout ProcessNode nodeA = new ProcessNode("A", new WorkflowStep() { @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture f = new CompletableFuture<>(); - f.complete(new WorkflowData(Map.of("test", "output"), "test-id")); + f.complete(new WorkflowData(Map.of("test", "output"), "test-id", "test-node-id")); return f; } @@ -68,7 +73,7 @@ public String getName() { } }, Map.of(), - new WorkflowData(Map.of("test", "input"), Map.of("foo", "bar"), "test-id"), + new WorkflowData(Map.of("test", "input"), Map.of("foo", "bar"), "test-id", "test-node-id"), List.of(successfulNode), testThreadPool, TimeValue.timeValueMillis(50) @@ -78,6 +83,7 @@ public String getName() { assertEquals("input", nodeA.input().getContent().get("test")); assertEquals("bar", nodeA.input().getParams().get("foo")); assertEquals("test-id", nodeA.input().getWorkflowId()); + assertEquals("test-node-id", nodeA.input().getNodeId()); assertEquals(1, nodeA.predecessors().size()); assertEquals(50, nodeA.nodeTimeout().millis()); assertEquals("A", nodeA.toString()); @@ -91,7 +97,12 @@ public void testNodeNoTimeout() throws InterruptedException, ExecutionException // Tests where execute finishes before timeout ProcessNode nodeB = new ProcessNode("B", new WorkflowStep() { @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture future = new CompletableFuture<>(); testThreadPool.schedule( () -> future.complete(WorkflowData.EMPTY), @@ -121,7 +132,12 @@ public void testNodeTimeout() throws InterruptedException, ExecutionException { // Tests where execute finishes after timeout ProcessNode nodeZ = new ProcessNode("Zzz", new WorkflowStep() { @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture future = new CompletableFuture<>(); testThreadPool.schedule(() -> future.complete(WorkflowData.EMPTY), TimeValue.timeValueMinutes(1), ThreadPool.Names.GENERIC); return future; @@ -148,7 +164,12 @@ public void testExceptions() { // Tests where a predecessor future completed exceptionally ProcessNode nodeE = new ProcessNode("E", new WorkflowStep() { @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture f = new CompletableFuture<>(); f.complete(WorkflowData.EMPTY); return f; diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java index bd40c50ad..b7b47de46 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java @@ -20,7 +20,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.test.OpenSearchTestCase; -import java.util.List; +import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -70,7 +70,8 @@ public void setUp() throws Exception { Map.entry("framework_type", "sentence_transformers"), Map.entry("url", "something.com") ), - "test-id" + "test-id", + "test-node-id" ); } @@ -87,7 +88,12 @@ public void testRegisterLocalModelSuccess() throws Exception { return null; }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); - CompletableFuture future = registerLocalModelStep.execute(List.of(workflowData)); + CompletableFuture future = registerLocalModelStep.execute( + workflowData.getNodeId(), + workflowData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); assertTrue(future.isDone()); @@ -105,7 +111,12 @@ public void testRegisterLocalModelFailure() { return null; }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); - CompletableFuture future = this.registerLocalModelStep.execute(List.of(workflowData)); + CompletableFuture future = this.registerLocalModelStep.execute( + workflowData.getNodeId(), + workflowData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone()); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); @@ -114,7 +125,12 @@ public void testRegisterLocalModelFailure() { } public void testMissingInputs() { - CompletableFuture future = registerLocalModelStep.execute(List.of(WorkflowData.EMPTY)); + CompletableFuture future = registerLocalModelStep.execute( + "nodeId", + WorkflowData.EMPTY, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone()); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index e60707f67..cde194326 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -18,7 +18,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.test.OpenSearchTestCase; -import java.util.List; +import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -55,7 +55,8 @@ public void setUp() throws Exception { Map.entry("description", "description"), Map.entry("connector_id", "abcdefg") ), - "test-id" + "test-id", + "test-node-id" ); } @@ -72,7 +73,12 @@ public void testRegisterRemoteModelSuccess() throws Exception { return null; }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); - CompletableFuture future = this.registerRemoteModelStep.execute(List.of(workflowData)); + CompletableFuture future = this.registerRemoteModelStep.execute( + workflowData.getNodeId(), + workflowData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(mlNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); @@ -90,7 +96,12 @@ public void testRegisterRemoteModelFailure() { return null; }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); - CompletableFuture future = this.registerRemoteModelStep.execute(List.of(workflowData)); + CompletableFuture future = this.registerRemoteModelStep.execute( + workflowData.getNodeId(), + workflowData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone()); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); @@ -100,7 +111,12 @@ public void testRegisterRemoteModelFailure() { } public void testMissingInputs() { - CompletableFuture future = this.registerRemoteModelStep.execute(List.of(WorkflowData.EMPTY)); + CompletableFuture future = this.registerRemoteModelStep.execute( + "nodeId", + WorkflowData.EMPTY, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone()); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java index 8a4a1fda9..39023c6b4 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java @@ -26,14 +26,17 @@ public void testWorkflowData() { assertTrue(empty.getContent().isEmpty()); Map expectedContent = Map.of("baz", new String[] { "qux", "quxx" }); - WorkflowData contentOnly = new WorkflowData(expectedContent, "test-id-123"); + WorkflowData contentOnly = new WorkflowData(expectedContent, null, null); assertTrue(contentOnly.getParams().isEmpty()); assertEquals(expectedContent, contentOnly.getContent()); + assertNull(contentOnly.getWorkflowId()); + assertNull(contentOnly.getNodeId()); Map expectedParams = Map.of("foo", "bar"); - WorkflowData contentAndParams = new WorkflowData(expectedContent, expectedParams, "test-id-123"); + WorkflowData contentAndParams = new WorkflowData(expectedContent, expectedParams, "test-id-123", "test-node-id"); assertEquals(expectedParams, contentAndParams.getParams()); assertEquals(expectedContent, contentAndParams.getContent()); assertEquals("test-id-123", contentAndParams.getWorkflowId()); + assertEquals("test-node-id", contentAndParams.getNodeId()); } } From ae6db3ceb457d4b9c6e19a9e25e7615d318e41d7 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Wed, 29 Nov 2023 09:53:16 -0800 Subject: [PATCH 3/5] Add Delete Connector Step Signed-off-by: Daniel Widdis --- .../workflow/DeleteConnectorStep.java | 88 +++++++++++++++++ .../workflow/WorkflowStepFactory.java | 15 +-- .../resources/mappings/workflow-steps.json | 8 ++ .../workflow/DeleteConnectorStepTests.java | 97 +++++++++++++++++++ 4 files changed, 195 insertions(+), 13 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java create mode 100644 src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java new file mode 100644 index 000000000..c2cb77dfa --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +import static org.opensearch.flowframework.common.CommonValue.CONNECTOR_ID; + +/** + * Step to delete a connector for a remote model + */ +public class DeleteConnectorStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(DeleteConnectorStep.class); + + private MachineLearningNodeClient mlClient; + + static final String NAME = "delete_connector"; + + /** + * Instantiate this class + * @param mlClient Machine Learning client to perform the deletion + */ + public DeleteConnectorStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; + } + + @Override + public CompletableFuture execute(List data) throws IOException { + CompletableFuture deleteConnectorFuture = new CompletableFuture<>(); + + ActionListener actionListener = new ActionListener<>() { + + @Override + public void onResponse(DeleteResponse deleteResponse) { + deleteConnectorFuture.complete( + new WorkflowData(Map.ofEntries(Map.entry("connector_id", deleteResponse.getId())), data.get(0).getWorkflowId()) + ); + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to delete connector"); + deleteConnectorFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + }; + + Optional connectorId = data.stream() + .map(WorkflowData::getContent) + .filter(m -> m.containsKey(CONNECTOR_ID)) + .map(m -> m.get(CONNECTOR_ID).toString()) + .findFirst(); + + if (connectorId.isPresent()) { + mlClient.deleteConnector(connectorId.get(), actionListener); + } else { + deleteConnectorFuture.completeExceptionally( + new FlowFrameworkException("Required field " + CONNECTOR_ID + " is not provided", RestStatus.BAD_REQUEST) + ); + } + + return deleteConnectorFuture; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 2e450d5b0..babb468b7 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -25,7 +25,6 @@ public class WorkflowStepFactory { private final Map stepMap = new HashMap<>(); - private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; /** * Instantiate this class. @@ -42,17 +41,6 @@ public WorkflowStepFactory( Client client, MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler - ) { - this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; - populateMap(settings, clusterService, client, mlClient, flowFrameworkIndicesHandler); - } - - private void populateMap( - Settings settings, - ClusterService clusterService, - Client client, - MachineLearningNodeClient mlClient, - FlowFrameworkIndicesHandler flowFrameworkIndicesHandler ) { stepMap.put(NoOpStep.NAME, new NoOpStep()); stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client)); @@ -61,6 +49,7 @@ private void populateMap( stepMap.put(RegisterRemoteModelStep.NAME, new RegisterRemoteModelStep(mlClient)); stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient)); stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); + stepMap.put(DeleteConnectorStep.NAME, new DeleteConnectorStep(mlClient)); stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient)); stepMap.put(GetMLTaskStep.NAME, new GetMLTaskStep(settings, clusterService, mlClient)); } @@ -79,7 +68,7 @@ public WorkflowStep createStep(String type) { /** * Gets the step map - * @return the step map + * @return a read-only copy of the step map */ public Map getStepMap() { return Map.copyOf(this.stepMap); diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index 5bd88147b..5769daa90 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -39,6 +39,14 @@ "connector_id" ] }, + "delete_connector": { + "inputs": [ + "connector_id" + ], + "outputs":[ + "connector_id" + ] + }, "register_local_model": { "inputs":[ "name", diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java new file mode 100644 index 000000000..5cdde128c --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java @@ -0,0 +1,97 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.common.CommonValue; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; + +public class DeleteConnectorStepTests extends OpenSearchTestCase { + private WorkflowData inputData; + + @Mock + MachineLearningNodeClient machineLearningNodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + MockitoAnnotations.openMocks(this); + + inputData = new WorkflowData(Map.of(CommonValue.CONNECTOR_ID, "test"), "test-id"); + } + + public void testDeleteConnector() throws IOException, ExecutionException, InterruptedException { + + String connectorId = "connectorId"; + DeleteConnectorStep deleteConnectorStep = new DeleteConnectorStep(machineLearningNodeClient); + + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + DeleteResponse output = new DeleteResponse(shardId, connectorId, 1, 1, 1, true); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).deleteConnector(any(String.class), actionListenerCaptor.capture()); + + CompletableFuture future = deleteConnectorStep.execute(List.of(inputData)); + + verify(machineLearningNodeClient).deleteConnector(any(String.class), actionListenerCaptor.capture()); + + assertTrue(future.isDone()); + assertEquals(connectorId, future.get().getContent().get("connector_id")); + + } + + public void testDeleteConnectorFailure() throws IOException { + DeleteConnectorStep deleteConnectorStep = new DeleteConnectorStep(machineLearningNodeClient); + + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new FlowFrameworkException("Failed to create connector", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(machineLearningNodeClient).deleteConnector(any(String.class), actionListenerCaptor.capture()); + + CompletableFuture future = deleteConnectorStep.execute(List.of(inputData)); + + verify(machineLearningNodeClient).deleteConnector(any(String.class), actionListenerCaptor.capture()); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to create connector", ex.getCause().getMessage()); + } +} From f84ca841dee7ee668f556c64ca5901f3573a5ca8 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Wed, 29 Nov 2023 10:38:22 -0800 Subject: [PATCH 4/5] Add eclipse core runtime version resolution Signed-off-by: Daniel Widdis --- build.gradle | 1 + 1 file changed, 1 insertion(+) diff --git a/build.gradle b/build.gradle index 6b66247a6..0109ec935 100644 --- a/build.gradle +++ b/build.gradle @@ -148,6 +148,7 @@ dependencies { configurations.all { resolutionStrategy { force("com.google.guava:guava:32.1.3-jre") // CVE for 31.1 + force("org.eclipse.platform:org.eclipse.core.runtime:3.29.0") // CVE for < 3.29.0 force("com.fasterxml.jackson.core:jackson-core:2.16.0") // Dependency Jar Hell } } From a9b4a83502b2db6c2345df8e40c3980ad328d785 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Wed, 29 Nov 2023 10:51:12 -0800 Subject: [PATCH 5/5] Use JDK17 for spotless Signed-off-by: Daniel Widdis --- .github/workflows/CI.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index d85c3c361..bbfc48aeb 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -14,6 +14,11 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + # Spotless requires JDK 17+ + - uses: actions/setup-java@v3 + with: + java-version: 17 + distribution: temurin - name: Spotless Check run: ./gradlew spotlessCheck build: @@ -41,7 +46,7 @@ jobs: - uses: actions/checkout@v4 - name: Build and Run Tests run: | - ./gradlew check + ./gradlew check -x spotlessJava - name: Upload Coverage Report if: matrix.os == 'ubuntu-latest' uses: codecov/codecov-action@v3