diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java index f600c233b9273..4c9b592abe250 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java @@ -499,7 +499,7 @@ ReducedQueryPhase reducedQueryPhase( for (SearchPhaseResult entry : queryResults) { QuerySearchResult result = entry.queryResult(); if (entry instanceof StreamSearchResult) { - tickets.addAll(((StreamSearchResult)entry).getFlightTickets()); + tickets.addAll(((StreamSearchResult) entry).getFlightTickets()); } from = result.from(); // sorted queries can set the size to 0 if they have enough competitive hits. @@ -728,7 +728,7 @@ public static final class ReducedQueryPhase { this.from = from; this.isEmptyResult = isEmptyResult; this.sortValueFormats = sortValueFormats; - this.osTickets = osTickets; + this.osTickets = osTickets; } /** diff --git a/server/src/main/java/org/opensearch/action/search/SearchTransportService.java b/server/src/main/java/org/opensearch/action/search/SearchTransportService.java index 561fa83793d4f..8c30d37076cc9 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchTransportService.java +++ b/server/src/main/java/org/opensearch/action/search/SearchTransportService.java @@ -243,7 +243,6 @@ public void sendExecuteQuery( // we optimize this and expect a QueryFetchSearchResult if we only have a single shard in the search request // this used to be the QUERY_AND_FETCH which doesn't exist anymore. - if (request.isStreamRequest()) { Writeable.Reader reader = StreamSearchResult::new; final ActionListener handler = responseWrapper.apply(connection, listener); diff --git a/server/src/main/java/org/opensearch/action/search/SearchType.java b/server/src/main/java/org/opensearch/action/search/SearchType.java index a8ada789adf22..a8e75c5f89113 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchType.java +++ b/server/src/main/java/org/opensearch/action/search/SearchType.java @@ -89,7 +89,7 @@ public static SearchType fromId(byte id) { } else if (id == 1 || id == 3) { // TODO this bwc layer can be removed once this is back-ported to 5.3 QUERY_AND_FETCH is removed // now return QUERY_THEN_FETCH; - } else if (id == 5) { + } else if (id == 5) { return STREAM; } else { throw new IllegalArgumentException("No search type for [" + id + "]"); diff --git a/server/src/main/java/org/opensearch/action/search/StreamAsyncAction.java b/server/src/main/java/org/opensearch/action/search/StreamAsyncAction.java index 67b6c7c11ce81..720cd6135fece 100644 --- a/server/src/main/java/org/opensearch/action/search/StreamAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/StreamAsyncAction.java @@ -33,30 +33,19 @@ package org.opensearch.action.search; import org.apache.logging.log4j.Logger; -import org.apache.lucene.search.TopFieldDocs; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.routing.GroupShardsIterator; import org.opensearch.common.util.concurrent.AbstractRunnable; -import org.opensearch.common.util.concurrent.AtomicArray; import org.opensearch.core.action.ActionListener; -import org.opensearch.search.SearchExtBuilder; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchPhaseResult; -import org.opensearch.search.SearchShardTarget; -import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.InternalSearchResponse; -import org.opensearch.search.internal.SearchContext; -import org.opensearch.search.internal.ShardSearchRequest; -import org.opensearch.search.profile.SearchProfileShardResults; -import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.stream.OSTicket; import org.opensearch.search.stream.StreamSearchResult; -import org.opensearch.search.suggest.Suggest; import org.opensearch.telemetry.tracing.Tracer; import org.opensearch.transport.Transport; -import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -72,8 +61,46 @@ */ class StreamAsyncAction extends SearchQueryThenFetchAsyncAction { - public StreamAsyncAction(Logger logger, SearchTransportService searchTransportService, BiFunction nodeIdToConnection, Map aliasFilter, Map concreteIndexBoosts, Map> indexRoutings, SearchPhaseController searchPhaseController, Executor executor, QueryPhaseResultConsumer resultConsumer, SearchRequest request, ActionListener listener, GroupShardsIterator shardsIts, TransportSearchAction.SearchTimeProvider timeProvider, ClusterState clusterState, SearchTask task, SearchResponse.Clusters clusters, SearchRequestContext searchRequestContext, Tracer tracer) { - super(logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings, searchPhaseController, executor, resultConsumer, request, listener, shardsIts, timeProvider, clusterState, task, clusters, searchRequestContext, tracer); + public StreamAsyncAction( + Logger logger, + SearchTransportService searchTransportService, + BiFunction nodeIdToConnection, + Map aliasFilter, + Map concreteIndexBoosts, + Map> indexRoutings, + SearchPhaseController searchPhaseController, + Executor executor, + QueryPhaseResultConsumer resultConsumer, + SearchRequest request, + ActionListener listener, + GroupShardsIterator shardsIts, + TransportSearchAction.SearchTimeProvider timeProvider, + ClusterState clusterState, + SearchTask task, + SearchResponse.Clusters clusters, + SearchRequestContext searchRequestContext, + Tracer tracer + ) { + super( + logger, + searchTransportService, + nodeIdToConnection, + aliasFilter, + concreteIndexBoosts, + indexRoutings, + searchPhaseController, + executor, + resultConsumer, + request, + listener, + shardsIts, + timeProvider, + clusterState, + task, + clusters, + searchRequestContext, + tracer + ); } @Override @@ -83,6 +110,7 @@ protected SearchPhase getNextPhase(final SearchPhaseResults r class StreamSearchReducePhase extends SearchPhase { private SearchPhaseContext context; + protected StreamSearchReducePhase(String name, SearchPhaseContext context) { super(name); this.context = context; @@ -97,10 +125,12 @@ public void run() { class StreamReduceAction extends AbstractRunnable { private SearchPhaseContext context; private SearchPhase phase; + StreamReduceAction(SearchPhaseContext context, SearchPhase phase) { this.context = context; } + @Override protected void doRun() throws Exception { List tickets = new ArrayList<>(); @@ -109,7 +139,17 @@ protected void doRun() throws Exception { tickets.addAll(((StreamSearchResult) entry).getFlightTickets()); } } - InternalSearchResponse internalSearchResponse = new InternalSearchResponse(SearchHits.empty(),null, null, null, false, false, 1, Collections.emptyList(), tickets); + InternalSearchResponse internalSearchResponse = new InternalSearchResponse( + SearchHits.empty(), + null, + null, + null, + false, + false, + 1, + Collections.emptyList(), + tickets + ); context.sendSearchResponse(internalSearchResponse, results.getAtomicArray()); } diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index 5df862dd78017..3213704d4d4ec 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -124,7 +124,8 @@ import java.util.stream.StreamSupport; import static org.opensearch.action.admin.cluster.node.tasks.get.GetTaskAction.TASKS_ORIGIN; -import static org.opensearch.action.search.SearchType.*; +import static org.opensearch.action.search.SearchType.DFS_QUERY_THEN_FETCH; +import static org.opensearch.action.search.SearchType.QUERY_THEN_FETCH; import static org.opensearch.search.sort.FieldSortBuilder.hasPrimaryFieldSort; /** diff --git a/server/src/main/java/org/opensearch/bootstrap/Security.java b/server/src/main/java/org/opensearch/bootstrap/Security.java index a0bcf7086b6d5..29ea1624c5b69 100644 --- a/server/src/main/java/org/opensearch/bootstrap/Security.java +++ b/server/src/main/java/org/opensearch/bootstrap/Security.java @@ -40,7 +40,6 @@ import org.opensearch.http.HttpTransportSettings; import org.opensearch.plugins.PluginInfo; import org.opensearch.plugins.PluginsService; -import org.opensearch.secure_sm.SecureSM; import org.opensearch.transport.TcpTransport; import java.io.IOException; @@ -138,15 +137,15 @@ static void configure(Environment environment, boolean filterBadDefaults) throws // enable security policy: union of template and environment-based paths, and possibly plugin permissions Map codebases = getCodebaseJarMap(JarHell.parseClassPath()); -// Policy.setPolicy( -// new OpenSearchPolicy( -// codebases, -// createPermissions(environment), -// getPluginPermissions(environment), -// filterBadDefaults, -// createRecursiveDataPathPermission(environment) -// ) -// ); + // Policy.setPolicy( + // new OpenSearchPolicy( + // codebases, + // createPermissions(environment), + // getPluginPermissions(environment), + // filterBadDefaults, + // createRecursiveDataPathPermission(environment) + // ) + // ); // enable security manager final String[] classesThatCanExit = new String[] { diff --git a/server/src/main/java/org/opensearch/search/SearchService.java b/server/src/main/java/org/opensearch/search/SearchService.java index 254b3edbf4d35..645c79ad223e2 100644 --- a/server/src/main/java/org/opensearch/search/SearchService.java +++ b/server/src/main/java/org/opensearch/search/SearchService.java @@ -872,7 +872,6 @@ public void executeQueryPhase( }, wrapFailureListener(listener, readerContext, markAsUsed)); } - public void executeStreamPhase(QuerySearchRequest request, SearchShardTask task, ActionListener listener) { final ReaderContext readerContext = findReaderContext(request.contextId(), request.shardSearchRequest()); final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.shardSearchRequest()); diff --git a/server/src/main/java/org/opensearch/search/internal/InternalSearchResponse.java b/server/src/main/java/org/opensearch/search/internal/InternalSearchResponse.java index baf93cb6cf741..8f53612e98727 100644 --- a/server/src/main/java/org/opensearch/search/internal/InternalSearchResponse.java +++ b/server/src/main/java/org/opensearch/search/internal/InternalSearchResponse.java @@ -114,7 +114,7 @@ public InternalSearchResponse(StreamInput in) throws IOException { in.readOptionalWriteable(SearchProfileShardResults::new), in.readVInt(), readSearchExtBuildersOnOrAfter(in), - (in.readBoolean()? in.readList(OSTicket::new): null) + (in.readBoolean() ? in.readList(OSTicket::new) : null) ); } diff --git a/server/src/main/java/org/opensearch/search/internal/SearchContext.java b/server/src/main/java/org/opensearch/search/internal/SearchContext.java index 3fdec3d4bf9bf..753ba309564d1 100644 --- a/server/src/main/java/org/opensearch/search/internal/SearchContext.java +++ b/server/src/main/java/org/opensearch/search/internal/SearchContext.java @@ -31,10 +31,13 @@ package org.opensearch.search.internal; +import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Collector; import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreMode; import org.opensearch.action.search.SearchShardTask; import org.opensearch.action.search.SearchType; import org.opensearch.arrow.FlightService; @@ -81,6 +84,7 @@ import org.opensearch.search.stream.StreamSearchResult; import org.opensearch.search.suggest.SuggestionSearchContext; +import java.io.IOException; import java.util.Collection; import java.util.HashMap; import java.util.List; @@ -124,7 +128,17 @@ public List toInternalAggregations(Collection co } }; - public static final ArrowCollector NO_OP_ARROW_COLLECTOR = new ArrowCollector(); + public static final ArrowCollector NO_OP_ARROW_COLLECTOR = new ArrowCollector(new Collector() { + @Override + public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { + return null; + } + + @Override + public ScoreMode scoreMode() { + return null; + } + }, 1000); private final List releasables = new CopyOnWriteArrayList<>(); private final AtomicBoolean closed = new AtomicBoolean(false); diff --git a/server/src/main/java/org/opensearch/search/lookup/SearchLookup.java b/server/src/main/java/org/opensearch/search/lookup/SearchLookup.java index 0a53e30ce2ac3..6d8e3330bc042 100644 --- a/server/src/main/java/org/opensearch/search/lookup/SearchLookup.java +++ b/server/src/main/java/org/opensearch/search/lookup/SearchLookup.java @@ -51,7 +51,7 @@ * @opensearch.api */ @PublicApi(since = "1.0.0") -public class /**/SearchLookup { +public class /**/ SearchLookup { /** * The maximum depth of field dependencies. * When a runtime field's doc values depends on another runtime field's doc values, diff --git a/server/src/main/java/org/opensearch/search/query/ArrowCollector.java b/server/src/main/java/org/opensearch/search/query/ArrowCollector.java index 3b58e8b10a8da..5a09b41254db0 100644 --- a/server/src/main/java/org/opensearch/search/query/ArrowCollector.java +++ b/server/src/main/java/org/opensearch/search/query/ArrowCollector.java @@ -34,6 +34,7 @@ import org.apache.lucene.search.ScoreMode; import org.opensearch.common.annotation.ExperimentalApi; import org.opensearch.index.fielddata.IndexNumericFieldData; +import org.opensearch.search.query.stream.StreamCollector; import java.io.IOException; import java.util.ArrayList; @@ -49,17 +50,28 @@ public class ArrowCollector extends StreamCollector { List projectionFields; VectorSchemaRoot root; - final int BATCH_SIZE = 1000; + Map vectors; - public ArrowCollector(){ - super(); + private final int batchSize; + public ArrowCollector(Collector in, int batchSize) { + super(in); + this.batchSize = batchSize; } public ArrowCollector(Collector in, List projectionFields, int batchSize) { - // super(delegateCollector); + super(in, batchSize); allocator = new RootAllocator(); this.projectionFields = projectionFields; + this.batchSize = batchSize; + } + + public Map getVectors() { + return this.vectors; + } + + public void setVectors(Map vectors) { + this.vectors = vectors; } private Field createArrowField(String fieldName, IndexNumericFieldData.NumericType type) { @@ -162,6 +174,7 @@ public LeafCollector getLeafCollector(LeafReaderContext context) throws IOExcept } iterators.put(field.fieldName, numericDocValues[0]); }); + setVectors(vectors); schema = new Schema(arrowFields.values()); root = new VectorSchemaRoot(new ArrayList<>(arrowFields.values()), new ArrayList<>(vectors.values())); final int[] i = { 0 }; @@ -184,8 +197,8 @@ public void collect(int docId) throws IOException { } if (iterator.advanceExact(docId)) { index[0] = i[0] / iterators.size(); - if (index[0] > BATCH_SIZE || vector.getValueCapacity() == 0) { - vector.allocateNew(BATCH_SIZE); + if (index[0] > batchSize || vector.getValueCapacity() == 0) { + vector.allocateNew(batchSize); } setValue(vector, index[0], iterator.longValue()); i[0]++; @@ -205,20 +218,18 @@ public void finish() throws IOException { @Override public void onNewBatch() { - + for (FieldVector vector : getVectors().values()) { + ((BaseFixedWidthVector) vector).allocateNew(batchSize); + } } @Override public VectorSchemaRoot getVectorSchemaRoot(BufferAllocator allocator) { - return null; + return root; } @Override public ScoreMode scoreMode() { return ScoreMode.COMPLETE_NO_SCORES; } - - public VectorSchemaRoot getRootVector() { - return root; - } } diff --git a/server/src/main/java/org/opensearch/search/query/EarlyTerminatingCollector.java b/server/src/main/java/org/opensearch/search/query/EarlyTerminatingCollector.java index 5b86a70d64fff..e369900d90bb2 100644 --- a/server/src/main/java/org/opensearch/search/query/EarlyTerminatingCollector.java +++ b/server/src/main/java/org/opensearch/search/query/EarlyTerminatingCollector.java @@ -47,7 +47,7 @@ * @opensearch.internal */ public class EarlyTerminatingCollector extends FilterCollector { - static final class EarlyTerminationException extends RuntimeException { + public static final class EarlyTerminationException extends RuntimeException { EarlyTerminationException(String msg) { super(msg); } diff --git a/server/src/main/java/org/opensearch/search/query/QueryCollectorContext.java b/server/src/main/java/org/opensearch/search/query/QueryCollectorContext.java index 32f4a45813503..db9b5d09c1f2c 100644 --- a/server/src/main/java/org/opensearch/search/query/QueryCollectorContext.java +++ b/server/src/main/java/org/opensearch/search/query/QueryCollectorContext.java @@ -135,7 +135,7 @@ protected InternalProfileCollectorManager createWithProfiler(InternalProfileColl * * @param result The query search result to populate */ - void postProcess(QuerySearchResult result) throws IOException {} + public void postProcess(QuerySearchResult result) throws IOException {} /** * Creates the collector tree from the provided collectors diff --git a/server/src/main/java/org/opensearch/search/query/QueryPhase.java b/server/src/main/java/org/opensearch/search/query/QueryPhase.java index f5921b26debb1..3f3965791bb8e 100644 --- a/server/src/main/java/org/opensearch/search/query/QueryPhase.java +++ b/server/src/main/java/org/opensearch/search/query/QueryPhase.java @@ -32,7 +32,6 @@ package org.opensearch.search.query; -import org.apache.arrow.flight.Ticket; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.index.IndexReader; @@ -66,16 +65,13 @@ import org.opensearch.search.internal.SearchContext; import org.opensearch.search.profile.ProfileShardResult; import org.opensearch.search.profile.SearchProfileShardResults; -import org.opensearch.search.query.stream.StreamResultFlightProducer; +import org.opensearch.search.profile.query.InternalProfileCollector; import org.opensearch.search.rescore.RescoreProcessor; import org.opensearch.search.sort.SortAndFormats; -import org.opensearch.search.stream.OSTicket; -import org.opensearch.search.stream.StreamSearchResult; import org.opensearch.search.suggest.SuggestProcessor; import org.opensearch.threadpool.ThreadPool; import java.io.IOException; -import java.util.ArrayList; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -83,13 +79,11 @@ import java.util.concurrent.ExecutorService; import java.util.stream.Collectors; -import static org.opensearch.search.profile.query.CollectorResult.REASON_SEARCH_TOP_HITS; import static org.opensearch.search.query.QueryCollectorContext.createEarlyTerminationCollectorContext; import static org.opensearch.search.query.QueryCollectorContext.createFilteredCollectorContext; import static org.opensearch.search.query.QueryCollectorContext.createMinScoreCollectorContext; import static org.opensearch.search.query.QueryCollectorContext.createMultiCollectorContext; import static org.opensearch.search.query.TopDocsCollectorContext.createTopDocsCollectorContext; -import static org.opensearch.search.query.stream.StreamSearchPhase.DefaultStreamSearchPhaseSearcher.createQueryCollector; /** * Query phase of a search request, used to run the query and get back from each shard information about the matching documents @@ -103,7 +97,6 @@ public class QueryPhase { // TODO: remove this property public static final boolean SYS_PROP_REWRITE_SORT = Booleans.parseBoolean(System.getProperty("opensearch.search.rewrite_sort", "true")); public static final QueryPhaseSearcher DEFAULT_QUERY_PHASE_SEARCHER = new DefaultQueryPhaseSearcher(); - public static final QueryPhaseSearcher STREAM_QUERY_PHASE_SEARCHER = new StreamQueryPhaseSearcher(); private final QueryPhaseSearcher queryPhaseSearcher; private final SuggestProcessor suggestProcessor; private final RescoreProcessor rescoreProcessor; @@ -190,146 +183,6 @@ static boolean executeInternal(SearchContext searchContext) throws QueryPhaseExe return executeInternal(searchContext, QueryPhase.DEFAULT_QUERY_PHASE_SEARCHER); } - public static boolean executeStreamInternal( - SearchContext searchContext, - QueryPhaseSearcher queryPhaseSearcher, - List projectionFields - ) { - return executeInternal(searchContext, QueryPhase.STREAM_QUERY_PHASE_SEARCHER, projectionFields); - } - - /** - * In a package-private method so that it can be tested without having to - * wire everything (mapperService, etc.) - * @return whether the rescoring phase should be executed - * - * TODO: refactor this - */ - public static boolean executeInternal( - SearchContext searchContext, - QueryPhaseSearcher queryPhaseSearcher, - List projectionFields - ) throws QueryPhaseExecutionException { - final ContextIndexSearcher searcher = searchContext.searcher(); - final IndexReader reader = searcher.getIndexReader(); - QuerySearchResult queryResult = searchContext.queryResult(); - queryResult.searchTimedOut(false); - try { - queryResult.from(searchContext.from()); - queryResult.size(searchContext.size()); - Query query = searchContext.query(); - assert query == searcher.rewrite(query); // already rewritten - - final ScrollContext scrollContext = searchContext.scrollContext(); - if (scrollContext != null) { - if (scrollContext.totalHits == null) { - // first round - assert scrollContext.lastEmittedDoc == null; - // there is not much that we can optimize here since we want to collect all - // documents in order to get the total number of hits - - } else { - final ScoreDoc after = scrollContext.lastEmittedDoc; - if (canEarlyTerminate(reader, searchContext.sort())) { - // now this gets interesting: since the search sort is a prefix of the index sort, we can directly - // skip to the desired doc - if (after != null) { - query = new BooleanQuery.Builder().add(query, BooleanClause.Occur.MUST) - .add(new SearchAfterSortedDocQuery(searchContext.sort().sort, (FieldDoc) after), BooleanClause.Occur.FILTER) - .build(); - } - } - } - } - - final LinkedList collectors = new LinkedList<>(); - // whether the chain contains a collector that filters documents - boolean hasFilterCollector = false; - if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER) { - // add terminate_after before the filter collectors - // it will only be applied on documents accepted by these filter collectors - collectors.add(createEarlyTerminationCollectorContext(searchContext.terminateAfter())); - // this collector can filter documents during the collection - hasFilterCollector = true; - } - if (searchContext.parsedPostFilter() != null) { - // add post filters before aggregations - // it will only be applied to top hits - collectors.add(createFilteredCollectorContext(searcher, searchContext.parsedPostFilter().query())); - // this collector can filter documents during the collection - hasFilterCollector = true; - } - - // plug in additional collectors, like aggregations except global aggregations - final List> managersExceptGlobalAgg = searchContext - .queryCollectorManagers() - .entrySet() - .stream() - .filter(entry -> !(entry.getKey().equals(GlobalAggCollectorManager.class))) - .map(Map.Entry::getValue) - .collect(Collectors.toList()); - if (managersExceptGlobalAgg.isEmpty() == false) { - collectors.add(createMultiCollectorContext(managersExceptGlobalAgg)); - } - - if (searchContext.minimumScore() != null) { - // apply the minimum score after multi collector so we filter aggs as well - collectors.add(createMinScoreCollectorContext(searchContext.minimumScore())); - // this collector can filter documents during the collection - hasFilterCollector = true; - } - - boolean timeoutSet = scrollContext == null - && searchContext.timeout() != null - && searchContext.timeout().equals(SearchService.NO_TIMEOUT) == false; - - final Runnable timeoutRunnable; - if (timeoutSet) { - timeoutRunnable = searcher.addQueryCancellation(createQueryTimeoutChecker(searchContext)); - } else { - timeoutRunnable = null; - } - - if (searchContext.lowLevelCancellation()) { - searcher.addQueryCancellation(() -> { - SearchShardTask task = searchContext.getTask(); - if (task != null && task.isCancelled()) { - throw new TaskCancelledException("cancelled task with reason: " + task.getReasonCancelled()); - } - }); - } - - try { - boolean shouldRescore = queryPhaseSearcher.searchWith( - searchContext, - searcher, - query, - collectors, - projectionFields, - hasFilterCollector, - timeoutSet - ); - - ExecutorService executor = searchContext.indexShard().getThreadPool().executor(ThreadPool.Names.SEARCH); - if (executor instanceof EWMATrackingThreadPoolExecutor) { - final EWMATrackingThreadPoolExecutor rExecutor = (EWMATrackingThreadPoolExecutor) executor; - queryResult.nodeQueueSize(rExecutor.getCurrentQueueSize()); - queryResult.serviceTimeEWMA((long) rExecutor.getTaskExecutionEWMA()); - } - - return shouldRescore; - } finally { - // Search phase has finished, no longer need to check for timeout - // otherwise aggregation phase might get cancelled. - if (timeoutRunnable != null) { - searcher.removeQueryCancellation(timeoutRunnable); - } - } - } catch (Exception e) { - throw new QueryPhaseExecutionException(searchContext.shardTarget(), "Failed to execute main query", e); - } - } - /** * In a package-private method so that it can be tested without having to * wire everything (mapperService, etc.) @@ -487,37 +340,42 @@ private static boolean searchWithCollector( boolean hasFilterCollector, boolean timeoutSet ) throws IOException { - final ArrowCollector collector = createQueryCollector(collectors); + // add passed collector, the first collector context in the chain + collectors.addFirst(Objects.requireNonNull(queryCollectorContext)); + + final Collector queryCollector; + if (searchContext.getProfilers() != null) { + InternalProfileCollector profileCollector = QueryCollectorContext.createQueryCollectorWithProfiler(collectors); + searchContext.getProfilers().getCurrentQueryProfiler().setCollector(profileCollector); + queryCollector = profileCollector; + } else { + queryCollector = QueryCollectorContext.createQueryCollector(collectors); + } QuerySearchResult queryResult = searchContext.queryResult(); - StreamResultFlightProducer.CollectorCallback collectorCallback = new StreamResultFlightProducer.CollectorCallback() { - @Override - public void collect(Collector queryCollector) throws IOException { - try { - searcher.search(query, queryCollector); - } catch (EarlyTerminatingCollector.EarlyTerminationException e) { - // EarlyTerminationException is not caught in ContextIndexSearcher to allow force termination of collection. Postcollection - // still needs to be processed for Aggregations when early termination takes place. - searchContext.bucketCollectorProcessor().processPostCollection(queryCollector); - queryResult.terminatedEarly(true); - } - if (searchContext.isSearchTimedOut()) { - assert timeoutSet : "TimeExceededException thrown even though timeout wasn't set"; - if (searchContext.request().allowPartialSearchResults() == false) { - throw new QueryPhaseExecutionException(searchContext.shardTarget(), "Time exceeded"); - } - queryResult.searchTimedOut(true); - } - if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) { - queryResult.terminatedEarly(false); - } - for (QueryCollectorContext ctx : collectors) { - ctx.postProcess(queryResult); - } + try { + searcher.search(query, queryCollector); + } catch (EarlyTerminatingCollector.EarlyTerminationException e) { + // EarlyTerminationException is not caught in ContextIndexSearcher to allow force termination of collection. Postcollection + // still needs to be processed for Aggregations when early termination takes place. + searchContext.bucketCollectorProcessor().processPostCollection(queryCollector); + queryResult.terminatedEarly(true); + } + if (searchContext.isSearchTimedOut()) { + assert timeoutSet : "TimeExceededException thrown even though timeout wasn't set"; + if (searchContext.request().allowPartialSearchResults() == false) { + throw new QueryPhaseExecutionException(searchContext.shardTarget(), "Time exceeded"); } - }; - Ticket ticket = searchContext.flightService().getFlightProducer().createStream(collector, collectorCallback); - StreamSearchResult streamSearchResult = searchContext.streamSearchResult(); - streamSearchResult.flights(List.of(new OSTicket(ticket.getBytes()))); + queryResult.searchTimedOut(true); + } + if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) { + queryResult.terminatedEarly(false); + } + for (QueryCollectorContext ctx : collectors) { + ctx.postProcess(queryResult); + } + if (queryCollectorContext instanceof RescoringQueryCollectorContext) { + return ((RescoringQueryCollectorContext) queryCollectorContext).shouldRescore(); + } return false; } @@ -613,63 +471,4 @@ protected boolean searchWithCollector( ); } } - - /** - * Default {@link QueryPhaseSearcher} implementation which delegates to the {@link QueryPhase}. - * - * @opensearch.internal - */ - public static class StreamQueryPhaseSearcher implements QueryPhaseSearcher { - - /** - * Please use {@link QueryPhase#STREAM_QUERY_PHASE_SEARCHER} - */ - protected StreamQueryPhaseSearcher() {} - - @Override - public boolean searchWith( - SearchContext searchContext, - ContextIndexSearcher searcher, - Query query, - LinkedList collectors, - boolean hasFilterCollector, - boolean hasTimeout - ) throws IOException { - return searchWith(searchContext, searcher, query, collectors, new ArrayList<>(), hasFilterCollector, hasTimeout); - } - - @Override - public boolean searchWith( - SearchContext searchContext, - ContextIndexSearcher searcher, - Query query, - LinkedList collectors, - List projectionFields, - boolean hasFilterCollector, - boolean hasTimeout - ) throws IOException { - return searchWithCollector(searchContext, searcher, query, collectors, projectionFields, hasFilterCollector, hasTimeout); - } - - protected boolean searchWithCollector( - SearchContext searchContext, - ContextIndexSearcher searcher, - Query query, - LinkedList collectors, - List projectionFields, - boolean hasFilterCollector, - boolean hasTimeout - ) throws IOException { - final ArrowCollectorContext arrowCollectorContext = new ArrowCollectorContext(REASON_SEARCH_TOP_HITS, projectionFields); - return QueryPhase.searchWithCollector( - searchContext, - searcher, - query, - collectors, - arrowCollectorContext, - hasFilterCollector, - hasTimeout - ); - } - } } diff --git a/server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java b/server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java index f780f6fe32af2..3d3961b767b01 100644 --- a/server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java +++ b/server/src/main/java/org/opensearch/search/query/TopDocsCollectorContext.java @@ -211,7 +211,7 @@ Collector create(Collector in) { } @Override - void postProcess(QuerySearchResult result) { + public void postProcess(QuerySearchResult result) { final TotalHits totalHitCount = hitCountSupplier.get(); final TopDocs topDocs; if (sort != null) { @@ -273,7 +273,7 @@ Collector create(Collector in) throws IOException { } @Override - void postProcess(QuerySearchResult result) throws IOException { + public void postProcess(QuerySearchResult result) throws IOException { final CollapseTopFieldDocs topDocs = topDocsCollector.getTopDocs(); result.topDocs(new TopDocsAndMaxScore(topDocs, maxScoreSupplier.get()), sortFmt); } @@ -619,7 +619,7 @@ TopDocsAndMaxScore newTopDocs() { } @Override - void postProcess(QuerySearchResult result) throws IOException { + public void postProcess(QuerySearchResult result) throws IOException { final TopDocsAndMaxScore topDocs = newTopDocs(); result.topDocs(topDocs, sortAndFormats == null ? null : sortAndFormats.formats); } @@ -684,7 +684,7 @@ protected ReduceableSearchResult reduceWith(final TopDocs topDocs, final float m } @Override - void postProcess(QuerySearchResult result) throws IOException { + public void postProcess(QuerySearchResult result) throws IOException { final TopDocsAndMaxScore topDocs = newTopDocs(); if (scrollContext.totalHits == null) { // first round diff --git a/server/src/main/java/org/opensearch/search/query/StreamCollector.java b/server/src/main/java/org/opensearch/search/query/stream/StreamCollector.java similarity index 90% rename from server/src/main/java/org/opensearch/search/query/StreamCollector.java rename to server/src/main/java/org/opensearch/search/query/stream/StreamCollector.java index 3705e071bf85a..9bfbbe8c4bbd1 100644 --- a/server/src/main/java/org/opensearch/search/query/StreamCollector.java +++ b/server/src/main/java/org/opensearch/search/query/stream/StreamCollector.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.search.query; +package org.opensearch.search.query.stream; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.VectorSchemaRoot; @@ -16,7 +16,6 @@ import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.Scorable; import org.opensearch.common.annotation.ExperimentalApi; -import org.opensearch.search.query.stream.StreamWriter; import java.io.IOException; @@ -27,6 +26,10 @@ public abstract class StreamCollector extends FilterCollector { private int docsInCurrentBatch; private StreamWriter streamWriter = null; + public StreamCollector(Collector collector) { + this(collector, 1000); + } + public StreamCollector(Collector collector, int batchSize) { super(collector); this.batchSize = batchSize; @@ -34,7 +37,7 @@ public StreamCollector(Collector collector, int batchSize) { } public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { - LeafCollector leafCollector =((this.in != null)? super.getLeafCollector(context): null); + LeafCollector leafCollector = ((this.in != null) ? super.getLeafCollector(context) : null); return new LeafCollector() { @Override public void setScorer(Scorable scorable) throws IOException { diff --git a/server/src/main/java/org/opensearch/search/query/stream/StreamContext.java b/server/src/main/java/org/opensearch/search/query/stream/StreamContext.java index f419f3b25c747..4ff56483d7d2f 100644 --- a/server/src/main/java/org/opensearch/search/query/stream/StreamContext.java +++ b/server/src/main/java/org/opensearch/search/query/stream/StreamContext.java @@ -8,11 +8,12 @@ package org.opensearch.search.query.stream; -import io.grpc.internal.ServerStreamListener; import org.apache.arrow.flight.BackpressureStrategy; import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.vector.VectorSchemaRoot; +import io.grpc.internal.ServerStreamListener; + public class StreamContext { private VectorSchemaRoot vectorSchemaRoot; @@ -20,8 +21,12 @@ public class StreamContext { private ServerStreamListener listener; private BackpressureStrategy backpressureStrategy; - public StreamContext(VectorSchemaRoot vectorSchemaRoot, FlightDescriptor flightDescriptor, ServerStreamListener listener, - BackpressureStrategy backpressureStrategy) { + public StreamContext( + VectorSchemaRoot vectorSchemaRoot, + FlightDescriptor flightDescriptor, + ServerStreamListener listener, + BackpressureStrategy backpressureStrategy + ) { this.vectorSchemaRoot = vectorSchemaRoot; this.flightDescriptor = flightDescriptor; this.listener = listener; @@ -31,12 +36,15 @@ public StreamContext(VectorSchemaRoot vectorSchemaRoot, FlightDescriptor flightD public VectorSchemaRoot getVectorSchemaRoot() { return vectorSchemaRoot; } + public FlightDescriptor getFlightDescriptor() { return flightDescriptor; } + public ServerStreamListener getListener() { return listener; } + public BackpressureStrategy getBackpressureStrategy() { return backpressureStrategy; } diff --git a/server/src/main/java/org/opensearch/search/query/stream/StreamResultCollector.java b/server/src/main/java/org/opensearch/search/query/stream/StreamResultCollector.java deleted file mode 100644 index 6f9fe5120d4fd..0000000000000 --- a/server/src/main/java/org/opensearch/search/query/stream/StreamResultCollector.java +++ /dev/null @@ -1,119 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.search.query.stream; - -import org.apache.arrow.flight.FlightDescriptor; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.IntVector; -import org.apache.arrow.vector.VarCharVector; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.types.pojo.ArrowType; -import org.apache.arrow.vector.types.pojo.Field; -import org.apache.arrow.vector.types.pojo.FieldType; -import org.apache.arrow.vector.types.pojo.Schema; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.search.Collector; -import org.apache.lucene.search.LeafCollector; -import org.apache.lucene.search.Scorable; -import org.apache.lucene.search.ScoreMode; -import org.apache.lucene.search.Weight; -import org.opensearch.arrow.FlightService; - -import java.io.IOException; - -import static java.util.Arrays.asList; -import static org.apache.lucene.search.ScoreMode.TOP_DOCS; - -public class StreamResultCollector implements Collector { - - BufferAllocator allocator; - Field docIDField; - Field joinField; - Schema schema; - FlightService flightService; - - FlightDescriptor flightDescriptor; - Collector in; - StreamContext streamContext; - - public StreamResultCollector(Collector in, FlightService flightService, FlightDescriptor flightDescriptor) { - this.in = in; - this.flightService = flightService; - allocator = flightService.getAllocator(); - docIDField = new Field("docID", FieldType.nullable(new ArrowType.Int(32, true)), null); - joinField = new Field("joinField", FieldType.nullable(new ArrowType.Utf8()), null); - schema = new Schema(asList(docIDField, joinField)); - this.flightDescriptor = flightDescriptor; - } - - public StreamResultCollector(Collector in, StreamContext streamContext) { - this.streamContext = streamContext; - } - - @Override - public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { - LeafCollector innerLeafCollector = (this.in != null? this.in.getLeafCollector(context) : null); - VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); - VarCharVector joinFieldVector = (VarCharVector) root.getVector("joinField"); - IntVector docIDVector = (IntVector) root.getVector("docID"); - BinaryDocValues docValues = context.reader().getBinaryDocValues("joinField"); - root.getFieldVectors(); - int batchSize = 1000; - docIDVector.allocateNew(batchSize); - joinFieldVector.allocateNew(batchSize); - final int[] i = {0}; - return new LeafCollector() { - @Override - public void setScorer(Scorable scorable) throws IOException { - if (innerLeafCollector != null) { - innerLeafCollector.setScorer(scorable); - } - } - - @Override - public void collect(int docId) throws IOException { - if (innerLeafCollector != null) { - innerLeafCollector.collect(docId); - } - if (docValues != null) { - if (docValues.advanceExact(docId)) { - if (i[0] > batchSize) { - docIDVector.allocateNew(batchSize); - joinFieldVector.allocateNew(batchSize); - } - docIDVector.set(i[0], docId); - joinFieldVector.set(i[0], docValues.binaryValue().bytes); - i[0]++; - } - } - } - - @Override - public void finish() throws IOException { - if (innerLeafCollector != null) { - innerLeafCollector.finish(); - } - root.setRowCount(i[0]); - //flightService.getFlightProducer().addOutput(flightDescriptor, root); - } - }; - } - - @Override - public ScoreMode scoreMode() { - return TOP_DOCS; - } - - public void setWeight(Weight weight) { - if (in != null) { - in.setWeight(weight); - } - } -} diff --git a/server/src/main/java/org/opensearch/search/query/stream/StreamResultFlightProducer.java b/server/src/main/java/org/opensearch/search/query/stream/StreamResultFlightProducer.java index c252582ac1f13..5d91ad570611d 100644 --- a/server/src/main/java/org/opensearch/search/query/stream/StreamResultFlightProducer.java +++ b/server/src/main/java/org/opensearch/search/query/stream/StreamResultFlightProducer.java @@ -8,12 +8,14 @@ package org.opensearch.search.query.stream; -import org.apache.arrow.flight.*; +import org.apache.arrow.flight.BackpressureStrategy; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.NoOpFlightProducer; +import org.apache.arrow.flight.Ticket; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.lucene.search.Collector; import org.opensearch.common.annotation.ExperimentalApi; -import org.opensearch.search.query.StreamCollector; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -39,7 +41,7 @@ public Ticket createStream(StreamCollector streamCollector, CollectorCallback ca } @Override - public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { + public void getStream(FlightProducer.CallContext context, Ticket ticket, ServerStreamListener listener) { if (lookup.get(ticket) == null) { listener.error(new IllegalStateException("Data not ready")); return; @@ -70,6 +72,7 @@ public void getStream(CallContext context, Ticket ticket, ServerStreamListener l static class StreamState { StreamCollector streamCollector; CollectorCallback collectorCallback; + StreamState(StreamCollector streamCollector, CollectorCallback collectorCallback) { this.streamCollector = streamCollector; this.collectorCallback = collectorCallback; diff --git a/server/src/main/java/org/opensearch/search/query/stream/StreamSearchPhase.java b/server/src/main/java/org/opensearch/search/query/stream/StreamSearchPhase.java index 2742cc813cc86..97c20ff54d752 100644 --- a/server/src/main/java/org/opensearch/search/query/stream/StreamSearchPhase.java +++ b/server/src/main/java/org/opensearch/search/query/stream/StreamSearchPhase.java @@ -21,6 +21,7 @@ import org.opensearch.search.profile.SearchProfileShardResults; import org.opensearch.search.query.ArrowCollector; import org.opensearch.search.query.EarlyTerminatingCollector; +import org.opensearch.search.query.ProjectionField; import org.opensearch.search.query.QueryCollectorContext; import org.opensearch.search.query.QueryPhase; import org.opensearch.search.query.QueryPhaseExecutionException; @@ -35,10 +36,10 @@ public class StreamSearchPhase extends QueryPhase { private static final Logger LOGGER = LogManager.getLogger(StreamSearchPhase.class); - public static final QueryPhaseSearcher DEFAULT_QUERY_PHASE_SEARCHER = new DefaultStreamSearchPhaseSearcher(); + public static final QueryPhaseSearcher DEFAULT_STREAM_PHASE_SEARCHER = new DefaultStreamSearchPhaseSearcher(); public StreamSearchPhase() { - super(DEFAULT_QUERY_PHASE_SEARCHER); + super(DEFAULT_STREAM_PHASE_SEARCHER); } @Override @@ -57,7 +58,6 @@ public void execute(SearchContext searchContext) throws QueryPhaseExecutionExcep } } - public static class DefaultStreamSearchPhaseSearcher extends DefaultQueryPhaseSearcher { @Override @@ -66,10 +66,11 @@ public boolean searchWith( ContextIndexSearcher searcher, Query query, LinkedList collectors, + List projectionFields, boolean hasFilterCollector, boolean hasTimeout ) throws IOException { - return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); + return searchWithCollector(searchContext, searcher, query, collectors, projectionFields, hasFilterCollector, hasTimeout); } @Override @@ -87,25 +88,27 @@ public void postProcess(SearchContext context) { }; } - protected boolean searchWithCollector( + private static boolean searchWithCollector( SearchContext searchContext, ContextIndexSearcher searcher, Query query, LinkedList collectors, + List projectionFields, boolean hasFilterCollector, boolean hasTimeout ) throws IOException { - return searchWithCollector(searchContext, searcher, query, collectors, hasTimeout); + return searchWithCollector(searchContext, searcher, query, collectors, projectionFields, hasTimeout); } - private boolean searchWithCollector( + private static boolean searchWithCollector( SearchContext searchContext, ContextIndexSearcher searcher, Query query, LinkedList collectors, + List projectionFields, boolean timeoutSet ) throws IOException { - final ArrowCollector collector = createQueryCollector(collectors); + final ArrowCollector collector = createQueryCollector(collectors, projectionFields); QuerySearchResult queryResult = searchContext.queryResult(); StreamResultFlightProducer.CollectorCallback collectorCallback = new StreamResultFlightProducer.CollectorCallback() { @Override @@ -113,7 +116,8 @@ public void collect(Collector queryCollector) throws IOException { try { searcher.search(query, queryCollector); } catch (EarlyTerminatingCollector.EarlyTerminationException e) { - // EarlyTerminationException is not caught in ContextIndexSearcher to allow force termination of collection. Postcollection + // EarlyTerminationException is not caught in ContextIndexSearcher to allow force termination of collection. + // Postcollection // still needs to be processed for Aggregations when early termination takes place. searchContext.bucketCollectorProcessor().processPostCollection(queryCollector); queryResult.terminatedEarly(true); @@ -133,15 +137,17 @@ public void collect(Collector queryCollector) throws IOException { } } }; + searchContext.setArrowCollector(collector); Ticket ticket = searchContext.flightService().getFlightProducer().createStream(collector, collectorCallback); StreamSearchResult streamSearchResult = searchContext.streamSearchResult(); streamSearchResult.flights(List.of(new OSTicket(ticket.getBytes()))); return false; } - public static ArrowCollector createQueryCollector(List collectors) throws IOException { + public static ArrowCollector createQueryCollector(List collectors, List projectionFields) + throws IOException { Collector collector = QueryCollectorContext.createQueryCollector(collectors); - return new ArrowCollector(collector, null, 1000); + return new ArrowCollector(collector, projectionFields, 1000); } } } diff --git a/server/src/main/java/org/opensearch/search/query/stream/StreamWriter.java b/server/src/main/java/org/opensearch/search/query/stream/StreamWriter.java index fdfc70a5e483f..c094707477c80 100644 --- a/server/src/main/java/org/opensearch/search/query/stream/StreamWriter.java +++ b/server/src/main/java/org/opensearch/search/query/stream/StreamWriter.java @@ -10,7 +10,6 @@ import org.apache.arrow.flight.BackpressureStrategy; import org.apache.arrow.flight.FlightProducer.ServerStreamListener; - import org.apache.arrow.vector.VectorSchemaRoot; import org.opensearch.common.annotation.ExperimentalApi; @@ -22,9 +21,7 @@ public class StreamWriter { private static final int timeout = 1000; private int batches = 0; - public StreamWriter(VectorSchemaRoot root, - BackpressureStrategy backpressureStrategy, - ServerStreamListener listener) { + public StreamWriter(VectorSchemaRoot root, BackpressureStrategy backpressureStrategy, ServerStreamListener listener) { this.backpressureStrategy = backpressureStrategy; this.listener = listener; this.root = root; @@ -39,6 +36,5 @@ public void writeBatch(int rowCount) { batches++; } - public void finish() { - } + public void finish() {} } diff --git a/server/src/main/java/org/opensearch/search/stream/StreamSearchResult.java b/server/src/main/java/org/opensearch/search/stream/StreamSearchResult.java index 8e9ae3c9b29d7..0a88bba191a30 100644 --- a/server/src/main/java/org/opensearch/search/stream/StreamSearchResult.java +++ b/server/src/main/java/org/opensearch/search/stream/StreamSearchResult.java @@ -9,7 +9,6 @@ package org.opensearch.search.stream; import org.opensearch.common.annotation.ExperimentalApi; - import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.search.SearchPhaseResult; @@ -67,7 +66,7 @@ public void setShardIndex(int shardIndex) { @Override public QuerySearchResult queryResult() { - return queryResult; + return queryResult; } public List getFlightTickets() { diff --git a/server/src/test/java/org/opensearch/index/mapper/DerivedFieldMapperQueryTests.java b/server/src/test/java/org/opensearch/index/mapper/DerivedFieldMapperQueryTests.java index 0d5ecbf8134c5..c744f2592e24f 100644 --- a/server/src/test/java/org/opensearch/index/mapper/DerivedFieldMapperQueryTests.java +++ b/server/src/test/java/org/opensearch/index/mapper/DerivedFieldMapperQueryTests.java @@ -435,7 +435,7 @@ public void execute() { query = geoShapeQuery("geopoint", new Rectangle(0.0, 55.0, 55.0, 0.0)).toQuery(queryShardContext); topDocs = searcher.search(query, 10); assertEquals(4, topDocs.totalHits.value); - } + } } } diff --git a/server/src/test/java/org/opensearch/search/DefaultSearchContextTests.java b/server/src/test/java/org/opensearch/search/DefaultSearchContextTests.java index 69b4a7c0c53f2..ab6f3cf5baa9c 100644 --- a/server/src/test/java/org/opensearch/search/DefaultSearchContextTests.java +++ b/server/src/test/java/org/opensearch/search/DefaultSearchContextTests.java @@ -806,7 +806,7 @@ protected Engine.Searcher acquireSearcherInternal(String source) { executor, null, Collections.emptyList(), - null + null ); context.evaluateRequestShouldUseConcurrentSearch(); assertFalse(context.shouldUseConcurrentSearch()); @@ -924,7 +924,7 @@ protected Engine.Searcher acquireSearcherInternal(String source) { executor, null, Collections.emptyList(), - null + null ); // Case1: if there is no agg in the query, non-concurrent path is used @@ -953,7 +953,7 @@ protected Engine.Searcher acquireSearcherInternal(String source) { executor, null, Collections.emptyList(), - null + null ); // add un-supported agg operation @@ -986,7 +986,7 @@ protected Engine.Searcher acquireSearcherInternal(String source) { executor, null, Collections.emptyList(), - null + null ); // create a supported agg operation context.aggregations(mockAggregations); @@ -1045,7 +1045,7 @@ public Optional create(IndexSettings indexSettin executor, null, concurrentSearchRequestDeciders, - null + null ); // create a supported agg operation context.aggregations(mockAggregations); @@ -1084,7 +1084,7 @@ public Optional create(IndexSettings indexSettin executor, null, concurrentSearchRequestDeciders, - null + null ); // create a supported agg operation @@ -1127,7 +1127,7 @@ public Optional create(IndexSettings indexSettin executor, null, concurrentSearchRequestDeciders, - null + null ); // create a supported agg operation diff --git a/server/src/test/java/org/opensearch/search/query/ArrowCollectorTests.java b/server/src/test/java/org/opensearch/search/query/ArrowCollectorTests.java index cc275acc6b512..8273510e96bc9 100644 --- a/server/src/test/java/org/opensearch/search/query/ArrowCollectorTests.java +++ b/server/src/test/java/org/opensearch/search/query/ArrowCollectorTests.java @@ -10,6 +10,9 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.VectorSchemaRoot; @@ -28,6 +31,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.opensearch.action.search.SearchShardTask; +import org.opensearch.arrow.FlightService; import org.opensearch.index.fielddata.IndexNumericFieldData; import org.opensearch.index.query.ParsedQuery; import org.opensearch.index.shard.IndexShard; @@ -38,12 +42,15 @@ import org.opensearch.test.TestSearchContext; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.LinkedList; import java.util.List; import java.util.concurrent.ExecutorService; +import static org.opensearch.search.query.stream.StreamSearchPhase.DEFAULT_STREAM_PHASE_SEARCHER; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -77,7 +84,7 @@ public void testArrow() throws Exception { IndexWriterConfig iwc = newIndexWriterConfig(); RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); - // FlightService flightService = new FlightService(); + FlightService flightService = new FlightService(); final int numDocs = scaledRandomIntBetween(100, 200); IndexReader reader = null; try { @@ -89,20 +96,35 @@ public void testArrow() throws Exception { } w.close(); reader = DirectoryReader.open(dir); - // flightService.start(); - TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader, null), null); + flightService.start(); + TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader, null), null, flightService); context.setSize(1000); context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); List projectionFields = new ArrayList<>(); projectionFields.add(new ProjectionField(IndexNumericFieldData.NumericType.LONG, "longpoint")); - QueryPhase.executeStreamInternal(context.withCleanQueryResult(), QueryPhase.STREAM_QUERY_PHASE_SEARCHER, projectionFields); - VectorSchemaRoot vectorSchemaRoot = context.getArrowCollector().getRootVector(); - System.out.println(vectorSchemaRoot.getSchema()); - Field longPoint = vectorSchemaRoot.getSchema().findField("longpoint"); - assertEquals(longPoint, new Field("longpoint", FieldType.nullable(new ArrowType.Int(64, true)), null)); - BigIntVector vector = (BigIntVector) vectorSchemaRoot.getVector("longpoint"); - assertEquals(vector.getValueCount(), numDocs); + DEFAULT_STREAM_PHASE_SEARCHER.searchWith( + context, + context.searcher(), + context.query(), + new LinkedList<>(), + projectionFields, + false, + false + ); + // QueryPhase.executeStreamInternal(context.withCleanQueryResult(), QueryPhase.STREAM_QUERY_PHASE_SEARCHER, projectionFields); + FlightStream flightStream = flightService.getFlightClient().getStream(new Ticket("id1".getBytes(StandardCharsets.UTF_8))); + System.out.println(flightStream.getSchema()); + System.out.println(flightStream.next()); + System.out.println(flightStream.getRoot().contentToTSVString()); + System.out.println(flightStream.getRoot().getRowCount()); + System.out.println(flightStream.next()); + // VectorSchemaRoot vectorSchemaRoot = context.getArrowCollector().getVectorSchemaRoot(new RootAllocator(Integer.MAX_VALUE)); + // System.out.println(vectorSchemaRoot.getSchema()); + // Field longPoint = vectorSchemaRoot.getSchema().findField("longpoint"); + // assertEquals(longPoint, new Field("longpoint", FieldType.nullable(new ArrowType.Int(64, true)), null)); + // BigIntVector vector = (BigIntVector) vectorSchemaRoot.getVector("longpoint"); + // assertEquals(vector.getValueCount(), numDocs); } finally { if (reader != null) reader.close(); dir.close(); @@ -137,8 +159,16 @@ public void testArrowMultipleFields() throws Exception { List projectionFields = new ArrayList<>(); projectionFields.add(new ProjectionField(IndexNumericFieldData.NumericType.LONG, "longpoint")); projectionFields.add(new ProjectionField(IndexNumericFieldData.NumericType.INT, "intpoint")); - QueryPhase.executeStreamInternal(context.withCleanQueryResult(), QueryPhase.STREAM_QUERY_PHASE_SEARCHER, projectionFields); - VectorSchemaRoot vectorSchemaRoot = context.getArrowCollector().getRootVector(); + DEFAULT_STREAM_PHASE_SEARCHER.searchWith( + context, + context.searcher(), + context.query(), + new LinkedList<>(), + projectionFields, + false, + false + ); + VectorSchemaRoot vectorSchemaRoot = context.getArrowCollector().getVectorSchemaRoot(new RootAllocator(Integer.MAX_VALUE)); System.out.println(vectorSchemaRoot.getSchema()); Field longPoint = vectorSchemaRoot.getSchema().findField("longpoint"); assertEquals(longPoint, new Field("longpoint", FieldType.nullable(new ArrowType.Int(64, true)), null)); diff --git a/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java b/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java index 34aeb466ae360..954e9b22c3c2a 100644 --- a/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java +++ b/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java @@ -556,8 +556,8 @@ public void testArrow() throws Exception { final int numDocs = scaledRandomIntBetween(100, 200); for (int i = 0; i < numDocs; ++i) { Document doc = new Document(); - doc.add(new StringField("joinField", Integer.toString(i%10), Store.NO)); - doc.add(new SortedSetDocValuesField("joinField", new BytesRef(Integer.toString(i%10)))); + doc.add(new StringField("joinField", Integer.toString(i % 10), Store.NO)); + doc.add(new SortedSetDocValuesField("joinField", new BytesRef(Integer.toString(i % 10)))); w.addDocument(doc); } w.close(); diff --git a/server/src/test/java/org/opensearch/search/query/StreamResultCollectorTests.java b/server/src/test/java/org/opensearch/search/query/StreamResultCollectorTests.java index 9e9d98f798f6a..cf3c4189c9c97 100644 --- a/server/src/test/java/org/opensearch/search/query/StreamResultCollectorTests.java +++ b/server/src/test/java/org/opensearch/search/query/StreamResultCollectorTests.java @@ -9,6 +9,7 @@ package org.opensearch.search.query; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Ticket; import org.apache.lucene.document.BinaryDocValuesField; @@ -35,7 +36,6 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; -import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.concurrent.ExecutorService; @@ -50,13 +50,13 @@ public class StreamResultCollectorTests extends IndexShardTestCase { @ParametersFactory public static Collection concurrency() { - return Collections.singletonList( - new Object[] { 0, QueryPhase.DEFAULT_QUERY_PHASE_SEARCHER } - ); + return Collections.singletonList(new Object[] { 0, QueryPhase.DEFAULT_QUERY_PHASE_SEARCHER }); } + public StreamResultCollectorTests(int concurrency, QueryPhaseSearcher queryPhaseSearcher) { this.queryPhaseSearcher = queryPhaseSearcher; } + @Override public void setUp() throws Exception { super.setUp(); @@ -80,8 +80,8 @@ public void testArrow() throws Exception { try { for (int i = 0; i < numDocs; ++i) { Document doc = new Document(); - doc.add(new StringField("joinField", Integer.toString(i%10), Field.Store.NO)); - doc.add(new BinaryDocValuesField("joinField", new BytesRef(Integer.toString(i%10)))); + doc.add(new StringField("joinField", Integer.toString(i % 10), Field.Store.NO)); + doc.add(new BinaryDocValuesField("joinField", new BytesRef(Integer.toString(i % 10)))); w.addDocument(doc); } w.close(); @@ -102,8 +102,7 @@ public void testArrow() throws Exception { System.out.println(flightStream.next()); flightStream.close(); } finally { - if (reader != null) - reader.close(); + if (reader != null) reader.close(); dir.close(); flightService.stop(); flightService.close(); diff --git a/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java b/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java index 7b7c02a80f046..c6bc4be950fe0 100644 --- a/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java +++ b/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java @@ -625,7 +625,7 @@ public QuerySearchResult queryResult() { @Override public StreamSearchResult streamSearchResult() { - return null; + return new StreamSearchResult(); } @Override