From dc95b85f2bd8c740c812b7ac7a1b8b303996b46d Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 27 Feb 2024 13:42:24 -0800 Subject: [PATCH] Adding simple aggregation processor, custom scenario works, base is failing Signed-off-by: Martin Gaievski --- .../search/HybridTopScoreDocCollector.java | 2 +- .../search/query/HybridCollectorManager.java | 257 ++++++++++++++++++ .../query/HybridQueryPhaseSearcher.java | 208 ++++---------- .../ml/MLCommonsClientAccessorTests.java | 2 +- 4 files changed, 317 insertions(+), 152 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java diff --git a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java index 3b39c0046..050a28422 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java +++ b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java @@ -30,7 +30,7 @@ * Collects the TopDocs after executing hybrid query. Uses HybridQueryTopDocs as DTO to handle each sub query results */ @Log4j2 -public class HybridTopScoreDocCollector implements Collector { +public class HybridTopScoreDocCollector implements Collector { private static final TopDocs EMPTY_TOPDOCS = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); private int docBase; private final HitsThresholdChecker hitsThresholdChecker; diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java new file mode 100644 index 000000000..3d991a4f7 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -0,0 +1,257 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import lombok.RequiredArgsConstructor; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.Weight; +import org.opensearch.common.lucene.search.FilteredCollector; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.neuralsearch.search.HitsThresholdChecker; +import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.MultiCollectorWrapper; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; +import org.opensearch.search.sort.SortAndFormats; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; + +@RequiredArgsConstructor +public abstract class HybridCollectorManager implements CollectorManager { + + private final int numHits; + private final HitsThresholdChecker hitsThresholdChecker; + private final boolean isSingleShard; + private final int trackTotalHitsUpTo; + private final SortAndFormats sortAndFormats; + private final Optional filteringWeightOptional; + + public static CollectorManager createHybridCollectorManager(final SearchContext searchContext) throws IOException { + final IndexReader reader = searchContext.searcher().getIndexReader(); + final int totalNumDocs = Math.max(0, reader.numDocs()); + boolean isSingleShard = searchContext.numberOfShards() == 1; + int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); + int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo(); + + Weight filterWeight = null; + // check for post filter + if (Objects.nonNull(searchContext.parsedPostFilter())) { + Query filterQuery = searchContext.parsedPostFilter().query(); + ContextIndexSearcher searcher = searchContext.searcher(); + filterWeight = searcher.createWeight(searcher.rewrite(filterQuery), ScoreMode.COMPLETE_NO_SCORES, 1f); + } + + return searchContext.shouldUseConcurrentSearch() + ? new HybridCollectorConcurrentSearchManager( + numDocs, + new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())), + isSingleShard, + trackTotalHitsUpTo, + searchContext.sort(), + filterWeight + ) + : new HybridCollectorNonConcurrentManager( + numDocs, + new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())), + isSingleShard, + trackTotalHitsUpTo, + searchContext.sort(), + filterWeight + ); + } + + @Override + abstract public Collector newCollector(); + + Collector getCollector() { + Collector hybridcollector = new HybridTopScoreDocCollector<>(numHits, hitsThresholdChecker); + if (filteringWeightOptional.isEmpty()) { + // this is plain hybrid query scores collector + return hybridcollector; + } + // this is hybrid query scores collector with post filter applied + return new FilteredCollector(hybridcollector, filteringWeightOptional.get()); + } + + @Override + public ReduceableSearchResult reduce(Collection collectors) { + final List> hybridTopScoreDocCollectors = new ArrayList<>(); + + for (final Collector collector : collectors) { + if (collector instanceof MultiCollectorWrapper) { + for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) { + if (sub instanceof HybridTopScoreDocCollector) { + hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) sub); + } + } + } else if (collector instanceof HybridTopScoreDocCollector) { + hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) collector); + } else if (collector instanceof FilteredCollector + && ((FilteredCollector) collector).getCollector() instanceof HybridTopScoreDocCollector) { + hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) ((FilteredCollector) collector).getCollector()); + } + } + + if (!hybridTopScoreDocCollectors.isEmpty()) { + HybridTopScoreDocCollector hybridTopScoreDocCollector = hybridTopScoreDocCollectors.stream() + .findFirst() + .orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query")); + List topDocs = hybridTopScoreDocCollector.topDocs(); + TopDocs newTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard), topDocs); + float maxScore = getMaxScore(topDocs); + TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); + return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); }; + } + throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors"); + } + + private TopDocs getNewTopDocs(final TotalHits totalHits, final List topDocs) { + ScoreDoc[] scoreDocs = new ScoreDoc[0]; + if (Objects.nonNull(topDocs)) { + // for a single shard case we need to do score processing at coordinator level. + // this is workaround for current core behaviour, for single shard fetch phase is executed + // right after query phase and processors are called after actual fetch is done + // find any valid doc Id, or set it to -1 if there is not a single match + int delimiterDocId = topDocs.stream() + .filter(Objects::nonNull) + .filter(topDoc -> Objects.nonNull(topDoc.scoreDocs)) + .map(topDoc -> topDoc.scoreDocs) + .filter(scoreDoc -> scoreDoc.length > 0) + .map(scoreDoc -> scoreDoc[0].doc) + .findFirst() + .orElse(-1); + if (delimiterDocId == -1) { + return new TopDocs(totalHits, scoreDocs); + } + // format scores using following template: + // doc_id | magic_number_1 + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_1 + List result = new ArrayList<>(); + result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); + for (TopDocs topDoc : topDocs) { + if (Objects.isNull(topDoc) || Objects.isNull(topDoc.scoreDocs)) { + result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); + continue; + } + result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); + result.addAll(Arrays.asList(topDoc.scoreDocs)); + } + result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); + scoreDocs = result.stream().map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new); + } + return new TopDocs(totalHits, scoreDocs); + } + + private TotalHits getTotalHits(int trackTotalHitsUpTo, final List topDocs, final boolean isSingleShard) { + final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : TotalHits.Relation.EQUAL_TO; + if (topDocs == null || topDocs.isEmpty()) { + return new TotalHits(0, relation); + } + + List scoreDocs = topDocs.stream() + .map(topdDoc -> topdDoc.scoreDocs) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + Set uniqueDocIds = new HashSet<>(); + for (ScoreDoc[] scoreDocsArray : scoreDocs) { + uniqueDocIds.addAll(Arrays.stream(scoreDocsArray).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList())); + } + long maxTotalHits = uniqueDocIds.size(); + + return new TotalHits(maxTotalHits, relation); + } + + private float getMaxScore(final List topDocs) { + if (topDocs.isEmpty()) { + return 0.0f; + } else { + return topDocs.stream() + .map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0]) + .map(scoreDoc -> scoreDoc.score) + .max(Float::compare) + .get(); + } + } + + private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) { + return sortAndFormats == null ? null : sortAndFormats.formats; + } + + static class HybridCollectorNonConcurrentManager extends HybridCollectorManager { + Collector maxScoreCollector; + + public HybridCollectorNonConcurrentManager( + int numHits, + HitsThresholdChecker hitsThresholdChecker, + boolean isSingleShard, + int trackTotalHitsUpTo, + SortAndFormats sortAndFormats, + Weight filteringWeight + ) { + super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, Optional.ofNullable(filteringWeight)); + } + + @Override + public Collector newCollector() { + if (Objects.isNull(maxScoreCollector)) { + maxScoreCollector = getCollector(); + return maxScoreCollector; + } else { + Collector toReturnCollector = maxScoreCollector; + maxScoreCollector = null; + return toReturnCollector; + } + } + } + + static class HybridCollectorConcurrentSearchManager extends HybridCollectorManager { + + public HybridCollectorConcurrentSearchManager( + int numHits, + HitsThresholdChecker hitsThresholdChecker, + boolean isSingleShard, + int trackTotalHitsUpTo, + SortAndFormats sortAndFormats, + Weight filteringWeight + ) { + super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, Optional.ofNullable(filteringWeight)); + } + + @Override + public Collector newCollector() { + return getCollector(); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index bf05fdc9d..a30e4c27b 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -11,13 +11,18 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Objects; +import lombok.AllArgsConstructor; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; @@ -33,12 +38,15 @@ import org.opensearch.neuralsearch.search.HitsThresholdChecker; import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; import org.opensearch.search.DocValueFormat; +import org.opensearch.search.aggregations.AggregationProcessor; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QueryCollectorContext; import org.opensearch.search.query.QueryPhase; +import org.opensearch.search.query.QueryPhaseExecutionException; import org.opensearch.search.query.QueryPhaseSearcherWrapper; import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; import org.opensearch.search.query.TopDocsCollectorContext; import org.opensearch.search.rescore.RescoreContext; import org.opensearch.search.sort.SortAndFormats; @@ -66,11 +74,9 @@ public boolean searchWith( final boolean hasFilterCollector, final boolean hasTimeout ) throws IOException { - if (isHybridQuery(query, searchContext)) { - Query hybridQuery = extractHybridQuery(searchContext, query); - return searchWithCollector(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); + if (!isHybridQuery(query, searchContext)) { + validateQuery(searchContext, query); } - validateQuery(searchContext, query); return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } @@ -122,20 +128,6 @@ private boolean isWrappedHybridQuery(final Query query) { && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery); } - private Query extractHybridQuery(final SearchContext searchContext, final Query query) { - if (hasNestedFieldOrNestedDocs(query, searchContext) - && isWrappedHybridQuery(query) - && ((BooleanQuery) query).clauses().size() > 0) { - // extract hybrid query and replace bool with hybrid query - List booleanClauses = ((BooleanQuery) query).clauses(); - if (booleanClauses.isEmpty() || booleanClauses.get(0).getQuery() instanceof HybridQuery == false) { - throw new IllegalStateException("cannot process hybrid query due to incorrect structure of top level bool query"); - } - return booleanClauses.get(0).getQuery(); - } - return query; - } - /** * Validate the query from neural-search plugin point of view. Current main goal for validation is to block cases * when hybrid query is wrapped into other compound queries. @@ -180,152 +172,68 @@ private void validateNestedBooleanQuery(final Query query, final int level) { } } - @VisibleForTesting - protected boolean searchWithCollector( - final SearchContext searchContext, - final ContextIndexSearcher searcher, - final Query query, - final LinkedList collectors, - final boolean hasFilterCollector, - final boolean hasTimeout - ) throws IOException { - log.debug("searching with custom doc collector, shard {}", searchContext.shardTarget().getShardId()); + private int getMaxDepthLimit(final SearchContext searchContext) { + Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings(); + return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue(); + } - final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext, hasFilterCollector); - collectors.addFirst(topDocsFactory); - if (searchContext.size() == 0) { - final TotalHitCountCollector collector = new TotalHitCountCollector(); - searcher.search(query, collector); - return false; - } - final IndexReader reader = searchContext.searcher().getIndexReader(); - int totalNumDocs = Math.max(0, reader.numDocs()); - int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); - final boolean shouldRescore = !searchContext.rescore().isEmpty(); - if (shouldRescore) { - for (RescoreContext rescoreContext : searchContext.rescore()) { - numDocs = Math.max(numDocs, rescoreContext.getWindowSize()); - } - } + @Override + public AggregationProcessor aggregationProcessor(SearchContext searchContext) { + AggregationProcessor coreAggProcessor = super.aggregationProcessor(searchContext); + return new HybridAggregationProcessor(coreAggProcessor); + } - final QuerySearchResult queryResult = searchContext.queryResult(); + @AllArgsConstructor + public class HybridAggregationProcessor implements AggregationProcessor { - final HybridTopScoreDocCollector collector = new HybridTopScoreDocCollector( - numDocs, - new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())) - ); + private final AggregationProcessor delegateAggsProcessor; - searcher.search(query, collector); + @Override + public void preProcess(SearchContext context) { + delegateAggsProcessor.preProcess(context); - if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) { - queryResult.terminatedEarly(false); + if (isHybridQuery(context.query(), context)) { + // adding collector manager for hybrid query + CollectorManager collectorManager; + try { + collectorManager = HybridCollectorManager.createHybridCollectorManager(context); + } catch (IOException e) { + throw new RuntimeException(e); + } + Map, CollectorManager> collectorManagersByManagerClass = context + .queryCollectorManagers(); + collectorManagersByManagerClass.put(HybridCollectorManager.class, collectorManager); + } } - setTopDocsInQueryResult(queryResult, collector, searchContext); - - return shouldRescore; - } - - private void setTopDocsInQueryResult( - final QuerySearchResult queryResult, - final HybridTopScoreDocCollector collector, - final SearchContext searchContext - ) { - final List topDocs = collector.topDocs(); - final float maxScore = getMaxScore(topDocs); - final boolean isSingleShard = searchContext.numberOfShards() == 1; - final TopDocs newTopDocs = getNewTopDocs(getTotalHits(searchContext, topDocs, isSingleShard), topDocs); - final TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); - queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats(searchContext.sort())); - } - - private TopDocs getNewTopDocs(final TotalHits totalHits, final List topDocs) { - ScoreDoc[] scoreDocs = new ScoreDoc[0]; - if (Objects.nonNull(topDocs)) { - // for a single shard case we need to do score processing at coordinator level. - // this is workaround for current core behaviour, for single shard fetch phase is executed - // right after query phase and processors are called after actual fetch is done - // find any valid doc Id, or set it to -1 if there is not a single match - int delimiterDocId = topDocs.stream() - .filter(Objects::nonNull) - .filter(topDoc -> Objects.nonNull(topDoc.scoreDocs)) - .map(topDoc -> topDoc.scoreDocs) - .filter(scoreDoc -> scoreDoc.length > 0) - .map(scoreDoc -> scoreDoc[0].doc) - .findFirst() - .orElse(-1); - if (delimiterDocId == -1) { - return new TopDocs(totalHits, scoreDocs); - } - // format scores using following template: - // doc_id | magic_number_1 - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_1 - List result = new ArrayList<>(); - result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); - for (TopDocs topDoc : topDocs) { - if (Objects.isNull(topDoc) || Objects.isNull(topDoc.scoreDocs)) { - result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); - continue; + @Override + public void postProcess(SearchContext context) { + if (isHybridQuery(context.query(), context)) { + if (!context.shouldUseConcurrentSearch()) { + reduceCollectorResults(context); } - result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); - result.addAll(Arrays.asList(topDoc.scoreDocs)); + updateQueryResult(context.queryResult(), context); } - result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); - scoreDocs = result.stream().map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new); - } - return new TopDocs(totalHits, scoreDocs); - } - private TotalHits getTotalHits(final SearchContext searchContext, final List topDocs, final boolean isSingleShard) { - int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo(); - final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED - ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO - : TotalHits.Relation.EQUAL_TO; - if (topDocs == null || topDocs.isEmpty()) { - return new TotalHits(0, relation); + delegateAggsProcessor.postProcess(context); } - long maxTotalHits = topDocs.get(0).totalHits.value; - int totalSize = 0; - for (TopDocs topDoc : topDocs) { - maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value); - if (isSingleShard) { - totalSize += topDoc.totalHits.value + 1; + + private void reduceCollectorResults(SearchContext context) { + CollectorManager collectorManager = context.queryCollectorManagers() + .get(HybridCollectorManager.class); + try { + final Collection collectors = List.of(collectorManager.newCollector()); + collectorManager.reduce(collectors).reduce(context.queryResult()); + } catch (IOException e) { + throw new QueryPhaseExecutionException(context.shardTarget(), "failed to execute hybrid query aggregation processor", e); } } - // add 1 qty per each sub-query and + 2 for start and stop delimiters - totalSize += 2; - if (isSingleShard) { - // for single shard we need to update total size as this is how many docs are fetched in Fetch phase - searchContext.size(totalSize); - } - return new TotalHits(maxTotalHits, relation); - } - - private float getMaxScore(final List topDocs) { - if (topDocs.isEmpty()) { - return 0.0f; - } else { - return topDocs.stream() - .map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0]) - .map(scoreDoc -> scoreDoc.score) - .max(Float::compare) - .get(); + private void updateQueryResult(final QuerySearchResult queryResult, final SearchContext searchContext) { + boolean isSingleShard = searchContext.numberOfShards() == 1; + if (isSingleShard) { + searchContext.size(queryResult.queryResult().topDocs().topDocs.scoreDocs.length); + } } } - - private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) { - return sortAndFormats == null ? null : sortAndFormats.formats; - } - - private int getMaxDepthLimit(final SearchContext searchContext) { - Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings(); - return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue(); - } } diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index a2fd41ac6..6d31ea6a6 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -328,7 +328,7 @@ public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenR .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException); } - + public void testInferenceMultimodal_whenInvalidInputAndEmptyTensorOutput_thenFail() { List tensorsList = new ArrayList<>(); List mlModelTensorList = List.of(