diff --git a/.github/workflows/backport.yml b/.github/workflows/backport.yml index b269557a5..9fed34a7d 100644 --- a/.github/workflows/backport.yml +++ b/.github/workflows/backport.yml @@ -37,3 +37,4 @@ jobs: with: github_token: ${{ steps.github_app_token.outputs.token }} head_template: backport/backport-<%= number %>-to-<%= base %> + failure_labels: backport-failed diff --git a/.github/workflows/test-api-consistency.yml b/.github/workflows/test-api-consistency.yml new file mode 100644 index 000000000..4c71b4d3b --- /dev/null +++ b/.github/workflows/test-api-consistency.yml @@ -0,0 +1,26 @@ +name: Daily API Consistency Test + +on: + schedule: + - cron: '0 8 * * *' # Runs daily at 8 AM UTC + workflow_dispatch: + +jobs: + API-consistency-test: + runs-on: ubuntu-latest + strategy: + matrix: + java: [21] + + steps: + - name: Checkout Flow Framework + uses: actions/checkout@v3 + + - name: Setup Java ${{ matrix.java }} + uses: actions/setup-java@v3 + with: + distribution: 'temurin' + java-version: ${{ matrix.java }} + + - name: Run API Consistency Tests + run: ./gradlew test --tests "org.opensearch.flowframework.workflow.*" diff --git a/CHANGELOG.md b/CHANGELOG.md index 30e4c2250..6a4108ce2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,17 +8,34 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ### Enhancements ### Bug Fixes ### Infrastructure +- Set Java target compatibility to JDK 21 ([#730](https://github.com/opensearch-project/flow-framework/pull/730)) + ### Documentation +- Add alert summary agent template ([#873](https://github.com/opensearch-project/flow-framework/pull/873)) +- Add alert summary with log pattern agent template ([#945](https://github.com/opensearch-project/flow-framework/pull/945)) + ### Maintenance ### Refactoring ## [Unreleased 2.x](https://github.com/opensearch-project/flow-framework/compare/2.17...2.x) ### Features +- Add ApiSpecFetcher for Fetching and Comparing API Specifications ([#651](https://github.com/opensearch-project/flow-framework/issues/651)) +- Add optional config field to tool step ([#899](https://github.com/opensearch-project/flow-framework/pull/899)) +- Add API Consistency Tests with ML-Common and Set Up Daily GitHub Action Trigger([#908](https://github.com/opensearch-project/flow-framework/issues/908)) + ### Enhancements +- Incrementally remove resources from workflow state during deprovisioning ([#898](https://github.com/opensearch-project/flow-framework/pull/898)) + ### Bug Fixes - Remove useCase and defaultParams field in WorkflowRequest ([#758](https://github.com/opensearch-project/flow-framework/pull/758)) +- Fixed Template Update Location and Improved Logger Statements in ReprovisionWorkflowTransportAction ([#918](https://github.com/opensearch-project/flow-framework/pull/918)) ### Infrastructure ### Documentation +- Add knowledge base alert agent into sample templates ([#874](https://github.com/opensearch-project/flow-framework/pull/874)) +- Add query assist data summary agent into sample templates ([#875](https://github.com/opensearch-project/flow-framework/pull/875)) +- Add suggest anomaly detector agent into sample templates ([#944](https://github.com/opensearch-project/flow-framework/pull/944)) + ### Maintenance ### Refactoring +- Update workflow state without using painless script ([#894](https://github.com/opensearch-project/flow-framework/pull/894)) diff --git a/build.gradle b/build.gradle index 8fe85bc39..97fa00352 100644 --- a/build.gradle +++ b/build.gradle @@ -7,7 +7,7 @@ import java.nio.file.Paths buildscript { ext { - opensearch_version = System.getProperty("opensearch.version", "2.17.0-SNAPSHOT") + opensearch_version = System.getProperty("opensearch.version", "2.19.0-SNAPSHOT") buildVersionQualifier = System.getProperty("build.version_qualifier", "") isSnapshot = "true" == System.getProperty("build.snapshot", "true") version_tokens = opensearch_version.tokenize('-') @@ -24,7 +24,7 @@ buildscript { opensearch_no_snapshot = opensearch_build.replace("-SNAPSHOT","") System.setProperty('tests.security.manager', 'false') common_utils_version = System.getProperty("common_utils.version", opensearch_build) - + swaggerCoreVersion = "2.2.23" bwcVersionShort = "2.12.0" bwcVersion = bwcVersionShort + ".0" bwcOpenSearchFFDownload = 'https://ci.opensearch.org/ci/dbc/distribution-build-opensearch/' + bwcVersionShort + '/latest/linux/x64/tar/builds/' + @@ -34,6 +34,10 @@ buildscript { bwcFlowFrameworkPath = bwcFilePath + "flowframework/" isSameMajorVersion = opensearch_version.split("\\.")[0] == bwcVersionShort.split("\\.")[0] + swaggerVersion = "2.1.23" + jacksonVersion = "2.18.1" + swaggerCoreVersion = "2.2.25" + } repositories { @@ -52,8 +56,8 @@ buildscript { plugins { id "de.undercouch.download" version "5.6.0" - id "org.gradle.test-retry" version "1.5.10" apply false - id "io.github.surpsg.delta-coverage" version "2.4.0" + id "org.gradle.test-retry" version "1.6.0" apply false + id "io.github.surpsg.delta-coverage" version "2.5.0" } apply plugin: 'java' @@ -164,17 +168,37 @@ configurations { dependencies { implementation "org.opensearch:opensearch:${opensearch_version}" - implementation 'org.junit.jupiter:junit-jupiter:5.11.0' + implementation 'org.junit.jupiter:junit-jupiter:5.11.3' api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" api group: 'org.opensearch.client', name: 'opensearch-rest-client', version: "${opensearch_version}" + api group: 'org.slf4j', name: 'slf4j-api', version: '1.7.36' implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.17.0' implementation "org.opensearch:common-utils:${common_utils_version}" - implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1' - implementation 'org.bouncycastle:bcprov-jdk18on:1.78' + implementation "com.amazonaws:aws-encryption-sdk-java:3.0.1" + implementation "software.amazon.cryptography:aws-cryptographic-material-providers:1.7.0" + implementation "org.dafny:DafnyRuntime:4.9.0" + implementation "software.amazon.smithy.dafny:conversion:0.1.1" + implementation 'org.bouncycastle:bcprov-jdk18on:1.78.1' implementation "jakarta.json.bind:jakarta.json.bind-api:3.0.1" implementation "org.glassfish:jakarta.json:2.0.1" implementation "org.eclipse:yasson:3.0.4" implementation "com.google.code.gson:gson:2.11.0" + // Swagger-Parser dependencies for API consistency tests + implementation "io.swagger.core.v3:swagger-models:${swaggerCoreVersion}" + implementation "io.swagger.core.v3:swagger-core:${swaggerCoreVersion}" + implementation "io.swagger.parser.v3:swagger-parser-core:${swaggerVersion}" + implementation "io.swagger.parser.v3:swagger-parser:${swaggerVersion}" + implementation "io.swagger.parser.v3:swagger-parser-v3:${swaggerVersion}" + // Declare and force Jackson dependencies for tests + testImplementation("com.fasterxml.jackson.core:jackson-databind") { + version { strictly("${jacksonVersion}") } + } + testImplementation("com.fasterxml.jackson.datatype:jackson-datatype-jsr310") { + version { strictly("${jacksonVersion}") } + } + testImplementation("com.fasterxml.jackson.core:jackson-annotations") { + version { strictly("${jacksonVersion}") } + } // ZipArchive dependencies used for integration tests zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}" @@ -184,8 +208,8 @@ dependencies { configurations.all { resolutionStrategy { - force("com.google.guava:guava:33.3.0-jre") // CVE for 31.1, keep to force transitive dependencies - force("com.fasterxml.jackson.core:jackson-core:2.17.2") // Dependency Jar Hell + force("com.google.guava:guava:33.3.1-jre") // CVE for 31.1, keep to force transitive dependencies + force("com.fasterxml.jackson.core:jackson-core:${jacksonVersion}") // Dependency Jar Hell } } } diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 2b189974c..fb602ee2a 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,7 +1,7 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionSha256Sum=5b9c5eb3f9fc2c94abaea57d90bd78747ca117ddbbf96c859d3741181a12bf2a -distributionUrl=https\://services.gradle.org/distributions/gradle-8.10-bin.zip +distributionSha256Sum=31c55713e40233a8303827ceb42ca48a47267a0ad4bab9177123121e71524c26 +distributionUrl=https\://services.gradle.org/distributions/gradle-8.10.2-bin.zip networkTimeout=10000 validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME diff --git a/release-notes/opensearch-flow-framework.release-notes-2.17.1.0.md b/release-notes/opensearch-flow-framework.release-notes-2.17.1.0.md new file mode 100644 index 000000000..19fb154a3 --- /dev/null +++ b/release-notes/opensearch-flow-framework.release-notes-2.17.1.0.md @@ -0,0 +1,6 @@ +## Version 2.17.1.0 + +Compatible with OpenSearch 2.17.1 + +### Maintenance +- Fix flaky integ test reprovisioning before template update ([#880](https://github.com/opensearch-project/flow-framework/pull/880)) diff --git a/release-notes/opensearch-flow-framework.release-notes-2.18.0.0.md b/release-notes/opensearch-flow-framework.release-notes-2.18.0.0.md new file mode 100644 index 000000000..87ff82c0b --- /dev/null +++ b/release-notes/opensearch-flow-framework.release-notes-2.18.0.0.md @@ -0,0 +1,19 @@ +## Version 2.18.0.0 + +Compatible with OpenSearch 2.18.0 + +### Features +- Add ApiSpecFetcher for Fetching and Comparing API Specifications ([#651](https://github.com/opensearch-project/flow-framework/issues/651)) +- Add optional config field to tool step ([#899](https://github.com/opensearch-project/flow-framework/pull/899)) + +### Enhancements +- Incrementally remove resources from workflow state during deprovisioning ([#898](https://github.com/opensearch-project/flow-framework/pull/898)) + +### Bug Fixes +- Fixed Template Update Location and Improved Logger Statements in ReprovisionWorkflowTransportAction ([#918](https://github.com/opensearch-project/flow-framework/pull/918)) + +### Documentation +- Add query assist data summary agent into sample templates ([#875](https://github.com/opensearch-project/flow-framework/pull/875)) + +### Refactoring +- Update workflow state without using painless script ([#894](https://github.com/opensearch-project/flow-framework/pull/894)) diff --git a/sample-templates/alert-summary-agent-claude-tested.json b/sample-templates/alert-summary-agent-claude-tested.json new file mode 100644 index 000000000..ab064f5a4 --- /dev/null +++ b/sample-templates/alert-summary-agent-claude-tested.json @@ -0,0 +1,94 @@ +{ + "name": "Alert Summary Agent", + "description": "Create Alert Summary Agent using Claude on BedRock", + "use_case": "REGISTER_AGENT", + "version": { + "template": "1.0.0", + "compatibility": ["2.17.0", "3.0.0"] + }, + "workflows": { + "provision": { + "user_params": {}, + "nodes": [ + { + "id": "create_claude_connector", + "type": "create_connector", + "previous_node_inputs": {}, + "user_inputs": { + "version": "1", + "name": "Claude instant runtime Connector", + "protocol": "aws_sigv4", + "description": "The connector to BedRock service for Claude model", + "actions": [ + { + "headers": { + "x-amz-content-sha256": "required", + "content-type": "application/json" + }, + "method": "POST", + "request_body": "{\"prompt\":\"${parameters.prompt}\", \"max_tokens_to_sample\":${parameters.max_tokens_to_sample}, \"temperature\":${parameters.temperature}, \"anthropic_version\":\"${parameters.anthropic_version}\" }", + "action_type": "predict", + "url": "https://bedrock-runtime.us-west-2.amazonaws.com/model/anthropic.claude-instant-v1/invoke" + } + ], + "credential": { + "access_key": "", + "secret_key": "", + "session_token": "" + }, + "parameters": { + "region": "us-west-2", + "endpoint": "bedrock-runtime.us-west-2.amazonaws.com", + "content_type": "application/json", + "auth": "Sig_V4", + "max_tokens_to_sample": "8000", + "service_name": "bedrock", + "temperature": "0.0001", + "response_filter": "$.completion", + "anthropic_version": "bedrock-2023-05-31" + } + } + }, + { + "id": "register_claude_model", + "type": "register_remote_model", + "previous_node_inputs": { + "create_claude_connector": "connector_id" + }, + "user_inputs": { + "description": "Claude model", + "deploy": true, + "name": "claude-instant" + } + }, + { + "id": "create_alert_summary_ml_model_tool", + "type": "create_tool", + "previous_node_inputs": { + "register_claude_model": "model_id" + }, + "user_inputs": { + "parameters": { + "prompt": "You are an OpenSearch Alert Assistant to help summarize the alerts.\n Here is the detail of alert: ${parameters.context};\n The question is: ${parameters.question}." + }, + "name": "MLModelTool", + "type": "MLModelTool" + } + }, + { + "id": "create_alert_summary_agent", + "type": "register_agent", + "previous_node_inputs": { + "create_alert_summary_ml_model_tool": "tools" + }, + "user_inputs": { + "parameters": {}, + "type": "flow", + "name": "Alert Summary Agent", + "description": "this is an alert summary agent" + } + } + ] + } + } +} diff --git a/sample-templates/alert-summary-agent-claude-tested.yml b/sample-templates/alert-summary-agent-claude-tested.yml new file mode 100644 index 000000000..ce596e071 --- /dev/null +++ b/sample-templates/alert-summary-agent-claude-tested.yml @@ -0,0 +1,71 @@ +--- +name: Alert Summary Agent +description: Create Alert Summary Agent using Claude on BedRock +use_case: REGISTER_AGENT +version: + template: 1.0.0 + compatibility: + - 2.17.0 + - 3.0.0 +workflows: + provision: + user_params: {} + nodes: + - id: create_claude_connector + type: create_connector + previous_node_inputs: {} + user_inputs: + version: '1' + name: Claude instant runtime Connector + protocol: aws_sigv4 + description: The connector to BedRock service for Claude model + actions: + - headers: + x-amz-content-sha256: required + content-type: application/json + method: POST + request_body: '{"prompt":"${parameters.prompt}", "max_tokens_to_sample":${parameters.max_tokens_to_sample}, + "temperature":${parameters.temperature}, "anthropic_version":"${parameters.anthropic_version}" + }' + action_type: predict + url: https://bedrock-runtime.us-west-2.amazonaws.com/model/anthropic.claude-instant-v1/invoke + credential: + access_key: "" + secret_key: "" + session_token: "" + parameters: + region: us-west-2 + endpoint: bedrock-runtime.us-west-2.amazonaws.com + content_type: application/json + auth: Sig_V4 + max_tokens_to_sample: '8000' + service_name: bedrock + temperature: '0.0001' + response_filter: "$.completion" + anthropic_version: bedrock-2023-05-31 + - id: register_claude_model + type: register_remote_model + previous_node_inputs: + create_claude_connector: connector_id + user_inputs: + description: Claude model + deploy: true + name: claude-instant + - id: create_alert_summary_ml_model_tool + type: create_tool + previous_node_inputs: + register_claude_model: model_id + user_inputs: + parameters: + prompt: "You are an OpenSearch Alert Assistant to help summarize the alerts.\n Here is the detail of alert: ${parameters.context};\n The question is: ${parameters.question}." + name: MLModelTool + type: MLModelTool + - id: create_alert_summary_agent + type: register_agent + previous_node_inputs: + create_alert_summary_ml_model_tool: tools + user_inputs: + parameters: {} + type: flow + name: Alert Summary Agent + description: this is an alert summary agent diff --git a/sample-templates/alert-summary-log-pattern-agent.json b/sample-templates/alert-summary-log-pattern-agent.json new file mode 100644 index 000000000..041518f89 --- /dev/null +++ b/sample-templates/alert-summary-log-pattern-agent.json @@ -0,0 +1,94 @@ +{ + "name": "Alert Summary With Log Pattern Agent", + "description": "Create Alert Summary with Log Pattern Agent using Claude on BedRock", + "use_case": "REGISTER_AGENT", + "version": { + "template": "1.0.0", + "compatibility": ["2.17.0", "3.0.0"] + }, + "workflows": { + "provision": { + "user_params": {}, + "nodes": [ + { + "id": "create_claude_connector", + "type": "create_connector", + "previous_node_inputs": {}, + "user_inputs": { + "version": "1", + "name": "Claude instant runtime Connector", + "protocol": "aws_sigv4", + "description": "The connector to BedRock service for Claude model", + "actions": [ + { + "headers": { + "x-amz-content-sha256": "required", + "content-type": "application/json" + }, + "method": "POST", + "request_body": "{\"prompt\":\"\\n\\nHuman: ${parameters.prompt}\\n\\nAssistant:\", \"max_tokens_to_sample\":${parameters.max_tokens_to_sample}, \"temperature\":${parameters.temperature}, \"anthropic_version\":\"${parameters.anthropic_version}\" }", + "action_type": "predict", + "url": "https://bedrock-runtime.us-west-2.amazonaws.com/model/anthropic.claude-instant-v1/invoke" + } + ], + "credential": { + "access_key": "", + "secret_key": "", + "session_token": "" + }, + "parameters": { + "region": "us-west-2", + "endpoint": "bedrock-runtime.us-west-2.amazonaws.com", + "content_type": "application/json", + "auth": "Sig_V4", + "max_tokens_to_sample": "8000", + "service_name": "bedrock", + "temperature": "0.0001", + "response_filter": "$.completion", + "anthropic_version": "bedrock-2023-05-31" + } + } + }, + { + "id": "register_claude_model", + "type": "register_remote_model", + "previous_node_inputs": { + "create_claude_connector": "connector_id" + }, + "user_inputs": { + "description": "Claude model", + "deploy": true, + "name": "claude-instant" + } + }, + { + "id": "create_alert_summary_with_log_pattern_ml_model_tool", + "type": "create_tool", + "previous_node_inputs": { + "register_claude_model": "model_id" + }, + "user_inputs": { + "parameters": { + "prompt": " You are an OpenSearch Alert Assistant to help summarize the alerts.\n Here is the detail of alert: \n ${parameters.context};\n \n And help detect if there is any common pattern or trend or outlier for the log pattern output. Log pattern groups the alert trigger logs by their generated patterns, the output contains some sample logs for each top-k patterns.\n Here is the log pattern output:\n ${parameters.topNLogPatternData};" + }, + "name": "MLModelTool", + "type": "MLModelTool" + } + }, + { + "id": "create_alert_summary_with_log_pattern_agent", + "type": "register_agent", + "previous_node_inputs": { + "create_alert_summary_with_log_pattern_ml_model_tool": "tools" + }, + "user_inputs": { + "parameters": {}, + "type": "flow", + "name": "Alert Summary With Log Pattern Agent", + "description": "this is an alert summary with log pattern agent" + } + } + ] + } + } +} diff --git a/sample-templates/alert-summary-log-pattern-agent.yml b/sample-templates/alert-summary-log-pattern-agent.yml new file mode 100644 index 000000000..83b23b6d9 --- /dev/null +++ b/sample-templates/alert-summary-log-pattern-agent.yml @@ -0,0 +1,88 @@ +# This template creates a connector to the BedRock service for Claude model +# It then registers a model using the connector and deploys it. +# Finally, it creates a flow agent base agent with ML Model tool to generate alert summary from log patterns. +# +# To use: +# - update the "credential" fields under the create_claude_connector node. +# - if needed, update region +# +# After provisioning: +# - returns a workflow ID +# - use the status API to get the deployed agent ID +--- +name: Alert Summary With Log Pattern Agent +description: Create Alert Summary with Log Pattern Agent using Claude on BedRock +use_case: REGISTER_AGENT +version: + template: 1.0.0 + compatibility: + - 2.17.0 + - 3.0.0 +workflows: + provision: + user_params: {} + nodes: + - id: create_claude_connector + type: create_connector + previous_node_inputs: {} + user_inputs: + version: '1' + name: Claude instant runtime Connector + protocol: aws_sigv4 + description: The connector to BedRock service for Claude model + actions: + - headers: + x-amz-content-sha256: required + content-type: application/json + method: POST + request_body: '{"prompt":"\n\nHuman: ${parameters.prompt}\n\nAssistant:", + "max_tokens_to_sample":${parameters.max_tokens_to_sample}, "temperature":${parameters.temperature}, "anthropic_version":"${parameters.anthropic_version}" + }' + action_type: predict + url: https://bedrock-runtime.us-west-2.amazonaws.com/model/anthropic.claude-instant-v1/invoke + credential: + access_key: "" + secret_key: "" + session_token: "" + parameters: + region: us-west-2 + endpoint: bedrock-runtime.us-west-2.amazonaws.com + content_type: application/json + auth: Sig_V4 + max_tokens_to_sample: '8000' + service_name: bedrock + temperature: '0.0001' + response_filter: "$.completion" + anthropic_version: bedrock-2023-05-31 + - id: register_claude_model + type: register_remote_model + previous_node_inputs: + create_claude_connector: connector_id + user_inputs: + description: Claude model + deploy: true + name: claude-instant + - id: create_alert_summary_with_log_pattern_ml_model_tool + type: create_tool + previous_node_inputs: + register_claude_model: model_id + user_inputs: + parameters: + prompt: " You are an OpenSearch Alert Assistant to help summarize + the alerts.\n Here is the detail of alert: \n ${parameters.context};\n + \ \n And help detect if there is any common pattern + or trend or outlier for the log pattern output. Log pattern groups the + alert trigger logs by their generated patterns, the output contains some + sample logs for each top-k patterns.\n Here is the log + pattern output:\n ${parameters.topNLogPatternData};" + name: MLModelTool + type: MLModelTool + - id: create_alert_summary_with_log_pattern_agent + type: register_agent + previous_node_inputs: + create_alert_summary_with_log_pattern_ml_model_tool: tools + user_inputs: + parameters: {} + type: flow + name: Alert Summary With Log Pattern Agent + description: this is an alert summary with log pattern agent diff --git a/sample-templates/anomaly-detector-suggestion-agent-claude.json b/sample-templates/anomaly-detector-suggestion-agent-claude.json new file mode 100644 index 000000000..78909fa5f --- /dev/null +++ b/sample-templates/anomaly-detector-suggestion-agent-claude.json @@ -0,0 +1,99 @@ +{ + "name": "Anomaly detector suggestion agent", + "description": "Create an anomaly detector suggestion agent using Claude on BedRock", + "use_case": "REGISTER_AGENT", + "version": { + "template": "1.0.0", + "compatibility": [ + "2.16.0", + "2.17.0", + "3.0.0" + ] + }, + "workflows": { + "provision": { + "user_params": {}, + "nodes": [ + { + "id": "create_claude_connector", + "type": "create_connector", + "previous_node_inputs": {}, + "user_inputs": { + "credential": { + "access_key": "", + "secret_key": "", + "session_token": "" + }, + "parameters": { + "endpoint": "bedrock-runtime.us-west-2.amazonaws.com", + "content_type": "application/json", + "auth": "Sig_V4", + "max_tokens_to_sample": "8000", + "service_name": "bedrock", + "temperature": 0, + "response_filter": "$.completion", + "region": "us-west-2", + "anthropic_version": "bedrock-2023-05-31" + }, + "version": "1", + "name": "Claude instant runtime Connector", + "protocol": "aws_sigv4", + "description": "The connector to BedRock service for claude model", + "actions": [ + { + "headers": { + "x-amz-content-sha256": "required", + "content-type": "application/json" + }, + "method": "POST", + "request_body": "{\"prompt\":\"${parameters.prompt}\", \"max_tokens_to_sample\":${parameters.max_tokens_to_sample}, \"temperature\":${parameters.temperature}, \"anthropic_version\":\"${parameters.anthropic_version}\" }", + "action_type": "predict", + "url": "https://bedrock-runtime.us-west-2.amazonaws.com/model/anthropic.claude-instant-v1/invoke" + } + ] + } + }, + { + "id": "register_claude_model", + "type": "register_remote_model", + "previous_node_inputs": { + "create_claude_connector": "connector_id" + }, + "user_inputs": { + "name": "claude-instant", + "description": "Claude model", + "deploy": true + } + }, + { + "id": "create_anomoly_detectors_tool", + "type": "create_tool", + "previous_node_inputs": { + "register_claude_model": "model_id" + }, + "user_inputs": { + "parameters": { + "model_type":"", + "prompt": "Human:\" turn\": Here are some examples of the create anomaly detector API in OpenSearch: Example 1. POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"timestamp\",\"indices\":[\"ecommerce\"],\"feature_attributes\":[{\"feature_name\":\"feature1\",\"aggregation_query\":{\"avg_total_revenue\":{\"avg\":{\"field\":\"total_revenue_usd\"}}}},{\"feature_name\":\"feature2\",\"aggregation_query\":{\"max_total_revenue\":{\"max\":{\"field\":\"total_revenue_usd\"}}}}]}, Example 2. POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"@timestamp\",\"indices\":[\"access_log*\"],\"feature_attributes\":[{\"feature_name\":\"feature1\",\"feature_enabled\":true,\"aggregation_query\":{\"latencyAvg\":{\"sum\":{\"field\":\"responseLatency\"}}}}]} and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, and the optional aggregation methods are value_count, avg, min, max and sum, note that value_count can perform on both numeric and keyword type fields, and other aggregation methods can only perform on numeric type fields. Please give me some suggestion about creating an anomaly detector for the index ${indexInfo.indexName}, you need to give the key information: the top 3 suitable aggregation fields which are numeric types(long, integer, double, float, short etc.) and the suitable aggregation method for each field, you should give at most 3 aggregation fields and corresponding aggregation methods, if there are no numeric type fields, both the aggregation field and method are empty string, and also give at most 1 category field if there exists a keyword type field whose name is just like region, country, city or currency, if not exist, the category field is empty string, note the category field must be keyword type. Show me a format of keyed and pipe-delimited list wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}. \n\nAssistant:\" turn\"" + }, + "name": "CreateAnomalyDetectorTool", + "type": "CreateAnomalyDetectorTool" + } + }, + { + "id": "anomaly_detector_suggestion_agent", + "type": "register_agent", + "previous_node_inputs": { + "create_anomoly_detectors_tool": "tools" + }, + "user_inputs": { + "parameters": {}, + "type": "flow", + "name": "Anomaly detector suggestion agent", + "description": "this is the anomaly detector suggestion agent" + } + } + ] + } + } +} diff --git a/sample-templates/anomaly-detector-suggestion-agent-claude.yml b/sample-templates/anomaly-detector-suggestion-agent-claude.yml new file mode 100644 index 000000000..5f715f533 --- /dev/null +++ b/sample-templates/anomaly-detector-suggestion-agent-claude.yml @@ -0,0 +1,94 @@ +--- +name: Anomaly detector suggestion agent +description: Create an anomaly detector suggestion agent using Claude on BedRock +use_case: REGISTER_AGENT +version: + template: 1.0.0 + compatibility: + - 2.16.0 + - 2.17.0 + - 3.0.0 +workflows: + provision: + user_params: {} + nodes: + - id: create_claude_connector + type: create_connector + previous_node_inputs: {} + user_inputs: + credential: + access_key: "" + secret_key: "" + session_token: "" + parameters: + endpoint: bedrock-runtime.us-west-2.amazonaws.com + content_type: application/json + auth: Sig_V4 + max_tokens_to_sample: '8000' + service_name: bedrock + temperature: 0 + response_filter: "$.completion" + region: us-west-2 + anthropic_version: bedrock-2023-05-31 + version: '1' + name: Claude instant runtime Connector + protocol: aws_sigv4 + description: The connector to BedRock service for claude model + actions: + - headers: + x-amz-content-sha256: required + content-type: application/json + method: POST + request_body: '{"prompt":"${parameters.prompt}", "max_tokens_to_sample":${parameters.max_tokens_to_sample}, + "temperature":${parameters.temperature}, "anthropic_version":"${parameters.anthropic_version}" + }' + action_type: predict + url: https://bedrock-runtime.us-west-2.amazonaws.com/model/anthropic.claude-instant-v1/invoke + - id: register_claude_model + type: register_remote_model + previous_node_inputs: + create_claude_connector: connector_id + user_inputs: + name: claude-instant + description: Claude model + deploy: true + - id: create_anomoly_detectors_tool + type: create_tool + previous_node_inputs: + register_claude_model: model_id + user_inputs: + parameters: + model_type: '' + prompt: "Human:\" turn\": Here are some examples of the create anomaly detector + API in OpenSearch: Example 1. POST _plugins/_anomaly_detection/detectors, + {\"time_field\":\"timestamp\",\"indices\":[\"ecommerce\"],\"feature_attributes\":[{\"feature_name\":\"feature1\",\"aggregation_query\":{\"avg_total_revenue\":{\"avg\":{\"field\":\"total_revenue_usd\"}}}},{\"feature_name\":\"feature2\",\"aggregation_query\":{\"max_total_revenue\":{\"max\":{\"field\":\"total_revenue_usd\"}}}}]}, + Example 2. POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"@timestamp\",\"indices\":[\"access_log*\"],\"feature_attributes\":[{\"feature_name\":\"feature1\",\"feature_enabled\":true,\"aggregation_query\":{\"latencyAvg\":{\"sum\":{\"field\":\"responseLatency\"}}}}]} + and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: + ${indexInfo.indexMapping}, and the optional aggregation methods are value_count, + avg, min, max and sum, note that value_count can perform on both numeric + and keyword type fields, and other aggregation methods can only perform + on numeric type fields. Please give me some suggestion about creating + an anomaly detector for the index ${indexInfo.indexName}, you need to + give the key information: the top 3 suitable aggregation fields which + are numeric types(long, integer, double, float, short etc.) and the suitable + aggregation method for each field, you should give at most 3 aggregation + fields and corresponding aggregation methods, if there are no numeric + type fields, both the aggregation field and method are empty string, and + also give at most 1 category field if there exists a keyword type field + whose name is just like region, country, city or currency, if not exist, + the category field is empty string, note the category field must be keyword + type. Show me a format of keyed and pipe-delimited list wrapped in a curly + bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited + list of all the aggregation field names|aggregation_method=comma-delimited + list of all the aggregation methods}. \n\nAssistant:\" turn\"" + name: CreateAnomalyDetectorTool + type: CreateAnomalyDetectorTool + - id: anomaly_detector_suggestion_agent + type: register_agent + previous_node_inputs: + create_anomoly_detectors_tool: tools + user_inputs: + parameters: {} + type: flow + name: Anomaly detector suggestion agent + description: this is the anomaly detector suggestion agent diff --git a/sample-templates/create-knowledge-base-alert-agent.json b/sample-templates/create-knowledge-base-alert-agent.json new file mode 100644 index 000000000..da0f60c40 --- /dev/null +++ b/sample-templates/create-knowledge-base-alert-agent.json @@ -0,0 +1,93 @@ +{ + "name": "Olly II Agents", + "description": "This template is to create all Agents required for olly II features ", + "use_case": "REGISTER_AGENTS", + "version": { + "template": "1.0.0", + "compatibility": [ + "2.15.0", + "3.0.0" + ] + }, + "workflows": { + "provision": { + "user_params": {}, + "nodes": [ + { + "id": "create_knowledge_base_connector", + "type": "create_connector", + "previous_node_inputs": {}, + "user_inputs": { + "name": "Amazon Bedrock Connector: knowledge base", + "description": "The connector to the Bedrock knowledge base", + "version": "1", + "protocol": "aws_sigv4", + "parameters": { + "region": "us-west-2", + "service_name": "bedrock", + "knowledgeBaseId": "PUT_YOUR_KNOWLEDGE_BASE_ID_HERE", + "model_arn": "arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-3-sonnet-20240229-v1:0" + }, + "credential": { + "access_key": "PUT_YOUR_ACCESS_KEY_HERE", + "secret_key": "PUT_YOUR_SECRET_KEY_HERE" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-agent-runtime.us-west-2.amazonaws.com/retrieveAndGenerate", + "headers": { + "content-type": "application/json" + }, + "request_body": "{\"input\": {\"text\": \"${parameters.text}\"}, \"retrieveAndGenerateConfiguration\": {\"type\": \"KNOWLEDGE_BASE\", \"knowledgeBaseConfiguration\": {\"knowledgeBaseId\": \"${parameters.knowledgeBaseId}\", \"modelArn\": \"${parameters.model_arn}\"}}}", + "post_process_function": "return params.output.text;" + } + ] + } + }, + { + "id": "register_knowledge_base_model", + "type": "register_remote_model", + "previous_node_inputs": { + "create_knowledge_base_connector": "connector_id" + }, + "user_inputs": { + "name": "Claude model on bedrock", + "function_name": "remote", + "version": "1.0.0", + "description": "Claude model on bedrock", + "deploy": "true" + } + }, + { + "id": "create_kb_ml_model_tool", + "type": "create_tool", + "previous_node_inputs": { + "register_knowledge_base_model": "model_id" + }, + "user_inputs": { + "parameters": { + "text": "You are an OpenSearch Alert Assistant to provide your insight on this alert to help users understand the alert, find potential causes and give feasible solutions to address it.\n Here is the detail of alert: ${parameters.context};\n The alert summary is: ${parameters.summary};\n The question is: ${parameters.question}." + }, + "name": "MLModelTool", + "type": "MLModelTool" + } + }, + { + "id": "create_knowledge_base_agent", + "type": "register_agent", + "previous_node_inputs": { + "create_kb_ml_model_tool": "tools" + }, + "user_inputs": { + "parameters": {}, + "type": "flow", + "name": "Bedrock knowledge base agent", + "description": "this is an agent to call retrieveAndGenerate API in bedrock knowledge base suggestion agent" + } + } + ] + } + } +} diff --git a/sample-templates/create-knowledge-base-alert-agent.yml b/sample-templates/create-knowledge-base-alert-agent.yml new file mode 100644 index 000000000..e4126cd70 --- /dev/null +++ b/sample-templates/create-knowledge-base-alert-agent.yml @@ -0,0 +1,83 @@ +# This template creates a connector to the BedRock service for Knowledge base +# It then registers a model using the connector and deploys it. +# Finally, it creates a flow agent base agent with ML Model tool to access the knowledge base. +# +# To use: +# - update the "credential" and "knowledgeBaseId" fields under the create_knowledge_base_connector node. +# - if needed, update region +# +# After provisioning: +# - returns a workflow ID +# - use the status API to get the deployed agent ID +--- +name: Olly II Agents +description: 'This template is to create all Agents required for olly II features ' +use_case: REGISTER_AGENTS +version: + template: 1.0.0 + compatibility: + - 2.15.0 + - 3.0.0 +workflows: + provision: + user_params: {} + nodes: + - id: create_knowledge_base_connector + type: create_connector + previous_node_inputs: {} + user_inputs: + name: 'Amazon Bedrock Connector: knowledge base' + description: The connector to the Bedrock knowledge base + version: '1' + protocol: aws_sigv4 + parameters: + region: us-west-2 + service_name: bedrock + knowledgeBaseId: PUT_YOUR_KNOWLEDGE_BASE_ID_HERE + model_arn: arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-3-sonnet-20240229-v1:0 + credential: + access_key: PUT_YOUR_ACCESS_KEY_HERE + secret_key: PUT_YOUR_SECRET_KEY_HERE + actions: + - action_type: predict + method: POST + url: https://bedrock-agent-runtime.us-west-2.amazonaws.com/retrieveAndGenerate + headers: + content-type: application/json + request_body: '{"input": {"text": "${parameters.text}"}, "retrieveAndGenerateConfiguration": + {"type": "KNOWLEDGE_BASE", "knowledgeBaseConfiguration": {"knowledgeBaseId": + "${parameters.knowledgeBaseId}", "modelArn": "${parameters.model_arn}"}}}' + post_process_function: return params.output.text; + - id: register_knowledge_base_model + type: register_remote_model + previous_node_inputs: + create_knowledge_base_connector: connector_id + user_inputs: + name: Claude model on bedrock + function_name: remote + version: 1.0.0 + description: Claude model on bedrock + deploy: 'true' + - id: create_kb_ml_model_tool + type: create_tool + previous_node_inputs: + register_knowledge_base_model: model_id + user_inputs: + parameters: + text: |- + You are an OpenSearch Alert Assistant to provide your insight on this alert to help users understand the alert, find potential causes and give feasible solutions to address it. + Here is the detail of alert: ${parameters.context}; + The alert summary is: ${parameters.summary}; + The question is: ${parameters.question}. + name: MLModelTool + type: MLModelTool + - id: create_knowledge_base_agent + type: register_agent + previous_node_inputs: + create_kb_ml_model_tool: tools + user_inputs: + parameters: {} + type: flow + name: Bedrock knowledge base agent + description: this is an agent to call retrieveAndGenerate API in bedrock knowledge + base suggestion agent diff --git a/sample-templates/query-assist-data-summary-agent-claude-tested.json b/sample-templates/query-assist-data-summary-agent-claude-tested.json new file mode 100644 index 000000000..50c339c56 --- /dev/null +++ b/sample-templates/query-assist-data-summary-agent-claude-tested.json @@ -0,0 +1,94 @@ +{ + "name": "Query Assist Data Summary Agent", + "description": "Create Query Assist Data Summary Agent using Claude on BedRock", + "use_case": "REGISTER_AGENT", + "version": { + "template": "1.0.0", + "compatibility": ["2.17.0", "3.0.0"] + }, + "workflows": { + "provision": { + "user_params": {}, + "nodes": [ + { + "id": "create_claude_connector", + "type": "create_connector", + "previous_node_inputs": {}, + "user_inputs": { + "version": "1", + "name": "Claude instant runtime Connector", + "protocol": "aws_sigv4", + "description": "The connector to BedRock service for Claude model", + "actions": [ + { + "headers": { + "x-amz-content-sha256": "required", + "content-type": "application/json" + }, + "method": "POST", + "request_body": "{\"prompt\":\"${parameters.prompt}\", \"max_tokens_to_sample\":${parameters.max_tokens_to_sample}, \"temperature\":${parameters.temperature}, \"anthropic_version\":\"${parameters.anthropic_version}\" }", + "action_type": "predict", + "url": "https://bedrock-runtime.us-west-2.amazonaws.com/model/anthropic.claude-instant-v1/invoke" + } + ], + "credential": { + "access_key": "", + "secret_key": "", + "session_token": "" + }, + "parameters": { + "region": "us-west-2", + "endpoint": "bedrock-runtime.us-west-2.amazonaws.com", + "content_type": "application/json", + "auth": "Sig_V4", + "max_tokens_to_sample": "8000", + "service_name": "bedrock", + "temperature": "0.0001", + "response_filter": "$.completion", + "anthropic_version": "bedrock-2023-05-31" + } + } + }, + { + "id": "register_claude_model", + "type": "register_remote_model", + "previous_node_inputs": { + "create_claude_connector": "connector_id" + }, + "user_inputs": { + "description": "Claude model", + "deploy": true, + "name": "claude-instant" + } + }, + { + "id": "create_query_assist_data_summary_ml_model_tool", + "type": "create_tool", + "previous_node_inputs": { + "register_claude_model": "model_id" + }, + "user_inputs": { + "parameters": { + "prompt": "Human: You are an assistant that helps to summarize the data and provide data insights.\nThe data are queried from OpenSearch index through user's question which was translated into PPL query.\nHere is a sample PPL query: `source= | where = `.\nNow you are given ${parameters.sample_count} sample data out of ${parameters.total_count} total data.\nThe user's question is `${parameters.question}`, the translated PPL query is `${parameters.ppl}` and sample data are:\n```\n${parameters.sample_data}\n```\nCould you help provide a summary of the sample data and provide some useful insights with precise wording and in plain text format, do not use markdown format.\nYou don't need to echo my requirements in response.\n\nAssistant:" + }, + "name": "MLModelTool", + "type": "MLModelTool" + } + }, + { + "id": "create_query_assist_data_summary_agent", + "type": "register_agent", + "previous_node_inputs": { + "create_query_assist_data_summary_ml_model_tool": "tools" + }, + "user_inputs": { + "parameters": {}, + "type": "flow", + "name": "Query Assist Data Summary Agent", + "description": "this is an query assist data summary agent" + } + } + ] + } + } + } diff --git a/sample-templates/query-assist-data-summary-agent-claude-tested.yml b/sample-templates/query-assist-data-summary-agent-claude-tested.yml new file mode 100644 index 000000000..16c5036ab --- /dev/null +++ b/sample-templates/query-assist-data-summary-agent-claude-tested.yml @@ -0,0 +1,71 @@ +--- +name: Query Assist Data Summary Agent +description: Create Query Assist Data Summary Agent using Claude on BedRock +use_case: REGISTER_AGENT +version: + template: 1.0.0 + compatibility: + - 2.17.0 + - 3.0.0 +workflows: + provision: + user_params: {} + nodes: + - id: create_claude_connector + type: create_connector + previous_node_inputs: {} + user_inputs: + version: '1' + name: Claude instant runtime Connector + protocol: aws_sigv4 + description: The connector to BedRock service for Claude model + actions: + - headers: + x-amz-content-sha256: required + content-type: application/json + method: POST + request_body: '{"prompt":"${parameters.prompt}", "max_tokens_to_sample":${parameters.max_tokens_to_sample}, + "temperature":${parameters.temperature}, "anthropic_version":"${parameters.anthropic_version}" + }' + action_type: predict + url: https://bedrock-runtime.us-west-2.amazonaws.com/model/anthropic.claude-instant-v1/invoke + credential: + access_key: "" + secret_key: "" + session_token: "" + parameters: + region: us-west-2 + endpoint: bedrock-runtime.us-west-2.amazonaws.com + content_type: application/json + auth: Sig_V4 + max_tokens_to_sample: '8000' + service_name: bedrock + temperature: '0.0001' + response_filter: "$.completion" + anthropic_version: bedrock-2023-05-31 + - id: register_claude_model + type: register_remote_model + previous_node_inputs: + create_claude_connector: connector_id + user_inputs: + description: Claude model + deploy: true + name: claude-instant + - id: create_query_assist_data_summary_ml_model_tool + type: create_tool + previous_node_inputs: + register_claude_model: model_id + user_inputs: + parameters: + prompt: "Human: You are an assistant that helps to summarize the data and provide data insights.\nThe data are queried from OpenSearch index through user's question which was translated into PPL query.\nHere is a sample PPL query: `source= | where = `.\nNow you are given ${parameters.sample_count} sample data out of ${parameters.total_count} total data.\nThe user's question is `${parameters.question}`, the translated PPL query is `${parameters.ppl}` and sample data are:\n```\n${parameters.sample_data}\n```\nCould you help provide a summary of the sample data and provide some useful insights with precise wording and in plain text format, do not use markdown format.\nYou don't need to echo my requirements in response.\n\nAssistant:" + name: MLModelTool + type: MLModelTool + - id: create_query_assist_data_summary_agent + type: register_agent + previous_node_inputs: + create_alert_summary_ml_model_tool: tools + user_inputs: + parameters: {} + type: flow + name: Query Assist Data Summary Agent + description: this is an query assist data summary agent diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index f291cff1c..9c88788b3 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -164,6 +164,8 @@ private CommonValue() {} public static final String TOOLS_FIELD = "tools"; /** The tools order field for an agent */ public static final String TOOLS_ORDER_FIELD = "tools_order"; + /** The tools config field */ + public static final String CONFIG_FIELD = "config"; /** The memory field for an agent */ public static final String MEMORY_FIELD = "memory"; /** The app type field for an agent */ @@ -233,4 +235,9 @@ private CommonValue() {} public static final String CREATE_INGEST_PIPELINE_MODEL_ID = "create_ingest_pipeline.model_id"; /** The field name for reindex source index substitution */ public static final String REINDEX_SOURCE_INDEX = "reindex.source_index"; + + /**URI for the YAML file of the ML Commons API specification.*/ + public static final String ML_COMMONS_API_SPEC_YAML_URI = + "https://raw.githubusercontent.com/opensearch-project/opensearch-api-specification/refs/heads/main/spec/namespaces/ml.yaml"; + } diff --git a/src/main/java/org/opensearch/flowframework/exception/ApiSpecParseException.java b/src/main/java/org/opensearch/flowframework/exception/ApiSpecParseException.java new file mode 100644 index 000000000..ae77452c7 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/exception/ApiSpecParseException.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.exception; + +import org.opensearch.OpenSearchException; + +import java.util.List; + +/** + * Custom exception to be thrown when an error occurs during the parsing of an API specification. + */ +public class ApiSpecParseException extends OpenSearchException { + + /** + * Constructor with message. + * + * @param message The detail message. + */ + public ApiSpecParseException(String message) { + super(message); + } + + /** + * Constructor with message and cause. + * + * @param message The detail message. + * @param cause The cause of the exception. + */ + public ApiSpecParseException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Constructor with message and list of detailed errors. + * + * @param message The detail message. + * @param details The list of errors encountered during the parsing process. + */ + public ApiSpecParseException(String message, List details) { + super(message + ": " + String.join(", ", details)); + } +} diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index 02ef8a825..f05a162ff 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -10,7 +10,9 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessageFactory; import org.opensearch.ExceptionsHelper; +import org.opensearch.action.DocWriteRequest.OpType; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.admin.indices.mapping.put.PutMappingRequest; @@ -34,6 +36,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -45,12 +48,13 @@ import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.flowframework.workflow.WorkflowData; -import org.opensearch.script.Script; -import org.opensearch.script.ScriptType; +import org.opensearch.index.engine.VersionConflictEngineException; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.atomic.AtomicBoolean; @@ -80,6 +84,8 @@ public class FlowFrameworkIndicesHandler { private static final Map indexMappingUpdated = new HashMap<>(); private static final Map indexSettings = Map.of("index.auto_expand_replicas", "0-1"); private final NamedXContentRegistry xContentRegistry; + // Retries in case of simultaneous updates + private static final int RETRIES = 5; /** * constructor @@ -576,7 +582,39 @@ public void canDeleteWorkflowStateDoc( } /** - * Updates a document in the workflow state index + * Updates a complete document in the workflow state index + * @param documentId the document ID + * @param updatedDocument a complete document to update the global state index with + * @param listener action listener + */ + public void updateFlowFrameworkSystemIndexDoc( + String documentId, + ToXContentObject updatedDocument, + ActionListener listener + ) { + if (!doesIndexExist(WORKFLOW_STATE_INDEX)) { + String errorMessage = "Failed to update document " + documentId + " due to missing " + WORKFLOW_STATE_INDEX + " index"; + logger.error(errorMessage); + listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST)); + } else { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + UpdateRequest updateRequest = new UpdateRequest(WORKFLOW_STATE_INDEX, documentId); + XContentBuilder builder = XContentFactory.jsonBuilder(); + updatedDocument.toXContent(builder, null); + updateRequest.doc(builder); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + updateRequest.retryOnConflict(RETRIES); + client.update(updateRequest, ActionListener.runBefore(listener, context::restore)); + } catch (Exception e) { + String errorMessage = "Failed to update " + WORKFLOW_STATE_INDEX + " entry : " + documentId; + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); + } + } + } + + /** + * Updates a partial document in the workflow state index * @param documentId the document ID * @param updatedFields the fields to update the global state index with * @param listener action listener @@ -596,7 +634,7 @@ public void updateFlowFrameworkSystemIndexDoc( Map updatedContent = new HashMap<>(updatedFields); updateRequest.doc(updatedContent); updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - updateRequest.retryOnConflict(5); + updateRequest.retryOnConflict(RETRIES); // TODO: decide what condition can be considered as an update conflict and add retry strategy client.update(updateRequest, ActionListener.runBefore(listener, context::restore)); } catch (Exception e) { @@ -631,111 +669,151 @@ public void deleteFlowFrameworkSystemIndexDoc(String documentId, ActionListener< } /** - * Updates a document in the workflow state index - * @param indexName the index that we will be updating a document of. - * @param documentId the document ID - * @param script the given script to update doc - * @param listener action listener + * Adds a resource to the state index, including common exception handling + * @param currentNodeInputs Inputs to the current node + * @param nodeId current process node (workflow step) id + * @param workflowStepName the workflow step name that created the resource + * @param resourceId the id of the newly created resource + * @param listener the ActionListener for this step to handle completing the future after update */ - public void updateFlowFrameworkSystemIndexDocWithScript( - String indexName, - String documentId, - Script script, - ActionListener listener + public void addResourceToStateIndex( + WorkflowData currentNodeInputs, + String nodeId, + String workflowStepName, + String resourceId, + ActionListener listener ) { - if (!doesIndexExist(indexName)) { - String errorMessage = "Failed to update document for given workflow due to missing " + indexName + " index"; + String workflowId = currentNodeInputs.getWorkflowId(); + String resourceName = getResourceByWorkflowStep(workflowStepName); + ResourceCreated newResource = new ResourceCreated(workflowStepName, nodeId, resourceName, resourceId); + if (!doesIndexExist(WORKFLOW_STATE_INDEX)) { + String errorMessage = "Failed to update state for " + workflowId + " due to missing " + WORKFLOW_STATE_INDEX + " index"; logger.error(errorMessage); - listener.onFailure(new Exception(errorMessage)); + listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND)); } else { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - UpdateRequest updateRequest = new UpdateRequest(indexName, documentId); - // TODO: Also add ability to change other fields at the same time when adding detailed provision progress - updateRequest.script(script); - updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - updateRequest.retryOnConflict(3); - // TODO: Implement our own concurrency control to improve on retry mechanism - client.update(updateRequest, ActionListener.runBefore(listener, context::restore)); - } catch (Exception e) { - String errorMessage = "Failed to update " + indexName + " entry : " + documentId; - logger.error(errorMessage, e); - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); + getAndUpdateResourceInStateDocumentWithRetries( + workflowId, + newResource, + OpType.INDEX, + RETRIES, + ActionListener.runBefore(listener, context::restore) + ); } } } /** - * Creates a new ResourceCreated object and a script to update the state index - * @param workflowId workflowId for the relevant step - * @param nodeId current process node (workflow step) id - * @param workflowStepName the workflowstep name that created the resource - * @param resourceId the id of the newly created resource + * Removes a resource from the state index, including common exception handling + * @param workflowId The workflow document id in the state index + * @param resourceToDelete The resource to delete * @param listener the ActionListener for this step to handle completing the future after update - * @throws IOException if parsing fails on new resource */ - private void updateResourceInStateIndex( - String workflowId, - String nodeId, - String workflowStepName, - String resourceId, - ActionListener listener - ) throws IOException { - ResourceCreated newResource = new ResourceCreated( - workflowStepName, - nodeId, - getResourceByWorkflowStep(workflowStepName), - resourceId - ); - - // The script to append a new object to the resources_created array - Script script = new Script( - ScriptType.INLINE, - "painless", - "ctx._source.resources_created.add(params.newResource);", - Collections.singletonMap("newResource", newResource.resourceMap()) - ); - - updateFlowFrameworkSystemIndexDocWithScript(WORKFLOW_STATE_INDEX, workflowId, script, ActionListener.wrap(updateResponse -> { - logger.info("updated resources created of {}", workflowId); - listener.onResponse(updateResponse); - }, listener::onFailure)); + public void deleteResourceFromStateIndex(String workflowId, ResourceCreated resourceToDelete, ActionListener listener) { + if (!doesIndexExist(WORKFLOW_STATE_INDEX)) { + String errorMessage = "Failed to update state for " + workflowId + " due to missing " + WORKFLOW_STATE_INDEX + " index"; + logger.error(errorMessage); + listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND)); + } else { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + getAndUpdateResourceInStateDocumentWithRetries( + workflowId, + resourceToDelete, + OpType.DELETE, + RETRIES, + ActionListener.runBefore(listener, context::restore) + ); + } + } } /** - * Adds a resource to the state index, including common exception handling - * @param currentNodeInputs Inputs to the current node - * @param nodeId current process node (workflow step) id - * @param workflowStepName the workflow step name that created the resource - * @param resourceId the id of the newly created resource - * @param listener the ActionListener for this step to handle completing the future after update + * Performs a get and update of a State Index document adding or removing a resource with strong consistency and retries + * @param workflowId The document id to update + * @param resource The resource to add or remove from the resources created list + * @param operation The operation to perform on the resource (INDEX to append to the list or DELETE to remove) + * @param retries The number of retries on update version conflicts + * @param listener The listener to complete on success or failure */ - public void addResourceToStateIndex( - WorkflowData currentNodeInputs, - String nodeId, - String workflowStepName, - String resourceId, + private void getAndUpdateResourceInStateDocumentWithRetries( + String workflowId, + ResourceCreated resource, + OpType operation, + int retries, ActionListener listener ) { - String resourceName = getResourceByWorkflowStep(workflowStepName); - try { - updateResourceInStateIndex( - currentNodeInputs.getWorkflowId(), - nodeId, - workflowStepName, - resourceId, - ActionListener.wrap(updateResponse -> { - logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); - listener.onResponse(new WorkflowData(Map.of(resourceName, resourceId), currentNodeInputs.getWorkflowId(), nodeId)); - }, exception -> { - String errorMessage = "Failed to update new created " + nodeId + " resource " + workflowStepName + " id " + resourceId; - logger.error(errorMessage, exception); - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); - }) + GetRequest getRequest = new GetRequest(WORKFLOW_STATE_INDEX, workflowId); + client.get(getRequest, ActionListener.wrap(getResponse -> { + if (!getResponse.isExists()) { + listener.onFailure(new FlowFrameworkException("Workflow state not found for " + workflowId, RestStatus.NOT_FOUND)); + return; + } + WorkflowState currentState = WorkflowState.parse(getResponse.getSourceAsString()); + List resourcesCreated = new ArrayList<>(currentState.resourcesCreated()); + if (operation == OpType.DELETE) { + resourcesCreated.removeIf(r -> r.resourceMap().equals(resource.resourceMap())); + } else { + resourcesCreated.add(resource); + } + XContentBuilder builder = XContentFactory.jsonBuilder(); + WorkflowState newState = WorkflowState.builder(currentState).resourcesCreated(resourcesCreated).build(); + newState.toXContent(builder, null); + UpdateRequest updateRequest = new UpdateRequest(WORKFLOW_STATE_INDEX, workflowId).doc(builder) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .setIfSeqNo(getResponse.getSeqNo()) + .setIfPrimaryTerm(getResponse.getPrimaryTerm()); + client.update( + updateRequest, + ActionListener.wrap( + r -> handleStateUpdateSuccess(workflowId, resource, operation, listener), + e -> handleStateUpdateException(workflowId, resource, operation, retries, listener, e) + ) ); - } catch (Exception e) { - String errorMessage = "Failed to parse and update new created resource"; - logger.error(errorMessage, e); - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); + }, ex -> handleStateUpdateException(workflowId, resource, operation, 0, listener, ex))); + } + + private void handleStateUpdateSuccess( + String workflowId, + ResourceCreated newResource, + OpType operation, + ActionListener listener + ) { + String resourceName = newResource.resourceType(); + String resourceId = newResource.resourceId(); + String nodeId = newResource.workflowStepId(); + logger.info( + "Updated resources created for {} on step {} to {} resource {} {}", + workflowId, + nodeId, + operation.equals(OpType.DELETE) ? "delete" : "add", + resourceName, + resourceId + ); + listener.onResponse(new WorkflowData(Map.of(resourceName, resourceId), workflowId, nodeId)); + } + + private void handleStateUpdateException( + String workflowId, + ResourceCreated newResource, + OpType operation, + int retries, + ActionListener listener, + Exception e + ) { + if (e instanceof VersionConflictEngineException && retries > 0) { + // Retry if we haven't exhausted retries + getAndUpdateResourceInStateDocumentWithRetries(workflowId, newResource, operation, retries - 1, listener); + return; } + String errorMessage = ParameterizedMessageFactory.INSTANCE.newMessage( + "Failed to update workflow state for {} on step {} to {} resource {} {}", + workflowId, + newResource.workflowStepId(), + operation.equals(OpType.DELETE) ? "delete" : "add", + newResource.resourceType(), + newResource.resourceId() + ).getFormattedMessage(); + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(e))); } } diff --git a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java index 1b58e66db..2b8db025c 100644 --- a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java @@ -47,6 +47,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import static org.opensearch.flowframework.common.CommonValue.ALLOW_DELETE; @@ -214,19 +215,32 @@ private void executeDeprovisionSequence( // Repeat attempting to delete resources as long as at least one is successful int resourceCount = deprovisionProcessSequence.size(); while (resourceCount > 0) { + PlainActionFuture stateUpdateFuture; Iterator iter = deprovisionProcessSequence.iterator(); - while (iter.hasNext()) { + do { ProcessNode deprovisionNode = iter.next(); ResourceCreated resource = getResourceFromDeprovisionNode(deprovisionNode, resourcesCreated); String resourceNameAndId = getResourceNameAndId(resource); PlainActionFuture deprovisionFuture = deprovisionNode.execute(); + stateUpdateFuture = PlainActionFuture.newFuture(); try { deprovisionFuture.get(); logger.info("Successful {} for {}", deprovisionNode.id(), resourceNameAndId); + // Remove from state index resource list + flowFrameworkIndicesHandler.deleteResourceFromStateIndex(workflowId, resource, stateUpdateFuture); + try { + // Wait at most 1 second for state index update. + stateUpdateFuture.actionGet(1, TimeUnit.SECONDS); + } catch (Exception e) { + // Ignore incremental resource removal failures (or timeouts) as we catch up at the end with remainingResources + } // Remove from list so we don't try again iter.remove(); // Pause briefly before next step Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + break; } catch (Throwable t) { // If any deprovision fails due to not found, it's a success if (t.getCause() instanceof OpenSearchStatusException @@ -238,7 +252,7 @@ private void executeDeprovisionSequence( logger.info("Failed {} for {}", deprovisionNode.id(), resourceNameAndId); } } - } + } while (iter.hasNext()); if (deprovisionProcessSequence.size() < resourceCount) { // If we've deleted something, decrement and try again if not zero resourceCount = deprovisionProcessSequence.size(); @@ -259,6 +273,7 @@ private void executeDeprovisionSequence( try { Thread.sleep(1000); } catch (InterruptedException e) { + Thread.currentThread().interrupt(); break; } } else { @@ -274,6 +289,7 @@ private void executeDeprovisionSequence( if (!deleteNotAllowed.isEmpty()) { logger.info("Resources requiring allow_delete: {}.", deleteNotAllowed); } + // This is a redundant best-effort backup to the incremental deletion done earlier updateWorkflowState(workflowId, remainingResources, deleteNotAllowed, listener); } diff --git a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java index 8d024d180..867c61f60 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java @@ -31,13 +31,12 @@ import org.opensearch.flowframework.model.State; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.flowframework.util.EncryptorUtils; import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.plugins.PluginsService; -import org.opensearch.script.Script; -import org.opensearch.script.ScriptType; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -58,7 +57,6 @@ import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.RESOURCES_CREATED_FIELD; import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; -import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FILTER_BY_BACKEND_ROLES; import static org.opensearch.flowframework.util.ParseUtils.getUserContext; import static org.opensearch.flowframework.util.ParseUtils.resolveUserAndExecute; @@ -210,24 +208,14 @@ private void executeReprovisionRequest( // Remove error field if any prior to subsequent execution if (response.getWorkflowState().getError() != null) { - Script script = new Script( - ScriptType.INLINE, - "painless", - "if(ctx._source.containsKey('error')){ctx._source.remove('error')}", - Collections.emptyMap() - ); - flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDocWithScript( - WORKFLOW_STATE_INDEX, - workflowId, - script, - ActionListener.wrap(updateResponse -> { - - }, exception -> { - String errorMessage = "Failed to update workflow state: " + workflowId; - logger.error(errorMessage, exception); - listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); - }) - ); + WorkflowState newState = WorkflowState.builder(response.getWorkflowState()).error(null).build(); + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc(workflowId, newState, ActionListener.wrap(updateResponse -> { + + }, exception -> { + String errorMessage = "Failed to update workflow state: " + workflowId; + logger.error(errorMessage, exception); + listener.onFailure(new FlowFrameworkException(errorMessage, ExceptionsHelper.status(exception))); + })); } // Update State Index, maintain resources created for subsequent execution @@ -282,12 +270,28 @@ private void executeWorkflowAsync( ActionListener listener ) { try { - threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL).execute(() -> { executeWorkflow(template, workflowSequence, workflowId); }); + threadPool.executor(PROVISION_WORKFLOW_THREAD_POOL).execute(() -> { + updateTemplate(template, workflowId); + executeWorkflow(template, workflowSequence, workflowId); + }); } catch (Exception exception) { listener.onFailure(new FlowFrameworkException("Failed to execute workflow " + workflowId, ExceptionsHelper.status(exception))); } } + /** + * Replace template document + * @param template The template to store after reprovisioning completes successfully + * @param workflowId The workflowId associated with the workflow that is executing + */ + private void updateTemplate(Template template, String workflowId) { + flowFrameworkIndicesHandler.updateTemplateInGlobalContext(workflowId, template, ActionListener.wrap(templateResponse -> { + logger.info("Updated template for {}", workflowId); + }, exception -> { logger.error("Failed to update use case template for {}", workflowId, exception); }), + true // ignores NOT_STARTED state if request is to reprovision + ); + } + /** * Executes the given workflow sequence * @param template The template to store after reprovisioning completes successfully @@ -301,8 +305,9 @@ private void executeWorkflow(Template template, List workflowSequen for (ProcessNode processNode : workflowSequence) { List predecessors = processNode.predecessors(); logger.info( - "Queueing process [{}].{}", + "Queueing Process [{} (type: {})].{}", processNode.id(), + processNode.workflowStep().getName(), predecessors.isEmpty() ? " Can start immediately!" : String.format( @@ -333,18 +338,6 @@ private void executeWorkflow(Template template, List workflowSequen logger.info("updated workflow {} state to {}", workflowId, State.COMPLETED); - // Replace template document - flowFrameworkIndicesHandler.updateTemplateInGlobalContext( - workflowId, - template, - ActionListener.wrap(templateResponse -> { - logger.info("Updated template for {}", workflowId, State.COMPLETED); - }, exception -> { - String errorMessage = "Failed to update use case template for " + workflowId; - logger.error(errorMessage, exception); - }), - true // ignores NOT_STARTED state if request is to reprovision - ); }, exception -> { logger.error("Failed to update workflow state for workflow {}", workflowId, exception); }) ); } catch (Exception ex) { diff --git a/src/main/java/org/opensearch/flowframework/util/ApiSpecFetcher.java b/src/main/java/org/opensearch/flowframework/util/ApiSpecFetcher.java new file mode 100644 index 000000000..12630b6c3 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/util/ApiSpecFetcher.java @@ -0,0 +1,119 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.util; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.flowframework.exception.ApiSpecParseException; +import org.opensearch.rest.RestRequest; + +import java.util.List; + +import io.swagger.v3.oas.models.OpenAPI; +import io.swagger.v3.oas.models.Operation; +import io.swagger.v3.oas.models.PathItem; +import io.swagger.v3.oas.models.media.Content; +import io.swagger.v3.oas.models.media.MediaType; +import io.swagger.v3.oas.models.media.Schema; +import io.swagger.v3.oas.models.parameters.RequestBody; +import io.swagger.v3.parser.OpenAPIV3Parser; +import io.swagger.v3.parser.core.models.ParseOptions; +import io.swagger.v3.parser.core.models.SwaggerParseResult; + +/** + * Utility class for fetching and parsing OpenAPI specifications. + */ +public class ApiSpecFetcher { + private static final Logger logger = LogManager.getLogger(ApiSpecFetcher.class); + private static final ParseOptions PARSE_OPTIONS = new ParseOptions(); + private static final OpenAPIV3Parser OPENAPI_PARSER = new OpenAPIV3Parser(); + + static { + PARSE_OPTIONS.setResolve(true); + PARSE_OPTIONS.setResolveFully(true); + } + + /** + * Parses the OpenAPI specification directly from the URI. + * + * @param apiSpecUri URI to the API specification (can be file path or web URI). + * @return Parsed OpenAPI object. + * @throws ApiSpecParseException If parsing fails. + */ + public static OpenAPI fetchApiSpec(String apiSpecUri) { + logger.info("Parsing API spec from URI: {}", apiSpecUri); + SwaggerParseResult result = OPENAPI_PARSER.readLocation(apiSpecUri, null, PARSE_OPTIONS); + OpenAPI openApi = result.getOpenAPI(); + + if (openApi == null) { + throw new ApiSpecParseException("Unable to parse spec from URI: " + apiSpecUri, result.getMessages()); + } + + return openApi; + } + + /** + * Compares the required fields in the API spec with the required enum parameters. + * + * @param requiredEnumParams List of required parameters from the enum. + * @param apiSpecUri URI of the API spec to fetch and compare. + * @param path The API path to check. + * @param method The HTTP method (POST, GET, etc.). + * @return boolean indicating if the required fields match. + */ + public static boolean compareRequiredFields(List requiredEnumParams, String apiSpecUri, String path, RestRequest.Method method) + throws IllegalArgumentException, ApiSpecParseException { + OpenAPI openAPI = fetchApiSpec(apiSpecUri); + + PathItem pathItem = openAPI.getPaths().get(path); + Content content = getContent(method, pathItem); + MediaType mediaType = content.get(XContentType.JSON.mediaTypeWithoutParameters()); + if (mediaType != null) { + Schema schema = mediaType.getSchema(); + + List requiredApiParams = schema.getRequired(); + if (requiredApiParams != null && !requiredApiParams.isEmpty()) { + return requiredApiParams.stream().allMatch(requiredEnumParams::contains); + } + } + return false; + } + + private static Content getContent(RestRequest.Method method, PathItem pathItem) throws IllegalArgumentException, ApiSpecParseException { + Operation operation; + switch (method) { + case POST: + operation = pathItem.getPost(); + break; + case GET: + operation = pathItem.getGet(); + break; + case PUT: + operation = pathItem.getPut(); + break; + case DELETE: + operation = pathItem.getDelete(); + break; + default: + throw new IllegalArgumentException("Unsupported HTTP method: " + method); + } + + if (operation == null) { + throw new IllegalArgumentException("No operation found for the specified method: " + method); + } + + RequestBody requestBody = operation.getRequestBody(); + if (requestBody == null) { + throw new ApiSpecParseException("No requestBody defined for this operation."); + } + + return requestBody.getContent(); + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java index 45e2ee240..9d13c6953 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java @@ -17,10 +17,13 @@ import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.common.agent.MLToolSpec; +import java.util.Collections; +import java.util.HashMap; import java.util.Map; import java.util.Optional; import java.util.Set; +import static org.opensearch.flowframework.common.CommonValue.CONFIG_FIELD; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.INCLUDE_OUTPUT_IN_AGENT_RESPONSE; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; @@ -38,7 +41,21 @@ public class ToolStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(ToolStep.class); PlainActionFuture toolFuture = PlainActionFuture.newFuture(); - static final String NAME = "create_tool"; + + /** The name of this step, used as a key in the template and the {@link WorkflowStepFactory} */ + public static final String NAME = "create_tool"; + /** Required input keys */ + public static final Set REQUIRED_INPUTS = Set.of(TYPE); + /** Optional input keys */ + public static final Set OPTIONAL_INPUTS = Set.of( + NAME_FIELD, + DESCRIPTION_FIELD, + PARAMETERS_FIELD, + CONFIG_FIELD, + INCLUDE_OUTPUT_IN_AGENT_RESPONSE + ); + /** Provided output keys */ + public static final Set PROVIDED_OUTPUTS = Set.of(TOOLS_FIELD); @Override public PlainActionFuture execute( @@ -48,13 +65,10 @@ public PlainActionFuture execute( Map previousNodeInputs, Map params ) { - Set requiredKeys = Set.of(TYPE); - Set optionalKeys = Set.of(NAME_FIELD, DESCRIPTION_FIELD, PARAMETERS_FIELD, INCLUDE_OUTPUT_IN_AGENT_RESPONSE); - try { Map inputs = ParseUtils.getInputsFromPreviousSteps( - requiredKeys, - optionalKeys, + REQUIRED_INPUTS, + OPTIONAL_INPUTS, currentNodeInputs, outputs, previousNodeInputs, @@ -69,11 +83,13 @@ public PlainActionFuture execute( // parse connector_id, model_id and agent_id from previous node inputs Set toolParameterKeys = Set.of(CONNECTOR_ID, MODEL_ID, AGENT_ID); Map parameters = getToolsParametersMap( - inputs.get(PARAMETERS_FIELD), + inputs.getOrDefault(PARAMETERS_FIELD, new HashMap<>()), previousNodeInputs, outputs, toolParameterKeys ); + @SuppressWarnings("unchecked") + Map config = (Map) inputs.getOrDefault(CONFIG_FIELD, Collections.emptyMap()); MLToolSpec.MLToolSpecBuilder builder = MLToolSpec.builder(); @@ -90,6 +106,7 @@ public PlainActionFuture execute( if (includeOutputInAgentResponse != null) { builder.includeOutputInAgentResponse(includeOutputInAgentResponse); } + builder.configMap(config); MLToolSpec mlToolSpec = builder.build(); diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 9fc8baada..65e8dea78 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -52,7 +52,6 @@ import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.SOURCE_INDEX; import static org.opensearch.flowframework.common.CommonValue.SUCCESS; -import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD; import static org.opensearch.flowframework.common.CommonValue.TYPE; import static org.opensearch.flowframework.common.CommonValue.URL; import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; @@ -231,7 +230,7 @@ public enum WorkflowSteps { DELETE_AGENT(DeleteAgentStep.NAME, List.of(AGENT_ID), List.of(AGENT_ID), List.of(OPENSEARCH_ML), null), /** Create Tool Step */ - CREATE_TOOL(ToolStep.NAME, List.of(TYPE), List.of(TOOLS_FIELD), List.of(OPENSEARCH_ML), null), + CREATE_TOOL(ToolStep.NAME, ToolStep.REQUIRED_INPUTS, ToolStep.PROVIDED_OUTPUTS, List.of(OPENSEARCH_ML), null), /** Create Ingest Pipeline Step */ CREATE_INGEST_PIPELINE( diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java index 0e8abbfab..3b37cd94b 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java @@ -807,6 +807,17 @@ protected List getResourcesCreated(RestClient client, String wo TimeUnit.SECONDS ); + return getResourcesCreated(client, workflowId); + } + + /** + * Helper method retrieve any resources created incrementally without waiting for completion + * @param client the rest client + * @param workflowId the workflow id to retrieve resources from + * @return a list of created resources + * @throws Exception if the request fails + */ + protected List getResourcesCreated(RestClient client, String workflowId) throws Exception { Response response = getWorkflowStatus(client, workflowId, true); // Parse workflow state from response and retrieve resources created diff --git a/src/test/java/org/opensearch/flowframework/exception/ApiSpecParseExceptionTests.java b/src/test/java/org/opensearch/flowframework/exception/ApiSpecParseExceptionTests.java new file mode 100644 index 000000000..ab93bd66c --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/exception/ApiSpecParseExceptionTests.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.exception; + +import org.opensearch.OpenSearchException; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Arrays; +import java.util.List; + +public class ApiSpecParseExceptionTests extends OpenSearchTestCase { + + public void testApiSpecParseException() { + ApiSpecParseException exception = new ApiSpecParseException("API spec parsing failed"); + assertTrue(exception instanceof OpenSearchException); + assertEquals("API spec parsing failed", exception.getMessage()); + } + + public void testApiSpecParseExceptionWithCause() { + Throwable cause = new RuntimeException("Underlying issue"); + ApiSpecParseException exception = new ApiSpecParseException("API spec parsing failed", cause); + assertTrue(exception instanceof OpenSearchException); + assertEquals("API spec parsing failed", exception.getMessage()); + assertEquals(cause, exception.getCause()); + } + + public void testApiSpecParseExceptionWithDetailedErrors() { + String message = "API spec parsing failed"; + List details = Arrays.asList("Missing required field", "Invalid type"); + ApiSpecParseException exception = new ApiSpecParseException(message, details); + assertTrue(exception instanceof OpenSearchException); + String expectedMessage = "API spec parsing failed: Missing required field, Invalid type"; + assertEquals(expectedMessage, exception.getMessage()); + } + +} diff --git a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java index ecaec46b5..a7dd7f75e 100644 --- a/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java +++ b/src/test/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandlerTests.java @@ -35,6 +35,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.TestHelpers; import org.opensearch.flowframework.common.WorkflowResources; @@ -47,6 +48,7 @@ import org.opensearch.flowframework.workflow.CreateConnectorStep; import org.opensearch.flowframework.workflow.CreateIndexStep; import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.index.get.GetResult; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -445,6 +447,63 @@ public void testUpdateFlowFrameworkSystemIndexDoc() throws IOException { ); } + public void testUpdateFlowFrameworkSystemIndexFullDoc() throws IOException { + ClusterState mockClusterState = mock(ClusterState.class); + Metadata mockMetaData = mock(Metadata.class); + when(clusterService.state()).thenReturn(mockClusterState); + when(mockClusterState.metadata()).thenReturn(mockMetaData); + when(mockMetaData.hasIndex(WORKFLOW_STATE_INDEX)).thenReturn(true); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + // test success + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, Result.UPDATED)); + return null; + }).when(client).update(any(UpdateRequest.class), any()); + + ToXContentObject fooBar = new ToXContentObject() { + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + xContentBuilder.field("foo", "bar"); + xContentBuilder.endObject(); + return builder; + } + }; + + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc("1", fooBar, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(Result.UPDATED, responseCaptor.getValue().getResult()); + + // test failure + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(new Exception("Failed to update state")); + return null; + }).when(client).update(any(UpdateRequest.class), any()); + + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc("1", fooBar, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Failed to update state", exceptionCaptor.getValue().getMessage()); + + // test no index + when(mockMetaData.hasIndex(WORKFLOW_STATE_INDEX)).thenReturn(false); + flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc("1", fooBar, listener); + + verify(listener, times(2)).onFailure(exceptionCaptor.capture()); + assertEquals( + "Failed to update document 1 due to missing .plugins-flow-framework-state index", + exceptionCaptor.getValue().getMessage() + ); + } + public void testDeleteFlowFrameworkSystemIndexDoc() throws IOException { ClusterState mockClusterState = mock(ClusterState.class); Metadata mockMetaData = mock(Metadata.class); @@ -492,7 +551,7 @@ public void testDeleteFlowFrameworkSystemIndexDoc() throws IOException { ); } - public void testAddResourceToStateIndex() throws IOException { + public void testAddResourceToStateIndex() { ClusterState mockClusterState = mock(ClusterState.class); Metadata mockMetaData = mock(Metadata.class); when(clusterService.state()).thenReturn(mockClusterState); @@ -502,6 +561,16 @@ public void testAddResourceToStateIndex() throws IOException { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); // test success + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + XContentBuilder builder = XContentFactory.jsonBuilder(); + WorkflowState state = WorkflowState.builder().build(); + state.toXContent(builder, null); + BytesReference workflowBytesRef = BytesReference.bytes(builder); + GetResult getResult = new GetResult(WORKFLOW_STATE_INDEX, "this_id", 1, 1, 1, true, workflowBytesRef, null, null); + responseListener.onResponse(new GetResponse(getResult)); + return null; + }).when(client).get(any(GetRequest.class), any()); doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(1); responseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "this_id", -2, 0, 0, Result.UPDATED)); @@ -509,7 +578,7 @@ public void testAddResourceToStateIndex() throws IOException { }).when(client).update(any(UpdateRequest.class), any()); flowFrameworkIndicesHandler.addResourceToStateIndex( - new WorkflowData(Collections.emptyMap(), null, null), + new WorkflowData(Collections.emptyMap(), "this_id", null), "node_id", CreateConnectorStep.NAME, "this_id", @@ -528,7 +597,7 @@ public void testAddResourceToStateIndex() throws IOException { }).when(client).update(any(UpdateRequest.class), any()); flowFrameworkIndicesHandler.addResourceToStateIndex( - new WorkflowData(Collections.emptyMap(), null, null), + new WorkflowData(Collections.emptyMap(), "this_id", null), "node_id", CreateConnectorStep.NAME, "this_id", @@ -537,6 +606,313 @@ public void testAddResourceToStateIndex() throws IOException { ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); - assertEquals("Failed to update new created node_id resource create_connector id this_id", exceptionCaptor.getValue().getMessage()); + assertEquals( + "Failed to update workflow state for this_id on step node_id to add resource connector_id this_id", + exceptionCaptor.getValue().getMessage() + ); + + // test document not found + @SuppressWarnings("unchecked") + ActionListener notFoundListener = mock(ActionListener.class); + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + GetResult getResult = new GetResult(WORKFLOW_STATE_INDEX, "this_id", -2, 0, 1, false, null, null, null); + responseListener.onResponse(new GetResponse(getResult)); + return null; + }).when(client).get(any(GetRequest.class), any()); + flowFrameworkIndicesHandler.addResourceToStateIndex( + new WorkflowData(Collections.emptyMap(), "this_id", null), + "node_id", + CreateConnectorStep.NAME, + "this_id", + notFoundListener + ); + + exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(notFoundListener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Workflow state not found for this_id", exceptionCaptor.getValue().getMessage()); + + // test index not found + when(mockMetaData.hasIndex(WORKFLOW_STATE_INDEX)).thenReturn(false); + @SuppressWarnings("unchecked") + ActionListener indexNotFoundListener = mock(ActionListener.class); + flowFrameworkIndicesHandler.addResourceToStateIndex( + new WorkflowData(Collections.emptyMap(), "this_id", null), + "node_id", + CreateConnectorStep.NAME, + "this_id", + indexNotFoundListener + ); + + exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(indexNotFoundListener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals( + "Failed to update state for this_id due to missing .plugins-flow-framework-state index", + exceptionCaptor.getValue().getMessage() + ); + } + + public void testDeleteResourceFromStateIndex() { + ClusterState mockClusterState = mock(ClusterState.class); + Metadata mockMetaData = mock(Metadata.class); + when(clusterService.state()).thenReturn(mockClusterState); + when(mockClusterState.metadata()).thenReturn(mockMetaData); + when(mockMetaData.hasIndex(WORKFLOW_STATE_INDEX)).thenReturn(true); + ResourceCreated resourceToDelete = new ResourceCreated("", "node_id", "connector_id", "this_id"); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + // test success + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + XContentBuilder builder = XContentFactory.jsonBuilder(); + WorkflowState state = WorkflowState.builder().build(); + state.toXContent(builder, null); + BytesReference workflowBytesRef = BytesReference.bytes(builder); + GetResult getResult = new GetResult(WORKFLOW_STATE_INDEX, "this_id", 1, 1, 1, true, workflowBytesRef, null, null); + responseListener.onResponse(new GetResponse(getResult)); + return null; + }).when(client).get(any(GetRequest.class), any()); + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "this_id", -2, 0, 0, Result.UPDATED)); + return null; + }).when(client).update(any(UpdateRequest.class), any()); + + flowFrameworkIndicesHandler.deleteResourceFromStateIndex("this_id", resourceToDelete, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowData.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals("this_id", responseCaptor.getValue().getContent().get(WorkflowResources.CONNECTOR_ID)); + + // test failure + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(new Exception("Failed to update state")); + return null; + }).when(client).update(any(UpdateRequest.class), any()); + + flowFrameworkIndicesHandler.deleteResourceFromStateIndex("this_id", resourceToDelete, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals( + "Failed to update workflow state for this_id on step node_id to delete resource connector_id this_id", + exceptionCaptor.getValue().getMessage() + ); + + // test document not found + @SuppressWarnings("unchecked") + ActionListener notFoundListener = mock(ActionListener.class); + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + GetResult getResult = new GetResult(WORKFLOW_STATE_INDEX, "this_id", -2, 0, 1, false, null, null, null); + responseListener.onResponse(new GetResponse(getResult)); + return null; + }).when(client).get(any(GetRequest.class), any()); + flowFrameworkIndicesHandler.deleteResourceFromStateIndex("this_id", resourceToDelete, notFoundListener); + + exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(notFoundListener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Workflow state not found for this_id", exceptionCaptor.getValue().getMessage()); + + // test index not found + when(mockMetaData.hasIndex(WORKFLOW_STATE_INDEX)).thenReturn(false); + @SuppressWarnings("unchecked") + ActionListener indexNotFoundListener = mock(ActionListener.class); + flowFrameworkIndicesHandler.deleteResourceFromStateIndex("this_id", resourceToDelete, indexNotFoundListener); + + exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(indexNotFoundListener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals( + "Failed to update state for this_id due to missing .plugins-flow-framework-state index", + exceptionCaptor.getValue().getMessage() + ); + } + + public void testAddResourceToStateIndexWithRetries() { + ClusterState mockClusterState = mock(ClusterState.class); + Metadata mockMetaData = mock(Metadata.class); + when(clusterService.state()).thenReturn(mockClusterState); + when(mockClusterState.metadata()).thenReturn(mockMetaData); + when(mockMetaData.hasIndex(WORKFLOW_STATE_INDEX)).thenReturn(true); + VersionConflictEngineException conflictException = new VersionConflictEngineException( + new ShardId(WORKFLOW_STATE_INDEX, "", 1), + "this_id", + null + ); + UpdateResponse updateResponse = new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "this_id", -2, 0, 0, Result.UPDATED); + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + XContentBuilder builder = XContentFactory.jsonBuilder(); + WorkflowState state = WorkflowState.builder().build(); + state.toXContent(builder, null); + BytesReference workflowBytesRef = BytesReference.bytes(builder); + GetResult getResult = new GetResult(WORKFLOW_STATE_INDEX, "this_id", 1, 1, 1, true, workflowBytesRef, null, null); + responseListener.onResponse(new GetResponse(getResult)); + return null; + }).when(client).get(any(GetRequest.class), any()); + + // test success on retry + @SuppressWarnings("unchecked") + ActionListener retryListener = mock(ActionListener.class); + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(conflictException); + return null; + }).doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), any()); + + flowFrameworkIndicesHandler.addResourceToStateIndex( + new WorkflowData(Collections.emptyMap(), "this_id", null), + "node_id", + CreateConnectorStep.NAME, + "this_id", + retryListener + ); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowData.class); + verify(retryListener, times(1)).onResponse(responseCaptor.capture()); + assertEquals("this_id", responseCaptor.getValue().getContent().get(WorkflowResources.CONNECTOR_ID)); + + // test failure on 6th after 5 retries even if 7th would have been success + @SuppressWarnings("unchecked") + ActionListener threeRetryListener = mock(ActionListener.class); + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(conflictException); + return null; + }).doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(conflictException); + return null; + }).doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(conflictException); + return null; + }).doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(conflictException); + return null; + }).doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(conflictException); + return null; + }).doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(conflictException); + return null; + }).doAnswer(invocation -> { + // we'll never get here + ActionListener responseListener = invocation.getArgument(1); + responseListener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), any()); + + flowFrameworkIndicesHandler.addResourceToStateIndex( + new WorkflowData(Collections.emptyMap(), "this_id", null), + "node_id", + CreateConnectorStep.NAME, + "this_id", + threeRetryListener + ); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(threeRetryListener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals( + "Failed to update workflow state for this_id on step node_id to add resource connector_id this_id", + exceptionCaptor.getValue().getMessage() + ); + } + + public void testDeleteResourceFromStateIndexWithRetries() { + ClusterState mockClusterState = mock(ClusterState.class); + Metadata mockMetaData = mock(Metadata.class); + when(clusterService.state()).thenReturn(mockClusterState); + when(mockClusterState.metadata()).thenReturn(mockMetaData); + when(mockMetaData.hasIndex(WORKFLOW_STATE_INDEX)).thenReturn(true); + VersionConflictEngineException conflictException = new VersionConflictEngineException( + new ShardId(WORKFLOW_STATE_INDEX, "", 1), + "this_id", + null + ); + UpdateResponse updateResponse = new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "this_id", -2, 0, 0, Result.UPDATED); + ResourceCreated resourceToDelete = new ResourceCreated("", "node_id", "connector_id", "this_id"); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + XContentBuilder builder = XContentFactory.jsonBuilder(); + WorkflowState state = WorkflowState.builder().build(); + state.toXContent(builder, null); + BytesReference workflowBytesRef = BytesReference.bytes(builder); + GetResult getResult = new GetResult(WORKFLOW_STATE_INDEX, "this_id", 1, 1, 1, true, workflowBytesRef, null, null); + responseListener.onResponse(new GetResponse(getResult)); + return null; + }).when(client).get(any(GetRequest.class), any()); + + // test success on retry + @SuppressWarnings("unchecked") + ActionListener retryListener = mock(ActionListener.class); + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(conflictException); + return null; + }).doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), any()); + + flowFrameworkIndicesHandler.deleteResourceFromStateIndex("this_id", resourceToDelete, retryListener); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowData.class); + verify(retryListener, times(1)).onResponse(responseCaptor.capture()); + assertEquals("this_id", responseCaptor.getValue().getContent().get(WorkflowResources.CONNECTOR_ID)); + + // test failure on 6th after 5 retries even if 7th would have been success + @SuppressWarnings("unchecked") + ActionListener threeRetryListener = mock(ActionListener.class); + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(conflictException); + return null; + }).doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(conflictException); + return null; + }).doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(conflictException); + return null; + }).doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(conflictException); + return null; + }).doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(conflictException); + return null; + }).doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + responseListener.onFailure(conflictException); + return null; + }).doAnswer(invocation -> { + // we'll never get here + ActionListener responseListener = invocation.getArgument(1); + responseListener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), any()); + + flowFrameworkIndicesHandler.deleteResourceFromStateIndex("this_id", resourceToDelete, threeRetryListener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(threeRetryListener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals( + "Failed to update workflow state for this_id on step node_id to delete resource connector_id this_id", + exceptionCaptor.getValue().getMessage() + ); } } diff --git a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java index 072a480dd..c437c32e3 100644 --- a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java +++ b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java @@ -299,8 +299,13 @@ public void testCreateAndProvisionAgentFrameworkWorkflow() throws Exception { assertNotNull(resourcesCreated.get(0).resourceId()); // Hit Deprovision API - // By design, this may not completely deprovision the first time if it takes >2s to process removals Response deprovisionResponse = deprovisionWorkflow(client(), workflowId); + // Test for incremental removal + assertBusy(() -> { + List resourcesRemaining = getResourcesCreated(client(), workflowId); + assertTrue(resourcesRemaining.size() < 5); + }, 30, TimeUnit.SECONDS); + // By design, this may not completely deprovision the first time if it takes >2s to process removals try { assertBusy( () -> { getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); }, @@ -446,7 +451,6 @@ public void testReprovisionWorkflow() throws Exception { assertTrue(getPipelineResponse.pipelines().get(0).getConfigAsMap().toString().contains(modelId)); // Reprovision template to add index which uses default ingest pipeline - Instant preUpdateTime = Instant.now(); // Store a timestamp template = TestHelpers.createTemplateFromFile("registerremotemodel-ingestpipeline-createindex.json"); response = reprovisionWorkflow(client(), workflowId, template); assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); @@ -464,17 +468,6 @@ public void testReprovisionWorkflow() throws Exception { Map indexSettings = getIndexSettingsAsMap(indexName); assertEquals(pipelineId, indexSettings.get("index.default_pipeline")); - // The template doesn't get updated until after the resources are created which can cause a race condition and flaky failure - // See https://github.com/opensearch-project/flow-framework/issues/870 - // Making sure the template got updated before reprovisioning again. - // Quick fix to stop this from being flaky, needs a more permanent fix to synchronize template update with COMPLETED provisioning - assertBusy(() -> { - Response r = getWorkflow(client(), workflowId); - assertEquals(RestStatus.OK.getStatus(), r.getStatusLine().getStatusCode()); - Template t = Template.parse(EntityUtils.toString(r.getEntity(), StandardCharsets.UTF_8)); - assertTrue(t.lastUpdatedTime().isAfter(preUpdateTime)); - }, 30, TimeUnit.SECONDS); - // Reprovision template to remove default ingest pipeline template = TestHelpers.createTemplateFromFile("registerremotemodel-ingestpipeline-updateindex.json"); response = reprovisionWorkflow(client(), workflowId, template); diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index 0d6168855..ba76bc833 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -71,6 +71,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; @@ -707,7 +708,7 @@ public void testUpdateWorkflow() throws IOException { ActionListener updateResponseListener = invocation.getArgument(2); updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); return null; - }).when(flowFrameworkIndicesHandler).updateFlowFrameworkSystemIndexDoc(anyString(), any(), any()); + }).when(flowFrameworkIndicesHandler).updateFlowFrameworkSystemIndexDoc(anyString(), anyMap(), any()); createWorkflowTransportAction.doExecute(mock(Task.class), updateWorkflow, listener); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); diff --git a/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java index 203255361..4841871aa 100644 --- a/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java @@ -175,6 +175,7 @@ public void testDeprovisionWorkflow() throws Exception { ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); verify(listener, times(1)).onResponse(responseCaptor.capture()); assertEquals(workflowId, responseCaptor.getValue().getWorkflowId()); + verify(flowFrameworkIndicesHandler, times(1)).deleteResourceFromStateIndex(anyString(), any(ResourceCreated.class), any()); } public void testFailToDeprovision() throws Exception { @@ -208,6 +209,7 @@ public void testFailToDeprovision() throws Exception { verify(listener, times(1)).onFailure(exceptionCaptor.capture()); assertEquals(RestStatus.ACCEPTED, exceptionCaptor.getValue().getRestStatus()); assertEquals("Failed to deprovision some resources: [model_id modelId].", exceptionCaptor.getValue().getMessage()); + verify(flowFrameworkIndicesHandler, times(0)).deleteResourceFromStateIndex(anyString(), any(ResourceCreated.class), any()); } public void testAllowDeleteRequired() throws Exception { @@ -248,6 +250,7 @@ public void testAllowDeleteRequired() throws Exception { "These resources require the allow_delete parameter to deprovision: [index_name test-index].", exceptionCaptor.getValue().getMessage() ); + verify(flowFrameworkIndicesHandler, times(0)).deleteResourceFromStateIndex(anyString(), any(ResourceCreated.class), any()); // Test (2nd) failure with wrong allow_delete param workflowRequest = new WorkflowRequest(workflowId, null, Map.of(ALLOW_DELETE, "wrong-index")); @@ -264,6 +267,7 @@ public void testAllowDeleteRequired() throws Exception { "These resources require the allow_delete parameter to deprovision: [index_name test-index].", exceptionCaptor.getValue().getMessage() ); + verify(flowFrameworkIndicesHandler, times(0)).deleteResourceFromStateIndex(anyString(), any(ResourceCreated.class), any()); // Test success with correct allow_delete param workflowRequest = new WorkflowRequest(workflowId, null, Map.of(ALLOW_DELETE, "wrong-index,test-index,other-index")); @@ -280,6 +284,7 @@ public void testAllowDeleteRequired() throws Exception { ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); verify(listener, times(1)).onResponse(responseCaptor.capture()); assertEquals(workflowId, responseCaptor.getValue().getWorkflowId()); + verify(flowFrameworkIndicesHandler, times(1)).deleteResourceFromStateIndex(anyString(), any(ResourceCreated.class), any()); } public void testFailToDeprovisionAndAllowDeleteRequired() throws Exception { @@ -323,5 +328,6 @@ public void testFailToDeprovisionAndAllowDeleteRequired() throws Exception { + " These resources require the allow_delete parameter to deprovision: [index_name test-index].", exceptionCaptor.getValue().getMessage() ); + verify(flowFrameworkIndicesHandler, times(0)).deleteResourceFromStateIndex(anyString(), any(ResourceCreated.class), any()); } } diff --git a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java index a6eacc069..623270a27 100644 --- a/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportActionTests.java @@ -54,6 +54,7 @@ import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -164,7 +165,7 @@ public void testProvisionWorkflow() { ActionListener actionListener = invocation.getArgument(2); actionListener.onResponse(mock(UpdateResponse.class)); return null; - }).when(flowFrameworkIndicesHandler).updateFlowFrameworkSystemIndexDoc(any(), any(), any()); + }).when(flowFrameworkIndicesHandler).updateFlowFrameworkSystemIndexDoc(any(), anyMap(), any()); doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(2); @@ -211,7 +212,7 @@ public void testProvisionWorkflowTwice() { ActionListener actionListener = invocation.getArgument(2); actionListener.onResponse(mock(UpdateResponse.class)); return null; - }).when(flowFrameworkIndicesHandler).updateFlowFrameworkSystemIndexDoc(any(), any(), any()); + }).when(flowFrameworkIndicesHandler).updateFlowFrameworkSystemIndexDoc(any(), anyMap(), any()); provisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java index e654b0482..6e1e65d3b 100644 --- a/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java @@ -44,6 +44,7 @@ import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; @@ -147,7 +148,7 @@ public void testReprovisionWorkflow() throws Exception { ActionListener actionListener = invocation.getArgument(2); actionListener.onResponse(mock(UpdateResponse.class)); return null; - }).when(flowFrameworkIndicesHandler).updateFlowFrameworkSystemIndexDoc(any(), any(), any()); + }).when(flowFrameworkIndicesHandler).updateFlowFrameworkSystemIndexDoc(any(), anyMap(), any()); @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); @@ -275,7 +276,7 @@ public void testFailedStateUpdate() throws Exception { ActionListener actionListener = invocation.getArgument(2); actionListener.onFailure(new Exception("failed")); return null; - }).when(flowFrameworkIndicesHandler).updateFlowFrameworkSystemIndexDoc(any(), any(), any()); + }).when(flowFrameworkIndicesHandler).updateFlowFrameworkSystemIndexDoc(any(), anyMap(), any()); @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); diff --git a/src/test/java/org/opensearch/flowframework/util/ApiSpecFetcherTests.java b/src/test/java/org/opensearch/flowframework/util/ApiSpecFetcherTests.java new file mode 100644 index 000000000..fb60ae08d --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/util/ApiSpecFetcherTests.java @@ -0,0 +1,130 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.util; + +import org.opensearch.flowframework.exception.ApiSpecParseException; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.junit.Before; + +import java.util.Arrays; +import java.util.List; + +import io.swagger.v3.oas.models.OpenAPI; + +import static org.opensearch.flowframework.common.CommonValue.ML_COMMONS_API_SPEC_YAML_URI; +import static org.opensearch.rest.RestRequest.Method.DELETE; +import static org.opensearch.rest.RestRequest.Method.PATCH; +import static org.opensearch.rest.RestRequest.Method.POST; +import static org.opensearch.rest.RestRequest.Method.PUT; + +public class ApiSpecFetcherTests extends OpenSearchTestCase { + + private ApiSpecFetcher apiSpecFetcher; + + @Before + public void setUp() throws Exception { + super.setUp(); + } + + public void testFetchApiSpecSuccess() throws Exception { + + OpenAPI result = ApiSpecFetcher.fetchApiSpec(ML_COMMONS_API_SPEC_YAML_URI); + + assertNotNull("The fetched OpenAPI spec should not be null", result); + } + + public void testFetchApiSpecThrowsException() throws Exception { + String invalidUri = "http://invalid-url.com/fail.yaml"; + + ApiSpecParseException exception = expectThrows(ApiSpecParseException.class, () -> { ApiSpecFetcher.fetchApiSpec(invalidUri); }); + + assertNotNull("Exception should be thrown for invalid URI", exception); + assertTrue(exception.getMessage().contains("Unable to parse spec")); + } + + public void testCompareRequiredFieldsSuccess() throws Exception { + + String path = "/_plugins/_ml/agents/_register"; + RestRequest.Method method = POST; + + // Assuming REGISTER_AGENT step in the enum has these required fields + List expectedRequiredParams = Arrays.asList("name", "type"); + + boolean comparisonResult = ApiSpecFetcher.compareRequiredFields(expectedRequiredParams, ML_COMMONS_API_SPEC_YAML_URI, path, method); + + assertTrue("The required fields should match between API spec and enum", comparisonResult); + } + + public void testCompareRequiredFieldsFailure() throws Exception { + + String path = "/_plugins/_ml/agents/_register"; + RestRequest.Method method = POST; + + List wrongRequiredParams = Arrays.asList("nonexistent_param"); + + boolean comparisonResult = ApiSpecFetcher.compareRequiredFields(wrongRequiredParams, ML_COMMONS_API_SPEC_YAML_URI, path, method); + + assertFalse("The required fields should not match for incorrect input", comparisonResult); + } + + public void testCompareRequiredFieldsThrowsException() throws Exception { + String invalidUri = "http://invalid-url.com/fail.yaml"; + String path = "/_plugins/_ml/agents/_register"; + RestRequest.Method method = PUT; + + Exception exception = expectThrows( + Exception.class, + () -> { ApiSpecFetcher.compareRequiredFields(List.of(), invalidUri, path, method); } + ); + + assertNotNull("An exception should be thrown for an invalid API spec Uri", exception); + assertTrue(exception.getMessage().contains("Unable to parse spec")); + } + + public void testUnsupportedMethodException() throws IllegalArgumentException { + Exception exception = expectThrows(Exception.class, () -> { + ApiSpecFetcher.compareRequiredFields( + List.of("name", "type"), + ML_COMMONS_API_SPEC_YAML_URI, + "/_plugins/_ml/agents/_register", + PATCH + ); + }); + + assertEquals("Unsupported HTTP method: PATCH", exception.getMessage()); + } + + public void testNoOperationFoundException() throws Exception { + Exception exception = expectThrows(IllegalArgumentException.class, () -> { + ApiSpecFetcher.compareRequiredFields( + List.of("name", "type"), + ML_COMMONS_API_SPEC_YAML_URI, + "/_plugins/_ml/agents/_register", + DELETE + ); + }); + + assertEquals("No operation found for the specified method: DELETE", exception.getMessage()); + } + + public void testNoRequestBodyDefinedException() throws ApiSpecParseException { + Exception exception = expectThrows(ApiSpecParseException.class, () -> { + ApiSpecFetcher.compareRequiredFields( + List.of("name", "type"), + ML_COMMONS_API_SPEC_YAML_URI, + "/_plugins/_ml/model_groups/{model_group_id}", + RestRequest.Method.GET + ); + }); + + assertEquals("No requestBody defined for this operation.", exception.getMessage()); + } + +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index dd9eb369d..3ed8d15ec 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -14,20 +14,24 @@ import org.opensearch.flowframework.common.CommonValue; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.util.ApiSpecFetcher; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; +import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import static org.opensearch.flowframework.common.CommonValue.ML_COMMONS_API_SPEC_YAML_URI; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; @@ -134,4 +138,17 @@ public void testCreateConnectorFailure() throws IOException { assertEquals("Failed to create connector", ex.getCause().getMessage()); } + public void testApiSpecCreateConnectorInputParamComparison() throws Exception { + List requiredEnumParams = WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.inputs(); + + boolean isMatch = ApiSpecFetcher.compareRequiredFields( + requiredEnumParams, + ML_COMMONS_API_SPEC_YAML_URI, + "/_plugins/_ml/connectors/_create", + RestRequest.Method.POST + ); + + assertTrue(isMatch); + } + } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java index 626dfdfa1..8def95f58 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java @@ -13,6 +13,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.util.ApiSpecFetcher; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.agent.LLMSpec; @@ -20,10 +21,12 @@ import org.opensearch.ml.common.agent.MLMemorySpec; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; +import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; @@ -31,6 +34,7 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import static org.opensearch.flowframework.common.CommonValue.ML_COMMONS_API_SPEC_YAML_URI; import static org.opensearch.flowframework.common.WorkflowResources.AGENT_ID; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; @@ -52,7 +56,13 @@ public void setUp() throws Exception { this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); MockitoAnnotations.openMocks(this); - MLToolSpec tools = new MLToolSpec("tool1", "CatIndexTool", "desc", Collections.emptyMap(), false); + MLToolSpec tools = MLToolSpec.builder() + .type("tool1") + .name("CatIndexTool") + .description("desc") + .parameters(Collections.emptyMap()) + .includeOutputInAgentResponse(false) + .build(); LLMSpec llmSpec = new LLMSpec("xyz", Collections.emptyMap()); @@ -150,4 +160,18 @@ public void testRegisterAgentFailure() throws IOException { assertTrue(ex.getCause() instanceof FlowFrameworkException); assertEquals("Failed to register the agent", ex.getCause().getMessage()); } + + public void testApiSpecRegisterAgentInputParamComparison() throws Exception { + List requiredEnumParams = WorkflowStepFactory.WorkflowSteps.REGISTER_AGENT.inputs(); + + boolean isMatch = ApiSpecFetcher.compareRequiredFields( + requiredEnumParams, + ML_COMMONS_API_SPEC_YAML_URI, + "/_plugins/_ml/agents/_register", + RestRequest.Method.POST + ); + + assertTrue(isMatch); + } + } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java index 6a6809d07..d42a1ae21 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java @@ -20,11 +20,13 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.util.ApiSpecFetcher; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; +import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ScalingExecutorBuilder; import org.opensearch.threadpool.TestThreadPool; @@ -33,6 +35,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -43,6 +46,7 @@ import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; +import static org.opensearch.flowframework.common.CommonValue.ML_COMMONS_API_SPEC_YAML_URI; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; @@ -398,4 +402,17 @@ public void testBoolParseFail() throws IOException, ExecutionException, Interrup assertEquals("Failed to parse value [no] as only [true] or [false] are allowed.", w.getMessage()); assertEquals(RestStatus.BAD_REQUEST, w.getRestStatus()); } + + public void testApiSpecRegisterLocalCustomModelInputParamComparison() throws Exception { + List requiredEnumParams = WorkflowStepFactory.WorkflowSteps.REGISTER_LOCAL_CUSTOM_MODEL.inputs(); + + boolean isMatch = ApiSpecFetcher.compareRequiredFields( + requiredEnumParams, + ML_COMMONS_API_SPEC_YAML_URI, + "/_plugins/_ml/models/_register", + RestRequest.Method.POST + ); + + assertTrue(isMatch); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java index 162b97dba..a0ef430b9 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java @@ -20,11 +20,13 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.util.ApiSpecFetcher; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; +import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ScalingExecutorBuilder; import org.opensearch.threadpool.TestThreadPool; @@ -33,6 +35,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -42,6 +45,7 @@ import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; +import static org.opensearch.flowframework.common.CommonValue.ML_COMMONS_API_SPEC_YAML_URI; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; @@ -303,4 +307,17 @@ public void testBoolParseFail() throws IOException, ExecutionException, Interrup assertEquals("Failed to parse value [no] as only [true] or [false] are allowed.", w.getMessage()); assertEquals(RestStatus.BAD_REQUEST, w.getRestStatus()); } + + public void testApiSpecRegisterLocalPretrainedModelInputParamComparison() throws Exception { + List requiredEnumParams = WorkflowStepFactory.WorkflowSteps.REGISTER_LOCAL_PRETRAINED_MODEL.inputs(); + + boolean isMatch = ApiSpecFetcher.compareRequiredFields( + requiredEnumParams, + ML_COMMONS_API_SPEC_YAML_URI, + "/_plugins/_ml/models/_register", + RestRequest.Method.POST + ); + + assertTrue(isMatch); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java index 79d7bb883..e3157b20b 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java @@ -20,11 +20,13 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.util.ApiSpecFetcher; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; +import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ScalingExecutorBuilder; import org.opensearch.threadpool.TestThreadPool; @@ -33,6 +35,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -42,6 +45,7 @@ import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; +import static org.opensearch.flowframework.common.CommonValue.ML_COMMONS_API_SPEC_YAML_URI; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; @@ -310,4 +314,17 @@ public void testBoolParseFail() throws IOException, ExecutionException, Interrup assertEquals("Failed to parse value [no] as only [true] or [false] are allowed.", w.getMessage()); assertEquals(RestStatus.BAD_REQUEST, w.getRestStatus()); } + + public void testApiSpecRegisterLocalSparseEncodingModelInputParamComparison() throws Exception { + List requiredEnumParams = WorkflowStepFactory.WorkflowSteps.REGISTER_LOCAL_SPARSE_ENCODING_MODEL.inputs(); + + boolean isMatch = ApiSpecFetcher.compareRequiredFields( + requiredEnumParams, + ML_COMMONS_API_SPEC_YAML_URI, + "/_plugins/_ml/models/_register", + RestRequest.Method.POST + ); + + assertTrue(isMatch); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelGroupStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelGroupStepTests.java index 7f7adf44b..05eeb8500 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelGroupStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelGroupStepTests.java @@ -14,11 +14,13 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.util.ApiSpecFetcher; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; +import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; @@ -31,6 +33,7 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import static org.opensearch.flowframework.common.CommonValue.ML_COMMONS_API_SPEC_YAML_URI; import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_STATUS; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_GROUP_ID; import static org.mockito.ArgumentMatchers.any; @@ -204,4 +207,17 @@ public void testBoolParseFail() throws IOException, ExecutionException, Interrup assertEquals("Failed to parse value [no] as only [true] or [false] are allowed.", w.getMessage()); assertEquals(RestStatus.BAD_REQUEST, w.getRestStatus()); } + + public void testApiSpecRegisterModelGroupInputParamComparison() throws Exception { + List requiredEnumParams = WorkflowStepFactory.WorkflowSteps.REGISTER_MODEL_GROUP.inputs(); + + boolean isMatch = ApiSpecFetcher.compareRequiredFields( + requiredEnumParams, + ML_COMMONS_API_SPEC_YAML_URI, + "/_plugins/_ml/model_groups/_register", + RestRequest.Method.POST + ); + + assertTrue(isMatch); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index 0e2ab91e9..362601264 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -17,15 +17,18 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.util.ApiSpecFetcher; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; +import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.RemoteTransportException; import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicInteger; @@ -35,6 +38,7 @@ import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.INTERFACE_FIELD; +import static org.opensearch.flowframework.common.CommonValue.ML_COMMONS_API_SPEC_YAML_URI; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; @@ -416,4 +420,17 @@ public void testBoolParseFail() throws IOException, ExecutionException, Interrup assertEquals("Failed to parse value [yes] as only [true] or [false] are allowed.", w.getMessage()); assertEquals(RestStatus.BAD_REQUEST, w.getRestStatus()); } + + public void testApiSpecRegisterRemoteModelInputParamComparison() throws Exception { + List requiredEnumParams = WorkflowStepFactory.WorkflowSteps.REGISTER_REMOTE_MODEL.inputs(); + + boolean isMatch = ApiSpecFetcher.compareRequiredFields( + requiredEnumParams, + ML_COMMONS_API_SPEC_YAML_URI, + "/_plugins/_ml/model_groups/_register", + RestRequest.Method.POST + ); + + assertTrue(isMatch); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java index 029b5c835..2b5e5b7fa 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java @@ -61,6 +61,7 @@ public void setUp() throws Exception { Map.entry("name", "name"), Map.entry("description", "description"), Map.entry("parameters", Collections.emptyMap()), + Map.entry("config", Map.of("foo", "bar")), Map.entry("include_output_in_agent_response", "false") ), "test-id", @@ -102,6 +103,7 @@ public void testTool() throws ExecutionException, InterruptedException { ); assertTrue(future.isDone()); assertEquals(MLToolSpec.class, future.get().getContent().get("tools").getClass()); + assertEquals(Map.of("foo", "bar"), ((MLToolSpec) future.get().getContent().get("tools")).getConfigMap()); } public void testBoolParseFail() {