diff --git a/build.gradle b/build.gradle index 58f9d2195e..e98ea6ec58 100644 --- a/build.gradle +++ b/build.gradle @@ -23,7 +23,7 @@ buildscript { } common_utils_version = System.getProperty("common_utils.version", opensearch_build) - kotlin_version = System.getProperty("kotlin.version", "1.8.21") + kotlin_version = System.getProperty("kotlin.version", "1.9.23") } repositories { diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 15f9389644..2f22dac12b 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -56,7 +56,7 @@ public class CommonValue { public static final String ML_MODEL_INDEX = ".plugins-ml-model"; public static final String ML_TASK_INDEX = ".plugins-ml-task"; public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2; - public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 10; + public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 11; public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector"; public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2; public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 3; @@ -82,59 +82,56 @@ public class CommonValue { + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" + " }\n" + " }\n"; - public static final String ML_MODEL_GROUP_INDEX_MAPPING = "{\n" + - " \"_meta\": {\n" + - " \"schema_version\": " + ML_MODEL_GROUP_INDEX_SCHEMA_VERSION + "\n" + - " },\n" + - " \"properties\": {\n" + - " \"" + MLModelGroup.MODEL_GROUP_NAME_FIELD + "\": {\n" + - " \"type\": \"text\",\n" + - " \"fields\": {\n" + - " \"keyword\": {\n" + - " \"type\": \"keyword\",\n" + - " \"ignore_above\": 256\n" + - " }\n" + - " }\n" + - " },\n" + - " \"" + MLModelGroup.DESCRIPTION_FIELD + "\": {\n" + - " \"type\": \"text\"\n" + - " },\n" + - " \"" + MLModelGroup.LATEST_VERSION_FIELD + "\": {\n" + - " \"type\": \"integer\"\n" + - " },\n" + - " \"" + MLModelGroup.MODEL_GROUP_ID_FIELD + "\": {\n" + - " \"type\": \"keyword\"\n" + - " },\n" + - " \"" + MLModelGroup.BACKEND_ROLES_FIELD + "\": {\n" + - " \"type\": \"text\",\n" + - " \"fields\": {\n" + - " \"keyword\": {\n" + - " \"type\": \"keyword\",\n" + - " \"ignore_above\": 256\n" + - " }\n" + - " }\n" + - " },\n" + - " \"" + MLModelGroup.ACCESS + "\": {\n" + - " \"type\": \"keyword\"\n" + - " },\n" + - " \"" + MLModelGroup.OWNER + "\": {\n" + - " \"type\": \"nested\",\n" + - " \"properties\": {\n" + - " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" - + - " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + - " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + - " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" - + - " }\n" + - " },\n" + - " \"" + MLModelGroup.CREATED_TIME_FIELD + "\": {\n" + - " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + - " \"" + MLModelGroup.LAST_UPDATED_TIME_FIELD + "\": {\n" + - " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + - " }\n" + - "}"; + public static final String ML_MODEL_GROUP_INDEX_MAPPING = "{\n" + + " \"_meta\": {\n" + + " \"schema_version\": " + ML_MODEL_GROUP_INDEX_SCHEMA_VERSION + "\n" + + " },\n" + + " \"properties\": {\n" + + " \"" + MLModelGroup.MODEL_GROUP_NAME_FIELD + "\": {\n" + + " \"type\": \"text\",\n" + + " \"fields\": {\n" + + " \"keyword\": {\n" + + " \"type\": \"keyword\",\n" + + " \"ignore_above\": 256\n" + + " }\n" + + " }\n" + + " },\n" + + " \"" + MLModelGroup.DESCRIPTION_FIELD + "\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"" + MLModelGroup.LATEST_VERSION_FIELD + "\": {\n" + + " \"type\": \"integer\"\n" + + " },\n" + + " \"" + MLModelGroup.MODEL_GROUP_ID_FIELD + "\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"" + MLModelGroup.BACKEND_ROLES_FIELD + "\": {\n" + + " \"type\": \"text\",\n" + + " \"fields\": {\n" + + " \"keyword\": {\n" + + " \"type\": \"keyword\",\n" + + " \"ignore_above\": 256\n" + + " }\n" + + " }\n" + + " },\n" + + " \"" + MLModelGroup.ACCESS + "\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"" + MLModelGroup.OWNER + "\": {\n" + + " \"type\": \"nested\",\n" + + " \"properties\": {\n" + + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" + + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" + + " }\n" + + " },\n" + + " \"" + MLModelGroup.CREATED_TIME_FIELD + "\": {\n" + + " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + MLModelGroup.LAST_UPDATED_TIME_FIELD + "\": {\n" + + " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + + " }\n" + + "}"; public static final String ML_CONNECTOR_INDEX_FIELDS = " \"properties\": {\n" + " \"" @@ -265,45 +262,48 @@ public class CommonValue { + MLModel.LAST_UNDEPLOYED_TIME_FIELD + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + " \"" + + MLModel.INTERFACE_FIELD + + "\": {\"type\": \"flat_object\"},\n" + + " \"" + MLModel.GUARDRAILS_FIELD - + "\" : {\n" + - " \"properties\": {\n" + - " \"input_guardrail\": {\n" + - " \"properties\": {\n" + - " \"regex\": {\n" + - " \"type\": \"text\"\n" + - " },\n" + - " \"stop_words\": {\n" + - " \"properties\": {\n" + - " \"index_name\": {\n" + - " \"type\": \"text\"\n" + - " },\n" + - " \"source_fields\": {\n" + - " \"type\": \"text\"\n" + - " }\n" + - " }\n" + - " }\n" + - " }\n" + - " },\n" + - " \"output_guardrail\": {\n" + - " \"properties\": {\n" + - " \"regex\": {\n" + - " \"type\": \"text\"\n" + - " },\n" + - " \"stop_words\": {\n" + - " \"properties\": {\n" + - " \"index_name\": {\n" + - " \"type\": \"text\"\n" + - " },\n" + - " \"source_fields\": {\n" + - " \"type\": \"text\"\n" + - " }\n" + - " }\n" + - " }\n" + - " }\n" + - " }\n" + - " }\n" + - " },\n" + + "\" : {\n" + + " \"properties\": {\n" + + " \"input_guardrail\": {\n" + + " \"properties\": {\n" + + " \"regex\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"stop_words\": {\n" + + " \"properties\": {\n" + + " \"index_name\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"source_fields\": {\n" + + " \"type\": \"text\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " \"output_guardrail\": {\n" + + " \"properties\": {\n" + + " \"regex\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"stop_words\": {\n" + + " \"properties\": {\n" + + " \"index_name\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"source_fields\": {\n" + + " \"type\": \"text\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + " \"" + MLModel.CONNECTOR_FIELD + "\": {" + ML_CONNECTOR_INDEX_FIELDS + " }\n}," diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index c160306550..363fa4bb7d 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -29,12 +29,18 @@ import java.io.IOException; import java.time.Instant; import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Locale; +import java.util.Map; +import java.util.Set; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.USER; import static org.opensearch.ml.common.connector.Connector.createConnector; +import static org.opensearch.ml.common.utils.StringUtils.filteredParameterMap; @Getter public class MLModel implements ToXContentObject { @@ -89,6 +95,9 @@ public class MLModel implements ToXContentObject { public static final String CONNECTOR_FIELD = "connector"; public static final String CONNECTOR_ID_FIELD = "connector_id"; public static final String GUARDRAILS_FIELD = "guardrails"; + public static final String INTERFACE_FIELD = "interface"; + + public static final Set allowedInterfaceFieldKeys = new HashSet<>(Arrays.asList("input", "output")); private String name; private String modelGroupId; @@ -134,6 +143,36 @@ public class MLModel implements ToXContentObject { private String connectorId; private Guardrails guardrails; + /** + * Model interface is a map that contains the input and output fields of the model, with JSON schema as the value. + * Sample model interface: + * { + * "interface": { + * "input": { + * "properties": { + * "parameters": { + * "properties": { + * "messages": { + * "type": "string", + * "description": "This is a test description field" + * } + * } + * } + * } + * }, + * "output": { + * "properties": { + * "inference_results": { + * "type": "array", + * "description": "This is a test description field" + * } + * } + * } + * } + * } + */ + private Map modelInterface; + @Builder(toBuilder = true) public MLModel(String name, String modelGroupId, @@ -166,7 +205,8 @@ public MLModel(String name, Boolean isHidden, Connector connector, String connectorId, - Guardrails guardrails) { + Guardrails guardrails, + Map modelInterface) { this.name = name; this.modelGroupId = modelGroupId; this.algorithm = algorithm; @@ -200,6 +240,7 @@ public MLModel(String name, this.connector = connector; this.connectorId = connectorId; this.guardrails = guardrails; + this.modelInterface = modelInterface; } public MLModel(StreamInput input) throws IOException { @@ -261,6 +302,9 @@ public MLModel(StreamInput input) throws IOException { if (input.readBoolean()) { this.guardrails = new Guardrails(input); } + if (input.readBoolean()) { + modelInterface = input.readMap(StreamInput::readString, StreamInput::readString); + } } } @@ -338,6 +382,12 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + if (modelInterface != null) { + out.writeBoolean(true); + out.writeMap(modelInterface, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); + } } @Override @@ -442,6 +492,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (guardrails != null) { builder.field(GUARDRAILS_FIELD, guardrails); } + if (modelInterface != null) { + builder.field(INTERFACE_FIELD, modelInterface); + } builder.endObject(); return builder; } @@ -486,6 +539,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws Connector connector = null; String connectorId = null; Guardrails guardrails = null; + Map modelInterface = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -617,6 +671,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws case GUARDRAILS_FIELD: guardrails = Guardrails.parse(parser); break; + case INTERFACE_FIELD: + modelInterface = filteredParameterMap(parser.map(), allowedInterfaceFieldKeys); + break; default: parser.skipChildren(); break; @@ -656,6 +713,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws .connector(connector) .connectorId(connectorId) .guardrails(guardrails) + .modelInterface(modelInterface) .build(); } @@ -663,4 +721,5 @@ public static MLModel fromStream(StreamInput in) throws IOException { MLModel mlModel = new MLModel(in); return mlModel; } + } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index 256aa673a3..5bb00560a2 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -11,9 +11,9 @@ import lombok.extern.log4j.Log4j2; import org.apache.commons.text.StringSubstitutor; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.commons.authuser.User; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.commons.authuser.User; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.AccessMode; @@ -361,7 +361,7 @@ public void decrypt(Function function) { @Override public Connector cloneConnector() { - try (BytesStreamOutput bytesStreamOutput = new BytesStreamOutput()){ + try (BytesStreamOutput bytesStreamOutput = new BytesStreamOutput()) { this.writeTo(bytesStreamOutput); StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); return new HttpConnector(streamInput); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java index b7461e35c8..03047cf692 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java @@ -15,6 +15,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.model.Guardrails; import org.opensearch.ml.common.model.MLDeploySetting; @@ -26,8 +27,12 @@ import java.io.IOException; import java.time.Instant; +import java.util.HashMap; +import java.util.Map; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.MLModel.allowedInterfaceFieldKeys; +import static org.opensearch.ml.common.utils.StringUtils.filteredParameterMap; @Data public class MLUpdateModelInput implements ToXContentObject, Writeable { @@ -66,10 +71,13 @@ public class MLUpdateModelInput implements ToXContentObject, Writeable { private Instant lastUpdateTime; private Guardrails guardrails; + private Map modelInterface; + @Builder(toBuilder = true) public MLUpdateModelInput(String modelId, String description, String version, String name, String modelGroupId, Boolean isEnabled, MLRateLimiter rateLimiter, MLModelConfig modelConfig, MLDeploySetting deploySetting, - Connector updatedConnector, String connectorId, MLCreateConnectorInput connector, Instant lastUpdateTime, Guardrails guardrails) { + Connector updatedConnector, String connectorId, MLCreateConnectorInput connector, Instant lastUpdateTime, + Guardrails guardrails, Map modelInterface) { this.modelId = modelId; this.description = description; this.version = version; @@ -84,6 +92,7 @@ public MLUpdateModelInput(String modelId, String description, String version, St this.connector = connector; this.lastUpdateTime = lastUpdateTime; this.guardrails = guardrails; + this.modelInterface = modelInterface; } public MLUpdateModelInput(StreamInput in) throws IOException { @@ -116,56 +125,15 @@ public MLUpdateModelInput(StreamInput in) throws IOException { this.deploySetting = new MLDeploySetting(in); } } + if (streamInputVersion.onOrAfter(MLRegisterModelInput.MINIMAL_SUPPORTED_VERSION_FOR_INTERFACE)) { + if (in.readBoolean()) { + modelInterface = in.readMap(StreamInput::readString, StreamInput::readString); + } + } } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(MODEL_ID_FIELD, modelId); - if (name != null) { - builder.field(MODEL_NAME_FIELD, name); - } - if (description != null) { - builder.field(DESCRIPTION_FIELD, description); - } - if (version != null) { - builder.field(MODEL_VERSION_FIELD, version); - } - if (modelGroupId != null) { - builder.field(MODEL_GROUP_ID_FIELD, modelGroupId); - } - if (isEnabled != null) { - builder.field(IS_ENABLED_FIELD, isEnabled); - } - if (rateLimiter != null) { - builder.field(RATE_LIMITER_FIELD, rateLimiter); - } - if (modelConfig != null) { - builder.field(MODEL_CONFIG_FIELD, modelConfig); - } - if (deploySetting != null) { - builder.field(DEPLOY_SETTING_FIELD, deploySetting); - } - if (updatedConnector != null) { - builder.field(UPDATED_CONNECTOR_FIELD, updatedConnector); - } - if (connectorId != null) { - builder.field(CONNECTOR_ID_FIELD, connectorId); - } - if (connector != null) { - builder.field(CONNECTOR_FIELD, connector); - } - if (lastUpdateTime != null) { - builder.field(LAST_UPDATED_TIME_FIELD, lastUpdateTime.toEpochMilli()); - } - if (guardrails != null) { - builder.field(GUARDRAILS_FIELD, guardrails); - } - builder.endObject(); - return builder; - } - - public XContentBuilder toXContentForUpdateRequestDoc(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(MODEL_ID_FIELD, modelId); if (name != null) { @@ -205,6 +173,9 @@ public XContentBuilder toXContentForUpdateRequestDoc(XContentBuilder builder, Pa if (guardrails != null) { builder.field(GUARDRAILS_FIELD, guardrails); } + if (modelInterface != null) { + builder.field(MLModel.INTERFACE_FIELD, modelInterface); + } builder.endObject(); return builder; } @@ -258,6 +229,14 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } } + if (streamOutputVersion.onOrAfter(MLRegisterModelInput.MINIMAL_SUPPORTED_VERSION_FOR_INTERFACE)) { + if (modelInterface != null) { + out.writeBoolean(true); + out.writeMap(modelInterface, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); + } + } } public static MLUpdateModelInput parse(XContentParser parser) throws IOException { @@ -275,6 +254,7 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException MLCreateConnectorInput connector = null; Instant lastUpdateTime = null; Guardrails guardrails = null; + Map modelInterface = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -311,6 +291,9 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException case GUARDRAILS_FIELD: guardrails = Guardrails.parse(parser); break; + case MLModel.INTERFACE_FIELD: + modelInterface = filteredParameterMap(parser.map(), allowedInterfaceFieldKeys); + break; default: parser.skipChildren(); break; @@ -319,6 +302,7 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException // Model ID can only be set through RestRequest. Model version can only be set // automatically. return new MLUpdateModelInput(modelId, description, version, name, modelGroupId, isEnabled, rateLimiter, - modelConfig, deploySetting, updatedConnector, connectorId, connector, lastUpdateTime, guardrails); + modelConfig, deploySetting, updatedConnector, connectorId, connector, lastUpdateTime, guardrails, + modelInterface); } } \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java index 1d2bbf0943..bffa04328b 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java @@ -29,12 +29,16 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.Objects; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.MLModel.allowedInterfaceFieldKeys; import static org.opensearch.ml.common.connector.Connector.createConnector; +import static org.opensearch.ml.common.utils.StringUtils.filteredParameterMap; /** * ML input data: algirithm name, parameters and input data set. @@ -67,6 +71,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { public static final Version MINIMAL_SUPPORTED_VERSION_FOR_DOES_VERSION_CREATE_MODEL_GROUP = Version.V_2_11_0; public static final Version MINIMAL_SUPPORTED_VERSION_FOR_AGENT_FRAMEWORK = Version.V_2_12_0; public static final Version MINIMAL_SUPPORTED_VERSION_FOR_GUARDRAILS_AND_AUTO_DEPLOY = Version.V_2_13_0; + public static final Version MINIMAL_SUPPORTED_VERSION_FOR_INTERFACE = Version.V_2_14_0; private FunctionName functionName; private String modelName; @@ -95,6 +100,8 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { private Boolean isHidden; private Guardrails guardrails; + private Map modelInterface; + @Builder(toBuilder = true) public MLRegisterModelInput(FunctionName functionName, String modelName, @@ -117,7 +124,8 @@ public MLRegisterModelInput(FunctionName functionName, AccessMode accessMode, Boolean doesVersionCreateModelGroup, Boolean isHidden, - Guardrails guardrails) { + Guardrails guardrails, + Map modelInterface) { this.functionName = Objects.requireNonNullElse(functionName, FunctionName.TEXT_EMBEDDING); if (modelName == null) { throw new IllegalArgumentException("model name is null"); @@ -155,6 +163,7 @@ public MLRegisterModelInput(FunctionName functionName, this.doesVersionCreateModelGroup = doesVersionCreateModelGroup; this.isHidden = isHidden; this.guardrails = guardrails; + this.modelInterface = modelInterface; } public MLRegisterModelInput(StreamInput in) throws IOException { @@ -209,6 +218,11 @@ public MLRegisterModelInput(StreamInput in) throws IOException { this.deploySetting = new MLDeploySetting(in); } } + if (streamInputVersion.onOrAfter(MLRegisterModelInput.MINIMAL_SUPPORTED_VERSION_FOR_INTERFACE)) { + if (in.readBoolean()) { + this.modelInterface = in.readMap(StreamInput::readString, StreamInput::readString); + } + } } @Override @@ -282,6 +296,14 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } } + if (streamOutputVersion.onOrAfter(MLRegisterModelInput.MINIMAL_SUPPORTED_VERSION_FOR_INTERFACE)) { + if (modelInterface != null) { + out.writeBoolean(true); + out.writeMap(modelInterface, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); + } + } } @Override @@ -347,6 +369,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (guardrails != null) { builder.field(GUARDRAILS_FIELD, guardrails); } + if (modelInterface != null) { + builder.field(MLModel.INTERFACE_FIELD, modelInterface); + } builder.endObject(); return builder; } @@ -372,6 +397,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName Boolean doesVersionCreateModelGroup = null; Boolean isHidden = null; Guardrails guardrails = null; + Map modelInterface = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -445,6 +471,9 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName case GUARDRAILS_FIELD: guardrails = Guardrails.parse(parser); break; + case MLModel.INTERFACE_FIELD: + modelInterface = filteredParameterMap(parser.map(), allowedInterfaceFieldKeys); + break; default: parser.skipChildren(); break; @@ -453,7 +482,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, isEnabled, rateLimiter, url, hashValue, modelFormat, modelConfig, deploySetting, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup, - isHidden, guardrails); + isHidden, guardrails, modelInterface); } public static MLRegisterModelInput parse(XContentParser parser, boolean deployModel) throws IOException { @@ -478,6 +507,7 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo Boolean doesVersionCreateModelGroup = null; Boolean isHidden = null; Guardrails guardrails = null; + Map modelInterface = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -558,13 +588,17 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo case GUARDRAILS_FIELD: guardrails = Guardrails.parse(parser); break; + case MLModel.INTERFACE_FIELD: + modelInterface = filteredParameterMap(parser.map(), allowedInterfaceFieldKeys); + break; default: parser.skipChildren(); break; } } return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, isEnabled, rateLimiter, - url, hashValue, modelFormat, modelConfig, deploySetting, deployModel, modelNodeIds.toArray(new String[0]), connector, - connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup, isHidden, guardrails); + url, hashValue, modelFormat, modelConfig, deploySetting, deployModel, modelNodeIds.toArray(new String[0]), + connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup, + isHidden, guardrails, modelInterface); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java index 50a6390cbb..d56120aa5c 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java @@ -29,10 +29,14 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Locale; +import java.util.Map; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.MLModel.allowedInterfaceFieldKeys; +import static org.opensearch.ml.common.utils.StringUtils.filteredParameterMap; @Data public class MLRegisterModelMetaInput implements ToXContentObject, Writeable { @@ -79,6 +83,8 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable { private Boolean doesVersionCreateModelGroup; private Boolean isHidden; + private Map modelInterface; + @Builder(toBuilder = true) public MLRegisterModelMetaInput(String name, FunctionName functionName, String modelGroupId, String version, String description, Boolean isEnabled, MLRateLimiter rateLimiter, MLModelFormat modelFormat, @@ -86,7 +92,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m MLModelConfig modelConfig, MLDeploySetting deploySetting, Integer totalChunks, List backendRoles, AccessMode accessMode, Boolean isAddAllBackendRoles, - Boolean doesVersionCreateModelGroup, Boolean isHidden) { + Boolean doesVersionCreateModelGroup, Boolean isHidden, Map modelInterface) { if (name == null) { throw new IllegalArgumentException("model name is null"); } @@ -129,6 +135,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m this.isAddAllBackendRoles = isAddAllBackendRoles; this.doesVersionCreateModelGroup = doesVersionCreateModelGroup; this.isHidden = isHidden; + this.modelInterface = modelInterface; } public MLRegisterModelMetaInput(StreamInput in) throws IOException { @@ -174,6 +181,11 @@ public MLRegisterModelMetaInput(StreamInput in) throws IOException { this.deploySetting = new MLDeploySetting(in); } } + if (streamInputVersion.onOrAfter(MLRegisterModelInput.MINIMAL_SUPPORTED_VERSION_FOR_INTERFACE)) { + if (in.readBoolean()) { + this.modelInterface = in.readMap(StreamInput::readString, StreamInput::readString); + } + } } @Override @@ -239,6 +251,14 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } } + if (streamOutputVersion.onOrAfter(MLRegisterModelInput.MINIMAL_SUPPORTED_VERSION_FOR_INTERFACE)) { + if (modelInterface != null) { + out.writeBoolean(true); + out.writeMap(modelInterface, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); + } + } } @Override @@ -289,6 +309,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (isHidden != null) { builder.field(MLModel.IS_HIDDEN_FIELD, isHidden); } + if (modelInterface != null) { + builder.field(MLModel.INTERFACE_FIELD, modelInterface); + } builder.endObject(); return builder; } @@ -313,6 +336,7 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc Boolean isAddAllBackendRoles = null; Boolean doesVersionCreateModelGroup = null; Boolean isHidden = null; + Map modelInterface = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -384,6 +408,8 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc case MLModel.IS_HIDDEN_FIELD: isHidden = parser.booleanValue(); break; + case MLModel.INTERFACE_FIELD: + modelInterface = filteredParameterMap(parser.map(), allowedInterfaceFieldKeys); default: parser.skipChildren(); break; @@ -391,7 +417,8 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc } return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, isEnabled, rateLimiter, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, - deploySetting, totalChunks, backendRoles, accessMode, isAddAllBackendRoles, doesVersionCreateModelGroup, isHidden); + deploySetting, totalChunks, backendRoles, accessMode, isAddAllBackendRoles, doesVersionCreateModelGroup, + isHidden, modelInterface); } } diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index 4b15444a58..b9c69cd194 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -14,18 +14,28 @@ import org.json.JSONArray; import org.json.JSONException; import org.json.JSONObject; +import org.opensearch.OpenSearchParseException; +import java.io.IOException; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD; @Log4j2 public class StringUtils { @@ -95,6 +105,28 @@ public static Map fromJson(String jsonStr, String defaultKey) { return result; } + public static Map filteredParameterMap(Map parameterObjs, Set allowedList) { + Map parameters = new HashMap<>(); + Set filteredKeys = new HashSet<>(parameterObjs.keySet()); + filteredKeys.retainAll(allowedList); + for (String key : filteredKeys) { + Object value = parameterObjs.get(key); + try { + AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + if (value instanceof String) { + parameters.put(key, (String)value); + } else { + parameters.put(key, gson.toJson(value)); + } + return null; + }); + } catch (PrivilegedActionException e) { + throw new RuntimeException(e); + } + } + return parameters; + } + @SuppressWarnings("removal") public static Map getParameterMap(Map parameterObjs) { Map parameters = new HashMap<>(); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java index 6f4018165f..d656ae5134 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java @@ -36,7 +36,6 @@ import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; -import org.opensearch.ml.common.transport.model.MLUpdateModelInput; import org.opensearch.search.SearchModule; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; @@ -85,7 +84,7 @@ public class MLUpdateModelInputTest { "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":" + - "\"test-connector_id\",\"connector\":{\"description\":\"updated description\",\"version\":\"1\",\"parameters\":{},\"credential\":{}}}"; + "\"test-connector_id\"}"; @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -167,24 +166,9 @@ public void readInputStreamSuccessWithNullFields() throws IOException { @Test public void testToXContent() throws Exception { String jsonStr = serializationWithToXContent(updateModelInput); - assertEquals(expectedInputStr, jsonStr); - } - - @Test - public void testToXContentForUpdateRequestDoc() throws Exception { - String jsonStr = serializationWithToXContentForUpdateRequestDoc(updateModelInput); assertEquals(expectedOutputStrForUpdateRequestDoc, jsonStr); } - @Test - public void testToXContenttForUpdateRequestDocIncomplete() throws Exception { - String expectedIncompleteInputStr = "{\"model_id\":\"test-model_id\"}"; - updateModelInput = MLUpdateModelInput.builder() - .modelId("test-model_id").build(); - String jsonStr = serializationWithToXContentForUpdateRequestDoc(updateModelInput); - assertEquals(expectedIncompleteInputStr, jsonStr); - } - @Test public void testToXContentIncomplete() throws Exception { String expectedIncompleteInputStr = "{\"model_id\":\"test-model_id\"}"; @@ -270,10 +254,4 @@ private String serializationWithToXContent(MLUpdateModelInput input) throws IOEx return builder.toString(); } - private String serializationWithToXContentForUpdateRequestDoc(MLUpdateModelInput input) throws IOException { - XContentBuilder builder = XContentFactory.jsonBuilder(); - input.toXContentForUpdateRequestDoc(builder, ToXContent.EMPTY_PARAMS); - assertNotNull(builder); - return builder.toString(); - } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java index d4e95a7244..667d4276c5 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java @@ -45,7 +45,7 @@ public void setup() { mLRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "model_group_id", "1.0", "Model Description", null, null, MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, null, 2, - null, null, false, false, false); + null, null, false, false, false, null); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java index 433ce3acdb..bbf64f5688 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java @@ -33,7 +33,7 @@ public void setUp() { config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config", TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512); mlRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "Model Group Id", "1.0", - "Model Description", null, null, MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, null, 2, null, null, null, null, null); + "Model Description", null, null, MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, null, 2, null, null, null, null, null, null); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java index 6fe56fd5ce..b2aa45e068 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java @@ -10,9 +10,12 @@ import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; +import static java.util.stream.Collectors.toList; import static org.junit.Assert.assertEquals; public class StringUtilsTest { @@ -117,6 +120,21 @@ public void getParameterMap() { assertEquals("[1.01,\"abc\"]", parameterMap.get("key5")); } + @Test + public void getInterfaceMap() { + final Set allowedInterfaceFieldNameList = new HashSet<>(Arrays.asList("input","output")); + Map parameters = new HashMap<>(); + parameters.put("input", "value1"); + parameters.put("output", 2); + parameters.put("key3", 2.1); + parameters.put("key4", new int[]{10, 20}); + parameters.put("key5", new Object[]{1.01, "abc"}); + Map interfaceMap = StringUtils.filteredParameterMap(parameters, allowedInterfaceFieldNameList); + Assert.assertEquals(2, interfaceMap.size()); + Assert.assertEquals("value1", interfaceMap.get("input")); + Assert.assertEquals("2", interfaceMap.get("output")); + } + @Test public void processTextDocs() { List processedDocs = StringUtils.processTextDocs(Arrays.asList("abc \n\n123\"4", null, "[1.01,\"abc\"]")); diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index bb900fe99e..251ea79a1e 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -24,9 +24,7 @@ dependencies { implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}" testImplementation "org.opensearch.test:framework:${opensearch_version}" implementation "org.opensearch:common-utils:${common_utils_version}" - implementation ("org.jetbrains.kotlin:kotlin-stdlib:${kotlin_version}") { - exclude group: "org.jetbrains", module: "annotations" - } + implementation ("org.jetbrains.kotlin:kotlin-stdlib:${kotlin_version}") implementation group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' implementation group: 'org.reflections', name: 'reflections', version: '0.9.12' implementation group: 'org.tribuo', name: 'tribuo-clustering-kmeans', version: '4.2.1' diff --git a/plugin/build.gradle b/plugin/build.gradle index c2ff6931bd..6d2986358e 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -52,6 +52,7 @@ dependencies { implementation "org.opensearch:common-utils:${common_utils_version}" implementation("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}") implementation("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}") + implementation group: 'com.networknt' , name: 'json-schema-validator', version: '1.4.0' implementation group: 'com.google.guava', name: 'guava', version: '32.1.2-jre' implementation group: 'com.google.code.gson', name: 'gson', version: '2.10.1' implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.10' @@ -352,6 +353,7 @@ configurations.all { resolutionStrategy.force 'commons-codec:commons-codec:1.15' resolutionStrategy.force 'org.slf4j:slf4j-api:1.7.36' resolutionStrategy.force 'org.codehaus.plexus:plexus-utils:3.3.0' + exclude group: "org.jetbrains", module: "annotations" } apply plugin: 'com.netflix.nebula.ospackage' diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java index f328a39d16..2037996ffe 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java @@ -206,13 +206,13 @@ private void updateRemoteOrTextEmbeddingModel( String newConnectorId = Strings.hasLength(updateModelInput.getConnectorId()) ? updateModelInput.getConnectorId() : null; boolean isModelDeployed = isModelDeployed(mlModel.getModelState()); // This flag is used to decide if we need to re-deploy the predictor(model) when updating the model cache. - // If one of the internal connector, stand-alone connector id, model quota flag, as well as the model rate limiter and guardrails - // need update, we - // need to perform a re-deploy. + // If one of the internal connector, stand-alone connector id, model quota flag, model rate limiter, model interface, + // and guardrails need update, we need to perform a re-deployment. boolean isPredictorUpdate = (updateModelInput.getConnector() != null) || (newConnectorId != null) || !Objects.equals(updateModelInput.getIsEnabled(), mlModel.getIsEnabled()) - || (updateModelInput.getGuardrails() != null); + || (updateModelInput.getGuardrails() != null) + || (updateModelInput.getModelInterface() != null); if (MLRateLimiter.updateValidityPreCheck(mlModel.getRateLimiter(), updateModelInput.getRateLimiter())) { MLRateLimiter updatedRateLimiterConfig = MLRateLimiter.update(mlModel.getRateLimiter(), updateModelInput.getRateLimiter()); updateModelInput.setRateLimiter(updatedRateLimiterConfig); @@ -379,7 +379,7 @@ private void buildUpdateRequest( ) { try { updateModelInput.setLastUpdateTime(Instant.now()); - updateRequest.doc(updateModelInput.toXContentForUpdateRequestDoc(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); + updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); updateRequest.docAsUpsert(true); updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); if (isUpdateModelCache) { @@ -420,7 +420,7 @@ private void buildUpdateRequest( Integer.parseInt(updatedVersion) ); try { - updateRequest.doc(updateModelInput.toXContentForUpdateRequestDoc(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); + updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); updateRequest.docAsUpsert(true); updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); if (isUpdateModelCache) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index b396b39016..94ed36214a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -17,13 +17,16 @@ import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; @@ -33,6 +36,7 @@ import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.task.MLPredictTaskRunner; import org.opensearch.ml.task.MLTaskRunner; +import org.opensearch.ml.utils.MLNodeUtils; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -152,9 +156,11 @@ public void onResponse(MLModel mlModel) { ) ); } else { + validateInputSchema(modelId, mlPredictionTaskRequest.getMlInput()); executePredict(mlPredictionTaskRequest, wrappedListener, modelId); } } else { + validateInputSchema(modelId, mlPredictionTaskRequest.getMlInput()); executePredict(mlPredictionTaskRequest, wrappedListener, modelId); } } @@ -228,4 +234,20 @@ private void executePredict( }) ); } + + public void validateInputSchema(String modelId, MLInput mlInput) { + if (modelCacheHelper.getModelInterface(modelId) != null && modelCacheHelper.getModelInterface(modelId).get("input") != null) { + String inputSchemaString = modelCacheHelper.getModelInterface(modelId).get("input"); + try { + MLNodeUtils + .validateSchema( + inputSchemaString, + mlInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString() + ); + } catch (Exception e) { + throw new OpenSearchStatusException("Error validating input schema: " + e.getMessage(), RestStatus.BAD_REQUEST); + } + } + } + } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java index 61d27576d1..732d28663b 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java @@ -48,6 +48,7 @@ public class MLModelCache { private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Long memSizeEstimationCPU; private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Long memSizeEstimationGPU; private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) MLGuard mlGuard; + private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Map modelInterface; // In rare case, this could be null, e.g. model info not synced up yet a predict request comes in. @Setter @@ -171,6 +172,7 @@ public void clear() { rateLimiter = null; userRateLimiterMap = null; mlGuard = null; + modelInterface = null; } public void addModelInferenceDuration(double duration, long maxRequestCount) { diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index 429323297c..6c4b12fcf4 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -187,10 +187,44 @@ public TokenBucket getUserRateLimiter(String modelId, String user) { return userRateLimiterMap.get(user); } + /** + * Set the ml interface for the model + * + * @param modelId model id + * @param modelInterface model interface + */ + public synchronized void setModelInterface(String modelId, Map modelInterface) { + log.debug("Setting ML Interface {} for Model {}", modelInterface, modelId); + getExistingModelCache(modelId).setModelInterface(modelInterface); + } + + /** + * Get the current ml interface for the model + * + * @param modelId model id + */ + public Map getModelInterface(String modelId) { + MLModelCache modelCache = modelCaches.get(modelId); + if (modelCache == null) { + return null; + } + return modelCache.getModelInterface(); + } + + /** + * Remove the ml interface from cache + * + * @param modelId model id + */ + public synchronized void removeModelInterface(String modelId) { + log.debug("Removing the ML Interface from Model {}", modelId); + getExistingModelCache(modelId).setModelInterface(null); + } + /** * Set a ml guard * - * @param modelId model id + * @param modelId model id * @param mlGuard mlGuard */ public synchronized void setMLGuard(String modelId, MLGuard mlGuard) { diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 8ff53fb890..10b6af2333 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -294,6 +294,7 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput .version(version) .modelGroupId(mlRegisterModelMetaInput.getModelGroupId()) .description(mlRegisterModelMetaInput.getDescription()) + .isEnabled(mlRegisterModelMetaInput.getIsEnabled()) .rateLimiter(mlRegisterModelMetaInput.getRateLimiter()) .isEnabled(mlRegisterModelMetaInput.getIsEnabled()) .modelFormat(mlRegisterModelMetaInput.getModelFormat()) @@ -304,6 +305,7 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput .modelContentHash(mlRegisterModelMetaInput.getModelContentHashValue()) .modelContentSizeInBytes(mlRegisterModelMetaInput.getModelContentSizeInBytes()) .isHidden(mlRegisterModelMetaInput.getIsHidden()) + .modelInterface(mlRegisterModelMetaInput.getModelInterface()) .createdTime(now) .lastUpdateTime(now) .build(); @@ -539,6 +541,7 @@ private void indexRemoteModel( .lastUpdateTime(now) .isHidden(registerModelInput.getIsHidden()) .guardrails(registerModelInput.getGuardrails()) + .modelInterface(registerModelInput.getModelInterface()) .build(); IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX); @@ -606,6 +609,7 @@ void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, St .lastUpdateTime(now) .isHidden(registerModelInput.getIsHidden()) .guardrails(registerModelInput.getGuardrails()) + .modelInterface(registerModelInput.getModelInterface()) .build(); IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX); if (registerModelInput.getIsHidden() != null && registerModelInput.getIsHidden()) { @@ -673,6 +677,7 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas .lastUpdateTime(now) .isHidden(registerModelInput.getIsHidden()) .guardrails(registerModelInput.getGuardrails()) + .modelInterface(registerModelInput.getModelInterface()) .build(); IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX); if (functionName == FunctionName.METRICS_CORRELATION) { @@ -750,6 +755,7 @@ private void registerModel( .version(version) .modelFormat(registerModelInput.getModelFormat()) .rateLimiter(registerModelInput.getRateLimiter()) + .isEnabled(registerModelInput.getIsEnabled()) .chunkNumber(chunkNum) .totalChunks(chunkFiles.size()) .content(Base64.getEncoder().encodeToString(bytes)) @@ -757,6 +763,7 @@ private void registerModel( .lastUpdateTime(now) .isHidden(registerModelInput.getIsHidden()) .guardrails(registerModelInput.getGuardrails()) + .modelInterface(registerModelInput.getModelInterface()) .build(); IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX); if (registerModelInput.getIsHidden() != null && registerModelInput.getIsHidden()) { @@ -1023,6 +1030,7 @@ public void deployModel( setupRateLimiter(modelId, eligibleNodeCount, mlModel.getRateLimiter()); setupMLGuard(modelId, mlModel.getGuardrails()); + setupModelInterface(modelId, mlModel.getModelInterface()); deployControllerWithDeployingModel(mlModel, eligibleNodeCount); // check circuit breaker before deploying custom model chunks checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats); @@ -1104,6 +1112,7 @@ private void deployRemoteOrBuiltInModel(MLModel mlModel, Integer eligibleNodeCou String modelId = mlModel.getModelId(); setupRateLimiter(modelId, eligibleNodeCount, mlModel.getRateLimiter()); setupMLGuard(modelId, mlModel.getGuardrails()); + setupModelInterface(modelId, mlModel.getModelInterface()); if (mlModel.getConnector() != null || FunctionName.REMOTE != mlModel.getAlgorithm()) { setupParamsAndPredictable(modelId, mlModel); mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); @@ -1160,7 +1169,6 @@ private Map setUpParameterMap(String modelId) { params.put(GUARDRAILS, mlGuard); log.info("Setting up ML guard parameter for ML predictor."); } - return Collections.unmodifiableMap(params); } @@ -1184,6 +1192,7 @@ public synchronized void updateModelCache(String modelId, ActionListener modelCacheHelper.setIsModelEnabled(modelId, mlModel.getIsEnabled()); setupRateLimiter(modelId, eligibleNodeCount, mlModel.getRateLimiter()); setupMLGuard(modelId, mlModel.getGuardrails()); + setupModelInterface(modelId, mlModel.getModelInterface()); if (mlModel.getAlgorithm() == FunctionName.REMOTE) { if (mlModel.getConnector() != null) { setupParamsAndPredictable(modelId, mlModel); @@ -1480,6 +1489,35 @@ public Map getUserRateLimiterMap(String modelId) { return modelCacheHelper.getUserRateLimiterMap(modelId); } + /** + * Set up model interface with model id. + */ + private void setupModelInterface(String modelId, Map modelInterface) { + log.debug("Model interface for model: {} loaded into cache.", modelId); + if (modelInterface != null) { + modelCacheHelper.setModelInterface(modelId, modelInterface); + } else { + modelCacheHelper.removeModelInterface(modelId); + } + } + + /** + * Get model interface with model id. + * + * @param modelId model id + * @return a Map containing the model interface + */ + public Map getModelInterface(String modelId) { + return modelCacheHelper.getModelInterface(modelId); + } + + /** + * Set up ML guard with model id. + * + * @param modelId + * @param guardrails + */ + private void setupMLGuard(String modelId, Guardrails guardrails) { if (guardrails != null) { modelCacheHelper.setMLGuard(modelId, createMLGuard(guardrails, xContentRegistry, client)); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 16176f197e..ec3fd03b32 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -20,6 +20,7 @@ import java.util.UUID; import org.opensearch.OpenSearchException; +import org.opensearch.OpenSearchStatusException; import org.opensearch.ResourceNotFoundException; import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.action.get.GetRequest; @@ -31,10 +32,13 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.cluster.DiscoveryNodeHelper; @@ -48,6 +52,7 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLPredictionOutput; +import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; @@ -61,6 +66,7 @@ import org.opensearch.ml.stats.MLActionLevelStat; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStats; +import org.opensearch.ml.utils.MLNodeUtils; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; @@ -324,6 +330,9 @@ private void runPredict( if (mlInput.getAlgorithm() == FunctionName.REMOTE) { long startTime = System.nanoTime(); ActionListener trackPredictDurationListener = ActionListener.wrap(output -> { + if (output.getOutput() instanceof ModelTensorOutput) { + validateOutputSchema(modelId, (ModelTensorOutput) output.getOutput()); + } handleAsyncMLTaskComplete(mlTask); mlModelManager.trackPredictDuration(modelId, startTime); internalListener.onResponse(output); @@ -334,6 +343,9 @@ private void runPredict( if (output instanceof MLPredictionOutput) { ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name()); } + if (output instanceof ModelTensorOutput) { + validateOutputSchema(modelId, (ModelTensorOutput) output); + } // Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state handleAsyncMLTaskComplete(mlTask); internalListener.onResponse(new MLTaskResponse(output)); @@ -383,7 +395,9 @@ private void runPredict( if (output instanceof MLPredictionOutput) { ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name()); } - + if (output instanceof ModelTensorOutput) { + validateOutputSchema(modelId, (ModelTensorOutput) output); + } // Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state handleAsyncMLTaskComplete(mlTask); MLTaskResponse response = MLTaskResponse.builder().output(output).build(); @@ -439,4 +453,19 @@ private void handlePredictFailure( handleAsyncMLTaskFailure(mlTask, e); listener.onFailure(e); } + + public void validateOutputSchema(String modelId, ModelTensorOutput output) { + if (mlModelManager.getModelInterface(modelId) != null && mlModelManager.getModelInterface(modelId).get("output") != null) { + String outputSchemaString = mlModelManager.getModelInterface(modelId).get("output"); + try { + MLNodeUtils + .validateSchema( + outputSchemaString, + output.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString() + ); + } catch (Exception e) { + throw new OpenSearchStatusException("Error validating output schema: " + e.getMessage(), RestStatus.BAD_REQUEST); + } + } + } } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java index f5c0f5ba52..227518aabf 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java @@ -9,9 +9,11 @@ import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_ROLE_NAME; import java.io.IOException; +import java.util.Arrays; import java.util.Set; import java.util.function.Function; +import org.opensearch.OpenSearchParseException; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentHelper; @@ -25,6 +27,13 @@ import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStats; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.networknt.schema.JsonSchema; +import com.networknt.schema.JsonSchemaFactory; +import com.networknt.schema.SpecVersion.VersionFlag; +import com.networknt.schema.ValidationMessage; + import lombok.experimental.UtilityClass; @UtilityClass @@ -56,6 +65,29 @@ public static void parseField(XContentParser parser, Set set, Function errors = schema.validate(jsonNode); + if (!errors.isEmpty()) { + throw new OpenSearchParseException( + "Validation failed: " + + Arrays.toString(errors.toArray(new ValidationMessage[0])) + + " for instance: " + + instanceString + + " with schema: " + + schemaString + ); + } + } + public static void checkOpenCircuitBreaker(MLCircuitBreakerService mlCircuitBreakerService, MLStats mlStats) { ThresholdCircuitBreaker openCircuitBreaker = mlCircuitBreakerService.checkOpenCB(); if (openCircuitBreaker != null) { diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java index da1280f46c..942a968cf0 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java @@ -648,9 +648,7 @@ public void testUpdateRequestDocIOException() throws IOException { doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); doReturn(MLModelState.REGISTERED).when(mockModel).getModelState(); - doThrow(new IOException("Exception occurred during building update request.")) - .when(mockUpdateModelInput) - .toXContentForUpdateRequestDoc(any(), any()); + doThrow(new IOException("Exception occurred during building update request.")).when(mockUpdateModelInput).toXContent(any(), any()); transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IOException.class); verify(actionListener).onFailure(argumentCaptor.capture()); @@ -699,9 +697,7 @@ public void testUpdateRequestDocInRegisterToNewModelGroupIOException() throws IO return null; }).when(mlModelGroupManager).getModelGroupResponse(eq("mockUpdateModelGroupId"), isA(ActionListener.class)); - doThrow(new IOException("Exception occurred during building update request.")) - .when(mockUpdateModelInput) - .toXContentForUpdateRequestDoc(any(), any()); + doThrow(new IOException("Exception occurred during building update request.")).when(mockUpdateModelInput).toXContent(any(), any()); transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IOException.class); verify(actionListener).onFailure(argumentCaptor.capture()); diff --git a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java index a1832dcd62..aa7afdce6e 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java @@ -18,6 +18,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.Map; import org.junit.Before; import org.junit.Rule; @@ -41,6 +42,7 @@ import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DataFrameBuilder; import org.opensearch.ml.common.dataset.DataFrameInputDataset; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; @@ -233,4 +235,50 @@ public void testPrediction_MLResourceNotFoundException() { assertEquals("Testing MLResourceNotFoundException", argumentCaptor.getValue().getMessage()); } + public void testValidateInputSchemaSuccess() { + RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters( + Map + .of( + "messages", + "[{\\\"role\\\":\\\"system\\\",\\\"content\\\":\\\"You are a helpful assistant.\\\"}," + + "{\\\"role\\\":\\\"user\\\",\\\"content\\\":\\\"Hello!\\\"}]" + ) + ) + .build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build(); + Map modelInterface = Map + .of( + "input", + "{\"properties\":{\"parameters\":{\"properties\":{\"messages\":{" + + "\"description\":\"This is a test description field\",\"type\":\"string\"}}}}}" + ); + when(modelCacheHelper.getModelInterface(any())).thenReturn(modelInterface); + transportPredictionTaskAction.validateInputSchema("testId", mlInput); + } + + public void testValidateInputSchemaFailed() { + exceptionRule.expect(OpenSearchStatusException.class); + RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters( + Map + .of( + "messages", + "[{\\\"role\\\":\\\"system\\\",\\\"content\\\":\\\"You are a helpful assistant.\\\"}," + + "{\\\"role\\\":\\\"user\\\",\\\"content\\\":\\\"Hello!\\\"}]" + ) + ) + .build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build(); + Map modelInterface = Map + .of( + "input", + "{\"properties\":{\"parameters\":{\"properties\":{\"messages\":{" + + "\"description\":\"This is a test description field\",\"type\":\"integer\"}}}}}" + ); + when(modelCacheHelper.getModelInterface(any())).thenReturn(modelInterface); + transportPredictionTaskAction.validateInputSchema("testId", mlInput); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index 0f90528fab..cbde703543 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -26,6 +26,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; import org.opensearch.Version; import org.opensearch.action.get.GetResponse; import org.opensearch.client.Client; @@ -56,7 +57,9 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.parameter.rcf.BatchRCFParams; import org.opensearch.ml.common.output.MLPredictionOutput; +import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; @@ -390,6 +393,45 @@ public void testExecuteTask_OnLocalNode_NullGetResponse() { assertEquals("No model found, please check the modelId.", argumentCaptor.getValue().getMessage()); } + public void testValidateModelTensorOutputSuccess() { + ModelTensor modelTensor = ModelTensor + .builder() + .name("response") + .dataAsMap(Map.of("id", "chatcmpl-9JUSY2myXUjGBUrG0GO5niEAY5NKm")) + .build(); + Map modelInterface = Map + .of( + "output", + "{\"properties\":{\"inference_results\":{\"description\":\"This is a test description field\"," + "\"type\":\"array\"}}}" + ); + ModelTensorOutput modelTensorOutput = ModelTensorOutput + .builder() + .mlModelOutputs(List.of(ModelTensors.builder().mlModelTensors(List.of(modelTensor)).build())) + .build(); + when(mlModelManager.getModelInterface(any())).thenReturn(modelInterface); + taskRunner.validateOutputSchema("testId", modelTensorOutput); + } + + public void testValidateModelTensorOutputFailed() { + exceptionRule.expect(OpenSearchStatusException.class); + ModelTensor modelTensor = ModelTensor + .builder() + .name("response") + .dataAsMap(Map.of("id", "chatcmpl-9JUSY2myXUjGBUrG0GO5niEAY5NKm")) + .build(); + Map modelInterface = Map + .of( + "output", + "{\"properties\":{\"inference_results\":{\"description\":\"This is a test description field\"," + "\"type\":\"string\"}}}" + ); + ModelTensorOutput modelTensorOutput = ModelTensorOutput + .builder() + .mlModelOutputs(List.of(ModelTensors.builder().mlModelTensors(List.of(modelTensor)).build())) + .build(); + when(mlModelManager.getModelInterface(any())).thenReturn(modelInterface); + taskRunner.validateOutputSchema("testId", modelTensorOutput); + } + private void setupMocks(boolean runOnLocalNode, boolean failedToParseQueryInput, boolean failedToGetModel, boolean nullGetResponse) { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); diff --git a/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java b/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java index 11be98488e..7838308834 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java @@ -13,6 +13,7 @@ import java.util.Set; import org.junit.Assert; +import org.junit.Test; import org.opensearch.Version; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodeRole; @@ -49,4 +50,17 @@ public void testCreateXContentParserFromRegistry() throws IOException { MLTask parsedMLTask = MLTask.parse(xContentParser); assertEquals(mlTask, parsedMLTask); } + + @Test + public void testValidateSchema() throws IOException { + String schema = "{" + + "\"type\": \"object\"," + + "\"properties\": {" + + " \"key1\": {\"type\": \"string\"}," + + " \"key2\": {\"type\": \"integer\"}" + + "}" + + "}"; + String json = "{\"key1\": \"foo\", \"key2\": 123}"; + MLNodeUtils.validateSchema(schema, json); + } }