Skip to content

Commit

Permalink
Fix gradle
Browse files Browse the repository at this point in the history
Signed-off-by: Harsha Vamsi Kalluri <[email protected]>
harshavamsi committed Sep 18, 2024
1 parent db65b0e commit 7bc5f34
Showing 28 changed files with 252 additions and 465 deletions.
Original file line number Diff line number Diff line change
@@ -499,7 +499,7 @@ ReducedQueryPhase reducedQueryPhase(
for (SearchPhaseResult entry : queryResults) {
QuerySearchResult result = entry.queryResult();
if (entry instanceof StreamSearchResult) {
tickets.addAll(((StreamSearchResult)entry).getFlightTickets());
tickets.addAll(((StreamSearchResult) entry).getFlightTickets());
}
from = result.from();
// sorted queries can set the size to 0 if they have enough competitive hits.
@@ -728,7 +728,7 @@ public static final class ReducedQueryPhase {
this.from = from;
this.isEmptyResult = isEmptyResult;
this.sortValueFormats = sortValueFormats;
this.osTickets = osTickets;
this.osTickets = osTickets;
}

/**
Original file line number Diff line number Diff line change
@@ -243,7 +243,6 @@ public void sendExecuteQuery(
// we optimize this and expect a QueryFetchSearchResult if we only have a single shard in the search request
// this used to be the QUERY_AND_FETCH which doesn't exist anymore.


if (request.isStreamRequest()) {
Writeable.Reader<SearchPhaseResult> reader = StreamSearchResult::new;
final ActionListener handler = responseWrapper.apply(connection, listener);
Original file line number Diff line number Diff line change
@@ -89,7 +89,7 @@ public static SearchType fromId(byte id) {
} else if (id == 1 || id == 3) { // TODO this bwc layer can be removed once this is back-ported to 5.3 QUERY_AND_FETCH is removed
// now
return QUERY_THEN_FETCH;
} else if (id == 5) {
} else if (id == 5) {
return STREAM;
} else {
throw new IllegalArgumentException("No search type for [" + id + "]");
Original file line number Diff line number Diff line change
@@ -33,30 +33,19 @@
package org.opensearch.action.search;

import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.TopFieldDocs;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.routing.GroupShardsIterator;
import org.opensearch.common.util.concurrent.AbstractRunnable;
import org.opensearch.common.util.concurrent.AtomicArray;
import org.opensearch.core.action.ActionListener;
import org.opensearch.search.SearchExtBuilder;
import org.opensearch.search.SearchHits;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.internal.AliasFilter;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.profile.SearchProfileShardResults;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.stream.OSTicket;
import org.opensearch.search.stream.StreamSearchResult;
import org.opensearch.search.suggest.Suggest;
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.transport.Transport;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@@ -72,8 +61,46 @@
*/
class StreamAsyncAction extends SearchQueryThenFetchAsyncAction {

public StreamAsyncAction(Logger logger, SearchTransportService searchTransportService, BiFunction<String, String, Transport.Connection> nodeIdToConnection, Map<String, AliasFilter> aliasFilter, Map<String, Float> concreteIndexBoosts, Map<String, Set<String>> indexRoutings, SearchPhaseController searchPhaseController, Executor executor, QueryPhaseResultConsumer resultConsumer, SearchRequest request, ActionListener<SearchResponse> listener, GroupShardsIterator<SearchShardIterator> shardsIts, TransportSearchAction.SearchTimeProvider timeProvider, ClusterState clusterState, SearchTask task, SearchResponse.Clusters clusters, SearchRequestContext searchRequestContext, Tracer tracer) {
super(logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings, searchPhaseController, executor, resultConsumer, request, listener, shardsIts, timeProvider, clusterState, task, clusters, searchRequestContext, tracer);
public StreamAsyncAction(
Logger logger,
SearchTransportService searchTransportService,
BiFunction<String, String, Transport.Connection> nodeIdToConnection,
Map<String, AliasFilter> aliasFilter,
Map<String, Float> concreteIndexBoosts,
Map<String, Set<String>> indexRoutings,
SearchPhaseController searchPhaseController,
Executor executor,
QueryPhaseResultConsumer resultConsumer,
SearchRequest request,
ActionListener<SearchResponse> listener,
GroupShardsIterator<SearchShardIterator> shardsIts,
TransportSearchAction.SearchTimeProvider timeProvider,
ClusterState clusterState,
SearchTask task,
SearchResponse.Clusters clusters,
SearchRequestContext searchRequestContext,
Tracer tracer
) {
super(
logger,
searchTransportService,
nodeIdToConnection,
aliasFilter,
concreteIndexBoosts,
indexRoutings,
searchPhaseController,
executor,
resultConsumer,
request,
listener,
shardsIts,
timeProvider,
clusterState,
task,
clusters,
searchRequestContext,
tracer
);
}

@Override
@@ -83,6 +110,7 @@ protected SearchPhase getNextPhase(final SearchPhaseResults<SearchPhaseResult> r

class StreamSearchReducePhase extends SearchPhase {
private SearchPhaseContext context;

protected StreamSearchReducePhase(String name, SearchPhaseContext context) {
super(name);
this.context = context;
@@ -97,10 +125,12 @@ public void run() {
class StreamReduceAction extends AbstractRunnable {
private SearchPhaseContext context;
private SearchPhase phase;

StreamReduceAction(SearchPhaseContext context, SearchPhase phase) {
this.context = context;

}

@Override
protected void doRun() throws Exception {
List<OSTicket> tickets = new ArrayList<>();
@@ -109,7 +139,17 @@ protected void doRun() throws Exception {
tickets.addAll(((StreamSearchResult) entry).getFlightTickets());
}
}
InternalSearchResponse internalSearchResponse = new InternalSearchResponse(SearchHits.empty(),null, null, null, false, false, 1, Collections.emptyList(), tickets);
InternalSearchResponse internalSearchResponse = new InternalSearchResponse(
SearchHits.empty(),
null,
null,
null,
false,
false,
1,
Collections.emptyList(),
tickets
);
context.sendSearchResponse(internalSearchResponse, results.getAtomicArray());
}

Original file line number Diff line number Diff line change
@@ -124,7 +124,8 @@
import java.util.stream.StreamSupport;

import static org.opensearch.action.admin.cluster.node.tasks.get.GetTaskAction.TASKS_ORIGIN;
import static org.opensearch.action.search.SearchType.*;
import static org.opensearch.action.search.SearchType.DFS_QUERY_THEN_FETCH;
import static org.opensearch.action.search.SearchType.QUERY_THEN_FETCH;
import static org.opensearch.search.sort.FieldSortBuilder.hasPrimaryFieldSort;

/**
19 changes: 9 additions & 10 deletions server/src/main/java/org/opensearch/bootstrap/Security.java
Original file line number Diff line number Diff line change
@@ -40,7 +40,6 @@
import org.opensearch.http.HttpTransportSettings;
import org.opensearch.plugins.PluginInfo;
import org.opensearch.plugins.PluginsService;
import org.opensearch.secure_sm.SecureSM;
import org.opensearch.transport.TcpTransport;

import java.io.IOException;
@@ -138,15 +137,15 @@ static void configure(Environment environment, boolean filterBadDefaults) throws

// enable security policy: union of template and environment-based paths, and possibly plugin permissions
Map<String, URL> codebases = getCodebaseJarMap(JarHell.parseClassPath());
// Policy.setPolicy(
// new OpenSearchPolicy(
// codebases,
// createPermissions(environment),
// getPluginPermissions(environment),
// filterBadDefaults,
// createRecursiveDataPathPermission(environment)
// )
// );
// Policy.setPolicy(
// new OpenSearchPolicy(
// codebases,
// createPermissions(environment),
// getPluginPermissions(environment),
// filterBadDefaults,
// createRecursiveDataPathPermission(environment)
// )
// );

// enable security manager
final String[] classesThatCanExit = new String[] {
Original file line number Diff line number Diff line change
@@ -872,7 +872,6 @@ public void executeQueryPhase(
}, wrapFailureListener(listener, readerContext, markAsUsed));
}


public void executeStreamPhase(QuerySearchRequest request, SearchShardTask task, ActionListener<StreamSearchResult> listener) {
final ReaderContext readerContext = findReaderContext(request.contextId(), request.shardSearchRequest());
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.shardSearchRequest());
Original file line number Diff line number Diff line change
@@ -114,7 +114,7 @@ public InternalSearchResponse(StreamInput in) throws IOException {
in.readOptionalWriteable(SearchProfileShardResults::new),
in.readVInt(),
readSearchExtBuildersOnOrAfter(in),
(in.readBoolean()? in.readList(OSTicket::new): null)
(in.readBoolean() ? in.readList(OSTicket::new) : null)
);
}

Original file line number Diff line number Diff line change
@@ -31,10 +31,13 @@

package org.opensearch.search.internal;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.opensearch.action.search.SearchShardTask;
import org.opensearch.action.search.SearchType;
import org.opensearch.arrow.FlightService;
@@ -81,6 +84,7 @@
import org.opensearch.search.stream.StreamSearchResult;
import org.opensearch.search.suggest.SuggestionSearchContext;

import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
@@ -124,7 +128,17 @@ public List<InternalAggregation> toInternalAggregations(Collection<Collector> co
}
};

public static final ArrowCollector NO_OP_ARROW_COLLECTOR = new ArrowCollector();
public static final ArrowCollector NO_OP_ARROW_COLLECTOR = new ArrowCollector(new Collector() {
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
return null;
}

@Override
public ScoreMode scoreMode() {
return null;
}
}, 1000);

private final List<Releasable> releasables = new CopyOnWriteArrayList<>();
private final AtomicBoolean closed = new AtomicBoolean(false);
Original file line number Diff line number Diff line change
@@ -51,7 +51,7 @@
* @opensearch.api
*/
@PublicApi(since = "1.0.0")
public class /**/SearchLookup {
public class /**/ SearchLookup {
/**
* The maximum depth of field dependencies.
* When a runtime field's doc values depends on another runtime field's doc values,
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@
import org.apache.lucene.search.ScoreMode;
import org.opensearch.common.annotation.ExperimentalApi;
import org.opensearch.index.fielddata.IndexNumericFieldData;
import org.opensearch.search.query.stream.StreamCollector;

import java.io.IOException;
import java.util.ArrayList;
@@ -49,17 +50,28 @@ public class ArrowCollector extends StreamCollector {
List<ProjectionField> projectionFields;
VectorSchemaRoot root;

final int BATCH_SIZE = 1000;
Map<String, FieldVector> vectors;

public ArrowCollector(){
super();
private final int batchSize;

public ArrowCollector(Collector in, int batchSize) {
super(in);
this.batchSize = batchSize;
}

public ArrowCollector(Collector in, List<ProjectionField> projectionFields, int batchSize) {
// super(delegateCollector);
super(in, batchSize);
allocator = new RootAllocator();
this.projectionFields = projectionFields;
this.batchSize = batchSize;
}

public Map<String, FieldVector> getVectors() {
return this.vectors;
}

public void setVectors(Map<String, FieldVector> vectors) {
this.vectors = vectors;
}

private Field createArrowField(String fieldName, IndexNumericFieldData.NumericType type) {
@@ -162,6 +174,7 @@ public LeafCollector getLeafCollector(LeafReaderContext context) throws IOExcept
}
iterators.put(field.fieldName, numericDocValues[0]);
});
setVectors(vectors);
schema = new Schema(arrowFields.values());
root = new VectorSchemaRoot(new ArrayList<>(arrowFields.values()), new ArrayList<>(vectors.values()));
final int[] i = { 0 };
@@ -184,8 +197,8 @@ public void collect(int docId) throws IOException {
}
if (iterator.advanceExact(docId)) {
index[0] = i[0] / iterators.size();
if (index[0] > BATCH_SIZE || vector.getValueCapacity() == 0) {
vector.allocateNew(BATCH_SIZE);
if (index[0] > batchSize || vector.getValueCapacity() == 0) {
vector.allocateNew(batchSize);
}
setValue(vector, index[0], iterator.longValue());
i[0]++;
@@ -205,20 +218,18 @@ public void finish() throws IOException {

@Override
public void onNewBatch() {

for (FieldVector vector : getVectors().values()) {
((BaseFixedWidthVector) vector).allocateNew(batchSize);
}
}

@Override
public VectorSchemaRoot getVectorSchemaRoot(BufferAllocator allocator) {
return null;
return root;
}

@Override
public ScoreMode scoreMode() {
return ScoreMode.COMPLETE_NO_SCORES;
}

public VectorSchemaRoot getRootVector() {
return root;
}
}
Original file line number Diff line number Diff line change
@@ -47,7 +47,7 @@
* @opensearch.internal
*/
public class EarlyTerminatingCollector extends FilterCollector {
static final class EarlyTerminationException extends RuntimeException {
public static final class EarlyTerminationException extends RuntimeException {
EarlyTerminationException(String msg) {
super(msg);
}
Original file line number Diff line number Diff line change
@@ -135,7 +135,7 @@ protected InternalProfileCollectorManager createWithProfiler(InternalProfileColl
*
* @param result The query search result to populate
*/
void postProcess(QuerySearchResult result) throws IOException {}
public void postProcess(QuerySearchResult result) throws IOException {}

/**
* Creates the collector tree from the provided <code>collectors</code>
271 changes: 35 additions & 236 deletions server/src/main/java/org/opensearch/search/query/QueryPhase.java

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -211,7 +211,7 @@ Collector create(Collector in) {
}

@Override
void postProcess(QuerySearchResult result) {
public void postProcess(QuerySearchResult result) {
final TotalHits totalHitCount = hitCountSupplier.get();
final TopDocs topDocs;
if (sort != null) {
@@ -273,7 +273,7 @@ Collector create(Collector in) throws IOException {
}

@Override
void postProcess(QuerySearchResult result) throws IOException {
public void postProcess(QuerySearchResult result) throws IOException {
final CollapseTopFieldDocs topDocs = topDocsCollector.getTopDocs();
result.topDocs(new TopDocsAndMaxScore(topDocs, maxScoreSupplier.get()), sortFmt);
}
@@ -619,7 +619,7 @@ TopDocsAndMaxScore newTopDocs() {
}

@Override
void postProcess(QuerySearchResult result) throws IOException {
public void postProcess(QuerySearchResult result) throws IOException {
final TopDocsAndMaxScore topDocs = newTopDocs();
result.topDocs(topDocs, sortAndFormats == null ? null : sortAndFormats.formats);
}
@@ -684,7 +684,7 @@ protected ReduceableSearchResult reduceWith(final TopDocs topDocs, final float m
}

@Override
void postProcess(QuerySearchResult result) throws IOException {
public void postProcess(QuerySearchResult result) throws IOException {
final TopDocsAndMaxScore topDocs = newTopDocs();
if (scrollContext.totalHits == null) {
// first round
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@
* compatible open source license.
*/

package org.opensearch.search.query;
package org.opensearch.search.query.stream;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
@@ -16,7 +16,6 @@
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Scorable;
import org.opensearch.common.annotation.ExperimentalApi;
import org.opensearch.search.query.stream.StreamWriter;

import java.io.IOException;

@@ -27,14 +26,18 @@ public abstract class StreamCollector extends FilterCollector {
private int docsInCurrentBatch;
private StreamWriter streamWriter = null;

public StreamCollector(Collector collector) {
this(collector, 1000);
}

public StreamCollector(Collector collector, int batchSize) {
super(collector);
this.batchSize = batchSize;
docsInCurrentBatch = 0;
}

public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
LeafCollector leafCollector =((this.in != null)? super.getLeafCollector(context): null);
LeafCollector leafCollector = ((this.in != null) ? super.getLeafCollector(context) : null);
return new LeafCollector() {
@Override
public void setScorer(Scorable scorable) throws IOException {
Original file line number Diff line number Diff line change
@@ -8,20 +8,25 @@

package org.opensearch.search.query.stream;

import io.grpc.internal.ServerStreamListener;
import org.apache.arrow.flight.BackpressureStrategy;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.vector.VectorSchemaRoot;

import io.grpc.internal.ServerStreamListener;

public class StreamContext {

private VectorSchemaRoot vectorSchemaRoot;
private FlightDescriptor flightDescriptor;
private ServerStreamListener listener;
private BackpressureStrategy backpressureStrategy;

public StreamContext(VectorSchemaRoot vectorSchemaRoot, FlightDescriptor flightDescriptor, ServerStreamListener listener,
BackpressureStrategy backpressureStrategy) {
public StreamContext(
VectorSchemaRoot vectorSchemaRoot,
FlightDescriptor flightDescriptor,
ServerStreamListener listener,
BackpressureStrategy backpressureStrategy
) {
this.vectorSchemaRoot = vectorSchemaRoot;
this.flightDescriptor = flightDescriptor;
this.listener = listener;
@@ -31,12 +36,15 @@ public StreamContext(VectorSchemaRoot vectorSchemaRoot, FlightDescriptor flightD
public VectorSchemaRoot getVectorSchemaRoot() {
return vectorSchemaRoot;
}

public FlightDescriptor getFlightDescriptor() {
return flightDescriptor;
}

public ServerStreamListener getListener() {
return listener;
}

public BackpressureStrategy getBackpressureStrategy() {
return backpressureStrategy;
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -8,12 +8,14 @@

package org.opensearch.search.query.stream;

import org.apache.arrow.flight.*;
import org.apache.arrow.flight.BackpressureStrategy;
import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.NoOpFlightProducer;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.lucene.search.Collector;
import org.opensearch.common.annotation.ExperimentalApi;
import org.opensearch.search.query.StreamCollector;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
@@ -39,7 +41,7 @@ public Ticket createStream(StreamCollector streamCollector, CollectorCallback ca
}

@Override
public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
public void getStream(FlightProducer.CallContext context, Ticket ticket, ServerStreamListener listener) {
if (lookup.get(ticket) == null) {
listener.error(new IllegalStateException("Data not ready"));
return;
@@ -70,6 +72,7 @@ public void getStream(CallContext context, Ticket ticket, ServerStreamListener l
static class StreamState {
StreamCollector streamCollector;
CollectorCallback collectorCallback;

StreamState(StreamCollector streamCollector, CollectorCallback collectorCallback) {
this.streamCollector = streamCollector;
this.collectorCallback = collectorCallback;
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@
import org.opensearch.search.profile.SearchProfileShardResults;
import org.opensearch.search.query.ArrowCollector;
import org.opensearch.search.query.EarlyTerminatingCollector;
import org.opensearch.search.query.ProjectionField;
import org.opensearch.search.query.QueryCollectorContext;
import org.opensearch.search.query.QueryPhase;
import org.opensearch.search.query.QueryPhaseExecutionException;
@@ -35,10 +36,10 @@

public class StreamSearchPhase extends QueryPhase {
private static final Logger LOGGER = LogManager.getLogger(StreamSearchPhase.class);
public static final QueryPhaseSearcher DEFAULT_QUERY_PHASE_SEARCHER = new DefaultStreamSearchPhaseSearcher();
public static final QueryPhaseSearcher DEFAULT_STREAM_PHASE_SEARCHER = new DefaultStreamSearchPhaseSearcher();

public StreamSearchPhase() {
super(DEFAULT_QUERY_PHASE_SEARCHER);
super(DEFAULT_STREAM_PHASE_SEARCHER);
}

@Override
@@ -57,7 +58,6 @@ public void execute(SearchContext searchContext) throws QueryPhaseExecutionExcep
}
}


public static class DefaultStreamSearchPhaseSearcher extends DefaultQueryPhaseSearcher {

@Override
@@ -66,10 +66,11 @@ public boolean searchWith(
ContextIndexSearcher searcher,
Query query,
LinkedList<QueryCollectorContext> collectors,
List<ProjectionField> projectionFields,
boolean hasFilterCollector,
boolean hasTimeout
) throws IOException {
return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
return searchWithCollector(searchContext, searcher, query, collectors, projectionFields, hasFilterCollector, hasTimeout);
}

@Override
@@ -87,33 +88,36 @@ public void postProcess(SearchContext context) {
};
}

protected boolean searchWithCollector(
private static boolean searchWithCollector(
SearchContext searchContext,
ContextIndexSearcher searcher,
Query query,
LinkedList<QueryCollectorContext> collectors,
List<ProjectionField> projectionFields,
boolean hasFilterCollector,
boolean hasTimeout
) throws IOException {
return searchWithCollector(searchContext, searcher, query, collectors, hasTimeout);
return searchWithCollector(searchContext, searcher, query, collectors, projectionFields, hasTimeout);
}

private boolean searchWithCollector(
private static boolean searchWithCollector(
SearchContext searchContext,
ContextIndexSearcher searcher,
Query query,
LinkedList<QueryCollectorContext> collectors,
List<ProjectionField> projectionFields,
boolean timeoutSet
) throws IOException {
final ArrowCollector collector = createQueryCollector(collectors);
final ArrowCollector collector = createQueryCollector(collectors, projectionFields);
QuerySearchResult queryResult = searchContext.queryResult();
StreamResultFlightProducer.CollectorCallback collectorCallback = new StreamResultFlightProducer.CollectorCallback() {
@Override
public void collect(Collector queryCollector) throws IOException {
try {
searcher.search(query, queryCollector);
} catch (EarlyTerminatingCollector.EarlyTerminationException e) {
// EarlyTerminationException is not caught in ContextIndexSearcher to allow force termination of collection. Postcollection
// EarlyTerminationException is not caught in ContextIndexSearcher to allow force termination of collection.
// Postcollection
// still needs to be processed for Aggregations when early termination takes place.
searchContext.bucketCollectorProcessor().processPostCollection(queryCollector);
queryResult.terminatedEarly(true);
@@ -133,15 +137,17 @@ public void collect(Collector queryCollector) throws IOException {
}
}
};
searchContext.setArrowCollector(collector);
Ticket ticket = searchContext.flightService().getFlightProducer().createStream(collector, collectorCallback);
StreamSearchResult streamSearchResult = searchContext.streamSearchResult();
streamSearchResult.flights(List.of(new OSTicket(ticket.getBytes())));
return false;
}

public static ArrowCollector createQueryCollector(List<QueryCollectorContext> collectors) throws IOException {
public static ArrowCollector createQueryCollector(List<QueryCollectorContext> collectors, List<ProjectionField> projectionFields)
throws IOException {
Collector collector = QueryCollectorContext.createQueryCollector(collectors);
return new ArrowCollector(collector, null, 1000);
return new ArrowCollector(collector, projectionFields, 1000);
}
}
}
Original file line number Diff line number Diff line change
@@ -10,7 +10,6 @@

import org.apache.arrow.flight.BackpressureStrategy;
import org.apache.arrow.flight.FlightProducer.ServerStreamListener;

import org.apache.arrow.vector.VectorSchemaRoot;
import org.opensearch.common.annotation.ExperimentalApi;

@@ -22,9 +21,7 @@ public class StreamWriter {
private static final int timeout = 1000;
private int batches = 0;

public StreamWriter(VectorSchemaRoot root,
BackpressureStrategy backpressureStrategy,
ServerStreamListener listener) {
public StreamWriter(VectorSchemaRoot root, BackpressureStrategy backpressureStrategy, ServerStreamListener listener) {
this.backpressureStrategy = backpressureStrategy;
this.listener = listener;
this.root = root;
@@ -39,6 +36,5 @@ public void writeBatch(int rowCount) {
batches++;
}

public void finish() {
}
public void finish() {}
}
Original file line number Diff line number Diff line change
@@ -9,7 +9,6 @@
package org.opensearch.search.stream;

import org.opensearch.common.annotation.ExperimentalApi;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.search.SearchPhaseResult;
@@ -67,7 +66,7 @@ public void setShardIndex(int shardIndex) {

@Override
public QuerySearchResult queryResult() {
return queryResult;
return queryResult;
}

public List<OSTicket> getFlightTickets() {
Original file line number Diff line number Diff line change
@@ -435,7 +435,7 @@ public void execute() {
query = geoShapeQuery("geopoint", new Rectangle(0.0, 55.0, 55.0, 0.0)).toQuery(queryShardContext);
topDocs = searcher.search(query, 10);
assertEquals(4, topDocs.totalHits.value);
}
}
}
}

Original file line number Diff line number Diff line change
@@ -806,7 +806,7 @@ protected Engine.Searcher acquireSearcherInternal(String source) {
executor,
null,
Collections.emptyList(),
null
null
);
context.evaluateRequestShouldUseConcurrentSearch();
assertFalse(context.shouldUseConcurrentSearch());
@@ -924,7 +924,7 @@ protected Engine.Searcher acquireSearcherInternal(String source) {
executor,
null,
Collections.emptyList(),
null
null
);

// Case1: if there is no agg in the query, non-concurrent path is used
@@ -953,7 +953,7 @@ protected Engine.Searcher acquireSearcherInternal(String source) {
executor,
null,
Collections.emptyList(),
null
null
);

// add un-supported agg operation
@@ -986,7 +986,7 @@ protected Engine.Searcher acquireSearcherInternal(String source) {
executor,
null,
Collections.emptyList(),
null
null
);
// create a supported agg operation
context.aggregations(mockAggregations);
@@ -1045,7 +1045,7 @@ public Optional<ConcurrentSearchRequestDecider> create(IndexSettings indexSettin
executor,
null,
concurrentSearchRequestDeciders,
null
null
);
// create a supported agg operation
context.aggregations(mockAggregations);
@@ -1084,7 +1084,7 @@ public Optional<ConcurrentSearchRequestDecider> create(IndexSettings indexSettin
executor,
null,
concurrentSearchRequestDeciders,
null
null
);

// create a supported agg operation
@@ -1127,7 +1127,7 @@ public Optional<ConcurrentSearchRequestDecider> create(IndexSettings indexSettin
executor,
null,
concurrentSearchRequestDeciders,
null
null
);

// create a supported agg operation
Original file line number Diff line number Diff line change
@@ -10,6 +10,9 @@

import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VectorSchemaRoot;
@@ -28,6 +31,7 @@
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.opensearch.action.search.SearchShardTask;
import org.opensearch.arrow.FlightService;
import org.opensearch.index.fielddata.IndexNumericFieldData;
import org.opensearch.index.query.ParsedQuery;
import org.opensearch.index.shard.IndexShard;
@@ -38,12 +42,15 @@
import org.opensearch.test.TestSearchContext;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.ExecutorService;

import static org.opensearch.search.query.stream.StreamSearchPhase.DEFAULT_STREAM_PHASE_SEARCHER;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

@@ -77,7 +84,7 @@ public void testArrow() throws Exception {
IndexWriterConfig iwc = newIndexWriterConfig();
RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc);

// FlightService flightService = new FlightService();
FlightService flightService = new FlightService();
final int numDocs = scaledRandomIntBetween(100, 200);
IndexReader reader = null;
try {
@@ -89,20 +96,35 @@ public void testArrow() throws Exception {
}
w.close();
reader = DirectoryReader.open(dir);
// flightService.start();
TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader, null), null);
flightService.start();
TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader, null), null, flightService);
context.setSize(1000);
context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery()));
context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()));
List<ProjectionField> projectionFields = new ArrayList<>();
projectionFields.add(new ProjectionField(IndexNumericFieldData.NumericType.LONG, "longpoint"));
QueryPhase.executeStreamInternal(context.withCleanQueryResult(), QueryPhase.STREAM_QUERY_PHASE_SEARCHER, projectionFields);
VectorSchemaRoot vectorSchemaRoot = context.getArrowCollector().getRootVector();
System.out.println(vectorSchemaRoot.getSchema());
Field longPoint = vectorSchemaRoot.getSchema().findField("longpoint");
assertEquals(longPoint, new Field("longpoint", FieldType.nullable(new ArrowType.Int(64, true)), null));
BigIntVector vector = (BigIntVector) vectorSchemaRoot.getVector("longpoint");
assertEquals(vector.getValueCount(), numDocs);
DEFAULT_STREAM_PHASE_SEARCHER.searchWith(
context,
context.searcher(),
context.query(),
new LinkedList<>(),
projectionFields,
false,
false
);
// QueryPhase.executeStreamInternal(context.withCleanQueryResult(), QueryPhase.STREAM_QUERY_PHASE_SEARCHER, projectionFields);
FlightStream flightStream = flightService.getFlightClient().getStream(new Ticket("id1".getBytes(StandardCharsets.UTF_8)));
System.out.println(flightStream.getSchema());
System.out.println(flightStream.next());
System.out.println(flightStream.getRoot().contentToTSVString());
System.out.println(flightStream.getRoot().getRowCount());
System.out.println(flightStream.next());
// VectorSchemaRoot vectorSchemaRoot = context.getArrowCollector().getVectorSchemaRoot(new RootAllocator(Integer.MAX_VALUE));
// System.out.println(vectorSchemaRoot.getSchema());
// Field longPoint = vectorSchemaRoot.getSchema().findField("longpoint");
// assertEquals(longPoint, new Field("longpoint", FieldType.nullable(new ArrowType.Int(64, true)), null));
// BigIntVector vector = (BigIntVector) vectorSchemaRoot.getVector("longpoint");
// assertEquals(vector.getValueCount(), numDocs);
} finally {
if (reader != null) reader.close();
dir.close();
@@ -137,8 +159,16 @@ public void testArrowMultipleFields() throws Exception {
List<ProjectionField> projectionFields = new ArrayList<>();
projectionFields.add(new ProjectionField(IndexNumericFieldData.NumericType.LONG, "longpoint"));
projectionFields.add(new ProjectionField(IndexNumericFieldData.NumericType.INT, "intpoint"));
QueryPhase.executeStreamInternal(context.withCleanQueryResult(), QueryPhase.STREAM_QUERY_PHASE_SEARCHER, projectionFields);
VectorSchemaRoot vectorSchemaRoot = context.getArrowCollector().getRootVector();
DEFAULT_STREAM_PHASE_SEARCHER.searchWith(
context,
context.searcher(),
context.query(),
new LinkedList<>(),
projectionFields,
false,
false
);
VectorSchemaRoot vectorSchemaRoot = context.getArrowCollector().getVectorSchemaRoot(new RootAllocator(Integer.MAX_VALUE));
System.out.println(vectorSchemaRoot.getSchema());
Field longPoint = vectorSchemaRoot.getSchema().findField("longpoint");
assertEquals(longPoint, new Field("longpoint", FieldType.nullable(new ArrowType.Int(64, true)), null));
Original file line number Diff line number Diff line change
@@ -556,8 +556,8 @@ public void testArrow() throws Exception {
final int numDocs = scaledRandomIntBetween(100, 200);
for (int i = 0; i < numDocs; ++i) {
Document doc = new Document();
doc.add(new StringField("joinField", Integer.toString(i%10), Store.NO));
doc.add(new SortedSetDocValuesField("joinField", new BytesRef(Integer.toString(i%10))));
doc.add(new StringField("joinField", Integer.toString(i % 10), Store.NO));
doc.add(new SortedSetDocValuesField("joinField", new BytesRef(Integer.toString(i % 10))));
w.addDocument(doc);
}
w.close();
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
package org.opensearch.search.query;

import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.Ticket;
import org.apache.lucene.document.BinaryDocValuesField;
@@ -35,7 +36,6 @@

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.concurrent.ExecutorService;
@@ -50,13 +50,13 @@ public class StreamResultCollectorTests extends IndexShardTestCase {

@ParametersFactory
public static Collection<Object[]> concurrency() {
return Collections.singletonList(
new Object[] { 0, QueryPhase.DEFAULT_QUERY_PHASE_SEARCHER }
);
return Collections.singletonList(new Object[] { 0, QueryPhase.DEFAULT_QUERY_PHASE_SEARCHER });
}

public StreamResultCollectorTests(int concurrency, QueryPhaseSearcher queryPhaseSearcher) {
this.queryPhaseSearcher = queryPhaseSearcher;
}

@Override
public void setUp() throws Exception {
super.setUp();
@@ -80,8 +80,8 @@ public void testArrow() throws Exception {
try {
for (int i = 0; i < numDocs; ++i) {
Document doc = new Document();
doc.add(new StringField("joinField", Integer.toString(i%10), Field.Store.NO));
doc.add(new BinaryDocValuesField("joinField", new BytesRef(Integer.toString(i%10))));
doc.add(new StringField("joinField", Integer.toString(i % 10), Field.Store.NO));
doc.add(new BinaryDocValuesField("joinField", new BytesRef(Integer.toString(i % 10))));
w.addDocument(doc);
}
w.close();
@@ -102,8 +102,7 @@ public void testArrow() throws Exception {
System.out.println(flightStream.next());
flightStream.close();
} finally {
if (reader != null)
reader.close();
if (reader != null) reader.close();
dir.close();
flightService.stop();
flightService.close();
Original file line number Diff line number Diff line change
@@ -625,7 +625,7 @@ public QuerySearchResult queryResult() {

@Override
public StreamSearchResult streamSearchResult() {
return null;
return new StreamSearchResult();
}

@Override

0 comments on commit 7bc5f34

Please sign in to comment.