Skip to content

Commit

Permalink
Refactor tests, run full test suite based on flag
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Mar 6, 2024
1 parent 1f2dc6f commit 9758474
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 30 deletions.
6 changes: 6 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,12 @@ task integTest(type: RestIntegTestTask) {
description = "Run tests against a cluster"
testClassesDirs = sourceSets.test.output.classesDirs
classpath = sourceSets.test.runtimeClasspath
boolean runCompleteAggsTestSuite = Boolean.parseBoolean(System.getProperty('test_aggs', "false"))
if (!runCompleteAggsTestSuite) {
filter {
excludeTestsMatching "org.opensearch.neuralsearch.query.aggregation.*IT"
}
}
}
tasks.named("check").configure { dependsOn(integTest) }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
import java.util.List;
import java.util.Map;

/**
* Abstraction of aggregation processor for hybrid query. Main responsibility is to register custom
* collector manager on preProcess point and run collectorManager.reduce for non-concurrent search
*/
@AllArgsConstructor
public class HybridAggregationProcessor implements AggregationProcessor {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,14 @@
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight;
import org.opensearch.common.lucene.search.FilteredCollector;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.MultiCollectorWrapper;
import org.opensearch.search.query.QuerySearchResult;
Expand All @@ -33,7 +29,6 @@
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

Expand All @@ -48,7 +43,6 @@ public abstract class HybridCollectorManager implements CollectorManager<Collect
private final boolean isSingleShard;
private final int trackTotalHitsUpTo;
private final SortAndFormats sortAndFormats;
private final Optional<Weight> filteringWeightOptional;

public static CollectorManager createHybridCollectorManager(final SearchContext searchContext) throws IOException {
final IndexReader reader = searchContext.searcher().getIndexReader();
Expand All @@ -57,30 +51,20 @@ public static CollectorManager createHybridCollectorManager(final SearchContext
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();

Weight filterWeight = null;
// check for post filter
if (Objects.nonNull(searchContext.parsedPostFilter())) {
Query filterQuery = searchContext.parsedPostFilter().query();
ContextIndexSearcher searcher = searchContext.searcher();
filterWeight = searcher.createWeight(searcher.rewrite(filterQuery), ScoreMode.COMPLETE_NO_SCORES, 1f);
}

return searchContext.shouldUseConcurrentSearch()
? new HybridCollectorConcurrentSearchManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort(),
filterWeight
searchContext.sort()
)
: new HybridCollectorNonConcurrentManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
isSingleShard,
trackTotalHitsUpTo,
searchContext.sort(),
filterWeight
searchContext.sort()
);
}

Expand All @@ -89,12 +73,7 @@ public static CollectorManager createHybridCollectorManager(final SearchContext

Collector getCollector() {
Collector hybridcollector = new HybridTopScoreDocCollector<>(numHits, hitsThresholdChecker);
if (filteringWeightOptional.isEmpty()) {
// this is plain hybrid query scores collector
return hybridcollector;
}
// this is hybrid query scores collector with post filter applied
return new FilteredCollector(hybridcollector, filteringWeightOptional.get());
return hybridcollector;
}

@Override
Expand Down Expand Up @@ -217,10 +196,9 @@ public HybridCollectorNonConcurrentManager(
HitsThresholdChecker hitsThresholdChecker,
boolean isSingleShard,
int trackTotalHitsUpTo,
SortAndFormats sortAndFormats,
Weight filteringWeight
SortAndFormats sortAndFormats
) {
super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, Optional.ofNullable(filteringWeight));
super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats);
}

@Override
Expand All @@ -243,10 +221,9 @@ public HybridCollectorConcurrentSearchManager(
HitsThresholdChecker hitsThresholdChecker,
boolean isSingleShard,
int trackTotalHitsUpTo,
SortAndFormats sortAndFormats,
Weight filteringWeight
SortAndFormats sortAndFormats
) {
super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, Optional.ofNullable(filteringWeight));
super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats);
}

@Override
Expand Down

0 comments on commit 9758474

Please sign in to comment.