diff --git a/CHANGELOG.md b/CHANGELOG.md index 987b3c945..4dd2b8a79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) - Add guardrails to default use case params ([#658](https://github.com/opensearch-project/flow-framework/pull/658)) - Allow strings for boolean workflow step parameters ([#671](https://github.com/opensearch-project/flow-framework/pull/671)) - Add optional delay parameter to no-op step ([#674](https://github.com/opensearch-project/flow-framework/pull/674)) +- Add model interface support for remote and local custom models ([#701](https://github.com/opensearch-project/flow-framework/pull/701)) ### Bug Fixes - Reset workflow state to initial state after successful deprovision ([#635](https://github.com/opensearch-project/flow-framework/pull/635)) diff --git a/release-notes/opensearch-flow-framework.release-notes-2.14.0.0.md b/release-notes/opensearch-flow-framework.release-notes-2.14.0.0.md index 92a15d60f..13b7e8982 100644 --- a/release-notes/opensearch-flow-framework.release-notes-2.14.0.0.md +++ b/release-notes/opensearch-flow-framework.release-notes-2.14.0.0.md @@ -6,6 +6,7 @@ Compatible with OpenSearch 2.14.0 - Add guardrails to default use case params ([#658](https://github.com/opensearch-project/flow-framework/pull/658)) - Allow strings for boolean workflow step parameters ([#671](https://github.com/opensearch-project/flow-framework/pull/671)) - Add optional delay parameter to no-op step ([#674](https://github.com/opensearch-project/flow-framework/pull/674)) +- Add model interface support for remote and local custom models ([#701](https://github.com/opensearch-project/flow-framework/pull/701)) ### Bug Fixes - Reset workflow state to initial state after successful deprovision ([#635](https://github.com/opensearch-project/flow-framework/pull/635)) diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index c3a7afb51..ac0291687 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -172,6 +172,8 @@ private CommonValue() {} public static final String GUARDRAILS_FIELD = "guardrails"; /** Delay field */ public static final String DELAY_FIELD = "delay"; + /** Model Interface Field */ + public static final String INTERFACE_FIELD = "interface"; /* * Constants associated with resource provisioning / state diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index 899167ac8..980e9755b 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -34,6 +34,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; import static org.opensearch.flowframework.common.CommonValue.GUARDRAILS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.INTERFACE_FIELD; import static org.opensearch.flowframework.common.CommonValue.TOOLS_ORDER_FIELD; import static org.opensearch.flowframework.util.ParseUtils.buildStringToObjectMap; import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap; @@ -164,7 +165,7 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { if (GUARDRAILS_FIELD.equals(inputFieldName)) { userInputs.put(inputFieldName, Guardrails.parse(parser)); break; - } else if (CONFIGURATIONS.equals(inputFieldName)) { + } else if (CONFIGURATIONS.equals(inputFieldName) || INTERFACE_FIELD.equals(inputFieldName)) { Map configurationsMap = parser.map(); try { String configurationsString = ParseUtils.parseArbitraryStringToObjectMapToString(configurationsMap); diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index ccf9ab686..b4b908aeb 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -451,4 +451,25 @@ public static String removingBackslashesAndQuotesInArrayInJsonString(String inpu matcher.appendTail(result); return result.toString(); } + + /** + * Takes a String to json object map and converts this to a String to String map + * @param stringToObjectMap The string to object map to be transformed + * @return the transformed map + * @throws Exception for issues processing map + */ + public static Map convertStringToObjectMapToStringToStringMap(Map stringToObjectMap) throws Exception { + try (Jsonb jsonb = JsonbBuilder.create()) { + Map stringToStringMap = new HashMap<>(); + for (Map.Entry entry : stringToObjectMap.entrySet()) { + Object value = entry.getValue(); + if (value instanceof String) { + stringToStringMap.put(entry.getKey(), (String) value); + } else { + stringToStringMap.put(entry.getKey(), jsonb.toJson(value)); + } + } + return stringToStringMap; + } + } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java index fe4e54b6a..80294a7cd 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java @@ -13,8 +13,12 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.common.Booleans; +import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.exception.WorkflowStepException; @@ -30,6 +34,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.threadpool.ThreadPool; +import java.nio.charset.StandardCharsets; import java.util.Map; import java.util.Set; import java.util.stream.Stream; @@ -40,6 +45,7 @@ import static org.opensearch.flowframework.common.CommonValue.EMBEDDING_DIMENSION; import static org.opensearch.flowframework.common.CommonValue.FRAMEWORK_TYPE; import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME; +import static org.opensearch.flowframework.common.CommonValue.INTERFACE_FIELD; import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE; import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE; @@ -116,6 +122,7 @@ public PlainActionFuture execute( String description = (String) inputs.get(DESCRIPTION_FIELD); String modelGroupId = (String) inputs.get(MODEL_GROUP_ID); String allConfig = (String) inputs.get(ALL_CONFIG); + String modelInterface = (String) inputs.get(INTERFACE_FIELD); final Boolean deploy = inputs.containsKey(DEPLOY_FIELD) ? Booleans.parseBoolean(inputs.get(DEPLOY_FIELD).toString()) : null; // Build register model input @@ -149,6 +156,27 @@ public PlainActionFuture execute( if (modelGroupId != null) { mlInputBuilder.modelGroupId(modelGroupId); } + if (modelInterface != null) { + try { + // Convert model interface string to map + BytesReference modelInterfaceBytes = new BytesArray(modelInterface.getBytes(StandardCharsets.UTF_8)); + Map modelInterfaceAsMap = XContentHelper.convertToMap( + modelInterfaceBytes, + false, + MediaTypeRegistry.JSON + ).v2(); + + // Convert to string to string map + Map parameters = ParseUtils.convertStringToObjectMapToStringToStringMap(modelInterfaceAsMap); + mlInputBuilder.modelInterface(parameters); + + } catch (Exception ex) { + String errorMessage = "Failed to create model interface"; + logger.error(errorMessage, ex); + registerLocalModelFuture.onFailure(new WorkflowStepException(errorMessage, RestStatus.BAD_REQUEST)); + } + + } if (deploy != null) { mlInputBuilder.deployModel(deploy); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStep.java index 0efa458c8..56b4f61f7 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStep.java @@ -22,6 +22,7 @@ import static org.opensearch.flowframework.common.CommonValue.EMBEDDING_DIMENSION; import static org.opensearch.flowframework.common.CommonValue.FRAMEWORK_TYPE; import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME; +import static org.opensearch.flowframework.common.CommonValue.INTERFACE_FIELD; import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE; import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE; @@ -71,7 +72,7 @@ protected Set getRequiredKeys() { @Override protected Set getOptionalKeys() { - return Set.of(DESCRIPTION_FIELD, MODEL_GROUP_ID, ALL_CONFIG, DEPLOY_FIELD); + return Set.of(DESCRIPTION_FIELD, MODEL_GROUP_ID, ALL_CONFIG, DEPLOY_FIELD, INTERFACE_FIELD); } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index cce5d6ee8..db5d290ca 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -14,8 +14,12 @@ import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateResponse; import org.opensearch.common.Booleans; +import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; @@ -27,12 +31,14 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; +import java.nio.charset.StandardCharsets; import java.util.Map; import java.util.Set; import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.GUARDRAILS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.INTERFACE_FIELD; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; @@ -76,7 +82,7 @@ public PlainActionFuture execute( PlainActionFuture registerRemoteModelFuture = PlainActionFuture.newFuture(); Set requiredKeys = Set.of(NAME_FIELD, CONNECTOR_ID); - Set optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD, DEPLOY_FIELD, GUARDRAILS_FIELD); + Set optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD, DEPLOY_FIELD, GUARDRAILS_FIELD, INTERFACE_FIELD); try { Map inputs = ParseUtils.getInputsFromPreviousSteps( @@ -93,6 +99,7 @@ public PlainActionFuture execute( String description = (String) inputs.get(DESCRIPTION_FIELD); String connectorId = (String) inputs.get(CONNECTOR_ID); Guardrails guardRails = (Guardrails) inputs.get(GUARDRAILS_FIELD); + String modelInterface = (String) inputs.get(INTERFACE_FIELD); final Boolean deploy = inputs.containsKey(DEPLOY_FIELD) ? Booleans.parseBoolean(inputs.get(DEPLOY_FIELD).toString()) : null; MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder() @@ -112,6 +119,27 @@ public PlainActionFuture execute( if (guardRails != null) { builder.guardrails(guardRails); } + if (modelInterface != null) { + try { + // Convert model interface string to map + BytesReference modelInterfaceBytes = new BytesArray(modelInterface.getBytes(StandardCharsets.UTF_8)); + Map modelInterfaceAsMap = XContentHelper.convertToMap( + modelInterfaceBytes, + false, + MediaTypeRegistry.JSON + ).v2(); + + // Convert to string to string map + Map parameters = ParseUtils.convertStringToObjectMapToStringToStringMap(modelInterfaceAsMap); + builder.modelInterface(parameters); + + } catch (Exception ex) { + String errorMessage = "Failed to create model interface"; + logger.error(errorMessage, ex); + registerRemoteModelFuture.onFailure(new WorkflowStepException(errorMessage, RestStatus.BAD_REQUEST)); + } + + } MLRegisterModelInput mlInput = builder.build(); diff --git a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java index 7ece1e463..29167c740 100644 --- a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java +++ b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java @@ -90,6 +90,12 @@ public void testParseArbitraryStringToObjectMapToString() throws Exception { assertEquals("{\"test-1\":{\"test-1\":\"test-1\"}}", parsedMap); } + public void testConvertStringToObjectMapToStringToStringMap() throws Exception { + Map map = Map.ofEntries(Map.entry("test", Map.of("test-1", "{'test-2', 'test-3'}"))); + Map convertedMap = ParseUtils.convertStringToObjectMapToStringToStringMap(map); + assertEquals("{test={\"test-1\":\"{'test-2', 'test-3'}\"}}", convertedMap.toString()); + } + public void testConditionallySubstituteWithNoPlaceholders() { String input = "This string has no placeholders"; Map outputs = new HashMap<>(); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index 1312d1638..603bfde57 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -36,6 +36,7 @@ import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; +import static org.opensearch.flowframework.common.CommonValue.INTERFACE_FIELD; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; @@ -68,7 +69,11 @@ public void setUp() throws Exception { Map.entry("function_name", "ignored"), Map.entry("name", "xyz"), Map.entry("description", "description"), - Map.entry(CONNECTOR_ID, "abcdefg") + Map.entry(CONNECTOR_ID, "abcdefg"), + Map.entry( + INTERFACE_FIELD, + "{\"output\":{\"properties\":{\"inference_results\":{\"description\":\"This is a test description field\",\"type\":\"array\",\"items\":{\"type\":\"object\",\"properties\":{\"output\":{\"description\":\"This is a test description field\",\"type\":\"array\",\"items\":{\"properties\":{\"name\":{\"description\":\"This is a test description field\",\"type\":\"string\"},\"dataAsMap\":{\"description\":\"This is a test description field\",\"type\":\"object\"}}}},\"status_code\":{\"description\":\"This is a test description field\",\"type\":\"integer\"}}}}}},\"input\":{\"properties\":{\"parameters\":{\"properties\":{\"messages\":{\"description\":\"This is a test description field\",\"type\":\"string\"}}}}}}" + ) ), "test-id", "test-node-id" @@ -205,6 +210,38 @@ public void testRegisterRemoteModelFailure() { } + public void testReisterRemoteModelInterfaceFailure() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new IllegalArgumentException("Failed to register remote model")); + return null; + }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); + + WorkflowData incorrectWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("function_name", "ignored"), + Map.entry("name", "xyz"), + Map.entry("description", "description"), + Map.entry(CONNECTOR_ID, "abcdefg"), + Map.entry(INTERFACE_FIELD, "{\"output\":") + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = this.registerRemoteModelStep.execute( + incorrectWorkflowData.getNodeId(), + incorrectWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + assertTrue(future.isDone()); + ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to create model interface", ex.getCause().getMessage()); + } + public void testRegisterRemoteModelUnSafeFailure() { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1);