Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Encrypt/Decrypt template credentials #216

Merged
merged 1 commit into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ dependencies {
api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}"
implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.14.0'
implementation "org.opensearch:common-utils:${common_utils_version}"
implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1'
implementation 'org.bouncycastle:bcprov-jdk18on:1.77'

configurations.all {
resolutionStrategy {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.opensearch.flowframework.transport.ProvisionWorkflowTransportAction;
import org.opensearch.flowframework.transport.SearchWorkflowAction;
import org.opensearch.flowframework.transport.SearchWorkflowTransportAction;
import org.opensearch.flowframework.util.EncryptorUtils;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
import org.opensearch.ml.client.MachineLearningNodeClient;
Expand Down Expand Up @@ -96,7 +97,8 @@ public Collection<Object> createComponents(
this.clusterService = clusterService;
flowFrameworkFeatureEnabledSetting = new FlowFrameworkFeatureEnabledSetting(clusterService, settings);
MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client);
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService);
EncryptorUtils encryptorUtils = new EncryptorUtils(clusterService, client);
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService, encryptorUtils);
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(
settings,
clusterService,
Expand All @@ -106,7 +108,7 @@ public Collection<Object> createComponents(
);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool);

return ImmutableList.of(workflowStepFactory, workflowProcessSorter, flowFrameworkIndicesHandler);
return ImmutableList.of(workflowStepFactory, workflowProcessSorter, encryptorUtils, flowFrameworkIndicesHandler);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ private CommonValue() {}
public static final String WORKFLOW_STATE_INDEX_MAPPING = "mappings/workflow-state.json";
/** Workflow State index mapping version */
public static final Integer WORKFLOW_STATE_INDEX_VERSION = 1;
/** Config Index Name */
public static final String CONFIG_INDEX = ".plugins-flow-framework-config";
/** Config index mapping file path */
public static final String CONFIG_INDEX_MAPPING = "mappings/config.json";
/** Config index mapping version */
public static final Integer CONFIG_INDEX_VERSION = 1;
/** Master key field name */
public static final String MASTER_KEY = "master_key";
/** Create Time field name */
public static final String CREATE_TIME = "create_time";

/** The template field name for template use case */
public static final String USE_CASE_FIELD = "use_case";
Expand Down Expand Up @@ -119,8 +129,8 @@ private CommonValue() {}
public static final String PROTOCOL_FIELD = "protocol";
/** Connector parameters field */
public static final String PARAMETERS_FIELD = "parameters";
/** Connector credentials field */
public static final String CREDENTIALS_FIELD = "credentials";
/** Connector credential field */
public static final String CREDENTIAL_FIELD = "credential";
/** Connector actions field */
public static final String ACTIONS_FIELD = "actions";
/** Backend roles for the model */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import java.util.function.Supplier;

import static org.opensearch.flowframework.common.CommonValue.CONFIG_INDEX;
import static org.opensearch.flowframework.common.CommonValue.CONFIG_INDEX_VERSION;
import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX;
import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX_VERSION;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX;
Expand All @@ -36,6 +38,14 @@ public enum FlowFrameworkIndex {
WORKFLOW_STATE_INDEX,
ThrowingSupplierWrapper.throwingSupplierWrapper(FlowFrameworkIndicesHandler::getWorkflowStateMappings),
WORKFLOW_STATE_INDEX_VERSION
),
/**
* Config Index
*/
CONFIG(
CONFIG_INDEX,
ThrowingSupplierWrapper.throwingSupplierWrapper(FlowFrameworkIndicesHandler::getConfigIndexMappings),
CONFIG_INDEX_VERSION
);

private final String indexName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.opensearch.flowframework.model.State;
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.model.WorkflowState;
import org.opensearch.flowframework.util.EncryptorUtils;
import org.opensearch.script.Script;

import java.io.IOException;
Expand All @@ -48,6 +49,7 @@
import java.util.concurrent.atomic.AtomicBoolean;

import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.flowframework.common.CommonValue.CONFIG_INDEX_MAPPING;
import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX;
import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX_MAPPING;
import static org.opensearch.flowframework.common.CommonValue.META;
Expand All @@ -64,17 +66,20 @@ public class FlowFrameworkIndicesHandler {
private static final Logger logger = LogManager.getLogger(FlowFrameworkIndicesHandler.class);
private final Client client;
private final ClusterService clusterService;
private final EncryptorUtils encryptorUtils;
private static final Map<String, AtomicBoolean> indexMappingUpdated = new HashMap<>();
private static final Map<String, Object> indexSettings = Map.of("index.auto_expand_replicas", "0-1");

/**
* constructor
* @param client the open search client
* @param clusterService ClusterService
* @param encryptorUtils encryption utility
*/
public FlowFrameworkIndicesHandler(Client client, ClusterService clusterService) {
public FlowFrameworkIndicesHandler(Client client, ClusterService clusterService, EncryptorUtils encryptorUtils) {
this.client = client;
this.clusterService = clusterService;
this.encryptorUtils = encryptorUtils;
for (FlowFrameworkIndex mlIndex : FlowFrameworkIndex.values()) {
indexMappingUpdated.put(mlIndex.getIndexName(), new AtomicBoolean(false));
}
Expand Down Expand Up @@ -104,6 +109,15 @@ public static String getWorkflowStateMappings() throws IOException {
return getIndexMappings(WORKFLOW_STATE_INDEX_MAPPING);
}

/**
* Get config index mapping
* @return config index mapping
* @throws IOException if mapping file cannot be read correctly
*/
public static String getConfigIndexMappings() throws IOException {
return getIndexMappings(CONFIG_INDEX_MAPPING);
}

/**
* Create global context index if it's absent
* @param listener The action listener
Expand All @@ -120,6 +134,14 @@ public void initWorkflowStateIndexIfAbsent(ActionListener<Boolean> listener) {
initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex.WORKFLOW_STATE, listener);
}

/**
* Create config index if it's absent
* @param listener The action listener
*/
public void initConfigIndexIfAbsent(ActionListener<Boolean> listener) {
initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex.CONFIG, listener);
}

/**
* Checks if the given index exists
* @param indexName the name of the index
Expand Down Expand Up @@ -287,7 +309,8 @@ public void putTemplateToGlobalContext(Template template, ActionListener<IndexRe
XContentBuilder builder = XContentFactory.jsonBuilder();
ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()
) {
request.source(template.toXContent(builder, ToXContent.EMPTY_PARAMS))
Template templateWithEncryptedCredentials = encryptorUtils.encryptTemplateCredentials(template);
request.source(templateWithEncryptedCredentials.toXContent(builder, ToXContent.EMPTY_PARAMS))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.index(request, ActionListener.runBefore(listener, () -> context.restore()));
} catch (Exception e) {
Expand All @@ -301,6 +324,23 @@ public void putTemplateToGlobalContext(Template template, ActionListener<IndexRe
}));
}

/**
* Initializes config index and EncryptorUtils
* @param listener action listener
*/
public void initializeConfigIndex(ActionListener<Boolean> listener) {
initConfigIndexIfAbsent(ActionListener.wrap(indexCreated -> {
if (!indexCreated) {
listener.onFailure(new FlowFrameworkException("No response to create config index", INTERNAL_SERVER_ERROR));
return;
}
encryptorUtils.initializeMasterKey(listener);
}, createIndexException -> {
logger.error("Failed to create config index", createIndexException);
listener.onFailure(createIndexException);
}));
}

/**
* add document insert into global context index
* @param workflowId the workflowId, corresponds to document ID of
Expand Down Expand Up @@ -361,7 +401,8 @@ public void updateTemplateInGlobalContext(String documentId, Template template,
XContentBuilder builder = XContentFactory.jsonBuilder();
ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()
) {
request.source(template.toXContent(builder, ToXContent.EMPTY_PARAMS))
Template encryptedTemplate = encryptorUtils.encryptTemplateCredentials(template);
request.source(encryptedTemplate.toXContent(builder, ToXContent.EMPTY_PARAMS))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.index(request, ActionListener.runBefore(listener, () -> context.restore()));
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,25 +119,49 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work
listener.onFailure(ffe);
return;
} else {
// Create new global context and state index entries
flowFrameworkIndicesHandler.putTemplateToGlobalContext(templateWithUser, ActionListener.wrap(globalContextResponse -> {
flowFrameworkIndicesHandler.putInitialStateToWorkflowState(
globalContextResponse.getId(),
user,
ActionListener.wrap(stateResponse -> {
logger.info("create state workflow doc");
listener.onResponse(new WorkflowResponse(globalContextResponse.getId()));
}, exception -> {
logger.error("Failed to save workflow state : {}", exception.getMessage());
if (exception instanceof FlowFrameworkException) {
listener.onFailure(exception);
} else {
listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST));
}
})
);
// Initialize config index and create new global context and state index entries
flowFrameworkIndicesHandler.initializeConfigIndex(ActionListener.wrap(isInitialized -> {
if (!isInitialized) {
listener.onFailure(
new FlowFrameworkException("Failed to initalize config index", RestStatus.INTERNAL_SERVER_ERROR)
);
} else {
// Create new global context and state index entries
flowFrameworkIndicesHandler.putTemplateToGlobalContext(
templateWithUser,
ActionListener.wrap(globalContextResponse -> {
flowFrameworkIndicesHandler.putInitialStateToWorkflowState(
globalContextResponse.getId(),
user,
ActionListener.wrap(stateResponse -> {
logger.info("create state workflow doc");
listener.onResponse(new WorkflowResponse(globalContextResponse.getId()));
}, exception -> {
logger.error("Failed to save workflow state : {}", exception.getMessage());
if (exception instanceof FlowFrameworkException) {
listener.onFailure(exception);
} else {
listener.onFailure(
new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST)
);
}
})
);
}, exception -> {
logger.error("Failed to save use case template : {}", exception.getMessage());
if (exception instanceof FlowFrameworkException) {
listener.onFailure(exception);
} else {
listener.onFailure(
new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))
);
}

})
);
}
}, exception -> {
logger.error("Failed to save use case template : {}", exception.getMessage());
logger.error("Failed to initialize config index : {}", exception.getMessage());
if (exception instanceof FlowFrameworkException) {
listener.onFailure(exception);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.opensearch.flowframework.model.State;
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.model.Workflow;
import org.opensearch.flowframework.util.EncryptorUtils;
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.tasks.Task;
Expand Down Expand Up @@ -61,6 +62,7 @@ public class ProvisionWorkflowTransportAction extends HandledTransportAction<Wor
private final Client client;
private final WorkflowProcessSorter workflowProcessSorter;
private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;
private final EncryptorUtils encryptorUtils;

/**
* Instantiates a new ProvisionWorkflowTransportAction
Expand All @@ -70,6 +72,7 @@ public class ProvisionWorkflowTransportAction extends HandledTransportAction<Wor
* @param client The node client to retrieve a stored use case template
* @param workflowProcessSorter Utility class to generate a togologically sorted list of Process nodes
* @param flowFrameworkIndicesHandler Class to handle all internal system indices actions
* @param encryptorUtils Utility class to handle encryption/decryption
*/
@Inject
public ProvisionWorkflowTransportAction(
Expand All @@ -78,13 +81,15 @@ public ProvisionWorkflowTransportAction(
ThreadPool threadPool,
Client client,
WorkflowProcessSorter workflowProcessSorter,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler,
EncryptorUtils encryptorUtils
) {
super(ProvisionWorkflowAction.NAME, transportService, actionFilters, WorkflowRequest::new);
this.threadPool = threadPool;
this.client = client;
this.workflowProcessSorter = workflowProcessSorter;
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
this.encryptorUtils = encryptorUtils;
}

@Override
Expand All @@ -110,6 +115,10 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work

// Parse template from document source
Template template = Template.parse(response.getSourceAsString());

// Decrypt template
template = encryptorUtils.decryptTemplateCredentials(template);

// Sort and validate graph
Workflow provisionWorkflow = template.workflows().get(PROVISION_WORKFLOW);
List<ProcessNode> provisionProcessSequence = workflowProcessSorter.sortProcessNodes(provisionWorkflow, workflowId);
Expand Down
Loading
Loading