Skip to content

Commit

Permalink
Merge branch 'main' into backport/backport-1329-to-main
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo authored Sep 27, 2023
2 parents c27ebd3 + b42104c commit 9f9bafa
Show file tree
Hide file tree
Showing 41 changed files with 1,201 additions and 242 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public class CommonValue {
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2;
public static final String ML_CONFIG_INDEX = ".plugins-ml-config";
public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 2;
public static final String ML_MAP_RESPONSE_KEY = "response";
public static final String USER_FIELD_MAPPING = " \""
+ CommonValue.USER
+ "\": {\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ public enum FunctionName {
RCF_SUMMARIZE,
LOGISTIC_REGRESSION,
TEXT_EMBEDDING,
SPARSE_ENCODING,
SPARSE_TOKENIZE,
METRICS_CORRELATION,
REMOTE;

Expand All @@ -33,7 +35,7 @@ public static FunctionName from(String value) {
* @return true for deep learning model.
*/
public static boolean isDLModel(FunctionName functionName) {
if (functionName == TEXT_EMBEDDING) {
if (functionName == TEXT_EMBEDDING || functionName == SPARSE_ENCODING || functionName == SPARSE_TOKENIZE) {
return true;
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Map;
import java.util.Optional;

import static org.opensearch.ml.common.CommonValue.ML_MAP_RESPONSE_KEY;
import static org.opensearch.ml.common.utils.StringUtils.isJson;

@Getter
Expand Down Expand Up @@ -101,7 +102,7 @@ public <T> void parseResponse(T response, List<ModelTensor> modelTensors, boolea
return;
}
if (response instanceof String && isJson((String)response)) {
Map<String, Object> data = StringUtils.fromJson((String) response, "response");
Map<String, Object> data = StringUtils.fromJson((String) response, ML_MAP_RESPONSE_KEY);
modelTensors.add(ModelTensor.builder().name("response").dataAsMap(data).build());
} else {
Map<String, Object> map = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws
}
}
MLInputDataset inputDataSet = null;
if (algorithm == FunctionName.TEXT_EMBEDDING) {
if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.SPARSE_ENCODING || algorithm == FunctionName.SPARSE_TOKENIZE) {
ModelResultFilter filter = new ModelResultFilter(returnBytes, returnNumber, targetResponse, targetResponsePositions);
inputDataSet = new TextDocsInputDataSet(textDocs, filter);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
* ML input class which supports a list fo text docs.
* This class can be used for TEXT_EMBEDDING model.
*/
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING})
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING, FunctionName.SPARSE_ENCODING, FunctionName.SPARSE_TOKENIZE})
public class TextDocsMLInput extends MLInput {
public static final String TEXT_DOCS_FIELD = "text_docs";
public static final String RESULT_FILTER_FIELD = "result_filter";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,47 @@
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.MLTaskType;

import java.io.IOException;

@Getter
public class MLDeployModelResponse extends ActionResponse implements ToXContentObject {
public static final String TASK_ID_FIELD = "task_id";
public static final String TASK_TYPE_FIELD = "task_type";
public static final String STATUS_FIELD = "status";

private String taskId;
private MLTaskType taskType;
private String status;

public MLDeployModelResponse(StreamInput in) throws IOException {
super(in);
this.taskId = in.readString();
this.taskType = in.readEnum(MLTaskType.class);
this.status = in.readString();
}

public MLDeployModelResponse(String taskId, String status) {
public MLDeployModelResponse(String taskId, MLTaskType mlTaskType, String status) {
this.taskId = taskId;
this.taskType = mlTaskType;
this.status= status;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(taskId);
out.writeEnum(taskType);
out.writeString(status);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(TASK_ID_FIELD, taskId);
if (taskType != null) {
builder.field(TASK_TYPE_FIELD, taskType);
}
builder.field(STATUS_FIELD, status);
builder.endObject();
return builder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLCommonsClassLoader;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
Expand Down Expand Up @@ -104,7 +103,7 @@ public MLRegisterModelInput(FunctionName functionName,
if (modelFormat == null) {
throw new IllegalArgumentException("model format is null");
}
if (url != null && modelConfig == null) {
if (url != null && modelConfig == null && functionName != FunctionName.SPARSE_TOKENIZE && functionName != FunctionName.SPARSE_ENCODING) { // The tokenize model doesn't require a model configuration. Currently, we only support one type of sparse model, which is pretrained, and it doesn't necessitate a model configuration.
throw new IllegalArgumentException("model config is null");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m
if (modelContentHashValue == null) {
throw new IllegalArgumentException("model content hash value is null");
}
if (modelConfig == null) {
if (modelConfig == null && functionName != FunctionName.SPARSE_TOKENIZE && functionName != FunctionName.SPARSE_ENCODING) { // The tokenize model doesn't require a model configuration. Currently, we only support one type of sparse model, which is pretrained, and it doesn't necessitate a model configuration.
throw new IllegalArgumentException("model config is null");
}
if (totalChunks == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,21 +149,27 @@ public void testClassLoader_ExecuteOutputMCorr() throws IOException {
assertArrayEquals(new long[]{1, 2}, metrics);
}

@Test
public void testClassLoader_MLInput() throws IOException {
assertTrue(MLCommonsClassLoader.canInitMLInput(FunctionName.TEXT_EMBEDDING));
private void testClassLoader_MLInput_DlModel(FunctionName functionName) throws IOException {
assertTrue(MLCommonsClassLoader.canInitMLInput(functionName));

String jsonStr = "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}}";
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();

TextDocsMLInput mlInput = MLCommonsClassLoader.initMLInput(FunctionName.TEXT_EMBEDDING, new Object[]{parser, FunctionName.TEXT_EMBEDDING}, XContentParser.class, FunctionName.class);
TextDocsMLInput mlInput = MLCommonsClassLoader.initMLInput(functionName, new Object[]{parser, functionName}, XContentParser.class, FunctionName.class);
assertNotNull(mlInput);
assertEquals(FunctionName.TEXT_EMBEDDING, mlInput.getFunctionName());
assertEquals(functionName, mlInput.getFunctionName());
assertEquals(2, ((TextDocsInputDataSet)mlInput.getInputDataset()).getDocs().size());
}

@Test
public void testClassLoader_MLInput() throws IOException {
testClassLoader_MLInput_DlModel(FunctionName.TEXT_EMBEDDING);
testClassLoader_MLInput_DlModel(FunctionName.SPARSE_TOKENIZE);
testClassLoader_MLInput_DlModel(FunctionName.SPARSE_ENCODING);
}

public enum TestEnum {
TEST
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.core.common.Strings;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
Expand Down Expand Up @@ -110,19 +108,19 @@ public void parse_LinearRegression() throws IOException {
});
}

@Test
public void parse_TextEmbedding() throws IOException {
private void parse_NLPModel(FunctionName functionName) throws IOException {
String sentence = "test sentence";
String column = "column1";
Integer position = 1;
ModelResultFilter resultFilter = ModelResultFilter.builder()
.targetResponse(Arrays.asList(column))
.targetResponsePositions(Arrays.asList(position))
.build();
TextDocsInputDataSet inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList(sentence))
.resultFilter(resultFilter).build();
String expectedInputStr = "{\"algorithm\":\"TEXT_EMBEDDING\",\"text_docs\":[\"test sentence\"],\"return_bytes\":false,\"return_number\":false,\"target_response\":[\"column1\"],\"target_response_positions\":[1]}";
testParse(FunctionName.TEXT_EMBEDDING, inputDataset, expectedInputStr, parsedInput -> {

TextDocsInputDataSet inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList(sentence)).resultFilter(resultFilter).build();
String expectedInputStr = "{\"algorithm\":\"functionName\",\"text_docs\":[\"test sentence\"],\"return_bytes\":false,\"return_number\":false,\"target_response\":[\"column1\"],\"target_response_positions\":[1]}";
expectedInputStr = expectedInputStr.replace("functionName", functionName.toString());
testParse(functionName, inputDataset, expectedInputStr, parsedInput -> {
assertNotNull(parsedInput.getInputDataset());
TextDocsInputDataSet parsedInputDataSet = (TextDocsInputDataSet) parsedInput.getInputDataset();
assertEquals(1, parsedInputDataSet.getDocs().size());
Expand All @@ -134,19 +132,33 @@ public void parse_TextEmbedding() throws IOException {
}

@Test
public void parse_TextEmbedding_NullResultFilter() throws IOException {
public void parse_NLP_Related() throws IOException {
parse_NLPModel(FunctionName.TEXT_EMBEDDING);
parse_NLPModel(FunctionName.SPARSE_TOKENIZE);
parse_NLPModel(FunctionName.SPARSE_ENCODING);
}

private void parse_NLPModel_NullResultFilter(FunctionName functionName) throws IOException {
String sentence = "test sentence";
TextDocsInputDataSet inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList(sentence)).build();
String expectedInputStr = "{\"algorithm\":\"TEXT_EMBEDDING\",\"text_docs\":[\"test sentence\"]}";
testParse(FunctionName.TEXT_EMBEDDING, inputDataset, expectedInputStr, parsedInput -> {
String expectedInputStr = "{\"algorithm\":\"functionName\",\"text_docs\":[\"test sentence\"]}";
expectedInputStr = expectedInputStr.replace("functionName", functionName.toString());
testParse(functionName, inputDataset, expectedInputStr, parsedInput -> {
assertNotNull(parsedInput.getInputDataset());
assertEquals(1, ((TextDocsInputDataSet) parsedInput.getInputDataset()).getDocs().size());
assertEquals(sentence, ((TextDocsInputDataSet) parsedInput.getInputDataset()).getDocs().get(0));
});
}

private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr,
Consumer<MLInput> verify) throws IOException {

@Test
public void parse_NLPRelated_NullResultFilter() throws IOException {
parse_NLPModel_NullResultFilter(FunctionName.TEXT_EMBEDDING);
parse_NLPModel_NullResultFilter(FunctionName.SPARSE_TOKENIZE);
parse_NLPModel_NullResultFilter(FunctionName.SPARSE_ENCODING);
}

private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr, Consumer<MLInput> verify) throws IOException {
MLInput input = MLInput.builder().inputDataset(inputDataset).algorithm(algorithm).build();
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
input.toXContent(builder, ToXContent.EMPTY_PARAMS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.MLTaskType;

import java.io.IOException;

Expand All @@ -19,37 +20,40 @@ public class MLDeployModelResponseTest {

private String taskId;
private String status;
private MLTaskType taskType;

@Before
public void setUp() throws Exception {
taskId = "test_id";
status = "test";
taskType = MLTaskType.DEPLOY_MODEL;
}

@Test
public void writeTo_Success() throws IOException {
// Setup
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
MLDeployModelResponse response = new MLDeployModelResponse(taskId, status);
MLDeployModelResponse response = new MLDeployModelResponse(taskId, taskType, status);
// Run the test
response.writeTo(bytesStreamOutput);
MLDeployModelResponse parsedResponse = new MLDeployModelResponse(bytesStreamOutput.bytes().streamInput());
// Verify the results
assertEquals(response.getTaskId(), parsedResponse.getTaskId());
assertEquals(response.getTaskType(), parsedResponse.getTaskType());
assertEquals(response.getStatus(), parsedResponse.getStatus());
}

@Test
public void testToXContent() throws IOException {
// Setup
MLDeployModelResponse response = new MLDeployModelResponse(taskId, status);
MLDeployModelResponse response = new MLDeployModelResponse(taskId, taskType, status);
// Run the test
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(builder);
String jsonStr = builder.toString();
// Verify the results
assertEquals("{\"task_id\":\"test_id\"," +
assertEquals("{\"task_id\":\"test_id\"," + "\"task_type\":\"DEPLOY_MODEL\"," +
"\"status\":\"test\"}", jsonStr);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ GET /_plugins/_ml/profile/models/zwla5YUB1qmVrJFlwzXJ
"models": {
"zwla5YUB1qmVrJFlwzXJ": { # model id
"model_state": "LOADED",
"predictor": "org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingModel@1a0b0793",
"predictor": "org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel@1a0b0793",
"target_worker_nodes": [ # plan to deploy model to these nodes
"0TLL4hHxRv6_G3n6y1l0BQ"
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ public List downloadPrebuiltModelMetaList(String taskId, MLRegisterModelInput re
* @param modelContentHash model content hash value
* @param listener action listener
*/
public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String modelName, String version, String url, String modelContentHash, ActionListener<Map<String, Object>> listener) {
public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String modelName, String version, String url, String modelContentHash, FunctionName functionName, ActionListener<Map<String, Object>> listener) {
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
Path registerModelPath = mlEngine.getRegisterModelPath(taskId, modelName, version);
Expand All @@ -200,7 +200,7 @@ public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String mo
File modelZipFile = new File(modelPath);
log.debug("download model to file {}", modelZipFile.getAbsolutePath());
DownloadUtils.download(url, modelPath, new ProgressBar());
verifyModelZipFile(modelFormat, modelPath, modelName);
verifyModelZipFile(modelFormat, modelPath, modelName, functionName);
String hash = calculateFileHash(modelZipFile);
if (hash.equals(modelContentHash)) {
List<String> chunkFiles = splitFileIntoChunks(modelZipFile, modelPartsPath, CHUNK_SIZE);
Expand All @@ -222,7 +222,7 @@ public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String mo
}
}

public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePath, String modelName) throws IOException {
public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePath, String modelName, FunctionName functionName) throws IOException {
boolean hasPtFile = false;
boolean hasOnnxFile = false;
boolean hasTokenizerFile = false;
Expand All @@ -237,7 +237,7 @@ public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePat
}
}
}
if (!hasPtFile && !hasOnnxFile) {
if (!hasPtFile && !hasOnnxFile && functionName != FunctionName.SPARSE_TOKENIZE) { // sparse tokenizer model doesn't need model file.
throw new IllegalArgumentException("Can't find model file");
}
if (!hasTokenizerFile) {
Expand Down
Loading

0 comments on commit 9f9bafa

Please sign in to comment.