diff --git a/build.gradle b/build.gradle index fc3b629bf..0109ec935 100644 --- a/build.gradle +++ b/build.gradle @@ -142,6 +142,8 @@ dependencies { api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.14.0' implementation "org.opensearch:common-utils:${common_utils_version}" + implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1' + implementation 'org.bouncycastle:bcprov-jdk18on:1.77' configurations.all { resolutionStrategy { diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 4b2e4a9cb..14df7e17e 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -39,6 +39,7 @@ import org.opensearch.flowframework.transport.ProvisionWorkflowTransportAction; import org.opensearch.flowframework.transport.SearchWorkflowAction; import org.opensearch.flowframework.transport.SearchWorkflowTransportAction; +import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.ml.client.MachineLearningNodeClient; @@ -96,7 +97,8 @@ public Collection createComponents( this.clusterService = clusterService; flowFrameworkFeatureEnabledSetting = new FlowFrameworkFeatureEnabledSetting(clusterService, settings); MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); - FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService); + EncryptorUtils encryptorUtils = new EncryptorUtils(clusterService, client); + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService, encryptorUtils); WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory( settings, clusterService, @@ -106,7 +108,7 @@ public Collection createComponents( ); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); - return ImmutableList.of(workflowStepFactory, workflowProcessSorter, flowFrameworkIndicesHandler); + return ImmutableList.of(workflowStepFactory, workflowProcessSorter, encryptorUtils, flowFrameworkIndicesHandler); } @Override diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 1f9f33f4b..90f208c8d 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -33,6 +33,16 @@ private CommonValue() {} public static final String WORKFLOW_STATE_INDEX_MAPPING = "mappings/workflow-state.json"; /** Workflow State index mapping version */ public static final Integer WORKFLOW_STATE_INDEX_VERSION = 1; + /** Config Index Name */ + public static final String CONFIG_INDEX = ".plugins-flow-framework-config"; + /** Config index mapping file path */ + public static final String CONFIG_INDEX_MAPPING = "mappings/config.json"; + /** Config index mapping version */ + public static final Integer CONFIG_INDEX_VERSION = 1; + /** Master key field name */ + public static final String MASTER_KEY = "master_key"; + /** Create Time field name */ + public static final String CREATE_TIME = "create_time"; /** The template field name for template use case */ public static final String USE_CASE_FIELD = "use_case"; @@ -119,8 +129,8 @@ private CommonValue() {} public static final String PROTOCOL_FIELD = "protocol"; /** Connector parameters field */ public static final String PARAMETERS_FIELD = "parameters"; - /** Connector credentials field */ - public static final String CREDENTIALS_FIELD = "credentials"; + /** Connector credential field */ + public static final String CREDENTIAL_FIELD = "credential"; /** Connector actions field */ public static final String ACTIONS_FIELD = "actions"; /** Backend roles for the model */ diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java index 4b005e45d..b74c1d89a 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java @@ -12,6 +12,8 @@ import java.util.function.Supplier; +import static org.opensearch.flowframework.common.CommonValue.CONFIG_INDEX; +import static org.opensearch.flowframework.common.CommonValue.CONFIG_INDEX_VERSION; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX_VERSION; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; @@ -36,6 +38,14 @@ public enum FlowFrameworkIndex { WORKFLOW_STATE_INDEX, ThrowingSupplierWrapper.throwingSupplierWrapper(FlowFrameworkIndicesHandler::getWorkflowStateMappings), WORKFLOW_STATE_INDEX_VERSION + ), + /** + * Config Index + */ + CONFIG( + CONFIG_INDEX, + ThrowingSupplierWrapper.throwingSupplierWrapper(FlowFrameworkIndicesHandler::getConfigIndexMappings), + CONFIG_INDEX_VERSION ); private final String indexName; diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index 7dcc89de6..cce4ba839 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -38,6 +38,7 @@ import org.opensearch.flowframework.model.State; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.script.Script; import java.io.IOException; @@ -48,6 +49,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; +import static org.opensearch.flowframework.common.CommonValue.CONFIG_INDEX_MAPPING; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX_MAPPING; import static org.opensearch.flowframework.common.CommonValue.META; @@ -64,6 +66,7 @@ public class FlowFrameworkIndicesHandler { private static final Logger logger = LogManager.getLogger(FlowFrameworkIndicesHandler.class); private final Client client; private final ClusterService clusterService; + private final EncryptorUtils encryptorUtils; private static final Map indexMappingUpdated = new HashMap<>(); private static final Map indexSettings = Map.of("index.auto_expand_replicas", "0-1"); @@ -71,10 +74,12 @@ public class FlowFrameworkIndicesHandler { * constructor * @param client the open search client * @param clusterService ClusterService + * @param encryptorUtils encryption utility */ - public FlowFrameworkIndicesHandler(Client client, ClusterService clusterService) { + public FlowFrameworkIndicesHandler(Client client, ClusterService clusterService, EncryptorUtils encryptorUtils) { this.client = client; this.clusterService = clusterService; + this.encryptorUtils = encryptorUtils; for (FlowFrameworkIndex mlIndex : FlowFrameworkIndex.values()) { indexMappingUpdated.put(mlIndex.getIndexName(), new AtomicBoolean(false)); } @@ -104,6 +109,15 @@ public static String getWorkflowStateMappings() throws IOException { return getIndexMappings(WORKFLOW_STATE_INDEX_MAPPING); } + /** + * Get config index mapping + * @return config index mapping + * @throws IOException if mapping file cannot be read correctly + */ + public static String getConfigIndexMappings() throws IOException { + return getIndexMappings(CONFIG_INDEX_MAPPING); + } + /** * Create global context index if it's absent * @param listener The action listener @@ -120,6 +134,14 @@ public void initWorkflowStateIndexIfAbsent(ActionListener listener) { initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex.WORKFLOW_STATE, listener); } + /** + * Create config index if it's absent + * @param listener The action listener + */ + public void initConfigIndexIfAbsent(ActionListener listener) { + initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex.CONFIG, listener); + } + /** * Checks if the given index exists * @param indexName the name of the index @@ -287,7 +309,8 @@ public void putTemplateToGlobalContext(Template template, ActionListener context.restore())); } catch (Exception e) { @@ -301,6 +324,23 @@ public void putTemplateToGlobalContext(Template template, ActionListener listener) { + initConfigIndexIfAbsent(ActionListener.wrap(indexCreated -> { + if (!indexCreated) { + listener.onFailure(new FlowFrameworkException("No response to create config index", INTERNAL_SERVER_ERROR)); + return; + } + encryptorUtils.initializeMasterKey(listener); + }, createIndexException -> { + logger.error("Failed to create config index", createIndexException); + listener.onFailure(createIndexException); + })); + } + /** * add document insert into global context index * @param workflowId the workflowId, corresponds to document ID of @@ -361,7 +401,8 @@ public void updateTemplateInGlobalContext(String documentId, Template template, XContentBuilder builder = XContentFactory.jsonBuilder(); ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext() ) { - request.source(template.toXContent(builder, ToXContent.EMPTY_PARAMS)) + Template encryptedTemplate = encryptorUtils.encryptTemplateCredentials(template); + request.source(encryptedTemplate.toXContent(builder, ToXContent.EMPTY_PARAMS)) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); client.index(request, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index a77e7dd79..6ca1c4661 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -119,25 +119,49 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { - flowFrameworkIndicesHandler.putInitialStateToWorkflowState( - globalContextResponse.getId(), - user, - ActionListener.wrap(stateResponse -> { - logger.info("create state workflow doc"); - listener.onResponse(new WorkflowResponse(globalContextResponse.getId())); - }, exception -> { - logger.error("Failed to save workflow state : {}", exception.getMessage()); - if (exception instanceof FlowFrameworkException) { - listener.onFailure(exception); - } else { - listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST)); - } - }) - ); + // Initialize config index and create new global context and state index entries + flowFrameworkIndicesHandler.initializeConfigIndex(ActionListener.wrap(isInitialized -> { + if (!isInitialized) { + listener.onFailure( + new FlowFrameworkException("Failed to initalize config index", RestStatus.INTERNAL_SERVER_ERROR) + ); + } else { + // Create new global context and state index entries + flowFrameworkIndicesHandler.putTemplateToGlobalContext( + templateWithUser, + ActionListener.wrap(globalContextResponse -> { + flowFrameworkIndicesHandler.putInitialStateToWorkflowState( + globalContextResponse.getId(), + user, + ActionListener.wrap(stateResponse -> { + logger.info("create state workflow doc"); + listener.onResponse(new WorkflowResponse(globalContextResponse.getId())); + }, exception -> { + logger.error("Failed to save workflow state : {}", exception.getMessage()); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure( + new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST) + ); + } + }) + ); + }, exception -> { + logger.error("Failed to save use case template : {}", exception.getMessage()); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + } + + }) + ); + } }, exception -> { - logger.error("Failed to save use case template : {}", exception.getMessage()); + logger.error("Failed to initialize config index : {}", exception.getMessage()); if (exception instanceof FlowFrameworkException) { listener.onFailure(exception); } else { diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 1f19a4d04..b381b41ec 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -26,6 +26,7 @@ import org.opensearch.flowframework.model.State; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.tasks.Task; @@ -61,6 +62,7 @@ public class ProvisionWorkflowTransportAction extends HandledTransportAction provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId); diff --git a/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java b/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java new file mode 100644 index 000000000..70b30b5cc --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java @@ -0,0 +1,286 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.util; + +import com.google.common.collect.ImmutableMap; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowNode; + +import javax.crypto.spec.SecretKeySpec; + +import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Base64; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import com.amazonaws.encryptionsdk.AwsCrypto; +import com.amazonaws.encryptionsdk.CommitmentPolicy; +import com.amazonaws.encryptionsdk.CryptoResult; +import com.amazonaws.encryptionsdk.jce.JceMasterKey; + +import static org.opensearch.flowframework.common.CommonValue.CONFIG_INDEX; +import static org.opensearch.flowframework.common.CommonValue.CREATE_TIME; +import static org.opensearch.flowframework.common.CommonValue.CREDENTIAL_FIELD; +import static org.opensearch.flowframework.common.CommonValue.MASTER_KEY; + +/** + * Encryption utility class + */ +public class EncryptorUtils { + + private static final Logger logger = LogManager.getLogger(EncryptorUtils.class); + + private static final String ALGORITHM = "AES"; + private static final String PROVIDER = "Custom"; + private static final String WRAPPING_ALGORITHM = "AES/GCM/NoPadding"; + + private ClusterService clusterService; + private Client client; + private String masterKey; + + /** + * Instantiates a new EncryptorUtils object + * @param clusterService the cluster service + * @param client the node client + */ + public EncryptorUtils(ClusterService clusterService, Client client) { + this.masterKey = null; + this.clusterService = clusterService; + this.client = client; + } + + /** + * Sets the master key + * @param masterKey the master key + */ + void setMasterKey(String masterKey) { + this.masterKey = masterKey; + } + + /** + * Returns the master key + * @return the master key + */ + String getMasterKey() { + return this.masterKey; + } + + /** + * Randomly generate a master key + * @return the master key string + */ + String generateMasterKey() { + byte[] masterKeyBytes = new byte[32]; + new SecureRandom().nextBytes(masterKeyBytes); + return Base64.getEncoder().encodeToString(masterKeyBytes); + } + + /** + * Encrypts template credentials + * @param template the template to encrypt + * @return template with encrypted credentials + */ + public Template encryptTemplateCredentials(Template template) { + return processTemplateCredentials(template, this::encrypt); + } + + /** + * Decrypts template credentials + * @param template the template to decrypt + * @return template with decrypted credentials + */ + public Template decryptTemplateCredentials(Template template) { + return processTemplateCredentials(template, this::decrypt); + } + + // TODO : Improve processTemplateCredentials to encrypt different fields based on the WorkflowStep type + /** + * Applies the given cipher function on template credentials + * @param template the template to process + * @param cipher the encryption/decryption function to apply on credential values + * @return template with encrypted credentials + */ + private Template processTemplateCredentials(Template template, Function cipherFunction) { + Template.Builder processedTemplateBuilder = new Template.Builder(); + + Map processedWorkflows = new HashMap<>(); + for (Map.Entry entry : template.workflows().entrySet()) { + + List processedNodes = new ArrayList<>(); + for (WorkflowNode node : entry.getValue().nodes()) { + if (node.userInputs().containsKey(CREDENTIAL_FIELD)) { + // Apply the cipher funcion on all values within credential field + @SuppressWarnings("unchecked") + Map credentials = new HashMap<>((Map) node.userInputs().get(CREDENTIAL_FIELD)); + credentials.replaceAll((key, cred) -> cipherFunction.apply(cred)); + + // Replace credentials field in node user inputs + Map processedUserInputs = new HashMap<>(); + processedUserInputs.putAll(node.userInputs()); + processedUserInputs.replace(CREDENTIAL_FIELD, credentials); + + // build new node to add to processed nodes + WorkflowNode processedWorkflowNode = new WorkflowNode( + node.id(), + node.type(), + node.previousNodeInputs(), + processedUserInputs + ); + processedNodes.add(processedWorkflowNode); + } else { + processedNodes.add(node); + } + } + + // Add processed workflow nodes to processed workflows + processedWorkflows.put(entry.getKey(), new Workflow(entry.getValue().userParams(), processedNodes, entry.getValue().edges())); + } + + Template processedTemplate = processedTemplateBuilder.name(template.name()) + .description(template.description()) + .useCase(template.useCase()) + .templateVersion(template.templateVersion()) + .compatibilityVersion(template.compatibilityVersion()) + .workflows(processedWorkflows) + .uiMetadata(template.getUiMetadata()) + .user(template.getUser()) + .build(); + + return processedTemplate; + } + + /** + * Encrypts the given credential + * @param credential the credential to encrypt + * @return the encrypted credential + */ + String encrypt(final String credential) { + initializeMasterKeyIfAbsent(); + final AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt).build(); + byte[] bytes = Base64.getDecoder().decode(masterKey); + JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, ALGORITHM), PROVIDER, "", WRAPPING_ALGORITHM); + final CryptoResult encryptResult = crypto.encryptData( + jceMasterKey, + credential.getBytes(StandardCharsets.UTF_8) + ); + return Base64.getEncoder().encodeToString(encryptResult.getResult()); + } + + /** + * Decrypts the given credential + * @param encryptedCredential the credential to decrypt + * @return the decrypted credential + */ + String decrypt(final String encryptedCredential) { + initializeMasterKeyIfAbsent(); + final AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt).build(); + + byte[] bytes = Base64.getDecoder().decode(masterKey); + JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, ALGORITHM), PROVIDER, "", WRAPPING_ALGORITHM); + + final CryptoResult decryptedResult = crypto.decryptData( + jceMasterKey, + Base64.getDecoder().decode(encryptedCredential) + ); + return new String(decryptedResult.getResult(), StandardCharsets.UTF_8); + } + + /** + * Retrieves an existing master key or generates a new key to index + * @param listener the action listener + */ + public void initializeMasterKey(ActionListener listener) { + // Index has either been created or it already exists, need to check if master key has been initalized already, if not then + // generate + // This is necessary in case of global context index restoration from snapshot, will need to use the same master key to decrypt + // stored credentials + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + + GetRequest getRequest = new GetRequest(CONFIG_INDEX).id(MASTER_KEY); + client.get(getRequest, ActionListener.wrap(getResponse -> { + + if (!getResponse.isExists()) { + + // Generate new key and index + final String masterKey = generateMasterKey(); + IndexRequest masterKeyIndexRequest = new IndexRequest(CONFIG_INDEX).id(MASTER_KEY) + .source(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME, Instant.now().toEpochMilli())) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + client.index(masterKeyIndexRequest, ActionListener.wrap(indexResponse -> { + // Set generated key to master + logger.info("Config has been initialized successfully"); + this.masterKey = masterKey; + listener.onResponse(true); + }, indexException -> { + logger.error("Failed to index config", indexException); + listener.onFailure(indexException); + })); + + } else { + // Set existing key to master + logger.info("Config has already been initialized"); + final String masterKey = (String) getResponse.getSourceAsMap().get(MASTER_KEY); + this.masterKey = masterKey; + listener.onResponse(true); + } + }, getRequestException -> { + logger.error("Failed to search for config from config index", getRequestException); + listener.onFailure(getRequestException); + })); + + } catch (Exception e) { + logger.error("Failed to retrieve config from config index", e); + listener.onFailure(e); + } + } + + /** + * Retrieves master key from system index if not yet set + */ + void initializeMasterKeyIfAbsent() { + if (masterKey != null) { + return; + } + + if (!clusterService.state().metadata().hasIndex(CONFIG_INDEX)) { + throw new FlowFrameworkException("Config Index has not been initialized", RestStatus.INTERNAL_SERVER_ERROR); + } else { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + GetRequest getRequest = new GetRequest(CONFIG_INDEX).id(MASTER_KEY); + client.get(getRequest, ActionListener.wrap(response -> { + if (response.isExists()) { + this.masterKey = (String) response.getSourceAsMap().get(MASTER_KEY); + } else { + throw new FlowFrameworkException("Config has not been initialized", RestStatus.NOT_FOUND); + } + }, exception -> { throw new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)); })); + } + } + } + +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index 41bd71489..b00857ff6 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -42,7 +42,7 @@ import java.util.stream.Stream; import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD; -import static org.opensearch.flowframework.common.CommonValue.CREDENTIALS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.CREDENTIAL_FIELD; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; @@ -74,22 +74,28 @@ public CreateConnectorStep(MachineLearningNodeClient mlClient, FlowFrameworkIndi // TODO: need to add retry conflicts here @Override - public CompletableFuture execute(List data) throws IOException { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { CompletableFuture createConnectorFuture = new CompletableFuture<>(); ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { + String workflowId = currentNodeInputs.getWorkflowId(); createConnectorFuture.complete( new WorkflowData( Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId())), - data.get(0).getWorkflowId() + workflowId, + currentNodeInputs.getNodeId() ) ); try { logger.info("Created connector successfully"); - String workflowId = data.get(0).getWorkflowId(); String workflowStepName = getName(); ResourceCreated newResource = new ResourceCreated(workflowStepName, mlCreateConnectorResponse.getConnectorId()); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); @@ -136,6 +142,12 @@ public void onFailure(Exception e) { Map credentials = Collections.emptyMap(); List actions = Collections.emptyList(); + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + try { for (WorkflowData workflowData : data) { for (Entry entry : workflowData.getContent().entrySet()) { @@ -155,8 +167,8 @@ public void onFailure(Exception e) { case PARAMETERS_FIELD: parameters = getParameterMap(entry.getValue()); break; - case CREDENTIALS_FIELD: - credentials = getStringToStringMap(entry.getValue(), CREDENTIALS_FIELD); + case CREDENTIAL_FIELD: + credentials = getStringToStringMap(entry.getValue(), CREDENTIAL_FIELD); break; case ACTIONS_FIELD: actions = getConnectorActionList(entry.getValue()); diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index f3a82b26c..f443e9c2c 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -19,6 +19,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -54,14 +55,25 @@ public CreateIndexStep(ClusterService clusterService, Client client) { } @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture future = new CompletableFuture<>(); ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(CreateIndexResponse createIndexResponse) { logger.info("created index: {}", createIndexResponse.index()); - future.complete(new WorkflowData(Map.of(INDEX_NAME, createIndexResponse.index()), data.get(0).getWorkflowId())); + future.complete( + new WorkflowData( + Map.of(INDEX_NAME, createIndexResponse.index()), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); } @Override @@ -75,6 +87,12 @@ public void onFailure(Exception e) { String type = null; Settings settings = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + for (WorkflowData workflowData : data) { Map content = workflowData.getContent(); index = (String) content.get(INDEX_NAME); diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java index a63a800fd..77dae29eb 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java @@ -20,6 +20,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -59,7 +60,12 @@ public CreateIngestPipelineStep(Client client) { } @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture createIngestPipelineFuture = new CompletableFuture<>(); @@ -71,6 +77,12 @@ public CompletableFuture execute(List data) { String outputFieldName = null; BytesReference configuration = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + // Extract required content from workflow data and generate the ingest pipeline configuration for (WorkflowData workflowData : data) { @@ -126,7 +138,11 @@ public CompletableFuture execute(List data) { // PutPipelineRequest returns only an AcknowledgeResponse, returning pipelineId instead createIngestPipelineFuture.complete( - new WorkflowData(Map.of(PIPELINE_ID, putPipelineRequest.getId()), data.get(0).getWorkflowId()) + new WorkflowData( + Map.of(PIPELINE_ID, putPipelineRequest.getId()), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) ); // TODO : Use node client to index response data to global context (pending global context index implementation) diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java index c2cb77dfa..a1c52ef15 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java @@ -18,6 +18,10 @@ import org.opensearch.ml.client.MachineLearningNodeClient; import java.io.IOException; +<<<<<<< HEAD +import java.util.ArrayList; +======= +>>>>>>> 6d59df28b09eeb30443ea55474b1817863dea3af import java.util.List; import java.util.Map; import java.util.Optional; @@ -45,7 +49,16 @@ public DeleteConnectorStep(MachineLearningNodeClient mlClient) { } @Override +<<<<<<< HEAD + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { +======= public CompletableFuture execute(List data) throws IOException { +>>>>>>> 6d59df28b09eeb30443ea55474b1817863dea3af CompletableFuture deleteConnectorFuture = new CompletableFuture<>(); ActionListener actionListener = new ActionListener<>() { @@ -53,7 +66,15 @@ public CompletableFuture execute(List data) throws I @Override public void onResponse(DeleteResponse deleteResponse) { deleteConnectorFuture.complete( +<<<<<<< HEAD + new WorkflowData( + Map.ofEntries(Map.entry("connector_id", deleteResponse.getId())), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) +======= new WorkflowData(Map.ofEntries(Map.entry("connector_id", deleteResponse.getId())), data.get(0).getWorkflowId()) +>>>>>>> 6d59df28b09eeb30443ea55474b1817863dea3af ); } @@ -64,6 +85,15 @@ public void onFailure(Exception e) { } }; +<<<<<<< HEAD + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + +======= +>>>>>>> 6d59df28b09eeb30443ea55474b1817863dea3af Optional connectorId = data.stream() .map(WorkflowData::getContent) .filter(m -> m.containsKey(CONNECTOR_ID)) diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index 8ce89176c..aa6768605 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -17,6 +17,7 @@ import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -41,7 +42,12 @@ public DeployModelStep(MachineLearningNodeClient mlClient) { } @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture deployModelFuture = new CompletableFuture<>(); @@ -52,7 +58,8 @@ public void onResponse(MLDeployModelResponse mlDeployModelResponse) { deployModelFuture.complete( new WorkflowData( Map.ofEntries(Map.entry("deploy_model_status", mlDeployModelResponse.getStatus())), - data.get(0).getWorkflowId() + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() ) ); } @@ -66,6 +73,12 @@ public void onFailure(Exception e) { String modelId = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + for (WorkflowData workflowData : data) { if (workflowData.getContent().containsKey(MODEL_ID)) { modelId = (String) workflowData.getContent().get(MODEL_ID); diff --git a/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java b/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java index bb57adbae..018783b19 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/GetMLTaskStep.java @@ -20,6 +20,7 @@ import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTaskState; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -50,12 +51,23 @@ public GetMLTaskStep(Settings settings, ClusterService clusterService, MachineLe } @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture getMLTaskFuture = new CompletableFuture<>(); String taskId = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + for (WorkflowData workflowData : data) { Map content = workflowData.getContent(); for (Entry entry : content.entrySet()) { @@ -73,7 +85,7 @@ public CompletableFuture execute(List data) { logger.error("Failed to retrieve ML Task"); getMLTaskFuture.completeExceptionally(new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST)); } else { - retryableGetMlTask(data.get(0).getWorkflowId(), getMLTaskFuture, taskId, 0); + retryableGetMlTask(currentNodeInputs.getWorkflowId(), currentNodeInputs.getNodeId(), getMLTaskFuture, taskId, 0); } return getMLTaskFuture; @@ -87,11 +99,18 @@ public String getName() { /** * Retryable GetMLTask * @param workflowId the workflow id + * @param nodeId the node id * @param getMLTaskFuture the workflow step future * @param taskId the ml task id * @param retries the current number of request retries */ - protected void retryableGetMlTask(String workflowId, CompletableFuture getMLTaskFuture, String taskId, int retries) { + protected void retryableGetMlTask( + String workflowId, + String nodeId, + CompletableFuture getMLTaskFuture, + String taskId, + int retries + ) { mlClient.getTask(taskId, ActionListener.wrap(response -> { if (response.getState() != MLTaskState.COMPLETED) { throw new IllegalStateException("MLTask is not yet completed"); @@ -103,7 +122,8 @@ protected void retryableGetMlTask(String workflowId, CompletableFuture execute(List data) throws IOException { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { CompletableFuture registerModelGroupFuture = new CompletableFuture<>(); @@ -67,7 +72,8 @@ public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse Map.entry("model_group_id", mlRegisterModelGroupResponse.getModelGroupId()), Map.entry("model_group_status", mlRegisterModelGroupResponse.getStatus()) ), - data.get(0).getWorkflowId() + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() ) ); } @@ -85,6 +91,12 @@ public void onFailure(Exception e) { AccessMode modelAccessMode = null; Boolean isAddAllBackendRoles = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + for (WorkflowData workflowData : data) { Map content = workflowData.getContent(); diff --git a/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java b/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java index 098c5626c..bbf325e46 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java @@ -9,7 +9,7 @@ package org.opensearch.flowframework.workflow; import java.io.IOException; -import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; /** @@ -21,7 +21,12 @@ public class NoOpStep implements WorkflowStep { public static final String NAME = "noop"; @Override - public CompletableFuture execute(List data) throws IOException { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { return CompletableFuture.completedFuture(WorkflowData.EMPTY); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java index 729043074..c6bdcc6a5 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java @@ -14,7 +14,7 @@ import org.opensearch.threadpool.Scheduler.ScheduledCancellable; import org.opensearch.threadpool.ThreadPool; -import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -152,10 +152,10 @@ public CompletableFuture execute() { logger.info("Starting {}.", this.id); // get the input data from predecessor(s) - List input = new ArrayList(); - input.add(this.input); + Map inputMap = new HashMap<>(); for (CompletableFuture cf : predFutures) { - input.add(cf.get()); + WorkflowData wd = cf.get(); + inputMap.put(wd.getNodeId(), wd); } ScheduledCancellable delayExec = null; @@ -167,7 +167,12 @@ public CompletableFuture execute() { }, this.nodeTimeout, ThreadPool.Names.SAME); } // record start time for this step. - CompletableFuture stepFuture = this.workflowStep.execute(input); + CompletableFuture stepFuture = this.workflowStep.execute( + this.id, + this.input, + inputMap, + this.previousNodeInputs + ); // If completed exceptionally, this is a no-op future.complete(stepFuture.get()); // record end time passing workflow steps diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java index ad6cbff8f..27aa5e537 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java @@ -24,6 +24,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -64,7 +65,12 @@ public RegisterLocalModelStep(MachineLearningNodeClient mlClient) { } @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture registerLocalModelFuture = new CompletableFuture<>(); @@ -78,7 +84,8 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { Map.entry(TASK_ID, mlRegisterModelResponse.getTaskId()), Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) ), - data.get(0).getWorkflowId() + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() ) ); } @@ -102,6 +109,12 @@ public void onFailure(Exception e) { String allConfig = null; String url = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + for (WorkflowData workflowData : data) { Map content = workflowData.getContent(); diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index de889b720..e41323a14 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -20,6 +20,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; +import java.util.ArrayList; import java.util.List; import java.util.Locale; import java.util.Map; @@ -55,7 +56,12 @@ public RegisterRemoteModelStep(MachineLearningNodeClient mlClient) { } @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture registerRemoteModelFuture = new CompletableFuture<>(); @@ -69,7 +75,8 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { Map.entry(MODEL_ID, mlRegisterModelResponse.getModelId()), Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus()) ), - data.get(0).getWorkflowId() + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() ) ); } @@ -87,6 +94,12 @@ public void onFailure(Exception e) { String description = null; String connectorId = null; + // TODO: Recreating the list to get this compiling + // Need to refactor the below iteration to pull directly from the maps + List data = new ArrayList<>(); + data.add(currentNodeInputs); + data.addAll(outputs.values()); + // TODO : Handle inline connector configuration : https://github.com/opensearch-project/flow-framework/issues/149 for (WorkflowData workflowData : data) { diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java index 4f62885e9..a0d901f74 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java @@ -28,18 +28,21 @@ public class WorkflowData { @Nullable private String workflowId; + @Nullable + private String nodeId; private WorkflowData() { - this(Collections.emptyMap(), Collections.emptyMap(), ""); + this(Collections.emptyMap(), Collections.emptyMap(), null, null); } /** * Instantiate this object with content and empty params. * @param content The content map * @param workflowId The workflow ID associated with this step + * @param nodeId The node ID associated with this step */ - public WorkflowData(Map content, @Nullable String workflowId) { - this(content, Collections.emptyMap(), workflowId); + public WorkflowData(Map content, @Nullable String workflowId, @Nullable String nodeId) { + this(content, Collections.emptyMap(), workflowId, nodeId); } /** @@ -47,11 +50,13 @@ public WorkflowData(Map content, @Nullable String workflowId) { * @param content The content map * @param params The params map * @param workflowId The workflow ID associated with this step + * @param nodeId The node ID associated with this step */ - public WorkflowData(Map content, Map params, @Nullable String workflowId) { + public WorkflowData(Map content, Map params, @Nullable String workflowId, @Nullable String nodeId) { this.content = Map.copyOf(content); this.params = Map.copyOf(params); this.workflowId = workflowId; + this.nodeId = nodeId; } /** @@ -72,11 +77,20 @@ public Map getParams() { }; /** - * Returns the workflowId associated with this workflow. + * Returns the workflowId associated with this data. * @return the workflowId of this data. */ @Nullable public String getWorkflowId() { return this.workflowId; }; + + /** + * Returns the nodeId associated with this data. + * @return the nodeId of this data. + */ + @Nullable + public String getNodeId() { + return this.nodeId; + }; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index da0705eea..3e8b77f9d 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -70,7 +70,7 @@ public List sortProcessNodes(Workflow workflow, String workflowId) Map idToNodeMap = new HashMap<>(); for (WorkflowNode node : sortedNodes) { WorkflowStep step = workflowStepFactory.createStep(node.type()); - WorkflowData data = new WorkflowData(node.userInputs(), workflow.userParams(), workflowId); + WorkflowData data = new WorkflowData(node.userInputs(), workflow.userParams(), workflowId, node.id()); List predecessorNodes = workflow.edges() .stream() .filter(e -> e.destination().equals(node.id())) diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java index 1f5545cdf..f106ee652 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java @@ -9,7 +9,7 @@ package org.opensearch.flowframework.workflow; import java.io.IOException; -import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; /** @@ -19,11 +19,19 @@ public interface WorkflowStep { /** * Triggers the actual processing of the building block. - * @param data representing input params and content, or output content of previous steps. The first element of the list is data (if any) provided from parsing the template, and may be {@link WorkflowData#EMPTY}. + * @param currentNodeId The id of the node executing this step + * @param currentNodeInputs Input params and content for this node, from workflow parsing + * @param previousNodeInputs Input params for this node that come from previous steps + * @param outputs WorkflowData content of previous steps. * @return A CompletableFuture of the building block. This block should return immediately, but not be completed until the step executes, containing either the step's output data or {@link WorkflowData#EMPTY} which may be passed to follow-on steps. * @throws IOException on a failure. */ - CompletableFuture execute(List data) throws IOException; + CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException; /** * Gets the name of the workflow step. diff --git a/src/main/resources/mappings/config.json b/src/main/resources/mappings/config.json new file mode 100644 index 000000000..6871f541c --- /dev/null +++ b/src/main/resources/mappings/config.json @@ -0,0 +1,15 @@ +{ + "dynamic": false, + "_meta": { + "schema_version": 1 + }, + "properties": { + "master_key": { + "type": "keyword" + }, + "create_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + } + } +} diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index c89baba13..e3827e0b3 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -78,7 +78,7 @@ public void tearDown() throws Exception { public void testPlugin() throws IOException { try (FlowFrameworkPlugin ffp = new FlowFrameworkPlugin()) { assertEquals( - 3, + 4, ffp.createComponents(client, clusterService, threadPool, null, null, null, environment, null, null, null, null).size() ); assertEquals(4, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); diff --git a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java index 2f0fc256f..fd7d8e0d4 100644 --- a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java +++ b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java @@ -25,6 +25,7 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.workflow.CreateIndexStep; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -53,6 +54,8 @@ public class FlowFrameworkIndicesHandlerTests extends OpenSearchTestCase { private CreateIndexStep createIndexStep; @Mock private ThreadPool threadPool; + @Mock + private EncryptorUtils encryptorUtils; private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; private AdminClient adminClient; private IndicesAdminClient indicesAdminClient; @@ -77,7 +80,7 @@ public void setUp() throws Exception { threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService); + flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService, encryptorUtils); adminClient = mock(AdminClient.class); indicesAdminClient = mock(IndicesAdminClient.class); metadata = mock(Metadata.class); diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 22d831f2b..70c066c0e 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -231,6 +231,13 @@ public void testFailedToCreateNewWorkflow() { return null; }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); + // Bypass initializeConfigIndex and force onResponse + doAnswer(invocation -> { + ActionListener initalizeMasterKeyIndexListener = invocation.getArgument(0); + initalizeMasterKeyIndexListener.onResponse(true); + return null; + }).when(flowFrameworkIndicesHandler).initializeConfigIndex(any()); + doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(1); responseListener.onFailure(new Exception("Failed to create global_context index")); @@ -261,6 +268,13 @@ public void testCreateNewWorkflow() { return null; }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); + // Bypass initializeConfigIndex and force onResponse + doAnswer(invocation -> { + ActionListener initalizeMasterKeyIndexListener = invocation.getArgument(0); + initalizeMasterKeyIndexListener.onResponse(true); + return null; + }).when(flowFrameworkIndicesHandler).initializeConfigIndex(any()); + // Bypass putTemplateToGlobalContext and force onResponse doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(1); diff --git a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java index 7d061e572..8bdcaa2c7 100644 --- a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java @@ -25,6 +25,7 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.index.get.GetResult; import org.opensearch.tasks.Task; @@ -53,6 +54,7 @@ public class ProvisionWorkflowTransportActionTests extends OpenSearchTestCase { private ProvisionWorkflowTransportAction provisionWorkflowTransportAction; private Template template; private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private EncryptorUtils encryptorUtils; @Override public void setUp() throws Exception { @@ -61,6 +63,7 @@ public void setUp() throws Exception { this.client = mock(Client.class); this.workflowProcessSorter = mock(WorkflowProcessSorter.class); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + this.encryptorUtils = mock(EncryptorUtils.class); this.provisionWorkflowTransportAction = new ProvisionWorkflowTransportAction( mock(TransportService.class), @@ -68,7 +71,8 @@ public void setUp() throws Exception { threadPool, client, workflowProcessSorter, - flowFrameworkIndicesHandler + flowFrameworkIndicesHandler, + encryptorUtils ); Version templateVersion = Version.fromString("1.0.0"); @@ -116,6 +120,8 @@ public void testProvisionWorkflow() { return null; }).when(client).get(any(GetRequest.class), any()); + when(encryptorUtils.decryptTemplateCredentials(any())).thenReturn(template); + provisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); verify(listener, times(1)).onResponse(responseCaptor.capture()); diff --git a/src/test/java/org/opensearch/flowframework/util/EncryptorUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/EncryptorUtilsTests.java new file mode 100644 index 000000000..1101fa1a2 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/util/EncryptorUtilsTests.java @@ -0,0 +1,193 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.util; + +import com.google.common.collect.ImmutableMap; +import org.opensearch.Version; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; + +import java.util.List; +import java.util.Map; + +import static org.opensearch.flowframework.common.CommonValue.CREDENTIAL_FIELD; +import static org.opensearch.flowframework.common.CommonValue.MASTER_KEY; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class EncryptorUtilsTests extends OpenSearchTestCase { + + private ClusterService clusterService; + private Client client; + private EncryptorUtils encryptorUtils; + private String testMasterKey; + private Template testTemplate; + private String testCredentialKey; + private String testCredentialValue; + + @Override + public void setUp() throws Exception { + super.setUp(); + this.clusterService = mock(ClusterService.class); + this.client = mock(Client.class); + this.encryptorUtils = new EncryptorUtils(clusterService, client); + this.testMasterKey = encryptorUtils.generateMasterKey(); + this.testCredentialKey = "credential_key"; + this.testCredentialValue = "12345"; + + Version templateVersion = Version.fromString("1.0.0"); + List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); + WorkflowNode nodeA = new WorkflowNode( + "A", + "a-type", + Map.of(), + Map.of(CREDENTIAL_FIELD, Map.of(testCredentialKey, testCredentialValue)) + ); + List nodes = List.of(nodeA); + Workflow workflow = new Workflow(Map.of("key", "value"), nodes, List.of()); + + this.testTemplate = new Template( + "test", + "description", + "use case", + templateVersion, + compatibilityVersions, + Map.of("provision", workflow), + Map.of(), + TestHelpers.randomUser() + ); + + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex(anyString())).thenReturn(true); + + ThreadPool threadPool = mock(ThreadPool.class); + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + } + + public void testGenerateMasterKey() { + String generatedMasterKey = encryptorUtils.generateMasterKey(); + encryptorUtils.setMasterKey(generatedMasterKey); + assertEquals(generatedMasterKey, encryptorUtils.getMasterKey()); + } + + public void testEncryptDecrypt() { + encryptorUtils.setMasterKey(testMasterKey); + String testString = "test"; + String encrypted = encryptorUtils.encrypt(testString); + assertNotNull(encrypted); + + String decrypted = encryptorUtils.decrypt(encrypted); + assertEquals(testString, decrypted); + } + + public void testEncryptWithDifferentMasterKey() { + encryptorUtils.setMasterKey(testMasterKey); + String testString = "test"; + String encrypted1 = encryptorUtils.encrypt(testString); + assertNotNull(encrypted1); + + // Change the master key prior to encryption + String differentMasterKey = encryptorUtils.generateMasterKey(); + encryptorUtils.setMasterKey(differentMasterKey); + String encrypted2 = encryptorUtils.encrypt(testString); + + assertNotEquals(encrypted1, encrypted2); + } + + public void testInitializeMasterKeySuccess() { + encryptorUtils.setMasterKey(null); + + String masterKey = encryptorUtils.generateMasterKey(); + doAnswer(invocation -> { + ActionListener getRequestActionListener = invocation.getArgument(1); + + // Stub get response for success case + GetResponse getResponse = mock(GetResponse.class); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSourceAsMap()).thenReturn(ImmutableMap.of(MASTER_KEY, masterKey)); + + getRequestActionListener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), any()); + + encryptorUtils.initializeMasterKeyIfAbsent(); + assertEquals(masterKey, encryptorUtils.getMasterKey()); + } + + public void testInitializeMasterKeyFailure() { + encryptorUtils.setMasterKey(null); + + doAnswer(invocation -> { + ActionListener getRequestActionListener = invocation.getArgument(1); + + // Stub get response for failure case + GetResponse getResponse = mock(GetResponse.class); + when(getResponse.isExists()).thenReturn(false); + getRequestActionListener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), any()); + + FlowFrameworkException ex = expectThrows(FlowFrameworkException.class, () -> encryptorUtils.initializeMasterKeyIfAbsent()); + assertEquals("Config has not been initialized", ex.getMessage()); + } + + public void testEncryptDecryptTemplateCredential() { + encryptorUtils.setMasterKey(testMasterKey); + + // Ecnrypt template with credential field + Template processedtemplate = encryptorUtils.encryptTemplateCredentials(testTemplate); + + // Validate the encrytped field + WorkflowNode node = processedtemplate.workflows().get("provision").nodes().get(0); + + @SuppressWarnings("unchecked") + Map encryptedCredentialMap = (Map) node.userInputs().get(CREDENTIAL_FIELD); + assertEquals(1, encryptedCredentialMap.size()); + + String encryptedCredential = encryptedCredentialMap.get(testCredentialKey); + assertNotNull(encryptedCredential); + assertNotEquals(testCredentialValue, encryptedCredential); + + // Decrypt credential field + processedtemplate = encryptorUtils.decryptTemplateCredentials(processedtemplate); + + // Validate the decrypted field + node = processedtemplate.workflows().get("provision").nodes().get(0); + + @SuppressWarnings("unchecked") + Map decryptedCredentialMap = (Map) node.userInputs().get(CREDENTIAL_FIELD); + assertEquals(1, decryptedCredentialMap.size()); + + String decryptedCredential = decryptedCredentialMap.get(testCredentialKey); + assertNotNull(decryptedCredential); + assertEquals(testCredentialValue, decryptedCredential); + } +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index a05a3927e..de3add996 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -20,7 +20,7 @@ import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; -import java.util.List; +import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -68,10 +68,11 @@ public void setUp() throws Exception { Map.entry(CommonValue.VERSION_FIELD, "1"), Map.entry(CommonValue.PROTOCOL_FIELD, "test"), Map.entry(CommonValue.PARAMETERS_FIELD, params), - Map.entry(CommonValue.CREDENTIALS_FIELD, credentials), + Map.entry(CommonValue.CREDENTIAL_FIELD, credentials), Map.entry(CommonValue.ACTIONS_FIELD, actions) ), - "test-id" + "test-id", + "test-node-id" ); } @@ -90,7 +91,12 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr return null; }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); - CompletableFuture future = createConnectorStep.execute(List.of(inputData)); + CompletableFuture future = createConnectorStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); @@ -111,7 +117,12 @@ public void testCreateConnectorFailure() throws IOException { return null; }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); - CompletableFuture future = createConnectorStep.execute(List.of(inputData)); + CompletableFuture future = createConnectorStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java index 67cb6cb9b..8be5c5787 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java @@ -24,8 +24,8 @@ import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; +import java.util.Collections; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -69,7 +69,7 @@ public class CreateIndexStepTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); MockitoAnnotations.openMocks(this); - inputData = new WorkflowData(Map.ofEntries(Map.entry("index_name", "demo"), Map.entry("type", "knn")), "test-id"); + inputData = new WorkflowData(Map.ofEntries(Map.entry("index_name", "demo"), Map.entry("type", "knn")), "test-id", "test-node-id"); clusterService = mock(ClusterService.class); client = mock(Client.class); adminClient = mock(AdminClient.class); @@ -91,7 +91,12 @@ public void setUp() throws Exception { public void testCreateIndexStep() throws ExecutionException, InterruptedException { @SuppressWarnings({ "unchecked" }) ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - CompletableFuture future = createIndexStep.execute(List.of(inputData)); + CompletableFuture future = createIndexStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertFalse(future.isDone()); verify(indicesAdminClient, times(1)).create(any(CreateIndexRequest.class), actionListenerCaptor.capture()); actionListenerCaptor.getValue().onResponse(new CreateIndexResponse(true, true, "demo")); @@ -106,7 +111,12 @@ public void testCreateIndexStep() throws ExecutionException, InterruptedExceptio public void testCreateIndexStepFailure() throws ExecutionException, InterruptedException { @SuppressWarnings({ "unchecked" }) ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - CompletableFuture future = createIndexStep.execute(List.of(inputData)); + CompletableFuture future = createIndexStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertFalse(future.isDone()); verify(indicesAdminClient, times(1)).create(any(CreateIndexRequest.class), actionListenerCaptor.capture()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java index 194c80eb0..f0a970758 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java @@ -16,7 +16,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.test.OpenSearchTestCase; -import java.util.List; +import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -51,11 +51,12 @@ public void setUp() throws Exception { Map.entry("input_field_name", "inputField"), Map.entry("output_field_name", "outputField") ), - "test-id" + "test-id", + "test-node-id" ); // Set output data to returned pipelineId - outpuData = new WorkflowData(Map.ofEntries(Map.entry("pipeline_id", "pipelineId")), "test-id"); + outpuData = new WorkflowData(Map.ofEntries(Map.entry("pipeline_id", "pipelineId")), "test-id", "test-node-id"); client = mock(Client.class); adminClient = mock(AdminClient.class); @@ -71,7 +72,12 @@ public void testCreateIngestPipelineStep() throws InterruptedException, Executio @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - CompletableFuture future = createIngestPipelineStep.execute(List.of(inputData)); + CompletableFuture future = createIngestPipelineStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertFalse(future.isDone()); @@ -89,7 +95,12 @@ public void testCreateIngestPipelineStepFailure() throws InterruptedException { @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - CompletableFuture future = createIngestPipelineStep.execute(List.of(inputData)); + CompletableFuture future = createIngestPipelineStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertFalse(future.isDone()); @@ -115,10 +126,16 @@ public void testMissingData() throws InterruptedException { Map.entry("type", "text_embedding"), Map.entry("model_id", "model_id") ), - "test-id" + "test-id", + "test-node-id" ); - CompletableFuture future = createIngestPipelineStep.execute(List.of(incorrectData)); + CompletableFuture future = createIngestPipelineStep.execute( + incorrectData.getNodeId(), + incorrectData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone() && future.isCompletedExceptionally()); ExecutionException exception = assertThrows(ExecutionException.class, () -> future.get()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java index 5cdde128c..478d46c91 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java @@ -20,7 +20,7 @@ import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; -import java.util.List; +import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -45,7 +45,7 @@ public void setUp() throws Exception { MockitoAnnotations.openMocks(this); - inputData = new WorkflowData(Map.of(CommonValue.CONNECTOR_ID, "test"), "test-id"); + inputData = new WorkflowData(Map.of(CommonValue.CONNECTOR_ID, "test"), "test-id", "test-node-id"); } public void testDeleteConnector() throws IOException, ExecutionException, InterruptedException { @@ -64,7 +64,12 @@ public void testDeleteConnector() throws IOException, ExecutionException, Interr return null; }).when(machineLearningNodeClient).deleteConnector(any(String.class), actionListenerCaptor.capture()); - CompletableFuture future = deleteConnectorStep.execute(List.of(inputData)); + CompletableFuture future = deleteConnectorStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).deleteConnector(any(String.class), actionListenerCaptor.capture()); @@ -85,7 +90,12 @@ public void testDeleteConnectorFailure() throws IOException { return null; }).when(machineLearningNodeClient).deleteConnector(any(String.class), actionListenerCaptor.capture()); - CompletableFuture future = deleteConnectorStep.execute(List.of(inputData)); + CompletableFuture future = deleteConnectorStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).deleteConnector(any(String.class), actionListenerCaptor.capture()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index fd856b945..670933373 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -19,7 +19,7 @@ import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; import org.opensearch.test.OpenSearchTestCase; -import java.util.List; +import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -44,7 +44,7 @@ public class DeployModelStepTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); - inputData = new WorkflowData(Map.ofEntries(Map.entry("model_id", "modelId")), "test-id"); + inputData = new WorkflowData(Map.ofEntries(Map.entry("model_id", "modelId")), "test-id", "test-node-id"); MockitoAnnotations.openMocks(this); @@ -67,7 +67,12 @@ public void testDeployModel() throws ExecutionException, InterruptedException { return null; }).when(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); - CompletableFuture future = deployModel.execute(List.of(inputData)); + CompletableFuture future = deployModel.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); @@ -87,7 +92,12 @@ public void testDeployModelFailure() { return null; }).when(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); - CompletableFuture future = deployModel.execute(List.of(inputData)); + CompletableFuture future = deployModel.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java index efb59d42a..bd62ddfc7 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/GetMLTaskStepTests.java @@ -21,7 +21,7 @@ import org.opensearch.ml.common.MLTaskState; import org.opensearch.test.OpenSearchTestCase; -import java.util.List; +import java.util.Collections; import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; @@ -70,7 +70,7 @@ public void setUp() throws Exception { when(clusterService.getClusterSettings()).thenReturn(clusterSettings); this.getMLTaskStep = spy(new GetMLTaskStep(testMaxRetrySetting, clusterService, mlNodeClient)); - this.workflowData = new WorkflowData(Map.ofEntries(Map.entry(TASK_ID, "test")), "test-id"); + this.workflowData = new WorkflowData(Map.ofEntries(Map.entry(TASK_ID, "test")), "test-id", "test-node-id"); } public void testGetMLTaskSuccess() throws Exception { @@ -85,7 +85,12 @@ public void testGetMLTaskSuccess() throws Exception { return null; }).when(mlNodeClient).getTask(any(), any()); - CompletableFuture future = this.getMLTaskStep.execute(List.of(workflowData)); + CompletableFuture future = this.getMLTaskStep.execute( + workflowData.getNodeId(), + workflowData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(mlNodeClient, times(1)).getTask(any(), any()); @@ -102,7 +107,12 @@ public void testGetMLTaskFailure() { return null; }).when(mlNodeClient).getTask(any(), any()); - CompletableFuture future = this.getMLTaskStep.execute(List.of(workflowData)); + CompletableFuture future = this.getMLTaskStep.execute( + workflowData.getNodeId(), + workflowData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone()); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); @@ -111,7 +121,12 @@ public void testGetMLTaskFailure() { } public void testMissingInputs() { - CompletableFuture future = this.getMLTaskStep.execute(List.of(WorkflowData.EMPTY)); + CompletableFuture future = this.getMLTaskStep.execute( + "nodeID", + WorkflowData.EMPTY, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone()); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java index f763c8005..bc914baa7 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java @@ -20,7 +20,7 @@ import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; -import java.util.List; +import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -29,7 +29,6 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.verify; @@ -54,7 +53,8 @@ public void setUp() throws Exception { Map.entry("access_mode", AccessMode.PUBLIC), Map.entry("add_all_backend_roles", false) ), - "test-id" + "test-id", + "test-node-id" ); } @@ -75,7 +75,12 @@ public void testRegisterModelGroup() throws ExecutionException, InterruptedExcep return null; }).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); - CompletableFuture future = modelGroupStep.execute(List.of(inputData)); + CompletableFuture future = modelGroupStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); @@ -97,7 +102,12 @@ public void testRegisterModelGroupFailure() throws ExecutionException, Interrupt return null; }).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); - CompletableFuture future = modelGroupStep.execute(List.of(inputData)); + CompletableFuture future = modelGroupStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); @@ -111,7 +121,12 @@ public void testRegisterModelGroupFailure() throws ExecutionException, Interrupt public void testRegisterModelGroupWithNoName() throws IOException { ModelGroupStep modelGroupStep = new ModelGroupStep(machineLearningNodeClient); - CompletableFuture future = modelGroupStep.execute(List.of(inputDataWithNoName)); + CompletableFuture future = modelGroupStep.execute( + inputDataWithNoName.getNodeId(), + inputDataWithNoName, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java index 6c03cd87d..1782375cc 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java @@ -19,7 +19,12 @@ public class NoOpStepTests extends OpenSearchTestCase { public void testNoOpStep() throws IOException { NoOpStep noopStep = new NoOpStep(); assertEquals(NoOpStep.NAME, noopStep.getName()); - CompletableFuture future = noopStep.execute(Collections.emptyList()); + CompletableFuture future = noopStep.execute( + "nodeId", + WorkflowData.EMPTY, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone()); assertFalse(future.isCompletedExceptionally()); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java index 6aae139e4..f50250ea5 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java @@ -56,9 +56,14 @@ public void testNode() throws InterruptedException, ExecutionException { // Tests where execute nas no timeout ProcessNode nodeA = new ProcessNode("A", new WorkflowStep() { @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture f = new CompletableFuture<>(); - f.complete(new WorkflowData(Map.of("test", "output"), "test-id")); + f.complete(new WorkflowData(Map.of("test", "output"), "test-id", "test-node-id")); return f; } @@ -68,7 +73,7 @@ public String getName() { } }, Map.of(), - new WorkflowData(Map.of("test", "input"), Map.of("foo", "bar"), "test-id"), + new WorkflowData(Map.of("test", "input"), Map.of("foo", "bar"), "test-id", "test-node-id"), List.of(successfulNode), testThreadPool, TimeValue.timeValueMillis(50) @@ -78,6 +83,7 @@ public String getName() { assertEquals("input", nodeA.input().getContent().get("test")); assertEquals("bar", nodeA.input().getParams().get("foo")); assertEquals("test-id", nodeA.input().getWorkflowId()); + assertEquals("test-node-id", nodeA.input().getNodeId()); assertEquals(1, nodeA.predecessors().size()); assertEquals(50, nodeA.nodeTimeout().millis()); assertEquals("A", nodeA.toString()); @@ -91,7 +97,12 @@ public void testNodeNoTimeout() throws InterruptedException, ExecutionException // Tests where execute finishes before timeout ProcessNode nodeB = new ProcessNode("B", new WorkflowStep() { @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture future = new CompletableFuture<>(); testThreadPool.schedule( () -> future.complete(WorkflowData.EMPTY), @@ -121,7 +132,12 @@ public void testNodeTimeout() throws InterruptedException, ExecutionException { // Tests where execute finishes after timeout ProcessNode nodeZ = new ProcessNode("Zzz", new WorkflowStep() { @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture future = new CompletableFuture<>(); testThreadPool.schedule(() -> future.complete(WorkflowData.EMPTY), TimeValue.timeValueMinutes(1), ThreadPool.Names.GENERIC); return future; @@ -148,7 +164,12 @@ public void testExceptions() { // Tests where a predecessor future completed exceptionally ProcessNode nodeE = new ProcessNode("E", new WorkflowStep() { @Override - public CompletableFuture execute(List data) { + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { CompletableFuture f = new CompletableFuture<>(); f.complete(WorkflowData.EMPTY); return f; diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java index bd40c50ad..b7b47de46 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java @@ -20,7 +20,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.test.OpenSearchTestCase; -import java.util.List; +import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -70,7 +70,8 @@ public void setUp() throws Exception { Map.entry("framework_type", "sentence_transformers"), Map.entry("url", "something.com") ), - "test-id" + "test-id", + "test-node-id" ); } @@ -87,7 +88,12 @@ public void testRegisterLocalModelSuccess() throws Exception { return null; }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); - CompletableFuture future = registerLocalModelStep.execute(List.of(workflowData)); + CompletableFuture future = registerLocalModelStep.execute( + workflowData.getNodeId(), + workflowData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); assertTrue(future.isDone()); @@ -105,7 +111,12 @@ public void testRegisterLocalModelFailure() { return null; }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); - CompletableFuture future = this.registerLocalModelStep.execute(List.of(workflowData)); + CompletableFuture future = this.registerLocalModelStep.execute( + workflowData.getNodeId(), + workflowData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone()); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); @@ -114,7 +125,12 @@ public void testRegisterLocalModelFailure() { } public void testMissingInputs() { - CompletableFuture future = registerLocalModelStep.execute(List.of(WorkflowData.EMPTY)); + CompletableFuture future = registerLocalModelStep.execute( + "nodeId", + WorkflowData.EMPTY, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone()); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index e60707f67..cde194326 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -18,7 +18,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.test.OpenSearchTestCase; -import java.util.List; +import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -55,7 +55,8 @@ public void setUp() throws Exception { Map.entry("description", "description"), Map.entry("connector_id", "abcdefg") ), - "test-id" + "test-id", + "test-node-id" ); } @@ -72,7 +73,12 @@ public void testRegisterRemoteModelSuccess() throws Exception { return null; }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); - CompletableFuture future = this.registerRemoteModelStep.execute(List.of(workflowData)); + CompletableFuture future = this.registerRemoteModelStep.execute( + workflowData.getNodeId(), + workflowData, + Collections.emptyMap(), + Collections.emptyMap() + ); verify(mlNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); @@ -90,7 +96,12 @@ public void testRegisterRemoteModelFailure() { return null; }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); - CompletableFuture future = this.registerRemoteModelStep.execute(List.of(workflowData)); + CompletableFuture future = this.registerRemoteModelStep.execute( + workflowData.getNodeId(), + workflowData, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone()); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); @@ -100,7 +111,12 @@ public void testRegisterRemoteModelFailure() { } public void testMissingInputs() { - CompletableFuture future = this.registerRemoteModelStep.execute(List.of(WorkflowData.EMPTY)); + CompletableFuture future = this.registerRemoteModelStep.execute( + "nodeId", + WorkflowData.EMPTY, + Collections.emptyMap(), + Collections.emptyMap() + ); assertTrue(future.isDone()); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java index 8a4a1fda9..39023c6b4 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java @@ -26,14 +26,17 @@ public void testWorkflowData() { assertTrue(empty.getContent().isEmpty()); Map expectedContent = Map.of("baz", new String[] { "qux", "quxx" }); - WorkflowData contentOnly = new WorkflowData(expectedContent, "test-id-123"); + WorkflowData contentOnly = new WorkflowData(expectedContent, null, null); assertTrue(contentOnly.getParams().isEmpty()); assertEquals(expectedContent, contentOnly.getContent()); + assertNull(contentOnly.getWorkflowId()); + assertNull(contentOnly.getNodeId()); Map expectedParams = Map.of("foo", "bar"); - WorkflowData contentAndParams = new WorkflowData(expectedContent, expectedParams, "test-id-123"); + WorkflowData contentAndParams = new WorkflowData(expectedContent, expectedParams, "test-id-123", "test-node-id"); assertEquals(expectedParams, contentAndParams.getParams()); assertEquals(expectedContent, contentAndParams.getContent()); assertEquals("test-id-123", contentAndParams.getWorkflowId()); + assertEquals("test-node-id", contentAndParams.getNodeId()); } }