-
Notifications
You must be signed in to change notification settings - Fork 140
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Update Model API POC Signed-off-by: Sicheng Song <[email protected]> * Using GetRequest to get model Signed-off-by: Sicheng Song <[email protected]> * Finalize model update API Signed-off-by: Sicheng Song <[email protected]> * Fix compile Signed-off-by: Sicheng Song <[email protected]> * Fix compileTest Signed-off-by: Sicheng Song <[email protected]> * Add Unit Test Cases for Update Model API Signed-off-by: Sicheng Song <[email protected]> * Tune back test coverage thereshold Signed-off-by: Sicheng Song <[email protected]> * Add more unit tests on Update model API Signed-off-by: Sicheng Song <[email protected]> * Add unit test for TransportUpdateModelAction class Signed-off-by: Sicheng Song <[email protected]> * Fix a test error Signed-off-by: Sicheng Song <[email protected]> * Change exception thrown to failure response Signed-off-by: Sicheng Song <[email protected]> * Move the function judgement to the outter block Signed-off-by: Sicheng Song <[email protected]> * Check if model is undeployed before update model Signed-off-by: Sicheng Song <[email protected]> * Add more unit test for update model API Signed-off-by: Sicheng Song <[email protected]> * Fix unit test due to blocking java 11 CI workflow Signed-off-by: Sicheng Song <[email protected]> * Enabling auto bumping model version during registering to a new model group and address reviewers' other concern Signed-off-by: Sicheng Song <[email protected]> * Autobump new model groups' latest version when register to a new model Signed-off-by: Sicheng Song <[email protected]> * Change the REST API method from POST to PUT Signed-off-by: Sicheng Song <[email protected]> * Change the update REST API endpoint Signed-off-by: Sicheng Song <[email protected]> --------- Signed-off-by: Sicheng Song <[email protected]>
- Loading branch information
Showing
14 changed files
with
2,007 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
18 changes: 18 additions & 0 deletions
18
common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.model; | ||
|
||
import org.opensearch.action.ActionType; | ||
import org.opensearch.action.update.UpdateResponse; | ||
|
||
public class MLUpdateModelAction extends ActionType<UpdateResponse> { | ||
public static MLUpdateModelAction INSTANCE = new MLUpdateModelAction(); | ||
public static final String NAME = "cluster:admin/opensearch/ml/models/update"; | ||
|
||
private MLUpdateModelAction() { | ||
super(NAME, UpdateResponse::new); | ||
} | ||
} |
155 changes: 155 additions & 0 deletions
155
common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.model; | ||
|
||
import lombok.Data; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.common.io.stream.Writeable; | ||
import org.opensearch.core.xcontent.ToXContentObject; | ||
import org.opensearch.core.xcontent.XContentBuilder; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.ml.common.connector.Connector; | ||
import org.opensearch.ml.common.model.MLModelConfig; | ||
import org.opensearch.ml.common.model.TextEmbeddingModelConfig; | ||
|
||
import java.io.IOException; | ||
import java.util.Map; | ||
|
||
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; | ||
import static org.opensearch.ml.common.connector.Connector.createConnector; | ||
|
||
@Data | ||
public class MLUpdateModelInput implements ToXContentObject, Writeable { | ||
|
||
public static final String MODEL_ID_FIELD = "model_id"; // mandatory | ||
public static final String DESCRIPTION_FIELD = "description"; // optional | ||
public static final String MODEL_VERSION_FIELD = "model_version"; // optional | ||
public static final String MODEL_NAME_FIELD = "name"; // optional | ||
public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; // optional | ||
public static final String MODEL_CONFIG_FIELD = "model_config"; // optional | ||
public static final String CONNECTOR_ID_FIELD = "connector_id"; // optional | ||
|
||
@Getter | ||
private String modelId; | ||
private String description; | ||
private String version; | ||
private String name; | ||
private String modelGroupId; | ||
private MLModelConfig modelConfig; | ||
private String connectorId; | ||
|
||
@Builder(toBuilder = true) | ||
public MLUpdateModelInput(String modelId, String description, String version, String name, String modelGroupId, MLModelConfig modelConfig, String connectorId) { | ||
this.modelId = modelId; | ||
this.description = description; | ||
this.version = version; | ||
this.name = name; | ||
this.modelGroupId = modelGroupId; | ||
this.modelConfig = modelConfig; | ||
this.connectorId = connectorId; | ||
} | ||
|
||
public MLUpdateModelInput(StreamInput in) throws IOException { | ||
this.modelId = in.readString(); | ||
this.description = in.readOptionalString(); | ||
this.version = in.readOptionalString(); | ||
this.name = in.readOptionalString(); | ||
this.modelGroupId = in.readOptionalString(); | ||
if (in.readBoolean()) { | ||
modelConfig = new TextEmbeddingModelConfig(in); | ||
} | ||
this.connectorId = in.readOptionalString(); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
builder.field(MODEL_ID_FIELD, modelId); | ||
if (name != null) { | ||
builder.field(MODEL_NAME_FIELD, name); | ||
} | ||
if (description != null) { | ||
builder.field(DESCRIPTION_FIELD, description); | ||
} | ||
if (version != null) { | ||
builder.field(MODEL_VERSION_FIELD, version); | ||
} | ||
if (modelGroupId != null) { | ||
builder.field(MODEL_GROUP_ID_FIELD, modelGroupId); | ||
} | ||
if (modelConfig != null) { | ||
builder.field(MODEL_CONFIG_FIELD, modelConfig); | ||
} | ||
if (connectorId != null) { | ||
builder.field(CONNECTOR_ID_FIELD, connectorId); | ||
} | ||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeString(modelId); | ||
out.writeOptionalString(description); | ||
out.writeOptionalString(version); | ||
out.writeOptionalString(name); | ||
out.writeOptionalString(modelGroupId); | ||
if (modelConfig != null) { | ||
out.writeBoolean(true); | ||
modelConfig.writeTo(out); | ||
} else { | ||
out.writeBoolean(false); | ||
} | ||
out.writeOptionalString(connectorId); | ||
} | ||
|
||
public static MLUpdateModelInput parse(XContentParser parser) throws IOException { | ||
String modelId = null; | ||
String description = null; | ||
String version = null; | ||
String name = null; | ||
String modelGroupId = null; | ||
MLModelConfig modelConfig = null; | ||
String connectorId = null; | ||
|
||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); | ||
while (parser.nextToken() != XContentParser.Token.END_OBJECT) { | ||
String fieldName = parser.currentName(); | ||
parser.nextToken(); | ||
switch (fieldName) { | ||
case MODEL_ID_FIELD: | ||
modelId = parser.text(); | ||
break; | ||
case DESCRIPTION_FIELD: | ||
description = parser.text(); | ||
break; | ||
case MODEL_NAME_FIELD: | ||
name = parser.text(); | ||
break; | ||
case MODEL_VERSION_FIELD: | ||
version = parser.text(); | ||
break; | ||
case MODEL_GROUP_ID_FIELD: | ||
modelGroupId = parser.text(); | ||
break; | ||
case MODEL_CONFIG_FIELD: | ||
modelConfig = TextEmbeddingModelConfig.parse(parser); | ||
break; | ||
case CONNECTOR_ID_FIELD: | ||
connectorId = parser.text(); | ||
break; | ||
default: | ||
parser.skipChildren(); | ||
break; | ||
} | ||
} | ||
// Model ID can only be set through RestRequest. Model version can only be set automatically. | ||
return new MLUpdateModelInput(modelId, description, version, name, modelGroupId, modelConfig, connectorId); | ||
} | ||
} |
75 changes: 75 additions & 0 deletions
75
common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.model; | ||
|
||
import lombok.AccessLevel; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.ToString; | ||
import lombok.experimental.FieldDefaults; | ||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.ActionRequestValidationException; | ||
import org.opensearch.core.common.io.stream.InputStreamStreamInput; | ||
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
|
||
import java.io.ByteArrayInputStream; | ||
import java.io.ByteArrayOutputStream; | ||
import java.io.IOException; | ||
import java.io.UncheckedIOException; | ||
|
||
import static org.opensearch.action.ValidateActions.addValidationError; | ||
|
||
@Getter | ||
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) | ||
@ToString | ||
public class MLUpdateModelRequest extends ActionRequest { | ||
|
||
MLUpdateModelInput updateModelInput; | ||
|
||
@Builder | ||
public MLUpdateModelRequest(MLUpdateModelInput updateModelInput) { | ||
this.updateModelInput = updateModelInput; | ||
} | ||
|
||
public MLUpdateModelRequest(StreamInput in) throws IOException { | ||
super(in); | ||
updateModelInput = new MLUpdateModelInput(in); | ||
} | ||
|
||
@Override | ||
public ActionRequestValidationException validate() { | ||
ActionRequestValidationException exception = null; | ||
if (updateModelInput == null) { | ||
exception = addValidationError("Update Model Input can't be null", exception); | ||
} | ||
|
||
return exception; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
this.updateModelInput.writeTo(out); | ||
} | ||
|
||
public static MLUpdateModelRequest fromActionRequest(ActionRequest actionRequest){ | ||
if (actionRequest instanceof MLUpdateModelRequest) { | ||
return (MLUpdateModelRequest) actionRequest; | ||
} | ||
|
||
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); | ||
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { | ||
actionRequest.writeTo(osso); | ||
try (StreamInput in = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { | ||
return new MLUpdateModelRequest(in); | ||
} | ||
} catch (IOException e) { | ||
throw new UncheckedIOException("Failed to parse ActionRequest into MLUpdateModelRequest", e); | ||
} | ||
} | ||
} |
Oops, something went wrong.