From d85524305b9eb49cd7d647d5052ddc8eb8e0c729 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Thu, 9 Nov 2023 11:28:44 -0800 Subject: [PATCH] add text similarity dataset unittests Signed-off-by: HenryL27 --- .../dataset/TextSimilarityInputDataSet.java | 2 +- .../TextSimilarityInputDatasetTest.java | 58 +++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 common/src/test/java/org/opensearch/ml/common/dataset/TextSimilarityInputDatasetTest.java diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/TextSimilarityInputDataSet.java b/common/src/main/java/org/opensearch/ml/common/dataset/TextSimilarityInputDataSet.java index 4d3879b919..6c15243057 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataset/TextSimilarityInputDataSet.java +++ b/common/src/main/java/org/opensearch/ml/common/dataset/TextSimilarityInputDataSet.java @@ -52,7 +52,7 @@ public TextSimilarityInputDataSet(List> pairs) { public TextSimilarityInputDataSet(StreamInput in) throws IOException { super(MLInputDataType.TEXT_SIMILARITY); int size = in.readInt(); - this.pairs = new ArrayList>(size); + this.pairs = new ArrayList>(); for(int i = 0; i < size; i++) { String query = in.readString(); String context = in.readString(); diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/TextSimilarityInputDatasetTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/TextSimilarityInputDatasetTest.java new file mode 100644 index 0000000000..2b2504038a --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/dataset/TextSimilarityInputDatasetTest.java @@ -0,0 +1,58 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.common.dataset; + +import static org.junit.Assert.assertThrows; + +import java.io.IOException; +import java.util.List; + +import org.apache.commons.lang3.tuple.Pair; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class TextSimilarityInputDatasetTest { + + @Test + public void testStreaming() throws IOException { + List> pairs = List.of( + Pair.of("today is sunny", "That is a happy dog"), + Pair.of("today is sunny", "it's summer") + ); + TextSimilarityInputDataSet dataset = TextSimilarityInputDataSet.builder().pairs(pairs).build(); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + dataset.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + TextSimilarityInputDataSet newDs = (TextSimilarityInputDataSet) MLInputDataset.fromStream(in); + assert (dataset.getPairs().equals(newDs.getPairs())); + } + + @Test + public void noPairs_ThenFail() { + List> pairs = List.of(); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> TextSimilarityInputDataSet.builder().pairs(pairs).build()); + assert (e.getMessage().equals("pairs must be nonempty")); + } +}