From aca02d13c88a5a2eab045664dfa02444859c6090 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Thu, 28 Sep 2023 11:15:04 -0700 Subject: [PATCH] allow input null for text docs input (#1402) Signed-off-by: Yaliang Wu --- .../ml/common/input/nlp/TextDocsMLInput.java | 6 +++++- .../ml/common/input/nlp/TextDocsMLInputTest.java | 16 ++++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java index ec9e31b3a9..deeb5ef81f 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java @@ -114,7 +114,11 @@ public TextDocsMLInput(XContentParser parser, FunctionName functionName) throws case TEXT_DOCS_FIELD: ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - docs.add(parser.text()); + if (parser.currentToken() == null || parser.currentToken() == XContentParser.Token.VALUE_NULL) { + docs.add(null); + } else { + docs.add(parser.text()); + } } break; case RESULT_FILTER_FIELD: diff --git a/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java index 6578cd42bd..ec00aef17f 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java @@ -24,6 +24,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; public class TextDocsMLInputTest { @@ -47,22 +48,22 @@ public void parseTextDocsMLInput() throws IOException { XContentBuilder builder = XContentFactory.jsonBuilder(); input.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); - parseMLInput(jsonStr); + parseMLInput(jsonStr, 2); } @Test public void parseTextDocsMLInput_OldWay() throws IOException { - String jsonStr = "{\"text_docs\": [ \"doc1\", \"doc2\" ],\"return_number\": true, \"return_bytes\": true,\"target_response\": [ \"field1\" ], \"target_response_positions\": [2]}"; - parseMLInput(jsonStr); + String jsonStr = "{\"text_docs\": [ \"doc1\", \"doc2\", null ],\"return_number\": true, \"return_bytes\": true,\"target_response\": [ \"field1\" ], \"target_response_positions\": [2]}"; + parseMLInput(jsonStr, 3); } @Test public void parseTextDocsMLInput_NewWay() throws IOException { String jsonStr = "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}}"; - parseMLInput(jsonStr); + parseMLInput(jsonStr, 2); } - private void parseMLInput(String jsonStr) throws IOException { + private void parseMLInput(String jsonStr, int docSize) throws IOException { XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), null, jsonStr); parser.nextToken(); @@ -72,9 +73,12 @@ private void parseMLInput(String jsonStr) throws IOException { assertEquals(input.getFunctionName(), parsedInput.getFunctionName()); assertEquals(input.getInputDataset().getInputDataType(), parsedInput.getInputDataset().getInputDataType()); TextDocsInputDataSet inputDataset = (TextDocsInputDataSet) parsedInput.getInputDataset(); - assertEquals(2, inputDataset.getDocs().size()); + assertEquals(docSize, inputDataset.getDocs().size()); assertEquals("doc1", inputDataset.getDocs().get(0)); assertEquals("doc2", inputDataset.getDocs().get(1)); + if (inputDataset.getDocs().size() > 2) { + assertNull(inputDataset.getDocs().get(2)); + } assertNotNull(inputDataset.getResultFilter()); assertTrue(inputDataset.getResultFilter().isReturnBytes()); assertTrue(inputDataset.getResultFilter().isReturnNumber());