diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java index cc5645306..85a0190a2 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java @@ -17,6 +17,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.model.MLModelConfig; @@ -28,10 +29,8 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; -import java.util.ArrayList; -import java.util.List; import java.util.Map; -import java.util.Map.Entry; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; @@ -81,12 +80,6 @@ public CompletableFuture execute( CompletableFuture registerLocalModelFuture = new CompletableFuture<>(); - // TODO: Recreating the list to get this compiling - // Need to refactor the below iteration to pull directly from the maps - List data = new ArrayList<>(); - data.add(currentNodeInputs); - data.addAll(outputs.values()); - ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { @@ -105,76 +98,41 @@ public void onFailure(Exception e) { } }; - String modelName = null; - String modelVersion = null; - String description = null; - MLModelFormat modelFormat = null; - String modelGroupId = null; - String modelContentHashValue = null; - String modelType = null; - String embeddingDimension = null; - FrameworkType frameworkType = null; - String allConfig = null; - String url = null; - - for (WorkflowData workflowData : data) { - Map content = workflowData.getContent(); - - for (Entry entry : content.entrySet()) { - switch (entry.getKey()) { - case NAME_FIELD: - modelName = (String) content.get(NAME_FIELD); - break; - case VERSION_FIELD: - modelVersion = (String) content.get(VERSION_FIELD); - break; - case DESCRIPTION_FIELD: - description = (String) content.get(DESCRIPTION_FIELD); - break; - case MODEL_FORMAT: - modelFormat = MLModelFormat.from((String) content.get(MODEL_FORMAT)); - break; - case MODEL_GROUP_ID: - modelGroupId = (String) content.get(MODEL_GROUP_ID); - break; - case MODEL_TYPE: - modelType = (String) content.get(MODEL_TYPE); - break; - case EMBEDDING_DIMENSION: - embeddingDimension = (String) content.get(EMBEDDING_DIMENSION); - break; - case FRAMEWORK_TYPE: - frameworkType = FrameworkType.from((String) content.get(FRAMEWORK_TYPE)); - break; - case ALL_CONFIG: - allConfig = (String) content.get(ALL_CONFIG); - break; - case MODEL_CONTENT_HASH_VALUE: - modelContentHashValue = (String) content.get(MODEL_CONTENT_HASH_VALUE); - break; - case URL: - url = (String) content.get(URL); - break; - default: - break; - - } - } - } - - if (Stream.of( - modelName, - modelVersion, - modelFormat, - modelGroupId, - modelType, - embeddingDimension, - frameworkType, - modelContentHashValue, - url - ).allMatch(x -> x != null)) { + Set requiredKeys = Set.of( + NAME_FIELD, + VERSION_FIELD, + MODEL_FORMAT, + MODEL_GROUP_ID, + MODEL_TYPE, + EMBEDDING_DIMENSION, + FRAMEWORK_TYPE, + MODEL_CONTENT_HASH_VALUE, + URL + ); + Set optionalKeys = Set.of(DESCRIPTION_FIELD, ALL_CONFIG); + + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs + ); - // Create Model configudation + String modelName = (String) inputs.get(NAME_FIELD); + String modelVersion = (String) inputs.get(VERSION_FIELD); + String description = (String) inputs.get(DESCRIPTION_FIELD); + MLModelFormat modelFormat = MLModelFormat.from((String) inputs.get(MODEL_FORMAT)); + String modelGroupId = (String) inputs.get(MODEL_GROUP_ID); + String modelContentHashValue = (String) inputs.get(MODEL_CONTENT_HASH_VALUE); + String modelType = (String) inputs.get(MODEL_TYPE); + String embeddingDimension = (String) inputs.get(EMBEDDING_DIMENSION); + FrameworkType frameworkType = FrameworkType.from((String) inputs.get(FRAMEWORK_TYPE)); + String allConfig = (String) inputs.get(ALL_CONFIG); + String url = (String) inputs.get(URL); + + // Create Model configuration TextEmbeddingModelConfigBuilder modelConfigBuilder = TextEmbeddingModelConfig.builder() .modelType(modelType) .embeddingDimension(Integer.valueOf(embeddingDimension)) @@ -200,12 +158,9 @@ public void onFailure(Exception e) { MLRegisterModelInput mlInput = mlInputBuilder.build(); mlClient.register(mlInput, actionListener); - } else { - registerLocalModelFuture.completeExceptionally( - new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) - ); + } catch (FlowFrameworkException e) { + registerLocalModelFuture.completeExceptionally(e); } - return registerLocalModelFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index 27a77cb98..8bf61fdc8 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -12,21 +12,18 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; -import java.util.ArrayList; -import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.Map.Entry; +import java.util.Set; import java.util.concurrent.CompletableFuture; -import java.util.stream.Stream; import static org.opensearch.flowframework.common.CommonValue.CONNECTOR_ID; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; @@ -88,48 +85,23 @@ public void onFailure(Exception e) { } }; - String modelName = null; - FunctionName functionName = null; - String modelGroupId = null; - String description = null; - String connectorId = null; - - // TODO: Recreating the list to get this compiling - // Need to refactor the below iteration to pull directly from the maps - List data = new ArrayList<>(); - data.add(currentNodeInputs); - data.addAll(outputs.values()); - - // TODO : Handle inline connector configuration : https://github.com/opensearch-project/flow-framework/issues/149 - - for (WorkflowData workflowData : data) { - - Map content = workflowData.getContent(); - for (Entry entry : content.entrySet()) { - switch (entry.getKey()) { - case NAME_FIELD: - modelName = (String) content.get(NAME_FIELD); - break; - case FUNCTION_NAME: - functionName = FunctionName.from(((String) content.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT)); - break; - case MODEL_GROUP_ID: - modelGroupId = (String) content.get(MODEL_GROUP_ID); - break; - case DESCRIPTION_FIELD: - description = (String) content.get(DESCRIPTION_FIELD); - break; - case CONNECTOR_ID: - connectorId = (String) content.get(CONNECTOR_ID); - break; - default: - break; - - } - } - } + Set requiredKeys = Set.of(NAME_FIELD, FUNCTION_NAME, CONNECTOR_ID); + Set optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD); + + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs + ); - if (Stream.of(modelName, functionName, connectorId).allMatch(x -> x != null)) { + String modelName = (String) inputs.get(NAME_FIELD); + FunctionName functionName = FunctionName.from(((String) inputs.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT)); + String modelGroupId = (String) inputs.get(MODEL_GROUP_ID); + String description = (String) inputs.get(DESCRIPTION_FIELD); + String connectorId = (String) inputs.get(CONNECTOR_ID); MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder() .functionName(functionName) @@ -145,12 +117,10 @@ public void onFailure(Exception e) { MLRegisterModelInput mlInput = builder.build(); mlClient.register(mlInput, actionListener); - } else { - registerRemoteModelFuture.completeExceptionally( - new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) - ); - } + } catch (FlowFrameworkException e) { + registerRemoteModelFuture.completeExceptionally(e); + } return registerRemoteModelFuture; } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java index d169812a9..917abcc6f 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java @@ -219,7 +219,7 @@ public void testRegisterLocalModelTaskFailure() { public void testMissingInputs() { CompletableFuture future = registerLocalModelStep.execute( "nodeId", - WorkflowData.EMPTY, + new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"), Collections.emptyMap(), Collections.emptyMap() ); @@ -227,7 +227,19 @@ public void testMissingInputs() { assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); - assertEquals("Required fields are not provided", ex.getCause().getMessage()); + assertTrue(ex.getCause().getMessage().startsWith("Missing required inputs [")); + for (String s : new String[] { + "model_format", + "name", + "model_type", + "embedding_dimension", + "framework_type", + "model_group_id", + "version", + "url", + "model_content_hash_value" }) { + assertTrue(ex.getCause().getMessage().contains(s)); + } + assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]")); } - } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index cde194326..c4715b4db 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -113,7 +113,7 @@ public void testRegisterRemoteModelFailure() { public void testMissingInputs() { CompletableFuture future = this.registerRemoteModelStep.execute( "nodeId", - WorkflowData.EMPTY, + new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"), Collections.emptyMap(), Collections.emptyMap() ); @@ -121,7 +121,11 @@ public void testMissingInputs() { assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); - assertEquals("Required fields are not provided", ex.getCause().getMessage()); + assertTrue(ex.getCause().getMessage().startsWith("Missing required inputs [")); + for (String s : new String[] { "name", "function_name", "connector_id" }) { + assertTrue(ex.getCause().getMessage().contains(s)); + } + assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]")); } }