Skip to content

Commit

Permalink
Addressing PR comments (Part 2), adding globalcontexthandler to creat…
Browse files Browse the repository at this point in the history
…e components, added updateTemplate(), indexExists() methods to handler and createIndex step respecitvely. Implemented CreateWorkflow/ProvisionWorkflow transport actions

Signed-off-by: Joshua Palis <[email protected]>
  • Loading branch information
joshpalis committed Oct 9, 2023
1 parent 791f943 commit c7c819b
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@
import org.opensearch.common.settings.IndexScopedSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.settings.SettingsFilter;
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.flowframework.indices.GlobalContextHandler;
import org.opensearch.flowframework.rest.RestCreateWorkflowAction;
import org.opensearch.flowframework.rest.RestProvisionWorkflowAction;
import org.opensearch.flowframework.transport.CreateWorkflowAction;
import org.opensearch.flowframework.transport.CreateWorkflowTransportAction;
import org.opensearch.flowframework.transport.ProvisionWorkflowAction;
import org.opensearch.flowframework.transport.ProvisionWorkflowTransportAction;
import org.opensearch.flowframework.workflow.CreateIndexStep;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
import org.opensearch.plugins.ActionPlugin;
Expand Down Expand Up @@ -76,7 +79,10 @@ public Collection<Object> createComponents(
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool);

return ImmutableList.of(workflowStepFactory, workflowProcessSorter);
// TODO : Refactor, move system index creation/associated methods outside of the CreateIndexStep
GlobalContextHandler globalContextHandler = new GlobalContextHandler(client, new CreateIndexStep(clusterService, client));

return ImmutableList.of(workflowStepFactory, workflowProcessSorter, globalContextHandler);
}

@Override
Expand Down Expand Up @@ -106,10 +112,9 @@ public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
FixedExecutorBuilder provisionThreadPool = new FixedExecutorBuilder(
settings,
PROVISION_THREAD_POOL,
1,
OpenSearchExecutors.allocatedProcessors(settings),
10,
FLOW_FRAMEWORK_THREAD_POOL_PREFIX + PROVISION_THREAD_POOL,
false
FLOW_FRAMEWORK_THREAD_POOL_PREFIX + PROVISION_THREAD_POOL
);
return ImmutableList.of(provisionThreadPool);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@ public class CommonValue {
public static final String AI_FLOW_FRAMEWORK_BASE_URI = "/_plugins/_flow_framework";
/** The URI for this plugin's workflow rest actions */
public static final String WORKFLOWS_URI = AI_FLOW_FRAMEWORK_BASE_URI + "/workflows";
/** Field name for workflow Id, the document Id of the indexed use case template */
public static final String WORKFLOW_ID = "workflow_id";
/** The field name for provision workflow within a use case template*/
public static final String PROVISION_WORKFLOW = "provision";

/** Flow Framework plugin thread pool name prefix */
public static final String FLOW_FRAMEWORK_THREAD_POOL_PREFIX = "thread_pool.flow_framework.";
/** The provision workflow thread pool name */
public static final String PROVISION_THREAD_POOL = "opensearch_workflow_provision";
/** Field name for workflow Id, the document Id of the indexed use case template */
public static final String WORKFLOW_ID = "workflow_id";

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import java.io.IOException;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;

import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
Expand Down Expand Up @@ -94,6 +95,37 @@ public void putTemplateToGlobalContext(Template template, ActionListener<IndexRe
}));
}

/**
* Replaces a document in the global context index
* @param documentId the document Id
* @param template the use-case template
* @param listener action listener
*/
public void updateTemplate(String documentId, Template template, ActionListener<IndexResponse> listener) {
if (!createIndexStep.doesIndexExist(GLOBAL_CONTEXT_INDEX)) {
String exceptionMessage = String.format(
Locale.ROOT,
"Failed to update template {}, global_context index does not exist.",
documentId
);
logger.error(exceptionMessage);
listener.onFailure(new Exception(exceptionMessage));
} else {
IndexRequest request = new IndexRequest(GLOBAL_CONTEXT_INDEX).id(documentId);
try (
XContentBuilder builder = XContentFactory.jsonBuilder();
ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()
) {
request.source(template.toXContent(builder, ToXContent.EMPTY_PARAMS))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.index(request, ActionListener.runBefore(listener, () -> context.restore()));
} catch (Exception e) {
logger.error("Failed to update global_context entry : {}. {}", documentId, e.getMessage());
listener.onFailure(e);
}
}
}

/**
* Update global context index for specific fields
* @param documentId global context index document id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import com.google.common.collect.ImmutableList;
import org.opensearch.client.node.NodeClient;
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.transport.ProvisionWorkflowAction;
import org.opensearch.flowframework.transport.WorkflowRequest;
import org.opensearch.rest.BaseRestHandler;
Expand All @@ -31,12 +30,6 @@ public class RestProvisionWorkflowAction extends BaseRestHandler {

private static final String PROVISION_WORKFLOW_ACTION = "provision_workflow_action";

// TODO : move to common values class, pending implementation
/**
* Field name for workflow Id, the document Id of the indexed use case template
*/
public static final String WORKFLOW_ID = "workflow_id";

/**
* Instantiates a new RestProvisionWorkflowAction
*/
Expand All @@ -52,8 +45,6 @@ public String getName() {
@Override
public List<Route> routes() {
return ImmutableList.of(
// Provision workflow from inline use case template
new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/%s", WORKFLOWS_URI, "_provision")),
// Provision workflow from indexed use case template
new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/{%s}/%s", WORKFLOWS_URI, WORKFLOW_ID, "_provision"))
);
Expand All @@ -62,20 +53,19 @@ public List<Route> routes() {
@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {

String workflowId = request.param(WORKFLOW_ID);
Template template = null;

// Validate content
if (request.hasContent()) {
template = Template.parse(request.content().utf8ToString());
throw new IOException("Invalid request format");
}

// Validate workflow request inputs
if (workflowId == null && template == null) {
throw new IOException("workflow_id and template cannot be both null");
// Validate params
String workflowId = request.param(WORKFLOW_ID);
if (workflowId == null) {
throw new IOException("workflow_id cannot be null");
}

// Create request and provision
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template);
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null);
return channel -> client.execute(ProvisionWorkflowAction.INSTANCE, workflowRequest, new RestToXContentListener<>(channel));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import org.apache.logging.log4j.Logger;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.indices.GlobalContextHandler;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand All @@ -25,41 +25,46 @@ public class CreateWorkflowTransportAction extends HandledTransportAction<Workfl

private final Logger logger = LogManager.getLogger(CreateWorkflowTransportAction.class);

private final Client client;
private final GlobalContextHandler globalContextHandler;

/**
* Intantiates a new CreateWorkflowTransportAction
* @param transportService the TransportService
* @param actionFilters action filters
* @param client the node client to interact with an index
* @param globalContextHandler The handler for the global context index
*/
@Inject
public CreateWorkflowTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
public CreateWorkflowTransportAction(
TransportService transportService,
ActionFilters actionFilters,
GlobalContextHandler globalContextHandler
) {
super(CreateWorkflowAction.NAME, transportService, actionFilters, WorkflowRequest::new);
this.client = client;
this.globalContextHandler = globalContextHandler;
}

@Override
protected void doExecute(Task task, WorkflowRequest request, ActionListener<WorkflowResponse> listener) {

String workflowId;
// TODO : Check if global context index exists, and if it does not then create

if (request.getWorkflowId() == null) {
// TODO : Create new entry
// TODO : Insert doc

// TODO : get document ID
workflowId = "";
// TODO : check if state index exists, and if it does not, then create
// TODO : insert state index doc, mapped with documentId, defaulted to NOT_STARTED
// Create new global context and state index entries
globalContextHandler.putTemplateToGlobalContext(request.getTemplate(), ActionListener.wrap(response -> {
// TODO : Check if state index exists, create if not
// TODO : Create StateIndexRequest for workflowId, default to NOT_STARTED
listener.onResponse(new WorkflowResponse(response.getId()));
}, exception -> {
logger.error("Failed to save use case template : {}", exception.getMessage());
listener.onFailure(exception);
}));
} else {
// TODO : Update existing entry
workflowId = request.getWorkflowId();
// TODO : Update state index entry, default back to NOT_STARTED
// Update existing entry, full document replacement
globalContextHandler.updateTemplate(request.getWorkflowId(), request.getTemplate(), ActionListener.wrap(response -> {
// TODO : Create StateIndexRequest for workflowId to reset entry to NOT_STARTED
listener.onResponse(new WorkflowResponse(request.getWorkflowId()));
}, exception -> {
logger.error("Failed to updated use case template {} : {}", request.getWorkflowId(), exception.getMessage());
listener.onFailure(exception);
}));
}

listener.onResponse(new WorkflowResponse(workflowId));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.model.Workflow;
Expand All @@ -31,7 +33,9 @@
import java.util.concurrent.CompletionException;
import java.util.stream.Collectors;

import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_THREAD_POOL;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW;

/**
* Transport Action to provision a workflow from a stored use case template
Expand All @@ -40,12 +44,6 @@ public class ProvisionWorkflowTransportAction extends HandledTransportAction<Wor

private final Logger logger = LogManager.getLogger(ProvisionWorkflowTransportAction.class);

// TODO : Move to common values class, pending implementation
/**
* The name of the provision workflow within the use case template
*/
private static final String PROVISION_WORKFLOW = "provision";

private final ThreadPool threadPool;
private final Client client;
private final WorkflowProcessSorter workflowProcessSorter;
Expand Down Expand Up @@ -75,31 +73,30 @@ public ProvisionWorkflowTransportAction(
@Override
protected void doExecute(Task task, WorkflowRequest request, ActionListener<WorkflowResponse> listener) {

if (request.getWorkflowId() == null) {
// Workflow provisioning from inline template, first parse and then index the given use case template
client.execute(CreateWorkflowAction.INSTANCE, request, ActionListener.wrap(workflowResponse -> {
String workflowId = workflowResponse.getWorkflowId();
Template template = request.getTemplate();

// TODO : Use node client to update state index to PROVISIONING, given workflowId

listener.onResponse(new WorkflowResponse(workflowId));

// Asychronously begin provision workflow excecution
executeWorkflowAsync(workflowId, template.workflows().get(PROVISION_WORKFLOW));
// Retrieve use case template from global context
String workflowId = request.getWorkflowId();
GetRequest getRequest = new GetRequest(GLOBAL_CONTEXT_INDEX, workflowId);

}, exception -> { listener.onFailure(exception); }));
} else {
// Use case template has been previously saved, retrieve entry and execute
String workflowId = request.getWorkflowId();
// Stash thread context to interact with system index
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getRequest, ActionListener.wrap(response -> {
context.restore();

// TODO : Retrieve template from global context index using node client
Template template = null; // temporary, remove later
// Parse template from document source
Template template = Template.parse(response.getSourceAsString());

// TODO : use node client to update state index entry to PROVISIONING, given workflowId
// TODO : Update state index entry to PROVISIONING, given workflowId

listener.onResponse(new WorkflowResponse(workflowId));
executeWorkflowAsync(workflowId, template.workflows().get(PROVISION_WORKFLOW));
// Respond to rest action then execute provisioning workflow async
listener.onResponse(new WorkflowResponse(workflowId));
executeWorkflowAsync(workflowId, template.workflows().get(PROVISION_WORKFLOW));
}, exception -> {
logger.error("Failed to retrieve template from global context.", exception);
listener.onFailure(exception);
}));
} catch (Exception e) {
logger.error("Failed to retrieve template from global context.", e);
listener.onFailure(e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,16 @@ public String getName() {
return NAME;
}

// TODO : Move to index management class, pending implementation
/**
* Checks if the given index exists
* @param indexName the name of the index
* @return boolean indicating the existence of an index
*/
public boolean doesIndexExist(String indexName) {
return clusterService.state().metadata().hasIndex(indexName);
}

/**
* Create Index if it's absent
* @param index The index that needs to be created
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.opensearch.client.AdminClient;
import org.opensearch.client.Client;
import org.opensearch.common.settings.Settings;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
Expand All @@ -24,13 +25,15 @@ public class FlowFrameworkPluginTests extends OpenSearchTestCase {

private Client client;
private ThreadPool threadPool;
private Settings settings;

@Override
public void setUp() throws Exception {
super.setUp();
client = mock(Client.class);
when(client.admin()).thenReturn(mock(AdminClient.class));
threadPool = new TestThreadPool(FlowFrameworkPluginTests.class.getName());
settings = Settings.EMPTY;
}

@Override
Expand All @@ -41,10 +44,10 @@ public void tearDown() throws Exception {

public void testPlugin() throws IOException {
try (FlowFrameworkPlugin ffp = new FlowFrameworkPlugin()) {
assertEquals(2, ffp.createComponents(client, null, threadPool, null, null, null, null, null, null, null, null).size());
assertEquals(3, ffp.createComponents(client, null, threadPool, null, null, null, null, null, null, null, null).size());
assertEquals(2, ffp.getRestHandlers(null, null, null, null, null, null, null).size());
assertEquals(2, ffp.getActions().size());
assertEquals(1, ffp.getExecutorBuilders(null).size());
assertEquals(1, ffp.getExecutorBuilders(settings).size());
}
}
}
Loading

0 comments on commit c7c819b

Please sign in to comment.