From 772fbbd9eb58aeda52ea57a36307ad7206b1cf0b Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Wed, 29 Nov 2023 14:53:22 -0800 Subject: [PATCH] Encrypt/Decrypt template credentials (#197) * added RegisterRemoteModelStep and tests Signed-off-by: Joshua Palis * Adding RegisterLocalModelStep, fixing tests, adding input/ouput definitions to workflow step json Signed-off-by: Joshua Palis * Fixing javadoc warnings, fixing log message Signed-off-by: Joshua Palis * Addressing PR comments,making description field optional for RegisterRemoteModelStep and RegisterLocalModelStep Signed-off-by: Joshua Palis * moving modelConfig builder before adding allConfig Signed-off-by: Joshua Palis * initial implementation Signed-off-by: Joshua Palis * Fixing create workflow transport action Signed-off-by: Joshua Palis * Removing duplicate register_remote_model validator Signed-off-by: Joshua Palis * Adding bouncy castle dependency to resolve encryption issue Signed-off-by: Joshua Palis * Fixing CreateWorkflowTransportActionTests Signed-off-by: Joshua Palis * Adding initial unit tests for encryptor utils Signed-off-by: Joshua Palis * Implemented encryption/decryption for workflow node user inputs with credential Signed-off-by: Joshua Palis * Addressing PR comments Signed-off-by: Joshua Palis * Suppressing unchecked warning, making credential strings constants Signed-off-by: Joshua Palis * Removing setMasterKey from initializeMasterKey method Signed-off-by: Joshua Palis * Adding final template encryption decryption test Signed-off-by: Joshua Palis * Addressing PR comments, changing master key index name to config, fixes error messages as well. Added create time field to config index, ensured that updates are also encrypted Signed-off-by: Joshua Palis * Added TODO Signed-off-by: Joshua Palis * changing getMasterKeyIndexMapping method name Signed-off-by: Joshua Palis * Removing unnecessary aws sdk dependency Signed-off-by: Joshua Palis --------- Signed-off-by: Joshua Palis --- build.gradle | 2 + .../flowframework/FlowFrameworkPlugin.java | 6 +- .../flowframework/common/CommonValue.java | 14 +- .../indices/FlowFrameworkIndex.java | 10 + .../indices/FlowFrameworkIndicesHandler.java | 47 ++- .../CreateWorkflowTransportAction.java | 60 ++-- .../ProvisionWorkflowTransportAction.java | 11 +- .../flowframework/util/EncryptorUtils.java | 286 ++++++++++++++++++ .../workflow/CreateConnectorStep.java | 6 +- src/main/resources/mappings/config.json | 15 + .../FlowFrameworkPluginTests.java | 2 +- .../FlowFrameworkIndicesHandlerTests.java | 5 +- .../CreateWorkflowTransportActionTests.java | 14 + ...ProvisionWorkflowTransportActionTests.java | 8 +- .../util/EncryptorUtilsTests.java | 193 ++++++++++++ .../workflow/CreateConnectorStepTests.java | 2 +- 16 files changed, 648 insertions(+), 33 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java create mode 100644 src/main/resources/mappings/config.json create mode 100644 src/test/java/org/opensearch/flowframework/util/EncryptorUtilsTests.java diff --git a/build.gradle b/build.gradle index 60bbed573..9f8bc9429 100644 --- a/build.gradle +++ b/build.gradle @@ -142,6 +142,8 @@ dependencies { api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.14.0' implementation "org.opensearch:common-utils:${common_utils_version}" + implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1' + implementation 'org.bouncycastle:bcprov-jdk18on:1.77' configurations.all { resolutionStrategy { diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 4b2e4a9cb..14df7e17e 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -39,6 +39,7 @@ import org.opensearch.flowframework.transport.ProvisionWorkflowTransportAction; import org.opensearch.flowframework.transport.SearchWorkflowAction; import org.opensearch.flowframework.transport.SearchWorkflowTransportAction; +import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.ml.client.MachineLearningNodeClient; @@ -96,7 +97,8 @@ public Collection createComponents( this.clusterService = clusterService; flowFrameworkFeatureEnabledSetting = new FlowFrameworkFeatureEnabledSetting(clusterService, settings); MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); - FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService); + EncryptorUtils encryptorUtils = new EncryptorUtils(clusterService, client); + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService, encryptorUtils); WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory( settings, clusterService, @@ -106,7 +108,7 @@ public Collection createComponents( ); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); - return ImmutableList.of(workflowStepFactory, workflowProcessSorter, flowFrameworkIndicesHandler); + return ImmutableList.of(workflowStepFactory, workflowProcessSorter, encryptorUtils, flowFrameworkIndicesHandler); } @Override diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 1f9f33f4b..90f208c8d 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -33,6 +33,16 @@ private CommonValue() {} public static final String WORKFLOW_STATE_INDEX_MAPPING = "mappings/workflow-state.json"; /** Workflow State index mapping version */ public static final Integer WORKFLOW_STATE_INDEX_VERSION = 1; + /** Config Index Name */ + public static final String CONFIG_INDEX = ".plugins-flow-framework-config"; + /** Config index mapping file path */ + public static final String CONFIG_INDEX_MAPPING = "mappings/config.json"; + /** Config index mapping version */ + public static final Integer CONFIG_INDEX_VERSION = 1; + /** Master key field name */ + public static final String MASTER_KEY = "master_key"; + /** Create Time field name */ + public static final String CREATE_TIME = "create_time"; /** The template field name for template use case */ public static final String USE_CASE_FIELD = "use_case"; @@ -119,8 +129,8 @@ private CommonValue() {} public static final String PROTOCOL_FIELD = "protocol"; /** Connector parameters field */ public static final String PARAMETERS_FIELD = "parameters"; - /** Connector credentials field */ - public static final String CREDENTIALS_FIELD = "credentials"; + /** Connector credential field */ + public static final String CREDENTIAL_FIELD = "credential"; /** Connector actions field */ public static final String ACTIONS_FIELD = "actions"; /** Backend roles for the model */ diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java index 4b005e45d..b74c1d89a 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java @@ -12,6 +12,8 @@ import java.util.function.Supplier; +import static org.opensearch.flowframework.common.CommonValue.CONFIG_INDEX; +import static org.opensearch.flowframework.common.CommonValue.CONFIG_INDEX_VERSION; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX_VERSION; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; @@ -36,6 +38,14 @@ public enum FlowFrameworkIndex { WORKFLOW_STATE_INDEX, ThrowingSupplierWrapper.throwingSupplierWrapper(FlowFrameworkIndicesHandler::getWorkflowStateMappings), WORKFLOW_STATE_INDEX_VERSION + ), + /** + * Config Index + */ + CONFIG( + CONFIG_INDEX, + ThrowingSupplierWrapper.throwingSupplierWrapper(FlowFrameworkIndicesHandler::getConfigIndexMappings), + CONFIG_INDEX_VERSION ); private final String indexName; diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index 7dcc89de6..cce4ba839 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -38,6 +38,7 @@ import org.opensearch.flowframework.model.State; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.script.Script; import java.io.IOException; @@ -48,6 +49,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; +import static org.opensearch.flowframework.common.CommonValue.CONFIG_INDEX_MAPPING; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX_MAPPING; import static org.opensearch.flowframework.common.CommonValue.META; @@ -64,6 +66,7 @@ public class FlowFrameworkIndicesHandler { private static final Logger logger = LogManager.getLogger(FlowFrameworkIndicesHandler.class); private final Client client; private final ClusterService clusterService; + private final EncryptorUtils encryptorUtils; private static final Map indexMappingUpdated = new HashMap<>(); private static final Map indexSettings = Map.of("index.auto_expand_replicas", "0-1"); @@ -71,10 +74,12 @@ public class FlowFrameworkIndicesHandler { * constructor * @param client the open search client * @param clusterService ClusterService + * @param encryptorUtils encryption utility */ - public FlowFrameworkIndicesHandler(Client client, ClusterService clusterService) { + public FlowFrameworkIndicesHandler(Client client, ClusterService clusterService, EncryptorUtils encryptorUtils) { this.client = client; this.clusterService = clusterService; + this.encryptorUtils = encryptorUtils; for (FlowFrameworkIndex mlIndex : FlowFrameworkIndex.values()) { indexMappingUpdated.put(mlIndex.getIndexName(), new AtomicBoolean(false)); } @@ -104,6 +109,15 @@ public static String getWorkflowStateMappings() throws IOException { return getIndexMappings(WORKFLOW_STATE_INDEX_MAPPING); } + /** + * Get config index mapping + * @return config index mapping + * @throws IOException if mapping file cannot be read correctly + */ + public static String getConfigIndexMappings() throws IOException { + return getIndexMappings(CONFIG_INDEX_MAPPING); + } + /** * Create global context index if it's absent * @param listener The action listener @@ -120,6 +134,14 @@ public void initWorkflowStateIndexIfAbsent(ActionListener listener) { initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex.WORKFLOW_STATE, listener); } + /** + * Create config index if it's absent + * @param listener The action listener + */ + public void initConfigIndexIfAbsent(ActionListener listener) { + initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex.CONFIG, listener); + } + /** * Checks if the given index exists * @param indexName the name of the index @@ -287,7 +309,8 @@ public void putTemplateToGlobalContext(Template template, ActionListener context.restore())); } catch (Exception e) { @@ -301,6 +324,23 @@ public void putTemplateToGlobalContext(Template template, ActionListener listener) { + initConfigIndexIfAbsent(ActionListener.wrap(indexCreated -> { + if (!indexCreated) { + listener.onFailure(new FlowFrameworkException("No response to create config index", INTERNAL_SERVER_ERROR)); + return; + } + encryptorUtils.initializeMasterKey(listener); + }, createIndexException -> { + logger.error("Failed to create config index", createIndexException); + listener.onFailure(createIndexException); + })); + } + /** * add document insert into global context index * @param workflowId the workflowId, corresponds to document ID of @@ -361,7 +401,8 @@ public void updateTemplateInGlobalContext(String documentId, Template template, XContentBuilder builder = XContentFactory.jsonBuilder(); ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext() ) { - request.source(template.toXContent(builder, ToXContent.EMPTY_PARAMS)) + Template encryptedTemplate = encryptorUtils.encryptTemplateCredentials(template); + request.source(encryptedTemplate.toXContent(builder, ToXContent.EMPTY_PARAMS)) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); client.index(request, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index a77e7dd79..6ca1c4661 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -119,25 +119,49 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { - flowFrameworkIndicesHandler.putInitialStateToWorkflowState( - globalContextResponse.getId(), - user, - ActionListener.wrap(stateResponse -> { - logger.info("create state workflow doc"); - listener.onResponse(new WorkflowResponse(globalContextResponse.getId())); - }, exception -> { - logger.error("Failed to save workflow state : {}", exception.getMessage()); - if (exception instanceof FlowFrameworkException) { - listener.onFailure(exception); - } else { - listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST)); - } - }) - ); + // Initialize config index and create new global context and state index entries + flowFrameworkIndicesHandler.initializeConfigIndex(ActionListener.wrap(isInitialized -> { + if (!isInitialized) { + listener.onFailure( + new FlowFrameworkException("Failed to initalize config index", RestStatus.INTERNAL_SERVER_ERROR) + ); + } else { + // Create new global context and state index entries + flowFrameworkIndicesHandler.putTemplateToGlobalContext( + templateWithUser, + ActionListener.wrap(globalContextResponse -> { + flowFrameworkIndicesHandler.putInitialStateToWorkflowState( + globalContextResponse.getId(), + user, + ActionListener.wrap(stateResponse -> { + logger.info("create state workflow doc"); + listener.onResponse(new WorkflowResponse(globalContextResponse.getId())); + }, exception -> { + logger.error("Failed to save workflow state : {}", exception.getMessage()); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure( + new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST) + ); + } + }) + ); + }, exception -> { + logger.error("Failed to save use case template : {}", exception.getMessage()); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure( + new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) + ); + } + + }) + ); + } }, exception -> { - logger.error("Failed to save use case template : {}", exception.getMessage()); + logger.error("Failed to initialize config index : {}", exception.getMessage()); if (exception instanceof FlowFrameworkException) { listener.onFailure(exception); } else { diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 1f19a4d04..b381b41ec 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -26,6 +26,7 @@ import org.opensearch.flowframework.model.State; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.tasks.Task; @@ -61,6 +62,7 @@ public class ProvisionWorkflowTransportAction extends HandledTransportAction provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId); diff --git a/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java b/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java new file mode 100644 index 000000000..70b30b5cc --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java @@ -0,0 +1,286 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.util; + +import com.google.common.collect.ImmutableMap; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowNode; + +import javax.crypto.spec.SecretKeySpec; + +import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Base64; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import com.amazonaws.encryptionsdk.AwsCrypto; +import com.amazonaws.encryptionsdk.CommitmentPolicy; +import com.amazonaws.encryptionsdk.CryptoResult; +import com.amazonaws.encryptionsdk.jce.JceMasterKey; + +import static org.opensearch.flowframework.common.CommonValue.CONFIG_INDEX; +import static org.opensearch.flowframework.common.CommonValue.CREATE_TIME; +import static org.opensearch.flowframework.common.CommonValue.CREDENTIAL_FIELD; +import static org.opensearch.flowframework.common.CommonValue.MASTER_KEY; + +/** + * Encryption utility class + */ +public class EncryptorUtils { + + private static final Logger logger = LogManager.getLogger(EncryptorUtils.class); + + private static final String ALGORITHM = "AES"; + private static final String PROVIDER = "Custom"; + private static final String WRAPPING_ALGORITHM = "AES/GCM/NoPadding"; + + private ClusterService clusterService; + private Client client; + private String masterKey; + + /** + * Instantiates a new EncryptorUtils object + * @param clusterService the cluster service + * @param client the node client + */ + public EncryptorUtils(ClusterService clusterService, Client client) { + this.masterKey = null; + this.clusterService = clusterService; + this.client = client; + } + + /** + * Sets the master key + * @param masterKey the master key + */ + void setMasterKey(String masterKey) { + this.masterKey = masterKey; + } + + /** + * Returns the master key + * @return the master key + */ + String getMasterKey() { + return this.masterKey; + } + + /** + * Randomly generate a master key + * @return the master key string + */ + String generateMasterKey() { + byte[] masterKeyBytes = new byte[32]; + new SecureRandom().nextBytes(masterKeyBytes); + return Base64.getEncoder().encodeToString(masterKeyBytes); + } + + /** + * Encrypts template credentials + * @param template the template to encrypt + * @return template with encrypted credentials + */ + public Template encryptTemplateCredentials(Template template) { + return processTemplateCredentials(template, this::encrypt); + } + + /** + * Decrypts template credentials + * @param template the template to decrypt + * @return template with decrypted credentials + */ + public Template decryptTemplateCredentials(Template template) { + return processTemplateCredentials(template, this::decrypt); + } + + // TODO : Improve processTemplateCredentials to encrypt different fields based on the WorkflowStep type + /** + * Applies the given cipher function on template credentials + * @param template the template to process + * @param cipher the encryption/decryption function to apply on credential values + * @return template with encrypted credentials + */ + private Template processTemplateCredentials(Template template, Function cipherFunction) { + Template.Builder processedTemplateBuilder = new Template.Builder(); + + Map processedWorkflows = new HashMap<>(); + for (Map.Entry entry : template.workflows().entrySet()) { + + List processedNodes = new ArrayList<>(); + for (WorkflowNode node : entry.getValue().nodes()) { + if (node.userInputs().containsKey(CREDENTIAL_FIELD)) { + // Apply the cipher funcion on all values within credential field + @SuppressWarnings("unchecked") + Map credentials = new HashMap<>((Map) node.userInputs().get(CREDENTIAL_FIELD)); + credentials.replaceAll((key, cred) -> cipherFunction.apply(cred)); + + // Replace credentials field in node user inputs + Map processedUserInputs = new HashMap<>(); + processedUserInputs.putAll(node.userInputs()); + processedUserInputs.replace(CREDENTIAL_FIELD, credentials); + + // build new node to add to processed nodes + WorkflowNode processedWorkflowNode = new WorkflowNode( + node.id(), + node.type(), + node.previousNodeInputs(), + processedUserInputs + ); + processedNodes.add(processedWorkflowNode); + } else { + processedNodes.add(node); + } + } + + // Add processed workflow nodes to processed workflows + processedWorkflows.put(entry.getKey(), new Workflow(entry.getValue().userParams(), processedNodes, entry.getValue().edges())); + } + + Template processedTemplate = processedTemplateBuilder.name(template.name()) + .description(template.description()) + .useCase(template.useCase()) + .templateVersion(template.templateVersion()) + .compatibilityVersion(template.compatibilityVersion()) + .workflows(processedWorkflows) + .uiMetadata(template.getUiMetadata()) + .user(template.getUser()) + .build(); + + return processedTemplate; + } + + /** + * Encrypts the given credential + * @param credential the credential to encrypt + * @return the encrypted credential + */ + String encrypt(final String credential) { + initializeMasterKeyIfAbsent(); + final AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt).build(); + byte[] bytes = Base64.getDecoder().decode(masterKey); + JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, ALGORITHM), PROVIDER, "", WRAPPING_ALGORITHM); + final CryptoResult encryptResult = crypto.encryptData( + jceMasterKey, + credential.getBytes(StandardCharsets.UTF_8) + ); + return Base64.getEncoder().encodeToString(encryptResult.getResult()); + } + + /** + * Decrypts the given credential + * @param encryptedCredential the credential to decrypt + * @return the decrypted credential + */ + String decrypt(final String encryptedCredential) { + initializeMasterKeyIfAbsent(); + final AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt).build(); + + byte[] bytes = Base64.getDecoder().decode(masterKey); + JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, ALGORITHM), PROVIDER, "", WRAPPING_ALGORITHM); + + final CryptoResult decryptedResult = crypto.decryptData( + jceMasterKey, + Base64.getDecoder().decode(encryptedCredential) + ); + return new String(decryptedResult.getResult(), StandardCharsets.UTF_8); + } + + /** + * Retrieves an existing master key or generates a new key to index + * @param listener the action listener + */ + public void initializeMasterKey(ActionListener listener) { + // Index has either been created or it already exists, need to check if master key has been initalized already, if not then + // generate + // This is necessary in case of global context index restoration from snapshot, will need to use the same master key to decrypt + // stored credentials + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + + GetRequest getRequest = new GetRequest(CONFIG_INDEX).id(MASTER_KEY); + client.get(getRequest, ActionListener.wrap(getResponse -> { + + if (!getResponse.isExists()) { + + // Generate new key and index + final String masterKey = generateMasterKey(); + IndexRequest masterKeyIndexRequest = new IndexRequest(CONFIG_INDEX).id(MASTER_KEY) + .source(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME, Instant.now().toEpochMilli())) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + client.index(masterKeyIndexRequest, ActionListener.wrap(indexResponse -> { + // Set generated key to master + logger.info("Config has been initialized successfully"); + this.masterKey = masterKey; + listener.onResponse(true); + }, indexException -> { + logger.error("Failed to index config", indexException); + listener.onFailure(indexException); + })); + + } else { + // Set existing key to master + logger.info("Config has already been initialized"); + final String masterKey = (String) getResponse.getSourceAsMap().get(MASTER_KEY); + this.masterKey = masterKey; + listener.onResponse(true); + } + }, getRequestException -> { + logger.error("Failed to search for config from config index", getRequestException); + listener.onFailure(getRequestException); + })); + + } catch (Exception e) { + logger.error("Failed to retrieve config from config index", e); + listener.onFailure(e); + } + } + + /** + * Retrieves master key from system index if not yet set + */ + void initializeMasterKeyIfAbsent() { + if (masterKey != null) { + return; + } + + if (!clusterService.state().metadata().hasIndex(CONFIG_INDEX)) { + throw new FlowFrameworkException("Config Index has not been initialized", RestStatus.INTERNAL_SERVER_ERROR); + } else { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + GetRequest getRequest = new GetRequest(CONFIG_INDEX).id(MASTER_KEY); + client.get(getRequest, ActionListener.wrap(response -> { + if (response.isExists()) { + this.masterKey = (String) response.getSourceAsMap().get(MASTER_KEY); + } else { + throw new FlowFrameworkException("Config has not been initialized", RestStatus.NOT_FOUND); + } + }, exception -> { throw new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)); })); + } + } + } + +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index 41bd71489..d685cbeaa 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -42,7 +42,7 @@ import java.util.stream.Stream; import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD; -import static org.opensearch.flowframework.common.CommonValue.CREDENTIALS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.CREDENTIAL_FIELD; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; @@ -155,8 +155,8 @@ public void onFailure(Exception e) { case PARAMETERS_FIELD: parameters = getParameterMap(entry.getValue()); break; - case CREDENTIALS_FIELD: - credentials = getStringToStringMap(entry.getValue(), CREDENTIALS_FIELD); + case CREDENTIAL_FIELD: + credentials = getStringToStringMap(entry.getValue(), CREDENTIAL_FIELD); break; case ACTIONS_FIELD: actions = getConnectorActionList(entry.getValue()); diff --git a/src/main/resources/mappings/config.json b/src/main/resources/mappings/config.json new file mode 100644 index 000000000..6871f541c --- /dev/null +++ b/src/main/resources/mappings/config.json @@ -0,0 +1,15 @@ +{ + "dynamic": false, + "_meta": { + "schema_version": 1 + }, + "properties": { + "master_key": { + "type": "keyword" + }, + "create_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + } + } +} diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index c89baba13..e3827e0b3 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -78,7 +78,7 @@ public void tearDown() throws Exception { public void testPlugin() throws IOException { try (FlowFrameworkPlugin ffp = new FlowFrameworkPlugin()) { assertEquals( - 3, + 4, ffp.createComponents(client, clusterService, threadPool, null, null, null, environment, null, null, null, null).size() ); assertEquals(4, ffp.getRestHandlers(settings, null, null, null, null, null, null).size()); diff --git a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java index 2f0fc256f..fd7d8e0d4 100644 --- a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java +++ b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java @@ -25,6 +25,7 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.workflow.CreateIndexStep; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -53,6 +54,8 @@ public class FlowFrameworkIndicesHandlerTests extends OpenSearchTestCase { private CreateIndexStep createIndexStep; @Mock private ThreadPool threadPool; + @Mock + private EncryptorUtils encryptorUtils; private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; private AdminClient adminClient; private IndicesAdminClient indicesAdminClient; @@ -77,7 +80,7 @@ public void setUp() throws Exception { threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService); + flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService, encryptorUtils); adminClient = mock(AdminClient.class); indicesAdminClient = mock(IndicesAdminClient.class); metadata = mock(Metadata.class); diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 22d831f2b..70c066c0e 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -231,6 +231,13 @@ public void testFailedToCreateNewWorkflow() { return null; }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); + // Bypass initializeConfigIndex and force onResponse + doAnswer(invocation -> { + ActionListener initalizeMasterKeyIndexListener = invocation.getArgument(0); + initalizeMasterKeyIndexListener.onResponse(true); + return null; + }).when(flowFrameworkIndicesHandler).initializeConfigIndex(any()); + doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(1); responseListener.onFailure(new Exception("Failed to create global_context index")); @@ -261,6 +268,13 @@ public void testCreateNewWorkflow() { return null; }).when(createWorkflowTransportAction).checkMaxWorkflows(any(TimeValue.class), anyInt(), any()); + // Bypass initializeConfigIndex and force onResponse + doAnswer(invocation -> { + ActionListener initalizeMasterKeyIndexListener = invocation.getArgument(0); + initalizeMasterKeyIndexListener.onResponse(true); + return null; + }).when(flowFrameworkIndicesHandler).initializeConfigIndex(any()); + // Bypass putTemplateToGlobalContext and force onResponse doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(1); diff --git a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java index 7d061e572..8bdcaa2c7 100644 --- a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java @@ -25,6 +25,7 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.index.get.GetResult; import org.opensearch.tasks.Task; @@ -53,6 +54,7 @@ public class ProvisionWorkflowTransportActionTests extends OpenSearchTestCase { private ProvisionWorkflowTransportAction provisionWorkflowTransportAction; private Template template; private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private EncryptorUtils encryptorUtils; @Override public void setUp() throws Exception { @@ -61,6 +63,7 @@ public void setUp() throws Exception { this.client = mock(Client.class); this.workflowProcessSorter = mock(WorkflowProcessSorter.class); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + this.encryptorUtils = mock(EncryptorUtils.class); this.provisionWorkflowTransportAction = new ProvisionWorkflowTransportAction( mock(TransportService.class), @@ -68,7 +71,8 @@ public void setUp() throws Exception { threadPool, client, workflowProcessSorter, - flowFrameworkIndicesHandler + flowFrameworkIndicesHandler, + encryptorUtils ); Version templateVersion = Version.fromString("1.0.0"); @@ -116,6 +120,8 @@ public void testProvisionWorkflow() { return null; }).when(client).get(any(GetRequest.class), any()); + when(encryptorUtils.decryptTemplateCredentials(any())).thenReturn(template); + provisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); verify(listener, times(1)).onResponse(responseCaptor.capture()); diff --git a/src/test/java/org/opensearch/flowframework/util/EncryptorUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/EncryptorUtilsTests.java new file mode 100644 index 000000000..1101fa1a2 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/util/EncryptorUtilsTests.java @@ -0,0 +1,193 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.util; + +import com.google.common.collect.ImmutableMap; +import org.opensearch.Version; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.TestHelpers; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; + +import java.util.List; +import java.util.Map; + +import static org.opensearch.flowframework.common.CommonValue.CREDENTIAL_FIELD; +import static org.opensearch.flowframework.common.CommonValue.MASTER_KEY; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class EncryptorUtilsTests extends OpenSearchTestCase { + + private ClusterService clusterService; + private Client client; + private EncryptorUtils encryptorUtils; + private String testMasterKey; + private Template testTemplate; + private String testCredentialKey; + private String testCredentialValue; + + @Override + public void setUp() throws Exception { + super.setUp(); + this.clusterService = mock(ClusterService.class); + this.client = mock(Client.class); + this.encryptorUtils = new EncryptorUtils(clusterService, client); + this.testMasterKey = encryptorUtils.generateMasterKey(); + this.testCredentialKey = "credential_key"; + this.testCredentialValue = "12345"; + + Version templateVersion = Version.fromString("1.0.0"); + List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); + WorkflowNode nodeA = new WorkflowNode( + "A", + "a-type", + Map.of(), + Map.of(CREDENTIAL_FIELD, Map.of(testCredentialKey, testCredentialValue)) + ); + List nodes = List.of(nodeA); + Workflow workflow = new Workflow(Map.of("key", "value"), nodes, List.of()); + + this.testTemplate = new Template( + "test", + "description", + "use case", + templateVersion, + compatibilityVersions, + Map.of("provision", workflow), + Map.of(), + TestHelpers.randomUser() + ); + + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex(anyString())).thenReturn(true); + + ThreadPool threadPool = mock(ThreadPool.class); + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + } + + public void testGenerateMasterKey() { + String generatedMasterKey = encryptorUtils.generateMasterKey(); + encryptorUtils.setMasterKey(generatedMasterKey); + assertEquals(generatedMasterKey, encryptorUtils.getMasterKey()); + } + + public void testEncryptDecrypt() { + encryptorUtils.setMasterKey(testMasterKey); + String testString = "test"; + String encrypted = encryptorUtils.encrypt(testString); + assertNotNull(encrypted); + + String decrypted = encryptorUtils.decrypt(encrypted); + assertEquals(testString, decrypted); + } + + public void testEncryptWithDifferentMasterKey() { + encryptorUtils.setMasterKey(testMasterKey); + String testString = "test"; + String encrypted1 = encryptorUtils.encrypt(testString); + assertNotNull(encrypted1); + + // Change the master key prior to encryption + String differentMasterKey = encryptorUtils.generateMasterKey(); + encryptorUtils.setMasterKey(differentMasterKey); + String encrypted2 = encryptorUtils.encrypt(testString); + + assertNotEquals(encrypted1, encrypted2); + } + + public void testInitializeMasterKeySuccess() { + encryptorUtils.setMasterKey(null); + + String masterKey = encryptorUtils.generateMasterKey(); + doAnswer(invocation -> { + ActionListener getRequestActionListener = invocation.getArgument(1); + + // Stub get response for success case + GetResponse getResponse = mock(GetResponse.class); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSourceAsMap()).thenReturn(ImmutableMap.of(MASTER_KEY, masterKey)); + + getRequestActionListener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), any()); + + encryptorUtils.initializeMasterKeyIfAbsent(); + assertEquals(masterKey, encryptorUtils.getMasterKey()); + } + + public void testInitializeMasterKeyFailure() { + encryptorUtils.setMasterKey(null); + + doAnswer(invocation -> { + ActionListener getRequestActionListener = invocation.getArgument(1); + + // Stub get response for failure case + GetResponse getResponse = mock(GetResponse.class); + when(getResponse.isExists()).thenReturn(false); + getRequestActionListener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), any()); + + FlowFrameworkException ex = expectThrows(FlowFrameworkException.class, () -> encryptorUtils.initializeMasterKeyIfAbsent()); + assertEquals("Config has not been initialized", ex.getMessage()); + } + + public void testEncryptDecryptTemplateCredential() { + encryptorUtils.setMasterKey(testMasterKey); + + // Ecnrypt template with credential field + Template processedtemplate = encryptorUtils.encryptTemplateCredentials(testTemplate); + + // Validate the encrytped field + WorkflowNode node = processedtemplate.workflows().get("provision").nodes().get(0); + + @SuppressWarnings("unchecked") + Map encryptedCredentialMap = (Map) node.userInputs().get(CREDENTIAL_FIELD); + assertEquals(1, encryptedCredentialMap.size()); + + String encryptedCredential = encryptedCredentialMap.get(testCredentialKey); + assertNotNull(encryptedCredential); + assertNotEquals(testCredentialValue, encryptedCredential); + + // Decrypt credential field + processedtemplate = encryptorUtils.decryptTemplateCredentials(processedtemplate); + + // Validate the decrypted field + node = processedtemplate.workflows().get("provision").nodes().get(0); + + @SuppressWarnings("unchecked") + Map decryptedCredentialMap = (Map) node.userInputs().get(CREDENTIAL_FIELD); + assertEquals(1, decryptedCredentialMap.size()); + + String decryptedCredential = decryptedCredentialMap.get(testCredentialKey); + assertNotNull(decryptedCredential); + assertEquals(testCredentialValue, decryptedCredential); + } +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index a05a3927e..e26fdf0c4 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -68,7 +68,7 @@ public void setUp() throws Exception { Map.entry(CommonValue.VERSION_FIELD, "1"), Map.entry(CommonValue.PROTOCOL_FIELD, "test"), Map.entry(CommonValue.PARAMETERS_FIELD, params), - Map.entry(CommonValue.CREDENTIALS_FIELD, credentials), + Map.entry(CommonValue.CREDENTIAL_FIELD, credentials), Map.entry(CommonValue.ACTIONS_FIELD, actions) ), "test-id"