From cd9ddb90f1fb0bfab3deef1bc16709e67ac41acd Mon Sep 17 00:00:00 2001 From: Amit Galitzky Date: Thu, 19 Oct 2023 17:45:00 -0700 Subject: [PATCH] adding state index initial Signed-off-by: Amit Galitzky --- build.gradle | 2 + .../flowframework/FlowFrameworkPlugin.java | 8 +- .../flowframework/common/CommonValue.java | 6 + .../indices/FlowFrameworkIndex.java | 9 +- .../indices/FlowFrameworkIndicesHandler.java | 432 ++++++++++++++++++ .../indices/GlobalContextHandler.java | 151 ------ .../model/PipelineProcessor.java | 4 +- .../model/ProvisioningProgress.java | 15 + .../opensearch/flowframework/model/State.java | 16 + .../flowframework/model/Template.java | 4 +- .../flowframework/model/WorkflowNode.java | 4 +- .../flowframework/model/WorkflowState.java | 271 +++++++++++ .../CreateWorkflowTransportAction.java | 48 +- .../ProvisionWorkflowTransportAction.java | 26 +- .../ParseUtils.java} | 44 +- .../workflow/CreateIndexStep.java | 287 ++++++------ .../mappings/knn-text-search-default.json | 20 + .../resources/mappings/workflow-state.json | 87 ++++ .../FlowFrameworkIndicesHandlerTests.java | 254 ++++++++++ .../indices/GlobalContextHandlerTests.java | 146 ------ .../CreateWorkflowTransportActionTests.java | 19 +- ...ProvisionWorkflowTransportActionTests.java | 6 +- .../workflow/CreateIndexStepTests.java | 88 ---- 23 files changed, 1373 insertions(+), 574 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java delete mode 100644 src/main/java/org/opensearch/flowframework/indices/GlobalContextHandler.java create mode 100644 src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java create mode 100644 src/main/java/org/opensearch/flowframework/model/State.java create mode 100644 src/main/java/org/opensearch/flowframework/model/WorkflowState.java rename src/main/java/org/opensearch/flowframework/{common/TemplateUtil.java => util/ParseUtils.java} (60%) create mode 100644 src/main/resources/mappings/knn-text-search-default.json create mode 100644 src/main/resources/mappings/workflow-state.json create mode 100644 src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java delete mode 100644 src/test/java/org/opensearch/flowframework/indices/GlobalContextHandlerTests.java diff --git a/build.gradle b/build.gradle index 68e5dffa6..9bc3a027c 100644 --- a/build.gradle +++ b/build.gradle @@ -56,6 +56,7 @@ buildscript { opensearch_group = "org.opensearch" opensearch_no_snapshot = opensearch_build.replace("-SNAPSHOT","") System.setProperty('tests.security.manager', 'false') + common_utils_version = System.getProperty("common_utils.version", opensearch_build) } repositories { @@ -135,6 +136,7 @@ dependencies { implementation 'org.junit.jupiter:junit-jupiter:5.10.0' implementation "com.google.guava:guava:32.1.3-jre" api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" + implementation "org.opensearch:common-utils:${common_utils_version}" configurations.all { resolutionStrategy { diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 0bac15c61..dd0fe09a1 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -24,14 +24,13 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; -import org.opensearch.flowframework.indices.GlobalContextHandler; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.rest.RestCreateWorkflowAction; import org.opensearch.flowframework.rest.RestProvisionWorkflowAction; import org.opensearch.flowframework.transport.CreateWorkflowAction; import org.opensearch.flowframework.transport.CreateWorkflowTransportAction; import org.opensearch.flowframework.transport.ProvisionWorkflowAction; import org.opensearch.flowframework.transport.ProvisionWorkflowTransportAction; -import org.opensearch.flowframework.workflow.CreateIndexStep; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.plugins.ActionPlugin; @@ -79,10 +78,9 @@ public Collection createComponents( WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); - // TODO : Refactor, move system index creation/associated methods outside of the CreateIndexStep - GlobalContextHandler globalContextHandler = new GlobalContextHandler(client, new CreateIndexStep(clusterService, client)); + FlowFrameworkIndicesHandler flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService); - return ImmutableList.of(workflowStepFactory, workflowProcessSorter, globalContextHandler); + return ImmutableList.of(workflowStepFactory, workflowProcessSorter, flowFrameworkIndicesHandler); } @Override diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 528590d7a..945a123ed 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -27,6 +27,12 @@ private CommonValue() {} public static final String GLOBAL_CONTEXT_INDEX_MAPPING = "mappings/global-context.json"; /** Global Context index mapping version */ public static final Integer GLOBAL_CONTEXT_INDEX_VERSION = 1; + /** Workflow State Index Name */ + public static final String WORKFLOW_STATE_INDEX = ".plugins-workflow-state"; + /** Workflow State index mapping file path */ + public static final String WORKFLOW_STATE_INDEX_MAPPING = "mappings/workflow-state.json"; + /** Workflow State index mapping version */ + public static final Integer WORKFLOW_STATE_INDEX_VERSION = 1; /** The transport action name prefix */ public static final String TRANSPORT_ACION_NAME_PREFIX = "cluster:admin/opensearch/flow_framework/"; diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java index d0ef3503c..8c259dd32 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndex.java @@ -14,6 +14,8 @@ import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX_VERSION; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX_VERSION; /** * An enumeration of Flow Framework indices @@ -24,8 +26,13 @@ public enum FlowFrameworkIndex { */ GLOBAL_CONTEXT( GLOBAL_CONTEXT_INDEX, - ThrowingSupplierWrapper.throwingSupplierWrapper(GlobalContextHandler::getGlobalContextMappings), + ThrowingSupplierWrapper.throwingSupplierWrapper(FlowFrameworkIndicesHandler::getGlobalContextMappings), GLOBAL_CONTEXT_INDEX_VERSION + ), + WORKFLOW_STATE( + WORKFLOW_STATE_INDEX, + ThrowingSupplierWrapper.throwingSupplierWrapper(FlowFrameworkIndicesHandler::getGlobalContextMappings), + WORKFLOW_STATE_INDEX_VERSION ); private final String indexName; diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java new file mode 100644 index 000000000..2ecc8142b --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -0,0 +1,432 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.indices; + +import com.google.common.base.Charsets; +import com.google.common.io.Resources; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.admin.indices.mapping.put.PutMappingRequest; +import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.model.ProvisioningProgress; +import org.opensearch.flowframework.model.State; +import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.WorkflowState; +import org.opensearch.flowframework.transport.WorkflowResponse; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; + +import java.io.IOException; +import java.net.URL; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; +import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; +import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX_MAPPING; +import static org.opensearch.flowframework.common.CommonValue.META; +import static org.opensearch.flowframework.common.CommonValue.NO_SCHEMA_VERSION; +import static org.opensearch.flowframework.common.CommonValue.SCHEMA_VERSION_FIELD; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; +import static org.opensearch.flowframework.model.WorkflowState.WORKFLOW_ID_FIELD; + +/** + * A handler for global context related operations + */ +public class FlowFrameworkIndicesHandler { + private static final Logger logger = LogManager.getLogger(FlowFrameworkIndicesHandler.class); + private final Client client; + ClusterService clusterService; + private static final Map indexMappingUpdated = new HashMap<>(); + private static final Map indexSettings = Map.of("index.auto_expand_replicas", "0-1"); + + /** + * constructor + * @param client the open search client + * @param clusterService ClusterService + */ + public FlowFrameworkIndicesHandler(Client client, ClusterService clusterService) { + this.client = client; + this.clusterService = clusterService; + } + + static { + for (FlowFrameworkIndex mlIndex : FlowFrameworkIndex.values()) { + indexMappingUpdated.put(mlIndex.getIndexName(), new AtomicBoolean(false)); + } + } + + /** + * Get global-context index mapping + * @return global-context index mapping + * @throws IOException if mapping file cannot be read correctly + */ + public static String getGlobalContextMappings() throws IOException { + return getIndexMappings(GLOBAL_CONTEXT_INDEX_MAPPING); + } + + public void initGlobalContextIndexIfAbsent(ActionListener listener) { + initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex.GLOBAL_CONTEXT, listener); + } + + public void initWorkflowStateIndexIfAbsent(ActionListener listener) { + initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex.WORKFLOW_STATE, listener); + } + + /** + * Checks if the given index exists + * @param indexName the name of the index + * @return boolean indicating the existence of an index + */ + public boolean doesIndexExist(String indexName) { + return clusterService.state().metadata().hasIndex(indexName); + } + + /** + * Create Index if it's absent + * @param index The index that needs to be created + * @param listener The action listener + */ + public void initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex index, ActionListener listener) { + String indexName = index.getIndexName(); + String mapping = index.getMapping(); + + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + if (!clusterService.state().metadata().hasIndex(indexName)) { + @SuppressWarnings("deprecation") + ActionListener actionListener = ActionListener.wrap(r -> { + if (r.isAcknowledged()) { + logger.info("create index:{}", indexName); + internalListener.onResponse(true); + } else { + internalListener.onResponse(false); + } + }, e -> { + logger.error("Failed to create index " + indexName, e); + internalListener.onFailure(e); + }); + CreateIndexRequest request = new CreateIndexRequest(indexName).mapping(mapping).settings(indexSettings); + client.admin().indices().create(request, actionListener); + } else { + logger.debug("index:{} is already created", indexName); + if (indexMappingUpdated.containsKey(indexName) && !indexMappingUpdated.get(indexName).get()) { + shouldUpdateIndex(indexName, index.getVersion(), ActionListener.wrap(r -> { + if (r) { + // return true if update index is needed + client.admin() + .indices() + .putMapping( + new PutMappingRequest().indices(indexName).source(mapping, XContentType.JSON), + ActionListener.wrap(response -> { + if (response.isAcknowledged()) { + UpdateSettingsRequest updateSettingRequest = new UpdateSettingsRequest(); + updateSettingRequest.indices(indexName).settings(indexSettings); + client.admin() + .indices() + .updateSettings(updateSettingRequest, ActionListener.wrap(updateResponse -> { + if (response.isAcknowledged()) { + indexMappingUpdated.get(indexName).set(true); + internalListener.onResponse(true); + } else { + internalListener.onFailure( + new FlowFrameworkException( + "Failed to update index setting for: " + indexName, + INTERNAL_SERVER_ERROR + ) + ); + } + }, exception -> { + logger.error("Failed to update index setting for: " + indexName, exception); + internalListener.onFailure(exception); + })); + } else { + internalListener.onFailure( + new FlowFrameworkException("Failed to update index: " + indexName, INTERNAL_SERVER_ERROR) + ); + } + }, exception -> { + logger.error("Failed to update index " + indexName, exception); + internalListener.onFailure(exception); + }) + ); + } else { + // no need to update index if it does not exist or the version is already up-to-date. + indexMappingUpdated.get(indexName).set(true); + internalListener.onResponse(true); + } + }, e -> { + logger.error("Failed to update index mapping", e); + internalListener.onFailure(e); + })); + } else { + // No need to update index if it's already updated. + internalListener.onResponse(true); + } + } + } catch (Exception e) { + logger.error("Failed to init index " + indexName, e); + listener.onFailure(e); + } + } + + /** + * Check if we should update index based on schema version. + * @param indexName index name + * @param newVersion new index mapping version + * @param listener action listener, if update index is needed, will pass true to its onResponse method + */ + private void shouldUpdateIndex(String indexName, Integer newVersion, ActionListener listener) { + IndexMetadata indexMetaData = clusterService.state().getMetadata().indices().get(indexName); + if (indexMetaData == null) { + listener.onResponse(Boolean.FALSE); + return; + } + Integer oldVersion = NO_SCHEMA_VERSION; + Map indexMapping = indexMetaData.mapping().getSourceAsMap(); + Object meta = indexMapping.get(META); + if (meta != null && meta instanceof Map) { + @SuppressWarnings("unchecked") + Map metaMapping = (Map) meta; + Object schemaVersion = metaMapping.get(SCHEMA_VERSION_FIELD); + if (schemaVersion instanceof Integer) { + oldVersion = (Integer) schemaVersion; + } + } + listener.onResponse(newVersion > oldVersion); + } + + /** + * Get index mapping json content. + * + * @param mapping type of the index to fetch the specific mapping file + * @return index mapping + * @throws IOException IOException if mapping file can't be read correctly + */ + public static String getIndexMappings(String mapping) throws IOException { + URL url = FlowFrameworkIndicesHandler.class.getClassLoader().getResource(mapping); + return Resources.toString(url, Charsets.UTF_8); + } + + /** + * add document insert into global context index + * @param template the use-case template + * @param listener action listener + */ + public void putTemplateToGlobalContext(Template template, ActionListener listener) { + initGlobalContextIndexIfAbsent(ActionListener.wrap(indexCreated -> { + if (!indexCreated) { + listener.onFailure(new FlowFrameworkException("No response to create global_context index", INTERNAL_SERVER_ERROR)); + return; + } + IndexRequest request = new IndexRequest(GLOBAL_CONTEXT_INDEX); + try ( + XContentBuilder builder = XContentFactory.jsonBuilder(); + ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext() + ) { + request.source(template.toDocumentSource(builder, ToXContent.EMPTY_PARAMS)) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.index(request, ActionListener.runBefore(listener, () -> context.restore())); + } catch (Exception e) { + logger.error("Failed to index global_context index"); + listener.onFailure(e); + } + }, e -> { + logger.error("Failed to create global_context index", e); + listener.onFailure(e); + })); + } + + /** + * add document insert into global context index + * @param workflowId the workflowId, corresponds to document ID of + * @param listener action listener + */ + public void putInitialStateToWorkflowState(String workflowId, User user, ActionListener listener) { + WorkflowState state = new WorkflowState.Builder().workflowId(workflowId) + .state(State.NOT_STARTED.name()) + .provisioningProgress(ProvisioningProgress.NOT_STARTED.name()) + .user(user) + .build(); + initWorkflowStateIndexIfAbsent(ActionListener.wrap(indexCreated -> { + if (!indexCreated) { + listener.onFailure(new FlowFrameworkException("No response to create workflow_state index", INTERNAL_SERVER_ERROR)); + return; + } + IndexRequest request = new IndexRequest(WORKFLOW_STATE_INDEX); + try ( + XContentBuilder builder = XContentFactory.jsonBuilder(); + ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext(); + + ) { + request.source(state.toXContent(builder, ToXContent.EMPTY_PARAMS)).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.index(request, ActionListener.runBefore(listener, () -> context.restore())); + } catch (Exception e) { + logger.error("Failed to put state index document", e); + listener.onFailure(e); + } + + }, e -> { + logger.error("Failed to create global_context index", e); + listener.onFailure(e); + })); + } + + /** + * Replaces a document in the global context index + * @param documentId the document Id + * @param template the use-case template + * @param listener action listener + */ + public void updateTemplateInGlobalContext(String documentId, Template template, ActionListener listener) { + if (!doesIndexExist(GLOBAL_CONTEXT_INDEX)) { + String exceptionMessage = "Failed to update workflow state for workflow_id : " + + documentId + + ", workflow_state index does not exist."; + logger.error(exceptionMessage); + listener.onFailure(new Exception(exceptionMessage)); + } else { + IndexRequest request = new IndexRequest(GLOBAL_CONTEXT_INDEX).id(documentId); + try ( + XContentBuilder builder = XContentFactory.jsonBuilder(); + ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext() + ) { + request.source(template.toDocumentSource(builder, ToXContent.EMPTY_PARAMS)) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.index(request, ActionListener.runBefore(listener, () -> context.restore())); + } catch (Exception e) { + logger.error("Failed to update global_context entry : {}. {}", documentId, e.getMessage()); + listener.onFailure(e); + } + } + } + + /** + * Updates a document in the workflow state index + * @param workflowStateDocId the document ID + * @param updatedFields the fields to update the global state index with + * @param listener action listener + */ + public void updateWorkflowState( + String workflowStateDocId, + ThreadContext.StoredContext context, + Map updatedFields, + ActionListener listener + ) { + if (!doesIndexExist(WORKFLOW_STATE_INDEX)) { + String exceptionMessage = "Failed to update state for given workflow due to missing workflow_state index"; + logger.error(exceptionMessage); + listener.onFailure(new Exception(exceptionMessage)); + } else { + UpdateRequest updateRequest = new UpdateRequest(WORKFLOW_STATE_INDEX, workflowStateDocId); + Map updatedContent = new HashMap<>(); + updatedContent.putAll(updatedFields); + updateRequest.doc(updatedContent); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore())); + } + } + + public void getWorkflowStateID(String workflowId, ActionListener listener) { + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(WORKFLOW_ID_FIELD, workflowId)); + SearchRequest searchRequest = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(query).size(1); // we are making the assumption there is only one document with this workflowID + searchRequest.source(sourceBuilder).indices(WORKFLOW_STATE_INDEX); + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + if (searchResponse == null + || searchResponse.getHits().getTotalHits() == null + || !(searchResponse.getHits().getTotalHits().value == 1)) { + logger.error("There are either one or no workflow state documents with the same workflowID: " + workflowId); + listener.onFailure(new FlowFrameworkException("Workflow state cannot be updated", INTERNAL_SERVER_ERROR)); + return; + } + String stateWorkflowDocID = searchResponse.getHits().getHits()[0].getId(); + listener.onResponse(stateWorkflowDocID); + }, exception -> { + logger.error("Failed to find workflow state for workflowID : {}. {}", workflowId, exception.getMessage()); + listener.onFailure(new FlowFrameworkException("Failed to find workflow state for workflowID: " + workflowId, BAD_REQUEST)); + })); + } + + public void getAndUpdateWorkflowStateDoc( + String workflowId, + Map updatedFields, + ActionListener workflowResponseListener + ) { + try { + ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext(); + getWorkflowStateID(workflowId, ActionListener.wrap(stateWorkflowId -> { + updateWorkflowState(stateWorkflowId, context, updatedFields, ActionListener.wrap(r -> {}, e -> { + logger.error("Failed to update workflow state : {}", e.getMessage()); + workflowResponseListener.onFailure( + new FlowFrameworkException("Failed to update workflow state", RestStatus.BAD_REQUEST) + ); + })); + }, exception -> { + logger.error("Failed to save workflow state : {}", exception.getMessage()); + workflowResponseListener.onFailure(new FlowFrameworkException("couldn't find workflow state", RestStatus.BAD_REQUEST)); + })); + } catch (Exception e) { + logger.error("Failed to update workflow state : {}", e.getMessage()); + workflowResponseListener.onFailure(new FlowFrameworkException("Failed to update workflow state", RestStatus.BAD_REQUEST)); + } + + } + + /** + * Update global context index for specific fields + * @param documentId global context index document id + * @param updatedFields updated fields; key: field name, value: new value + * @param listener UpdateResponse action listener + */ + public void storeResponseToGlobalContext( + String documentId, + Map updatedFields, + ActionListener listener + ) { + UpdateRequest updateRequest = new UpdateRequest(GLOBAL_CONTEXT_INDEX, documentId); + Map updatedUserOutputsContext = new HashMap<>(); + updatedUserOutputsContext.putAll(updatedFields); + updateRequest.doc(updatedUserOutputsContext); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + // TODO: decide what condition can be considered as an update conflict and add retry strategy + + try { + client.update(updateRequest, listener); + } catch (Exception e) { + logger.error("Failed to update global_context index"); + listener.onFailure(e); + } + } +} diff --git a/src/main/java/org/opensearch/flowframework/indices/GlobalContextHandler.java b/src/main/java/org/opensearch/flowframework/indices/GlobalContextHandler.java deleted file mode 100644 index 53037d7ce..000000000 --- a/src/main/java/org/opensearch/flowframework/indices/GlobalContextHandler.java +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.flowframework.indices; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.action.update.UpdateRequest; -import org.opensearch.action.update.UpdateResponse; -import org.opensearch.client.Client; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.flowframework.exception.FlowFrameworkException; -import org.opensearch.flowframework.model.Template; -import org.opensearch.flowframework.workflow.CreateIndexStep; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; - -import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; -import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; -import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX_MAPPING; -import static org.opensearch.flowframework.workflow.CreateIndexStep.getIndexMappings; - -/** - * A handler for global context related operations - */ -public class GlobalContextHandler { - private static final Logger logger = LogManager.getLogger(GlobalContextHandler.class); - private final Client client; - private final CreateIndexStep createIndexStep; - - /** - * constructor - * @param client the open search client - * @param createIndexStep create index step - */ - public GlobalContextHandler(Client client, CreateIndexStep createIndexStep) { - this.client = client; - this.createIndexStep = createIndexStep; - } - - /** - * Get global-context index mapping - * @return global-context index mapping - * @throws IOException if mapping file cannot be read correctly - */ - public static String getGlobalContextMappings() throws IOException { - return getIndexMappings(GLOBAL_CONTEXT_INDEX_MAPPING); - } - - private void initGlobalContextIndexIfAbsent(ActionListener listener) { - createIndexStep.initIndexIfAbsent(FlowFrameworkIndex.GLOBAL_CONTEXT, listener); - } - - /** - * add document insert into global context index - * @param template the use-case template - * @param listener action listener - */ - public void putTemplateToGlobalContext(Template template, ActionListener listener) { - initGlobalContextIndexIfAbsent(ActionListener.wrap(indexCreated -> { - if (!indexCreated) { - listener.onFailure(new FlowFrameworkException("No response to create global_context index", INTERNAL_SERVER_ERROR)); - return; - } - IndexRequest request = new IndexRequest(GLOBAL_CONTEXT_INDEX); - try ( - XContentBuilder builder = XContentFactory.jsonBuilder(); - ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext() - ) { - request.source(template.toDocumentSource(builder, ToXContent.EMPTY_PARAMS)) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.index(request, ActionListener.runBefore(listener, () -> context.restore())); - } catch (Exception e) { - logger.error("Failed to index global_context index"); - listener.onFailure(e); - } - }, e -> { - logger.error("Failed to create global_context index", e); - listener.onFailure(e); - })); - } - - /** - * Replaces a document in the global context index - * @param documentId the document Id - * @param template the use-case template - * @param listener action listener - */ - public void updateTemplateInGlobalContext(String documentId, Template template, ActionListener listener) { - if (!createIndexStep.doesIndexExist(GLOBAL_CONTEXT_INDEX)) { - String exceptionMessage = "Failed to update template for workflow_id : " - + documentId - + ", global_context index does not exist."; - logger.error(exceptionMessage); - listener.onFailure(new Exception(exceptionMessage)); - } else { - IndexRequest request = new IndexRequest(GLOBAL_CONTEXT_INDEX).id(documentId); - try ( - XContentBuilder builder = XContentFactory.jsonBuilder(); - ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext() - ) { - request.source(template.toDocumentSource(builder, ToXContent.EMPTY_PARAMS)) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.index(request, ActionListener.runBefore(listener, () -> context.restore())); - } catch (Exception e) { - logger.error("Failed to update global_context entry : {}. {}", documentId, e.getMessage()); - listener.onFailure(e); - } - } - } - - /** - * Update global context index for specific fields - * @param documentId global context index document id - * @param updatedFields updated fields; key: field name, value: new value - * @param listener UpdateResponse action listener - */ - public void storeResponseToGlobalContext( - String documentId, - Map updatedFields, - ActionListener listener - ) { - UpdateRequest updateRequest = new UpdateRequest(GLOBAL_CONTEXT_INDEX, documentId); - Map updatedUserOutputsContext = new HashMap<>(); - updatedUserOutputsContext.putAll(updatedFields); - updateRequest.doc(updatedUserOutputsContext); - updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - // TODO: decide what condition can be considered as an update conflict and add retry strategy - - try { - client.update(updateRequest, listener); - } catch (Exception e) { - logger.error("Failed to update global_context index"); - listener.onFailure(e); - } - } -} diff --git a/src/main/java/org/opensearch/flowframework/model/PipelineProcessor.java b/src/main/java/org/opensearch/flowframework/model/PipelineProcessor.java index b6da0abe5..f4f6f7d4e 100644 --- a/src/main/java/org/opensearch/flowframework/model/PipelineProcessor.java +++ b/src/main/java/org/opensearch/flowframework/model/PipelineProcessor.java @@ -17,8 +17,8 @@ import java.util.Map; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.flowframework.common.TemplateUtil.buildStringToStringMap; -import static org.opensearch.flowframework.common.TemplateUtil.parseStringToStringMap; +import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap; +import static org.opensearch.flowframework.util.ParseUtils.parseStringToStringMap; /** * This represents a processor associated with search and ingest pipelines in the {@link Template}. diff --git a/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java b/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java new file mode 100644 index 000000000..e0812893e --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/model/ProvisioningProgress.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +public enum ProvisioningProgress { + IN_PROGRESS, + DONE, + NOT_STARTED +} diff --git a/src/main/java/org/opensearch/flowframework/model/State.java b/src/main/java/org/opensearch/flowframework/model/State.java new file mode 100644 index 000000000..d2d95000f --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/model/State.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +public enum State { + NOT_STARTED, + PROVISIONING, + FAILED, + READY +} diff --git a/src/main/java/org/opensearch/flowframework/model/Template.java b/src/main/java/org/opensearch/flowframework/model/Template.java index 7d08ef240..d9fc3023e 100644 --- a/src/main/java/org/opensearch/flowframework/model/Template.java +++ b/src/main/java/org/opensearch/flowframework/model/Template.java @@ -25,8 +25,8 @@ import java.util.Map.Entry; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.flowframework.common.TemplateUtil.jsonToParser; -import static org.opensearch.flowframework.common.TemplateUtil.parseStringToStringMap; +import static org.opensearch.flowframework.util.ParseUtils.jsonToParser; +import static org.opensearch.flowframework.util.ParseUtils.parseStringToStringMap; /** * The Template is the central data structure which configures workflows. This object is used to parse JSON communicated via REST API. diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index e34c4ddec..d2046f096 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -24,8 +24,8 @@ import java.util.Objects; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.flowframework.common.TemplateUtil.buildStringToStringMap; -import static org.opensearch.flowframework.common.TemplateUtil.parseStringToStringMap; +import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap; +import static org.opensearch.flowframework.util.ParseUtils.parseStringToStringMap; /** * This represents a process node (step) in a workflow graph in the {@link Template}. diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowState.java b/src/main/java/org/opensearch/flowframework/model/WorkflowState.java new file mode 100644 index 000000000..19b12ab38 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowState.java @@ -0,0 +1,271 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.model; + +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.util.ParseUtils; + +import java.io.IOException; +import java.time.Instant; +import java.util.Map; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +/** + * The WorkflowState is used to store all additional information regarding a workflow that isn't part of the + * global context. + */ +public class WorkflowState implements ToXContentObject { + public static final String WORKFLOW_ID_FIELD = "workflow_id"; + public static final String ERROR_FIELD = "error"; + public static final String STATE_FIELD = "state"; + public static final String PROVISIONING_PROGRESS_FIELD = "provisioning_progress"; + public static final String PROVISION_START_TIME_FIELD = "provision_start_time"; + public static final String PROVISION_END_TIME_FIELD = "provision_end_time"; + public static final String USER_FIELD = "user"; + public static final String UI_METADATA_FIELD = "ui_metadata"; + + private String workflowId; + private String error; + private String state; + private String provisioningProgress; + private Instant provisionStartTime; + private Instant provisionEndTime; + private User user; + private Map uiMetadata; + + /** + * Instantiate the object representing the workflow state + * + * @param workflowId The workflow ID representing the given workflow + * @param error + * @param state + * @param provisioningProgress + * @param provisionStartTime + * @param provisionEndTime + * @param user + * @param uiMetadata + */ + public WorkflowState( + String workflowId, + String error, + String state, + String provisioningProgress, + Instant provisionStartTime, + Instant provisionEndTime, + User user, + Map uiMetadata + ) { + this.workflowId = workflowId; + this.error = error; + this.state = state; + this.provisioningProgress = provisioningProgress; + this.provisionStartTime = provisionStartTime; + this.provisionEndTime = provisionEndTime; + this.user = user; + this.uiMetadata = uiMetadata; + } + + private WorkflowState() {} + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String workflowId = null; + private String error = null; + private String state = null; + private String provisioningProgress = null; + private Instant provisionStartTime = null; + private Instant provisionEndTime = null; + private User user = null; + private Map uiMetadata = null; + + public Builder() {} + + public Builder workflowId(String workflowId) { + this.workflowId = workflowId; + return this; + } + + public Builder error(String error) { + this.error = error; + return this; + } + + public Builder state(String state) { + this.state = state; + return this; + } + + public Builder provisioningProgress(String provisioningProgress) { + this.provisioningProgress = provisioningProgress; + return this; + } + + public Builder provisionStartTime(Instant provisionStartTime) { + this.provisionStartTime = provisionStartTime; + return this; + } + + public Builder provisionEndTime(Instant provisionEndTime) { + this.provisionEndTime = provisionEndTime; + return this; + } + + public Builder user(User user) { + this.user = user; + return this; + } + + public Builder uiMetadata(Map uiMetadata) { + this.uiMetadata = uiMetadata; + return this; + } + + public WorkflowState build() { + WorkflowState workflowState = new WorkflowState(); + workflowState.workflowId = this.workflowId; + workflowState.error = this.error; + workflowState.state = this.state; + workflowState.provisioningProgress = this.provisioningProgress; + workflowState.provisionStartTime = this.provisionStartTime; + workflowState.provisionEndTime = this.provisionEndTime; + workflowState.user = this.user; + workflowState.uiMetadata = this.uiMetadata; + return workflowState; + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + if (workflowId != null) { + xContentBuilder.field(WORKFLOW_ID_FIELD, workflowId); + } + if (error != null) { + xContentBuilder.field(ERROR_FIELD, error); + } + if (state != null) { + xContentBuilder.field(STATE_FIELD, state); + } + if (provisioningProgress != null) { + xContentBuilder.field(PROVISIONING_PROGRESS_FIELD, provisioningProgress); + } + if (provisionStartTime != null) { + xContentBuilder.field(PROVISION_START_TIME_FIELD, provisionStartTime.toEpochMilli()); + } + if (provisionEndTime != null) { + xContentBuilder.field(PROVISION_END_TIME_FIELD, provisionEndTime.toEpochMilli()); + } + if (user != null) { + xContentBuilder.field(USER_FIELD, user); + } + if (uiMetadata != null && !uiMetadata.isEmpty()) { + xContentBuilder.field(UI_METADATA_FIELD, uiMetadata); + } + return xContentBuilder.endObject(); + } + + // TODO: might need to add another parse that takes in a workflow ID. + /** + * Parse raw json content into a Template instance. + * + * @param parser json based content parser + * @return an instance of the template + * @throws IOException if content can't be parsed correctly + */ + public static WorkflowState parse(XContentParser parser) throws IOException { + String workflowId = null; + String error = null; + String state = null; + String provisioningProgress = null; + Instant provisionStartTime = null; + Instant provisionEndTime = null; + User user = null; + Map uiMetadata = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case WORKFLOW_ID_FIELD: + workflowId = parser.text(); + break; + case ERROR_FIELD: + error = parser.text(); + break; + case STATE_FIELD: + state = parser.text(); + break; + case PROVISIONING_PROGRESS_FIELD: + provisioningProgress = parser.text(); + break; + case PROVISION_START_TIME_FIELD: + provisionStartTime = ParseUtils.toInstant(parser); + break; + case PROVISION_END_TIME_FIELD: + provisionEndTime = ParseUtils.toInstant(parser); + break; + case USER_FIELD: + user = User.parse(parser); + break; + case UI_METADATA_FIELD: + uiMetadata = parser.map(); + break; + } + } + return new Builder().workflowId(workflowId) + .error(error) + .state(state) + .provisioningProgress(provisioningProgress) + .provisionStartTime(provisionStartTime) + .provisionEndTime(provisionEndTime) + .user(user) + .uiMetadata(uiMetadata) + .build(); + } + + public String getWorkflowId() { + return workflowId; + } + + public String getError() { + return workflowId; + } + + public String getState() { + return state; + } + + public String getProvisioningProgress() { + return provisioningProgress; + } + + public Instant getProvisionStartTime() { + return provisionStartTime; + } + + public Instant getProvisionEndTime() { + return provisionEndTime; + } + + public User getUser() { + return user; + } + + public Map getUiMetadata() { + return uiMetadata; + } +} diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index f4147b144..d018950c1 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -8,18 +8,27 @@ */ package org.opensearch.flowframework.transport; +import com.google.common.collect.ImmutableMap; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; -import org.opensearch.flowframework.indices.GlobalContextHandler; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.ProvisioningProgress; +import org.opensearch.flowframework.model.State; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import static org.opensearch.flowframework.model.WorkflowState.PROVISIONING_PROGRESS_FIELD; +import static org.opensearch.flowframework.model.WorkflowState.STATE_FIELD; +import static org.opensearch.flowframework.util.ParseUtils.getUserContext; + /** * Transport Action to index or update a use case template within the Global Context */ @@ -27,44 +36,59 @@ public class CreateWorkflowTransportAction extends HandledTransportAction listener) { + User user = getUserContext(client); if (request.getWorkflowId() == null) { // Create new global context and state index entries - globalContextHandler.putTemplateToGlobalContext(request.getTemplate(), ActionListener.wrap(response -> { - // TODO : Check if state index exists, create if not - // TODO : Create StateIndexRequest for workflowId, default to NOT_STARTED - listener.onResponse(new WorkflowResponse(response.getId())); + flowFrameworkIndicesHandler.putTemplateToGlobalContext(request.getTemplate(), ActionListener.wrap(globalContextResponse -> { + flowFrameworkIndicesHandler.putInitialStateToWorkflowState( + globalContextResponse.getId(), + user, + ActionListener.wrap(stateResponse -> { + logger.info("create state workflow doc " + stateResponse); + listener.onResponse(new WorkflowResponse(globalContextResponse.getId())); + }, exception -> { + logger.error("Failed to save workflow state : {}", exception.getMessage()); + listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST)); + }) + ); }, exception -> { logger.error("Failed to save use case template : {}", exception.getMessage()); listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); })); } else { // Update existing entry, full document replacement - globalContextHandler.updateTemplateInGlobalContext( + flowFrameworkIndicesHandler.updateTemplateInGlobalContext( request.getWorkflowId(), request.getTemplate(), ActionListener.wrap(response -> { - // TODO : Create StateIndexRequest for workflowId to reset entry to NOT_STARTED - listener.onResponse(new WorkflowResponse(response.getId())); + flowFrameworkIndicesHandler.getAndUpdateWorkflowStateDoc( + request.getWorkflowId(), + ImmutableMap.of(STATE_FIELD, State.NOT_STARTED, PROVISIONING_PROGRESS_FIELD, ProvisioningProgress.NOT_STARTED), + listener + ); }, exception -> { logger.error("Failed to updated use case template {} : {}", request.getWorkflowId(), exception.getMessage()); listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index e03a1b4d8..1d517b197 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -8,6 +8,7 @@ */ package org.opensearch.flowframework.transport; +import com.google.common.collect.ImmutableMap; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.get.GetRequest; @@ -19,6 +20,9 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.ProvisioningProgress; +import org.opensearch.flowframework.model.State; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.workflow.ProcessNode; @@ -27,6 +31,7 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; +import java.time.Instant; import java.util.ArrayList; import java.util.List; import java.util.Locale; @@ -38,6 +43,9 @@ import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; +import static org.opensearch.flowframework.model.WorkflowState.PROVISIONING_PROGRESS_FIELD; +import static org.opensearch.flowframework.model.WorkflowState.PROVISION_START_TIME_FIELD; +import static org.opensearch.flowframework.model.WorkflowState.STATE_FIELD; /** * Transport Action to provision a workflow from a stored use case template @@ -49,6 +57,7 @@ public class ProvisionWorkflowTransportAction extends HandledTransportAction parseStringToStringMap(XContentParser parser) return map; } + /** + * Parse content parser to {@link java.time.Instant}. + * + * @param parser json based content parser + * @return instance of {@link java.time.Instant} + * @throws IOException IOException if content can't be parsed correctly + */ + public static Instant toInstant(XContentParser parser) throws IOException { + if (parser.currentToken() == null || parser.currentToken() == XContentParser.Token.VALUE_NULL) { + return null; + } + if (parser.currentToken().isValue()) { + return Instant.ofEpochMilli(parser.longValue()); + } + return null; + } + + /** + * Generates a user string formed by the username, backend roles, roles and requested tenants separated by '|' + * (e.g., john||own_index,testrole|__user__, no backend role so you see two verticle line after john.). + * This is the user string format used internally in the OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT and may be + * parsed using User.parse(string). + * @param client Client containing user info. A public API request will fill in the user info in the thread context. + * @return parsed user object + */ + public static User getUserContext(Client client) { + String userStr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + logger.debug("Filtering result by " + userStr); + return User.parse(userStr); + } + } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index 2b2f7338d..9415d99ec 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -8,38 +8,23 @@ */ package org.opensearch.flowframework.workflow; -import com.google.common.base.Charsets; -import com.google.common.io.Resources; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; -import org.opensearch.action.admin.indices.mapping.put.PutMappingRequest; -import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest; import org.opensearch.client.Client; -import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.action.ActionListener; -import org.opensearch.flowframework.exception.FlowFrameworkException; -import org.opensearch.flowframework.indices.FlowFrameworkIndex; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; -import java.io.IOException; -import java.net.URL; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; -import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; -import static org.opensearch.flowframework.common.CommonValue.META; -import static org.opensearch.flowframework.common.CommonValue.NO_SCHEMA_VERSION; -import static org.opensearch.flowframework.common.CommonValue.SCHEMA_VERSION_FIELD; - /** * Step to create an index */ @@ -101,7 +86,7 @@ public void onFailure(Exception e) { try { CreateIndexRequest request = new CreateIndexRequest(index).mapping( - getIndexMappings("mappings/" + type + ".json"), + FlowFrameworkIndicesHandler.getIndexMappings("mappings/" + type + ".json"), JsonXContent.jsonXContent.mediaType() ); client.admin().indices().create(request, actionListener); @@ -116,140 +101,140 @@ public void onFailure(Exception e) { public String getName() { return NAME; } + // + // /** + // * Checks if the given index exists + // * @param indexName the name of the index + // * @return boolean indicating the existence of an index + // */ + // public boolean doesIndexExist(String indexName) { + // return clusterService.state().metadata().hasIndex(indexName); + // } // TODO : Move to index management class, pending implementation - /** - * Checks if the given index exists - * @param indexName the name of the index - * @return boolean indicating the existence of an index - */ - public boolean doesIndexExist(String indexName) { - return clusterService.state().metadata().hasIndex(indexName); - } - - /** - * Create Index if it's absent - * @param index The index that needs to be created - * @param listener The action listener - */ - public void initIndexIfAbsent(FlowFrameworkIndex index, ActionListener listener) { - String indexName = index.getIndexName(); - String mapping = index.getMapping(); - - try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { - ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); - if (!clusterService.state().metadata().hasIndex(indexName)) { - @SuppressWarnings("deprecation") - ActionListener actionListener = ActionListener.wrap(r -> { - if (r.isAcknowledged()) { - logger.info("create index:{}", indexName); - internalListener.onResponse(true); - } else { - internalListener.onResponse(false); - } - }, e -> { - logger.error("Failed to create index " + indexName, e); - internalListener.onFailure(e); - }); - CreateIndexRequest request = new CreateIndexRequest(indexName).mapping(mapping).settings(indexSettings); - client.admin().indices().create(request, actionListener); - } else { - logger.debug("index:{} is already created", indexName); - if (indexMappingUpdated.containsKey(indexName) && !indexMappingUpdated.get(indexName).get()) { - shouldUpdateIndex(indexName, index.getVersion(), ActionListener.wrap(r -> { - if (r) { - // return true if update index is needed - client.admin() - .indices() - .putMapping( - new PutMappingRequest().indices(indexName).source(mapping, XContentType.JSON), - ActionListener.wrap(response -> { - if (response.isAcknowledged()) { - UpdateSettingsRequest updateSettingRequest = new UpdateSettingsRequest(); - updateSettingRequest.indices(indexName).settings(indexSettings); - client.admin() - .indices() - .updateSettings(updateSettingRequest, ActionListener.wrap(updateResponse -> { - if (response.isAcknowledged()) { - indexMappingUpdated.get(indexName).set(true); - internalListener.onResponse(true); - } else { - internalListener.onFailure( - new FlowFrameworkException( - "Failed to update index setting for: " + indexName, - INTERNAL_SERVER_ERROR - ) - ); - } - }, exception -> { - logger.error("Failed to update index setting for: " + indexName, exception); - internalListener.onFailure(exception); - })); - } else { - internalListener.onFailure( - new FlowFrameworkException("Failed to update index: " + indexName, INTERNAL_SERVER_ERROR) - ); - } - }, exception -> { - logger.error("Failed to update index " + indexName, exception); - internalListener.onFailure(exception); - }) - ); - } else { - // no need to update index if it does not exist or the version is already up-to-date. - indexMappingUpdated.get(indexName).set(true); - internalListener.onResponse(true); - } - }, e -> { - logger.error("Failed to update index mapping", e); - internalListener.onFailure(e); - })); - } else { - // No need to update index if it's already updated. - internalListener.onResponse(true); - } - } - } catch (Exception e) { - logger.error("Failed to init index " + indexName, e); - listener.onFailure(e); - } - } - - /** - * Get index mapping json content. - * - * @param mapping type of the index to fetch the specific mapping file - * @return index mapping - * @throws IOException IOException if mapping file can't be read correctly - */ - public static String getIndexMappings(String mapping) throws IOException { - URL url = CreateIndexStep.class.getClassLoader().getResource(mapping); - return Resources.toString(url, Charsets.UTF_8); - } - - /** - * Check if we should update index based on schema version. - * @param indexName index name - * @param newVersion new index mapping version - * @param listener action listener, if update index is needed, will pass true to its onResponse method - */ - private void shouldUpdateIndex(String indexName, Integer newVersion, ActionListener listener) { - IndexMetadata indexMetaData = clusterService.state().getMetadata().indices().get(indexName); - if (indexMetaData == null) { - listener.onResponse(Boolean.FALSE); - return; - } - Integer oldVersion = NO_SCHEMA_VERSION; - Map indexMapping = indexMetaData.mapping().getSourceAsMap(); - Object meta = indexMapping.get(META); - if (meta != null && meta instanceof Map) { - @SuppressWarnings("unchecked") - Map metaMapping = (Map) meta; - Object schemaVersion = metaMapping.get(SCHEMA_VERSION_FIELD); - if (schemaVersion instanceof Integer) { - oldVersion = (Integer) schemaVersion; - } - } - listener.onResponse(newVersion > oldVersion); - } + // /** + // * Create Index if it's absent + // * @param index The index that needs to be created + // * @param listener The action listener + // */ + // public void initIndexIfAbsent(FlowFrameworkIndex index, ActionListener listener) { + // String indexName = index.getIndexName(); + // String mapping = index.getMapping(); + // + // try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + // ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + // if (!clusterService.state().metadata().hasIndex(indexName)) { + // @SuppressWarnings("deprecation") + // ActionListener actionListener = ActionListener.wrap(r -> { + // if (r.isAcknowledged()) { + // logger.info("create index:{}", indexName); + // internalListener.onResponse(true); + // } else { + // internalListener.onResponse(false); + // } + // }, e -> { + // logger.error("Failed to create index " + indexName, e); + // internalListener.onFailure(e); + // }); + // CreateIndexRequest request = new CreateIndexRequest(indexName).mapping(mapping).settings(indexSettings); + // client.admin().indices().create(request, actionListener); + // } else { + // logger.debug("index:{} is already created", indexName); + // if (indexMappingUpdated.containsKey(indexName) && !indexMappingUpdated.get(indexName).get()) { + // shouldUpdateIndex(indexName, index.getVersion(), ActionListener.wrap(r -> { + // if (r) { + // // return true if update index is needed + // client.admin() + // .indices() + // .putMapping( + // new PutMappingRequest().indices(indexName).source(mapping, XContentType.JSON), + // ActionListener.wrap(response -> { + // if (response.isAcknowledged()) { + // UpdateSettingsRequest updateSettingRequest = new UpdateSettingsRequest(); + // updateSettingRequest.indices(indexName).settings(indexSettings); + // client.admin() + // .indices() + // .updateSettings(updateSettingRequest, ActionListener.wrap(updateResponse -> { + // if (response.isAcknowledged()) { + // indexMappingUpdated.get(indexName).set(true); + // internalListener.onResponse(true); + // } else { + // internalListener.onFailure( + // new FlowFrameworkException( + // "Failed to update index setting for: " + indexName, + // INTERNAL_SERVER_ERROR + // ) + // ); + // } + // }, exception -> { + // logger.error("Failed to update index setting for: " + indexName, exception); + // internalListener.onFailure(exception); + // })); + // } else { + // internalListener.onFailure( + // new FlowFrameworkException("Failed to update index: " + indexName, INTERNAL_SERVER_ERROR) + // ); + // } + // }, exception -> { + // logger.error("Failed to update index " + indexName, exception); + // internalListener.onFailure(exception); + // }) + // ); + // } else { + // // no need to update index if it does not exist or the version is already up-to-date. + // indexMappingUpdated.get(indexName).set(true); + // internalListener.onResponse(true); + // } + // }, e -> { + // logger.error("Failed to update index mapping", e); + // internalListener.onFailure(e); + // })); + // } else { + // // No need to update index if it's already updated. + // internalListener.onResponse(true); + // } + // } + // } catch (Exception e) { + // logger.error("Failed to init index " + indexName, e); + // listener.onFailure(e); + // } + // } + // + // /** + // * Get index mapping json content. + // * + // * @param mapping type of the index to fetch the specific mapping file + // * @return index mapping + // * @throws IOException IOException if mapping file can't be read correctly + // */ + // public static String getIndexMappings(String mapping) throws IOException { + // URL url = CreateIndexStep.class.getClassLoader().getResource(mapping); + // return Resources.toString(url, Charsets.UTF_8); + // } + // + // /** + // * Check if we should update index based on schema version. + // * @param indexName index name + // * @param newVersion new index mapping version + // * @param listener action listener, if update index is needed, will pass true to its onResponse method + // */ + // private void shouldUpdateIndex(String indexName, Integer newVersion, ActionListener listener) { + // IndexMetadata indexMetaData = clusterService.state().getMetadata().indices().get(indexName); + // if (indexMetaData == null) { + // listener.onResponse(Boolean.FALSE); + // return; + // } + // Integer oldVersion = NO_SCHEMA_VERSION; + // Map indexMapping = indexMetaData.mapping().getSourceAsMap(); + // Object meta = indexMapping.get(META); + // if (meta != null && meta instanceof Map) { + // @SuppressWarnings("unchecked") + // Map metaMapping = (Map) meta; + // Object schemaVersion = metaMapping.get(SCHEMA_VERSION_FIELD); + // if (schemaVersion instanceof Integer) { + // oldVersion = (Integer) schemaVersion; + // } + // } + // listener.onResponse(newVersion > oldVersion); + // } } diff --git a/src/main/resources/mappings/knn-text-search-default.json b/src/main/resources/mappings/knn-text-search-default.json new file mode 100644 index 000000000..5d7e20baf --- /dev/null +++ b/src/main/resources/mappings/knn-text-search-default.json @@ -0,0 +1,20 @@ +{ + "properties": { + "id": { + "type": "text" + }, + "passage_embedding": { + "type": "knn_vector", + "dimension": 768, + "method": { + "engine": "lucene", + "space_type": "l2", + "name": "hnsw", + "parameters": {} + } + }, + "passage_text": { + "type": "text" + } + } +} diff --git a/src/main/resources/mappings/workflow-state.json b/src/main/resources/mappings/workflow-state.json new file mode 100644 index 000000000..1102a6a3d --- /dev/null +++ b/src/main/resources/mappings/workflow-state.json @@ -0,0 +1,87 @@ +{ + "dynamic": false, + "_meta": { + "schema_version": 1 + }, + "properties": { + "schema_version": { + "type": "integer" + }, + "workflow_id": { + "type": "keyword" + }, + "error": { + "type": "text" + } + "state": { + "type": "keyword" + }, + "provisioning_progress": { + "type": "keyword" + }, + "provision_start_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "provision_end_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "user": { + "type": "nested", + "properties": { + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "backend_roles": { + "type" : "text", + "fields" : { + "keyword" : { + "type" : "keyword" + } + } + }, + "roles": { + "type" : "text", + "fields" : { + "keyword" : { + "type" : "keyword" + } + } + }, + "custom_attribute_names": { + "type" : "text", + "fields" : { + "keyword" : { + "type" : "keyword" + } + } + } + } + }, + "ui_metadata": { + "type": "object", + "enabled": false + } + } + "ui_metadata": { + "features": { + "sum_http_5xx": { + "aggregationBy": "sum", + "aggregationOf": "http_5xx", + "featureType": "simple_aggs" + }, + "sum_http_4xx": { + "aggregationBy": "sum", + "aggregationOf": "http_4xx", + "featureType": "simple_aggs" + } + }, + "filters": [] + }, diff --git a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java new file mode 100644 index 000000000..ca45544d7 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java @@ -0,0 +1,254 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.indices; + +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.flowframework.workflow.CreateIndexStep; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; + +import java.util.Map; + +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class FlowFrameworkIndicesHandlerTests extends OpenSearchTestCase { + @Mock + private Client client; + @Mock + private CreateIndexStep createIndexStep; + @Mock + private ThreadPool threadPool; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private AdminClient adminClient; + private IndicesAdminClient indicesAdminClient; + private ThreadContext threadContext; + @Mock + protected ClusterService clusterService; + @Mock + private FlowFrameworkIndicesHandler flowMock; + // private static final String META = "_meta"; + // private static final String SCHEMA_VERSION_FIELD = "schemaVersion"; + @Mock + private Metadata metadata; + // private Map indexMappingUpdated = new HashMap<>(); + @Mock + IndexMetadata indexMetadata; + + @Override + public void setUp() throws Exception { + super.setUp(); + MockitoAnnotations.openMocks(this); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + flowFrameworkIndicesHandler = new FlowFrameworkIndicesHandler(client, clusterService); + adminClient = mock(AdminClient.class); + indicesAdminClient = mock(IndicesAdminClient.class); + when(adminClient.indices()).thenReturn(indicesAdminClient); + when(client.admin()).thenReturn(adminClient); + when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("test cluster")).build()); + when(metadata.indices()).thenReturn(Map.of(GLOBAL_CONTEXT_INDEX, indexMetadata)); + when(adminClient.indices()).thenReturn(indicesAdminClient); + } + // + // public void testPutTemplateToGlobalContext() throws IOException { + // Template template = mock(Template.class); + // when(template.toDocumentSource(any(XContentBuilder.class), eq(ToXContent.EMPTY_PARAMS))).thenAnswer(invocation -> { + // XContentBuilder builder = invocation.getArgument(0); + // return builder; + // }); + // @SuppressWarnings("unchecked") + // + // ActionListener listener = mock(ActionListener.class); + // doAnswer(invocation -> { + // ActionListener callback = invocation.getArgument(1); + // callback.onResponse(true); + // return null; + // }).when(flowMock).initFlowFrameworkIndexIfAbsent(any(FlowFrameworkIndex.class), any()); + // flowMock.initFlowFrameworkIndexIfAbsent(any(FlowFrameworkIndex.class), any()); + // // when(flowFrameworkIndicesHandler.initFlowFrameworkIndexIfAbsent(flowFrameworkIndex, listener). + //// flowFrameworkIndicesHandler.initFlowFrameworkIndexIfAbsent(listener); + // //verify(flowMock).initFlowFrameworkIndexIfAbsent(any(FlowFrameworkIndex.class), any()); + // ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(IndexRequest.class); + // verify(indicesAdminClient, times(1)).create(requestCaptor.capture(), any()); + // + // assertEquals(GLOBAL_CONTEXT_INDEX, requestCaptor.getValue().index()); + // } + + // public void testPutTemplateToGlobalContext() throws IOException { + // Template template = mock(Template.class); + // when(template.toDocumentSource(any(XContentBuilder.class), eq(ToXContent.EMPTY_PARAMS))).thenAnswer(invocation -> { + // XContentBuilder builder = invocation.getArgument(0); + // return builder; + // }); + // @SuppressWarnings("unchecked") + // ActionListener listener = mock(ActionListener.class); + // + // doAnswer(invocation -> { + // ActionListener callback = invocation.getArgument(1); + // callback.onResponse(true); + // return null; + // }).when(flowMock).initFlowFrameworkIndexIfAbsent(any(FlowFrameworkIndex.class), any()); + // + // flowFrameworkIndicesHandler.putTemplateToGlobalContext(template, listener); + // + // ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(IndexRequest.class); + // verify(client, times(1)).index(requestCaptor.capture(), any()); + // + // assertEquals(GLOBAL_CONTEXT_INDEX, requestCaptor.getValue().index()); + // } + + // + // public void testStoreResponseToGlobalContext() { + // String documentId = "docId"; + // Map updatedFields = new HashMap<>(); + // updatedFields.put("field1", "value1"); + // @SuppressWarnings("unchecked") + // ActionListener listener = mock(ActionListener.class); + // + // flowFrameworkIndicesHandler.storeResponseToGlobalContext(documentId, updatedFields, listener); + // + // ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(UpdateRequest.class); + // verify(client, times(1)).update(requestCaptor.capture(), any()); + // + // assertEquals(GLOBAL_CONTEXT_INDEX, requestCaptor.getValue().index()); + // assertEquals(documentId, requestCaptor.getValue().id()); + // } + + // public void testUpdateTemplateInGlobalContext() throws IOException { + // Template template = mock(Template.class); + // when(template.toDocumentSource(any(XContentBuilder.class), eq(ToXContent.EMPTY_PARAMS))).thenAnswer(invocation -> { + // XContentBuilder builder = invocation.getArgument(0); + // return builder; + // }); + // when(createIndexStep.doesIndexExist(any())).thenReturn(true); + // + // flowFrameworkIndicesHandler.updateTemplateInGlobalContext("1", template, null); + // + // ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(IndexRequest.class); + // verify(client, times(1)).index(requestCaptor.capture(), any()); + // + // assertEquals("1", requestCaptor.getValue().id()); + // } + + // public void testFailedUpdateTemplateInGlobalContext() throws IOException { + // Template template = mock(Template.class); + // @SuppressWarnings("unchecked") + // ActionListener listener = mock(ActionListener.class); + // // when(createIndexStep.doesIndexExist(any())).thenReturn(false); + // + // flowFrameworkIndicesHandler.updateTemplateInGlobalContext("1", template, listener); + // ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + // + // verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + // + // assertEquals( + // "Failed to update template for workflow_id : 1, global_context index does not exist.", + // exceptionCaptor.getValue().getMessage() + // ); + // } + // public void testInitIndexIfAbsent_IndexNotPresent() { + // when(metadata.hasIndex(FlowFrameworkIndex.GLOBAL_CONTEXT.getIndexName())).thenReturn(false); + // + // @SuppressWarnings("unchecked") + // ActionListener listener = mock(ActionListener.class); + // flowFrameworkIndicesHandler.initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex.GLOBAL_CONTEXT, listener); + // + // verify(indicesAdminClient, times(1)).create(any(CreateIndexRequest.class), any()); + // } + + // public void testInitIndexIfAbsent_IndexExist() { + // FlowFrameworkIndex index = FlowFrameworkIndex.GLOBAL_CONTEXT; + // indexMappingUpdated.put(index.getIndexName(), new AtomicBoolean(false)); + // + // ClusterState mockClusterState = mock(ClusterState.class); + // Metadata mockMetadata = mock(Metadata.class); + // when(clusterService.state()).thenReturn(mockClusterState); + // when(mockClusterState.metadata()).thenReturn(mockMetadata); + // when(mockMetadata.hasIndex(index.getIndexName())).thenReturn(true); + // @SuppressWarnings("unchecked") + // ActionListener listener = mock(ActionListener.class); + // + // IndexMetadata mockIndexMetadata = mock(IndexMetadata.class); + // @SuppressWarnings("unchecked") + // Map mockIndices = mock(Map.class); + // when(clusterService.state()).thenReturn(mockClusterState); + // when(mockClusterState.getMetadata()).thenReturn(mockMetadata); + // when(mockMetadata.indices()).thenReturn(mockIndices); + // when(mockIndices.get(anyString())).thenReturn(mockIndexMetadata); + // Map mockMapping = new HashMap<>(); + // Map mockMetaMapping = new HashMap<>(); + // mockMetaMapping.put(SCHEMA_VERSION_FIELD, 1); + // mockMapping.put(META, mockMetaMapping); + // MappingMetadata mockMappingMetadata = mock(MappingMetadata.class); + // when(mockIndexMetadata.mapping()).thenReturn(mockMappingMetadata); + // when(mockMappingMetadata.getSourceAsMap()).thenReturn(mockMapping); + // + // flowFrameworkIndicesHandler.initFlowFrameworkIndexIfAbsent(index, listener); + // + // ArgumentCaptor putMappingRequestArgumentCaptor = ArgumentCaptor.forClass(PutMappingRequest.class); + // @SuppressWarnings({ "unchecked" }) + // ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + // verify(indicesAdminClient, times(1)).putMapping(putMappingRequestArgumentCaptor.capture(), listenerCaptor.capture()); + // PutMappingRequest capturedRequest = putMappingRequestArgumentCaptor.getValue(); + // assertEquals(index.getIndexName(), capturedRequest.indices()[0]); + // } + // + // public void testInitIndexIfAbsent_IndexExist_returnFalse() { + // FlowFrameworkIndex index = FlowFrameworkIndex.GLOBAL_CONTEXT; + // indexMappingUpdated.put(index.getIndexName(), new AtomicBoolean(false)); + // + // ClusterState mockClusterState = mock(ClusterState.class); + // Metadata mockMetadata = mock(Metadata.class); + // when(clusterService.state()).thenReturn(mockClusterState); + // when(mockClusterState.metadata()).thenReturn(mockMetadata); + // when(mockMetadata.hasIndex(index.getIndexName())).thenReturn(true); + // + // @SuppressWarnings("unchecked") + // ActionListener listener = mock(ActionListener.class); + // @SuppressWarnings("unchecked") + // Map mockIndices = mock(Map.class); + // when(mockClusterState.getMetadata()).thenReturn(mockMetadata); + // when(mockMetadata.indices()).thenReturn(mockIndices); + // when(mockIndices.get(anyString())).thenReturn(null); + // + // flowFrameworkIndicesHandler.initFlowFrameworkIndexIfAbsent(index, listener); + // assertTrue(indexMappingUpdated.get(index.getIndexName()).get()); + // } + // + // public void testDoesIndexExist() { + // ClusterState mockClusterState = mock(ClusterState.class); + // Metadata mockMetaData = mock(Metadata.class); + // when(clusterService.state()).thenReturn(mockClusterState); + // when(mockClusterState.metadata()).thenReturn(mockMetaData); + // + // flowFrameworkIndicesHandler.doesIndexExist(GLOBAL_CONTEXT_INDEX); + // + // ArgumentCaptor indexExistsCaptor = ArgumentCaptor.forClass(String.class); + // verify(mockMetaData, times(1)).hasIndex(indexExistsCaptor.capture()); + // + // assertEquals(GLOBAL_CONTEXT_INDEX, indexExistsCaptor.getValue()); + // } +} diff --git a/src/test/java/org/opensearch/flowframework/indices/GlobalContextHandlerTests.java b/src/test/java/org/opensearch/flowframework/indices/GlobalContextHandlerTests.java deleted file mode 100644 index 78f0fc618..000000000 --- a/src/test/java/org/opensearch/flowframework/indices/GlobalContextHandlerTests.java +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.flowframework.indices; - -import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.action.update.UpdateRequest; -import org.opensearch.action.update.UpdateResponse; -import org.opensearch.client.AdminClient; -import org.opensearch.client.Client; -import org.opensearch.client.IndicesAdminClient; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.flowframework.model.Template; -import org.opensearch.flowframework.workflow.CreateIndexStep; -import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.threadpool.ThreadPool; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; - -import org.mockito.ArgumentCaptor; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; - -import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -public class GlobalContextHandlerTests extends OpenSearchTestCase { - @Mock - private Client client; - @Mock - private CreateIndexStep createIndexStep; - @Mock - private ThreadPool threadPool; - private GlobalContextHandler globalContextHandler; - private AdminClient adminClient; - private IndicesAdminClient indicesAdminClient; - private ThreadContext threadContext; - - @Override - public void setUp() throws Exception { - super.setUp(); - MockitoAnnotations.openMocks(this); - - Settings settings = Settings.builder().build(); - threadContext = new ThreadContext(settings); - when(client.threadPool()).thenReturn(threadPool); - when(threadPool.getThreadContext()).thenReturn(threadContext); - - globalContextHandler = new GlobalContextHandler(client, createIndexStep); - adminClient = mock(AdminClient.class); - indicesAdminClient = mock(IndicesAdminClient.class); - when(adminClient.indices()).thenReturn(indicesAdminClient); - when(client.admin()).thenReturn(adminClient); - } - - public void testPutTemplateToGlobalContext() throws IOException { - Template template = mock(Template.class); - when(template.toDocumentSource(any(XContentBuilder.class), eq(ToXContent.EMPTY_PARAMS))).thenAnswer(invocation -> { - XContentBuilder builder = invocation.getArgument(0); - return builder; - }); - @SuppressWarnings("unchecked") - ActionListener listener = mock(ActionListener.class); - - doAnswer(invocation -> { - ActionListener callback = invocation.getArgument(1); - callback.onResponse(true); - return null; - }).when(createIndexStep).initIndexIfAbsent(any(FlowFrameworkIndex.class), any()); - - globalContextHandler.putTemplateToGlobalContext(template, listener); - - ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(IndexRequest.class); - verify(client, times(1)).index(requestCaptor.capture(), any()); - - assertEquals(GLOBAL_CONTEXT_INDEX, requestCaptor.getValue().index()); - } - - public void testStoreResponseToGlobalContext() { - String documentId = "docId"; - Map updatedFields = new HashMap<>(); - updatedFields.put("field1", "value1"); - @SuppressWarnings("unchecked") - ActionListener listener = mock(ActionListener.class); - - globalContextHandler.storeResponseToGlobalContext(documentId, updatedFields, listener); - - ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(UpdateRequest.class); - verify(client, times(1)).update(requestCaptor.capture(), any()); - - assertEquals(GLOBAL_CONTEXT_INDEX, requestCaptor.getValue().index()); - assertEquals(documentId, requestCaptor.getValue().id()); - } - - public void testUpdateTemplateInGlobalContext() throws IOException { - Template template = mock(Template.class); - when(template.toDocumentSource(any(XContentBuilder.class), eq(ToXContent.EMPTY_PARAMS))).thenAnswer(invocation -> { - XContentBuilder builder = invocation.getArgument(0); - return builder; - }); - when(createIndexStep.doesIndexExist(any())).thenReturn(true); - - globalContextHandler.updateTemplateInGlobalContext("1", template, null); - - ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(IndexRequest.class); - verify(client, times(1)).index(requestCaptor.capture(), any()); - - assertEquals("1", requestCaptor.getValue().id()); - } - - public void testFailedUpdateTemplateInGlobalContext() throws IOException { - Template template = mock(Template.class); - @SuppressWarnings("unchecked") - ActionListener listener = mock(ActionListener.class); - when(createIndexStep.doesIndexExist(any())).thenReturn(false); - - globalContextHandler.updateTemplateInGlobalContext("1", template, listener); - ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); - - verify(listener, times(1)).onFailure(exceptionCaptor.capture()); - - assertEquals( - "Failed to update template for workflow_id : 1, global_context index does not exist.", - exceptionCaptor.getValue().getMessage() - ); - - } -} diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 83c322fea..02fd7648f 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -11,9 +11,10 @@ import org.opensearch.Version; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; -import org.opensearch.flowframework.indices.GlobalContextHandler; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; @@ -37,17 +38,19 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { private CreateWorkflowTransportAction createWorkflowTransportAction; - private GlobalContextHandler globalContextHandler; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; private Template template; + private Client client = mock(Client.class); @Override public void setUp() throws Exception { super.setUp(); - this.globalContextHandler = mock(GlobalContextHandler.class); + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); this.createWorkflowTransportAction = new CreateWorkflowTransportAction( mock(TransportService.class), mock(ActionFilters.class), - globalContextHandler + flowFrameworkIndicesHandler, + client ); List operations = List.of("operation"); @@ -84,7 +87,7 @@ public void testCreateNewWorkflow() { ActionListener responseListener = invocation.getArgument(1); responseListener.onResponse(new IndexResponse(new ShardId(GLOBAL_CONTEXT_INDEX, "", 1), "1", 1L, 1L, 1L, true)); return null; - }).when(globalContextHandler).putTemplateToGlobalContext(any(Template.class), any()); + }).when(flowFrameworkIndicesHandler).putTemplateToGlobalContext(any(Template.class), any()); createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); @@ -103,7 +106,7 @@ public void testFailedToCreateNewWorkflow() { ActionListener responseListener = invocation.getArgument(1); responseListener.onFailure(new Exception("Failed to create global_context index")); return null; - }).when(globalContextHandler).putTemplateToGlobalContext(any(Template.class), any()); + }).when(flowFrameworkIndicesHandler).putTemplateToGlobalContext(any(Template.class), any()); createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); @@ -121,7 +124,7 @@ public void testUpdateWorkflow() { ActionListener responseListener = invocation.getArgument(2); responseListener.onResponse(new IndexResponse(new ShardId(GLOBAL_CONTEXT_INDEX, "", 1), "1", 1L, 1L, 1L, true)); return null; - }).when(globalContextHandler).updateTemplateInGlobalContext(any(), any(Template.class), any()); + }).when(flowFrameworkIndicesHandler).updateTemplateInGlobalContext(any(), any(Template.class), any()); createWorkflowTransportAction.doExecute(mock(Task.class), updateWorkflow, listener); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); @@ -139,7 +142,7 @@ public void testFailedToUpdateWorkflow() { ActionListener responseListener = invocation.getArgument(2); responseListener.onFailure(new Exception("Failed to update use case template")); return null; - }).when(globalContextHandler).updateTemplateInGlobalContext(any(), any(Template.class), any()); + }).when(flowFrameworkIndicesHandler).updateTemplateInGlobalContext(any(), any(Template.class), any()); createWorkflowTransportAction.doExecute(mock(Task.class), updateWorkflow, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java index f8b2d8490..9f1158392 100644 --- a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java @@ -20,6 +20,7 @@ import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; @@ -51,6 +52,7 @@ public class ProvisionWorkflowTransportActionTests extends OpenSearchTestCase { private WorkflowProcessSorter workflowProcessSorter; private ProvisionWorkflowTransportAction provisionWorkflowTransportAction; private Template template; + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; @Override public void setUp() throws Exception { @@ -58,13 +60,15 @@ public void setUp() throws Exception { this.threadPool = mock(ThreadPool.class); this.client = mock(Client.class); this.workflowProcessSorter = mock(WorkflowProcessSorter.class); + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); this.provisionWorkflowTransportAction = new ProvisionWorkflowTransportAction( mock(TransportService.class), mock(ActionFilters.class), threadPool, client, - workflowProcessSorter + workflowProcessSorter, + flowFrameworkIndicesHandler ); List operations = List.of("operation"); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java index 036714ba8..ab5dd476a 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java @@ -10,21 +10,17 @@ import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; -import org.opensearch.action.admin.indices.mapping.put.PutMappingRequest; -import org.opensearch.action.support.master.AcknowledgedResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.IndicesAdminClient; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.IndexMetadata; -import org.opensearch.cluster.metadata.MappingMetadata; import org.opensearch.cluster.metadata.Metadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; -import org.opensearch.flowframework.indices.FlowFrameworkIndex; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -41,7 +37,6 @@ import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -122,87 +117,4 @@ public void testCreateIndexStepFailure() throws ExecutionException, InterruptedE assertTrue(ex.getCause() instanceof Exception); assertEquals("Failed to create an index", ex.getCause().getMessage()); } - - public void testInitIndexIfAbsent_IndexNotPresent() { - when(metadata.hasIndex(FlowFrameworkIndex.GLOBAL_CONTEXT.getIndexName())).thenReturn(false); - - @SuppressWarnings("unchecked") - ActionListener listener = mock(ActionListener.class); - createIndexStep.initIndexIfAbsent(FlowFrameworkIndex.GLOBAL_CONTEXT, listener); - - verify(indicesAdminClient, times(1)).create(any(CreateIndexRequest.class), any()); - } - - public void testInitIndexIfAbsent_IndexExist() { - FlowFrameworkIndex index = FlowFrameworkIndex.GLOBAL_CONTEXT; - indexMappingUpdated.put(index.getIndexName(), new AtomicBoolean(false)); - - ClusterState mockClusterState = mock(ClusterState.class); - Metadata mockMetadata = mock(Metadata.class); - when(clusterService.state()).thenReturn(mockClusterState); - when(mockClusterState.metadata()).thenReturn(mockMetadata); - when(mockMetadata.hasIndex(index.getIndexName())).thenReturn(true); - @SuppressWarnings("unchecked") - ActionListener listener = mock(ActionListener.class); - - IndexMetadata mockIndexMetadata = mock(IndexMetadata.class); - @SuppressWarnings("unchecked") - Map mockIndices = mock(Map.class); - when(clusterService.state()).thenReturn(mockClusterState); - when(mockClusterState.getMetadata()).thenReturn(mockMetadata); - when(mockMetadata.indices()).thenReturn(mockIndices); - when(mockIndices.get(anyString())).thenReturn(mockIndexMetadata); - Map mockMapping = new HashMap<>(); - Map mockMetaMapping = new HashMap<>(); - mockMetaMapping.put(SCHEMA_VERSION_FIELD, 1); - mockMapping.put(META, mockMetaMapping); - MappingMetadata mockMappingMetadata = mock(MappingMetadata.class); - when(mockIndexMetadata.mapping()).thenReturn(mockMappingMetadata); - when(mockMappingMetadata.getSourceAsMap()).thenReturn(mockMapping); - - createIndexStep.initIndexIfAbsent(index, listener); - - ArgumentCaptor putMappingRequestArgumentCaptor = ArgumentCaptor.forClass(PutMappingRequest.class); - @SuppressWarnings({ "unchecked" }) - ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - verify(indicesAdminClient, times(1)).putMapping(putMappingRequestArgumentCaptor.capture(), listenerCaptor.capture()); - PutMappingRequest capturedRequest = putMappingRequestArgumentCaptor.getValue(); - assertEquals(index.getIndexName(), capturedRequest.indices()[0]); - } - - public void testInitIndexIfAbsent_IndexExist_returnFalse() { - FlowFrameworkIndex index = FlowFrameworkIndex.GLOBAL_CONTEXT; - indexMappingUpdated.put(index.getIndexName(), new AtomicBoolean(false)); - - ClusterState mockClusterState = mock(ClusterState.class); - Metadata mockMetadata = mock(Metadata.class); - when(clusterService.state()).thenReturn(mockClusterState); - when(mockClusterState.metadata()).thenReturn(mockMetadata); - when(mockMetadata.hasIndex(index.getIndexName())).thenReturn(true); - - @SuppressWarnings("unchecked") - ActionListener listener = mock(ActionListener.class); - @SuppressWarnings("unchecked") - Map mockIndices = mock(Map.class); - when(mockClusterState.getMetadata()).thenReturn(mockMetadata); - when(mockMetadata.indices()).thenReturn(mockIndices); - when(mockIndices.get(anyString())).thenReturn(null); - - createIndexStep.initIndexIfAbsent(index, listener); - assertTrue(indexMappingUpdated.get(index.getIndexName()).get()); - } - - public void testDoesIndexExist() { - ClusterState mockClusterState = mock(ClusterState.class); - Metadata mockMetaData = mock(Metadata.class); - when(clusterService.state()).thenReturn(mockClusterState); - when(mockClusterState.metadata()).thenReturn(mockMetaData); - - createIndexStep.doesIndexExist(GLOBAL_CONTEXT_INDEX); - - ArgumentCaptor indexExistsCaptor = ArgumentCaptor.forClass(String.class); - verify(mockMetaData, times(1)).hasIndex(indexExistsCaptor.capture()); - - assertEquals(GLOBAL_CONTEXT_INDEX, indexExistsCaptor.getValue()); - } }