From a1887cf8b9dd4467f5beb481ae691a4f0123abc9 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Fri, 29 Sep 2023 20:17:18 -0700 Subject: [PATCH 1/6] Add timeout for node execution Signed-off-by: Daniel Widdis --- .../flowframework/model/WorkflowNode.java | 4 ++ .../flowframework/workflow/ProcessNode.java | 52 ++++++++++------ .../workflow/WorkflowProcessSorter.java | 13 +++- .../workflow/WorkflowStepFactory.java | 2 +- .../workflow/ProcessNodeTests.java | 60 ++++++++++++++++--- .../resources/template/finaltemplate.json | 15 +++-- 6 files changed, 112 insertions(+), 34 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index b48b6e0d2..8c4a6ae52 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -41,6 +41,10 @@ public class WorkflowNode implements ToXContentObject { public static final String INPUTS_FIELD = "inputs"; /** The field defining processors in the inputs for search and ingest pipelines */ public static final String PROCESSORS_FIELD = "processors"; + /** The field defining the timeout value for this node */ + public static final String NODE_TIMEOUT_FIELD = "node_timeout"; + /** The default timeout value if the template doesn't override it */ + public static final String NODE_TIMEOUT_DEFAULT_VALUE = "10s"; private final String id; // unique id private final String type; // maps to a WorkflowStep diff --git a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java index a2d7628c3..f8e739461 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java @@ -10,6 +10,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.FutureUtils; import java.util.ArrayList; import java.util.List; @@ -17,6 +19,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; /** @@ -30,7 +33,8 @@ public class ProcessNode { private final WorkflowStep workflowStep; private final WorkflowData input; private final List predecessors; - private Executor executor; + private final Executor executor; + private final TimeValue nodeTimeout; private final CompletableFuture future = new CompletableFuture<>(); @@ -42,13 +46,22 @@ public class ProcessNode { * @param input Input required by the node encoded in a {@link WorkflowData} instance. * @param predecessors Nodes preceding this one in the workflow * @param executor The OpenSearch thread pool + * @param nodeTimeout The timeout value for executing on this node */ - public ProcessNode(String id, WorkflowStep workflowStep, WorkflowData input, List predecessors, Executor executor) { + public ProcessNode( + String id, + WorkflowStep workflowStep, + WorkflowData input, + List predecessors, + Executor executor, + TimeValue nodeTimeout + ) { this.id = id; this.workflowStep = workflowStep; this.input = input; this.predecessors = predecessors; this.executor = executor; + this.nodeTimeout = nodeTimeout; } /** @@ -102,22 +115,12 @@ public List predecessors() { * @return this node's future. This is returned immediately, while process execution continues asynchronously. */ public CompletableFuture execute() { - // TODO this class will be instantiated with the OpenSearch thread pool (or one for tests!) - // the generic executor from that pool should be passed to this runAsync call - // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/42 + if (this.future.isDone()) { + throw new IllegalStateException("Process Node [" + this.id + "] already executed."); + } CompletableFuture.runAsync(() -> { List> predFutures = predecessors.stream().map(p -> p.future()).collect(Collectors.toList()); - if (!predecessors.isEmpty()) { - CompletableFuture waitForPredecessors = CompletableFuture.allOf(predFutures.toArray(new CompletableFuture[0])); - try { - // We need timeouts to be part of the user template or in settings - // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/45 - waitForPredecessors.orTimeout(30, TimeUnit.SECONDS).get(); - } catch (InterruptedException | ExecutionException e) { - handleException(e); - return; - } - } + predFutures.stream().forEach(CompletableFuture::join); logger.info(">>> Starting {}.", this.id); // get the input data from predecessor(s) List input = new ArrayList(); @@ -130,11 +133,22 @@ public CompletableFuture execute() { return; } } - CompletableFuture stepFuture = this.workflowStep.execute(input); try { - stepFuture.orTimeout(15, TimeUnit.SECONDS).join(); + CompletableFuture delayFuture = null; + if (this.nodeTimeout.compareTo(TimeValue.ZERO) > 0) { + Executor timeoutExec = CompletableFuture.delayedExecutor(this.nodeTimeout.nanos(), TimeUnit.NANOSECONDS, executor); + delayFuture = CompletableFuture.runAsync(() -> { + future.completeExceptionally(new TimeoutException("Execute timed out for " + this.id)); + // If completed normally the above will be a no-op + if (future.isCompletedExceptionally()) { + logger.info(">>> Timed out {}.", this.id); + } + }, timeoutExec); + } + CompletableFuture stepFuture = this.workflowStep.execute(input); + future.complete(stepFuture.get()); // If completed exceptionally, this is a NOOP logger.info(">>> Finished {}.", this.id); - future.complete(stepFuture.get()); + FutureUtils.cancel(delayFuture); } catch (InterruptedException | ExecutionException e) { handleException(e); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 3370f6384..b5503f52f 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -10,6 +10,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.unit.TimeValue; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; @@ -27,6 +29,10 @@ import java.util.function.Function; import java.util.stream.Collectors; +import static org.opensearch.flowframework.model.WorkflowNode.INPUTS_FIELD; +import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_DEFAULT_VALUE; +import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_FIELD; + /** * Utility class converting a workflow of nodes and edges into a topologically sorted list of Process Nodes. */ @@ -91,7 +97,12 @@ public List sortProcessNodes(Workflow workflow) { .map(e -> idToNodeMap.get(e.source())) .collect(Collectors.toList()); - ProcessNode processNode = new ProcessNode(node.id(), step, data, predecessorNodes, executor); + TimeValue nodeTimeout = Setting.parseTimeValue( + (String) node.inputs().getOrDefault(NODE_TIMEOUT_FIELD, NODE_TIMEOUT_DEFAULT_VALUE), + TimeValue.ZERO, + String.join(".", node.id(), INPUTS_FIELD, NODE_TIMEOUT_FIELD) + ); + ProcessNode processNode = new ProcessNode(node.id(), step, data, predecessorNodes, executor, nodeTimeout); idToNodeMap.put(processNode.id(), processNode); nodes.add(processNode); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index dc0dc29a2..5da3128d0 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -62,7 +62,7 @@ private void populateMap(Client client) { // TODO: These are from the demo class as placeholders, remove when demos are deleted stepMap.put("demo_delay_3", new DemoWorkflowStep(3000)); - stepMap.put("demo_delay_5", new DemoWorkflowStep(3000)); + stepMap.put("demo_delay_5", new DemoWorkflowStep(5000)); // Use as a default until all the actual implementations are ready stepMap.put("placeholder", new WorkflowStep() { diff --git a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java index 1972d20eb..e362c50f1 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java @@ -8,29 +8,45 @@ */ package org.opensearch.flowframework.workflow; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; + +import org.opensearch.common.unit.TimeValue; import org.opensearch.test.OpenSearchTestCase; -import org.junit.After; -import org.junit.Before; +import org.junit.AfterClass; +import org.junit.BeforeClass; import java.util.Collections; import java.util.List; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class ProcessNodeTests extends OpenSearchTestCase { - private ExecutorService executor; + private static ExecutorService executor; - @Before - public void setup() { + @BeforeClass + public static void setup() { executor = Executors.newFixedThreadPool(10); } - @After - public void cleanup() { + @AfterClass + public static void cleanup() { executor.shutdown(); + try { + if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { + executor.shutdownNow(); + executor.awaitTermination(5, TimeUnit.SECONDS); + } + } catch (InterruptedException e) { + executor.shutdownNow(); + Thread.currentThread().interrupt(); + } } public void testNode() throws InterruptedException, ExecutionException { @@ -46,7 +62,7 @@ public CompletableFuture execute(List data) { public String getName() { return "test"; } - }, WorkflowData.EMPTY, Collections.emptyList(), executor); + }, WorkflowData.EMPTY, Collections.emptyList(), executor, TimeValue.ZERO); assertEquals("A", nodeA.id()); assertEquals("test", nodeA.workflowStep().getName()); assertEquals(WorkflowData.EMPTY, nodeA.input()); @@ -57,4 +73,32 @@ public String getName() { assertEquals(f, nodeA.future()); assertEquals(WorkflowData.EMPTY, f.get()); } + + public void testNodeTimeout() throws InterruptedException, ExecutionException { + ProcessNode nodeZ = new ProcessNode("Zzz", new WorkflowStep() { + @Override + public CompletableFuture execute(List data) { + CompletableFuture future = new CompletableFuture<>(); + CompletableFuture.delayedExecutor(250, TimeUnit.MILLISECONDS, executor).execute(() -> future.complete(WorkflowData.EMPTY)); + return future; + } + + @Override + public String getName() { + return "sleepy"; + } + }, WorkflowData.EMPTY, Collections.emptyList(), executor, TimeValue.timeValueMillis(100)); + assertEquals("Zzz", nodeZ.id()); + assertEquals("sleepy", nodeZ.workflowStep().getName()); + assertEquals(WorkflowData.EMPTY, nodeZ.input()); + assertEquals(Collections.emptyList(), nodeZ.predecessors()); + assertEquals("Zzz", nodeZ.toString()); + + CompletableFuture f = nodeZ.execute(); + CompletionException exception = assertThrows(CompletionException.class, () -> f.join()); + assertTrue(f.isCompletedExceptionally()); + assertEquals(TimeoutException.class, exception.getCause().getClass()); + + assertThrows(IllegalStateException.class, () -> nodeZ.execute()); + } } diff --git a/src/test/resources/template/finaltemplate.json b/src/test/resources/template/finaltemplate.json index d8443c4c6..fe1a57e36 100644 --- a/src/test/resources/template/finaltemplate.json +++ b/src/test/resources/template/finaltemplate.json @@ -25,7 +25,8 @@ "type": "create_index", "inputs": { "name": "user_inputs.index_name", - "settings": "user_inputs.index_settings" + "settings": "user_inputs.index_settings", + "node_timeout": "10s" } }, { @@ -41,7 +42,8 @@ "input_field": "text_passage", "output_field": "text_embedding" } - }] + }], + "node_timeout": "10s" } } ], @@ -60,7 +62,8 @@ "inputs": { "index": "user_inputs.index_name", "ingest_pipeline": "my-ingest-pipeline", - "document": "user_params.document" + "document": "user_params.document", + "node_timeout": "10s" } }] }, @@ -73,7 +76,8 @@ "type": "transform_query", "inputs": { "template": "neural-search-template-1", - "plaintext": "user_params.plaintext" + "plaintext": "user_params.plaintext", + "node_timeout": "10s" } }, { @@ -83,7 +87,8 @@ "index": "user_inputs.index_name", "query": "{{output-from-prev-step}}.query", "search_request_processors": [], - "search_response_processors": [] + "search_response_processors": [], + "node_timeout": "10s" } } ], From 70275ac99753ccab1fc230bc9aaad322d193a781 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Sat, 30 Sep 2023 11:51:42 -0700 Subject: [PATCH 2/6] Properly implement delays using OpenSearch ThreadPool Signed-off-by: Daniel Widdis --- src/main/java/demo/Demo.java | 14 +++-- src/main/java/demo/TemplateParseDemo.java | 10 +++- .../flowframework/FlowFrameworkPlugin.java | 7 ++- .../flowframework/workflow/ProcessNode.java | 56 ++++++++++--------- .../workflow/WorkflowProcessSorter.java | 16 +++--- .../flowframework/workflow/WorkflowStep.java | 1 + .../workflow/ProcessNodeTests.java | 32 ++++------- .../workflow/WorkflowProcessSorterTests.java | 13 +++-- 8 files changed, 82 insertions(+), 67 deletions(-) diff --git a/src/main/java/demo/Demo.java b/src/main/java/demo/Demo.java index 53cf3499c..ecbf17162 100644 --- a/src/main/java/demo/Demo.java +++ b/src/main/java/demo/Demo.java @@ -14,10 +14,12 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.common.SuppressForbidden; import org.opensearch.common.io.PathUtils; +import org.opensearch.common.settings.Settings; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.threadpool.ThreadPool; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -26,8 +28,7 @@ import java.util.List; import java.util.Locale; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; /** @@ -37,6 +38,8 @@ public class Demo { private static final Logger logger = LogManager.getLogger(Demo.class); + private Demo() {} + /** * Demonstrate parsing a JSON graph. * @@ -55,8 +58,9 @@ public static void main(String[] args) throws IOException { } Client client = new NodeClient(null, null); WorkflowStepFactory factory = WorkflowStepFactory.create(client); - ExecutorService executor = Executors.newFixedThreadPool(10); - WorkflowProcessSorter.create(factory, executor); + + ThreadPool threadPool = new ThreadPool(Settings.EMPTY); + WorkflowProcessSorter.create(factory, threadPool); logger.info("Parsing graph to sequence..."); Template t = Template.parse(json); @@ -80,6 +84,6 @@ public static void main(String[] args) throws IOException { } futureList.forEach(CompletableFuture::join); logger.info("All done!"); - executor.shutdown(); + ThreadPool.terminate(threadPool, 500, TimeUnit.MILLISECONDS); } } diff --git a/src/main/java/demo/TemplateParseDemo.java b/src/main/java/demo/TemplateParseDemo.java index 307d707c0..55225b94b 100644 --- a/src/main/java/demo/TemplateParseDemo.java +++ b/src/main/java/demo/TemplateParseDemo.java @@ -14,16 +14,18 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.common.SuppressForbidden; import org.opensearch.common.io.PathUtils; +import org.opensearch.common.settings.Settings; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.threadpool.ThreadPool; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.util.Map.Entry; -import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; /** * Demo class exercising {@link WorkflowProcessSorter}. This will be moved to a unit test. @@ -32,6 +34,8 @@ public class TemplateParseDemo { private static final Logger logger = LogManager.getLogger(TemplateParseDemo.class); + private TemplateParseDemo() {} + /** * Demonstrate parsing a JSON graph. * @@ -50,7 +54,8 @@ public static void main(String[] args) throws IOException { } Client client = new NodeClient(null, null); WorkflowStepFactory factory = WorkflowStepFactory.create(client); - WorkflowProcessSorter.create(factory, Executors.newFixedThreadPool(10)); + ThreadPool threadPool = new ThreadPool(Settings.EMPTY); + WorkflowProcessSorter.create(factory, threadPool); Template t = Template.parse(json); @@ -61,5 +66,6 @@ public static void main(String[] args) throws IOException { logger.info("Parsing {} workflow.", e.getKey()); WorkflowProcessSorter.get().sortProcessNodes(e.getValue()); } + ThreadPool.terminate(threadPool, 500, TimeUnit.MILLISECONDS); } } diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index d701c832e..c24fa911d 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -32,6 +32,11 @@ */ public class FlowFrameworkPlugin extends Plugin { + /** + * Instantiate this plugin. + */ + public FlowFrameworkPlugin() {} + @Override public Collection createComponents( Client client, @@ -47,7 +52,7 @@ public Collection createComponents( Supplier repositoriesServiceSupplier ) { WorkflowStepFactory workflowStepFactory = WorkflowStepFactory.create(client); - WorkflowProcessSorter workflowProcessSorter = WorkflowProcessSorter.create(workflowStepFactory, threadPool.generic()); + WorkflowProcessSorter workflowProcessSorter = WorkflowProcessSorter.create(workflowStepFactory, threadPool); return ImmutableList.of(workflowStepFactory, workflowProcessSorter); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java index f8e739461..f6364aa59 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java @@ -11,19 +11,19 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.concurrent.FutureUtils; +import org.opensearch.threadpool.Scheduler.ScheduledCancellable; +import org.opensearch.threadpool.ThreadPool; import java.util.ArrayList; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executor; -import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; /** - * Representation of a process node in a workflow graph. Tracks predecessor nodes which must be completed before it can start execution. + * Representation of a process node in a workflow graph. Tracks predecessor nodes which must be completed before it can + * start execution. */ public class ProcessNode { @@ -33,7 +33,7 @@ public class ProcessNode { private final WorkflowStep workflowStep; private final WorkflowData input; private final List predecessors; - private final Executor executor; + private final ThreadPool threadPool; private final TimeValue nodeTimeout; private final CompletableFuture future = new CompletableFuture<>(); @@ -41,31 +41,32 @@ public class ProcessNode { /** * Create this node linked to its executing process, including input data and any predecessor nodes. * - * @param id A string identifying the workflow step + * @param id A string identifying the workflow step * @param workflowStep A java class implementing {@link WorkflowStep} to be executed when it's this node's turn. - * @param input Input required by the node encoded in a {@link WorkflowData} instance. + * @param input Input required by the node encoded in a {@link WorkflowData} instance. * @param predecessors Nodes preceding this one in the workflow - * @param executor The OpenSearch thread pool - * @param nodeTimeout The timeout value for executing on this node + * @param threadPool The OpenSearch thread pool + * @param nodeTimeout The timeout value for executing on this node */ public ProcessNode( String id, WorkflowStep workflowStep, WorkflowData input, List predecessors, - Executor executor, + ThreadPool threadPool, TimeValue nodeTimeout ) { this.id = id; this.workflowStep = workflowStep; this.input = input; this.predecessors = predecessors; - this.executor = executor; + this.threadPool = threadPool; this.nodeTimeout = nodeTimeout; } /** * Returns the node's id. + * * @return the node id. */ public String id() { @@ -74,6 +75,7 @@ public String id() { /** * Returns the node's workflow implementation. + * * @return the workflow step */ public WorkflowStep workflowStep() { @@ -82,6 +84,7 @@ public WorkflowStep workflowStep() { /** * Returns the input data for this node. + * * @return the input data */ public WorkflowData input() { @@ -89,28 +92,30 @@ public WorkflowData input() { } /** - * Returns a {@link CompletableFuture} if this process is executing. - * Relies on the node having been sorted and executed in an order such that all predecessor nodes have begun execution first (and thus populated this value). + * Returns a {@link CompletableFuture} if this process is executing. Relies on the node having been sorted and + * executed in an order such that all predecessor nodes have begun execution first (and thus populated this value). * - * @return A future indicating the processing state of this node. - * Returns {@code null} if it has not begun executing, should not happen if a workflow is sorted and executed topologically. + * @return A future indicating the processing state of this node. Returns {@code null} if it has not begun + * executing, should not happen if a workflow is sorted and executed topologically. */ public CompletableFuture future() { return future; } /** - * Returns the predecessors of this node in the workflow. - * The predecessor's {@link #future()} must complete before execution begins on this node. + * Returns the predecessors of this node in the workflow. The predecessor's {@link #future()} must complete before + * execution begins on this node. * - * @return a set of predecessor nodes, if any. At least one node in the graph must have no predecessors and serve as a start node. + * @return a set of predecessor nodes, if any. At least one node in the graph must have no predecessors and serve as + * a start node. */ public List predecessors() { return predecessors; } /** - * Execute this node in the sequence. Initializes the node's {@link CompletableFuture} and completes it when the process completes. + * Execute this node in the sequence. Initializes the node's {@link CompletableFuture} and completes it when the + * process completes. * * @return this node's future. This is returned immediately, while process execution continues asynchronously. */ @@ -134,25 +139,26 @@ public CompletableFuture execute() { } } try { - CompletableFuture delayFuture = null; + ScheduledCancellable delayExec = null; if (this.nodeTimeout.compareTo(TimeValue.ZERO) > 0) { - Executor timeoutExec = CompletableFuture.delayedExecutor(this.nodeTimeout.nanos(), TimeUnit.NANOSECONDS, executor); - delayFuture = CompletableFuture.runAsync(() -> { + delayExec = threadPool.schedule(() -> { future.completeExceptionally(new TimeoutException("Execute timed out for " + this.id)); // If completed normally the above will be a no-op if (future.isCompletedExceptionally()) { logger.info(">>> Timed out {}.", this.id); } - }, timeoutExec); + }, this.nodeTimeout, ThreadPool.Names.GENERIC); + // TODO: improve use of thread pool beyond generic + // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/61 } CompletableFuture stepFuture = this.workflowStep.execute(input); future.complete(stepFuture.get()); // If completed exceptionally, this is a NOOP + delayExec.cancel(); logger.info(">>> Finished {}.", this.id); - FutureUtils.cancel(delayFuture); } catch (InterruptedException | ExecutionException e) { handleException(e); } - }, executor); + }, threadPool.generic()); return this.future; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index b5503f52f..0f8c3a418 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -15,6 +15,7 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.threadpool.ThreadPool; import java.util.ArrayDeque; import java.util.ArrayList; @@ -25,7 +26,6 @@ import java.util.Map; import java.util.Queue; import java.util.Set; -import java.util.concurrent.Executor; import java.util.function.Function; import java.util.stream.Collectors; @@ -43,20 +43,20 @@ public class WorkflowProcessSorter { private static WorkflowProcessSorter instance = null; private WorkflowStepFactory workflowStepFactory; - private Executor executor; + private ThreadPool threadPool; /** * Create the singleton instance of this class. Throws an {@link IllegalStateException} if already created. * * @param workflowStepFactory The singleton instance of {@link WorkflowStepFactory} - * @param executor A thread executor + * @param threadPool A thread executor * @return The created instance */ - public static synchronized WorkflowProcessSorter create(WorkflowStepFactory workflowStepFactory, Executor executor) { + public static synchronized WorkflowProcessSorter create(WorkflowStepFactory workflowStepFactory, ThreadPool threadPool) { if (instance != null) { throw new IllegalStateException("This class was already created."); } - instance = new WorkflowProcessSorter(workflowStepFactory, executor); + instance = new WorkflowProcessSorter(workflowStepFactory, threadPool); return instance; } @@ -72,9 +72,9 @@ public static synchronized WorkflowProcessSorter get() { return instance; } - private WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, Executor executor) { + private WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, ThreadPool threadPool) { this.workflowStepFactory = workflowStepFactory; - this.executor = executor; + this.threadPool = threadPool; } /** @@ -102,7 +102,7 @@ public List sortProcessNodes(Workflow workflow) { TimeValue.ZERO, String.join(".", node.id(), INPUTS_FIELD, NODE_TIMEOUT_FIELD) ); - ProcessNode processNode = new ProcessNode(node.id(), step, data, predecessorNodes, executor, nodeTimeout); + ProcessNode processNode = new ProcessNode(node.id(), step, data, predecessorNodes, threadPool, nodeTimeout); idToNodeMap.put(processNode.id(), processNode); nodes.add(processNode); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java index 6cd5f5a28..313bf8830 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java @@ -24,6 +24,7 @@ public interface WorkflowStep { CompletableFuture execute(List data); /** + * Gets the name of the workflow step. * * @return the name of this workflow step. */ diff --git a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java index e362c50f1..718637ffe 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java @@ -8,10 +8,10 @@ */ package org.opensearch.flowframework.workflow; -import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; - import org.opensearch.common.unit.TimeValue; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -20,33 +20,21 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -@ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class ProcessNodeTests extends OpenSearchTestCase { - private static ExecutorService executor; + private static TestThreadPool testThreadPool; @BeforeClass public static void setup() { - executor = Executors.newFixedThreadPool(10); + testThreadPool = new TestThreadPool(ProcessNodeTests.class.getName()); } @AfterClass public static void cleanup() { - executor.shutdown(); - try { - if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { - executor.shutdownNow(); - executor.awaitTermination(5, TimeUnit.SECONDS); - } - } catch (InterruptedException e) { - executor.shutdownNow(); - Thread.currentThread().interrupt(); - } + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); } public void testNode() throws InterruptedException, ExecutionException { @@ -62,7 +50,7 @@ public CompletableFuture execute(List data) { public String getName() { return "test"; } - }, WorkflowData.EMPTY, Collections.emptyList(), executor, TimeValue.ZERO); + }, WorkflowData.EMPTY, Collections.emptyList(), testThreadPool, TimeValue.ZERO); assertEquals("A", nodeA.id()); assertEquals("test", nodeA.workflowStep().getName()); assertEquals(WorkflowData.EMPTY, nodeA.input()); @@ -79,7 +67,11 @@ public void testNodeTimeout() throws InterruptedException, ExecutionException { @Override public CompletableFuture execute(List data) { CompletableFuture future = new CompletableFuture<>(); - CompletableFuture.delayedExecutor(250, TimeUnit.MILLISECONDS, executor).execute(() -> future.complete(WorkflowData.EMPTY)); + testThreadPool.schedule( + () -> future.complete(WorkflowData.EMPTY), + TimeValue.timeValueMillis(250), + ThreadPool.Names.GENERIC + ); return future; } @@ -87,7 +79,7 @@ public CompletableFuture execute(List data) { public String getName() { return "sleepy"; } - }, WorkflowData.EMPTY, Collections.emptyList(), executor, TimeValue.timeValueMillis(100)); + }, WorkflowData.EMPTY, Collections.emptyList(), testThreadPool, TimeValue.timeValueMillis(100)); assertEquals("Zzz", nodeZ.id()); assertEquals("sleepy", nodeZ.workflowStep().getName()); assertEquals(WorkflowData.EMPTY, nodeZ.input()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index 1e9c8e808..b398bed94 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -14,14 +14,15 @@ import org.opensearch.flowframework.model.TemplateTestJsonUtil; import org.opensearch.flowframework.model.Workflow; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; import org.junit.AfterClass; import org.junit.BeforeClass; import java.io.IOException; import java.util.Collections; import java.util.List; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.edge; @@ -43,7 +44,7 @@ private static List parse(String json) throws IOException { return workflowProcessSorter.sortProcessNodes(w).stream().map(ProcessNode::id).collect(Collectors.toList()); } - private static ExecutorService executor; + private static TestThreadPool testThreadPool; private static WorkflowProcessSorter workflowProcessSorter; @BeforeClass @@ -52,14 +53,14 @@ public static void setup() { Client client = mock(Client.class); when(client.admin()).thenReturn(adminClient); - executor = Executors.newFixedThreadPool(10); + testThreadPool = new TestThreadPool(WorkflowProcessSorterTests.class.getName()); WorkflowStepFactory factory = WorkflowStepFactory.create(client); - workflowProcessSorter = WorkflowProcessSorter.create(factory, executor); + workflowProcessSorter = WorkflowProcessSorter.create(factory, testThreadPool); } @AfterClass public static void cleanup() { - executor.shutdown(); + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); } public void testOrdering() throws IOException { From d8c3241fbf4b6d3687fed7e5229ae63337874e4c Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Sat, 30 Sep 2023 15:48:25 -0700 Subject: [PATCH 3/6] Add coverage for plugin class, make test threshold dynamic Signed-off-by: Daniel Widdis --- .codecov.yml | 9 +++-- .../flowframework/workflow/ProcessNode.java | 38 +++++++++---------- .../workflow/WorkflowProcessSorter.java | 2 +- .../flowframework/workflow/WorkflowStep.java | 1 - .../FlowFrameworkPluginTests.java | 34 ++++++++++++++++- 5 files changed, 58 insertions(+), 26 deletions(-) diff --git a/.codecov.yml b/.codecov.yml index e5bbd7262..827160da7 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -1,5 +1,5 @@ codecov: - require_ci_to_pass: yes + require_ci_to_pass: true # ignore files in demo package ignore: @@ -12,5 +12,8 @@ coverage: status: project: default: - target: 70% # the required coverage value - threshold: 1% # the leniency in hitting the target + target: auto + threshold: 2% # project coverage can drop + patch: + default: + target: 70% # required diff coverage value diff --git a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java index f6364aa59..538d72d4e 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java @@ -22,8 +22,8 @@ import java.util.stream.Collectors; /** - * Representation of a process node in a workflow graph. Tracks predecessor nodes which must be completed before it can - * start execution. + * Representation of a process node in a workflow graph. + * Tracks predecessor nodes which must be completed before it can start execution. */ public class ProcessNode { @@ -41,12 +41,12 @@ public class ProcessNode { /** * Create this node linked to its executing process, including input data and any predecessor nodes. * - * @param id A string identifying the workflow step + * @param id A string identifying the workflow step * @param workflowStep A java class implementing {@link WorkflowStep} to be executed when it's this node's turn. - * @param input Input required by the node encoded in a {@link WorkflowData} instance. + * @param input Input required by the node encoded in a {@link WorkflowData} instance. * @param predecessors Nodes preceding this one in the workflow - * @param threadPool The OpenSearch thread pool - * @param nodeTimeout The timeout value for executing on this node + * @param threadPool The OpenSearch thread pool + * @param nodeTimeout The timeout value for executing on this node */ public ProcessNode( String id, @@ -66,7 +66,6 @@ public ProcessNode( /** * Returns the node's id. - * * @return the node id. */ public String id() { @@ -75,7 +74,6 @@ public String id() { /** * Returns the node's workflow implementation. - * * @return the workflow step */ public WorkflowStep workflowStep() { @@ -84,7 +82,6 @@ public WorkflowStep workflowStep() { /** * Returns the input data for this node. - * * @return the input data */ public WorkflowData input() { @@ -92,32 +89,33 @@ public WorkflowData input() { } /** - * Returns a {@link CompletableFuture} if this process is executing. Relies on the node having been sorted and - * executed in an order such that all predecessor nodes have begun execution first (and thus populated this value). + * Returns a {@link CompletableFuture} if this process is executing. + * Relies on the node having been sorted and executed in an order such that all predecessor nodes have begun execution first (and thus populated this value). * - * @return A future indicating the processing state of this node. Returns {@code null} if it has not begun - * executing, should not happen if a workflow is sorted and executed topologically. + * @return A future indicating the processing state of this node. + * Returns {@code null} if it has not begun executing, should not happen if a workflow is sorted and executed topologically. */ public CompletableFuture future() { return future; } /** - * Returns the predecessors of this node in the workflow. The predecessor's {@link #future()} must complete before - * execution begins on this node. + * Returns the predecessors of this node in the workflow. + * The predecessor's {@link #future()} must complete before execution begins on this node. * - * @return a set of predecessor nodes, if any. At least one node in the graph must have no predecessors and serve as - * a start node. + * @return a set of predecessor nodes, if any. + * At least one node in the graph must have no predecessors and serve as a start node. */ public List predecessors() { return predecessors; } /** - * Execute this node in the sequence. Initializes the node's {@link CompletableFuture} and completes it when the - * process completes. + * Execute this node in the sequence. + * Initializes the node's {@link CompletableFuture} and completes it when the process completes. * - * @return this node's future. This is returned immediately, while process execution continues asynchronously. + * @return this node's future. + * This is returned immediately, while process execution continues asynchronously. */ public CompletableFuture execute() { if (this.future.isDone()) { diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 0f8c3a418..93753cee2 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -49,7 +49,7 @@ public class WorkflowProcessSorter { * Create the singleton instance of this class. Throws an {@link IllegalStateException} if already created. * * @param workflowStepFactory The singleton instance of {@link WorkflowStepFactory} - * @param threadPool A thread executor + * @param threadPool The Thread Pool to send to Process Nodes * @return The created instance */ public static synchronized WorkflowProcessSorter create(WorkflowStepFactory workflowStepFactory, ThreadPool threadPool) { diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java index 313bf8830..41e627016 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java @@ -25,7 +25,6 @@ public interface WorkflowStep { /** * Gets the name of the workflow step. - * * @return the name of this workflow step. */ String getName(); diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index 9f7075d19..d211e3928 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -8,8 +8,40 @@ */ package org.opensearch.flowframework; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class FlowFrameworkPluginTests extends OpenSearchTestCase { - // Add unit tests for your plugin + + private Client client; + private ThreadPool threadPool; + + @Override + public void setUp() throws Exception { + super.setUp(); + client = mock(Client.class); + when(client.admin()).thenReturn(mock(AdminClient.class)); + threadPool = new TestThreadPool(FlowFrameworkPluginTests.class.getName()); + } + + @Override + public void tearDown() throws Exception { + ThreadPool.terminate(threadPool, 500, TimeUnit.MILLISECONDS); + super.tearDown(); + } + + public void testPlugin() throws IOException { + try (FlowFrameworkPlugin ffp = new FlowFrameworkPlugin()) { + assertEquals(2, ffp.createComponents(client, null, threadPool, null, null, null, null, null, null, null, null).size()); + } + } } From e1616a7b4de16bf43f3650c0a426c79b6ab76025 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Sat, 30 Sep 2023 12:37:20 -0700 Subject: [PATCH 4/6] Tests don't like singletons Signed-off-by: Daniel Widdis --- src/main/java/demo/Demo.java | 6 ++-- src/main/java/demo/TemplateParseDemo.java | 6 ++-- .../flowframework/FlowFrameworkPlugin.java | 4 +-- .../workflow/WorkflowProcessSorter.java | 33 +++---------------- .../workflow/WorkflowStepFactory.java | 27 ++------------- .../workflow/WorkflowProcessSorterTests.java | 4 +-- 6 files changed, 17 insertions(+), 63 deletions(-) diff --git a/src/main/java/demo/Demo.java b/src/main/java/demo/Demo.java index ecbf17162..12bd6925d 100644 --- a/src/main/java/demo/Demo.java +++ b/src/main/java/demo/Demo.java @@ -57,14 +57,14 @@ public static void main(String[] args) throws IOException { return; } Client client = new NodeClient(null, null); - WorkflowStepFactory factory = WorkflowStepFactory.create(client); + WorkflowStepFactory factory = new WorkflowStepFactory(client); ThreadPool threadPool = new ThreadPool(Settings.EMPTY); - WorkflowProcessSorter.create(factory, threadPool); + WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool); logger.info("Parsing graph to sequence..."); Template t = Template.parse(json); - List processSequence = WorkflowProcessSorter.get().sortProcessNodes(t.workflows().get("demo")); + List processSequence = workflowProcessSorter.sortProcessNodes(t.workflows().get("demo")); List> futureList = new ArrayList<>(); for (ProcessNode n : processSequence) { diff --git a/src/main/java/demo/TemplateParseDemo.java b/src/main/java/demo/TemplateParseDemo.java index 55225b94b..dbe338217 100644 --- a/src/main/java/demo/TemplateParseDemo.java +++ b/src/main/java/demo/TemplateParseDemo.java @@ -53,9 +53,9 @@ public static void main(String[] args) throws IOException { return; } Client client = new NodeClient(null, null); - WorkflowStepFactory factory = WorkflowStepFactory.create(client); + WorkflowStepFactory factory = new WorkflowStepFactory(client); ThreadPool threadPool = new ThreadPool(Settings.EMPTY); - WorkflowProcessSorter.create(factory, threadPool); + WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool); Template t = Template.parse(json); @@ -64,7 +64,7 @@ public static void main(String[] args) throws IOException { for (Entry e : t.workflows().entrySet()) { logger.info("Parsing {} workflow.", e.getKey()); - WorkflowProcessSorter.get().sortProcessNodes(e.getValue()); + workflowProcessSorter.sortProcessNodes(e.getValue()); } ThreadPool.terminate(threadPool, 500, TimeUnit.MILLISECONDS); } diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index c24fa911d..853c138db 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -51,8 +51,8 @@ public Collection createComponents( IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier ) { - WorkflowStepFactory workflowStepFactory = WorkflowStepFactory.create(client); - WorkflowProcessSorter workflowProcessSorter = WorkflowProcessSorter.create(workflowStepFactory, threadPool); + WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(client); + WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); return ImmutableList.of(workflowStepFactory, workflowProcessSorter); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 93753cee2..4093e5351 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -34,45 +34,22 @@ import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_FIELD; /** - * Utility class converting a workflow of nodes and edges into a topologically sorted list of Process Nodes. + * Converts a workflow of nodes and edges into a topologically sorted list of Process Nodes. */ public class WorkflowProcessSorter { private static final Logger logger = LogManager.getLogger(WorkflowProcessSorter.class); - private static WorkflowProcessSorter instance = null; - private WorkflowStepFactory workflowStepFactory; private ThreadPool threadPool; /** - * Create the singleton instance of this class. Throws an {@link IllegalStateException} if already created. - * - * @param workflowStepFactory The singleton instance of {@link WorkflowStepFactory} - * @param threadPool The Thread Pool to send to Process Nodes - * @return The created instance - */ - public static synchronized WorkflowProcessSorter create(WorkflowStepFactory workflowStepFactory, ThreadPool threadPool) { - if (instance != null) { - throw new IllegalStateException("This class was already created."); - } - instance = new WorkflowProcessSorter(workflowStepFactory, threadPool); - return instance; - } - - /** - * Gets the singleton instance of this class. Throws an {@link IllegalStateException} if not yet created. + * Instantiate this class. * - * @return The created instance + * @param workflowStepFactory The factory which matches template step types to instances. + * @param threadPool The OpenSearch Thread pool to pass to process nodes. */ - public static synchronized WorkflowProcessSorter get() { - if (instance == null) { - throw new IllegalStateException("This factory has not yet been created."); - } - return instance; - } - - private WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, ThreadPool threadPool) { + public WorkflowProcessSorter(WorkflowStepFactory workflowStepFactory, ThreadPool threadPool) { this.workflowStepFactory = workflowStepFactory; this.threadPool = threadPool; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 5da3128d0..26dab0f42 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -22,37 +22,14 @@ */ public class WorkflowStepFactory { - private static WorkflowStepFactory instance = null; - private final Map stepMap = new HashMap<>(); /** - * Create the singleton instance of this class. Throws an {@link IllegalStateException} if already created. + * Instantiate this class. * * @param client The OpenSearch client steps can use - * @return The created instance - */ - public static synchronized WorkflowStepFactory create(Client client) { - if (instance != null) { - throw new IllegalStateException("This factory was already created."); - } - instance = new WorkflowStepFactory(client); - return instance; - } - - /** - * Gets the singleton instance of this class. Throws an {@link IllegalStateException} if not yet created. - * - * @return The created instance */ - public static synchronized WorkflowStepFactory get() { - if (instance == null) { - throw new IllegalStateException("This factory has not yet been created."); - } - return instance; - } - - private WorkflowStepFactory(Client client) { + public WorkflowStepFactory(Client client) { populateMap(client); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index b398bed94..3eefc16ab 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -54,8 +54,8 @@ public static void setup() { when(client.admin()).thenReturn(adminClient); testThreadPool = new TestThreadPool(WorkflowProcessSorterTests.class.getName()); - WorkflowStepFactory factory = WorkflowStepFactory.create(client); - workflowProcessSorter = WorkflowProcessSorter.create(factory, testThreadPool); + WorkflowStepFactory factory = new WorkflowStepFactory(client); + workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool); } @AfterClass From ba23075dc5b91fe2881fba808ecd7a4eb27d02df Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Sat, 30 Sep 2023 13:07:03 -0700 Subject: [PATCH 5/6] More thorough ProcessNode testing Signed-off-by: Daniel Widdis --- .../flowframework/workflow/ProcessNode.java | 65 +++++++------ .../model/TemplateTestJsonUtil.java | 24 ++++- .../workflow/ProcessNodeTests.java | 96 +++++++++++++++++-- .../workflow/WorkflowProcessSorterTests.java | 34 ++++++- 4 files changed, 175 insertions(+), 44 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java index 538d72d4e..2f902755c 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java @@ -17,7 +17,6 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; @@ -110,6 +109,14 @@ public List predecessors() { return predecessors; } + /** + * Returns the timeout value of this node in the workflow. A value of {@link TimeValue#ZERO} means no timeout. + * @return The node's timeout value. + */ + public TimeValue nodeTimeout() { + return nodeTimeout; + } + /** * Execute this node in the sequence. * Initializes the node's {@link CompletableFuture} and completes it when the process completes. @@ -123,49 +130,45 @@ public CompletableFuture execute() { } CompletableFuture.runAsync(() -> { List> predFutures = predecessors.stream().map(p -> p.future()).collect(Collectors.toList()); - predFutures.stream().forEach(CompletableFuture::join); - logger.info(">>> Starting {}.", this.id); - // get the input data from predecessor(s) - List input = new ArrayList(); - input.add(this.input); - for (CompletableFuture cf : predFutures) { - try { + try { + if (!predecessors.isEmpty()) { + CompletableFuture waitForPredecessors = CompletableFuture.allOf(predFutures.toArray(new CompletableFuture[0])); + waitForPredecessors.join(); + } + + logger.info("Starting {}.", this.id); + // get the input data from predecessor(s) + List input = new ArrayList(); + input.add(this.input); + for (CompletableFuture cf : predFutures) { input.add(cf.get()); - } catch (InterruptedException | ExecutionException e) { - handleException(e); - return; } - } - try { + ScheduledCancellable delayExec = null; if (this.nodeTimeout.compareTo(TimeValue.ZERO) > 0) { delayExec = threadPool.schedule(() -> { - future.completeExceptionally(new TimeoutException("Execute timed out for " + this.id)); - // If completed normally the above will be a no-op - if (future.isCompletedExceptionally()) { - logger.info(">>> Timed out {}.", this.id); + if (!future.isDone()) { + future.completeExceptionally(new TimeoutException("Execute timed out for " + this.id)); } - }, this.nodeTimeout, ThreadPool.Names.GENERIC); - // TODO: improve use of thread pool beyond generic - // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/61 + }, this.nodeTimeout, ThreadPool.Names.SAME); } CompletableFuture stepFuture = this.workflowStep.execute(input); - future.complete(stepFuture.get()); // If completed exceptionally, this is a NOOP - delayExec.cancel(); - logger.info(">>> Finished {}.", this.id); - } catch (InterruptedException | ExecutionException e) { - handleException(e); + // If completed exceptionally, this is a no-op + future.complete(stepFuture.get()); + if (delayExec != null) { + delayExec.cancel(); + } + logger.info("Finished {}.", this.id); + } catch (Throwable e) { + // TODO: better handling of getCause + this.future.completeExceptionally(e); } + // TODO: improve use of thread pool beyond generic + // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/61 }, threadPool.generic()); return this.future; } - private void handleException(Exception e) { - // TODO: better handling of getCause - this.future.completeExceptionally(e); - logger.debug("<<< Completed Exceptionally {}", this.id, e.getCause()); - } - @Override public String toString() { return this.id; diff --git a/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java b/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java index 247521084..b38346b29 100644 --- a/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java +++ b/src/test/java/org/opensearch/flowframework/model/TemplateTestJsonUtil.java @@ -27,7 +27,29 @@ public class TemplateTestJsonUtil { public static String node(String id) { - return "{\"" + WorkflowNode.ID_FIELD + "\": \"" + id + "\", \"" + WorkflowNode.TYPE_FIELD + "\": \"" + "placeholder" + "\"}"; + return nodeWithType(id, "placeholder"); + } + + public static String nodeWithType(String id, String type) { + return "{\"" + WorkflowNode.ID_FIELD + "\": \"" + id + "\", \"" + WorkflowNode.TYPE_FIELD + "\": \"" + type + "\"}"; + } + + public static String nodeWithTypeAndTimeout(String id, String type, String timeout) { + return "{\"" + + WorkflowNode.ID_FIELD + + "\": \"" + + id + + "\", \"" + + WorkflowNode.TYPE_FIELD + + "\": \"" + + type + + "\", \"" + + WorkflowNode.INPUTS_FIELD + + "\": {\"" + + WorkflowNode.NODE_TIMEOUT_FIELD + + "\": \"" + + timeout + + "\"}}"; } public static String edge(String sourceId, String destId) { diff --git a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java index 718637ffe..1e421c58c 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java @@ -17,19 +17,34 @@ import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + public class ProcessNodeTests extends OpenSearchTestCase { private static TestThreadPool testThreadPool; + private static ProcessNode successfulNode; + private static ProcessNode failedNode; @BeforeClass public static void setup() { testThreadPool = new TestThreadPool(ProcessNodeTests.class.getName()); + + CompletableFuture successfulFuture = new CompletableFuture<>(); + successfulFuture.complete(WorkflowData.EMPTY); + CompletableFuture failedFuture = new CompletableFuture<>(); + failedFuture.completeExceptionally(new RuntimeException("Test exception")); + successfulNode = mock(ProcessNode.class); + when(successfulNode.future()).thenReturn(successfulFuture); + failedNode = mock(ProcessNode.class); + when(failedNode.future()).thenReturn(failedFuture); } @AfterClass @@ -38,11 +53,12 @@ public static void cleanup() { } public void testNode() throws InterruptedException, ExecutionException { + // Tests where execute nas no timeout ProcessNode nodeA = new ProcessNode("A", new WorkflowStep() { @Override public CompletableFuture execute(List data) { CompletableFuture f = new CompletableFuture<>(); - f.complete(WorkflowData.EMPTY); + f.complete(new WorkflowData(Map.of("test", "output"))); return f; } @@ -50,31 +66,65 @@ public CompletableFuture execute(List data) { public String getName() { return "test"; } - }, WorkflowData.EMPTY, Collections.emptyList(), testThreadPool, TimeValue.ZERO); + }, + new WorkflowData(Map.of("test", "input"), Map.of("foo", "bar")), + List.of(successfulNode), + testThreadPool, + TimeValue.timeValueMillis(50) + ); assertEquals("A", nodeA.id()); assertEquals("test", nodeA.workflowStep().getName()); - assertEquals(WorkflowData.EMPTY, nodeA.input()); - assertEquals(Collections.emptyList(), nodeA.predecessors()); + assertEquals("input", nodeA.input().getContent().get("test")); + assertEquals("bar", nodeA.input().getParams().get("foo")); + assertEquals(1, nodeA.predecessors().size()); + assertEquals(50, nodeA.nodeTimeout().millis()); assertEquals("A", nodeA.toString()); CompletableFuture f = nodeA.execute(); assertEquals(f, nodeA.future()); - assertEquals(WorkflowData.EMPTY, f.get()); + assertEquals("output", f.get().getContent().get("test")); } - public void testNodeTimeout() throws InterruptedException, ExecutionException { - ProcessNode nodeZ = new ProcessNode("Zzz", new WorkflowStep() { + public void testNodeNoTimeout() throws InterruptedException, ExecutionException { + // Tests where execute finishes before timeout + ProcessNode nodeB = new ProcessNode("B", new WorkflowStep() { @Override public CompletableFuture execute(List data) { CompletableFuture future = new CompletableFuture<>(); testThreadPool.schedule( () -> future.complete(WorkflowData.EMPTY), - TimeValue.timeValueMillis(250), + TimeValue.timeValueMillis(100), ThreadPool.Names.GENERIC ); return future; } + @Override + public String getName() { + return "test"; + } + }, WorkflowData.EMPTY, Collections.emptyList(), testThreadPool, TimeValue.timeValueMillis(250)); + assertEquals("B", nodeB.id()); + assertEquals("test", nodeB.workflowStep().getName()); + assertEquals(WorkflowData.EMPTY, nodeB.input()); + assertEquals(Collections.emptyList(), nodeB.predecessors()); + assertEquals("B", nodeB.toString()); + + CompletableFuture f = nodeB.execute(); + assertEquals(f, nodeB.future()); + assertEquals(WorkflowData.EMPTY, f.get()); + } + + public void testNodeTimeout() throws InterruptedException, ExecutionException { + // Tests where execute finishes after timeout + ProcessNode nodeZ = new ProcessNode("Zzz", new WorkflowStep() { + @Override + public CompletableFuture execute(List data) { + CompletableFuture future = new CompletableFuture<>(); + testThreadPool.schedule(() -> future.complete(WorkflowData.EMPTY), TimeValue.timeValueMinutes(1), ThreadPool.Names.GENERIC); + return future; + } + @Override public String getName() { return "sleepy"; @@ -90,7 +140,35 @@ public String getName() { CompletionException exception = assertThrows(CompletionException.class, () -> f.join()); assertTrue(f.isCompletedExceptionally()); assertEquals(TimeoutException.class, exception.getCause().getClass()); + } + + public void testExceptions() { + // Tests where a predecessor future completed exceptionally + ProcessNode nodeE = new ProcessNode("E", new WorkflowStep() { + @Override + public CompletableFuture execute(List data) { + CompletableFuture f = new CompletableFuture<>(); + f.complete(WorkflowData.EMPTY); + return f; + } + + @Override + public String getName() { + return "test"; + } + }, WorkflowData.EMPTY, List.of(successfulNode, failedNode), testThreadPool, TimeValue.timeValueSeconds(15)); + assertEquals("E", nodeE.id()); + assertEquals("test", nodeE.workflowStep().getName()); + assertEquals(WorkflowData.EMPTY, nodeE.input()); + assertEquals(2, nodeE.predecessors().size()); + assertEquals("E", nodeE.toString()); + + CompletableFuture f = nodeE.execute(); + CompletionException exception = assertThrows(CompletionException.class, () -> f.join()); + assertTrue(f.isCompletedExceptionally()); + assertEquals("Test exception", exception.getCause().getMessage()); - assertThrows(IllegalStateException.class, () -> nodeZ.execute()); + // Tests where we already called execute + assertThrows(IllegalStateException.class, () -> nodeE.execute()); } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index 3eefc16ab..74240d561 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -27,6 +27,8 @@ import static org.opensearch.flowframework.model.TemplateTestJsonUtil.edge; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.node; +import static org.opensearch.flowframework.model.TemplateTestJsonUtil.nodeWithType; +import static org.opensearch.flowframework.model.TemplateTestJsonUtil.nodeWithTypeAndTimeout; import static org.opensearch.flowframework.model.TemplateTestJsonUtil.workflow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -37,11 +39,16 @@ public class WorkflowProcessSorterTests extends OpenSearchTestCase { private static final String NO_START_NODE_DETECTED = "No start node detected: all nodes have a predecessor."; private static final String CYCLE_DETECTED = "Cycle detected:"; - // Wrap parser into string list - private static List parse(String json) throws IOException { + // Wrap parser into node list + private static List parseToNodes(String json) throws IOException { XContentParser parser = TemplateTestJsonUtil.jsonToParser(json); Workflow w = Workflow.parse(parser); - return workflowProcessSorter.sortProcessNodes(w).stream().map(ProcessNode::id).collect(Collectors.toList()); + return workflowProcessSorter.sortProcessNodes(w); + } + + // Wrap parser into string list + private static List parse(String json) throws IOException { + return parseToNodes(json).stream().map(ProcessNode::id).collect(Collectors.toList()); } private static TestThreadPool testThreadPool; @@ -63,6 +70,27 @@ public static void cleanup() { ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); } + public void testNodeDetails() throws IOException { + List workflow = null; + workflow = parseToNodes( + workflow( + List.of( + nodeWithType("default_timeout", "create_ingest_pipeline"), + nodeWithTypeAndTimeout("custom_timeout", "create_index", "100ms") + ), + Collections.emptyList() + ) + ); + ProcessNode node = workflow.get(0); + assertEquals("default_timeout", node.id()); + assertEquals(CreateIngestPipelineStep.class, node.workflowStep().getClass()); + assertEquals(10, node.nodeTimeout().seconds()); + node = workflow.get(1); + assertEquals("custom_timeout", node.id()); + assertEquals(CreateIndexStep.class, node.workflowStep().getClass()); + assertEquals(100, node.nodeTimeout().millis()); + } + public void testOrdering() throws IOException { List workflow; From 869cb435e4520813d15f1a436aad0156951ff7dd Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Tue, 3 Oct 2023 10:51:00 -0700 Subject: [PATCH 6/6] Util method for timeout parsing Signed-off-by: Daniel Widdis --- .../workflow/WorkflowProcessSorter.java | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 4093e5351..71c44514e 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -10,7 +10,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.common.settings.Setting; import org.opensearch.common.unit.TimeValue; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; @@ -74,11 +73,7 @@ public List sortProcessNodes(Workflow workflow) { .map(e -> idToNodeMap.get(e.source())) .collect(Collectors.toList()); - TimeValue nodeTimeout = Setting.parseTimeValue( - (String) node.inputs().getOrDefault(NODE_TIMEOUT_FIELD, NODE_TIMEOUT_DEFAULT_VALUE), - TimeValue.ZERO, - String.join(".", node.id(), INPUTS_FIELD, NODE_TIMEOUT_FIELD) - ); + TimeValue nodeTimeout = parseTimeout(node); ProcessNode processNode = new ProcessNode(node.id(), step, data, predecessorNodes, threadPool, nodeTimeout); idToNodeMap.put(processNode.id(), processNode); nodes.add(processNode); @@ -87,6 +82,18 @@ public List sortProcessNodes(Workflow workflow) { return nodes; } + private TimeValue parseTimeout(WorkflowNode node) { + String timeoutValue = (String) node.inputs().getOrDefault(NODE_TIMEOUT_FIELD, NODE_TIMEOUT_DEFAULT_VALUE); + String fieldName = String.join(".", node.id(), INPUTS_FIELD, NODE_TIMEOUT_FIELD); + TimeValue timeValue = TimeValue.parseTimeValue(timeoutValue, fieldName); + if (timeValue.millis() < 0) { + throw new IllegalArgumentException( + "Failed to parse timeout value [" + timeoutValue + "] for field [" + fieldName + "]. Must be positive" + ); + } + return timeValue; + } + private static List topologicalSort(List workflowNodes, List workflowEdges) { // Basic validation Set nodeIds = workflowNodes.stream().map(n -> n.id()).collect(Collectors.toSet());