From a574f478920d0f0f7205f9044c0b84a3c4927f5d Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Tue, 19 Sep 2023 10:51:41 -0700 Subject: [PATCH] Topological Sorting and Sequenced Execution (#26) * Topological Sorting and Sequenced Execution Signed-off-by: Daniel Widdis * Add javadocs Signed-off-by: Daniel Widdis * Update demo to link to Workflow interface Signed-off-by: Daniel Widdis * Replace System.out with logging Signed-off-by: Daniel Widdis * Update with new interface signatures Signed-off-by: Daniel Widdis * Demo passing input data at parse-time Signed-off-by: Daniel Widdis * Demo passing data in between steps Signed-off-by: Daniel Widdis * Change execute arg to list and refactor demo classes to own package Signed-off-by: Daniel Widdis * Significantly simplify input/output data passing Signed-off-by: Daniel Widdis * Add tests Signed-off-by: Daniel Widdis * Fix javadocs and forbidden API issues Signed-off-by: Daniel Widdis * Address code review comments Signed-off-by: Daniel Widdis --------- Signed-off-by: Daniel Widdis --- .codecov.yml | 4 + build.gradle | 1 + formatter/formatting.gradle | 1 + .../java/demo/CreateIndexWorkflowStep.java | 83 ++++++++ src/main/java/demo/DataDemo.java | 85 ++++++++ src/main/java/demo/Demo.java | 88 ++++++++ src/main/java/demo/DemoWorkflowStep.java | 52 +++++ src/main/java/demo/README.txt | 13 ++ .../flowframework/template/ProcessNode.java | 189 ++++++++++++++++++ .../template/ProcessSequenceEdge.java | 67 +++++++ .../template/TemplateParser.java | 165 +++++++++++++++ .../flowframework/workflow/WorkflowData.java | 30 ++- .../flowframework/workflow/WorkflowStep.java | 11 +- src/main/resources/log4j2.xml | 17 ++ .../flowframework/FlowFrameworkPluginIT.java | 4 +- .../template/ProcessNodeTests.java | 65 ++++++ .../template/ProcessSequenceEdgeTests.java | 32 +++ .../template/TemplateParserTests.java | 153 ++++++++++++++ .../workflow/WorkflowDataTests.java | 28 +++ src/test/resources/template/datademo.json | 20 ++ src/test/resources/template/demo.json | 36 ++++ 21 files changed, 1133 insertions(+), 11 deletions(-) create mode 100644 src/main/java/demo/CreateIndexWorkflowStep.java create mode 100644 src/main/java/demo/DataDemo.java create mode 100644 src/main/java/demo/Demo.java create mode 100644 src/main/java/demo/DemoWorkflowStep.java create mode 100644 src/main/java/demo/README.txt create mode 100644 src/main/java/org/opensearch/flowframework/template/ProcessNode.java create mode 100644 src/main/java/org/opensearch/flowframework/template/ProcessSequenceEdge.java create mode 100644 src/main/java/org/opensearch/flowframework/template/TemplateParser.java create mode 100644 src/main/resources/log4j2.xml create mode 100644 src/test/java/org/opensearch/flowframework/template/ProcessNodeTests.java create mode 100644 src/test/java/org/opensearch/flowframework/template/ProcessSequenceEdgeTests.java create mode 100644 src/test/java/org/opensearch/flowframework/template/TemplateParserTests.java create mode 100644 src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java create mode 100644 src/test/resources/template/datademo.json create mode 100644 src/test/resources/template/demo.json diff --git a/.codecov.yml b/.codecov.yml index 7c38e4e63..e5bbd7262 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -1,6 +1,10 @@ codecov: require_ci_to_pass: yes +# ignore files in demo package +ignore: + - "src/main/java/demo" + coverage: precision: 2 round: down diff --git a/build.gradle b/build.gradle index 748757484..aa20423ee 100644 --- a/build.gradle +++ b/build.gradle @@ -105,6 +105,7 @@ repositories { dependencies { implementation "org.opensearch:opensearch:${opensearch_version}" implementation 'org.junit.jupiter:junit-jupiter:5.10.0' + implementation "com.google.code.gson:gson:2.10.1" compileOnly "com.google.guava:guava:32.1.2-jre" api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" diff --git a/formatter/formatting.gradle b/formatter/formatting.gradle index e3bc090e0..8f842128f 100644 --- a/formatter/formatting.gradle +++ b/formatter/formatting.gradle @@ -35,6 +35,7 @@ allprojects { trimTrailingWhitespace() endWithNewline() + indentWithSpaces() } format("license", { licenseHeaderFile("${rootProject.file("formatter/license-header.txt")}", "package "); diff --git a/src/main/java/demo/CreateIndexWorkflowStep.java b/src/main/java/demo/CreateIndexWorkflowStep.java new file mode 100644 index 000000000..c1a79188b --- /dev/null +++ b/src/main/java/demo/CreateIndexWorkflowStep.java @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package demo; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.flowframework.workflow.WorkflowStep; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +/** + * Sample to show other devs how to pass data around. Will be deleted once other PRs are merged. + */ +public class CreateIndexWorkflowStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(CreateIndexWorkflowStep.class); + + private final String name; + + /** + * Instantiate this class. + */ + public CreateIndexWorkflowStep() { + this.name = "CREATE_INDEX"; + } + + @Override + public CompletableFuture execute(List data) { + CompletableFuture future = new CompletableFuture<>(); + // TODO we will be passing a thread pool to this object when it's instantiated + // we should either add the generic executor from that pool to this call + // or use executorservice.submit or any of various threading options + // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/42 + CompletableFuture.runAsync(() -> { + String inputIndex = null; + boolean first = true; + for (WorkflowData wfData : data) { + logger.debug( + "{} sent params: {}, content: {}", + first ? "Initialization" : "Previous step", + wfData.getParams(), + wfData.getContent() + ); + if (first) { + Map params = data.get(0).getParams(); + if (params.containsKey("index")) { + inputIndex = params.get("index"); + } + first = false; + } + } + // do some work, simulating a REST API call + try { + Thread.sleep(2000); + } catch (InterruptedException e) {} + // Simulate response of created index + CreateIndexResponse response = new CreateIndexResponse(true, true, inputIndex); + future.complete(new WorkflowData() { + @Override + public Map getContent() { + return Map.of("index", response.index()); + } + }); + }); + + return future; + } + + @Override + public String getName() { + return name; + } +} diff --git a/src/main/java/demo/DataDemo.java b/src/main/java/demo/DataDemo.java new file mode 100644 index 000000000..f2d606f07 --- /dev/null +++ b/src/main/java/demo/DataDemo.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package demo; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.io.PathUtils; +import org.opensearch.flowframework.template.ProcessNode; +import org.opensearch.flowframework.template.TemplateParser; +import org.opensearch.flowframework.workflow.WorkflowStep; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; + +/** + * Demo class exercising {@link TemplateParser}. This will be moved to a unit test. + */ +public class DataDemo { + + private static final Logger logger = LogManager.getLogger(DataDemo.class); + + // This is temporary. We need a factory class to generate these workflow steps + // based on a field in the JSON. + private static Map workflowMap = new HashMap<>(); + static { + workflowMap.put("create_index", new CreateIndexWorkflowStep()); + workflowMap.put("create_another_index", new CreateIndexWorkflowStep()); + } + + /** + * Demonstrate parsing a JSON graph. + * + * @param args unused + */ + @SuppressForbidden(reason = "just a demo class that will be deleted") + public static void main(String[] args) { + String path = "src/test/resources/template/datademo.json"; + String json; + try { + json = new String(Files.readAllBytes(PathUtils.get(path)), StandardCharsets.UTF_8); + } catch (IOException e) { + logger.error("Failed to read JSON at path {}", path); + return; + } + + logger.info("Parsing graph to sequence..."); + List processSequence = TemplateParser.parseJsonGraphToSequence(json, workflowMap); + List> futureList = new ArrayList<>(); + + for (ProcessNode n : processSequence) { + Set predecessors = n.getPredecessors(); + logger.info( + "Queueing process [{}].{}", + n.id(), + predecessors.isEmpty() + ? " Can start immediately!" + : String.format( + Locale.getDefault(), + " Must wait for [%s] to complete first.", + predecessors.stream().map(p -> p.id()).collect(Collectors.joining(", ")) + ) + ); + futureList.add(n.execute()); + } + futureList.forEach(CompletableFuture::join); + logger.info("All done!"); + } + +} diff --git a/src/main/java/demo/Demo.java b/src/main/java/demo/Demo.java new file mode 100644 index 000000000..58d977827 --- /dev/null +++ b/src/main/java/demo/Demo.java @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package demo; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.io.PathUtils; +import org.opensearch.flowframework.template.ProcessNode; +import org.opensearch.flowframework.template.TemplateParser; +import org.opensearch.flowframework.workflow.WorkflowStep; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; + +/** + * Demo class exercising {@link TemplateParser}. This will be moved to a unit test. + */ +public class Demo { + + private static final Logger logger = LogManager.getLogger(Demo.class); + + // This is temporary. We need a factory class to generate these workflow steps + // based on a field in the JSON. + private static Map workflowMap = new HashMap<>(); + static { + workflowMap.put("fetch_model", new DemoWorkflowStep(3000)); + workflowMap.put("create_ingest_pipeline", new DemoWorkflowStep(3000)); + workflowMap.put("create_search_pipeline", new DemoWorkflowStep(5000)); + workflowMap.put("create_neural_search_index", new DemoWorkflowStep(2000)); + } + + /** + * Demonstrate parsing a JSON graph. + * + * @param args unused + */ + @SuppressForbidden(reason = "just a demo class that will be deleted") + public static void main(String[] args) { + String path = "src/test/resources/template/demo.json"; + String json; + try { + json = new String(Files.readAllBytes(PathUtils.get(path)), StandardCharsets.UTF_8); + } catch (IOException e) { + logger.error("Failed to read JSON at path {}", path); + return; + } + + logger.info("Parsing graph to sequence..."); + List processSequence = TemplateParser.parseJsonGraphToSequence(json, workflowMap); + List> futureList = new ArrayList<>(); + + for (ProcessNode n : processSequence) { + Set predecessors = n.getPredecessors(); + logger.info( + "Queueing process [{}].{}", + n.id(), + predecessors.isEmpty() + ? " Can start immediately!" + : String.format( + Locale.getDefault(), + " Must wait for [%s] to complete first.", + predecessors.stream().map(p -> p.id()).collect(Collectors.joining(", ")) + ) + ); + // TODO need to handle this better, passing an argument when we start them all at the beginning is silly + futureList.add(n.execute()); + } + futureList.forEach(CompletableFuture::join); + logger.info("All done!"); + } + +} diff --git a/src/main/java/demo/DemoWorkflowStep.java b/src/main/java/demo/DemoWorkflowStep.java new file mode 100644 index 000000000..037d9b6f6 --- /dev/null +++ b/src/main/java/demo/DemoWorkflowStep.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package demo; + +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.flowframework.workflow.WorkflowStep; + +import java.util.List; +import java.util.concurrent.CompletableFuture; + +/** + * Demo workflowstep to show sequenced execution + */ +public class DemoWorkflowStep implements WorkflowStep { + + private final long delay; + private final String name; + + /** + * Instantiate a step with a delay. + * @param delay milliseconds to take pretending to do work while really sleeping + */ + public DemoWorkflowStep(long delay) { + this.delay = delay; + this.name = "DEMO_DELAY_" + delay; + } + + @Override + public CompletableFuture execute(List data) { + CompletableFuture future = new CompletableFuture<>(); + CompletableFuture.runAsync(() -> { + try { + Thread.sleep(this.delay); + future.complete(null); + } catch (InterruptedException e) { + future.completeExceptionally(e); + } + }); + return future; + } + + @Override + public String getName() { + return name; + } +} diff --git a/src/main/java/demo/README.txt b/src/main/java/demo/README.txt new file mode 100644 index 000000000..4fef77960 --- /dev/null +++ b/src/main/java/demo/README.txt @@ -0,0 +1,13 @@ + +DO NOT DEPEND ON CLASSES IN THIS PACKAGE. + +The contents of this folder are for demo/proof-of-concept use. + +Feel free to look at the classes in this folder for potential "how could I" scenarios. + +Tests will not be written against them. +Documentation may be incomplete, wrong, or outdated. +These are not for production use. +They will be deleted without notice at some point, and altered without notice at other points. + +DO NOT DEPEND ON CLASSES IN THIS PACKAGE. diff --git a/src/main/java/org/opensearch/flowframework/template/ProcessNode.java b/src/main/java/org/opensearch/flowframework/template/ProcessNode.java new file mode 100644 index 000000000..08a7ec841 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/template/ProcessNode.java @@ -0,0 +1,189 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.template; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.flowframework.workflow.WorkflowStep; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +/** + * Representation of a process node in a workflow graph. Tracks predecessor nodes which must be completed before it can start execution. + */ +public class ProcessNode { + + private static final Logger logger = LogManager.getLogger(ProcessNode.class); + + private final String id; + private final WorkflowStep workflowStep; + private final WorkflowData input; + private CompletableFuture future = null; + + // will be populated during graph parsing + private Set predecessors = Collections.emptySet(); + + /** + * Create this node linked to its executing process. + * + * @param id A string identifying the workflow step + * @param workflowStep A java class implementing {@link WorkflowStep} to be executed when it's this node's turn. + */ + ProcessNode(String id, WorkflowStep workflowStep) { + this(id, workflowStep, WorkflowData.EMPTY); + } + + /** + * Create this node linked to its executing process. + * + * @param id A string identifying the workflow step + * @param workflowStep A java class implementing {@link WorkflowStep} to be executed when it's this node's turn. + * @param input Input required by the node + */ + public ProcessNode(String id, WorkflowStep workflowStep, WorkflowData input) { + this.id = id; + this.workflowStep = workflowStep; + this.input = input; + } + + /** + * Returns the node's id. + * @return the node id. + */ + public String id() { + return id; + } + + /** + * Returns the node's workflow implementation. + * @return the workflow step + */ + public WorkflowStep workflowStep() { + return workflowStep; + } + + /** + * Returns the input data for this node. + * @return the input data + */ + public WorkflowData input() { + return input; + } + + /** + * Returns a {@link CompletableFuture} if this process is executing. + * Relies on the node having been sorted and executed in an order such that all predecessor nodes have begun execution first (and thus populated this value). + * + * @return A future indicating the processing state of this node. + * Returns {@code null} if it has not begun executing, should not happen if a workflow is sorted and executed topologically. + */ + public CompletableFuture getFuture() { + return future; + } + + /** + * Returns the predecessors of this node in the workflow. + * The predecessor's {@link #getFuture()} must complete before execution begins on this node. + * + * @return a set of predecessor nodes, if any. At least one node in the graph must have no predecessors and serve as a start node. + */ + public Set getPredecessors() { + return predecessors; + } + + /** + * Sets the predecessor node. Called by {@link TemplateParser}. + * + * @param predecessors The predecessors of this node. + */ + void setPredecessors(Set predecessors) { + this.predecessors = Set.copyOf(predecessors); + } + + /** + * Execute this node in the sequence. Initializes the node's {@link CompletableFuture} and completes it when the process completes. + * + * @return this node's future. This is returned immediately, while process execution continues asynchronously. + */ + public CompletableFuture execute() { + this.future = new CompletableFuture<>(); + // TODO this class will be instantiated with the OpenSearch thread pool (or one for tests!) + // the generic executor from that pool should be passed to this runAsync call + // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/42 + CompletableFuture.runAsync(() -> { + List> predFutures = predecessors.stream().map(p -> p.getFuture()).collect(Collectors.toList()); + if (!predecessors.isEmpty()) { + CompletableFuture waitForPredecessors = CompletableFuture.allOf(predFutures.toArray(new CompletableFuture[0])); + try { + // We need timeouts to be part of the user template or in settings + // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/45 + waitForPredecessors.orTimeout(30, TimeUnit.SECONDS).get(); + } catch (InterruptedException | ExecutionException e) { + handleException(e); + return; + } + } + logger.info(">>> Starting {}.", this.id); + // get the input data from predecessor(s) + List input = new ArrayList(); + input.add(this.input); + for (CompletableFuture cf : predFutures) { + try { + input.add(cf.get()); + } catch (InterruptedException | ExecutionException e) { + handleException(e); + return; + } + } + CompletableFuture stepFuture = this.workflowStep.execute(input); + try { + stepFuture.join(); + future.complete(stepFuture.get()); + logger.debug("<<< Completed {}", this.id); + } catch (InterruptedException | ExecutionException e) { + handleException(e); + } + }); + return this.future; + } + + private void handleException(Exception e) { + // TODO: better handling of getCause + this.future.completeExceptionally(e); + logger.debug("<<< Completed Exceptionally {}", this.id); + } + + @Override + public int hashCode() { + return Objects.hash(id); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null) return false; + if (getClass() != obj.getClass()) return false; + ProcessNode other = (ProcessNode) obj; + return Objects.equals(id, other.id); + } + + @Override + public String toString() { + return this.id; + } +} diff --git a/src/main/java/org/opensearch/flowframework/template/ProcessSequenceEdge.java b/src/main/java/org/opensearch/flowframework/template/ProcessSequenceEdge.java new file mode 100644 index 000000000..9544620fb --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/template/ProcessSequenceEdge.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.template; + +import java.util.Objects; + +/** + * Representation of an edge between process nodes in a workflow graph. + */ +public class ProcessSequenceEdge { + private final String source; + private final String destination; + + /** + * Create this edge with the id's of the source and destination nodes. + * + * @param source The source node id. + * @param destination The destination node id. + */ + ProcessSequenceEdge(String source, String destination) { + this.source = source; + this.destination = destination; + } + + /** + * Gets the source node id. + * + * @return the source node id. + */ + public String getSource() { + return source; + } + + /** + * Gets the destination node id. + * + * @return the destination node id. + */ + public String getDestination() { + return destination; + } + + @Override + public int hashCode() { + return Objects.hash(destination, source); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null) return false; + if (getClass() != obj.getClass()) return false; + ProcessSequenceEdge other = (ProcessSequenceEdge) obj; + return Objects.equals(destination, other.destination) && Objects.equals(source, other.source); + } + + @Override + public String toString() { + return this.source + "->" + this.destination; + } +} diff --git a/src/main/java/org/opensearch/flowframework/template/TemplateParser.java b/src/main/java/org/opensearch/flowframework/template/TemplateParser.java new file mode 100644 index 000000000..bce07c616 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/template/TemplateParser.java @@ -0,0 +1,165 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.template; + +import com.google.gson.Gson; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.flowframework.workflow.WorkflowStep; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * Utility class for parsing templates. + */ +public class TemplateParser { + + private static final Logger logger = LogManager.getLogger(TemplateParser.class); + + // Field names in the JSON. Package private for tests. + static final String WORKFLOW = "sequence"; + static final String NODES = "nodes"; + static final String NODE_ID = "id"; + static final String EDGES = "edges"; + static final String SOURCE = "source"; + static final String DESTINATION = "dest"; + + /** + * Prevent instantiating this class. + */ + private TemplateParser() {} + + /** + * Parse a JSON representation of nodes and edges into a topologically sorted list of process nodes. + * @param json A string containing a JSON representation of nodes and edges + * @param workflowSteps A map linking JSON node names to Java objects implementing {@link WorkflowStep} + * @return A list of Process Nodes sorted topologically. All predecessors of any node will occur prior to it in the list. + */ + public static List parseJsonGraphToSequence(String json, Map workflowSteps) { + Gson gson = new Gson(); + JsonObject jsonObject = gson.fromJson(json, JsonObject.class); + + JsonObject graph = jsonObject.getAsJsonObject(WORKFLOW); + + List nodes = new ArrayList<>(); + List edges = new ArrayList<>(); + + for (JsonElement nodeJson : graph.getAsJsonArray(NODES)) { + JsonObject nodeObject = nodeJson.getAsJsonObject(); + String nodeId = nodeObject.get(NODE_ID).getAsString(); + // The below steps will be replaced by a generator class that instantiates a WorkflowStep + // based on user_input data from the template. + // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/41 + WorkflowStep workflowStep = workflowSteps.get(nodeId); + // temporary demo POC of getting from a request to input data + // this will be refactored into something pulling from user template as part of the above issue + WorkflowData inputData = WorkflowData.EMPTY; + if (List.of("create_index", "create_another_index").contains(nodeId)) { + CreateIndexRequest request = new CreateIndexRequest(nodeObject.get("index_name").getAsString()); + inputData = new WorkflowData() { + + @Override + public Map getContent() { + // See CreateIndexRequest ParseFields for source of content keys needed + return Map.of("mappings", request.mappings(), "settings", request.settings(), "aliases", request.aliases()); + } + + @Override + public Map getParams() { + // See RestCreateIndexAction for source of param keys needed + return Map.of("index", request.index()); + } + + }; + } + nodes.add(new ProcessNode(nodeId, workflowStep, inputData)); + } + + for (JsonElement edgeJson : graph.getAsJsonArray(EDGES)) { + JsonObject edgeObject = edgeJson.getAsJsonObject(); + String sourceNodeId = edgeObject.get(SOURCE).getAsString(); + String destNodeId = edgeObject.get(DESTINATION).getAsString(); + if (sourceNodeId.equals(destNodeId)) { + throw new IllegalArgumentException("Edge connects node " + sourceNodeId + " to itself."); + } + edges.add(new ProcessSequenceEdge(sourceNodeId, destNodeId)); + } + + return topologicalSort(nodes, edges); + } + + private static List topologicalSort(List nodes, List edges) { + // Define the graph + Set graph = new HashSet<>(edges); + // Map node id string to object + Map nodeMap = nodes.stream().collect(Collectors.toMap(ProcessNode::id, Function.identity())); + // Build predecessor and successor maps + Map> predecessorEdges = new HashMap<>(); + Map> successorEdges = new HashMap<>(); + for (ProcessSequenceEdge edge : edges) { + ProcessNode source = nodeMap.get(edge.getSource()); + ProcessNode dest = nodeMap.get(edge.getDestination()); + predecessorEdges.computeIfAbsent(dest, k -> new HashSet<>()).add(edge); + successorEdges.computeIfAbsent(source, k -> new HashSet<>()).add(edge); + } + // update predecessors on the node object + nodes.stream().filter(n -> predecessorEdges.containsKey(n)).forEach(n -> { + n.setPredecessors(predecessorEdges.get(n).stream().map(e -> nodeMap.get(e.getSource())).collect(Collectors.toSet())); + }); + + // See https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm + // L <- Empty list that will contain the sorted elements + List sortedNodes = new ArrayList<>(); + // S <- Set of all nodes with no incoming edge + Queue sourceNodes = new ArrayDeque<>(); + nodes.stream().filter(n -> !predecessorEdges.containsKey(n)).forEach(n -> sourceNodes.add(n)); + if (sourceNodes.isEmpty()) { + throw new IllegalArgumentException("No start node detected: all nodes have a predecessor."); + } + logger.debug("Start node(s): {}", sourceNodes); + + // while S is not empty do + while (!sourceNodes.isEmpty()) { + // remove a node n from S + ProcessNode n = sourceNodes.poll(); + // add n to L + sortedNodes.add(n); + // for each node m with an edge e from n to m do + for (ProcessSequenceEdge e : successorEdges.getOrDefault(n, Collections.emptySet())) { + ProcessNode m = nodeMap.get(e.getDestination()); + // remove edge e from the graph + graph.remove(e); + // if m has no other incoming edges then + if (!predecessorEdges.get(m).stream().anyMatch(i -> graph.contains(i))) { + // insert m into S + sourceNodes.add(m); + } + } + } + if (!graph.isEmpty()) { + throw new IllegalArgumentException("Cycle detected: " + graph); + } + logger.debug("Execution sequence: {}", sortedNodes); + return sortedNodes; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java index 3e8dc81b2..09eb041fc 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java @@ -8,7 +8,33 @@ */ package org.opensearch.flowframework.workflow; +import java.util.Collections; +import java.util.Map; + /** - * Interface for handling the input/output of the building blocks. + * Interface representing data provided as input to, and produced as output from, {@link WorkflowStep}s. */ -public interface WorkflowData {} +public interface WorkflowData { + + /** + * An object representing no data, useful when a workflow step has no required input or output. + */ + WorkflowData EMPTY = new WorkflowData() { + }; + + /** + * Accesses a map containing the content of the workflow step. This represents the data associated with a Rest API request. + * @return the content of this step. + */ + default Map getContent() { + return Collections.emptyMap(); + }; + + /** + * Accesses a map containing the params of this workflow step. This represents the params associated with a Rest API request, parsed from the URI. + * @return the params of this step. + */ + default Map getParams() { + return Collections.emptyMap(); + }; +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java index 6a65ce6e3..6cd5f5a28 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java @@ -8,8 +8,7 @@ */ package org.opensearch.flowframework.workflow; -import org.opensearch.common.Nullable; - +import java.util.List; import java.util.concurrent.CompletableFuture; /** @@ -18,11 +17,11 @@ public interface WorkflowStep { /** - * Triggers the processing of the building block. - * @param data for input/output params of the building blocks. - * @return CompletableFuture of the building block. + * Triggers the actual processing of the building block. + * @param data representing input params and content, or output content of previous steps. The first element of the list is data (if any) provided from parsing the template, and may be {@link WorkflowData#EMPTY}. + * @return A CompletableFuture of the building block. This block should return immediately, but not be completed until the step executes, containing either the step's output data or {@link WorkflowData#EMPTY} which may be passed to follow-on steps. */ - CompletableFuture execute(@Nullable WorkflowData data); + CompletableFuture execute(List data); /** * diff --git a/src/main/resources/log4j2.xml b/src/main/resources/log4j2.xml new file mode 100644 index 000000000..21a4c6fa5 --- /dev/null +++ b/src/main/resources/log4j2.xml @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginIT.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginIT.java index d54dc2c63..0dccc27ce 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginIT.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginIT.java @@ -22,8 +22,6 @@ import java.util.Collection; import java.util.Collections; -import static org.hamcrest.Matchers.containsString; - @ThreadLeakScope(ThreadLeakScope.Scope.NONE) @OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE) public class FlowFrameworkPluginIT extends OpenSearchIntegTestCase { @@ -38,6 +36,6 @@ public void testPluginInstalled() throws IOException, ParseException { String body = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8); logger.info("response body: {}", body); - assertThat(body, containsString("flowframework")); + assertTrue(body.contains("flowframework")); } } diff --git a/src/test/java/org/opensearch/flowframework/template/ProcessNodeTests.java b/src/test/java/org/opensearch/flowframework/template/ProcessNodeTests.java new file mode 100644 index 000000000..3feab9f3b --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/template/ProcessNodeTests.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.template; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope.Scope; + +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.flowframework.workflow.WorkflowStep; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +@ThreadLeakScope(Scope.NONE) +public class ProcessNodeTests extends OpenSearchTestCase { + + @Override + public void setUp() throws Exception { + super.setUp(); + } + + public void testNode() throws InterruptedException, ExecutionException { + ProcessNode nodeA = new ProcessNode("A", new WorkflowStep() { + @Override + public CompletableFuture execute(List data) { + CompletableFuture f = new CompletableFuture<>(); + f.complete(WorkflowData.EMPTY); + return f; + } + + @Override + public String getName() { + return "test"; + } + }); + assertEquals("A", nodeA.id()); + assertEquals("test", nodeA.workflowStep().getName()); + assertEquals(WorkflowData.EMPTY, nodeA.input()); + assertEquals(Collections.emptySet(), nodeA.getPredecessors()); + assertEquals("A", nodeA.toString()); + + // TODO: Once we can get OpenSearch Thread Pool for this execute method, create an IT and don't test execute here + CompletableFuture f = nodeA.execute(); + assertEquals(f, nodeA.getFuture()); + f.orTimeout(5, TimeUnit.SECONDS); + assertTrue(f.isDone()); + assertEquals(WorkflowData.EMPTY, f.get()); + + ProcessNode nodeB = new ProcessNode("B", null); + assertNotEquals(nodeA, nodeB); + + ProcessNode nodeA2 = new ProcessNode("A", null); + assertEquals(nodeA, nodeA2); + } +} diff --git a/src/test/java/org/opensearch/flowframework/template/ProcessSequenceEdgeTests.java b/src/test/java/org/opensearch/flowframework/template/ProcessSequenceEdgeTests.java new file mode 100644 index 000000000..80cecd96e --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/template/ProcessSequenceEdgeTests.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.template; + +import org.opensearch.test.OpenSearchTestCase; + +public class ProcessSequenceEdgeTests extends OpenSearchTestCase { + + @Override + public void setUp() throws Exception { + super.setUp(); + } + + public void testEdge() { + ProcessSequenceEdge edgeAB = new ProcessSequenceEdge("A", "B"); + assertEquals("A", edgeAB.getSource()); + assertEquals("B", edgeAB.getDestination()); + assertEquals("A->B", edgeAB.toString()); + + ProcessSequenceEdge edgeAB2 = new ProcessSequenceEdge("A", "B"); + assertEquals(edgeAB, edgeAB2); + + ProcessSequenceEdge edgeAC = new ProcessSequenceEdge("A", "C"); + assertNotEquals(edgeAB, edgeAC); + } +} diff --git a/src/test/java/org/opensearch/flowframework/template/TemplateParserTests.java b/src/test/java/org/opensearch/flowframework/template/TemplateParserTests.java new file mode 100644 index 000000000..24dcf0640 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/template/TemplateParserTests.java @@ -0,0 +1,153 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.template; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import static org.opensearch.flowframework.template.TemplateParser.DESTINATION; +import static org.opensearch.flowframework.template.TemplateParser.EDGES; +import static org.opensearch.flowframework.template.TemplateParser.NODES; +import static org.opensearch.flowframework.template.TemplateParser.NODE_ID; +import static org.opensearch.flowframework.template.TemplateParser.SOURCE; +import static org.opensearch.flowframework.template.TemplateParser.WORKFLOW; + +public class TemplateParserTests extends OpenSearchTestCase { + + private static final String NO_START_NODE_DETECTED = "No start node detected: all nodes have a predecessor."; + private static final String CYCLE_DETECTED = "Cycle detected:"; + + // Input JSON generators + private static String node(String id) { + return "{\"" + NODE_ID + "\": \"" + id + "\"}"; + } + + private static String edge(String sourceId, String destId) { + return "{\"" + SOURCE + "\": \"" + sourceId + "\", \"" + DESTINATION + "\": \"" + destId + "\"}"; + } + + private static String workflow(List nodes, List edges) { + return "{\"" + WORKFLOW + "\": {" + arrayField(NODES, nodes) + ", " + arrayField(EDGES, edges) + "}}"; + } + + private static String arrayField(String fieldName, List objects) { + return "\"" + fieldName + "\": [" + objects.stream().collect(Collectors.joining(", ")) + "]"; + } + + // Output list elements + private static ProcessNode expectedNode(String id) { + return new ProcessNode(id, null, null); + } + + // Less verbose parser + private static List parse(String json) { + return TemplateParser.parseJsonGraphToSequence(json, Collections.emptyMap()); + } + + @Override + public void setUp() throws Exception { + super.setUp(); + } + + public void testOrdering() { + List workflow; + + workflow = parse(workflow(List.of(node("A"), node("B"), node("C")), List.of(edge("C", "B"), edge("B", "A")))); + assertEquals(0, workflow.indexOf(expectedNode("C"))); + assertEquals(1, workflow.indexOf(expectedNode("B"))); + assertEquals(2, workflow.indexOf(expectedNode("A"))); + + workflow = parse( + workflow( + List.of(node("A"), node("B"), node("C"), node("D")), + List.of(edge("A", "B"), edge("A", "C"), edge("B", "D"), edge("C", "D")) + ) + ); + assertEquals(0, workflow.indexOf(expectedNode("A"))); + int b = workflow.indexOf(expectedNode("B")); + int c = workflow.indexOf(expectedNode("C")); + assertTrue(b == 1 || b == 2); + assertTrue(c == 1 || c == 2); + assertEquals(3, workflow.indexOf(expectedNode("D"))); + + workflow = parse( + workflow( + List.of(node("A"), node("B"), node("C"), node("D"), node("E")), + List.of(edge("A", "B"), edge("A", "C"), edge("B", "D"), edge("D", "E"), edge("C", "E")) + ) + ); + assertEquals(0, workflow.indexOf(expectedNode("A"))); + b = workflow.indexOf(expectedNode("B")); + c = workflow.indexOf(expectedNode("C")); + int d = workflow.indexOf(expectedNode("D")); + assertTrue(b == 1 || b == 2); + assertTrue(c == 1 || c == 2); + assertTrue(d == 2 || d == 3); + assertEquals(4, workflow.indexOf(expectedNode("E"))); + } + + public void testCycles() { + Exception ex; + + ex = assertThrows(IllegalArgumentException.class, () -> parse(workflow(List.of(node("A")), List.of(edge("A", "A"))))); + assertEquals("Edge connects node A to itself.", ex.getMessage()); + + ex = assertThrows( + IllegalArgumentException.class, + () -> parse(workflow(List.of(node("A"), node("B")), List.of(edge("A", "B"), edge("B", "B")))) + ); + assertEquals("Edge connects node B to itself.", ex.getMessage()); + + ex = assertThrows( + IllegalArgumentException.class, + () -> parse(workflow(List.of(node("A"), node("B")), List.of(edge("A", "B"), edge("B", "A")))) + ); + assertEquals(NO_START_NODE_DETECTED, ex.getMessage()); + + ex = assertThrows( + IllegalArgumentException.class, + () -> parse(workflow(List.of(node("A"), node("B"), node("C")), List.of(edge("A", "B"), edge("B", "C"), edge("C", "B")))) + ); + assertTrue(ex.getMessage().startsWith(CYCLE_DETECTED)); + assertTrue(ex.getMessage().contains("B->C")); + assertTrue(ex.getMessage().contains("C->B")); + + ex = assertThrows( + IllegalArgumentException.class, + () -> parse( + workflow( + List.of(node("A"), node("B"), node("C"), node("D")), + List.of(edge("A", "B"), edge("B", "C"), edge("C", "D"), edge("D", "B")) + ) + ) + ); + assertTrue(ex.getMessage().startsWith(CYCLE_DETECTED)); + assertTrue(ex.getMessage().contains("B->C")); + assertTrue(ex.getMessage().contains("C->D")); + assertTrue(ex.getMessage().contains("D->B")); + } + + public void testNoEdges() { + Exception ex = assertThrows( + IllegalArgumentException.class, + () -> parse(workflow(Collections.emptyList(), Collections.emptyList())) + ); + assertEquals(NO_START_NODE_DETECTED, ex.getMessage()); + + assertEquals(List.of(expectedNode("A")), parse(workflow(List.of(node("A")), Collections.emptyList()))); + + List workflow = parse(workflow(List.of(node("A"), node("B")), Collections.emptyList())); + assertEquals(2, workflow.size()); + assertTrue(workflow.contains(expectedNode("A"))); + assertTrue(workflow.contains(expectedNode("B"))); + } +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java new file mode 100644 index 000000000..42a1a1a03 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; + +public class WorkflowDataTests extends OpenSearchTestCase { + + @Override + public void setUp() throws Exception { + super.setUp(); + } + + public void testWorkflowData() { + WorkflowData data = new WorkflowData() { + }; + assertEquals(Collections.emptyMap(), data.getParams()); + assertEquals(Collections.emptyMap(), data.getContent()); + } +} diff --git a/src/test/resources/template/datademo.json b/src/test/resources/template/datademo.json new file mode 100644 index 000000000..a1323ed2c --- /dev/null +++ b/src/test/resources/template/datademo.json @@ -0,0 +1,20 @@ +{ + "sequence": { + "nodes": [ + { + "id": "create_index", + "index_name": "demo" + }, + { + "id": "create_another_index", + "index_name": "second_demo" + } + ], + "edges": [ + { + "source": "create_index", + "dest": "create_another_index" + } + ] + } +} diff --git a/src/test/resources/template/demo.json b/src/test/resources/template/demo.json new file mode 100644 index 000000000..38f1d0644 --- /dev/null +++ b/src/test/resources/template/demo.json @@ -0,0 +1,36 @@ +{ + "sequence": { + "nodes": [ + { + "id": "fetch_model" + }, + { + "id": "create_ingest_pipeline" + }, + { + "id": "create_search_pipeline" + }, + { + "id": "create_neural_search_index" + } + ], + "edges": [ + { + "source": "fetch_model", + "dest": "create_ingest_pipeline" + }, + { + "source": "fetch_model", + "dest": "create_search_pipeline" + }, + { + "source": "create_ingest_pipeline", + "dest": "create_neural_search_index" + }, + { + "source": "create_search_pipeline", + "dest": "create_neural_search_index" + } + ] + } +}