diff --git a/CHANGELOG.md b/CHANGELOG.md index 465053a1e..c7faa6fbd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ### Enhancements ### Bug Fixes - Add user mapping to Workflow State index ([#705](https://github.com/opensearch-project/flow-framework/pull/705)) +- Add Workflow Step for Reindex from source index to destination ([#718](https://github.com/opensearch-project/flow-framework/pull/718)) ### Infrastructure ### Documentation diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index ac0291687..a7e35ab5e 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -174,7 +174,30 @@ private CommonValue() {} public static final String DELAY_FIELD = "delay"; /** Model Interface Field */ public static final String INTERFACE_FIELD = "interface"; - + /** The reindex field for created resources */ + public static final String RE_INDEX_FIELD = "reindex"; + /** The source index field for reindex */ + public static final String SOURCE_INDEX = "source_index"; + /** The destination index field for reindex */ + public static final String DESTINATION_INDEX = "destination_index"; + /** The refresh field for reindex */ + public static final String REFRESH = "refresh"; + /** The timeout field for reindex */ + public static final String TIMEOUT = "timeout"; + /** The wait_for_active_shards field for reindex */ + public static final String WAIT_FOR_ACTIVE_SHARDS = "wait_for_active_shards"; + /** The wait_for_completion field for reindex */ + public static final String WAIT_FOR_COMPLETION = "wait_for_completion"; + /** The requests_per_second field for reindex */ + public static final String REQUESTS_PER_SECOND = "requests_per_second"; + /** The require_alias field for reindex */ + public static final String REQUIRE_ALIAS = "require_alias"; + /** The scroll field for reindex */ + public static final String SCROLL = "scroll"; + /** The slices field for reindex */ + public static final String SLICES = "slices"; + /** The max_docs field for reindex */ + public static final String MAX_DOCS = "max_docs"; /* * Constants associated with resource provisioning / state */ diff --git a/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java b/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java index a024ec3b8..af00078d6 100644 --- a/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java +++ b/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java @@ -21,6 +21,7 @@ import org.opensearch.flowframework.workflow.DeleteModelStep; import org.opensearch.flowframework.workflow.DeployModelStep; import org.opensearch.flowframework.workflow.NoOpStep; +import org.opensearch.flowframework.workflow.ReIndexStep; import org.opensearch.flowframework.workflow.RegisterAgentStep; import org.opensearch.flowframework.workflow.RegisterLocalCustomModelStep; import org.opensearch.flowframework.workflow.RegisterLocalPretrainedModelStep; @@ -58,6 +59,8 @@ public enum WorkflowResources { CREATE_SEARCH_PIPELINE(CreateSearchPipelineStep.NAME, WorkflowResources.PIPELINE_ID, null), // TODO delete step /** Workflow steps for creating an index and associated created resource */ CREATE_INDEX(CreateIndexStep.NAME, WorkflowResources.INDEX_NAME, NoOpStep.NAME), + /** Workflow steps for reindex a source index to destination index and associated created resource */ + RE_INDEX(ReIndexStep.NAME, CommonValue.DESTINATION_INDEX, NoOpStep.NAME), /** Workflow steps for registering/deleting an agent and the associated created resource */ REGISTER_AGENT(RegisterAgentStep.NAME, WorkflowResources.AGENT_ID, DeleteAgentStep.NAME); diff --git a/src/main/java/org/opensearch/flowframework/workflow/ReIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/ReIndexStep.java new file mode 100644 index 000000000..a245b6f1f --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/ReIndexStep.java @@ -0,0 +1,189 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.client.Client; +import org.opensearch.common.Booleans; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.exception.WorkflowStepException; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.index.reindex.BulkByScrollResponse; +import org.opensearch.index.reindex.ReindexAction; +import org.opensearch.index.reindex.ReindexRequest; + +import java.util.Map; +import java.util.Set; + +import static org.opensearch.flowframework.common.CommonValue.DESTINATION_INDEX; +import static org.opensearch.flowframework.common.CommonValue.MAX_DOCS; +import static org.opensearch.flowframework.common.CommonValue.REFRESH; +import static org.opensearch.flowframework.common.CommonValue.REQUESTS_PER_SECOND; +import static org.opensearch.flowframework.common.CommonValue.REQUIRE_ALIAS; +import static org.opensearch.flowframework.common.CommonValue.RE_INDEX_FIELD; +import static org.opensearch.flowframework.common.CommonValue.SCROLL; +import static org.opensearch.flowframework.common.CommonValue.SLICES; +import static org.opensearch.flowframework.common.CommonValue.SOURCE_INDEX; +import static org.opensearch.flowframework.common.CommonValue.TIMEOUT; +import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_ACTIVE_SHARDS; +import static org.opensearch.flowframework.common.CommonValue.WAIT_FOR_COMPLETION; + +/** + * Step to reindex + */ +public class ReIndexStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(ReIndexStep.class); + private final Client client; + private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + + /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ + public static final String NAME = "reindex"; + + /** + * Instantiate this class + * + * @param client Client to create an index + * @param flowFrameworkIndicesHandler FlowFrameworkIndicesHandler class to update system indices + */ + public ReIndexStep(Client client, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler) { + this.client = client; + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; + } + + @Override + public PlainActionFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs, + Map params + ) { + + PlainActionFuture reIndexFuture = PlainActionFuture.newFuture(); + + Set requiredKeys = Set.of(SOURCE_INDEX, DESTINATION_INDEX); + + Set optionalKeys = Set.of( + REFRESH, + TIMEOUT, + WAIT_FOR_ACTIVE_SHARDS, + WAIT_FOR_COMPLETION, + REQUESTS_PER_SECOND, + REQUIRE_ALIAS, + SCROLL, + SLICES, + MAX_DOCS + ); + + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs, + params + ); + + String sourceIndices = (String) inputs.get(SOURCE_INDEX); + String destinationIndex = (String) inputs.get(DESTINATION_INDEX); + Boolean refresh = inputs.containsKey(REFRESH) ? Booleans.parseBoolean(inputs.get(REFRESH).toString()) : null; + Integer requestsPerSecond = (Integer) inputs.get(REQUESTS_PER_SECOND); + Boolean requireAlias = inputs.containsKey(REQUIRE_ALIAS) ? Booleans.parseBoolean(inputs.get(REQUIRE_ALIAS).toString()) : null; + Integer slices = (Integer) inputs.get(SLICES); + Integer maxDocs = (Integer) inputs.get(MAX_DOCS); + + ReindexRequest reindexRequest = new ReindexRequest(); + reindexRequest.setSourceIndices(sourceIndices); + reindexRequest.setDestIndex(destinationIndex); + if (refresh != null) { + reindexRequest.setRefresh(refresh); + } + if (requestsPerSecond != null) { + reindexRequest.setRequestsPerSecond(requestsPerSecond); + } + if (requireAlias != null) { + reindexRequest.setRequireAlias(requireAlias); + } + if (maxDocs != null) { + reindexRequest.setMaxDocs(maxDocs); + } + if (slices != null) { + reindexRequest.setSlices(slices); + } + + ActionListener actionListener = new ActionListener<>() { + + @Override + public void onResponse(BulkByScrollResponse bulkByScrollResponse) { + logger.info("Reindex from source: {} to destination {}", sourceIndices, destinationIndex); + try { + if (bulkByScrollResponse.getBulkFailures().isEmpty() && bulkByScrollResponse.getSearchFailures().isEmpty()) { + flowFrameworkIndicesHandler.updateResourceInStateIndex( + currentNodeInputs.getWorkflowId(), + currentNodeId, + getName(), + destinationIndex, + ActionListener.wrap(response -> { + logger.info("successfully updated resource created in state index: {}", response.getIndex()); + + reIndexFuture.onResponse( + new WorkflowData( + Map.of(RE_INDEX_FIELD, Map.of(sourceIndices, destinationIndex)), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); + }, exception -> { + String errorMessage = "Failed to update new reindexed" + + currentNodeId + + " resource " + + getName() + + " id " + + destinationIndex; + logger.error(errorMessage, exception); + reIndexFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + }) + ); + } + } catch (Exception e) { + String errorMessage = "Failed to parse and update new created resource"; + logger.error(errorMessage, e); + reIndexFuture.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); + } + } + + @Override + public void onFailure(Exception e) { + String errorMessage = "Failed to reindex from source " + sourceIndices + " to " + destinationIndex; + logger.error(errorMessage, e); + reIndexFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e))); + } + }; + + client.execute(ReindexAction.INSTANCE, reindexRequest, actionListener); + + } catch (Exception e) { + reIndexFuture.onFailure(e); + } + + return reIndexFuture; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 224cbf1eb..cedba194c 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -34,6 +34,7 @@ import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; import static org.opensearch.flowframework.common.CommonValue.CREDENTIAL_FIELD; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; +import static org.opensearch.flowframework.common.CommonValue.DESTINATION_INDEX; import static org.opensearch.flowframework.common.CommonValue.EMBEDDING_DIMENSION; import static org.opensearch.flowframework.common.CommonValue.FRAMEWORK_TYPE; import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME; @@ -47,6 +48,8 @@ import static org.opensearch.flowframework.common.CommonValue.PIPELINE_ID; import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; +import static org.opensearch.flowframework.common.CommonValue.RE_INDEX_FIELD; +import static org.opensearch.flowframework.common.CommonValue.SOURCE_INDEX; import static org.opensearch.flowframework.common.CommonValue.SUCCESS; import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD; import static org.opensearch.flowframework.common.CommonValue.TYPE; @@ -84,6 +87,7 @@ public WorkflowStepFactory( ) { stepMap.put(NoOpStep.NAME, NoOpStep::new); stepMap.put(CreateIndexStep.NAME, () -> new CreateIndexStep(client, flowFrameworkIndicesHandler)); + stepMap.put(ReIndexStep.NAME, () -> new ReIndexStep(client, flowFrameworkIndicesHandler)); stepMap.put( RegisterLocalCustomModelStep.NAME, () -> new RegisterLocalCustomModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings) @@ -125,6 +129,9 @@ public enum WorkflowSteps { /** Create Index Step */ CREATE_INDEX(CreateIndexStep.NAME, List.of(INDEX_NAME, CONFIGURATIONS), List.of(INDEX_NAME), Collections.emptyList(), null), + /** Create ReIndex Step */ + RE_INDEX(ReIndexStep.NAME, List.of(SOURCE_INDEX, DESTINATION_INDEX), List.of(RE_INDEX_FIELD), Collections.emptyList(), null), + /** Create Connector Step */ CREATE_CONNECTOR( CreateConnectorStep.NAME, diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java index 80a9788c2..19cb3d718 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java @@ -46,7 +46,7 @@ public void testParseWorkflowValidator() throws IOException { WorkflowValidator validator = new WorkflowValidator(workflowStepValidators); - assertEquals(17, validator.getWorkflowStepValidators().size()); + assertEquals(18, validator.getWorkflowStepValidators().size()); assertTrue(validator.getWorkflowStepValidators().keySet().contains("create_connector")); assertEquals(7, validator.getWorkflowStepValidators().get("create_connector").getInputs().size()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/ReIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ReIndexStepTests.java new file mode 100644 index 000000000..d62804d3e --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/ReIndexStepTests.java @@ -0,0 +1,196 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.lucene.tests.util.LuceneTestCase; +import org.opensearch.OpenSearchException; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.common.Randomness; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.index.reindex.BulkByScrollResponse; +import org.opensearch.index.reindex.BulkByScrollTask; +import org.opensearch.index.reindex.ReindexRequest; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; + +import org.mockito.ArgumentCaptor; +import org.mockito.MockitoAnnotations; + +import static java.lang.Math.abs; +import static java.util.stream.Collectors.toList; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.common.unit.TimeValue.timeValueMillis; +import static org.opensearch.flowframework.common.CommonValue.DESTINATION_INDEX; +import static org.opensearch.flowframework.common.CommonValue.MAX_DOCS; +import static org.opensearch.flowframework.common.CommonValue.REFRESH; +import static org.opensearch.flowframework.common.CommonValue.REQUESTS_PER_SECOND; +import static org.opensearch.flowframework.common.CommonValue.REQUIRE_ALIAS; +import static org.opensearch.flowframework.common.CommonValue.RE_INDEX_FIELD; +import static org.opensearch.flowframework.common.CommonValue.SLICES; +import static org.opensearch.flowframework.common.CommonValue.SOURCE_INDEX; +import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; +import static org.apache.lucene.tests.util.TestUtil.randomSimpleString; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class ReIndexStepTests extends OpenSearchTestCase { + private WorkflowData inputData = WorkflowData.EMPTY; + private Client client; + private ReIndexStep reIndexStep; + + private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + + @Override + public void setUp() throws Exception { + super.setUp(); + this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + MockitoAnnotations.openMocks(this); + + inputData = new WorkflowData( + Map.ofEntries( + Map.entry(SOURCE_INDEX, "demo"), + Map.entry(DESTINATION_INDEX, "dest"), + Map.entry(REFRESH, true), + Map.entry(REQUESTS_PER_SECOND, 2), + Map.entry(REQUIRE_ALIAS, false), + Map.entry(SLICES, 1), + Map.entry(MAX_DOCS, 2) + ), + "test-id", + "test-node-id" + ); + + client = mock(Client.class); + reIndexStep = new ReIndexStep(client, flowFrameworkIndicesHandler); + } + + public void testReIndexStep() throws ExecutionException, InterruptedException, IOException { + + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + + @SuppressWarnings({ "unchecked" }) + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + PlainActionFuture future = reIndexStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + verify(client, times(1)).execute(any(), any(ReindexRequest.class), actionListenerCaptor.capture()); + actionListenerCaptor.getValue() + .onResponse( + new BulkByScrollResponse( + timeValueMillis(randomNonNegativeLong()), + randomStatus(), + Collections.emptyList(), + Collections.emptyList(), + randomBoolean() + ) + ); + + assertTrue(future.isDone()); + + Map outputData = Map.of(RE_INDEX_FIELD, Map.of("demo", "dest")); + assertEquals(outputData, future.get().getContent()); + + } + + public void testReIndexStepFailure() throws ExecutionException, InterruptedException { + @SuppressWarnings({ "unchecked" }) + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + PlainActionFuture future = reIndexStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + assertFalse(future.isDone()); + verify(client, times(1)).execute(any(), any(ReindexRequest.class), actionListenerCaptor.capture()); + + actionListenerCaptor.getValue().onFailure(new Exception("Failed to reindex from source demo to dest")); + + assertTrue(future.isDone()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof Exception); + assertEquals("Failed to reindex from source demo to dest", ex.getCause().getMessage()); + } + + private static BulkByScrollTask.Status randomStatus() { + if (randomBoolean()) { + return randomWorkingStatus(null); + } + boolean canHaveNullStatues = randomBoolean(); + List statuses = IntStream.range(0, between(0, 10)).mapToObj(i -> { + if (canHaveNullStatues && LuceneTestCase.rarely()) { + return null; + } + if (randomBoolean()) { + return new BulkByScrollTask.StatusOrException(new OpenSearchException(randomAlphaOfLength(5))); + } + return new BulkByScrollTask.StatusOrException(randomWorkingStatus(i)); + }).collect(toList()); + return new BulkByScrollTask.Status(statuses, randomBoolean() ? "test" : null); + } + + private static BulkByScrollTask.Status randomWorkingStatus(Integer sliceId) { + // These all should be believably small because we sum them if we have multiple workers + int total = between(0, 10000000); + int updated = between(0, total); + int created = between(0, total - updated); + int deleted = between(0, total - updated - created); + int noops = total - updated - created - deleted; + int batches = between(0, 10000); + long versionConflicts = between(0, total); + long bulkRetries = between(0, 10000000); + long searchRetries = between(0, 100000); + // smallest unit of time during toXContent is Milliseconds + TimeUnit[] timeUnits = { TimeUnit.MILLISECONDS, TimeUnit.SECONDS, TimeUnit.MINUTES, TimeUnit.HOURS, TimeUnit.DAYS }; + TimeValue throttled = new TimeValue(randomIntBetween(0, 1000), randomFrom(timeUnits)); + TimeValue throttledUntil = new TimeValue(randomIntBetween(0, 1000), randomFrom(timeUnits)); + return new BulkByScrollTask.Status( + sliceId, + total, + updated, + created, + deleted, + batches, + versionConflicts, + noops, + bulkRetries, + searchRetries, + throttled, + abs(Randomness.get().nextFloat()), + randomBoolean() ? null : randomSimpleString(Randomness.get()), + throttledUntil + ); + } +}