From b3f9d6514d2c0530c8bffac654fe175452673169 Mon Sep 17 00:00:00 2001 From: yuye-aws Date: Thu, 29 Aug 2024 01:26:54 +0800 Subject: [PATCH] feat: parse connector id from tool parameters map (#846) * feat: parse connector id from tool parameters map Signed-off-by: yuye-aws * update changelog Signed-off-by: yuye-aws * implement unit test for connector, model and agent id Signed-off-by: yuye-aws * tool step id: make node id unique Signed-off-by: yuye-aws * integration test: create agent with connector tool Signed-off-by: yuye-aws * integration test: update with get agent and get workflow Signed-off-by: yuye-aws * optimize: iterate through connector_id model_id and agent_id Signed-off-by: yuye-aws * update changelog Signed-off-by: yuye-aws --------- Signed-off-by: yuye-aws --- CHANGELOG.md | 1 + .../flowframework/workflow/ToolStep.java | 53 ++++---- .../FlowFrameworkRestTestCase.java | 18 ++- .../rest/FlowFrameworkRestApiIT.java | 77 ++++++++++- .../flowframework/workflow/ToolStepTests.java | 75 +++++++++- ...r-createconnectortool-createflowagent.json | 76 +++++++++++ ...ector-registerremotemodel-deploymodel.json | 128 +++++++++--------- 7 files changed, 329 insertions(+), 99 deletions(-) create mode 100644 src/test/resources/template/createconnector-createconnectortool-createflowagent.json diff --git a/CHANGELOG.md b/CHANGELOG.md index 192782355..8723fe35d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ### Features - Adds reprovision API to support updating search pipelines, ingest pipelines index settings ([#804](https://github.com/opensearch-project/flow-framework/pull/804)) - Adds user level access control based on backend roles ([#838](https://github.com/opensearch-project/flow-framework/pull/838)) +- Support parsing connector_id when creating tools ([#846](https://github.com/opensearch-project/flow-framework/pull/846)) ### Enhancements ### Bug Fixes diff --git a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java index 7f9bd609d..45e2ee240 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java @@ -28,6 +28,7 @@ import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD; import static org.opensearch.flowframework.common.CommonValue.TYPE; import static org.opensearch.flowframework.common.WorkflowResources.AGENT_ID; +import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; /** @@ -64,7 +65,15 @@ public PlainActionFuture execute( String name = (String) inputs.get(NAME_FIELD); String description = (String) inputs.get(DESCRIPTION_FIELD); Boolean includeOutputInAgentResponse = ParseUtils.parseIfExists(inputs, INCLUDE_OUTPUT_IN_AGENT_RESPONSE, Boolean.class); - Map parameters = getToolsParametersMap(inputs.get(PARAMETERS_FIELD), previousNodeInputs, outputs); + + // parse connector_id, model_id and agent_id from previous node inputs + Set toolParameterKeys = Set.of(CONNECTOR_ID, MODEL_ID, AGENT_ID); + Map parameters = getToolsParametersMap( + inputs.get(PARAMETERS_FIELD), + previousNodeInputs, + outputs, + toolParameterKeys + ); MLToolSpec.MLToolSpecBuilder builder = MLToolSpec.builder(); @@ -110,39 +119,29 @@ public String getName() { private Map getToolsParametersMap( Object parameters, Map previousNodeInputs, - Map outputs + Map outputs, + Set toolParameterKeys ) { @SuppressWarnings("unchecked") Map parametersMap = (Map) parameters; - Optional previousNodeModel = previousNodeInputs.entrySet() - .stream() - .filter(e -> MODEL_ID.equals(e.getValue())) - .map(Map.Entry::getKey) - .findFirst(); - - Optional previousNodeAgent = previousNodeInputs.entrySet() - .stream() - .filter(e -> AGENT_ID.equals(e.getValue())) - .map(Map.Entry::getKey) - .findFirst(); - - // Case when modelId is passed through previousSteps and not present already in parameters - if (previousNodeModel.isPresent() && !parametersMap.containsKey(MODEL_ID)) { - WorkflowData previousNodeOutput = outputs.get(previousNodeModel.get()); - if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(MODEL_ID)) { - parametersMap.put(MODEL_ID, previousNodeOutput.getContent().get(MODEL_ID).toString()); - } - } - // Case when agentId is passed through previousSteps and not present already in parameters - if (previousNodeAgent.isPresent() && !parametersMap.containsKey(AGENT_ID)) { - WorkflowData previousNodeOutput = outputs.get(previousNodeAgent.get()); - if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(AGENT_ID)) { - parametersMap.put(AGENT_ID, previousNodeOutput.getContent().get(AGENT_ID).toString()); + for (String toolParameterKey : toolParameterKeys) { + Optional previousNodeParameter = previousNodeInputs.entrySet() + .stream() + .filter(e -> toolParameterKey.equals(e.getValue())) + .map(Map.Entry::getKey) + .findFirst(); + + // Case when toolParameterKey is passed through previousSteps and not present already in parameters + if (previousNodeParameter.isPresent() && !parametersMap.containsKey(toolParameterKey)) { + WorkflowData previousNodeOutput = outputs.get(previousNodeParameter.get()); + if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(toolParameterKey)) { + parametersMap.put(toolParameterKey, previousNodeOutput.getContent().get(toolParameterKey).toString()); + } } } - // For other cases where modelId is already present in the parameters or not return the parametersMap + // For other cases where toolParameterKey is already present in the parameters or not return the parametersMap return parametersMap; } diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java index 5876ca1d7..49fb46a7a 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java @@ -660,6 +660,23 @@ protected Response getWorkflowStep(RestClient client) throws Exception { ); } + /** + * Helper method to invoke the Get Agent Rest Action + * @param client the rest client + * @return rest response + * @throws Exception + */ + protected Response getAgent(RestClient client, String agentId) throws Exception { + return TestHelpers.makeRequest( + client, + "GET", + String.format(Locale.ROOT, "/_plugins/_ml/agents/%s", agentId), + Collections.emptyMap(), + "", + null + ); + } + /** * Helper method to invoke the Search Workflow Rest Action with the given query * @param client the rest client @@ -668,7 +685,6 @@ protected Response getWorkflowStep(RestClient client) throws Exception { * @throws Exception if the request fails */ protected SearchResponse searchWorkflows(RestClient client, String query) throws Exception { - // Execute search Response restSearchResponse = TestHelpers.makeRequest( client, diff --git a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java index d176adc3b..11ba096c1 100644 --- a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java +++ b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java @@ -30,6 +30,7 @@ import java.nio.charset.StandardCharsets; import java.time.Instant; +import java.util.ArrayList; import java.util.Collections; import java.util.EnumSet; import java.util.HashMap; @@ -56,7 +57,6 @@ public void waitToStart() throws Exception { } public void testSearchWorkflows() throws Exception { - // Create a Workflow that has a credential 12345 Template template = TestHelpers.createTemplateFromFile("createconnector-registerremotemodel-deploymodel.json"); Response response = createWorkflow(client(), template); @@ -228,7 +228,6 @@ public void testCreateAndProvisionCyclicalTemplate() throws Exception { } public void testCreateAndProvisionRemoteModelWorkflow() throws Exception { - // Using a 3 step template to create a connector, register remote model and deploy model Template template = TestHelpers.createTemplateFromFile("createconnector-registerremotemodel-deploymodel.json"); @@ -331,6 +330,79 @@ public void testCreateAndProvisionAgentFrameworkWorkflow() throws Exception { assertBusy(() -> { getAndAssertWorkflowStatusNotFound(client(), workflowId); }, 30, TimeUnit.SECONDS); } + public void testCreateAndProvisionConnectorToolAgentFrameworkWorkflow() throws Exception { + // Create a Workflow that has a credential 12345 + Template template = TestHelpers.createTemplateFromFile("createconnector-createconnectortool-createflowagent.json"); + + // Hit Create Workflow API to create agent-framework template, with template validation check and provision parameter + Response response = createWorkflowWithProvision(client(), template); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + Map responseMap = entityAsMap(response); + String workflowId = (String) responseMap.get(WORKFLOW_ID); + // wait and ensure state is completed/done + assertBusy( + () -> { getAndAssertWorkflowStatus(client(), workflowId, State.COMPLETED, ProvisioningProgress.DONE); }, + 120, + TimeUnit.SECONDS + ); + + // Assert based on the agent-framework template + List resourcesCreated = getResourcesCreated(client(), workflowId, 120); + Map resourceMap = resourcesCreated.stream() + .collect(Collectors.toMap(ResourceCreated::workflowStepName, r -> r)); + assertEquals(2, resourceMap.size()); + assertTrue(resourceMap.containsKey("create_connector")); + assertTrue(resourceMap.containsKey("register_agent")); + String connectorId = resourceMap.get("create_connector").resourceId(); + String agentId = resourceMap.get("register_agent").resourceId(); + assertNotNull(connectorId); + assertNotNull(agentId); + + // Assert that the agent contains the correct connector_id + response = getAgent(client(), agentId); + Map agentResponse = entityAsMap(response); + assertTrue(agentResponse.containsKey("tools")); + @SuppressWarnings("unchecked") + ArrayList> tools = (ArrayList>) agentResponse.get("tools"); + assertEquals(1, tools.size()); + Map tool = tools.getFirst(); + assertTrue(tool.containsKey("parameters")); + @SuppressWarnings("unchecked") + Map toolParameters = (Map) tool.get("parameters"); + assertEquals(toolParameters, Map.of("connector_id", connectorId)); + + // Hit Deprovision API + // By design, this may not completely deprovision the first time if it takes >2s to process removals + Response deprovisionResponse = deprovisionWorkflow(client(), workflowId); + try { + assertBusy( + () -> { getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); }, + 30, + TimeUnit.SECONDS + ); + } catch (ComparisonFailure e) { + // 202 return if still processing + assertEquals(RestStatus.ACCEPTED, TestHelpers.restStatus(deprovisionResponse)); + } + if (TestHelpers.restStatus(deprovisionResponse) == RestStatus.ACCEPTED) { + // Short wait before we try again + Thread.sleep(10000); + deprovisionResponse = deprovisionWorkflow(client(), workflowId); + assertBusy( + () -> { getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); }, + 30, + TimeUnit.SECONDS + ); + } + assertEquals(RestStatus.OK, TestHelpers.restStatus(deprovisionResponse)); + // Hit Delete API + Response deleteResponse = deleteWorkflow(client(), workflowId); + assertEquals(RestStatus.OK, TestHelpers.restStatus(deleteResponse)); + + // Verify state doc is deleted + assertBusy(() -> { getAndAssertWorkflowStatusNotFound(client(), workflowId); }, 30, TimeUnit.SECONDS); + } + public void testReprovisionWorkflow() throws Exception { // Begin with a template to register a local pretrained model Template template = TestHelpers.createTemplateFromFile("registerremotemodel.json"); @@ -650,7 +722,6 @@ public void testCreateAndProvisionIngestAndSearchPipeline() throws Exception { } public void testDefaultCohereUseCase() throws Exception { - // Hit Create Workflow API with original template Response response = createWorkflowWithUseCaseWithNoValidation( client(), diff --git a/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java index 45cb3816c..029b5c835 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java @@ -14,13 +14,26 @@ import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.test.OpenSearchTestCase; -import java.io.IOException; import java.util.Collections; import java.util.Map; import java.util.concurrent.ExecutionException; +import static org.opensearch.flowframework.common.WorkflowResources.AGENT_ID; +import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; +import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; + public class ToolStepTests extends OpenSearchTestCase { private WorkflowData inputData; + private WorkflowData inputDataWithConnectorId; + private WorkflowData inputDataWithModelId; + private WorkflowData inputDataWithAgentId; + private static final String mockedConnectorId = "mocked-connector-id"; + private static final String mockedModelId = "mocked-model-id"; + private static final String mockedAgentId = "mocked-agent-id"; + private static final String createConnectorNodeId = "create_connector_node_id"; + private static final String createModelNodeId = "create_model_node_id"; + private static final String createAgentNodeId = "create_agent_node_id"; + private WorkflowData boolStringInputData; private WorkflowData badBoolInputData; @@ -39,6 +52,9 @@ public void setUp() throws Exception { "test-id", "test-node-id" ); + inputDataWithConnectorId = new WorkflowData(Map.of(CONNECTOR_ID, mockedConnectorId), "test-id", createConnectorNodeId); + inputDataWithModelId = new WorkflowData(Map.of(MODEL_ID, mockedModelId), "test-id", createModelNodeId); + inputDataWithAgentId = new WorkflowData(Map.of(AGENT_ID, mockedAgentId), "test-id", createAgentNodeId); boolStringInputData = new WorkflowData( Map.ofEntries( Map.entry("type", "type"), @@ -63,7 +79,7 @@ public void setUp() throws Exception { ); } - public void testTool() throws IOException, ExecutionException, InterruptedException { + public void testTool() throws ExecutionException, InterruptedException { ToolStep toolStep = new ToolStep(); PlainActionFuture future = toolStep.execute( @@ -88,7 +104,7 @@ public void testTool() throws IOException, ExecutionException, InterruptedExcept assertEquals(MLToolSpec.class, future.get().getContent().get("tools").getClass()); } - public void testBoolParseFail() throws IOException, ExecutionException, InterruptedException { + public void testBoolParseFail() { ToolStep toolStep = new ToolStep(); PlainActionFuture future = toolStep.execute( @@ -100,10 +116,61 @@ public void testBoolParseFail() throws IOException, ExecutionException, Interrup ); assertTrue(future.isDone()); - ExecutionException e = assertThrows(ExecutionException.class, () -> future.get()); + ExecutionException e = assertThrows(ExecutionException.class, future::get); assertEquals(WorkflowStepException.class, e.getCause().getClass()); WorkflowStepException w = (WorkflowStepException) e.getCause(); assertEquals("Failed to parse value [yes] as only [true] or [false] are allowed.", w.getMessage()); assertEquals(RestStatus.BAD_REQUEST, w.getRestStatus()); } + + public void testToolWithConnectorId() throws ExecutionException, InterruptedException { + ToolStep toolStep = new ToolStep(); + + PlainActionFuture future = toolStep.execute( + inputData.getNodeId(), + inputData, + Map.of(createConnectorNodeId, inputDataWithConnectorId), + Map.of(createConnectorNodeId, CONNECTOR_ID), + Collections.emptyMap() + ); + assertTrue(future.isDone()); + Object tools = future.get().getContent().get("tools"); + assertEquals(MLToolSpec.class, tools.getClass()); + MLToolSpec mlToolSpec = (MLToolSpec) tools; + assertEquals(mlToolSpec.getParameters(), Map.of(CONNECTOR_ID, mockedConnectorId)); + } + + public void testToolWithModelId() throws ExecutionException, InterruptedException { + ToolStep toolStep = new ToolStep(); + + PlainActionFuture future = toolStep.execute( + inputData.getNodeId(), + inputData, + Map.of(createModelNodeId, inputDataWithModelId), + Map.of(createModelNodeId, MODEL_ID), + Collections.emptyMap() + ); + assertTrue(future.isDone()); + Object tools = future.get().getContent().get("tools"); + assertEquals(MLToolSpec.class, tools.getClass()); + MLToolSpec mlToolSpec = (MLToolSpec) tools; + assertEquals(mlToolSpec.getParameters(), Map.of(MODEL_ID, mockedModelId)); + } + + public void testToolWithAgentId() throws ExecutionException, InterruptedException { + ToolStep toolStep = new ToolStep(); + + PlainActionFuture future = toolStep.execute( + inputData.getNodeId(), + inputData, + Map.of(createAgentNodeId, inputDataWithAgentId), + Map.of(createAgentNodeId, AGENT_ID), + Collections.emptyMap() + ); + assertTrue(future.isDone()); + Object tools = future.get().getContent().get("tools"); + assertEquals(MLToolSpec.class, tools.getClass()); + MLToolSpec mlToolSpec = (MLToolSpec) tools; + assertEquals(mlToolSpec.getParameters(), Map.of(AGENT_ID, mockedAgentId)); + } } diff --git a/src/test/resources/template/createconnector-createconnectortool-createflowagent.json b/src/test/resources/template/createconnector-createconnectortool-createflowagent.json new file mode 100644 index 000000000..e54310cf3 --- /dev/null +++ b/src/test/resources/template/createconnector-createconnectortool-createflowagent.json @@ -0,0 +1,76 @@ +{ + "name": "createconnector-createconnectortool-createflowagent", + "description": "test case", + "use_case": "TEST_CASE", + "version": { + "template": "1.0.0", + "compatibility": [ + "2.15.0", + "3.0.0" + ] + }, + "workflows": { + "provision": { + "nodes": [ + { + "id": "create_connector", + "type": "create_connector", + "user_inputs": { + "name": "OpenAI Chat Connector", + "description": "The connector to public OpenAI model service for GPT 3.5", + "version": "1", + "protocol": "http", + "parameters": { + "endpoint": "api.openai.com", + "model": "gpt-3.5-turbo" + }, + "credential": { + "openAI_key": "12345" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://${parameters.endpoint}/v1/chat/completions" + } + ] + } + }, + { + "id": "create_tool", + "type": "create_tool", + "previous_node_inputs": { + "create_connector": "connector_id" + }, + "user_inputs": { + "parameters": {}, + "name": "ConnectorTool", + "type": "ConnectorTool" + } + }, + { + "id": "create_flow_agent", + "type": "register_agent", + "previous_node_inputs": { + "create_tool": "tools" + }, + "user_inputs": { + "parameters": {}, + "type": "flow", + "name": "OpenAI Chat Agent" + } + } + ], + "edges": [ + { + "source": "create_connector", + "dest": "create_tool" + }, + { + "source": "create_tool", + "dest": "create_flow_agent" + } + ] + } + } +} diff --git a/src/test/resources/template/createconnector-registerremotemodel-deploymodel.json b/src/test/resources/template/createconnector-registerremotemodel-deploymodel.json index d889e6b9f..4a5513660 100644 --- a/src/test/resources/template/createconnector-registerremotemodel-deploymodel.json +++ b/src/test/resources/template/createconnector-registerremotemodel-deploymodel.json @@ -1,71 +1,71 @@ { - "name": "createconnector-registerremotemodel-deploymodel", - "description": "test case", - "use_case": "TEST_CASE", - "version": { - "template": "1.0.0", - "compatibility": [ - "2.12.0", - "3.0.0" - ] - }, - "workflows": { - "provision": { - "nodes": [ - { - "id": "workflow_step_1", - "type": "create_connector", - "user_inputs": { - "name": "OpenAI Chat Connector", - "description": "The connector to public OpenAI model service for GPT 3.5", - "version": "1", - "protocol": "http", - "parameters": { - "endpoint": "api.openai.com", - "model": "gpt-3.5-turbo" - }, - "credential": { - "openAI_key": "12345" - }, - "actions": [ - { - "action_type": "predict", - "method": "POST", - "url": "https://${parameters.endpoint}/v1/chat/completions" - } - ] - } - }, - { - "id": "workflow_step_2", - "type": "register_remote_model", - "previous_node_inputs": { - "workflow_step_1": "connector_id" + "name": "createconnector-registerremotemodel-deploymodel", + "description": "test case", + "use_case": "TEST_CASE", + "version": { + "template": "1.0.0", + "compatibility": [ + "2.12.0", + "3.0.0" + ] + }, + "workflows": { + "provision": { + "nodes": [ + { + "id": "workflow_step_1", + "type": "create_connector", + "user_inputs": { + "name": "OpenAI Chat Connector", + "description": "The connector to public OpenAI model service for GPT 3.5", + "version": "1", + "protocol": "http", + "parameters": { + "endpoint": "api.openai.com", + "model": "gpt-3.5-turbo" }, - "user_inputs": { - "name": "openAI-gpt-3.5-turbo", - "function_name": "remote", - "description": "test model" - } - }, - { - "id": "workflow_step_3", - "type": "deploy_model", - "previous_node_inputs": { - "workflow_step_2": "model_id" - } + "credential": { + "openAI_key": "12345" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://${parameters.endpoint}/v1/chat/completions" + } + ] } - ], - "edges": [ - { - "source": "workflow_step_1", - "dest": "workflow_step_2" + }, + { + "id": "workflow_step_2", + "type": "register_remote_model", + "previous_node_inputs": { + "workflow_step_1": "connector_id" }, - { - "source": "workflow_step_2", - "dest": "workflow_step_3" + "user_inputs": { + "name": "openAI-gpt-3.5-turbo", + "function_name": "remote", + "description": "test model" + } + }, + { + "id": "workflow_step_3", + "type": "deploy_model", + "previous_node_inputs": { + "workflow_step_2": "model_id" } - ] - } + } + ], + "edges": [ + { + "source": "workflow_step_1", + "dest": "workflow_step_2" + }, + { + "source": "workflow_step_2", + "dest": "workflow_step_3" + } + ] } } +}