Skip to content

Commit

Permalink
Add cross encoder support (opensearch-project#1615)
Browse files Browse the repository at this point in the history
* add text similarity inputs and function name

Signed-off-by: HenryL27 <[email protected]>

* add text similarity cross encoder model

Signed-off-by: HenryL27 <[email protected]>

* add text similarity unit tests

Signed-off-by: HenryL27 <[email protected]>

* add text similarity input unittests

Signed-off-by: HenryL27 <[email protected]>

* add text similarity dataset unittests

Signed-off-by: HenryL27 <[email protected]>

* add function name annotation

Signed-off-by: HenryL27 <[email protected]>

* refactor API to use single query

Signed-off-by: HenryL27 <[email protected]>

* omit private from class vars

Co-authored-by: Navneet Verma <[email protected]>
Signed-off-by: HenryL27 <[email protected]>

* change output name from logits to similarity

Signed-off-by: HenryL27 <[email protected]>

* hashify isDLModel

Signed-off-by: HenryL27 <[email protected]>

* add error message for non-torchscript cross encoders

Signed-off-by: HenryL27 <[email protected]>

* allow onnx, actually.

Signed-off-by: HenryL27 <[email protected]>

* apply spotless after rebase

Signed-off-by: HenryL27 <[email protected]>

* add unittest for new mlinput toXcontent clause

Signed-off-by: HenryL27 <[email protected]>

* static DLModels

Signed-off-by: HenryL27 <[email protected]>

* add tests and error message tweaks

Signed-off-by: HenryL27 <[email protected]>

* name test models w framework

Signed-off-by: HenryL27 <[email protected]>

* change pt->torch_script

Signed-off-by: HenryL27 <[email protected]>

---------

Signed-off-by: HenryL27 <[email protected]>
Co-authored-by: Navneet Verma <[email protected]>
  • Loading branch information
2 people authored and austintlee committed Feb 29, 2024
1 parent 21dcd7c commit 7451b31
Show file tree
Hide file tree
Showing 13 changed files with 932 additions and 4 deletions.
16 changes: 12 additions & 4 deletions common/src/main/java/org/opensearch/ml/common/FunctionName.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

package org.opensearch.ml.common;

import java.util.HashSet;
import java.util.Set;

public enum FunctionName {
LINEAR_REGRESSION,
KMEANS,
Expand All @@ -17,6 +20,7 @@ public enum FunctionName {
RCF_SUMMARIZE,
LOGISTIC_REGRESSION,
TEXT_EMBEDDING,
TEXT_SIMILARITY,
SPARSE_ENCODING,
SPARSE_TOKENIZE,
METRICS_CORRELATION,
Expand All @@ -30,14 +34,18 @@ public static FunctionName from(String value) {
}
}

private static final HashSet<FunctionName> DL_MODELS = new HashSet<>(Set.of(
TEXT_EMBEDDING,
TEXT_SIMILARITY,
SPARSE_ENCODING,
SPARSE_TOKENIZE
));

/**
* Check if model is deep learning model.
* @return true for deep learning model.
*/
public static boolean isDLModel(FunctionName functionName) {
if (functionName == TEXT_EMBEDDING || functionName == SPARSE_ENCODING || functionName == SPARSE_TOKENIZE) {
return true;
}
return false;
return DL_MODELS.contains(functionName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ public enum MLInputDataType {
SEARCH_QUERY,
DATA_FRAME,
TEXT_DOCS,
TEXT_SIMILARITY,
REMOTE
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.common.dataset;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.annotation.InputDataSet;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.experimental.FieldDefaults;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@InputDataSet(MLInputDataType.TEXT_SIMILARITY)
public class TextSimilarityInputDataSet extends MLInputDataset {

List<String> textDocs;

String queryText;

@Builder(toBuilder = true)
public TextSimilarityInputDataSet(String queryText, List<String> textDocs) {
super(MLInputDataType.TEXT_SIMILARITY);
Objects.requireNonNull(textDocs);
Objects.requireNonNull(queryText);
if(textDocs.isEmpty()) {
throw new IllegalArgumentException("No text documents were provided");
}
this.textDocs = textDocs;
this.queryText = queryText;
}

public TextSimilarityInputDataSet(StreamInput in) throws IOException {
super(MLInputDataType.TEXT_SIMILARITY);
this.queryText = in.readString();
int size = in.readInt();
this.textDocs = new ArrayList<String>();
for(int i = 0; i < size; i++) {
String context = in.readString();
this.textDocs.add(context);
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(queryText);
out.writeInt(this.textDocs.size());
for (String doc : this.textDocs) {
out.writeString(doc);
}
}
}
25 changes: 25 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/input/MLInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
Expand All @@ -21,6 +22,7 @@
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.search.builder.SearchSourceBuilder;

Expand Down Expand Up @@ -55,6 +57,8 @@ public class MLInput implements Input {
public static final String TARGET_RESPONSE_POSITIONS_FIELD = "target_response_positions";
// Input text sentences for text embedding model
public static final String TEXT_DOCS_FIELD = "text_docs";
// Input query text to compare against for text similarity model
public static final String QUERY_TEXT_FIELD = "query_text";

// Algorithm name
protected FunctionName algorithm;
Expand Down Expand Up @@ -157,6 +161,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(TARGET_RESPONSE_POSITIONS_FIELD, targetPositions.toArray(new Integer[0]));
}
}
break;
case TEXT_SIMILARITY:
TextSimilarityInputDataSet ds = (TextSimilarityInputDataSet) this.inputDataset;
List<String> tdocs = ds.getTextDocs();
String queryText = ds.getQueryText();
builder.field(QUERY_TEXT_FIELD, queryText);
if (tdocs != null && !tdocs.isEmpty()) {
builder.startArray(TEXT_DOCS_FIELD);
for(String d : tdocs) {
builder.value(d);
}
builder.endArray();
}
break;
default:
break;
}
Expand Down Expand Up @@ -186,6 +204,7 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws
List<String> targetResponse = new ArrayList<>();
List<Integer> targetResponsePositions = new ArrayList<>();
List<String> textDocs = new ArrayList<>();
String queryText = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -233,6 +252,9 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws
textDocs.add(parser.text());
}
break;
case QUERY_TEXT_FIELD:
queryText = parser.text();
break;
default:
parser.skipChildren();
break;
Expand All @@ -243,6 +265,9 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws
ModelResultFilter filter = new ModelResultFilter(returnBytes, returnNumber, targetResponse, targetResponsePositions);
inputDataSet = new TextDocsInputDataSet(textDocs, filter);
}
if (algorithm == FunctionName.TEXT_SIMILARITY) {
inputDataSet = new TextSimilarityInputDataSet(queryText, textDocs);
}
return new MLInput(algorithm, mlParameters, searchSourceBuilder, sourceIndices, dataFrame, inputDataSet);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.common.input.nlp;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;


/**
* MLInput which supports a text similarity algorithm
* Inputs are a query and a list of texts. Outputs are real numbers
* Use this for Cross Encoder models
*/
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_SIMILARITY})
public class TextSimilarityMLInput extends MLInput {

public TextSimilarityMLInput(FunctionName algorithm, MLInputDataset dataset) {
super(algorithm, null, dataset);
}

public TextSimilarityMLInput(StreamInput in) throws IOException {
super(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ALGORITHM_FIELD, algorithm.name());
if(parameters != null) {
builder.field(ML_PARAMETERS_FIELD, parameters);
}
if(inputDataset != null) {
TextSimilarityInputDataSet ds = (TextSimilarityInputDataSet) this.inputDataset;
List<String> docs = ds.getTextDocs();
String queryText = ds.getQueryText();
builder.field(QUERY_TEXT_FIELD, queryText);
if (docs != null && !docs.isEmpty()) {
builder.startArray(TEXT_DOCS_FIELD);
for(String d : docs) {
builder.value(d);
}
builder.endArray();
}
}
builder.endObject();
return builder;
}

public TextSimilarityMLInput(XContentParser parser, FunctionName functionName) throws IOException {
super();
this.algorithm = functionName;
List<String> docs = new ArrayList<>();
String queryText = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case TEXT_DOCS_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
String context = parser.text();
docs.add(context);
}
break;
case QUERY_TEXT_FIELD:
queryText = parser.text();
default:
parser.skipChildren();
break;
}
}
if(docs.isEmpty()) {
throw new IllegalArgumentException("No text documents were provided");
}
if(queryText == null) {
throw new IllegalArgumentException("No query text was provided");
}
inputDataset = new TextSimilarityInputDataSet(queryText, docs);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.common.dataset;

import static org.junit.Assert.assertThrows;

import java.io.IOException;
import java.util.List;

import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.BytesStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

public class TextSimilarityInputDatasetTest {

@Test
public void testStreaming() throws IOException {
List<String> docs = List.of("That is a happy dog", "it's summer");
String queryText = "today is sunny";
TextSimilarityInputDataSet dataset = TextSimilarityInputDataSet.builder().queryText(queryText).textDocs(docs).build();
BytesStreamOutput outbytes = new BytesStreamOutput();
StreamOutput osso = new OutputStreamStreamOutput(outbytes);
dataset.writeTo(osso);
StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes()));
TextSimilarityInputDataSet newDs = (TextSimilarityInputDataSet) MLInputDataset.fromStream(in);
assert (dataset.getTextDocs().equals(newDs.getTextDocs()));
assert (dataset.getQueryText().equals(newDs.getQueryText()));
}

@Test
public void noPairs_ThenFail() {
List<String> docs = List.of();
String queryText = "today is sunny";
IllegalArgumentException e = assertThrows(IllegalArgumentException.class,
() -> TextSimilarityInputDataSet.builder().textDocs(docs).queryText(queryText).build());
assert (e.getMessage().equals("No text documents were provided"));
}

@Test
public void noQuery_ThenFail() {
List<String> docs = List.of("That is a happy dog", "it's summer");
String queryText = null;
assertThrows(NullPointerException.class,
() -> TextSimilarityInputDataSet.builder().textDocs(docs).queryText(queryText).build());
}
}
Loading

0 comments on commit 7451b31

Please sign in to comment.