Skip to content

Commit

Permalink
Added new Guardrail field for remote model (#622)
Browse files Browse the repository at this point in the history
* Added new field guarddail for remote model

Signed-off-by: Owais Kazi <[email protected]>

* Fixed parsing

Signed-off-by: Owais Kazi <[email protected]>

* Deserialize

Signed-off-by: Owais Kazi <[email protected]>

* fixing guardrails

Signed-off-by: Joshua Palis <[email protected]>

* Added break

Signed-off-by: Owais Kazi <[email protected]>

---------

Signed-off-by: Owais Kazi <[email protected]>
Signed-off-by: Joshua Palis <[email protected]>
Co-authored-by: Joshua Palis <[email protected]>
  • Loading branch information
owaiskazi19 and joshpalis authored Mar 28, 2024
1 parent f9d832f commit 4a12730
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ private CommonValue() {}
public static final String PIPELINE_ID = "pipeline_id";
/** Pipeline Configurations */
public static final String CONFIGURATIONS = "configurations";
/** Guardrails field */
public static final String GUARDRAILS_FIELD = "guardrails";

/*
* Constants associated with resource provisioning / state
Expand Down
12 changes: 10 additions & 2 deletions src/main/java/org/opensearch/flowframework/model/WorkflowNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.flowframework.workflow.WorkflowStep;
import org.opensearch.ml.common.model.Guardrails;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -32,6 +33,7 @@
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS;
import static org.opensearch.flowframework.common.CommonValue.GUARDRAILS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.TOOLS_ORDER_FIELD;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToObjectMap;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToStringMap;
Expand Down Expand Up @@ -95,6 +97,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
xContentBuilder.field(e.getKey());
if (e.getValue() instanceof String || e.getValue() instanceof Number || e.getValue() instanceof Boolean) {
xContentBuilder.value(e.getValue());
} else if (GUARDRAILS_FIELD.equals(e.getKey())) {
Guardrails g = (Guardrails) e.getValue();
xContentBuilder.value(g);
} else if (e.getValue() instanceof Map<?, ?>) {
buildStringToStringMap(xContentBuilder, (Map<?, ?>) e.getValue());
} else if (e.getValue() instanceof Object[]) {
Expand Down Expand Up @@ -156,13 +161,16 @@ public static WorkflowNode parse(XContentParser parser) throws IOException {
userInputs.put(inputFieldName, parser.text());
break;
case START_OBJECT:
if (CONFIGURATIONS.equals(inputFieldName)) {
if (GUARDRAILS_FIELD.equals(inputFieldName)) {
userInputs.put(inputFieldName, Guardrails.parse(parser));
break;
} else if (CONFIGURATIONS.equals(inputFieldName)) {
Map<String, Object> configurationsMap = parser.map();
try {
String configurationsString = ParseUtils.parseArbitraryStringToObjectMapToString(configurationsMap);
userInputs.put(inputFieldName, configurationsString);
} catch (Exception ex) {
String errorMessage = "Failed to parse configuration map";
String errorMessage = "Failed to parse" + inputFieldName + "map";
logger.error(errorMessage, ex);
throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.model.Guardrails;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
Expand All @@ -29,6 +30,7 @@

import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.GUARDRAILS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID;
Expand Down Expand Up @@ -71,7 +73,7 @@ public PlainActionFuture<WorkflowData> execute(
PlainActionFuture<WorkflowData> registerRemoteModelFuture = PlainActionFuture.newFuture();

Set<String> requiredKeys = Set.of(NAME_FIELD, CONNECTOR_ID);
Set<String> optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD, DEPLOY_FIELD);
Set<String> optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD, DEPLOY_FIELD, GUARDRAILS_FIELD);

try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
Expand All @@ -87,6 +89,7 @@ public PlainActionFuture<WorkflowData> execute(
String modelGroupId = (String) inputs.get(MODEL_GROUP_ID);
String description = (String) inputs.get(DESCRIPTION_FIELD);
String connectorId = (String) inputs.get(CONNECTOR_ID);
Guardrails guardRails = (Guardrails) inputs.get(GUARDRAILS_FIELD);
final Boolean deploy = (Boolean) inputs.get(DEPLOY_FIELD);

MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder()
Expand All @@ -103,6 +106,11 @@ public PlainActionFuture<WorkflowData> execute(
if (deploy != null) {
builder.deployModel(deploy);
}

if (guardRails != null) {
builder.guardrails(guardRails);
}

MLRegisterModelInput mlInput = builder.build();

mlClient.register(mlInput, new ActionListener<MLRegisterModelResponse>() {
Expand Down

0 comments on commit 4a12730

Please sign in to comment.