Skip to content

Commit

Permalink
Model & user level throttling (opensearch-project#1800)
Browse files Browse the repository at this point in the history
* Enable in-place update model
---------

Signed-off-by: Sicheng Song <[email protected]>
  • Loading branch information
b4sjoo authored Dec 26, 2023
1 parent 75155b9 commit 50788de
Show file tree
Hide file tree
Showing 111 changed files with 8,466 additions and 483 deletions.
25 changes: 24 additions & 1 deletion common/src/main/java/org/opensearch/ml/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.connector.AbstractConnector;
import org.opensearch.ml.common.controller.MLModelController;

import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.APPLICATION_TYPE_FIELD;
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD;
Expand Down Expand Up @@ -54,12 +55,14 @@ public class CommonValue {
public static final String ML_MODEL_INDEX = ".plugins-ml-model";
public static final String ML_TASK_INDEX = ".plugins-ml-task";
public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 8;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 9;
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2;
public static final String ML_CONFIG_INDEX = ".plugins-ml-config";
public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 2;
public static final String ML_MODEL_CONTROLLER_INDEX = ".plugins-ml-controller";
public static final Integer ML_MODEL_CONTROLLER_INDEX_SCHEMA_VERSION = 1;
public static final String ML_MAP_RESPONSE_KEY = "response";
public static final String ML_AGENT_INDEX = ".plugins-ml-agent";
public static final Integer ML_AGENT_INDEX_SCHEMA_VERSION = 1;
Expand Down Expand Up @@ -222,6 +225,15 @@ public class CommonValue {
+ MODEL_MAX_LENGTH_FIELD + "\":{\"type\":\"integer\"},\""
+ ALL_CONFIG_FIELD + "\":{\"type\":\"text\"}}},\n"
+ " \""
+ MLModel.IS_ENABLED_FIELD
+ "\" : {\"type\": \"boolean\"},\n"
+ " \""
+ MLModel.IS_MODEL_CONTROLLER_ENABLED_FIELD
+ "\" : {\"type\": \"boolean\"},\n"
+ " \""
+ MLModel.MODEL_RATE_LIMITER_CONFIG_FIELD
+ "\" : {\"type\": \"flat_object\"},\n"
+ " \""
+ MLModel.MODEL_CONTENT_HASH_VALUE_FIELD
+ "\" : {\"type\": \"keyword\"},\n"
+ " \""
Expand Down Expand Up @@ -350,6 +362,17 @@ public class CommonValue {
+ " }\n"
+ "}";

public static final String ML_MODEL_CONTROLLER_INDEX_MAPPING = "{\n"
+ " \"_meta\": {\"schema_version\": "
+ ML_MODEL_CONTROLLER_INDEX_SCHEMA_VERSION
+ "},\n"
+ " \"properties\": {\n"
+ " \""
+ MLModelController.USER_RATE_LIMITER_CONFIG
+ "\" : {\"type\": \"flat_object\"}\n"
+ " }\n"
+ "}";

public static final String ML_AGENT_INDEX_MAPPING = "{\n"
+ " \"_meta\": {\"schema_version\": "
+ ML_AGENT_INDEX_SCHEMA_VERSION
Expand Down
53 changes: 52 additions & 1 deletion common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.controller.MLRateLimiter;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
Expand Down Expand Up @@ -50,9 +51,13 @@ public class MLModel implements ToXContentObject {
public static final String MODEL_FORMAT_FIELD = "model_format";
public static final String MODEL_STATE_FIELD = "model_state";
public static final String MODEL_CONTENT_SIZE_IN_BYTES_FIELD = "model_content_size_in_bytes";
//SHA256 hash value of model content.
// SHA256 hash value of model content.
public static final String MODEL_CONTENT_HASH_VALUE_FIELD = "model_content_hash_value";

// Model level quota and throttling control
public static final String IS_ENABLED_FIELD = "is_enabled";
public static final String MODEL_RATE_LIMITER_CONFIG_FIELD = "model_rate_limiter_config";
public static final String IS_MODEL_CONTROLLER_ENABLED_FIELD = "is_model_controller_enabled";
public static final String MODEL_CONFIG_FIELD = "model_config";
public static final String CREATED_TIME_FIELD = "created_time";
public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time";
Expand Down Expand Up @@ -94,6 +99,9 @@ public class MLModel implements ToXContentObject {
private Long modelContentSizeInBytes;
private String modelContentHash;
private MLModelConfig modelConfig;
private Boolean isEnabled;
private Boolean isModelControllerEnabled;
private MLRateLimiter modelRateLimiterConfig;
private Instant createdTime;
private Instant lastUpdateTime;
private Instant lastRegisteredTime;
Expand Down Expand Up @@ -131,6 +139,9 @@ public MLModel(String name,
MLModelState modelState,
Long modelContentSizeInBytes,
String modelContentHash,
Boolean isEnabled,
Boolean isModelControllerEnabled,
MLRateLimiter modelRateLimiterConfig,
MLModelConfig modelConfig,
Instant createdTime,
Instant lastUpdateTime,
Expand Down Expand Up @@ -158,6 +169,9 @@ public MLModel(String name,
this.modelState = modelState;
this.modelContentSizeInBytes = modelContentSizeInBytes;
this.modelContentHash = modelContentHash;
this.isEnabled = isEnabled;
this.isModelControllerEnabled = isModelControllerEnabled;
this.modelRateLimiterConfig = modelRateLimiterConfig;
this.modelConfig = modelConfig;
this.createdTime = createdTime;
this.lastUpdateTime = lastUpdateTime;
Expand Down Expand Up @@ -204,6 +218,11 @@ public MLModel(StreamInput input) throws IOException{
modelConfig = new TextEmbeddingModelConfig(input);
}
}
isEnabled = input.readOptionalBoolean();
isModelControllerEnabled = input.readOptionalBoolean();
if (input.readBoolean()) {
modelRateLimiterConfig = new MLRateLimiter(input);
}
createdTime = input.readOptionalInstant();
lastUpdateTime = input.readOptionalInstant();
lastRegisteredTime = input.readOptionalInstant();
Expand Down Expand Up @@ -258,6 +277,14 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeOptionalBoolean(isEnabled);
out.writeOptionalBoolean(isModelControllerEnabled);
if (modelRateLimiterConfig != null) {
out.writeBoolean(true);
modelRateLimiterConfig.writeTo(out);
} else {
out.writeBoolean(false);
}
out.writeOptionalInstant(createdTime);
out.writeOptionalInstant(lastUpdateTime);
out.writeOptionalInstant(lastRegisteredTime);
Expand Down Expand Up @@ -321,6 +348,15 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (modelConfig != null) {
builder.field(MODEL_CONFIG_FIELD, modelConfig);
}
if (isEnabled != null) {
builder.field(IS_ENABLED_FIELD, isEnabled);
}
if (isModelControllerEnabled != null) {
builder.field(IS_MODEL_CONTROLLER_ENABLED_FIELD, isModelControllerEnabled);
}
if (modelRateLimiterConfig != null) {
builder.field(MODEL_RATE_LIMITER_CONFIG_FIELD, modelRateLimiterConfig);
}
if (createdTime != null) {
builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli());
}
Expand Down Expand Up @@ -389,6 +425,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
Long modelContentSizeInBytes = null;
String modelContentHash = null;
MLModelConfig modelConfig = null;
Boolean isEnabled = null;
Boolean isModelControllerEnabled = null;
MLRateLimiter modelRateLimiterConfig = null;
Instant createdTime = null;
Instant lastUpdateTime = null;
Instant lastUploadedTime = null;
Expand Down Expand Up @@ -474,6 +513,15 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
modelConfig = TextEmbeddingModelConfig.parse(parser);
}
break;
case IS_ENABLED_FIELD:
isEnabled = parser.booleanValue();
break;
case IS_MODEL_CONTROLLER_ENABLED_FIELD:
isModelControllerEnabled = parser.booleanValue();
break;
case MODEL_RATE_LIMITER_CONFIG_FIELD:
modelRateLimiterConfig = MLRateLimiter.parse(parser);
break;
case PLANNING_WORKER_NODE_COUNT_FIELD:
planningWorkerNodeCount = parser.intValue();
break;
Expand Down Expand Up @@ -540,6 +588,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
.modelContentSizeInBytes(modelContentSizeInBytes)
.modelContentHash(modelContentHash)
.modelConfig(modelConfig)
.isEnabled(isEnabled)
.isModelControllerEnabled(isModelControllerEnabled)
.modelRateLimiterConfig(modelRateLimiterConfig)
.createdTime(createdTime)
.lastUpdateTime(lastUpdateTime)
.lastRegisteredTime(lastRegisteredTime == null? lastUploadedTime : lastRegisteredTime)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.controller;

import lombok.Builder;
import lombok.Data;
import lombok.Getter;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;

@Data
public class MLModelController implements ToXContentObject, Writeable {

public static final String MODEL_ID_FIELD = "model_id"; // mandatory
public static final String USER_RATE_LIMITER_CONFIG = "user_rate_limiter_config";

@Getter
private String modelId;
// The String is the username field where the MLRateLimiter is its corresponding rate limiter config.
private Map<String, MLRateLimiter> userRateLimiterConfig;

@Builder(toBuilder = true)
public MLModelController(String modelId, Map<String, MLRateLimiter> userRateLimiterConfig) {
this.modelId = modelId;
this.userRateLimiterConfig = userRateLimiterConfig;
}

public static MLModelController parse(XContentParser parser) throws IOException {
String modelId = null;
Map<String, MLRateLimiter> userRateLimiterConfig = new HashMap<>();

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case MODEL_ID_FIELD:
modelId = parser.text();
break;
case USER_RATE_LIMITER_CONFIG:
Map<String, String> userRateLimiterConfigStringMap = getParameterMap(parser.map());
userRateLimiterConfigStringMap.forEach((user, rateLimiterString) -> {
try {
XContentParser rateLimiterParser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, rateLimiterString);
rateLimiterParser.nextToken();
MLRateLimiter rateLimiter = MLRateLimiter.parse(rateLimiterParser);
if (!rateLimiter.isEmpty()) {
userRateLimiterConfig.put(user, rateLimiter);
}
} catch (IOException e) {
throw new RuntimeException(e);
}
});
break;
default:
parser.skipChildren();
break;
}
}
// Model ID can only be set through RestRequest.
return new MLModelController(modelId, userRateLimiterConfig);
}

public MLModelController(StreamInput in) throws IOException{
modelId = in.readString();
if (in.readBoolean()) {
userRateLimiterConfig = in.readMap(StreamInput::readString, MLRateLimiter::new);
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
if (userRateLimiterConfig != null) {
out.writeBoolean(true);
out.writeMap(userRateLimiterConfig, StreamOutput::writeString, (streamOutput, rateLimiter) -> rateLimiter.writeTo(streamOutput));
} else {
out.writeBoolean(false);
}
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(MODEL_ID_FIELD, modelId);
if (userRateLimiterConfig != null) {
builder.field(USER_RATE_LIMITER_CONFIG, userRateLimiterConfig);
}
builder.endObject();
return builder;
}


/**
* Checks if a deployment is required after updating the MLModelController.
*
* @param updateContent The updated MLModelController object.
* @return True if a deployment is required, false otherwise.
*/
public boolean isDeployRequiredAfterUpdate(MLModelController updateContent) {
if (updateContent != null && updateContent.getUserRateLimiterConfig() != null && !updateContent.getUserRateLimiterConfig().isEmpty()) {
Map<String, MLRateLimiter> updateUserRateLimiterConfig = updateContent.getUserRateLimiterConfig();
for (Map.Entry<String, MLRateLimiter> entry : updateUserRateLimiterConfig.entrySet()) {
String newUser = entry.getKey();
MLRateLimiter newRateLimiter = entry.getValue();
if (this.userRateLimiterConfig.containsKey(newUser)) {
MLRateLimiter oldRateLimiter = this.userRateLimiterConfig.get(newUser);
if (MLRateLimiter.isDeployRequiredAfterUpdate(oldRateLimiter, newRateLimiter)) {
return true;
}
} else {
if (newRateLimiter.isValid()) {
return true;
}
}
}
}
return false;
}

public void update(MLModelController updateContent) {
Map<String, MLRateLimiter> updateUserRateLimiterConfig = updateContent.getUserRateLimiterConfig();
if (updateUserRateLimiterConfig != null && !updateUserRateLimiterConfig.isEmpty()) {
updateUserRateLimiterConfig.forEach((user, rateLimiter) -> {
// rateLimiter can't be null due to parsing exception
if (this.userRateLimiterConfig.containsKey(user)) {
this.userRateLimiterConfig.get(user).update(rateLimiter);
} else {
this.userRateLimiterConfig.put(user, rateLimiter);
}
});
}
}
}
Loading

0 comments on commit 50788de

Please sign in to comment.