From e0cd1cfdedbd1322fff1a96642208f0c3a486207 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Wed, 11 Oct 2023 15:12:02 -0700 Subject: [PATCH] Addressed PR comments Signed-off-by: Owais Kazi --- .../flowframework/common/CommonValue.java | 9 +++++++++ .../workflow/DeployModelStep.java | 12 +++++------- .../workflow/RegisterModelStep.java | 18 +++++++++--------- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index a8fdf2929..0bf8ae890 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -19,4 +19,13 @@ public class CommonValue { public static final String GLOBAL_CONTEXT_INDEX = ".plugins-ai-global-context"; public static final String GLOBAL_CONTEXT_INDEX_MAPPING = "mappings/global-context.json"; public static final Integer GLOBAL_CONTEXT_INDEX_VERSION = 1; + public static final String MODEL_ID = "model_id"; + public static final String FUNCTION_NAME = "function_name"; + public static final String MODEL_NAME = "name"; + public static final String MODEL_VERSION = "model_version"; + public static final String MODEL_GROUP_ID = "model_group_id"; + public static final String DESCRIPTION = "description"; + public static final String CONNECTOR_ID = "connector_id"; + public static final String MODEL_FORMAT = "model_format"; + public static final String MODEL_CONFIG = "model_config"; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index 81024785d..e4c9b1a14 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -20,6 +20,8 @@ import java.util.Map; import java.util.concurrent.CompletableFuture; +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; + /** * Step to deploy a model */ @@ -27,7 +29,6 @@ public class DeployModelStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(DeployModelStep.class); private Client client; - private static final String MODEL_ID = "model_id"; static final String NAME = "deploy_model"; /** @@ -64,12 +65,9 @@ public void onFailure(Exception e) { String modelId = null; for (WorkflowData workflowData : data) { - Map content = workflowData.getContent(); - for (Map.Entry entry : content.entrySet()) { - if (entry.getKey() == MODEL_ID) { - modelId = (String) content.get(MODEL_ID); - } - + if (workflowData.getContent().containsKey(MODEL_ID)) { + modelId = (String) workflowData.getContent().get(MODEL_ID); + break; } } machineLearningNodeClient.deploy(modelId, actionListener); diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java index b9b071a69..b97c56d57 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java @@ -28,6 +28,15 @@ import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; +import static org.opensearch.flowframework.common.CommonValue.CONNECTOR_ID; +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION; +import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME; +import static org.opensearch.flowframework.common.CommonValue.MODEL_CONFIG; +import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; +import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_ID; +import static org.opensearch.flowframework.common.CommonValue.MODEL_NAME; +import static org.opensearch.flowframework.common.CommonValue.MODEL_VERSION; + /** * Step to register a remote model */ @@ -39,15 +48,6 @@ public class RegisterModelStep implements WorkflowStep { static final String NAME = "register_model"; - private static final String FUNCTION_NAME = "function_name"; - private static final String MODEL_NAME = "name"; - private static final String MODEL_VERSION = "model_version"; - private static final String MODEL_GROUP_ID = "model_group_id"; - private static final String DESCRIPTION = "description"; - private static final String CONNECTOR_ID = "connector_id"; - private static final String MODEL_FORMAT = "model_format"; - private static final String MODEL_CONFIG = "model_config"; - /** * Instantiate this class * @param client client to instantiate MLClient