From 0501de0ac1192dd7404e9f6ad9752058cc41f099 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Frank=20Woo=28=E5=90=B4=E5=B3=BB=E7=94=B3=29?= Date: Mon, 25 Sep 2023 14:57:58 +0800 Subject: [PATCH 1/3] [FEATURE] tags on model group&version #1303 tags operation on model group MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Frank Woo(吴峻申) --- .../org/opensearch/ml/common/CommonValue.java | 644 ++++++++++-------- .../opensearch/ml/common/MLModelGroup.java | 28 +- .../ml/common/model/ModelGroupTag.java | 116 ++++ .../MLRegisterModelGroupInput.java | 37 +- .../MLRegisterModelGroupRequest.java | 4 +- .../model_group/MLUpdateModelGroupInput.java | 25 +- .../MLRegisterModelGroupInputTest.java | 21 +- .../MLRegisterModelGroupRequestTest.java | 11 +- .../MLUpdateModelGroupInputTest.java | 10 + .../MLUpdateModelGroupRequestTest.java | 18 +- .../TransportRegisterModelGroupAction.java | 7 + .../TransportUpdateModelGroupAction.java | 50 +- .../ml/model/MLModelGroupManager.java | 2 + .../RegisterModelGroupITTests.java | 25 +- .../model_group/SearchModelGroupITTests.java | 86 ++- ...ransportRegisterModelGroupActionTests.java | 33 +- .../TransportUpdateModelGroupActionTests.java | 34 +- .../model_group/UpdateModelGroupITTests.java | 23 +- .../ml/action/models/SearchModelITTests.java | 9 +- .../ml/model/MLModelGroupManagerTests.java | 16 +- 20 files changed, 840 insertions(+), 359 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/model/ModelGroupTag.java diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index dab60ce986..550426989c 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -5,8 +5,6 @@ package org.opensearch.ml.common; -import org.opensearch.ml.common.connector.AbstractConnector; - import static org.opensearch.ml.common.model.MLModelConfig.ALL_CONFIG_FIELD; import static org.opensearch.ml.common.model.MLModelConfig.MODEL_TYPE_FIELD; import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.EMBEDDING_DIMENSION_FIELD; @@ -15,310 +13,356 @@ import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.NORMALIZE_RESULT_FIELD; import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.POOLING_MODE_FIELD; -public class CommonValue { - - public static Integer NO_SCHEMA_VERSION = 0; - public static final String USER = "user"; - public static final String META = "_meta"; - public static final String SCHEMA_VERSION_FIELD = "schema_version"; - public static final String UNDEPLOYED = "undeployed"; - public static final String NOT_FOUND = "not_found"; +import org.opensearch.ml.common.connector.AbstractConnector; - public static final String MASTER_KEY = "master_key"; - public static final String CREATE_TIME_FIELD = "create_time"; +public class CommonValue { - public static final String BOX_TYPE_KEY = "box_type"; - //hot node - public static String HOT_BOX_TYPE = "hot"; - // warm node - public static String WARM_BOX_TYPE = "warm"; - public static final String ML_MODEL_GROUP_INDEX = ".plugins-ml-model-group"; - public static final String ML_MODEL_INDEX = ".plugins-ml-model"; - public static final String ML_TASK_INDEX = ".plugins-ml-task"; - public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2; - public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 7; - public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector"; - public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2; - public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2; - public static final String ML_CONFIG_INDEX = ".plugins-ml-config"; - public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 2; - public static final String USER_FIELD_MAPPING = " \"" - + CommonValue.USER - + "\": {\n" - + " \"type\": \"nested\",\n" - + " \"properties\": {\n" - + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" - + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" - + " }\n" - + " }\n"; - public static final String ML_MODEL_GROUP_INDEX_MAPPING = "{\n" + - " \"_meta\": {\n" + - " \"schema_version\": "+ML_MODEL_GROUP_INDEX_SCHEMA_VERSION+"\n" + - " },\n" + - " \"properties\": {\n" + - " \""+MLModelGroup.MODEL_GROUP_NAME_FIELD+"\": {\n" + - " \"type\": \"text\",\n" + - " \"fields\": {\n" + - " \"keyword\": {\n" + - " \"type\": \"keyword\",\n" + - " \"ignore_above\": 256\n" + - " }\n" + - " }\n" + - " },\n" + - " \""+MLModelGroup.DESCRIPTION_FIELD+"\": {\n" + - " \"type\": \"text\"\n" + - " },\n" + - " \""+MLModelGroup.LATEST_VERSION_FIELD+"\": {\n" + - " \"type\": \"integer\"\n" + - " },\n" + - " \""+MLModelGroup.MODEL_GROUP_ID_FIELD+"\": {\n" + - " \"type\": \"keyword\"\n" + - " },\n" + - " \""+MLModelGroup.BACKEND_ROLES_FIELD+"\": {\n" + - " \"type\": \"text\",\n" + - " \"fields\": {\n" + - " \"keyword\": {\n" + - " \"type\": \"keyword\",\n" + - " \"ignore_above\": 256\n" + - " }\n" + - " }\n" + - " },\n" + - " \""+MLModelGroup.ACCESS+"\": {\n" + - " \"type\": \"keyword\"\n" + - " },\n" + - " \""+MLModelGroup.OWNER+"\": {\n" + - " \"type\": \"nested\",\n" + - " \"properties\": {\n" + - " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" + - " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + - " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + - " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" + - " }\n" + - " },\n" + - " \""+MLModelGroup.CREATED_TIME_FIELD+"\": {\n" + - " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + - " \""+MLModelGroup.LAST_UPDATED_TIME_FIELD+"\": {\n" + - " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + - " }\n" + - "}"; + public static Integer NO_SCHEMA_VERSION = 0; + public static final String USER = "user"; + public static final String META = "_meta"; + public static final String SCHEMA_VERSION_FIELD = "schema_version"; + public static final String UNDEPLOYED = "undeployed"; + public static final String NOT_FOUND = "not_found"; - public static final String ML_CONNECTOR_INDEX_FIELDS = " \"properties\": {\n" - + " \"" - + AbstractConnector.NAME_FIELD - + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" - + " \"" - + AbstractConnector.VERSION_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + AbstractConnector.DESCRIPTION_FIELD - + "\" : {\"type\": \"text\"},\n" - + " \"" - + AbstractConnector.PROTOCOL_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + AbstractConnector.PARAMETERS_FIELD - + "\" : {\"type\": \"flat_object\"},\n" - + " \"" - + AbstractConnector.CREDENTIAL_FIELD - + "\" : {\"type\": \"flat_object\"},\n" - + " \"" - + AbstractConnector.ACTIONS_FIELD - + "\" : {\"type\": \"flat_object\"}\n"; + public static final String MASTER_KEY = "master_key"; + public static final String CREATE_TIME_FIELD = "create_time"; - public static final String ML_MODEL_INDEX_MAPPING = "{\n" - + " \"_meta\": {\"schema_version\": " - + ML_MODEL_INDEX_SCHEMA_VERSION - + "},\n" - + " \"properties\": {\n" - + " \"" - + MLModel.ALGORITHM_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.MODEL_NAME_FIELD - + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" - + " \"" - + MLModel.OLD_MODEL_VERSION_FIELD - + "\" : {\"type\": \"long\"},\n" - + " \"" - + MLModel.MODEL_VERSION_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.MODEL_GROUP_ID_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.MODEL_CONTENT_FIELD - + "\" : {\"type\": \"binary\"},\n" - + " \"" - + MLModel.CHUNK_NUMBER_FIELD - + "\" : {\"type\": \"long\"},\n" - + " \"" - + MLModel.TOTAL_CHUNKS_FIELD - + "\" : {\"type\": \"long\"},\n" - + " \"" - + MLModel.MODEL_ID_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.DESCRIPTION_FIELD - + "\" : {\"type\": \"text\"},\n" - + " \"" - + MLModel.MODEL_FORMAT_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.MODEL_STATE_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.MODEL_CONTENT_SIZE_IN_BYTES_FIELD - + "\" : {\"type\": \"long\"},\n" - + " \"" - + MLModel.PLANNING_WORKER_NODE_COUNT_FIELD - + "\" : {\"type\": \"integer\"},\n" - + " \"" - + MLModel.CURRENT_WORKER_NODE_COUNT_FIELD - + "\" : {\"type\": \"integer\"},\n" - + " \"" - + MLModel.PLANNING_WORKER_NODES_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.DEPLOY_TO_ALL_NODES_FIELD - + "\": {\"type\": \"boolean\"},\n" - + " \"" - + MLModel.MODEL_CONFIG_FIELD - + "\" : {\"properties\":{\"" - + MODEL_TYPE_FIELD + "\":{\"type\":\"keyword\"},\"" - + EMBEDDING_DIMENSION_FIELD + "\":{\"type\":\"integer\"},\"" - + FRAMEWORK_TYPE_FIELD + "\":{\"type\":\"keyword\"},\"" - + POOLING_MODE_FIELD + "\":{\"type\":\"keyword\"},\"" - + NORMALIZE_RESULT_FIELD + "\":{\"type\":\"boolean\"},\"" - + MODEL_MAX_LENGTH_FIELD + "\":{\"type\":\"integer\"},\"" - + ALL_CONFIG_FIELD + "\":{\"type\":\"text\"}}},\n" - + " \"" - + MLModel.MODEL_CONTENT_HASH_VALUE_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.AUTO_REDEPLOY_RETRY_TIMES_FIELD - + "\" : {\"type\": \"integer\"},\n" - + " \"" - + MLModel.CREATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLModel.LAST_UPDATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLModel.LAST_REGISTERED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLModel.LAST_DEPLOYED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLModel.LAST_UNDEPLOYED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLModel.CONNECTOR_FIELD - + "\": {" + ML_CONNECTOR_INDEX_FIELDS + " }\n}," - + USER_FIELD_MAPPING - + " }\n" - + "}"; + public static final String BOX_TYPE_KEY = "box_type"; + // hot node + public static String HOT_BOX_TYPE = "hot"; + // warm node + public static String WARM_BOX_TYPE = "warm"; + public static final String ML_MODEL_GROUP_INDEX = ".plugins-ml-model-group"; + public static final String ML_MODEL_INDEX = ".plugins-ml-model"; + public static final String ML_TASK_INDEX = ".plugins-ml-task"; + public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 1; + public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 6; + public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector"; + public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 1; + public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 1; + public static final String ML_CONFIG_INDEX = ".plugins-ml-config"; + public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 1; + public static final String USER_FIELD_MAPPING = + " \"" + + CommonValue.USER + + "\": {\n" + + " \"type\": \"nested\",\n" + + " \"properties\": {\n" + + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" + + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" + + " }\n" + + " }\n"; + public static final String ML_MODEL_GROUP_INDEX_MAPPING = + "{\n" + + " \"_meta\": {\n" + + " \"schema_version\": " + + ML_MODEL_GROUP_INDEX_SCHEMA_VERSION + + "\n" + + " },\n" + + " \"properties\": {\n" + + " \"" + + MLModelGroup.MODEL_GROUP_NAME_FIELD + + "\": {\n" + + " \"type\": \"text\",\n" + + " \"fields\": {\n" + + " \"keyword\": {\n" + + " \"type\": \"keyword\",\n" + + " \"ignore_above\": 256\n" + + " }\n" + + " }\n" + + " },\n" + + " \"" + + MLModelGroup.DESCRIPTION_FIELD + + "\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"" + + MLModelGroup.LATEST_VERSION_FIELD + + "\": {\n" + + " \"type\": \"integer\"\n" + + " },\n" + + " \"" + + MLModelGroup.MODEL_GROUP_ID_FIELD + + "\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"" + + MLModelGroup.TAGS_FIELD + + "\": {\n" + + " \"type\": \"nested\",\n" + + " \"properties\": {\n" + + " \"key\": {\"type\": \"keyword\"},\n" + + " \"type\": {\"type\": \"keyword\"}}\n" + + " },\n" + + " \"" + + MLModelGroup.BACKEND_ROLES_FIELD + + "\": {\n" + + " \"type\": \"text\",\n" + + " \"fields\": {\n" + + " \"keyword\": {\n" + + " \"type\": \"keyword\",\n" + + " \"ignore_above\": 256\n" + + " }\n" + + " }\n" + + " },\n" + + " \"" + + MLModelGroup.ACCESS + + "\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"" + + MLModelGroup.OWNER + + "\": {\n" + + " \"type\": \"nested\",\n" + + " \"properties\": {\n" + + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" + + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" + + " }\n" + + " },\n" + + " \"" + + MLModelGroup.CREATED_TIME_FIELD + + "\": {\n" + + " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModelGroup.LAST_UPDATED_TIME_FIELD + + "\": {\n" + + " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + + " }\n" + + "}"; + + public static final String ML_CONNECTOR_INDEX_FIELDS = + " \"properties\": {\n" + + " \"" + + AbstractConnector.NAME_FIELD + + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" + + " \"" + + AbstractConnector.VERSION_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + AbstractConnector.DESCRIPTION_FIELD + + "\" : {\"type\": \"text\"},\n" + + " \"" + + AbstractConnector.PROTOCOL_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + AbstractConnector.PARAMETERS_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + AbstractConnector.CREDENTIAL_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + AbstractConnector.ACTIONS_FIELD + + "\" : {\"type\": \"flat_object\"}\n"; - public static final String ML_TASK_INDEX_MAPPING = "{\n" - + " \"_meta\": {\"schema_version\": " - + ML_TASK_INDEX_SCHEMA_VERSION - + "},\n" - + " \"properties\": {\n" - + " \"" - + MLTask.MODEL_ID_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.TASK_TYPE_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.FUNCTION_NAME_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.STATE_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.INPUT_TYPE_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.PROGRESS_FIELD - + "\": {\"type\": \"float\"},\n" - + " \"" - + MLTask.OUTPUT_INDEX_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.WORKER_NODE_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.CREATE_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLTask.LAST_UPDATE_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLTask.ERROR_FIELD - + "\": {\"type\": \"text\"},\n" - + " \"" - + MLTask.IS_ASYNC_TASK_FIELD - + "\" : {\"type\" : \"boolean\"}, \n" - + USER_FIELD_MAPPING - + " }\n" - + "}"; + public static final String ML_MODEL_INDEX_MAPPING = + "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_MODEL_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MLModel.ALGORITHM_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.MODEL_NAME_FIELD + + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" + + " \"" + + MLModel.OLD_MODEL_VERSION_FIELD + + "\" : {\"type\": \"long\"},\n" + + " \"" + + MLModel.MODEL_VERSION_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.MODEL_GROUP_ID_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.MODEL_CONTENT_FIELD + + "\" : {\"type\": \"binary\"},\n" + + " \"" + + MLModel.CHUNK_NUMBER_FIELD + + "\" : {\"type\": \"long\"},\n" + + " \"" + + MLModel.TOTAL_CHUNKS_FIELD + + "\" : {\"type\": \"long\"},\n" + + " \"" + + MLModel.MODEL_ID_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.DESCRIPTION_FIELD + + "\" : {\"type\": \"text\"},\n" + + " \"" + + MLModel.MODEL_FORMAT_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.MODEL_STATE_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.MODEL_CONTENT_SIZE_IN_BYTES_FIELD + + "\" : {\"type\": \"long\"},\n" + + " \"" + + MLModel.PLANNING_WORKER_NODE_COUNT_FIELD + + "\" : {\"type\": \"integer\"},\n" + + " \"" + + MLModel.CURRENT_WORKER_NODE_COUNT_FIELD + + "\" : {\"type\": \"integer\"},\n" + + " \"" + + MLModel.PLANNING_WORKER_NODES_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.DEPLOY_TO_ALL_NODES_FIELD + + "\": {\"type\": \"boolean\"},\n" + + " \"" + + MLModel.MODEL_CONFIG_FIELD + + "\" : {\"properties\":{\"" + + MODEL_TYPE_FIELD + + "\":{\"type\":\"keyword\"},\"" + + EMBEDDING_DIMENSION_FIELD + + "\":{\"type\":\"integer\"},\"" + + FRAMEWORK_TYPE_FIELD + + "\":{\"type\":\"keyword\"},\"" + + POOLING_MODE_FIELD + + "\":{\"type\":\"keyword\"},\"" + + NORMALIZE_RESULT_FIELD + + "\":{\"type\":\"boolean\"},\"" + + MODEL_MAX_LENGTH_FIELD + + "\":{\"type\":\"integer\"},\"" + + ALL_CONFIG_FIELD + + "\":{\"type\":\"text\"}}},\n" + + " \"" + + MLModel.MODEL_CONTENT_HASH_VALUE_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.AUTO_REDEPLOY_RETRY_TIMES_FIELD + + "\" : {\"type\": \"integer\"},\n" + + " \"" + + MLModel.CREATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModel.LAST_UPDATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModel.LAST_REGISTERED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModel.LAST_DEPLOYED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModel.LAST_UNDEPLOYED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModel.CONNECTOR_FIELD + + "\": {" + + ML_CONNECTOR_INDEX_FIELDS + + " }\n}," + + USER_FIELD_MAPPING + + " }\n" + + "}"; - public static final String ML_CONNECTOR_INDEX_MAPPING = "{\n" - + " \"_meta\": {\"schema_version\": " - + ML_CONNECTOR_SCHEMA_VERSION - + "},\n" - + ML_CONNECTOR_INDEX_FIELDS + ",\n" - + " \"" - + MLModelGroup.BACKEND_ROLES_FIELD - + "\": {\n" - + " \"type\": \"text\",\n" - + " \"fields\": {\n" - + " \"keyword\": {\n" - + " \"type\": \"keyword\",\n" - + " \"ignore_above\": 256\n" - + " }\n" - + " }\n" - + " },\n" - + " \"" - + MLModelGroup.ACCESS - + "\": {\n" - + " \"type\": \"keyword\"\n" - + " },\n" - + " \"" - + MLModelGroup.OWNER - + "\": {\n" - + " \"type\": \"nested\",\n" - + " \"properties\": {\n" - + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" - + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" - + " }\n" - + " },\n" - + " \"" - + AbstractConnector.CREATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + AbstractConnector.LAST_UPDATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" - + " }\n" - + "}"; + public static final String ML_TASK_INDEX_MAPPING = + "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_TASK_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MLTask.MODEL_ID_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.TASK_TYPE_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.FUNCTION_NAME_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.STATE_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.INPUT_TYPE_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.PROGRESS_FIELD + + "\": {\"type\": \"float\"},\n" + + " \"" + + MLTask.OUTPUT_INDEX_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.WORKER_NODE_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.CREATE_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLTask.LAST_UPDATE_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLTask.ERROR_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + MLTask.IS_ASYNC_TASK_FIELD + + "\" : {\"type\" : \"boolean\"}, \n" + + USER_FIELD_MAPPING + + " }\n" + + "}"; + public static final String ML_CONNECTOR_INDEX_MAPPING = + "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_CONNECTOR_SCHEMA_VERSION + + "},\n" + + ML_CONNECTOR_INDEX_FIELDS + + ",\n" + + " \"" + + MLModelGroup.BACKEND_ROLES_FIELD + + "\": {\n" + + " \"type\": \"text\",\n" + + " \"fields\": {\n" + + " \"keyword\": {\n" + + " \"type\": \"keyword\",\n" + + " \"ignore_above\": 256\n" + + " }\n" + + " }\n" + + " },\n" + + " \"" + + MLModelGroup.ACCESS + + "\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"" + + MLModelGroup.OWNER + + "\": {\n" + + " \"type\": \"nested\",\n" + + " \"properties\": {\n" + + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" + + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" + + " }\n" + + " },\n" + + " \"" + + AbstractConnector.CREATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + AbstractConnector.LAST_UPDATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + + " }\n" + + "}"; - public static final String ML_CONFIG_INDEX_MAPPING = "{\n" - + " \"_meta\": {\"schema_version\": " - + ML_CONFIG_INDEX_SCHEMA_VERSION - + "},\n" - + " \"properties\": {\n" - + " \"" - + MASTER_KEY - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + CREATE_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" - + " }\n" - + "}"; + public static final String ML_CONFIG_INDEX_MAPPING = + "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_CONFIG_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MASTER_KEY + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + CREATE_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + + " }\n" + + "}"; } diff --git a/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java b/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java index 9e2fbb7133..cfb0f06ad4 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java @@ -11,6 +11,7 @@ import java.time.Instant; import java.util.ArrayList; import java.util.List; +import java.util.Objects; import lombok.Builder; import lombok.Getter; import lombok.Setter; @@ -21,6 +22,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.model.ModelGroupTag; @Getter public class MLModelGroup implements ToXContentObject { @@ -34,7 +36,7 @@ public class MLModelGroup implements ToXContentObject { public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; //unique ID assigned to each model group public static final String CREATED_TIME_FIELD = "created_time"; //model group created time stamp public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; //updated whenever a new model version is created - + public static final String TAGS_FIELD = "tags"; @Setter private String name; @@ -50,10 +52,12 @@ public class MLModelGroup implements ToXContentObject { private Instant createdTime; private Instant lastUpdatedTime; + private List tags; @Builder(toBuilder = true) public MLModelGroup(String name, String description, int latestVersion, - List backendRoles, User owner, String access, + List backendRoles, User owner,List tags, + String access, String modelGroupId, Instant createdTime, Instant lastUpdatedTime) { @@ -69,6 +73,7 @@ public MLModelGroup(String name, String description, int latestVersion, this.modelGroupId = modelGroupId; this.createdTime = createdTime; this.lastUpdatedTime = lastUpdatedTime; + this.tags = tags; } @@ -84,6 +89,8 @@ public MLModelGroup(StreamInput input) throws IOException{ } else { this.owner = null; } + + tags = input.readList(ModelGroupTag::new); access = input.readOptionalString(); modelGroupId = input.readOptionalString(); createdTime = input.readOptionalInstant(); @@ -106,6 +113,8 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + + out.writeList(Objects.requireNonNullElseGet(tags, ArrayList::new)); out.writeOptionalString(access); out.writeOptionalString(modelGroupId); out.writeOptionalInstant(createdTime); @@ -126,6 +135,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (owner != null) { builder.field(OWNER, owner); } + + if (!CollectionUtils.isEmpty(tags)) { + builder.field(TAGS_FIELD, tags); + } if (access != null) { builder.field(ACCESS, access); } @@ -152,6 +165,7 @@ public static MLModelGroup parse(XContentParser parser) throws IOException { String modelGroupId = null; Instant createdTime = null; Instant lastUpdateTime = null; + List tags = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -178,6 +192,15 @@ public static MLModelGroup parse(XContentParser parser) throws IOException { case OWNER: owner = User.parse(parser); break; + + case TAGS_FIELD: + tags = new ArrayList<>(); + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + tags.add(ModelGroupTag.parse(parser)); + } + + break; case ACCESS: access = parser.text(); break; @@ -201,6 +224,7 @@ public static MLModelGroup parse(XContentParser parser) throws IOException { .backendRoles(backendRoles) .latestVersion(latestVersion) .owner(owner) + .tags(tags) .access(access) .modelGroupId(modelGroupId) .createdTime(createdTime) diff --git a/common/src/main/java/org/opensearch/ml/common/model/ModelGroupTag.java b/common/src/main/java/org/opensearch/ml/common/model/ModelGroupTag.java new file mode 100644 index 0000000000..9e89086e54 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/model/ModelGroupTag.java @@ -0,0 +1,116 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.*; +import org.opensearch.common.Nullable; +import org.opensearch.common.inject.internal.ToStringBuilder; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +public final class ModelGroupTag implements Writeable, ToXContent { + public static final String TAG_KEY_FIELD = "key"; + public static final String TAG_TYPE_FIELD = "type"; + + @Nullable private final String key; + @Nullable private final String type; + + public ModelGroupTag() { + key = ""; + type = ""; + } + + public ModelGroupTag(@Nullable final String key, @Nullable final String type) { + this.key = key; + this.type = type; + } + + public ModelGroupTag(String json) { + if (Strings.isNullOrEmpty(json)) { + throw new IllegalArgumentException("Response json cannot be null"); + } + + Map mapValue = + XContentHelper.convertToMap(JsonXContent.jsonXContent, json, false); + key = (String) mapValue.get(TAG_KEY_FIELD); + type = (String) mapValue.get(TAG_TYPE_FIELD); + } + + public ModelGroupTag(StreamInput in) throws IOException { + this.key = in.readString(); + this.type = in.readString(); + } + + public static ModelGroupTag parse(XContentParser parser) throws IOException { + String key = ""; + String type = ""; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case TAG_KEY_FIELD: + key = parser.text(); + break; + case TAG_TYPE_FIELD: + type = parser.text(); + break; + default: + break; + } + } + return new ModelGroupTag(key, type); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject().field(TAG_KEY_FIELD, key).field(TAG_TYPE_FIELD, type); + return builder.endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(key); + out.writeString(type); + } + + @Override + public String toString() { + ToStringBuilder builder = new ToStringBuilder(this.getClass()); + builder.add(TAG_KEY_FIELD, key); + builder.add(TAG_TYPE_FIELD, type); + return builder.toString(); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof ModelGroupTag)) { + return false; + } + ModelGroupTag that = (ModelGroupTag) obj; + return this.key.equals(that.key) && this.type.equals(that.type); + } + + @Nullable + public String getKey() { + return key; + } + + @Nullable + public String getType() { + return type; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInput.java index 4595a16d77..740ea31bb7 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInput.java @@ -5,22 +5,23 @@ package org.opensearch.ml.common.transport.model_group; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; import lombok.Builder; import lombok.Data; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.AccessMode; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Locale; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import org.opensearch.ml.common.model.ModelGroupTag; @Data public class MLRegisterModelGroupInput implements ToXContentObject, Writeable{ @@ -30,15 +31,17 @@ public class MLRegisterModelGroupInput implements ToXContentObject, Writeable{ public static final String BACKEND_ROLES_FIELD = "backend_roles"; //optional public static final String MODEL_ACCESS_MODE = "access_mode"; //optional public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; //optional + public static final String TAGS_FIELD = "tags"; //optional private String name; private String description; private List backendRoles; private AccessMode modelAccessMode; private Boolean isAddAllBackendRoles; + private List tags; @Builder(toBuilder = true) - public MLRegisterModelGroupInput(String name, String description, List backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles) { + public MLRegisterModelGroupInput(String name, String description, List backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles,List tags) { if (name == null) { throw new IllegalArgumentException("model group name is null"); } @@ -47,6 +50,7 @@ public MLRegisterModelGroupInput(String name, String description, List b this.backendRoles = backendRoles; this.modelAccessMode = modelAccessMode; this.isAddAllBackendRoles = isAddAllBackendRoles; + this.tags = tags; } public MLRegisterModelGroupInput(StreamInput in) throws IOException{ @@ -57,6 +61,7 @@ public MLRegisterModelGroupInput(StreamInput in) throws IOException{ modelAccessMode = in.readEnum(AccessMode.class); } this.isAddAllBackendRoles = in.readOptionalBoolean(); + this.tags=in.readList(ModelGroupTag::new); } @Override @@ -76,6 +81,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } out.writeOptionalBoolean(isAddAllBackendRoles); + if(!CollectionUtils.isEmpty(tags)){ + out.writeList(tags); + } } @Override @@ -94,6 +102,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (isAddAllBackendRoles != null) { builder.field(ADD_ALL_BACKEND_ROLES, isAddAllBackendRoles); } + if(!CollectionUtils.isEmpty(tags)){ + builder.field(TAGS_FIELD, tags); + } builder.endObject(); return builder; } @@ -104,6 +115,7 @@ public static MLRegisterModelGroupInput parse(XContentParser parser) throws IOEx List backendRoles = null; AccessMode modelAccessMode = null; Boolean isAddAllBackendRoles = null; + List tags=null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -129,12 +141,19 @@ public static MLRegisterModelGroupInput parse(XContentParser parser) throws IOEx case ADD_ALL_BACKEND_ROLES: isAddAllBackendRoles = parser.booleanValue(); break; + case TAGS_FIELD: + tags = new ArrayList<>(); + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + tags.add(ModelGroupTag.parse(parser)); + } + break; default: parser.skipChildren(); break; } } - return new MLRegisterModelGroupInput(name, description, backendRoles, modelAccessMode, isAddAllBackendRoles); + return new MLRegisterModelGroupInput(name, description, backendRoles, modelAccessMode, isAddAllBackendRoles,tags); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequest.java index 3bf3dabd03..dcc0a450f8 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequest.java @@ -62,8 +62,7 @@ public static MLRegisterModelGroupRequest fromActionRequest(ActionRequest action return (MLRegisterModelGroupRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLRegisterModelGroupRequest(input); @@ -71,6 +70,5 @@ public static MLRegisterModelGroupRequest fromActionRequest(ActionRequest action } catch (IOException e) { throw new UncheckedIOException("Failed to parse ActionRequest into MLCreateModelMetaRequest", e); } - } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInput.java index 22e612a5b1..1a86dace73 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInput.java @@ -10,10 +10,12 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.model.ModelGroupTag; import java.io.IOException; import java.util.ArrayList; @@ -31,7 +33,7 @@ public class MLUpdateModelGroupInput implements ToXContentObject, Writeable { public static final String BACKEND_ROLES_FIELD = "backend_roles"; //optional public static final String MODEL_ACCESS_MODE = "access_mode"; //optional public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles"; //optional - + public static final String TAGS_FIELD = "tags"; //optional private String modelGroupID; private String name; @@ -39,15 +41,17 @@ public class MLUpdateModelGroupInput implements ToXContentObject, Writeable { private List backendRoles; private AccessMode modelAccessMode; private Boolean isAddAllBackendRoles; + private List tags; @Builder(toBuilder = true) - public MLUpdateModelGroupInput(String modelGroupID, String name, String description, List backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles) { + public MLUpdateModelGroupInput(String modelGroupID, String name, String description, List backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles,List tags) { this.modelGroupID = modelGroupID; this.name = name; this.description = description; this.backendRoles = backendRoles; this.modelAccessMode = modelAccessMode; this.isAddAllBackendRoles = isAddAllBackendRoles; + this.tags = tags; } public MLUpdateModelGroupInput(StreamInput in) throws IOException { @@ -59,6 +63,7 @@ public MLUpdateModelGroupInput(StreamInput in) throws IOException { modelAccessMode = in.readEnum(AccessMode.class); } this.isAddAllBackendRoles = in.readOptionalBoolean(); + this.tags=in.readList(ModelGroupTag::new); } @Override @@ -80,6 +85,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (isAddAllBackendRoles != null) { builder.field(ADD_ALL_BACKEND_ROLES_FIELD, isAddAllBackendRoles); } + if(!CollectionUtils.isEmpty(tags)){ + builder.field(TAGS_FIELD, tags); + } builder.endObject(); return builder; } @@ -102,6 +110,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } out.writeOptionalBoolean(isAddAllBackendRoles); + if(!CollectionUtils.isEmpty(tags)){ + out.writeList(tags); + } } public static MLUpdateModelGroupInput parse(XContentParser parser) throws IOException { @@ -111,6 +122,7 @@ public static MLUpdateModelGroupInput parse(XContentParser parser) throws IOExce List backendRoles = null; AccessMode modelAccessMode = null; Boolean isAddAllBackendRoles = null; + List tags=null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -139,11 +151,18 @@ public static MLUpdateModelGroupInput parse(XContentParser parser) throws IOExce case ADD_ALL_BACKEND_ROLES_FIELD: isAddAllBackendRoles = parser.booleanValue(); break; + case TAGS_FIELD: + tags = new ArrayList<>(); + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + tags.add(ModelGroupTag.parse(parser)); + } + break; default: parser.skipChildren(); break; } } - return new MLUpdateModelGroupInput(modelGroupID, name, description, backendRoles, modelAccessMode, isAddAllBackendRoles); + return new MLUpdateModelGroupInput(modelGroupID, name, description, backendRoles, modelAccessMode, isAddAllBackendRoles,tags); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInputTest.java index a9a4969533..fd7e647be4 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInputTest.java @@ -4,11 +4,13 @@ import java.io.IOException; import java.util.Arrays; +import java.util.List; import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.model.ModelGroupTag; public class MLRegisterModelGroupInputTest { @@ -16,14 +18,17 @@ public class MLRegisterModelGroupInputTest { @Before public void setUp() throws Exception { - - mlRegisterModelGroupInput = mlRegisterModelGroupInput.builder() - .name("name") - .description("description") - .backendRoles(Arrays.asList("IT")) - .modelAccessMode(AccessMode.RESTRICTED) - .isAddAllBackendRoles(true) - .build(); + mlRegisterModelGroupInput = + MLRegisterModelGroupInput + .builder() + .name("name") + .description("description") + .backendRoles(Arrays.asList("IT")) + .modelAccessMode(AccessMode.RESTRICTED) + .isAddAllBackendRoles(true) + .tags(List.of(new ModelGroupTag("tag1", "String"), + new ModelGroupTag("tag2", "Number"))) + .build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java index ea1695e96d..1262254997 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java @@ -7,7 +7,7 @@ import java.io.IOException; import java.io.UncheckedIOException; -import java.util.Arrays; +import java.util.List; import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -15,6 +15,7 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.model.ModelGroupTag; public class MLRegisterModelGroupRequestTest { @@ -23,12 +24,14 @@ public class MLRegisterModelGroupRequestTest { @Before public void setUp(){ - mlRegisterModelGroupInput = mlRegisterModelGroupInput.builder() + mlRegisterModelGroupInput = MLRegisterModelGroupInput.builder() .name("name") .description("description") - .backendRoles(Arrays.asList("IT")) + .backendRoles(List.of("IT")) .modelAccessMode(AccessMode.RESTRICTED) .isAddAllBackendRoles(true) + .tags(List.of(new ModelGroupTag("tag1", "String"), + new ModelGroupTag("tag2", "Number"))) .build(); } @@ -45,6 +48,8 @@ public void writeTo_Success() throws IOException { assertEquals(request.getRegisterModelGroupInput().getBackendRoles().get(0), parsedRequest.getRegisterModelGroupInput().getBackendRoles().get(0)); assertEquals(request.getRegisterModelGroupInput().getModelAccessMode(), parsedRequest.getRegisterModelGroupInput().getModelAccessMode()); assertEquals(request.getRegisterModelGroupInput().getIsAddAllBackendRoles() ,parsedRequest.getRegisterModelGroupInput().getIsAddAllBackendRoles()); + assertEquals(request.getRegisterModelGroupInput().getTags().get(0) ,parsedRequest.getRegisterModelGroupInput().getTags().get(0)); + assertEquals(request.getRegisterModelGroupInput().getTags().get(1) ,parsedRequest.getRegisterModelGroupInput().getTags().get(1)); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInputTest.java index 9dc0fc559c..82bdce6193 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInputTest.java @@ -4,11 +4,13 @@ import java.io.IOException; import java.util.Arrays; +import java.util.List; import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.model.ModelGroupTag; public class MLUpdateModelGroupInputTest { @@ -24,6 +26,8 @@ public void setUp() throws Exception { .backendRoles(Arrays.asList("IT")) .modelAccessMode(AccessMode.RESTRICTED) .isAddAllBackendRoles(true) + .tags(List.of(new ModelGroupTag("tag1", "String"), + new ModelGroupTag("tag2", "Number"))) .build(); } @@ -34,5 +38,11 @@ public void readInputStream() throws IOException { StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); MLUpdateModelGroupInput parsedInput = new MLUpdateModelGroupInput(streamInput); assertEquals(mlUpdateModelGroupInput.getName(), parsedInput.getName()); + assertEquals(mlUpdateModelGroupInput.getTags().get(0), parsedInput.getTags().get(0)); + assertEquals(mlUpdateModelGroupInput.getTags().get(0).getKey(), parsedInput.getTags().get(0).getKey()); + assertEquals(mlUpdateModelGroupInput.getTags().get(0).getType(), parsedInput.getTags().get(0).getType()); + assertEquals(mlUpdateModelGroupInput.getTags().get(1), parsedInput.getTags().get(1)); + assertEquals(mlUpdateModelGroupInput.getTags().get(1).getKey(), parsedInput.getTags().get(1).getKey()); + assertEquals(mlUpdateModelGroupInput.getTags().get(1).getType(), parsedInput.getTags().get(1).getType()); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequestTest.java index c5406cfa7a..43001cc58a 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequestTest.java @@ -7,7 +7,7 @@ import java.io.IOException; import java.io.UncheckedIOException; -import java.util.Arrays; +import java.util.List; import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -15,6 +15,7 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.model.ModelGroupTag; public class MLUpdateModelGroupRequestTest { @@ -22,14 +23,15 @@ public class MLUpdateModelGroupRequestTest { @Before public void setUp(){ - - mlUpdateModelGroupInput = mlUpdateModelGroupInput.builder() + mlUpdateModelGroupInput = MLUpdateModelGroupInput.builder() .modelGroupID("modelGroupId") .name("name") .description("description") - .backendRoles(Arrays.asList("IT")) + .backendRoles(List.of("IT")) .modelAccessMode(AccessMode.RESTRICTED) .isAddAllBackendRoles(true) + .tags(List.of(new ModelGroupTag("tag1", "String"), + new ModelGroupTag("tag2", "Number"))) .build(); } @@ -48,6 +50,10 @@ public void writeTo_Success() throws IOException { assertEquals("IT", request.getUpdateModelGroupInput().getBackendRoles().get(0)); assertEquals(AccessMode.RESTRICTED, request.getUpdateModelGroupInput().getModelAccessMode()); assertEquals(true, request.getUpdateModelGroupInput().getIsAddAllBackendRoles()); + assertEquals("tag1", request.getUpdateModelGroupInput().getTags().get(0).getKey()); + assertEquals("String", request.getUpdateModelGroupInput().getTags().get(0).getType()); + assertEquals("tag2", request.getUpdateModelGroupInput().getTags().get(1).getKey()); + assertEquals("Number", request.getUpdateModelGroupInput().getTags().get(1).getType()); } @Test @@ -107,6 +113,10 @@ public void writeTo(StreamOutput out) throws IOException { MLUpdateModelGroupRequest result = MLUpdateModelGroupRequest.fromActionRequest(actionRequest); assertNotSame(result, request); assertEquals(request.getUpdateModelGroupInput().getName(), result.getUpdateModelGroupInput().getName()); + assertEquals(request.getUpdateModelGroupInput().getTags().get(0).getKey(), result.getUpdateModelGroupInput().getTags().get(0).getKey()); + assertEquals(request.getUpdateModelGroupInput().getTags().get(0).getType(), result.getUpdateModelGroupInput().getTags().get(0).getType()); + assertEquals(request.getUpdateModelGroupInput().getTags().get(1).getKey(), result.getUpdateModelGroupInput().getTags().get(1).getKey()); + assertEquals(request.getUpdateModelGroupInput().getTags().get(1).getType(), result.getUpdateModelGroupInput().getTags().get(1).getType()); } @Test(expected = UncheckedIOException.class) diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupAction.java index 94d4b5a8a7..505b433f37 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupAction.java @@ -12,6 +12,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupAction; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; @@ -64,6 +65,12 @@ public TransportRegisterModelGroupAction( protected void doExecute(Task task, ActionRequest request, ActionListener listener) { MLRegisterModelGroupRequest createModelGroupRequest = MLRegisterModelGroupRequest.fromActionRequest(request); MLRegisterModelGroupInput createModelGroupInput = createModelGroupRequest.getRegisterModelGroupInput(); + + if (!CollectionUtils.isEmpty(createModelGroupInput.getTags()) && createModelGroupInput.getTags().size() > 10) { + listener.onFailure(new IllegalArgumentException("The size of tags cannot be larger than 10")); + return; + } + mlModelGroupManager.createModelGroup(createModelGroupInput, ActionListener.wrap(modelGroupId -> { listener.onResponse(new MLRegisterModelGroupResponse(modelGroupId, MLTaskState.CREATED.name())); }, ex -> { diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java index 5d53c4dea9..92591cfb85 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java @@ -9,10 +9,7 @@ import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.utils.MLExceptionUtils.logException; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Iterator; -import java.util.Map; +import java.util.*; import org.apache.commons.lang3.StringUtils; import org.opensearch.OpenSearchStatusException; @@ -33,9 +30,11 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.exception.MLValidationException; +import org.opensearch.ml.common.model.ModelGroupTag; import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupAction; import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupRequest; @@ -88,6 +87,12 @@ public TransportUpdateModelGroupAction( protected void doExecute(Task task, ActionRequest request, ActionListener listener) { MLUpdateModelGroupRequest updateModelGroupRequest = MLUpdateModelGroupRequest.fromActionRequest(request); MLUpdateModelGroupInput updateModelGroupInput = updateModelGroupRequest.getUpdateModelGroupInput(); + + if (!CollectionUtils.isEmpty(updateModelGroupInput.getTags()) && updateModelGroupInput.getTags().size() > 10) { + listener.onFailure(new IllegalArgumentException("The size of tags cannot be larger than 10")); + return; + } + String modelGroupId = updateModelGroupInput.getModelGroupID(); User user = RestActionUtils.getUserContext(client); if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) { @@ -102,6 +107,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener existedTags = mlModelGroup.getTags(); + List updatedTags = updateModelGroupInput.getTags(); + + List toDeleteTags = new ArrayList<>(); + + Set updatedTagKeys = new HashSet<>(); + for (ModelGroupTag tag : updatedTags) { + updatedTagKeys.add(tag.getKey()); + } + + for (ModelGroupTag tag : existedTags) { + if (!updatedTagKeys.contains(tag.getKey())) { + toDeleteTags.add(tag); + } + } + + // Get the model version to be updated based on the modelGroupId and the key value of the tag to + // be deleted. + List models = new ArrayList<>(); + + // Loop through each model version to update the tags and remove the ones that need to be + // removed. + for (MLModel model : models) { + // TODO + // model update tags to be deleted + } + } + private void updateModelGroup( String modelGroupId, Map source, @@ -148,6 +185,9 @@ private void updateModelGroup( if (Boolean.TRUE.equals(updateModelGroupInput.getIsAddAllBackendRoles())) { source.put(MLModelGroup.BACKEND_ROLES_FIELD, user.getBackendRoles()); } + if (updateModelGroupInput.getTags() != null && !CollectionUtils.isEmpty(updateModelGroupInput.getTags())) { + source.put(MLModelGroup.TAGS_FIELD, updateModelGroupInput.getTags()); + } if (StringUtils.isNotBlank(updateModelGroupInput.getDescription())) { source.put(MLModelGroup.DESCRIPTION_FIELD, updateModelGroupInput.getDescription()); } @@ -179,7 +219,6 @@ private void updateModelGroup( } else { updateModelGroup(modelGroupId, source, listener); } - } private void updateModelGroup(String modelGroupId, Map source, ActionListener listener) { @@ -266,5 +305,4 @@ private void validateSecurityDisabledOrModelAccessControlDisabled(MLUpdateModelG ); } } - } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java index efc78edf20..312e5d9186 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java @@ -95,6 +95,7 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener> tags = (List>) response.getHits().getHits()[0].getSourceAsMap().get("tags"); + assertEquals(2, tags.size()); + assertEquals("tag1", tags.get(0).get("key")); + assertEquals("String", tags.get(0).get("type")); + assertEquals("tag2", tags.get(1).get("key")); + assertEquals("Number", tags.get(1).get("type")); + } + + public void test_match_search() { + SearchRequest searchRequest = new SearchRequest(); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchRequest.source(searchSourceBuilder); + MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery("tags.key", "tag1"); + NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder("tags", matchQueryBuilder, ScoreMode.None); + searchRequest.source().query(nestedQueryBuilder); + + SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, searchRequest).actionGet(); + assertEquals(1, response.getHits().getTotalHits().value); + assertEquals(modelGroupId, response.getHits().getHits()[0].getId()); + + List> tags = (List>) response.getHits().getHits()[0].getSourceAsMap().get("tags"); + assertEquals(2, tags.size()); + assertEquals("tag1", tags.get(0).get("key")); + assertEquals("String", tags.get(0).get("type")); + assertEquals("tag2", tags.get(1).get("key")); + assertEquals("Number", tags.get(1).get("type")); } public void test_bool_search() { @@ -79,6 +120,26 @@ public void test_term_search() { assertEquals(modelGroupId, response.getHits().getHits()[0].getId()); } + public void test_term_search_tags() { + SearchRequest searchRequest = new SearchRequest(); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchRequest.source(searchSourceBuilder); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery("tags.key", "tag1"); + NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder("tags", termQueryBuilder, ScoreMode.None); + searchRequest.source().query(nestedQueryBuilder); + + SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, searchRequest).actionGet(); + assertEquals(1, response.getHits().getTotalHits().value); + assertEquals(modelGroupId, response.getHits().getHits()[0].getId()); + + List> tags = (List>) response.getHits().getHits()[0].getSourceAsMap().get("tags"); + assertEquals(2, tags.size()); + assertEquals("tag1", tags.get(0).get("key")); + assertEquals("String", tags.get(0).get("type")); + assertEquals("tag2", tags.get(1).get("key")); + assertEquals("Number", tags.get(1).get("type")); + } + public void test_terms_search() { SearchRequest searchRequest = new SearchRequest(); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); @@ -89,6 +150,26 @@ public void test_terms_search() { assertEquals(modelGroupId, response.getHits().getHits()[0].getId()); } + public void test_terms_search_tags() { + SearchRequest searchRequest = new SearchRequest(); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchRequest.source(searchSourceBuilder); + TermsQueryBuilder termsQueryBuilder = QueryBuilders.termsQuery("tags.key", "tag1"); + NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder("tags", termsQueryBuilder, ScoreMode.None); + searchRequest.source().query(nestedQueryBuilder); + + SearchResponse response = client().execute(MLModelGroupSearchAction.INSTANCE, searchRequest).actionGet(); + assertEquals(1, response.getHits().getTotalHits().value); + assertEquals(modelGroupId, response.getHits().getHits()[0].getId()); + + List> tags = (List>) response.getHits().getHits()[0].getSourceAsMap().get("tags"); + assertEquals(2, tags.size()); + assertEquals("tag1", tags.get(0).get("key")); + assertEquals("String", tags.get(0).get("type")); + assertEquals("tag2", tags.get(1).get("key")); + assertEquals("Number", tags.get(1).get("type")); + } + public void test_range_search() { SearchRequest searchRequest = new SearchRequest(); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); @@ -118,5 +199,4 @@ public void test_queryString_search() { assertEquals(1, response.getHits().getTotalHits().value); assertEquals(modelGroupId, response.getHits().getHits()[0].getId()); } - } diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupActionTests.java index 269ac30d95..eb503fd923 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupActionTests.java @@ -9,6 +9,7 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.verify; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -26,6 +27,7 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.model.ModelGroupTag; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupRequest; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; @@ -94,7 +96,6 @@ public void setup() { mlModelGroupManager ); assertNotNull(transportRegisterModelGroupAction); - } public void test_Success() { @@ -125,6 +126,35 @@ public void test_Failure() { assertEquals("Failed to init model group index", argumentCaptor.getValue().getMessage()); } + public void test_Failure_exceed_size() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new Exception("The size of tags cannot be larger than 10")); + return null; + }).when(mlModelGroupManager).createModelGroup(any(), any()); + + List tags = new ArrayList<>(); + for (int i = 0; i < 11; i++) { + tags.add(new ModelGroupTag("key" + i, "type" + i)); + } + MLRegisterModelGroupInput registerModelGroupInput = MLRegisterModelGroupInput + .builder() + .name("modelGroupName") + .description("This is a test model group") + .backendRoles(null) + .modelAccessMode(AccessMode.PUBLIC) + .isAddAllBackendRoles(null) + .tags(tags) + .build(); + MLRegisterModelGroupRequest actionRequest = new MLRegisterModelGroupRequest(registerModelGroupInput); + + transportRegisterModelGroupAction.doExecute(task, actionRequest, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("The size of tags cannot be larger than 10", argumentCaptor.getValue().getMessage()); + } + private MLRegisterModelGroupRequest prepareRequest( List backendRoles, AccessMode modelAccessMode, @@ -140,5 +170,4 @@ private MLRegisterModelGroupRequest prepareRequest( .build(); return new MLRegisterModelGroupRequest(registerModelGroupInput); } - } diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java index 1a67977291..910bcd6f93 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java @@ -12,6 +12,7 @@ import static org.mockito.Mockito.when; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -43,9 +44,8 @@ import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.exception.MLResourceNotFoundException; -import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupInput; -import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupRequest; -import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupResponse; +import org.opensearch.ml.common.model.ModelGroupTag; +import org.opensearch.ml.common.transport.model_group.*; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.utils.TestHelper; @@ -353,6 +353,34 @@ public void test_FailedToFindModelGroupException() { assertEquals("Failed to find model group", argumentCaptor.getValue().getMessage()); } + public void test_FailedToExceedTagSizeException() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new IllegalArgumentException("The size of tags cannot be larger than 10")); + return null; + }).when(client).get(any(), any()); + + List tags = new ArrayList<>(); + for (int i = 0; i < 11; i++) { + tags.add(new ModelGroupTag("key" + i, "type" + i)); + } + MLUpdateModelGroupInput updateModelGroupInput = MLUpdateModelGroupInput + .builder() + .name("modelGroupName") + .description("This is a test model group") + .backendRoles(null) + .modelAccessMode(AccessMode.RESTRICTED) + .isAddAllBackendRoles(null) + .tags(tags) + .build(); + MLUpdateModelGroupRequest actionRequest = new MLUpdateModelGroupRequest(updateModelGroupInput); + + transportUpdateModelGroupAction.doExecute(task, actionRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("The size of tags cannot be larger than 10", argumentCaptor.getValue().getMessage()); + } + public void test_FailedToGetModelGroupException() { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/UpdateModelGroupITTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/UpdateModelGroupITTests.java index 19cf5b4bd5..ecfe015522 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/UpdateModelGroupITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/UpdateModelGroupITTests.java @@ -5,11 +5,14 @@ package org.opensearch.ml.action.model_group; +import java.util.List; + import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.opensearch.ml.action.MLCommonsIntegTestCase; import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.model.ModelGroupTag; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupAction; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupRequest; @@ -35,7 +38,14 @@ public void setUp() throws Exception { } private void registerModelGroup() { - MLRegisterModelGroupInput input = new MLRegisterModelGroupInput("mock_model_group_name", "mock_model_group_desc", null, null, null); + MLRegisterModelGroupInput input = new MLRegisterModelGroupInput( + "mock_model_group_name", + "mock_model_group_desc", + null, + null, + null, + null + ); MLRegisterModelGroupRequest createModelGroupRequest = new MLRegisterModelGroupRequest(input); MLRegisterModelGroupResponse response = client().execute(MLRegisterModelGroupAction.INSTANCE, createModelGroupRequest).actionGet(); this.modelGroupId = response.getModelGroupId(); @@ -49,7 +59,8 @@ public void test_update_public_model_group() { "mock_model_group_desc", null, AccessMode.PUBLIC, - false + false, + List.of(new ModelGroupTag("tag1", "String"), new ModelGroupTag("tag2", "Number")) ); MLUpdateModelGroupRequest createModelGroupRequest = new MLUpdateModelGroupRequest(input); client().execute(MLUpdateModelGroupAction.INSTANCE, createModelGroupRequest).actionGet(); @@ -63,7 +74,8 @@ public void test_update_private_model_group() { "mock_model_group_desc", null, AccessMode.PRIVATE, - false + false, + List.of(new ModelGroupTag("tag1", "String"), new ModelGroupTag("tag2", "Number")) ); MLUpdateModelGroupRequest createModelGroupRequest = new MLUpdateModelGroupRequest(input); client().execute(MLUpdateModelGroupAction.INSTANCE, createModelGroupRequest).actionGet(); @@ -76,6 +88,7 @@ public void test_update_model_group_without_access_fields() { "mock_model_group_desc", null, null, + null, null ); MLUpdateModelGroupRequest createModelGroupRequest = new MLUpdateModelGroupRequest(input); @@ -90,7 +103,8 @@ public void test_update_protected_model_group_with_addAllBackendRoles_true() { "mock_model_group_desc", null, AccessMode.RESTRICTED, - true + true, + List.of(new ModelGroupTag("tag1", "String"), new ModelGroupTag("tag2", "Number")) ); MLUpdateModelGroupRequest createModelGroupRequest = new MLUpdateModelGroupRequest(input); client().execute(MLUpdateModelGroupAction.INSTANCE, createModelGroupRequest).actionGet(); @@ -104,6 +118,7 @@ public void test_update_protected_model_group_with_backendRoles_notEmpty() { "mock_model_group_desc", ImmutableList.of("role-1"), AccessMode.RESTRICTED, + null, null ); MLUpdateModelGroupRequest createModelGroupRequest = new MLUpdateModelGroupRequest(input); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java index d5c1347e26..f7ee05fa88 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java @@ -46,7 +46,14 @@ public void setUp() throws Exception { } private void registerModelGroup() throws InterruptedException { - MLRegisterModelGroupInput input = new MLRegisterModelGroupInput("mock_model_group_name", "mock model group desc", null, null, null); + MLRegisterModelGroupInput input = new MLRegisterModelGroupInput( + "mock_model_group_name", + "mock model group desc", + null, + null, + null, + null + ); MLRegisterModelGroupRequest createModelGroupRequest = new MLRegisterModelGroupRequest(input); MLRegisterModelGroupResponse response = client().execute(MLRegisterModelGroupAction.INSTANCE, createModelGroupRequest).actionGet(); this.modelGroupId = response.getModelGroupId(); diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java index f7eb759026..b325b95c38 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java @@ -34,6 +34,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.ml.action.model_group.TransportRegisterModelGroupAction; import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.model.ModelGroupTag; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.indices.MLIndicesHandler; @@ -129,7 +130,7 @@ public void test_SuccessAddAllBackendRolesTrue() { verify(actionListener).onResponse(argumentCaptor.capture()); } - public void test_ModelGroupNameNotUnique() throws IOException {// + public void test_ModelGroupNameNotUnique() throws IOException { SearchResponse searchResponse = createModelGroupSearchResponse(1); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -343,6 +344,7 @@ private MLRegisterModelGroupInput prepareRequest(List backendRoles, Acce .backendRoles(backendRoles) .modelAccessMode(modelAccessMode) .isAddAllBackendRoles(isAddAllBackendRoles) + .tags(List.of(new ModelGroupTag("tag1", "String"), new ModelGroupTag("tag2", "Number"))) .build(); } @@ -355,12 +357,20 @@ private SearchResponse createModelGroupSearchResponse(long totalHits) throws IOE + " \"last_updated_time\": 1684981986069,\n" + " \"_id\": \"model_group_ID\",\n" + " \"name\": \"model_group_IT\",\n" - + " \"description\": \"This is an example description\"\n" + + " \"description\": \"This is an example description\",\n" + + " \"tags\": [\n" + + " {\n" + + " \"key\": \"tag1\",\n" + + " \"type\": \"String\"\n" + + " },\n" + + " { \"key\": \"tag2\",\n" + + " \"type\": \"Number\"\n" + + " }\n" + + " ]\n" + " }"; SearchHit modelGroup = SearchHit.fromXContent(TestHelper.parser(modelContent)); SearchHits hits = new SearchHits(new SearchHit[] { modelGroup }, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), Float.NaN); when(searchResponse.getHits()).thenReturn(hits); return searchResponse; } - } From fcd6c955e2b785bfe6afe89656b59e4f25850650 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Frank=20Woo=28=E5=90=B4=E5=B3=BB=E7=94=B3=29?= Date: Mon, 25 Sep 2023 15:52:57 +0800 Subject: [PATCH 2/3] [FEATURE] tags on model group&version #1303 check update tags whether is null or empty MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Frank Woo(吴峻申) --- .../TransportUpdateModelGroupAction.java | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java index 92591cfb85..79cf299603 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java @@ -9,8 +9,9 @@ import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.utils.MLExceptionUtils.logException; +import com.google.common.collect.ImmutableList; import java.util.*; - +import lombok.extern.log4j.Log4j2; import org.apache.commons.lang3.StringUtils; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; @@ -47,10 +48,6 @@ import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; -import com.google.common.collect.ImmutableList; - -import lombok.extern.log4j.Log4j2; - @Log4j2 public class TransportUpdateModelGroupAction extends HandledTransportAction { @@ -137,6 +134,10 @@ private void deleteTagsOfModelVersion(MLUpdateModelGroupInput updateModelGroupIn List existedTags = mlModelGroup.getTags(); List updatedTags = updateModelGroupInput.getTags(); + if (CollectionUtils.isEmpty(updatedTags)) { + return; + } + List toDeleteTags = new ArrayList<>(); Set updatedTagKeys = new HashSet<>(); @@ -185,7 +186,7 @@ private void updateModelGroup( if (Boolean.TRUE.equals(updateModelGroupInput.getIsAddAllBackendRoles())) { source.put(MLModelGroup.BACKEND_ROLES_FIELD, user.getBackendRoles()); } - if (updateModelGroupInput.getTags() != null && !CollectionUtils.isEmpty(updateModelGroupInput.getTags())) { + if (!CollectionUtils.isEmpty(updateModelGroupInput.getTags())) { source.put(MLModelGroup.TAGS_FIELD, updateModelGroupInput.getTags()); } if (StringUtils.isNotBlank(updateModelGroupInput.getDescription())) { From 17822b60ab31603a3e1aa588369027fcab24afe0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Frank=20Woo=28=E5=90=B4=E5=B3=BB=E7=94=B3=29?= Date: Mon, 25 Sep 2023 16:50:10 +0800 Subject: [PATCH 3/3] [FEATURE] tags on model group&version #1303 check update tags whether is null or empty MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Frank Woo(吴峻申) --- .../model_group/TransportUpdateModelGroupAction.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java index 79cf299603..fd8b63e99b 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java @@ -9,9 +9,8 @@ import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.utils.MLExceptionUtils.logException; -import com.google.common.collect.ImmutableList; import java.util.*; -import lombok.extern.log4j.Log4j2; + import org.apache.commons.lang3.StringUtils; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; @@ -48,6 +47,10 @@ import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import com.google.common.collect.ImmutableList; + +import lombok.extern.log4j.Log4j2; + @Log4j2 public class TransportUpdateModelGroupAction extends HandledTransportAction {