From 43d3df1ecf8d8239c9b36267712a6afeb703625c Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Wed, 27 Mar 2024 16:29:37 -0700 Subject: [PATCH 1/5] Added new field guarddail for remote model Signed-off-by: Owais Kazi --- .../flowframework/common/CommonValue.java | 2 + .../workflow/RegisterRemoteModelStep.java | 48 +++++++++++++++++-- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index d3960d90b..8df5613d4 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -168,6 +168,8 @@ private CommonValue() {} public static final String PIPELINE_ID = "pipeline_id"; /** Pipeline Configurations */ public static final String CONFIGURATIONS = "configurations"; + /** Guardrails field */ + public static final String GUARDRAILS_FIELD = "guardrails"; /* * Constants associated with resource provisioning / state diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index c32a7f0bd..ab0b98e79 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -20,15 +20,18 @@ import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.model.Guardrail; +import org.opensearch.ml.common.model.Guardrails; +import org.opensearch.ml.common.model.StopWords; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; -import java.util.Map; -import java.util.Set; +import java.util.*; import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; +import static org.opensearch.flowframework.common.CommonValue.GUARDRAILS_FIELD; import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; @@ -71,7 +74,7 @@ public PlainActionFuture execute( PlainActionFuture registerRemoteModelFuture = PlainActionFuture.newFuture(); Set requiredKeys = Set.of(NAME_FIELD, CONNECTOR_ID); - Set optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD, DEPLOY_FIELD); + Set optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD, DEPLOY_FIELD, GUARDRAILS_FIELD); try { Map inputs = ParseUtils.getInputsFromPreviousSteps( @@ -87,6 +90,7 @@ public PlainActionFuture execute( String modelGroupId = (String) inputs.get(MODEL_GROUP_ID); String description = (String) inputs.get(DESCRIPTION_FIELD); String connectorId = (String) inputs.get(CONNECTOR_ID); + Guardrails guardRails = getGuardRails(inputs.get(GUARDRAILS_FIELD)); final Boolean deploy = (Boolean) inputs.get(DEPLOY_FIELD); MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder() @@ -103,6 +107,10 @@ public PlainActionFuture execute( if (deploy != null) { builder.deployModel(deploy); } + if (guardRails != null) { + builder.guardrails(guardRails); + } + MLRegisterModelInput mlInput = builder.build(); mlClient.register(mlInput, new ActionListener() { @@ -188,6 +196,40 @@ public void onFailure(Exception e) { return registerRemoteModelFuture; } + private Guardrails getGuardRails(Object guardRails) { + Map map = (Map) guardRails; + + String type = null; + Guardrail inputGuardRail = null; + Guardrail outputGuardRail = null; + + type = (String) map.get(Guardrails.TYPE_FIELD); + inputGuardRail = getGuardRail(map.get(Guardrails.INPUT_GUARDRAIL_FIELD)); + outputGuardRail = getGuardRail(map.get(Guardrails.OUTPUT_GUARDRAIL_FIELD)); + + return new Guardrails(type, inputGuardRail, outputGuardRail); + } + + private Guardrail getGuardRail(Object guardRail) { + Map map = (Map) guardRail; + + List stopWords = new ArrayList<>(); + String[] regex = {}; + + List> stopWordsList = (List>) map.get(Guardrail.STOP_WORDS_FIELD); + + for (Map stopWord : stopWordsList) { + String indexName = (String) stopWord.get("index_name"); + String[] sourceFields = (String[]) stopWord.get("source_fields"); + StopWords stopWordsObject = new StopWords(indexName, sourceFields); + stopWords.add(stopWordsObject); + } + + regex = (String[]) map.get(Guardrail.REGEX_FIELD); + + return new Guardrail(stopWords, regex); + } + @Override public String getName() { return NAME; From b8535f454061eed3f2d3b81db7ae27d9704c52e1 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Wed, 27 Mar 2024 17:25:33 -0700 Subject: [PATCH 2/5] Fixed parsing Signed-off-by: Owais Kazi --- .../flowframework/model/WorkflowNode.java | 5 ++- .../workflow/RegisterRemoteModelStep.java | 43 ++----------------- 2 files changed, 7 insertions(+), 41 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index 15d52ccd1..e9f582a6f 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -32,6 +32,7 @@ import static java.util.concurrent.TimeUnit.SECONDS; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; +import static org.opensearch.flowframework.common.CommonValue.GUARDRAILS_FIELD; import static org.opensearch.flowframework.common.CommonValue.TOOLS_ORDER_FIELD; import static org.opensearch.flowframework.util.ParseUtils.buildStringToObjectMap; import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap; @@ -156,13 +157,13 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { userInputs.put(inputFieldName, parser.text()); break; case START_OBJECT: - if (CONFIGURATIONS.equals(inputFieldName)) { + if (CONFIGURATIONS.equals(inputFieldName) || GUARDRAILS_FIELD.equals(inputFieldName)) { Map configurationsMap = parser.map(); try { String configurationsString = ParseUtils.parseArbitraryStringToObjectMapToString(configurationsMap); userInputs.put(inputFieldName, configurationsString); } catch (Exception ex) { - String errorMessage = "Failed to parse configuration map"; + String errorMessage = "Failed to parse" + inputFieldName + "map"; logger.error(errorMessage, ex); throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index ab0b98e79..7c37db7ca 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -27,7 +27,10 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; -import java.util.*; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; @@ -90,7 +93,6 @@ public PlainActionFuture execute( String modelGroupId = (String) inputs.get(MODEL_GROUP_ID); String description = (String) inputs.get(DESCRIPTION_FIELD); String connectorId = (String) inputs.get(CONNECTOR_ID); - Guardrails guardRails = getGuardRails(inputs.get(GUARDRAILS_FIELD)); final Boolean deploy = (Boolean) inputs.get(DEPLOY_FIELD); MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder() @@ -107,9 +109,6 @@ public PlainActionFuture execute( if (deploy != null) { builder.deployModel(deploy); } - if (guardRails != null) { - builder.guardrails(guardRails); - } MLRegisterModelInput mlInput = builder.build(); @@ -196,40 +195,6 @@ public void onFailure(Exception e) { return registerRemoteModelFuture; } - private Guardrails getGuardRails(Object guardRails) { - Map map = (Map) guardRails; - - String type = null; - Guardrail inputGuardRail = null; - Guardrail outputGuardRail = null; - - type = (String) map.get(Guardrails.TYPE_FIELD); - inputGuardRail = getGuardRail(map.get(Guardrails.INPUT_GUARDRAIL_FIELD)); - outputGuardRail = getGuardRail(map.get(Guardrails.OUTPUT_GUARDRAIL_FIELD)); - - return new Guardrails(type, inputGuardRail, outputGuardRail); - } - - private Guardrail getGuardRail(Object guardRail) { - Map map = (Map) guardRail; - - List stopWords = new ArrayList<>(); - String[] regex = {}; - - List> stopWordsList = (List>) map.get(Guardrail.STOP_WORDS_FIELD); - - for (Map stopWord : stopWordsList) { - String indexName = (String) stopWord.get("index_name"); - String[] sourceFields = (String[]) stopWord.get("source_fields"); - StopWords stopWordsObject = new StopWords(indexName, sourceFields); - stopWords.add(stopWordsObject); - } - - regex = (String[]) map.get(Guardrail.REGEX_FIELD); - - return new Guardrail(stopWords, regex); - } - @Override public String getName() { return NAME; From 7cd96e3ad3016665cd60bbcb9ec2d527b3e27bba Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Wed, 27 Mar 2024 17:51:26 -0700 Subject: [PATCH 3/5] Deserialize Signed-off-by: Owais Kazi --- .../flowframework/model/WorkflowNode.java | 2 +- .../workflow/RegisterRemoteModelStep.java | 25 ++++++++++++++++--- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index e9f582a6f..fc11f5f8e 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -163,7 +163,7 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { String configurationsString = ParseUtils.parseArbitraryStringToObjectMapToString(configurationsMap); userInputs.put(inputFieldName, configurationsString); } catch (Exception ex) { - String errorMessage = "Failed to parse" + inputFieldName + "map"; + String errorMessage = "Failed to parse" + inputFieldName + "map"; logger.error(errorMessage, ex); throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index 7c37db7ca..05fa2f64b 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -14,21 +14,21 @@ import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.model.Guardrail; import org.opensearch.ml.common.model.Guardrails; -import org.opensearch.ml.common.model.StopWords; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; -import java.util.ArrayList; -import java.util.List; +import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.Map; import java.util.Set; @@ -93,8 +93,21 @@ public PlainActionFuture execute( String modelGroupId = (String) inputs.get(MODEL_GROUP_ID); String description = (String) inputs.get(DESCRIPTION_FIELD); String connectorId = (String) inputs.get(CONNECTOR_ID); + String guardRails = (String) inputs.get(GUARDRAILS_FIELD); final Boolean deploy = (Boolean) inputs.get(DEPLOY_FIELD); + byte[] byteArr = guardRails.getBytes(StandardCharsets.UTF_8); + BytesReference guardRailsBytes = new BytesArray(byteArr); + Guardrails guardrail = null; + + try { + guardrail = new Guardrails(guardRailsBytes.streamInput()); + } catch (IOException e) { + String errorMessage = "Failed to add guardrails"; + logger.error(errorMessage, e); + registerRemoteModelFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e))); + } + MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder() .functionName(FunctionName.REMOTE) .modelName(modelName) @@ -110,6 +123,10 @@ public PlainActionFuture execute( builder.deployModel(deploy); } + if (guardRails != null) { + builder.guardrails(guardrail); + } + MLRegisterModelInput mlInput = builder.build(); mlClient.register(mlInput, new ActionListener() { From fb7b9b1190fb584b905d89b0982369fcc9c4dcdc Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Thu, 28 Mar 2024 01:56:57 +0000 Subject: [PATCH 4/5] fixing guardrails Signed-off-by: Joshua Palis --- .../flowframework/model/WorkflowNode.java | 8 +++++++- .../workflow/RegisterRemoteModelStep.java | 20 ++----------------- 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index fc11f5f8e..5d2c19fb1 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -20,6 +20,7 @@ import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.flowframework.workflow.WorkflowStep; +import org.opensearch.ml.common.model.Guardrails; import java.io.IOException; import java.util.ArrayList; @@ -96,6 +97,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws xContentBuilder.field(e.getKey()); if (e.getValue() instanceof String || e.getValue() instanceof Number || e.getValue() instanceof Boolean) { xContentBuilder.value(e.getValue()); + } else if (GUARDRAILS_FIELD.equals(e.getKey())) { + Guardrails g = (Guardrails) e.getValue(); + xContentBuilder.value(g); } else if (e.getValue() instanceof Map) { buildStringToStringMap(xContentBuilder, (Map) e.getValue()); } else if (e.getValue() instanceof Object[]) { @@ -157,7 +161,9 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { userInputs.put(inputFieldName, parser.text()); break; case START_OBJECT: - if (CONFIGURATIONS.equals(inputFieldName) || GUARDRAILS_FIELD.equals(inputFieldName)) { + if (GUARDRAILS_FIELD.equals(inputFieldName)) { + userInputs.put(inputFieldName, Guardrails.parse(parser)); + } else if (CONFIGURATIONS.equals(inputFieldName)) { Map configurationsMap = parser.map(); try { String configurationsString = ParseUtils.parseArbitraryStringToObjectMapToString(configurationsMap); diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index 05fa2f64b..cc3800284 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -14,8 +14,6 @@ import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.bytes.BytesArray; -import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; @@ -27,8 +25,6 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; -import java.io.IOException; -import java.nio.charset.StandardCharsets; import java.util.Map; import java.util.Set; @@ -93,21 +89,9 @@ public PlainActionFuture execute( String modelGroupId = (String) inputs.get(MODEL_GROUP_ID); String description = (String) inputs.get(DESCRIPTION_FIELD); String connectorId = (String) inputs.get(CONNECTOR_ID); - String guardRails = (String) inputs.get(GUARDRAILS_FIELD); + Guardrails guardRails = (Guardrails) inputs.get(GUARDRAILS_FIELD); final Boolean deploy = (Boolean) inputs.get(DEPLOY_FIELD); - byte[] byteArr = guardRails.getBytes(StandardCharsets.UTF_8); - BytesReference guardRailsBytes = new BytesArray(byteArr); - Guardrails guardrail = null; - - try { - guardrail = new Guardrails(guardRailsBytes.streamInput()); - } catch (IOException e) { - String errorMessage = "Failed to add guardrails"; - logger.error(errorMessage, e); - registerRemoteModelFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(e))); - } - MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder() .functionName(FunctionName.REMOTE) .modelName(modelName) @@ -124,7 +108,7 @@ public PlainActionFuture execute( } if (guardRails != null) { - builder.guardrails(guardrail); + builder.guardrails(guardRails); } MLRegisterModelInput mlInput = builder.build(); From 7fd499b9847f313490dae1c6f5dd68190f60c661 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Wed, 27 Mar 2024 19:04:22 -0700 Subject: [PATCH 5/5] Added break Signed-off-by: Owais Kazi --- .../java/org/opensearch/flowframework/model/WorkflowNode.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java index 5d2c19fb1..899167ac8 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowNode.java @@ -163,6 +163,7 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { case START_OBJECT: if (GUARDRAILS_FIELD.equals(inputFieldName)) { userInputs.put(inputFieldName, Guardrails.parse(parser)); + break; } else if (CONFIGURATIONS.equals(inputFieldName)) { Map configurationsMap = parser.map(); try {