Skip to content

Commit

Permalink
Encrypt/Decrypt template credentials (opensearch-project#197)
Browse files Browse the repository at this point in the history
* added RegisterRemoteModelStep and tests

Signed-off-by: Joshua Palis <[email protected]>

* Adding RegisterLocalModelStep, fixing tests, adding input/ouput definitions to workflow step json

Signed-off-by: Joshua Palis <[email protected]>

* Fixing javadoc warnings, fixing log message

Signed-off-by: Joshua Palis <[email protected]>

* Addressing PR comments,making description field optional for RegisterRemoteModelStep and RegisterLocalModelStep

Signed-off-by: Joshua Palis <[email protected]>

* moving modelConfig builder before adding allConfig

Signed-off-by: Joshua Palis <[email protected]>

* initial implementation

Signed-off-by: Joshua Palis <[email protected]>

* Fixing create workflow transport action

Signed-off-by: Joshua Palis <[email protected]>

* Removing duplicate register_remote_model validator

Signed-off-by: Joshua Palis <[email protected]>

* Adding bouncy castle dependency to resolve encryption issue

Signed-off-by: Joshua Palis <[email protected]>

* Fixing CreateWorkflowTransportActionTests

Signed-off-by: Joshua Palis <[email protected]>

* Adding initial unit tests for encryptor utils

Signed-off-by: Joshua Palis <[email protected]>

* Implemented encryption/decryption for workflow node user inputs with credential

Signed-off-by: Joshua Palis <[email protected]>

* Addressing PR comments

Signed-off-by: Joshua Palis <[email protected]>

* Suppressing unchecked warning, making credential strings constants

Signed-off-by: Joshua Palis <[email protected]>

* Removing setMasterKey from initializeMasterKey method

Signed-off-by: Joshua Palis <[email protected]>

* Adding final template encryption decryption test

Signed-off-by: Joshua Palis <[email protected]>

* Addressing PR comments, changing master key index name to config, fixes error messages as well. Added create time field to config index, ensured that updates are also encrypted

Signed-off-by: Joshua Palis <[email protected]>

* Added TODO

Signed-off-by: Joshua Palis <[email protected]>

* changing getMasterKeyIndexMapping method name

Signed-off-by: Joshua Palis <[email protected]>

* Removing unnecessary aws sdk dependency

Signed-off-by: Joshua Palis <[email protected]>

---------

Signed-off-by: Joshua Palis <[email protected]>
  • Loading branch information
joshpalis authored Nov 29, 2023
1 parent 70e7926 commit 772fbbd
Show file tree
Hide file tree
Showing 16 changed files with 648 additions and 33 deletions.
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ dependencies {
api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}"
implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.14.0'
implementation "org.opensearch:common-utils:${common_utils_version}"
implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1'
implementation 'org.bouncycastle:bcprov-jdk18on:1.77'

configurations.all {
resolutionStrategy {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.opensearch.flowframework.transport.ProvisionWorkflowTransportAction;
import org.opensearch.flowframework.transport.SearchWorkflowAction;
import org.opensearch.flowframework.transport.SearchWorkflowTransportAction;
import org.opensearch.flowframework.util.EncryptorUtils;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
import org.opensearch.ml.client.MachineLearningNodeClient;
Expand Down Expand Up @@ -96,7 +97,8 @@ public Collection<Object> createComponents(
this.clusterService = clusterService;
flowFrameworkFeatureEnabledSetting = new FlowFrameworkFeatureEnabledSetting(clusterService, settings);
MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client);
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService);
EncryptorUtils encryptorUtils = new EncryptorUtils(clusterService, client);
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService, encryptorUtils);
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(
settings,
clusterService,
Expand All @@ -106,7 +108,7 @@ public Collection<Object> createComponents(
);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool);

return ImmutableList.of(workflowStepFactory, workflowProcessSorter, flowFrameworkIndicesHandler);
return ImmutableList.of(workflowStepFactory, workflowProcessSorter, encryptorUtils, flowFrameworkIndicesHandler);
}

@Override
Expand Down
14 changes: 12 additions & 2 deletions src/main/java/org/opensearch/flowframework/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ private CommonValue() {}
public static final String WORKFLOW_STATE_INDEX_MAPPING = "mappings/workflow-state.json";
/** Workflow State index mapping version */
public static final Integer WORKFLOW_STATE_INDEX_VERSION = 1;
/** Config Index Name */
public static final String CONFIG_INDEX = ".plugins-flow-framework-config";
/** Config index mapping file path */
public static final String CONFIG_INDEX_MAPPING = "mappings/config.json";
/** Config index mapping version */
public static final Integer CONFIG_INDEX_VERSION = 1;
/** Master key field name */
public static final String MASTER_KEY = "master_key";
/** Create Time field name */
public static final String CREATE_TIME = "create_time";

/** The template field name for template use case */
public static final String USE_CASE_FIELD = "use_case";
Expand Down Expand Up @@ -119,8 +129,8 @@ private CommonValue() {}
public static final String PROTOCOL_FIELD = "protocol";
/** Connector parameters field */
public static final String PARAMETERS_FIELD = "parameters";
/** Connector credentials field */
public static final String CREDENTIALS_FIELD = "credentials";
/** Connector credential field */
public static final String CREDENTIAL_FIELD = "credential";
/** Connector actions field */
public static final String ACTIONS_FIELD = "actions";
/** Backend roles for the model */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import java.util.function.Supplier;

import static org.opensearch.flowframework.common.CommonValue.CONFIG_INDEX;
import static org.opensearch.flowframework.common.CommonValue.CONFIG_INDEX_VERSION;
import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX;
import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX_VERSION;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX;
Expand All @@ -36,6 +38,14 @@ public enum FlowFrameworkIndex {
WORKFLOW_STATE_INDEX,
ThrowingSupplierWrapper.throwingSupplierWrapper(FlowFrameworkIndicesHandler::getWorkflowStateMappings),
WORKFLOW_STATE_INDEX_VERSION
),
/**
* Config Index
*/
CONFIG(
CONFIG_INDEX,
ThrowingSupplierWrapper.throwingSupplierWrapper(FlowFrameworkIndicesHandler::getConfigIndexMappings),
CONFIG_INDEX_VERSION
);

private final String indexName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.opensearch.flowframework.model.State;
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.model.WorkflowState;
import org.opensearch.flowframework.util.EncryptorUtils;
import org.opensearch.script.Script;

import java.io.IOException;
Expand All @@ -48,6 +49,7 @@
import java.util.concurrent.atomic.AtomicBoolean;

import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.flowframework.common.CommonValue.CONFIG_INDEX_MAPPING;
import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX;
import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX_MAPPING;
import static org.opensearch.flowframework.common.CommonValue.META;
Expand All @@ -64,17 +66,20 @@ public class FlowFrameworkIndicesHandler {
private static final Logger logger = LogManager.getLogger(FlowFrameworkIndicesHandler.class);
private final Client client;
private final ClusterService clusterService;
private final EncryptorUtils encryptorUtils;
private static final Map<String, AtomicBoolean> indexMappingUpdated = new HashMap<>();
private static final Map<String, Object> indexSettings = Map.of("index.auto_expand_replicas", "0-1");

/**
* constructor
* @param client the open search client
* @param clusterService ClusterService
* @param encryptorUtils encryption utility
*/
public FlowFrameworkIndicesHandler(Client client, ClusterService clusterService) {
public FlowFrameworkIndicesHandler(Client client, ClusterService clusterService, EncryptorUtils encryptorUtils) {
this.client = client;
this.clusterService = clusterService;
this.encryptorUtils = encryptorUtils;
for (FlowFrameworkIndex mlIndex : FlowFrameworkIndex.values()) {
indexMappingUpdated.put(mlIndex.getIndexName(), new AtomicBoolean(false));
}
Expand Down Expand Up @@ -104,6 +109,15 @@ public static String getWorkflowStateMappings() throws IOException {
return getIndexMappings(WORKFLOW_STATE_INDEX_MAPPING);
}

/**
* Get config index mapping
* @return config index mapping
* @throws IOException if mapping file cannot be read correctly
*/
public static String getConfigIndexMappings() throws IOException {
return getIndexMappings(CONFIG_INDEX_MAPPING);
}

/**
* Create global context index if it's absent
* @param listener The action listener
Expand All @@ -120,6 +134,14 @@ public void initWorkflowStateIndexIfAbsent(ActionListener<Boolean> listener) {
initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex.WORKFLOW_STATE, listener);
}

/**
* Create config index if it's absent
* @param listener The action listener
*/
public void initConfigIndexIfAbsent(ActionListener<Boolean> listener) {
initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex.CONFIG, listener);
}

/**
* Checks if the given index exists
* @param indexName the name of the index
Expand Down Expand Up @@ -287,7 +309,8 @@ public void putTemplateToGlobalContext(Template template, ActionListener<IndexRe
XContentBuilder builder = XContentFactory.jsonBuilder();
ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()
) {
request.source(template.toXContent(builder, ToXContent.EMPTY_PARAMS))
Template templateWithEncryptedCredentials = encryptorUtils.encryptTemplateCredentials(template);
request.source(templateWithEncryptedCredentials.toXContent(builder, ToXContent.EMPTY_PARAMS))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.index(request, ActionListener.runBefore(listener, () -> context.restore()));
} catch (Exception e) {
Expand All @@ -301,6 +324,23 @@ public void putTemplateToGlobalContext(Template template, ActionListener<IndexRe
}));
}

/**
* Initializes config index and EncryptorUtils
* @param listener action listener
*/
public void initializeConfigIndex(ActionListener<Boolean> listener) {
initConfigIndexIfAbsent(ActionListener.wrap(indexCreated -> {
if (!indexCreated) {
listener.onFailure(new FlowFrameworkException("No response to create config index", INTERNAL_SERVER_ERROR));
return;
}
encryptorUtils.initializeMasterKey(listener);
}, createIndexException -> {
logger.error("Failed to create config index", createIndexException);
listener.onFailure(createIndexException);
}));
}

/**
* add document insert into global context index
* @param workflowId the workflowId, corresponds to document ID of
Expand Down Expand Up @@ -361,7 +401,8 @@ public void updateTemplateInGlobalContext(String documentId, Template template,
XContentBuilder builder = XContentFactory.jsonBuilder();
ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()
) {
request.source(template.toXContent(builder, ToXContent.EMPTY_PARAMS))
Template encryptedTemplate = encryptorUtils.encryptTemplateCredentials(template);
request.source(encryptedTemplate.toXContent(builder, ToXContent.EMPTY_PARAMS))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.index(request, ActionListener.runBefore(listener, () -> context.restore()));
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,25 +119,49 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work
listener.onFailure(ffe);
return;
} else {
// Create new global context and state index entries
flowFrameworkIndicesHandler.putTemplateToGlobalContext(templateWithUser, ActionListener.wrap(globalContextResponse -> {
flowFrameworkIndicesHandler.putInitialStateToWorkflowState(
globalContextResponse.getId(),
user,
ActionListener.wrap(stateResponse -> {
logger.info("create state workflow doc");
listener.onResponse(new WorkflowResponse(globalContextResponse.getId()));
}, exception -> {
logger.error("Failed to save workflow state : {}", exception.getMessage());
if (exception instanceof FlowFrameworkException) {
listener.onFailure(exception);
} else {
listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST));
}
})
);
// Initialize config index and create new global context and state index entries
flowFrameworkIndicesHandler.initializeConfigIndex(ActionListener.wrap(isInitialized -> {
if (!isInitialized) {
listener.onFailure(
new FlowFrameworkException("Failed to initalize config index", RestStatus.INTERNAL_SERVER_ERROR)
);
} else {
// Create new global context and state index entries
flowFrameworkIndicesHandler.putTemplateToGlobalContext(
templateWithUser,
ActionListener.wrap(globalContextResponse -> {
flowFrameworkIndicesHandler.putInitialStateToWorkflowState(
globalContextResponse.getId(),
user,
ActionListener.wrap(stateResponse -> {
logger.info("create state workflow doc");
listener.onResponse(new WorkflowResponse(globalContextResponse.getId()));
}, exception -> {
logger.error("Failed to save workflow state : {}", exception.getMessage());
if (exception instanceof FlowFrameworkException) {
listener.onFailure(exception);
} else {
listener.onFailure(
new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST)
);
}
})
);
}, exception -> {
logger.error("Failed to save use case template : {}", exception.getMessage());
if (exception instanceof FlowFrameworkException) {
listener.onFailure(exception);
} else {
listener.onFailure(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))
);
}

})
);
}
}, exception -> {
logger.error("Failed to save use case template : {}", exception.getMessage());
logger.error("Failed to initialize config index : {}", exception.getMessage());
if (exception instanceof FlowFrameworkException) {
listener.onFailure(exception);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.opensearch.flowframework.model.State;
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.model.Workflow;
import org.opensearch.flowframework.util.EncryptorUtils;
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.tasks.Task;
Expand Down Expand Up @@ -61,6 +62,7 @@ public class ProvisionWorkflowTransportAction extends HandledTransportAction<Wor
private final Client client;
private final WorkflowProcessSorter workflowProcessSorter;
private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;
private final EncryptorUtils encryptorUtils;

/**
* Instantiates a new ProvisionWorkflowTransportAction
Expand All @@ -70,6 +72,7 @@ public class ProvisionWorkflowTransportAction extends HandledTransportAction<Wor
* @param client The node client to retrieve a stored use case template
* @param workflowProcessSorter Utility class to generate a togologically sorted list of Process nodes
* @param flowFrameworkIndicesHandler Class to handle all internal system indices actions
* @param encryptorUtils Utility class to handle encryption/decryption
*/
@Inject
public ProvisionWorkflowTransportAction(
Expand All @@ -78,13 +81,15 @@ public ProvisionWorkflowTransportAction(
ThreadPool threadPool,
Client client,
WorkflowProcessSorter workflowProcessSorter,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler,
EncryptorUtils encryptorUtils
) {
super(ProvisionWorkflowAction.NAME, transportService, actionFilters, WorkflowRequest::new);
this.threadPool = threadPool;
this.client = client;
this.workflowProcessSorter = workflowProcessSorter;
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
this.encryptorUtils = encryptorUtils;
}

@Override
Expand All @@ -110,6 +115,10 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work

// Parse template from document source
Template template = Template.parse(response.getSourceAsString());

// Decrypt template
template = encryptorUtils.decryptTemplateCredentials(template);

// Sort and validate graph
Workflow provisionWorkflow = template.workflows().get(PROVISION_WORKFLOW);
List<ProcessNode> provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId);
Expand Down
Loading

0 comments on commit 772fbbd

Please sign in to comment.