diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index 1312d1638..603bfde57 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -36,6 +36,7 @@ import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; +import static org.opensearch.flowframework.common.CommonValue.INTERFACE_FIELD; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; @@ -68,7 +69,11 @@ public void setUp() throws Exception { Map.entry("function_name", "ignored"), Map.entry("name", "xyz"), Map.entry("description", "description"), - Map.entry(CONNECTOR_ID, "abcdefg") + Map.entry(CONNECTOR_ID, "abcdefg"), + Map.entry( + INTERFACE_FIELD, + "{\"output\":{\"properties\":{\"inference_results\":{\"description\":\"This is a test description field\",\"type\":\"array\",\"items\":{\"type\":\"object\",\"properties\":{\"output\":{\"description\":\"This is a test description field\",\"type\":\"array\",\"items\":{\"properties\":{\"name\":{\"description\":\"This is a test description field\",\"type\":\"string\"},\"dataAsMap\":{\"description\":\"This is a test description field\",\"type\":\"object\"}}}},\"status_code\":{\"description\":\"This is a test description field\",\"type\":\"integer\"}}}}}},\"input\":{\"properties\":{\"parameters\":{\"properties\":{\"messages\":{\"description\":\"This is a test description field\",\"type\":\"string\"}}}}}}" + ) ), "test-id", "test-node-id" @@ -205,6 +210,38 @@ public void testRegisterRemoteModelFailure() { } + public void testReisterRemoteModelInterfaceFailure() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new IllegalArgumentException("Failed to register remote model")); + return null; + }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); + + WorkflowData incorrectWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("function_name", "ignored"), + Map.entry("name", "xyz"), + Map.entry("description", "description"), + Map.entry(CONNECTOR_ID, "abcdefg"), + Map.entry(INTERFACE_FIELD, "{\"output\":") + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = this.registerRemoteModelStep.execute( + incorrectWorkflowData.getNodeId(), + incorrectWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + assertTrue(future.isDone()); + ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to create model interface", ex.getCause().getMessage()); + } + public void testRegisterRemoteModelUnSafeFailure() { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1);