diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowState.java b/src/main/java/org/opensearch/flowframework/model/WorkflowState.java index b8086aa4c..e962f5758 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowState.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowState.java @@ -8,12 +8,15 @@ */ package org.opensearch.flowframework.model; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.commons.authuser.User; import org.opensearch.core.common.ParsingException; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParseException; @@ -393,6 +396,22 @@ public static WorkflowState parse(XContentParser parser) throws IOException { .build(); } + /** + * Parse a JSON workflow state + * @param json A string containing a JSON representation of a workflow state + * @return A {@link WorkflowState} represented by the JSON + * @throws IOException on failure to parse + */ + public static WorkflowState parse(String json) throws IOException { + XContentParser parser = JsonXContent.jsonXContent.createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + json + ); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + return parse(parser); + } + /** * The workflowID associated with this workflow-state * @return the workflowId diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java index 147239c0f..37175e949 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java @@ -298,6 +298,10 @@ protected Response createWorkflow(Template template) throws Exception { return TestHelpers.makeRequest(client(), "POST", WORKFLOW_URI, ImmutableMap.of(), template.toJson(), null); } + protected Response createWorkflowWithProvision(Template template) throws Exception { + return TestHelpers.makeRequest(client(), "POST", WORKFLOW_URI + "?provision=true", ImmutableMap.of(), template.toJson(), null); + } + /** * Helper method to invoke the Create Workflow Rest Action with dry run validation * @param template the template to create @@ -343,6 +347,40 @@ protected Response provisionWorkflow(String workflowId) throws Exception { ); } + /** + * Helper method to invoke the Deprovision Workflow Rest Action + * @param workflowId the workflow ID to deprovision + * @return a rest response + * @throws Exception if the request fails + */ + protected Response deprovisionWorkflow(String workflowId) throws Exception { + return TestHelpers.makeRequest( + client(), + "POST", + String.format(Locale.ROOT, "%s/%s/%s", WORKFLOW_URI, workflowId, "_deprovision"), + ImmutableMap.of(), + "", + null + ); + } + + /** + * Helper method to invoke the Delete Workflow Rest Action + * @param workflowId the workflow ID to delete + * @return a rest response + * @throws Exception if the request fails + */ + protected Response deleteWorkflow(String workflowId) throws Exception { + return TestHelpers.makeRequest( + client(), + "DELETE", + String.format(Locale.ROOT, "%s/%s", WORKFLOW_URI, workflowId), + ImmutableMap.of(), + "", + null + ); + } + /** * Helper method to invoke the Get Workflow Rest Action * @param workflowId the workflow ID to get the status @@ -395,6 +433,31 @@ protected SearchResponse searchWorkflows(String query) throws Exception { } } + protected SearchResponse searchWorkflowState(String query) throws Exception { + Response restSearchResponse = TestHelpers.makeRequest( + client(), + "GET", + String.format(Locale.ROOT, "%s/state/_search", WORKFLOW_URI), + ImmutableMap.of(), + query, + null + ); + assertEquals(RestStatus.OK, TestHelpers.restStatus(restSearchResponse)); + + // Parse entity content into SearchResponse + MediaType mediaType = MediaType.fromMediaType(restSearchResponse.getEntity().getContentType().getValue()); + try ( + XContentParser parser = mediaType.xContent() + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + restSearchResponse.getEntity().getContent() + ) + ) { + return SearchResponse.fromXContent(parser); + } + } + /** * Helper method to invoke the Get Workflow Rest Action and assert the provisioning and state status * @param workflowId the workflow ID to get the status @@ -408,8 +471,8 @@ protected void getAndAssertWorkflowStatus(String workflowId, State stateStatus, assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); Map responseMap = entityAsMap(response); - assertEquals(stateStatus.name(), (String) responseMap.get(CommonValue.STATE_FIELD)); - assertEquals(provisioningStatus.name(), (String) responseMap.get(CommonValue.PROVISIONING_PROGRESS_FIELD)); + assertEquals(stateStatus.name(), responseMap.get(CommonValue.STATE_FIELD)); + assertEquals(provisioningStatus.name(), responseMap.get(CommonValue.PROVISIONING_PROGRESS_FIELD)); } diff --git a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java index f4124f28d..cfc5919d4 100644 --- a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java +++ b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java @@ -21,10 +21,15 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.flowframework.model.WorkflowState; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; import static org.opensearch.flowframework.common.CommonValue.CREDENTIAL_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; @@ -174,4 +179,58 @@ public void testCreateAndProvisionRemoteModelWorkflow() throws Exception { assertNotNull(resourcesCreated.get(0).resourceId()); } + public void testCreateAndProvisionAgentFrameworkWorkflow() throws Exception { + Template template = TestHelpers.createTemplateFromFile("agent-framework.json"); + + // Hit Create Workflow API to create agent-framework template, with template validation check and provision parameter + Response response = createWorkflowWithProvision(template); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + Map responseMap = entityAsMap(response); + String workflowId = (String) responseMap.get(WORKFLOW_ID); + // wait and ensure state is completed/done + assertBusy(() -> { getAndAssertWorkflowStatus(workflowId, State.COMPLETED, ProvisioningProgress.DONE); }, 30, TimeUnit.SECONDS); + + // Hit Search State API with the workflow id created above + String query = "{\"query\":{\"ids\":{\"values\":[\"" + workflowId + "\"]}}}"; + SearchResponse searchResponse = searchWorkflowState(query); + assertEquals(1, searchResponse.getHits().getTotalHits().value); + String searchHitSource = searchResponse.getHits().getAt(0).getSourceAsString(); + WorkflowState searchHitWorkflowState = WorkflowState.parse(searchHitSource); + + // Assert based on the agent-framework template + List resourcesCreated = searchHitWorkflowState.resourcesCreated(); + Set expectedStepNames = new HashSet<>(); + expectedStepNames.add("root_agent"); + expectedStepNames.add("sub_agent"); + expectedStepNames.add("openAI_connector"); + expectedStepNames.add("gpt-3.5-model"); + expectedStepNames.add("deployed-gpt-3.5-model"); + Set stepNames = resourcesCreated.stream().map(ResourceCreated::workflowStepId).collect(Collectors.toSet()); + + assertEquals(5, resourcesCreated.size()); + assertEquals(stepNames, expectedStepNames); + assertNotNull(resourcesCreated.get(0).resourceId()); + + // Hit Deprovision API + Response deprovisionResponse = deprovisionWorkflow(workflowId); + assertEquals(RestStatus.OK, TestHelpers.restStatus(deprovisionResponse)); + assertBusy( + () -> { getAndAssertWorkflowStatus(workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); }, + 30, + TimeUnit.SECONDS + ); + + // Hit Delete API + Response deleteResponse = deleteWorkflow(workflowId); + assertEquals(RestStatus.OK, TestHelpers.restStatus(deleteResponse)); + + // wait for deletion to complete + Thread.sleep(30000); + + // Search this workflow id in global_context index to make sure it's deleted + SearchResponse searchResponseAfterDeletion = searchWorkflows(query); + assertBusy(() -> assertEquals(0, searchResponseAfterDeletion.getHits().getTotalHits().value), 30, TimeUnit.SECONDS); + + } + } diff --git a/src/test/resources/template/agent-framework.json b/src/test/resources/template/agent-framework.json new file mode 100644 index 000000000..713214638 --- /dev/null +++ b/src/test/resources/template/agent-framework.json @@ -0,0 +1,201 @@ +{ + "name": "opensearch-assistant-workflow", + "description": "test case", + "use_case": "REGISTER_AGENT", + "version": { + "template": "1.0.0", + "compatibility": [ + "2.12.0", + "3.0.0" + ] + }, + "workflows": { + "provision": { + "nodes": [ + { + "id": "openAI_connector", + "type": "create_connector", + "user_inputs": { + "name": "OpenAI Chat Connector", + "description": "The connector to public OpenAI model service for GPT 3.5", + "protocol": "aws_sigv4", + "version": "1", + "parameters": { + "region": "us-east-1", + "service_name": "bedrock", + "anthropic_version": "bedrock-2023-05-31", + "endpoint": "bedrock.us-east-1.amazonaws.com", + "auth": "Sig_V4", + "content_type": "application/json", + "max_tokens_to_sample": 8000, + "temperature": 0.0001 + }, + "credential": { + "access_key": "", + "secret_key": "", + "session_token": "" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-v2/invoke", + "request_body": "{\"prompt\":\"${parameters.prompt}\", \"max_tokens_to_sample\":${parameters.max_tokens_to_sample}, \"temperature\":${parameters.temperature}, \"anthropic_version\":\"${parameters.anthropic_version}\" }" + } + ] + } + }, + { + "id": "gpt-3.5-model", + "type": "register_remote_model", + "previous_node_inputs": { + "openAI_connector": "connector_id" + }, + "user_inputs": { + "name": "flow-register-remote-test-gpt-3.5-turbo", + "function_name": "remote", + "description": "test model" + } + }, + { + "id": "deployed-gpt-3.5-model", + "type": "deploy_model", + "previous_node_inputs": { + "gpt-3.5-model": "model_id" + } + }, + { + "id": "ml_model_tool", + "type": "create_tool", + "previous_node_inputs": { + "deployed-gpt-3.5-model": "model_id" + }, + "user_inputs": { + "name": "MLModelTool", + "type": "MLModelTool", + "alias": "language_model_tool", + "description": "A general tool to answer any question. But it can't handle math problem", + "parameters": { + "prompt": "Answer the question as best you can.", + "response_filter": "choices[0].message.content" + } + } + }, + { + "id": "math_tool", + "type": "create_tool", + "user_inputs": { + "name": "MathTool", + "type": "MathTool", + "description": "A general tool to calculate any math problem. The action input must be valid math expression, like 2+3", + "parameters": { + "max_iteration": 5 + } + } + }, + { + "id": "sub_agent", + "type": "register_agent", + "previous_node_inputs": { + "math_tool": "tools" + }, + "user_inputs": { + "name": "Sub Agent", + "type": "cot", + "description": "this is a test agent", + "parameters": { + "hello": "world" + }, + "llm.model_id": "ldzS04kBxRPZ5cnWrqpd", + "llm.parameters": { + "max_iteration": "5", + "stop_when_no_tool_found": "true" + }, + "memory": { + "type": "conversation_index" + }, + "app_type": "chatbot", + "created_time": 1689793598499, + "last_updated_time": 1689793598530 + } + }, + { + "id": "agent_tool", + "type": "create_tool", + "previous_node_inputs": { + "sub_agent": "agent_id" + }, + "user_inputs": { + "name": "AgentTool", + "type": "AgentTool", + "description": "Root Agent Tool", + "parameters": { + "max_iteration": 5 + } + } + }, + { + "id": "root_agent", + "type": "register_agent", + "previous_node_inputs": { + "deployed-gpt-3.5-model": "model_id", + "ml_model_tool": "tools", + "agent_tool": "tools" + }, + "user_inputs": { + "name": "DEMO-Test_Agent_For_CoT", + "type": "cot", + "description": "this is a test agent", + "parameters": { + "hello": "world" + }, + "llm.parameters": { + "max_iteration": "5", + "stop_when_no_tool_found": "true" + }, + "memory": { + "type": "conversation_index" + }, + "app_type": "chatbot", + "created_time": 1689793598499, + "last_updated_time": 1689793598530 + } + } + ], + "edges": [ + { + "source": "openAI_connector", + "dest": "gpt-3.5-model" + }, + { + "source": "gpt-3.5-model", + "dest": "deployed-gpt-3.5-model" + }, + { + "source": "deployed-gpt-3.5-model", + "dest": "root_agent" + }, + { + "source": "deployed-gpt-3.5-model", + "dest": "ml_model_tool" + }, + { + "source": "ml_model_tool", + "dest": "root_agent" + }, + { + "source": "math_tool", + "dest": "sub_agent" + }, + { + "source": "sub_agent", + "dest": "agent_tool" + }, + { + "source": "agent_tool", + "dest": "root_agent" + } + ] + } + } +}