Skip to content

Commit

Permalink
More thorough ProcessNode testing
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Oct 1, 2023
1 parent e1616a7 commit ba23075
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -110,6 +109,14 @@ public List<ProcessNode> predecessors() {
return predecessors;
}

/**
* Returns the timeout value of this node in the workflow. A value of {@link TimeValue#ZERO} means no timeout.
* @return The node's timeout value.
*/
public TimeValue nodeTimeout() {
return nodeTimeout;
}

/**
* Execute this node in the sequence.
* Initializes the node's {@link CompletableFuture} and completes it when the process completes.
Expand All @@ -123,49 +130,45 @@ public CompletableFuture<WorkflowData> execute() {
}
CompletableFuture.runAsync(() -> {
List<CompletableFuture<WorkflowData>> predFutures = predecessors.stream().map(p -> p.future()).collect(Collectors.toList());
predFutures.stream().forEach(CompletableFuture::join);
logger.info(">>> Starting {}.", this.id);
// get the input data from predecessor(s)
List<WorkflowData> input = new ArrayList<WorkflowData>();
input.add(this.input);
for (CompletableFuture<WorkflowData> cf : predFutures) {
try {
try {
if (!predecessors.isEmpty()) {
CompletableFuture<Void> waitForPredecessors = CompletableFuture.allOf(predFutures.toArray(new CompletableFuture<?>[0]));
waitForPredecessors.join();
}

logger.info("Starting {}.", this.id);
// get the input data from predecessor(s)
List<WorkflowData> input = new ArrayList<WorkflowData>();
input.add(this.input);
for (CompletableFuture<WorkflowData> cf : predFutures) {
input.add(cf.get());
} catch (InterruptedException | ExecutionException e) {
handleException(e);
return;
}
}
try {

ScheduledCancellable delayExec = null;
if (this.nodeTimeout.compareTo(TimeValue.ZERO) > 0) {
delayExec = threadPool.schedule(() -> {
future.completeExceptionally(new TimeoutException("Execute timed out for " + this.id));
// If completed normally the above will be a no-op
if (future.isCompletedExceptionally()) {
logger.info(">>> Timed out {}.", this.id);
if (!future.isDone()) {
future.completeExceptionally(new TimeoutException("Execute timed out for " + this.id));
}
}, this.nodeTimeout, ThreadPool.Names.GENERIC);
// TODO: improve use of thread pool beyond generic
// https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/61
}, this.nodeTimeout, ThreadPool.Names.SAME);
}
CompletableFuture<WorkflowData> stepFuture = this.workflowStep.execute(input);
future.complete(stepFuture.get()); // If completed exceptionally, this is a NOOP
delayExec.cancel();
logger.info(">>> Finished {}.", this.id);
} catch (InterruptedException | ExecutionException e) {
handleException(e);
// If completed exceptionally, this is a no-op
future.complete(stepFuture.get());
if (delayExec != null) {
delayExec.cancel();
}
logger.info("Finished {}.", this.id);
} catch (Throwable e) {
// TODO: better handling of getCause
this.future.completeExceptionally(e);
}
// TODO: improve use of thread pool beyond generic
// https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/61
}, threadPool.generic());
return this.future;
}

private void handleException(Exception e) {
// TODO: better handling of getCause
this.future.completeExceptionally(e);
logger.debug("<<< Completed Exceptionally {}", this.id, e.getCause());
}

@Override
public String toString() {
return this.id;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,29 @@
public class TemplateTestJsonUtil {

public static String node(String id) {
return "{\"" + WorkflowNode.ID_FIELD + "\": \"" + id + "\", \"" + WorkflowNode.TYPE_FIELD + "\": \"" + "placeholder" + "\"}";
return nodeWithType(id, "placeholder");
}

public static String nodeWithType(String id, String type) {
return "{\"" + WorkflowNode.ID_FIELD + "\": \"" + id + "\", \"" + WorkflowNode.TYPE_FIELD + "\": \"" + type + "\"}";
}

public static String nodeWithTypeAndTimeout(String id, String type, String timeout) {
return "{\""
+ WorkflowNode.ID_FIELD
+ "\": \""
+ id
+ "\", \""
+ WorkflowNode.TYPE_FIELD
+ "\": \""
+ type
+ "\", \""
+ WorkflowNode.INPUTS_FIELD
+ "\": {\""
+ WorkflowNode.NODE_TIMEOUT_FIELD
+ "\": \""
+ timeout
+ "\"}}";
}

public static String edge(String sourceId, String destId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,34 @@

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class ProcessNodeTests extends OpenSearchTestCase {

private static TestThreadPool testThreadPool;
private static ProcessNode successfulNode;
private static ProcessNode failedNode;

@BeforeClass
public static void setup() {
testThreadPool = new TestThreadPool(ProcessNodeTests.class.getName());

CompletableFuture<WorkflowData> successfulFuture = new CompletableFuture<>();
successfulFuture.complete(WorkflowData.EMPTY);
CompletableFuture<WorkflowData> failedFuture = new CompletableFuture<>();
failedFuture.completeExceptionally(new RuntimeException("Test exception"));
successfulNode = mock(ProcessNode.class);
when(successfulNode.future()).thenReturn(successfulFuture);
failedNode = mock(ProcessNode.class);
when(failedNode.future()).thenReturn(failedFuture);
}

@AfterClass
Expand All @@ -38,43 +53,78 @@ public static void cleanup() {
}

public void testNode() throws InterruptedException, ExecutionException {
// Tests where execute nas no timeout
ProcessNode nodeA = new ProcessNode("A", new WorkflowStep() {
@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
CompletableFuture<WorkflowData> f = new CompletableFuture<>();
f.complete(WorkflowData.EMPTY);
f.complete(new WorkflowData(Map.of("test", "output")));
return f;
}

@Override
public String getName() {
return "test";
}
}, WorkflowData.EMPTY, Collections.emptyList(), testThreadPool, TimeValue.ZERO);
},
new WorkflowData(Map.of("test", "input"), Map.of("foo", "bar")),
List.of(successfulNode),
testThreadPool,
TimeValue.timeValueMillis(50)
);
assertEquals("A", nodeA.id());
assertEquals("test", nodeA.workflowStep().getName());
assertEquals(WorkflowData.EMPTY, nodeA.input());
assertEquals(Collections.emptyList(), nodeA.predecessors());
assertEquals("input", nodeA.input().getContent().get("test"));
assertEquals("bar", nodeA.input().getParams().get("foo"));
assertEquals(1, nodeA.predecessors().size());
assertEquals(50, nodeA.nodeTimeout().millis());
assertEquals("A", nodeA.toString());

CompletableFuture<WorkflowData> f = nodeA.execute();
assertEquals(f, nodeA.future());
assertEquals(WorkflowData.EMPTY, f.get());
assertEquals("output", f.get().getContent().get("test"));
}

public void testNodeTimeout() throws InterruptedException, ExecutionException {
ProcessNode nodeZ = new ProcessNode("Zzz", new WorkflowStep() {
public void testNodeNoTimeout() throws InterruptedException, ExecutionException {
// Tests where execute finishes before timeout
ProcessNode nodeB = new ProcessNode("B", new WorkflowStep() {
@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
CompletableFuture<WorkflowData> future = new CompletableFuture<>();
testThreadPool.schedule(
() -> future.complete(WorkflowData.EMPTY),
TimeValue.timeValueMillis(250),
TimeValue.timeValueMillis(100),
ThreadPool.Names.GENERIC
);
return future;
}

@Override
public String getName() {
return "test";
}
}, WorkflowData.EMPTY, Collections.emptyList(), testThreadPool, TimeValue.timeValueMillis(250));
assertEquals("B", nodeB.id());
assertEquals("test", nodeB.workflowStep().getName());
assertEquals(WorkflowData.EMPTY, nodeB.input());
assertEquals(Collections.emptyList(), nodeB.predecessors());
assertEquals("B", nodeB.toString());

CompletableFuture<WorkflowData> f = nodeB.execute();
assertEquals(f, nodeB.future());
assertEquals(WorkflowData.EMPTY, f.get());
}

public void testNodeTimeout() throws InterruptedException, ExecutionException {
// Tests where execute finishes after timeout
ProcessNode nodeZ = new ProcessNode("Zzz", new WorkflowStep() {
@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
CompletableFuture<WorkflowData> future = new CompletableFuture<>();
testThreadPool.schedule(() -> future.complete(WorkflowData.EMPTY), TimeValue.timeValueMinutes(1), ThreadPool.Names.GENERIC);
return future;
}

@Override
public String getName() {
return "sleepy";
Expand All @@ -90,7 +140,35 @@ public String getName() {
CompletionException exception = assertThrows(CompletionException.class, () -> f.join());
assertTrue(f.isCompletedExceptionally());
assertEquals(TimeoutException.class, exception.getCause().getClass());
}

public void testExceptions() {
// Tests where a predecessor future completed exceptionally
ProcessNode nodeE = new ProcessNode("E", new WorkflowStep() {
@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
CompletableFuture<WorkflowData> f = new CompletableFuture<>();
f.complete(WorkflowData.EMPTY);
return f;
}

@Override
public String getName() {
return "test";
}
}, WorkflowData.EMPTY, List.of(successfulNode, failedNode), testThreadPool, TimeValue.timeValueSeconds(15));
assertEquals("E", nodeE.id());
assertEquals("test", nodeE.workflowStep().getName());
assertEquals(WorkflowData.EMPTY, nodeE.input());
assertEquals(2, nodeE.predecessors().size());
assertEquals("E", nodeE.toString());

CompletableFuture<WorkflowData> f = nodeE.execute();
CompletionException exception = assertThrows(CompletionException.class, () -> f.join());
assertTrue(f.isCompletedExceptionally());
assertEquals("Test exception", exception.getCause().getMessage());

assertThrows(IllegalStateException.class, () -> nodeZ.execute());
// Tests where we already called execute
assertThrows(IllegalStateException.class, () -> nodeE.execute());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

import static org.opensearch.flowframework.model.TemplateTestJsonUtil.edge;
import static org.opensearch.flowframework.model.TemplateTestJsonUtil.node;
import static org.opensearch.flowframework.model.TemplateTestJsonUtil.nodeWithType;
import static org.opensearch.flowframework.model.TemplateTestJsonUtil.nodeWithTypeAndTimeout;
import static org.opensearch.flowframework.model.TemplateTestJsonUtil.workflow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand All @@ -37,11 +39,16 @@ public class WorkflowProcessSorterTests extends OpenSearchTestCase {
private static final String NO_START_NODE_DETECTED = "No start node detected: all nodes have a predecessor.";
private static final String CYCLE_DETECTED = "Cycle detected:";

// Wrap parser into string list
private static List<String> parse(String json) throws IOException {
// Wrap parser into node list
private static List<ProcessNode> parseToNodes(String json) throws IOException {
XContentParser parser = TemplateTestJsonUtil.jsonToParser(json);
Workflow w = Workflow.parse(parser);
return workflowProcessSorter.sortProcessNodes(w).stream().map(ProcessNode::id).collect(Collectors.toList());
return workflowProcessSorter.sortProcessNodes(w);
}

// Wrap parser into string list
private static List<String> parse(String json) throws IOException {
return parseToNodes(json).stream().map(ProcessNode::id).collect(Collectors.toList());
}

private static TestThreadPool testThreadPool;
Expand All @@ -63,6 +70,27 @@ public static void cleanup() {
ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS);
}

public void testNodeDetails() throws IOException {
List<ProcessNode> workflow = null;
workflow = parseToNodes(
workflow(
List.of(
nodeWithType("default_timeout", "create_ingest_pipeline"),
nodeWithTypeAndTimeout("custom_timeout", "create_index", "100ms")
),
Collections.emptyList()
)
);
ProcessNode node = workflow.get(0);
assertEquals("default_timeout", node.id());
assertEquals(CreateIngestPipelineStep.class, node.workflowStep().getClass());
assertEquals(10, node.nodeTimeout().seconds());
node = workflow.get(1);
assertEquals("custom_timeout", node.id());
assertEquals(CreateIndexStep.class, node.workflowStep().getClass());
assertEquals(100, node.nodeTimeout().millis());
}

public void testOrdering() throws IOException {
List<String> workflow;

Expand Down

0 comments on commit ba23075

Please sign in to comment.