Skip to content

Commit

Permalink
ML Model Interface (#2357)
Browse files Browse the repository at this point in the history
* ML Interface poc

Signed-off-by: Sicheng Song <[email protected]>

* Style fix

Signed-off-by: Sicheng Song <[email protected]>

* Adjust model interface minimal support version

Signed-off-by: Sicheng Song <[email protected]>

* Fix '*' import

Signed-off-by: Sicheng Song <[email protected]>

* Fix UT

Signed-off-by: Sicheng Song <[email protected]>

* Address review concern

Signed-off-by: Sicheng Song <[email protected]>

* Change json schema pacakge to highest stars library

Signed-off-by: Sicheng Song <[email protected]>

* style fix

Signed-off-by: Sicheng Song <[email protected]>

* Move model interface from connector to model index

Signed-off-by: Sicheng Song <[email protected]>

* Fix compilation and UT

Signed-off-by: Sicheng Song <[email protected]>

* Move schema validation to rest layer

Signed-off-by: Sicheng Song <[email protected]>

* Remove unnecessary dependencies

Signed-off-by: Sicheng Song <[email protected]>

* Initiate modelInterface to null object

Signed-off-by: Sicheng Song <[email protected]>

* Change sout to log

Signed-off-by: Sicheng Song <[email protected]>

* Remove debug info

Signed-off-by: Sicheng Song <[email protected]>

* Fix minor styles

Signed-off-by: Sicheng Song <[email protected]>

* Fix minor styles

Signed-off-by: Sicheng Song <[email protected]>

* Fix minor styles

Signed-off-by: Sicheng Song <[email protected]>

* Fix build

Signed-off-by: Sicheng Song <[email protected]>

* Fix UTs

Signed-off-by: Sicheng Song <[email protected]>

* Fix UT

Signed-off-by: Sicheng Song <[email protected]>

* Validate whole output schema

Signed-off-by: Sicheng Song <[email protected]>

* Fix doc

Signed-off-by: Sicheng Song <[email protected]>

* Fix style

Signed-off-by: Sicheng Song <[email protected]>

* Rebase

Signed-off-by: Sicheng Song <[email protected]>

---------

Signed-off-by: Sicheng Song <[email protected]>
  • Loading branch information
b4sjoo authored Apr 30, 2024
1 parent 94a113d commit f9454e8
Show file tree
Hide file tree
Showing 25 changed files with 582 additions and 193 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ buildscript {
}

common_utils_version = System.getProperty("common_utils.version", opensearch_build)
kotlin_version = System.getProperty("kotlin.version", "1.8.21")
kotlin_version = System.getProperty("kotlin.version", "1.9.23")
}

repositories {
Expand Down
184 changes: 92 additions & 92 deletions common/src/main/java/org/opensearch/ml/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public class CommonValue {
public static final String ML_MODEL_INDEX = ".plugins-ml-model";
public static final String ML_TASK_INDEX = ".plugins-ml-task";
public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 10;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 11;
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 3;
Expand All @@ -82,59 +82,56 @@ public class CommonValue {
+ " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n"
+ " }\n"
+ " }\n";
public static final String ML_MODEL_GROUP_INDEX_MAPPING = "{\n" +
" \"_meta\": {\n" +
" \"schema_version\": " + ML_MODEL_GROUP_INDEX_SCHEMA_VERSION + "\n" +
" },\n" +
" \"properties\": {\n" +
" \"" + MLModelGroup.MODEL_GROUP_NAME_FIELD + "\": {\n" +
" \"type\": \"text\",\n" +
" \"fields\": {\n" +
" \"keyword\": {\n" +
" \"type\": \"keyword\",\n" +
" \"ignore_above\": 256\n" +
" }\n" +
" }\n" +
" },\n" +
" \"" + MLModelGroup.DESCRIPTION_FIELD + "\": {\n" +
" \"type\": \"text\"\n" +
" },\n" +
" \"" + MLModelGroup.LATEST_VERSION_FIELD + "\": {\n" +
" \"type\": \"integer\"\n" +
" },\n" +
" \"" + MLModelGroup.MODEL_GROUP_ID_FIELD + "\": {\n" +
" \"type\": \"keyword\"\n" +
" },\n" +
" \"" + MLModelGroup.BACKEND_ROLES_FIELD + "\": {\n" +
" \"type\": \"text\",\n" +
" \"fields\": {\n" +
" \"keyword\": {\n" +
" \"type\": \"keyword\",\n" +
" \"ignore_above\": 256\n" +
" }\n" +
" }\n" +
" },\n" +
" \"" + MLModelGroup.ACCESS + "\": {\n" +
" \"type\": \"keyword\"\n" +
" },\n" +
" \"" + MLModelGroup.OWNER + "\": {\n" +
" \"type\": \"nested\",\n" +
" \"properties\": {\n" +
" \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n"
+
" \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n"
+
" \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" +
" \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n"
+
" }\n" +
" },\n" +
" \"" + MLModelGroup.CREATED_TIME_FIELD + "\": {\n" +
" \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" +
" \"" + MLModelGroup.LAST_UPDATED_TIME_FIELD + "\": {\n" +
" \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" +
" }\n" +
"}";
public static final String ML_MODEL_GROUP_INDEX_MAPPING = "{\n"
+ " \"_meta\": {\n"
+ " \"schema_version\": " + ML_MODEL_GROUP_INDEX_SCHEMA_VERSION + "\n"
+ " },\n"
+ " \"properties\": {\n"
+ " \"" + MLModelGroup.MODEL_GROUP_NAME_FIELD + "\": {\n"
+ " \"type\": \"text\",\n"
+ " \"fields\": {\n"
+ " \"keyword\": {\n"
+ " \"type\": \"keyword\",\n"
+ " \"ignore_above\": 256\n"
+ " }\n"
+ " }\n"
+ " },\n"
+ " \"" + MLModelGroup.DESCRIPTION_FIELD + "\": {\n"
+ " \"type\": \"text\"\n"
+ " },\n"
+ " \"" + MLModelGroup.LATEST_VERSION_FIELD + "\": {\n"
+ " \"type\": \"integer\"\n"
+ " },\n"
+ " \"" + MLModelGroup.MODEL_GROUP_ID_FIELD + "\": {\n"
+ " \"type\": \"keyword\"\n"
+ " },\n"
+ " \"" + MLModelGroup.BACKEND_ROLES_FIELD + "\": {\n"
+ " \"type\": \"text\",\n"
+ " \"fields\": {\n"
+ " \"keyword\": {\n"
+ " \"type\": \"keyword\",\n"
+ " \"ignore_above\": 256\n"
+ " }\n"
+ " }\n"
+ " },\n"
+ " \"" + MLModelGroup.ACCESS + "\": {\n"
+ " \"type\": \"keyword\"\n"
+ " },\n"
+ " \"" + MLModelGroup.OWNER + "\": {\n"
+ " \"type\": \"nested\",\n"
+ " \"properties\": {\n"
+ " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n"
+ " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n"
+ " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n"
+ " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n"
+ " }\n"
+ " },\n"
+ " \"" + MLModelGroup.CREATED_TIME_FIELD + "\": {\n"
+ " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
+ " \"" + MLModelGroup.LAST_UPDATED_TIME_FIELD + "\": {\n"
+ " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n"
+ " }\n"
+ "}";

public static final String ML_CONNECTOR_INDEX_FIELDS = " \"properties\": {\n"
+ " \""
Expand Down Expand Up @@ -265,45 +262,48 @@ public class CommonValue {
+ MLModel.LAST_UNDEPLOYED_TIME_FIELD
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
+ " \""
+ MLModel.INTERFACE_FIELD
+ "\": {\"type\": \"flat_object\"},\n"
+ " \""
+ MLModel.GUARDRAILS_FIELD
+ "\" : {\n" +
" \"properties\": {\n" +
" \"input_guardrail\": {\n" +
" \"properties\": {\n" +
" \"regex\": {\n" +
" \"type\": \"text\"\n" +
" },\n" +
" \"stop_words\": {\n" +
" \"properties\": {\n" +
" \"index_name\": {\n" +
" \"type\": \"text\"\n" +
" },\n" +
" \"source_fields\": {\n" +
" \"type\": \"text\"\n" +
" }\n" +
" }\n" +
" }\n" +
" }\n" +
" },\n" +
" \"output_guardrail\": {\n" +
" \"properties\": {\n" +
" \"regex\": {\n" +
" \"type\": \"text\"\n" +
" },\n" +
" \"stop_words\": {\n" +
" \"properties\": {\n" +
" \"index_name\": {\n" +
" \"type\": \"text\"\n" +
" },\n" +
" \"source_fields\": {\n" +
" \"type\": \"text\"\n" +
" }\n" +
" }\n" +
" }\n" +
" }\n" +
" }\n" +
" }\n" +
" },\n"
+ "\" : {\n"
+ " \"properties\": {\n"
+ " \"input_guardrail\": {\n"
+ " \"properties\": {\n"
+ " \"regex\": {\n"
+ " \"type\": \"text\"\n"
+ " },\n"
+ " \"stop_words\": {\n"
+ " \"properties\": {\n"
+ " \"index_name\": {\n"
+ " \"type\": \"text\"\n"
+ " },\n"
+ " \"source_fields\": {\n"
+ " \"type\": \"text\"\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " },\n"
+ " \"output_guardrail\": {\n"
+ " \"properties\": {\n"
+ " \"regex\": {\n"
+ " \"type\": \"text\"\n"
+ " },\n"
+ " \"stop_words\": {\n"
+ " \"properties\": {\n"
+ " \"index_name\": {\n"
+ " \"type\": \"text\"\n"
+ " },\n"
+ " \"source_fields\": {\n"
+ " \"type\": \"text\"\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " },\n"
+ " \""
+ MLModel.CONNECTOR_FIELD
+ "\": {" + ML_CONNECTOR_INDEX_FIELDS + " }\n},"
Expand Down
61 changes: 60 additions & 1 deletion common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,18 @@
import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.USER;
import static org.opensearch.ml.common.connector.Connector.createConnector;
import static org.opensearch.ml.common.utils.StringUtils.filteredParameterMap;

@Getter
public class MLModel implements ToXContentObject {
Expand Down Expand Up @@ -89,6 +95,9 @@ public class MLModel implements ToXContentObject {
public static final String CONNECTOR_FIELD = "connector";
public static final String CONNECTOR_ID_FIELD = "connector_id";
public static final String GUARDRAILS_FIELD = "guardrails";
public static final String INTERFACE_FIELD = "interface";

public static final Set<String> allowedInterfaceFieldKeys = new HashSet<>(Arrays.asList("input", "output"));

private String name;
private String modelGroupId;
Expand Down Expand Up @@ -134,6 +143,36 @@ public class MLModel implements ToXContentObject {
private String connectorId;
private Guardrails guardrails;

/**
* Model interface is a map that contains the input and output fields of the model, with JSON schema as the value.
* Sample model interface:
* {
* "interface": {
* "input": {
* "properties": {
* "parameters": {
* "properties": {
* "messages": {
* "type": "string",
* "description": "This is a test description field"
* }
* }
* }
* }
* },
* "output": {
* "properties": {
* "inference_results": {
* "type": "array",
* "description": "This is a test description field"
* }
* }
* }
* }
* }
*/
private Map<String, String> modelInterface;

@Builder(toBuilder = true)
public MLModel(String name,
String modelGroupId,
Expand Down Expand Up @@ -166,7 +205,8 @@ public MLModel(String name,
Boolean isHidden,
Connector connector,
String connectorId,
Guardrails guardrails) {
Guardrails guardrails,
Map<String, String> modelInterface) {
this.name = name;
this.modelGroupId = modelGroupId;
this.algorithm = algorithm;
Expand Down Expand Up @@ -200,6 +240,7 @@ public MLModel(String name,
this.connector = connector;
this.connectorId = connectorId;
this.guardrails = guardrails;
this.modelInterface = modelInterface;
}

public MLModel(StreamInput input) throws IOException {
Expand Down Expand Up @@ -261,6 +302,9 @@ public MLModel(StreamInput input) throws IOException {
if (input.readBoolean()) {
this.guardrails = new Guardrails(input);
}
if (input.readBoolean()) {
modelInterface = input.readMap(StreamInput::readString, StreamInput::readString);
}
}
}

Expand Down Expand Up @@ -338,6 +382,12 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
if (modelInterface != null) {
out.writeBoolean(true);
out.writeMap(modelInterface, StreamOutput::writeString, StreamOutput::writeString);
} else {
out.writeBoolean(false);
}
}

@Override
Expand Down Expand Up @@ -442,6 +492,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (guardrails != null) {
builder.field(GUARDRAILS_FIELD, guardrails);
}
if (modelInterface != null) {
builder.field(INTERFACE_FIELD, modelInterface);
}
builder.endObject();
return builder;
}
Expand Down Expand Up @@ -486,6 +539,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
Connector connector = null;
String connectorId = null;
Guardrails guardrails = null;
Map<String, String> modelInterface = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -617,6 +671,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
case GUARDRAILS_FIELD:
guardrails = Guardrails.parse(parser);
break;
case INTERFACE_FIELD:
modelInterface = filteredParameterMap(parser.map(), allowedInterfaceFieldKeys);
break;
default:
parser.skipChildren();
break;
Expand Down Expand Up @@ -656,11 +713,13 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
.connector(connector)
.connectorId(connectorId)
.guardrails(guardrails)
.modelInterface(modelInterface)
.build();
}

public static MLModel fromStream(StreamInput in) throws IOException {
MLModel mlModel = new MLModel(in);
return mlModel;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import lombok.extern.log4j.Log4j2;
import org.apache.commons.text.StringSubstitutor;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.AccessMode;
Expand Down Expand Up @@ -361,7 +361,7 @@ public void decrypt(Function<String, String> function) {

@Override
public Connector cloneConnector() {
try (BytesStreamOutput bytesStreamOutput = new BytesStreamOutput()){
try (BytesStreamOutput bytesStreamOutput = new BytesStreamOutput()) {
this.writeTo(bytesStreamOutput);
StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
return new HttpConnector(streamInput);
Expand Down
Loading

0 comments on commit f9454e8

Please sign in to comment.