Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] tags on model group #1370

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
644 changes: 344 additions & 300 deletions common/src/main/java/org/opensearch/ml/common/CommonValue.java
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not change the format of this file, the format change caused a lot of differences.

Large diffs are not rendered by default.

28 changes: 26 additions & 2 deletions common/src/main/java/org/opensearch/ml/common/MLModelGroup.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
Expand All @@ -21,6 +22,7 @@
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.model.ModelGroupTag;

@Getter
public class MLModelGroup implements ToXContentObject {
Expand All @@ -34,7 +36,7 @@ public class MLModelGroup implements ToXContentObject {
public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; //unique ID assigned to each model group
public static final String CREATED_TIME_FIELD = "created_time"; //model group created time stamp
public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; //updated whenever a new model version is created

public static final String TAGS_FIELD = "tags";

@Setter
private String name;
Expand All @@ -50,10 +52,12 @@ public class MLModelGroup implements ToXContentObject {
private Instant createdTime;
private Instant lastUpdatedTime;

private List<ModelGroupTag> tags;

@Builder(toBuilder = true)
public MLModelGroup(String name, String description, int latestVersion,
List<String> backendRoles, User owner, String access,
List<String> backendRoles, User owner,List<ModelGroupTag> tags,
String access,
String modelGroupId,
Instant createdTime,
Instant lastUpdatedTime) {
Expand All @@ -69,6 +73,7 @@ public MLModelGroup(String name, String description, int latestVersion,
this.modelGroupId = modelGroupId;
this.createdTime = createdTime;
this.lastUpdatedTime = lastUpdatedTime;
this.tags = tags;
}


Expand All @@ -84,6 +89,8 @@ public MLModelGroup(StreamInput input) throws IOException{
} else {
this.owner = null;
}

tags = input.readList(ModelGroupTag::new);
access = input.readOptionalString();
modelGroupId = input.readOptionalString();
createdTime = input.readOptionalInstant();
Expand All @@ -106,6 +113,8 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}

out.writeList(Objects.requireNonNullElseGet(tags, ArrayList::new));
out.writeOptionalString(access);
out.writeOptionalString(modelGroupId);
out.writeOptionalInstant(createdTime);
Expand All @@ -126,6 +135,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (owner != null) {
builder.field(OWNER, owner);
}

if (!CollectionUtils.isEmpty(tags)) {
builder.field(TAGS_FIELD, tags);
}
if (access != null) {
builder.field(ACCESS, access);
}
Expand All @@ -152,6 +165,7 @@ public static MLModelGroup parse(XContentParser parser) throws IOException {
String modelGroupId = null;
Instant createdTime = null;
Instant lastUpdateTime = null;
List<ModelGroupTag> tags = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -178,6 +192,15 @@ public static MLModelGroup parse(XContentParser parser) throws IOException {
case OWNER:
owner = User.parse(parser);
break;

case TAGS_FIELD:
tags = new ArrayList<>();
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
tags.add(ModelGroupTag.parse(parser));
}

break;
case ACCESS:
access = parser.text();
break;
Expand All @@ -201,6 +224,7 @@ public static MLModelGroup parse(XContentParser parser) throws IOException {
.backendRoles(backendRoles)
.latestVersion(latestVersion)
.owner(owner)
.tags(tags)
.access(access)
.modelGroupId(modelGroupId)
.createdTime(createdTime)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.ml.common.model;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import java.util.*;
import org.opensearch.common.Nullable;
import org.opensearch.common.inject.internal.ToStringBuilder;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;

public final class ModelGroupTag implements Writeable, ToXContent {
public static final String TAG_KEY_FIELD = "key";
public static final String TAG_TYPE_FIELD = "type";

@Nullable private final String key;
@Nullable private final String type;

public ModelGroupTag() {
key = "";
type = "";
}

public ModelGroupTag(@Nullable final String key, @Nullable final String type) {
this.key = key;
this.type = type;
}

public ModelGroupTag(String json) {
if (Strings.isNullOrEmpty(json)) {
throw new IllegalArgumentException("Response json cannot be null");
}

Map<String, Object> mapValue =
XContentHelper.convertToMap(JsonXContent.jsonXContent, json, false);
key = (String) mapValue.get(TAG_KEY_FIELD);
type = (String) mapValue.get(TAG_TYPE_FIELD);
}

public ModelGroupTag(StreamInput in) throws IOException {
this.key = in.readString();
this.type = in.readString();
}

public static ModelGroupTag parse(XContentParser parser) throws IOException {
String key = "";
String type = "";

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();
switch (fieldName) {
case TAG_KEY_FIELD:
key = parser.text();
break;
case TAG_TYPE_FIELD:
type = parser.text();
break;
default:
break;
}
}
return new ModelGroupTag(key, type);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject().field(TAG_KEY_FIELD, key).field(TAG_TYPE_FIELD, type);
return builder.endObject();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(key);
out.writeString(type);
}

@Override
public String toString() {
ToStringBuilder builder = new ToStringBuilder(this.getClass());
builder.add(TAG_KEY_FIELD, key);
builder.add(TAG_TYPE_FIELD, type);
return builder.toString();
}

@Override
public boolean equals(Object obj) {
if (!(obj instanceof ModelGroupTag)) {
return false;
}
ModelGroupTag that = (ModelGroupTag) obj;
return this.key.equals(that.key) && this.type.equals(that.type);
}

@Nullable
public String getKey() {
return key;
}

@Nullable
public String getType() {
return type;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,23 @@

package org.opensearch.ml.common.transport.model_group;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import lombok.Builder;
import lombok.Data;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.AccessMode;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import org.opensearch.ml.common.model.ModelGroupTag;

@Data
public class MLRegisterModelGroupInput implements ToXContentObject, Writeable{
Expand All @@ -30,15 +31,17 @@ public class MLRegisterModelGroupInput implements ToXContentObject, Writeable{
public static final String BACKEND_ROLES_FIELD = "backend_roles"; //optional
public static final String MODEL_ACCESS_MODE = "access_mode"; //optional
public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; //optional
public static final String TAGS_FIELD = "tags"; //optional

private String name;
private String description;
private List<String> backendRoles;
private AccessMode modelAccessMode;
private Boolean isAddAllBackendRoles;
private List<ModelGroupTag> tags;

@Builder(toBuilder = true)
public MLRegisterModelGroupInput(String name, String description, List<String> backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles) {
public MLRegisterModelGroupInput(String name, String description, List<String> backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles,List<ModelGroupTag> tags) {
if (name == null) {
throw new IllegalArgumentException("model group name is null");
}
Expand All @@ -47,6 +50,7 @@ public MLRegisterModelGroupInput(String name, String description, List<String> b
this.backendRoles = backendRoles;
this.modelAccessMode = modelAccessMode;
this.isAddAllBackendRoles = isAddAllBackendRoles;
this.tags = tags;
}

public MLRegisterModelGroupInput(StreamInput in) throws IOException{
Expand All @@ -57,6 +61,7 @@ public MLRegisterModelGroupInput(StreamInput in) throws IOException{
modelAccessMode = in.readEnum(AccessMode.class);
}
this.isAddAllBackendRoles = in.readOptionalBoolean();
this.tags=in.readList(ModelGroupTag::new);
}

@Override
Expand All @@ -76,6 +81,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeOptionalBoolean(isAddAllBackendRoles);
if(!CollectionUtils.isEmpty(tags)){
out.writeList(tags);
}
}

@Override
Expand All @@ -94,6 +102,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (isAddAllBackendRoles != null) {
builder.field(ADD_ALL_BACKEND_ROLES, isAddAllBackendRoles);
}
if(!CollectionUtils.isEmpty(tags)){
builder.field(TAGS_FIELD, tags);
}
builder.endObject();
return builder;
}
Expand All @@ -104,6 +115,7 @@ public static MLRegisterModelGroupInput parse(XContentParser parser) throws IOEx
List<String> backendRoles = null;
AccessMode modelAccessMode = null;
Boolean isAddAllBackendRoles = null;
List<ModelGroupTag> tags=null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -129,12 +141,19 @@ public static MLRegisterModelGroupInput parse(XContentParser parser) throws IOEx
case ADD_ALL_BACKEND_ROLES:
isAddAllBackendRoles = parser.booleanValue();
break;
case TAGS_FIELD:
tags = new ArrayList<>();
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
tags.add(ModelGroupTag.parse(parser));
}
break;
default:
parser.skipChildren();
break;
}
}
return new MLRegisterModelGroupInput(name, description, backendRoles, modelAccessMode, isAddAllBackendRoles);
return new MLRegisterModelGroupInput(name, description, backendRoles, modelAccessMode, isAddAllBackendRoles,tags);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,13 @@ public static MLRegisterModelGroupRequest fromActionRequest(ActionRequest action
return (MLRegisterModelGroupRequest) actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLRegisterModelGroupRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("Failed to parse ActionRequest into MLCreateModelMetaRequest", e);
}

}
}
Loading