-
Notifications
You must be signed in to change notification settings - Fork 140
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Sicheng Song <[email protected]>
- Loading branch information
Showing
9 changed files
with
713 additions
and
2 deletions.
There are no files selected for viewing
18 changes: 18 additions & 0 deletions
18
common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.model; | ||
|
||
import org.opensearch.action.ActionType; | ||
import org.opensearch.action.update.UpdateResponse; | ||
|
||
public class MLUpdateModelAction extends ActionType<UpdateResponse> { | ||
public static MLUpdateModelAction INSTANCE = new MLUpdateModelAction(); | ||
public static final String NAME = "cluster:admin/opensearch/ml/models/update"; | ||
|
||
private MLUpdateModelAction() { | ||
super(NAME, UpdateResponse::new); | ||
} | ||
} |
172 changes: 172 additions & 0 deletions
172
common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.model; | ||
|
||
import lombok.Data; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.common.io.stream.Writeable; | ||
import org.opensearch.core.xcontent.ToXContentObject; | ||
import org.opensearch.core.xcontent.XContentBuilder; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.ml.common.connector.Connector; | ||
import org.opensearch.ml.common.model.MLModelConfig; | ||
import org.opensearch.ml.common.model.MLModelFormat; | ||
import org.opensearch.ml.common.model.TextEmbeddingModelConfig; | ||
|
||
import java.io.IOException; | ||
|
||
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; | ||
import static org.opensearch.ml.common.connector.Connector.createConnector; | ||
|
||
@Data | ||
public class MLUpdateModelInput implements ToXContentObject, Writeable { | ||
|
||
public static final String MODEL_ID_FIELD = "model_id"; // mandatory | ||
public static final String DESCRIPTION_FIELD = "description"; // optional | ||
public static final String MODEL_NAME_FIELD = "name"; // optional | ||
public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; // optional | ||
public static final String MODEL_FORMAT_FIELD = "model_format"; // optional | ||
public static final String MODEL_CONFIG_FIELD = "model_config"; // optional | ||
public static final String CONNECTOR_FIELD = "connector"; // optional | ||
public static final String CONNECTOR_ID_FIELD = "connector_id"; // optional | ||
|
||
@Getter | ||
private String modelId; | ||
private String description; | ||
private String name; | ||
private String modelGroupId; | ||
private MLModelFormat modelFormat; | ||
private MLModelConfig modelConfig; | ||
private Connector connector; | ||
private String connectorId; | ||
|
||
@Builder(toBuilder = true) | ||
public MLUpdateModelInput(String modelId, String description, String name, String modelGroupId, MLModelFormat modelFormat, MLModelConfig modelConfig, Connector connector, String connectorId) { | ||
this.modelId = modelId; | ||
this.description = description; | ||
this.name = name; | ||
this.modelGroupId = modelGroupId; | ||
this.modelFormat = modelFormat; | ||
this.modelConfig = modelConfig; | ||
this.connector = connector; | ||
this.connectorId = connectorId; | ||
} | ||
|
||
public MLUpdateModelInput(StreamInput in) throws IOException { | ||
this.modelId = in.readString(); | ||
this.description = in.readOptionalString(); | ||
this.name = in.readOptionalString(); | ||
this.modelGroupId = in.readOptionalString(); | ||
if (in.readBoolean()) { | ||
modelFormat = in.readEnum(MLModelFormat.class); | ||
} | ||
if (in.readBoolean()) { | ||
modelConfig = new TextEmbeddingModelConfig(in); | ||
} | ||
if (in.readBoolean()) { | ||
connector = Connector.fromStream(in); | ||
} | ||
this.connectorId = in.readOptionalString(); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
builder.field(MODEL_ID_FIELD, modelId); | ||
if (description != null) { | ||
builder.field(DESCRIPTION_FIELD, description); | ||
} | ||
if (name != null) { | ||
builder.field(MODEL_NAME_FIELD, name); | ||
} | ||
if (modelGroupId != null) { | ||
builder.field(MODEL_GROUP_ID_FIELD, modelGroupId); | ||
} | ||
if (modelFormat != null) { | ||
builder.field(MODEL_FORMAT_FIELD, modelFormat); | ||
} | ||
if (modelConfig != null) { | ||
builder.field(MODEL_CONFIG_FIELD, modelConfig); | ||
} | ||
if (connector != null) { | ||
builder.field(CONNECTOR_FIELD, connector); | ||
} | ||
if (connectorId != null) { | ||
builder.field(CONNECTOR_ID_FIELD, connectorId); | ||
} | ||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeString(modelId); | ||
out.writeOptionalString(description); | ||
out.writeOptionalString(name); | ||
out.writeOptionalString(modelGroupId); | ||
if (modelFormat != null) { | ||
out.writeBoolean(true); | ||
out.writeEnum(modelFormat); | ||
} else { | ||
out.writeBoolean(false); | ||
} | ||
if (modelConfig != null) { | ||
out.writeBoolean(true); | ||
modelConfig.writeTo(out); | ||
} else { | ||
out.writeBoolean(false); | ||
} | ||
if (connector != null) { | ||
out.writeBoolean(true); | ||
connector.writeTo(out); | ||
} else { | ||
out.writeBoolean(false); | ||
} | ||
out.writeOptionalString(connectorId); | ||
} | ||
|
||
public static MLUpdateModelInput parse(XContentParser parser, String modelId) throws IOException { | ||
MLUpdateModelInput input = new MLUpdateModelInput(modelId, null, null, null, null, null, null, null); | ||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); | ||
while (parser.nextToken() != XContentParser.Token.END_OBJECT) { | ||
String fieldName = parser.currentName(); | ||
parser.nextToken(); | ||
switch (fieldName) { | ||
case MODEL_ID_FIELD: | ||
input.setModelId(parser.text()); | ||
break; | ||
case DESCRIPTION_FIELD: | ||
input.setDescription(parser.text()); | ||
break; | ||
case MODEL_NAME_FIELD: | ||
input.setName(parser.text()); | ||
break; | ||
case MODEL_GROUP_ID_FIELD: | ||
input.setModelGroupId(parser.text()); | ||
break; | ||
case MODEL_FORMAT_FIELD: | ||
input.setModelFormat(MLModelFormat.from(parser.text())); | ||
break; | ||
case MODEL_CONFIG_FIELD: | ||
input.setModelConfig(TextEmbeddingModelConfig.parse(parser)); | ||
break; | ||
case CONNECTOR_FIELD: | ||
input.setConnector(createConnector(parser)); | ||
break; | ||
case CONNECTOR_ID_FIELD: | ||
input.setConnectorId(parser.text()); | ||
break; | ||
default: | ||
parser.skipChildren(); | ||
break; | ||
} | ||
} | ||
return input; | ||
} | ||
} |
75 changes: 75 additions & 0 deletions
75
common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.model; | ||
|
||
import lombok.AccessLevel; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.ToString; | ||
import lombok.experimental.FieldDefaults; | ||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.ActionRequestValidationException; | ||
import org.opensearch.core.common.io.stream.InputStreamStreamInput; | ||
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
|
||
import java.io.ByteArrayInputStream; | ||
import java.io.ByteArrayOutputStream; | ||
import java.io.IOException; | ||
import java.io.UncheckedIOException; | ||
|
||
import static org.opensearch.action.ValidateActions.addValidationError; | ||
|
||
@Getter | ||
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) | ||
@ToString | ||
public class MLUpdateModelRequest extends ActionRequest { | ||
|
||
MLUpdateModelInput updateModelInput; | ||
|
||
@Builder | ||
public MLUpdateModelRequest(MLUpdateModelInput updateModelInput) { | ||
this.updateModelInput = updateModelInput; | ||
} | ||
|
||
public MLUpdateModelRequest(StreamInput in) throws IOException { | ||
super(in); | ||
updateModelInput = new MLUpdateModelInput(in); | ||
} | ||
|
||
@Override | ||
public ActionRequestValidationException validate() { | ||
ActionRequestValidationException exception = null; | ||
if (updateModelInput == null) { | ||
exception = addValidationError("Update Model Input can't be null", exception); | ||
} | ||
|
||
return exception; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
this.updateModelInput.writeTo(out); | ||
} | ||
|
||
public static MLUpdateModelRequest fromActionRequest(ActionRequest actionRequest){ | ||
if (actionRequest instanceof MLUpdateModelRequest) { | ||
return (MLUpdateModelRequest) actionRequest; | ||
} | ||
|
||
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); | ||
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { | ||
actionRequest.writeTo(osso); | ||
try (StreamInput in = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { | ||
return new MLUpdateModelRequest(in); | ||
} | ||
} catch (IOException e) { | ||
throw new UncheckedIOException("Failed to parse ActionRequest into MLUpdateModelRequest", e); | ||
} | ||
} | ||
} |
45 changes: 45 additions & 0 deletions
45
common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelResponse.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.model; | ||
|
||
import lombok.Getter; | ||
import org.opensearch.core.action.ActionResponse; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.xcontent.ToXContentObject; | ||
import org.opensearch.core.xcontent.XContentBuilder; | ||
|
||
import java.io.IOException; | ||
|
||
@Getter | ||
public class MLUpdateModelResponse extends ActionResponse implements ToXContentObject { | ||
|
||
public static final String STATUS_FIELD = "status"; | ||
|
||
private String status; | ||
|
||
public MLUpdateModelResponse(StreamInput in) throws IOException { | ||
super(in); | ||
this.status = in.readString(); | ||
} | ||
|
||
public MLUpdateModelResponse(String status) { | ||
this.status = status; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeString(status); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
builder.field(STATUS_FIELD, status); | ||
builder.endObject(); | ||
return builder; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.