diff --git a/docs/changelog/112929.yaml b/docs/changelog/112929.yaml new file mode 100644 index 0000000000000..e5f49897432de --- /dev/null +++ b/docs/changelog/112929.yaml @@ -0,0 +1,5 @@ +pr: 112929 +summary: "ES|QL: Add support for cached strings in plan serialization" +area: ES|QL +type: enhancement +issues: [] diff --git a/docs/changelog/112938.yaml b/docs/changelog/112938.yaml new file mode 100644 index 0000000000000..82b98871c3352 --- /dev/null +++ b/docs/changelog/112938.yaml @@ -0,0 +1,35 @@ +pr: 112938 +summary: Enhance SORT push-down to Lucene to cover references to fields and ST_DISTANCE function +area: ES|QL +type: enhancement +issues: + - 109973 +highlight: + title: Enhance SORT push-down to Lucene to cover references to fields and ST_DISTANCE function + body: |- + The most used and likely most valuable geospatial search query in Elasticsearch is the sorted proximity search, + finding items within a certain distance of a point of interest and sorting the results by distance. + This has been possible in ES|QL since 8.15.0, but the sorting was done in-memory, not pushed down to Lucene. + Now the sorting is pushed down to Lucene, which results in a significant performance improvement. + + Queries that perform both filtering and sorting on distance are supported. For example: + + [source,esql] + ---- + FROM test + | EVAL distance = ST_DISTANCE(location, TO_GEOPOINT("POINT(37.7749, -122.4194)")) + | WHERE distance < 1000000 + | SORT distance ASC, name DESC + | LIMIT 10 + ---- + + In addition, the support for sorting on EVAL expressions has been extended to cover references to fields: + + [source,esql] + ---- + FROM test + | EVAL ref = field + | SORT ref ASC + | LIMIT 10 + ---- + notable: false diff --git a/docs/changelog/114411.yaml b/docs/changelog/114411.yaml new file mode 100644 index 0000000000000..23bff3c8e25ba --- /dev/null +++ b/docs/changelog/114411.yaml @@ -0,0 +1,5 @@ +pr: 114411 +summary: "ESQL: Push down filters even in case of renames in Evals" +area: ES|QL +type: enhancement +issues: [] diff --git a/docs/changelog/114533.yaml b/docs/changelog/114533.yaml new file mode 100644 index 0000000000000..f45589e8de921 --- /dev/null +++ b/docs/changelog/114533.yaml @@ -0,0 +1,5 @@ +pr: 114533 +summary: Fix dim validation for bit `element_type` +area: Vector Search +type: bug +issues: [] diff --git a/docs/changelog/114552.yaml b/docs/changelog/114552.yaml new file mode 100644 index 0000000000000..00e2f95b5038d --- /dev/null +++ b/docs/changelog/114552.yaml @@ -0,0 +1,5 @@ +pr: 114552 +summary: Improve exception message for bad environment variable placeholders in settings +area: Infra/Settings +type: enhancement +issues: [110858] diff --git a/modules/data-streams/src/main/java/org/elasticsearch/datastreams/lifecycle/action/DeleteDataStreamLifecycleAction.java b/modules/data-streams/src/main/java/org/elasticsearch/datastreams/lifecycle/action/DeleteDataStreamLifecycleAction.java index 1c4659efc2f8b..1595348649528 100644 --- a/modules/data-streams/src/main/java/org/elasticsearch/datastreams/lifecycle/action/DeleteDataStreamLifecycleAction.java +++ b/modules/data-streams/src/main/java/org/elasticsearch/datastreams/lifecycle/action/DeleteDataStreamLifecycleAction.java @@ -34,7 +34,26 @@ private DeleteDataStreamLifecycleAction() {/* no instances */} public static final class Request extends AcknowledgedRequest implements IndicesRequest.Replaceable { private String[] names; - private IndicesOptions indicesOptions = IndicesOptions.fromOptions(false, true, true, true, false, false, true, false); + private IndicesOptions indicesOptions = IndicesOptions.builder() + .concreteTargetOptions(IndicesOptions.ConcreteTargetOptions.ERROR_WHEN_UNAVAILABLE_TARGETS) + .wildcardOptions( + IndicesOptions.WildcardOptions.builder() + .matchOpen(true) + .matchClosed(true) + .includeHidden(false) + .resolveAliases(false) + .allowEmptyExpressions(true) + .build() + ) + .gatekeeperOptions( + IndicesOptions.GatekeeperOptions.builder() + .allowAliasToMultipleIndices(false) + .allowClosedIndices(true) + .ignoreThrottled(false) + .allowFailureIndices(false) + .build() + ) + .build(); public Request(StreamInput in) throws IOException { super(in); diff --git a/muted-tests.yml b/muted-tests.yml index ac6faae986d44..5d14cabdd46ce 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -71,9 +71,6 @@ tests: - class: org.elasticsearch.xpack.test.rest.XPackRestIT method: test {p0=rollup/security_tests/Index-based access} issue: https://github.com/elastic/elasticsearch/issues/111631 -- class: org.elasticsearch.tdigest.ComparisonTests - method: testSparseGaussianDistribution - issue: https://github.com/elastic/elasticsearch/issues/111721 - class: org.elasticsearch.upgrades.FullClusterRestartIT method: testSnapshotRestore {cluster=OLD} issue: https://github.com/elastic/elasticsearch/issues/111777 @@ -343,9 +340,6 @@ tests: - class: org.elasticsearch.xpack.security.CoreWithSecurityClientYamlTestSuiteIT method: test {yaml=cluster.stats/30_ccs_stats/cross-cluster search stats search} issue: https://github.com/elastic/elasticsearch/issues/114371 -- class: org.elasticsearch.xpack.esql.qa.single_node.RestEsqlIT - method: testProfileOrdinalsGroupingOperator {SYNC} - issue: https://github.com/elastic/elasticsearch/issues/114380 - class: org.elasticsearch.xpack.inference.services.cohere.CohereServiceTests method: testInfer_StreamRequest issue: https://github.com/elastic/elasticsearch/issues/114385 @@ -384,6 +378,15 @@ tests: - class: org.elasticsearch.datastreams.logsdb.qa.LogsDbVersusLogsDbReindexedIntoStandardModeChallengeRestIT method: testTermsQuery issue: https://github.com/elastic/elasticsearch/issues/114563 +- class: org.elasticsearch.datastreams.logsdb.qa.LogsDbVersusLogsDbReindexedIntoStandardModeChallengeRestIT + method: testMatchAllQuery + issue: https://github.com/elastic/elasticsearch/issues/114607 +- class: org.elasticsearch.xpack.esql.optimizer.PhysicalPlanOptimizerTests + method: testPushSpatialIntersectsEvalToSource {default} + issue: https://github.com/elastic/elasticsearch/issues/114627 +- class: org.elasticsearch.xpack.esql.optimizer.PhysicalPlanOptimizerTests + method: testPushWhereEvalToSource {default} + issue: https://github.com/elastic/elasticsearch/issues/114628 # Examples: # diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/70_dense_vector_telemetry.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/70_dense_vector_telemetry.yml index 66b05e4d0d156..16574ceb587b4 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/70_dense_vector_telemetry.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/70_dense_vector_telemetry.yml @@ -21,13 +21,13 @@ setup: element_type: byte index_options: type: hnsw + m: 16 + ef_construction: 100 vector2: type: dense_vector dims: 1024 index: true similarity: dot_product - index_options: - type: int8_hnsw vector3: type: dense_vector dims: 100 diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index d136aac8a2e5c..0f9c27a7877b8 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -239,6 +239,8 @@ static TransportVersion def(int id) { public static final TransportVersion TEXT_SIMILARITY_RERANKER_QUERY_REWRITE = def(8_763_00_0); public static final TransportVersion SIMULATE_INDEX_TEMPLATES_SUBSTITUTIONS = def(8_764_00_0); public static final TransportVersion RETRIEVERS_TELEMETRY_ADDED = def(8_765_00_0); + public static final TransportVersion ESQL_CACHED_STRING_SERIALIZATION = def(8_766_00_0); + public static final TransportVersion CHUNK_SENTENCE_OVERLAP_SETTING_ADDED = def(8_767_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/alias/IndicesAliasesRequest.java b/server/src/main/java/org/elasticsearch/action/admin/indices/alias/IndicesAliasesRequest.java index cf06dd34fd5ca..d66cab1d2d717 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/alias/IndicesAliasesRequest.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/alias/IndicesAliasesRequest.java @@ -58,7 +58,26 @@ public class IndicesAliasesRequest extends AcknowledgedRequest implements IndicesRequest.Replaceable { - public static final IndicesOptions DEFAULT_INDICES_OPTIONS = IndicesOptions.fromOptions( - false, - true, - true, - true, - false, - false, - true, - false - ); + public static final IndicesOptions DEFAULT_INDICES_OPTIONS = IndicesOptions.builder() + .concreteTargetOptions(IndicesOptions.ConcreteTargetOptions.ERROR_WHEN_UNAVAILABLE_TARGETS) + .wildcardOptions( + IndicesOptions.WildcardOptions.builder() + .matchOpen(true) + .matchClosed(true) + .allowEmptyExpressions(true) + .resolveAliases(false) + .build() + ) + .gatekeeperOptions( + IndicesOptions.GatekeeperOptions.builder() + .allowAliasToMultipleIndices(false) + .allowClosedIndices(true) + .ignoreThrottled(false) + .allowFailureIndices(true) + .build() + ) + .build(); private String[] indices; // Delete index should work by default on both open and closed indices. diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/template/post/TransportSimulateIndexTemplateAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/template/post/TransportSimulateIndexTemplateAction.java index fdced5fc18ac9..ec8eb4babfdac 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/template/post/TransportSimulateIndexTemplateAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/template/post/TransportSimulateIndexTemplateAction.java @@ -280,6 +280,7 @@ public static Template resolveTemplate( templateSettings, mappings ); + MetadataCreateIndexService.validateAdditionalSettings(provider, result, additionalSettings); dummySettings.put(result); additionalSettings.put(result); } diff --git a/server/src/main/java/org/elasticsearch/action/datastreams/DataStreamsStatsAction.java b/server/src/main/java/org/elasticsearch/action/datastreams/DataStreamsStatsAction.java index 2bd4d223bc4ae..fbb084e8cd121 100644 --- a/server/src/main/java/org/elasticsearch/action/datastreams/DataStreamsStatsAction.java +++ b/server/src/main/java/org/elasticsearch/action/datastreams/DataStreamsStatsAction.java @@ -40,7 +40,32 @@ public static class Request extends BroadcastRequest { public Request() { // this doesn't really matter since data stream name resolution isn't affected by IndicesOptions and // a data stream's backing indices are retrieved from its metadata - super(null, IndicesOptions.fromOptions(false, true, true, true, true, false, true, false)); + super( + null, + IndicesOptions.builder() + .concreteTargetOptions(IndicesOptions.ConcreteTargetOptions.ERROR_WHEN_UNAVAILABLE_TARGETS) + .wildcardOptions( + IndicesOptions.WildcardOptions.builder() + .matchOpen(true) + .matchClosed(true) + .includeHidden(false) + .resolveAliases(false) + .allowEmptyExpressions(true) + .build() + ) + .gatekeeperOptions( + IndicesOptions.GatekeeperOptions.builder() + .allowAliasToMultipleIndices(true) + .allowClosedIndices(true) + .ignoreThrottled(false) + .allowFailureIndices(true) + .build() + ) + .failureStoreOptions( + IndicesOptions.FailureStoreOptions.builder().includeRegularIndices(true).includeFailureIndices(true).build() + ) + .build() + ); } public Request(StreamInput in) throws IOException { diff --git a/server/src/main/java/org/elasticsearch/action/datastreams/DeleteDataStreamAction.java b/server/src/main/java/org/elasticsearch/action/datastreams/DeleteDataStreamAction.java index 4f3e238796ed6..4f647d4f02884 100644 --- a/server/src/main/java/org/elasticsearch/action/datastreams/DeleteDataStreamAction.java +++ b/server/src/main/java/org/elasticsearch/action/datastreams/DeleteDataStreamAction.java @@ -46,7 +46,25 @@ public static class Request extends MasterNodeRequest implements Indice // empty response can be returned in case wildcards were used or // 404 status code returned in case no wildcard were used. private final boolean wildcardExpressionsOriginallySpecified; - private IndicesOptions indicesOptions = IndicesOptions.fromOptions(false, true, true, true, false, false, true, false); + private IndicesOptions indicesOptions = IndicesOptions.builder() + .concreteTargetOptions(IndicesOptions.ConcreteTargetOptions.ERROR_WHEN_UNAVAILABLE_TARGETS) + .wildcardOptions( + IndicesOptions.WildcardOptions.builder() + .matchOpen(true) + .matchClosed(true) + .resolveAliases(false) + .allowEmptyExpressions(true) + .build() + ) + .gatekeeperOptions( + IndicesOptions.GatekeeperOptions.builder() + .allowAliasToMultipleIndices(false) + .allowClosedIndices(true) + .ignoreThrottled(false) + .allowFailureIndices(true) + .build() + ) + .build(); public Request(TimeValue masterNodeTimeout, String... names) { super(masterNodeTimeout); diff --git a/server/src/main/java/org/elasticsearch/action/datastreams/GetDataStreamAction.java b/server/src/main/java/org/elasticsearch/action/datastreams/GetDataStreamAction.java index 8d7f440ab20e4..c1cf0fa7aab42 100644 --- a/server/src/main/java/org/elasticsearch/action/datastreams/GetDataStreamAction.java +++ b/server/src/main/java/org/elasticsearch/action/datastreams/GetDataStreamAction.java @@ -56,7 +56,26 @@ private GetDataStreamAction() { public static class Request extends MasterNodeReadRequest implements IndicesRequest.Replaceable { private String[] names; - private IndicesOptions indicesOptions = IndicesOptions.fromOptions(false, true, true, true, false, false, true, false); + private IndicesOptions indicesOptions = IndicesOptions.builder() + .concreteTargetOptions(IndicesOptions.ConcreteTargetOptions.ERROR_WHEN_UNAVAILABLE_TARGETS) + .wildcardOptions( + IndicesOptions.WildcardOptions.builder() + .matchOpen(true) + .matchClosed(true) + .includeHidden(false) + .resolveAliases(false) + .allowEmptyExpressions(true) + .build() + ) + .gatekeeperOptions( + IndicesOptions.GatekeeperOptions.builder() + .allowAliasToMultipleIndices(false) + .allowClosedIndices(true) + .ignoreThrottled(false) + .allowFailureIndices(true) + .build() + ) + .build(); private boolean includeDefaults = false; private boolean verbose = false; diff --git a/server/src/main/java/org/elasticsearch/action/datastreams/lifecycle/GetDataStreamLifecycleAction.java b/server/src/main/java/org/elasticsearch/action/datastreams/lifecycle/GetDataStreamLifecycleAction.java index 6314f47ab9516..bd628c88a1b1e 100644 --- a/server/src/main/java/org/elasticsearch/action/datastreams/lifecycle/GetDataStreamLifecycleAction.java +++ b/server/src/main/java/org/elasticsearch/action/datastreams/lifecycle/GetDataStreamLifecycleAction.java @@ -47,7 +47,26 @@ private GetDataStreamLifecycleAction() {/* no instances */} public static class Request extends MasterNodeReadRequest implements IndicesRequest.Replaceable { private String[] names; - private IndicesOptions indicesOptions = IndicesOptions.fromOptions(false, true, true, true, false, false, true, false); + private IndicesOptions indicesOptions = IndicesOptions.builder() + .concreteTargetOptions(IndicesOptions.ConcreteTargetOptions.ERROR_WHEN_UNAVAILABLE_TARGETS) + .wildcardOptions( + IndicesOptions.WildcardOptions.builder() + .matchOpen(true) + .matchClosed(true) + .includeHidden(false) + .resolveAliases(false) + .allowEmptyExpressions(true) + .build() + ) + .gatekeeperOptions( + IndicesOptions.GatekeeperOptions.builder() + .allowAliasToMultipleIndices(false) + .allowClosedIndices(true) + .ignoreThrottled(false) + .allowFailureIndices(true) + .build() + ) + .build(); private boolean includeDefaults = false; public Request(TimeValue masterNodeTimeout, String[] names) { diff --git a/server/src/main/java/org/elasticsearch/action/datastreams/lifecycle/PutDataStreamLifecycleAction.java b/server/src/main/java/org/elasticsearch/action/datastreams/lifecycle/PutDataStreamLifecycleAction.java index 77f723a46f168..b054d12890366 100644 --- a/server/src/main/java/org/elasticsearch/action/datastreams/lifecycle/PutDataStreamLifecycleAction.java +++ b/server/src/main/java/org/elasticsearch/action/datastreams/lifecycle/PutDataStreamLifecycleAction.java @@ -78,7 +78,26 @@ public static Request parseRequest(XContentParser parser, Factory factory) { } private String[] names; - private IndicesOptions indicesOptions = IndicesOptions.fromOptions(false, true, true, true, false, false, true, false); + private IndicesOptions indicesOptions = IndicesOptions.builder() + .concreteTargetOptions(IndicesOptions.ConcreteTargetOptions.ERROR_WHEN_UNAVAILABLE_TARGETS) + .wildcardOptions( + IndicesOptions.WildcardOptions.builder() + .matchOpen(true) + .matchClosed(true) + .includeHidden(false) + .resolveAliases(false) + .allowEmptyExpressions(true) + .build() + ) + .gatekeeperOptions( + IndicesOptions.GatekeeperOptions.builder() + .allowAliasToMultipleIndices(false) + .allowClosedIndices(true) + .ignoreThrottled(false) + .allowFailureIndices(false) + .build() + ) + .build(); private final DataStreamLifecycle lifecycle; public Request(StreamInput in) throws IOException { diff --git a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java index 772f36898202b..7ad61f60c0088 100644 --- a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java @@ -116,7 +116,8 @@ private void innerRun() throws Exception { // still use DFS_QUERY_THEN_FETCH, which does not perform the "query and fetch" optimization during the query phase. final boolean queryAndFetchOptimization = searchPhaseShardResults.length() == 1 && context.getRequest().hasKnnSearch() == false - && reducedQueryPhase.queryPhaseRankCoordinatorContext() == null; + && reducedQueryPhase.queryPhaseRankCoordinatorContext() == null + && (context.getRequest().source() == null || context.getRequest().source().rankBuilder() == null); if (queryAndFetchOptimization) { assert assertConsistentWithQueryAndFetchOptimization(); // query AND fetch optimization diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java index 1c4eb1c191370..74786dff1648d 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java @@ -462,8 +462,7 @@ private static SearchHits getHits( : "not enough hits fetched. index [" + index + "] length: " + fetchResult.hits().getHits().length; SearchHit searchHit = fetchResult.hits().getHits()[index]; searchHit.shard(fetchResult.getSearchShardTarget()); - if (reducedQueryPhase.queryPhaseRankCoordinatorContext != null) { - assert shardDoc instanceof RankDoc; + if (shardDoc instanceof RankDoc) { searchHit.setRank(((RankDoc) shardDoc).rank); searchHit.score(shardDoc.score); long shardAndDoc = ShardDocSortField.encodeShardAndDoc(shardDoc.shardIndex, shardDoc.doc); @@ -735,6 +734,12 @@ static int getTopDocsSize(SearchRequest request) { return DEFAULT_SIZE; } SearchSourceBuilder source = request.source(); + if (source.rankBuilder() != null) { + // if we have a RankBuilder defined, it needs to have access to all the documents in order to rerank them + // so we override size here and keep all `rank_window_size` docs. + // Pagination is taking place later through RankFeaturePhaseRankCoordinatorContext#rankAndPaginate + return source.rankBuilder().rankWindowSize(); + } return (source.size() == -1 ? DEFAULT_SIZE : source.size()) + (source.from() == -1 ? SearchService.DEFAULT_FROM : source.from()); } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java index f43f1c6b05a15..1cebbabde0769 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java @@ -980,7 +980,6 @@ static Settings aggregateIndexSettings( final Settings.Builder indexSettingsBuilder = Settings.builder(); if (sourceMetadata == null) { - final Settings.Builder additionalIndexSettings = Settings.builder(); final Settings templateAndRequestSettings = Settings.builder().put(combinedTemplateSettings).put(request.settings()).build(); final boolean timeSeriesTemplate = Optional.of(request) @@ -990,19 +989,20 @@ static Settings aggregateIndexSettings( // Loop through all the explicit index setting providers, adding them to the // additionalIndexSettings map + final Settings.Builder additionalIndexSettings = Settings.builder(); final var resolvedAt = Instant.ofEpochMilli(request.getNameResolvedAt()); for (IndexSettingProvider provider : indexSettingProviders) { - additionalIndexSettings.put( - provider.getAdditionalIndexSettings( - request.index(), - request.dataStreamName(), - timeSeriesTemplate, - currentState.getMetadata(), - resolvedAt, - templateAndRequestSettings, - combinedTemplateMappings - ) + var newAdditionalSettings = provider.getAdditionalIndexSettings( + request.index(), + request.dataStreamName(), + timeSeriesTemplate, + currentState.getMetadata(), + resolvedAt, + templateAndRequestSettings, + combinedTemplateMappings ); + validateAdditionalSettings(provider, newAdditionalSettings, additionalIndexSettings); + additionalIndexSettings.put(newAdditionalSettings); } // For all the explicit settings, we go through the template and request level settings @@ -1111,6 +1111,29 @@ static Settings aggregateIndexSettings( return indexSettings; } + /** + * Validates whether additional settings don't have keys that are already defined in all additional settings. + * + * @param provider The {@link IndexSettingProvider} that produced additionalSettings + * @param additionalSettings The settings produced by the specified provider + * @param allAdditionalSettings A settings builder containing all additional settings produced by any {@link IndexSettingProvider} + * that already executed + * @throws IllegalArgumentException If keys in additionalSettings are already defined in allAdditionalSettings + */ + public static void validateAdditionalSettings( + IndexSettingProvider provider, + Settings additionalSettings, + Settings.Builder allAdditionalSettings + ) throws IllegalArgumentException { + for (String settingName : additionalSettings.keySet()) { + if (allAdditionalSettings.keys().contains(settingName)) { + var name = provider.getClass().getSimpleName(); + var message = Strings.format("additional index setting [%s] added by [%s] is already present", settingName, name); + throw new IllegalArgumentException(message); + } + } + } + private static void validateSoftDeleteSettings(Settings indexSettings) { if (IndexSettings.INDEX_SOFT_DELETES_SETTING.get(indexSettings) == false && IndexMetadata.SETTING_INDEX_VERSION_CREATED.get(indexSettings).onOrAfter(IndexVersions.V_8_0_0)) { diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexTemplateService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexTemplateService.java index 2a2cf6743a877..57194ded9422e 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexTemplateService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexTemplateService.java @@ -694,25 +694,25 @@ private void validateIndexTemplateV2(String name, ComposableIndexTemplate indexT // Workaround for the fact that start_time and end_time are injected by the MetadataCreateDataStreamService upon creation, // but when validating templates that create data streams the MetadataCreateDataStreamService isn't used. var finalTemplate = indexTemplate.template(); - var finalSettings = Settings.builder(); final var now = Instant.now(); final var metadata = currentState.getMetadata(); final var combinedMappings = collectMappings(indexTemplate, metadata.componentTemplates(), "tmp_idx"); final var combinedSettings = resolveSettings(indexTemplate, metadata.componentTemplates()); // First apply settings sourced from index setting providers: + var finalSettings = Settings.builder(); for (var provider : indexSettingProviders) { - finalSettings.put( - provider.getAdditionalIndexSettings( - "validate-index-name", - indexTemplate.getDataStreamTemplate() != null ? "validate-data-stream-name" : null, - indexTemplate.getDataStreamTemplate() != null && metadata.isTimeSeriesTemplate(indexTemplate), - currentState.getMetadata(), - now, - combinedSettings, - combinedMappings - ) + var newAdditionalSettings = provider.getAdditionalIndexSettings( + "validate-index-name", + indexTemplate.getDataStreamTemplate() != null ? "validate-data-stream-name" : null, + indexTemplate.getDataStreamTemplate() != null && metadata.isTimeSeriesTemplate(indexTemplate), + currentState.getMetadata(), + now, + combinedSettings, + combinedMappings ); + MetadataCreateIndexService.validateAdditionalSettings(provider, newAdditionalSettings, finalSettings); + finalSettings.put(newAdditionalSettings); } // Then apply setting from component templates: finalSettings.put(combinedSettings); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java index b1e91ad75e9a2..cc5454ee074e6 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java @@ -32,6 +32,8 @@ import java.io.IOException; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; + public class ES813FlatVectorFormat extends KnnVectorsFormat { static final String NAME = "ES813FlatVectorFormat"; @@ -55,6 +57,11 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException return new ES813FlatVectorReader(format.fieldsReader(state)); } + @Override + public int getMaxDimensions(String fieldName) { + return MAX_DIMS_COUNT; + } + static class ES813FlatVectorWriter extends KnnVectorsWriter { private final FlatVectorsWriter writer; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813Int8FlatVectorFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813Int8FlatVectorFormat.java index 248421fb99d1c..9491598653c44 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813Int8FlatVectorFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813Int8FlatVectorFormat.java @@ -30,6 +30,8 @@ import java.io.IOException; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; + public class ES813Int8FlatVectorFormat extends KnnVectorsFormat { static final String NAME = "ES813Int8FlatVectorFormat"; @@ -58,6 +60,11 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException return new ES813FlatVectorReader(format.fieldsReader(state)); } + @Override + public int getMaxDimensions(String fieldName) { + return MAX_DIMS_COUNT; + } + @Override public String toString() { return NAME + "(name=" + NAME + ", innerFormat=" + format + ")"; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormat.java index d6ce73dd4149a..6bb32d8e1ef52 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormat.java @@ -22,6 +22,7 @@ import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; public final class ES814HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat { @@ -70,7 +71,7 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException @Override public int getMaxDimensions(String fieldName) { - return 1024; + return MAX_DIMS_COUNT; } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorFormat.java index af771b6a27f19..5cd5872e10421 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorFormat.java @@ -18,6 +18,8 @@ import java.io.IOException; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; + public class ES815BitFlatVectorFormat extends KnnVectorsFormat { static final String NAME = "ES815BitFlatVectorFormat"; @@ -45,4 +47,9 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException public String toString() { return NAME; } + + @Override + public int getMaxDimensions(String fieldName) { + return MAX_DIMS_COUNT; + } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormat.java index 5e4656ea94c5b..186dfcbeb5d52 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormat.java @@ -20,6 +20,8 @@ import java.io.IOException; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; + public class ES815HnswBitVectorsFormat extends KnnVectorsFormat { static final String NAME = "ES815HnswBitVectorsFormat"; @@ -72,4 +74,9 @@ public String toString() { + flatVectorsFormat + ")"; } + + @Override + public int getMaxDimensions(String fieldName) { + return MAX_DIMS_COUNT; + } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsFormat.java index 523d5f6c4a91f..e32aea0fb04ae 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsFormat.java @@ -29,6 +29,8 @@ import java.io.IOException; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; + /** * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 */ @@ -68,6 +70,11 @@ public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException return new ES816BinaryQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state), scorer); } + @Override + public int getMaxDimensions(String fieldName) { + return MAX_DIMS_COUNT; + } + @Override public String toString() { return "ES816BinaryQuantizedVectorsFormat(name=" + NAME + ", flatVectorScorer=" + scorer + ")"; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816HnswBinaryQuantizedVectorsFormat.java index 989f88e0a7857..097cdffff6ae4 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816HnswBinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES816HnswBinaryQuantizedVectorsFormat.java @@ -39,6 +39,7 @@ import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; /** * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 @@ -128,7 +129,7 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException @Override public int getMaxDimensions(String fieldName) { - return 1024; + return MAX_DIMS_COUNT; } @Override diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 4adfe619ca4e1..d7353584706d8 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -139,32 +139,27 @@ public static class Builder extends FieldMapper.Builder { if (o instanceof Integer == false) { throw new MapperParsingException("Property [dims] on field [" + n + "] must be an integer but got [" + o + "]"); } - int dims = XContentMapValues.nodeIntegerValue(o); - int maxDims = elementType.getValue() == ElementType.BIT ? MAX_DIMS_COUNT_BIT : MAX_DIMS_COUNT; - int minDims = elementType.getValue() == ElementType.BIT ? Byte.SIZE : 1; - if (dims < minDims || dims > maxDims) { - throw new MapperParsingException( - "The number of dimensions for field [" - + n - + "] should be in the range [" - + minDims - + ", " - + maxDims - + "] but was [" - + dims - + "]" - ); - } - if (elementType.getValue() == ElementType.BIT) { - if (dims % Byte.SIZE != 0) { + + return XContentMapValues.nodeIntegerValue(o); + }, m -> toType(m).fieldType().dims, XContentBuilder::field, Object::toString).setSerializerCheck((id, ic, v) -> v != null) + .setMergeValidator((previous, current, c) -> previous == null || Objects.equals(previous, current)) + .addValidator(dims -> { + if (dims == null) { + return; + } + int maxDims = elementType.getValue() == ElementType.BIT ? MAX_DIMS_COUNT_BIT : MAX_DIMS_COUNT; + int minDims = elementType.getValue() == ElementType.BIT ? Byte.SIZE : 1; + if (dims < minDims || dims > maxDims) { throw new MapperParsingException( - "The number of dimensions for field [" + n + "] should be a multiple of 8 but was [" + dims + "]" + "The number of dimensions should be in the range [" + minDims + ", " + maxDims + "] but was [" + dims + "]" ); } - } - return dims; - }, m -> toType(m).fieldType().dims, XContentBuilder::field, Object::toString).setSerializerCheck((id, ic, v) -> v != null) - .setMergeValidator((previous, current, c) -> previous == null || Objects.equals(previous, current)); + if (elementType.getValue() == ElementType.BIT) { + if (dims % Byte.SIZE != 0) { + throw new MapperParsingException("The number of dimensions for should be a multiple of 8 but was [" + dims + "]"); + } + } + }); private final Parameter similarity; private final Parameter indexOptions; diff --git a/server/src/main/java/org/elasticsearch/node/InternalSettingsPreparer.java b/server/src/main/java/org/elasticsearch/node/InternalSettingsPreparer.java index 2e606f7b83eb8..94227aed40d13 100644 --- a/server/src/main/java/org/elasticsearch/node/InternalSettingsPreparer.java +++ b/server/src/main/java/org/elasticsearch/node/InternalSettingsPreparer.java @@ -54,7 +54,11 @@ public static Environment prepareEnvironment( loadOverrides(output, properties); output.put(input); replaceForcedSettings(output); - output.replacePropertyPlaceholders(); + try { + output.replacePropertyPlaceholders(); + } catch (Exception e) { + throw new SettingsException("Failed to replace property placeholders from [" + configFile.getFileName() + "]", e); + } ensureSpecialSettingsExist(output, defaultNodeName); return new Environment(output.build(), configDir); diff --git a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java index ff1594de523cf..6bc667d4359b1 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java @@ -75,7 +75,10 @@ public void execute(SearchContext context, int[] docIdsToLoad, RankDocShardInfo return; } - Profiler profiler = context.getProfilers() == null ? Profiler.NOOP : Profilers.startProfilingFetchPhase(); + Profiler profiler = context.getProfilers() == null + || (context.request().source() != null && context.request().source().rankBuilder() != null) + ? Profiler.NOOP + : Profilers.startProfilingFetchPhase(); SearchHits hits = null; try { hits = buildSearchHits(context, docIdsToLoad, profiler, rankDocs); diff --git a/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java b/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java index 423fad92483ed..d17cd4f69dec7 100644 --- a/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java +++ b/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java @@ -60,6 +60,13 @@ private QueryPhase() {} public static void execute(SearchContext searchContext) throws QueryPhaseExecutionException { if (searchContext.queryPhaseRankShardContext() == null) { + if (searchContext.request().source() != null && searchContext.request().source().rankBuilder() != null) { + // if we have a RankBuilder provided, we want to fetch all rankWindowSize results + // and rerank the documents as per the RankBuilder's instructions. + // Pagination will take place later once they're all (re)ranked. + searchContext.size(searchContext.request().source().rankBuilder().rankWindowSize()); + searchContext.from(0); + } executeQuery(searchContext); } else { executeRank(searchContext); diff --git a/server/src/test/java/org/elasticsearch/action/support/CancellableFanOutTests.java b/server/src/test/java/org/elasticsearch/action/support/CancellableFanOutTests.java index 5a8aa5acf09ba..4882d4d8fd5db 100644 --- a/server/src/test/java/org/elasticsearch/action/support/CancellableFanOutTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/CancellableFanOutTests.java @@ -248,6 +248,7 @@ public void testConcurrency() throws InterruptedException { final var itemsProcessed = new AtomicInteger(); final var completionLatch = new CountDownLatch(1); + final var onCompletionCalled = new AtomicBoolean(); new CancellableFanOut() { @Override protected void sendItemRequest(String s, ActionListener listener) { @@ -261,6 +262,7 @@ protected void onItemResponse(String s, String response) { assertCurrentThread(isProcessorThread); assertEquals(s, response); assertThat(itemsProcessed.incrementAndGet(), lessThanOrEqualTo(items.size())); + assertFalse(onCompletionCalled.get()); } @Override @@ -269,10 +271,12 @@ protected void onItemFailure(String s, Exception e) { assertThat(e, instanceOf(ElasticsearchException.class)); assertEquals("sendItemRequest", e.getMessage()); assertThat(itemsProcessed.incrementAndGet(), lessThanOrEqualTo(items.size())); + assertFalse(onCompletionCalled.get()); } @Override protected String onCompletion() { + assertTrue(onCompletionCalled.compareAndSet(false, true)); assertEquals(items.size(), itemsProcessed.get()); assertCurrentThread(anyOf(isTestThread, isProcessorThread)); if (randomBoolean()) { diff --git a/server/src/test/java/org/elasticsearch/index/IndexSettingProviderTests.java b/server/src/test/java/org/elasticsearch/index/IndexSettingProviderTests.java new file mode 100644 index 0000000000000..387340c0a6f50 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/IndexSettingProviderTests.java @@ -0,0 +1,95 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index; + +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.common.compress.CompressedXContent; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESSingleNodeTestCase; + +import java.time.Instant; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + +public class IndexSettingProviderTests extends ESSingleNodeTestCase { + + public void testIndexCreation() throws Exception { + var indexService = createIndex("my-index1"); + assertFalse(indexService.getIndexSettings().getSettings().hasValue("index.refresh_interval")); + + INDEX_SETTING_PROVIDER1_ENABLED.set(true); + indexService = createIndex("my-index2"); + assertTrue(indexService.getIndexSettings().getSettings().hasValue("index.refresh_interval")); + + INDEX_SETTING_PROVIDER2_ENABLED.set(true); + var e = expectThrows(IllegalArgumentException.class, () -> createIndex("my-index3")); + assertEquals( + "additional index setting [index.refresh_interval] added by [TestIndexSettingsProvider] is already present", + e.getMessage() + ); + } + + @Override + protected Collection> getPlugins() { + return List.of(Plugin1.class, Plugin2.class); + } + + public static class Plugin1 extends Plugin { + + @Override + public Collection getAdditionalIndexSettingProviders(IndexSettingProvider.Parameters parameters) { + return List.of(new TestIndexSettingsProvider("index.refresh_interval", "-1", INDEX_SETTING_PROVIDER1_ENABLED)); + } + + } + + public static class Plugin2 extends Plugin { + + @Override + public Collection getAdditionalIndexSettingProviders(IndexSettingProvider.Parameters parameters) { + return List.of(new TestIndexSettingsProvider("index.refresh_interval", "100s", INDEX_SETTING_PROVIDER2_ENABLED)); + } + } + + private static final AtomicBoolean INDEX_SETTING_PROVIDER1_ENABLED = new AtomicBoolean(false); + private static final AtomicBoolean INDEX_SETTING_PROVIDER2_ENABLED = new AtomicBoolean(false); + + static class TestIndexSettingsProvider implements IndexSettingProvider { + + private final String settingName; + private final String settingValue; + private final AtomicBoolean enabled; + + TestIndexSettingsProvider(String settingName, String settingValue, AtomicBoolean enabled) { + this.settingName = settingName; + this.settingValue = settingValue; + this.enabled = enabled; + } + + @Override + public Settings getAdditionalIndexSettings( + String indexName, + String dataStreamName, + boolean isTimeSeries, + Metadata metadata, + Instant resolvedAt, + Settings indexTemplateAndCreateRequestSettings, + List combinedTemplateMappings + ) { + if (enabled.get()) { + return Settings.builder().put(settingName, settingValue).build(); + } else { + return Settings.EMPTY; + } + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index 8aede4940443c..04b9b05ecfe3a 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -175,7 +175,7 @@ protected void registerParameters(ParameterChecker checker) throws IOException { ), fieldMapping( b -> b.field("type", "dense_vector") - .field("dims", dims) + .field("dims", dims * 8) .field("index", true) .field("similarity", "l2_norm") .field("element_type", "bit") @@ -192,7 +192,7 @@ protected void registerParameters(ParameterChecker checker) throws IOException { ), fieldMapping( b -> b.field("type", "dense_vector") - .field("dims", dims) + .field("dims", dims * 8) .field("index", true) .field("similarity", "l2_norm") .field("element_type", "bit") @@ -891,9 +891,7 @@ public void testDims() { }))); assertThat( e.getMessage(), - equalTo( - "Failed to parse mapping: " + "The number of dimensions for field [field] should be in the range [1, 4096] but was [0]" - ) + equalTo("Failed to parse mapping: " + "The number of dimensions should be in the range [1, 4096] but was [0]") ); } // test max limit for non-indexed vectors @@ -904,10 +902,7 @@ public void testDims() { }))); assertThat( e.getMessage(), - equalTo( - "Failed to parse mapping: " - + "The number of dimensions for field [field] should be in the range [1, 4096] but was [5000]" - ) + equalTo("Failed to parse mapping: " + "The number of dimensions should be in the range [1, 4096] but was [5000]") ); } // test max limit for indexed vectors @@ -919,10 +914,7 @@ public void testDims() { }))); assertThat( e.getMessage(), - equalTo( - "Failed to parse mapping: " - + "The number of dimensions for field [field] should be in the range [1, 4096] but was [5000]" - ) + equalTo("Failed to parse mapping: " + "The number of dimensions should be in the range [1, 4096] but was [5000]") ); } } @@ -955,6 +947,14 @@ public void testMergeDims() throws IOException { ); } + public void testLargeDimsBit() throws IOException { + createMapperService(fieldMapping(b -> { + b.field("type", "dense_vector"); + b.field("dims", 1024 * Byte.SIZE); + b.field("element_type", ElementType.BIT.toString()); + })); + } + public void testDefaults() throws Exception { DocumentMapper mapper = createDocumentMapper(fieldMapping(b -> b.field("type", "dense_vector").field("dims", 3))); diff --git a/server/src/test/java/org/elasticsearch/node/InternalSettingsPreparerTests.java b/server/src/test/java/org/elasticsearch/node/InternalSettingsPreparerTests.java index 3d406fff79eb0..32edcc0ad82aa 100644 --- a/server/src/test/java/org/elasticsearch/node/InternalSettingsPreparerTests.java +++ b/server/src/test/java/org/elasticsearch/node/InternalSettingsPreparerTests.java @@ -86,6 +86,20 @@ public void testGarbageIsNotSwallowed() throws IOException { } } + public void testReplacePlaceholderFailure() { + try { + InternalSettingsPreparer.prepareEnvironment( + Settings.builder().put(baseEnvSettings).put("cluster.name", "${ES_CLUSTER_NAME}").build(), + emptyMap(), + null, + () -> "default_node_name" + ); + fail("Expected SettingsException"); + } catch (SettingsException e) { + assertEquals("Failed to replace property placeholders from [elasticsearch.yml]", e.getMessage()); + } + } + public void testSecureSettings() { MockSecureSettings secureSettings = new MockSecureSettings(); secureSettings.setString("foo", "secret"); diff --git a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java index f50fe16c500dd..809bcd5a9bc12 100644 --- a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java +++ b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java @@ -166,7 +166,7 @@ public void testSortByManyLongsTooMuchMemoryAsync() throws IOException { "error", matchesMap().extraOk() .entry("bytes_wanted", greaterThan(1000)) - .entry("reason", matchesRegex("\\[request] Data too large, data for \\[(topn|esql_block_factory)] would .+")) + .entry("reason", matchesRegex("\\[request] Data too large, data for \\[.+] would be .+")) .entry("durability", "TRANSIENT") .entry("type", "circuit_breaking_exception") .entry("bytes_limit", greaterThan(1000)) diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/FieldAttribute.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/FieldAttribute.java index b5d44d98f476e..767d2f45f90e4 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/FieldAttribute.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/FieldAttribute.java @@ -115,7 +115,7 @@ private FieldAttribute(StreamInput in) throws IOException { this( Source.readFrom((StreamInput & PlanStreamInput) in), in.readOptionalWriteable(FieldAttribute::readFrom), - in.readString(), + ((PlanStreamInput) in).readCachedString(), DataType.readFrom(in), EsField.readFrom(in), in.readOptionalString(), @@ -130,7 +130,7 @@ public void writeTo(StreamOutput out) throws IOException { if (((PlanStreamOutput) out).writeAttributeCacheHeader(this)) { Source.EMPTY.writeTo(out); out.writeOptionalWriteable(parent); - out.writeString(name()); + ((PlanStreamOutput) out).writeCachedString(name()); dataType().writeTo(out); field.writeTo(out); // We used to write the qualifier here. We can still do if needed in the future. diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java index 979368c300e00..c0092caeb9d5d 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java @@ -14,6 +14,8 @@ import org.elasticsearch.index.mapper.SourceFieldMapper; import org.elasticsearch.index.mapper.TimeSeriesIdFieldMapper; import org.elasticsearch.xpack.esql.core.plugin.EsqlCorePlugin; +import org.elasticsearch.xpack.esql.core.util.PlanStreamInput; +import org.elasticsearch.xpack.esql.core.util.PlanStreamOutput; import java.io.IOException; import java.math.BigInteger; @@ -519,12 +521,12 @@ public DataType counter() { } public void writeTo(StreamOutput out) throws IOException { - out.writeString(typeName); + ((PlanStreamOutput) out).writeCachedString(typeName); } public static DataType readFrom(StreamInput in) throws IOException { // TODO: Use our normal enum serialization pattern - return readFrom(in.readString()); + return readFrom(((PlanStreamInput) in).readCachedString()); } /** diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DateEsField.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DateEsField.java index fb1c87c570c26..7c4b98c5af84e 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DateEsField.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DateEsField.java @@ -8,6 +8,8 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.esql.core.util.PlanStreamInput; +import org.elasticsearch.xpack.esql.core.util.PlanStreamOutput; import java.io.IOException; import java.util.Map; @@ -26,12 +28,12 @@ private DateEsField(String name, DataType dataType, Map propert } protected DateEsField(StreamInput in) throws IOException { - this(in.readString(), DataType.DATETIME, in.readImmutableMap(EsField::readFrom), in.readBoolean()); + this(((PlanStreamInput) in).readCachedString(), DataType.DATETIME, in.readImmutableMap(EsField::readFrom), in.readBoolean()); } @Override public void writeContent(StreamOutput out) throws IOException { - out.writeString(getName()); + ((PlanStreamOutput) out).writeCachedString(getName()); out.writeMap(getProperties(), (o, x) -> x.writeTo(out)); out.writeBoolean(isAggregatable()); } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/EsField.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/EsField.java index 19053d67d34d2..6235176d82de6 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/EsField.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/EsField.java @@ -60,7 +60,7 @@ public EsField(String name, DataType esDataType, Map properties } public EsField(StreamInput in) throws IOException { - this.name = in.readString(); + this.name = ((PlanStreamInput) in).readCachedString(); this.esDataType = readDataType(in); this.properties = in.readImmutableMap(EsField::readFrom); this.aggregatable = in.readBoolean(); @@ -68,7 +68,7 @@ public EsField(StreamInput in) throws IOException { } private DataType readDataType(StreamInput in) throws IOException { - String name = in.readString(); + String name = ((PlanStreamInput) in).readCachedString(); if (in.getTransportVersion().before(TransportVersions.ESQL_NESTED_UNSUPPORTED) && name.equalsIgnoreCase("NESTED")) { /* * The "nested" data type existed in older versions of ESQL but was @@ -98,7 +98,7 @@ public void writeTo(StreamOutput out) throws IOException { * This needs to be overridden by subclasses for specific serialization */ public void writeContent(StreamOutput out) throws IOException { - out.writeString(name); + ((PlanStreamOutput) out).writeCachedString(name); esDataType.writeTo(out); out.writeMap(properties, (o, x) -> x.writeTo(out)); out.writeBoolean(aggregatable); diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/InvalidMappedField.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/InvalidMappedField.java index 18105eb6c724c..40825af56ccfe 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/InvalidMappedField.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/InvalidMappedField.java @@ -10,6 +10,8 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.QlIllegalArgumentException; +import org.elasticsearch.xpack.esql.core.util.PlanStreamInput; +import org.elasticsearch.xpack.esql.core.util.PlanStreamOutput; import java.io.IOException; import java.util.Map; @@ -52,7 +54,7 @@ private InvalidMappedField(String name, String errorMessage, Map types() { @@ -61,7 +63,7 @@ public Set types() { @Override public void writeContent(StreamOutput out) throws IOException { - out.writeString(getName()); + ((PlanStreamOutput) out).writeCachedString(getName()); out.writeString(errorMessage); out.writeMap(getProperties(), (o, x) -> x.writeTo(out)); } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/KeywordEsField.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/KeywordEsField.java index fc7f5bd1f1ea8..48995bafec451 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/KeywordEsField.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/KeywordEsField.java @@ -8,6 +8,8 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.esql.core.util.PlanStreamInput; +import org.elasticsearch.xpack.esql.core.util.PlanStreamOutput; import java.io.IOException; import java.util.Collections; @@ -59,7 +61,7 @@ protected KeywordEsField( public KeywordEsField(StreamInput in) throws IOException { this( - in.readString(), + ((PlanStreamInput) in).readCachedString(), KEYWORD, in.readImmutableMap(EsField::readFrom), in.readBoolean(), @@ -71,7 +73,7 @@ public KeywordEsField(StreamInput in) throws IOException { @Override public void writeContent(StreamOutput out) throws IOException { - out.writeString(getName()); + ((PlanStreamOutput) out).writeCachedString(getName()); out.writeMap(getProperties(), (o, x) -> x.writeTo(out)); out.writeBoolean(isAggregatable()); out.writeInt(precision); diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/MultiTypeEsField.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/MultiTypeEsField.java index 81dc77eddcdf8..522cb682c0943 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/MultiTypeEsField.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/MultiTypeEsField.java @@ -10,6 +10,8 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.util.PlanStreamInput; +import org.elasticsearch.xpack.esql.core.util.PlanStreamOutput; import java.io.IOException; import java.util.HashMap; @@ -36,13 +38,18 @@ public MultiTypeEsField(String name, DataType dataType, boolean aggregatable, Ma } protected MultiTypeEsField(StreamInput in) throws IOException { - this(in.readString(), DataType.readFrom(in), in.readBoolean(), in.readImmutableMap(i -> i.readNamedWriteable(Expression.class))); + this( + ((PlanStreamInput) in).readCachedString(), + DataType.readFrom(in), + in.readBoolean(), + in.readImmutableMap(i -> i.readNamedWriteable(Expression.class)) + ); } @Override public void writeContent(StreamOutput out) throws IOException { - out.writeString(getName()); - out.writeString(getDataType().typeName()); + ((PlanStreamOutput) out).writeCachedString(getName()); + getDataType().writeTo(out); out.writeBoolean(isAggregatable()); out.writeMap(getIndexToConversionExpressions(), (o, v) -> out.writeNamedWriteable(v)); } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/TextEsField.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/TextEsField.java index 3162224579387..c6c494ef289bb 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/TextEsField.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/TextEsField.java @@ -10,6 +10,8 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.Tuple; import org.elasticsearch.xpack.esql.core.QlIllegalArgumentException; +import org.elasticsearch.xpack.esql.core.util.PlanStreamInput; +import org.elasticsearch.xpack.esql.core.util.PlanStreamOutput; import java.io.IOException; import java.util.Map; @@ -32,12 +34,12 @@ public TextEsField(String name, Map properties, boolean hasDocV } protected TextEsField(StreamInput in) throws IOException { - this(in.readString(), in.readImmutableMap(EsField::readFrom), in.readBoolean(), in.readBoolean()); + this(((PlanStreamInput) in).readCachedString(), in.readImmutableMap(EsField::readFrom), in.readBoolean(), in.readBoolean()); } @Override public void writeContent(StreamOutput out) throws IOException { - out.writeString(getName()); + ((PlanStreamOutput) out).writeCachedString(getName()); out.writeMap(getProperties(), (o, x) -> x.writeTo(out)); out.writeBoolean(isAggregatable()); out.writeBoolean(isAlias()); diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/UnsupportedEsField.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/UnsupportedEsField.java index 13ee2b42a321b..980620cb98847 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/UnsupportedEsField.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/UnsupportedEsField.java @@ -8,6 +8,8 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.esql.core.util.PlanStreamInput; +import org.elasticsearch.xpack.esql.core.util.PlanStreamOutput; import java.io.IOException; import java.util.Map; @@ -34,13 +36,18 @@ public UnsupportedEsField(String name, String originalType, String inherited, Ma } public UnsupportedEsField(StreamInput in) throws IOException { - this(in.readString(), in.readString(), in.readOptionalString(), in.readImmutableMap(EsField::readFrom)); + this( + ((PlanStreamInput) in).readCachedString(), + ((PlanStreamInput) in).readCachedString(), + in.readOptionalString(), + in.readImmutableMap(EsField::readFrom) + ); } @Override public void writeContent(StreamOutput out) throws IOException { - out.writeString(getName()); - out.writeString(getOriginalType()); + ((PlanStreamOutput) out).writeCachedString(getName()); + ((PlanStreamOutput) out).writeCachedString(getOriginalType()); out.writeOptionalString(getInherited()); out.writeMap(getProperties(), (o, x) -> x.writeTo(out)); } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamInput.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamInput.java index 471c9476ad31d..826b0cbfa3498 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamInput.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamInput.java @@ -47,4 +47,6 @@ public interface PlanStreamInput { A readAttributeWithCache(CheckedFunction constructor) throws IOException; A readEsFieldWithCache() throws IOException; + + String readCachedString() throws IOException; } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamOutput.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamOutput.java index 4c30cb66e9f86..e4797411c3796 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamOutput.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamOutput.java @@ -31,4 +31,6 @@ public interface PlanStreamOutput { * @throws IOException */ boolean writeEsFieldCacheHeader(EsField field) throws IOException; + + void writeCachedString(String field) throws IOException; } diff --git a/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/RestEsqlIT.java b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/RestEsqlIT.java index 974180526d750..3388f6f517bdf 100644 --- a/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/RestEsqlIT.java +++ b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/RestEsqlIT.java @@ -45,6 +45,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.startsWith; @@ -331,14 +332,13 @@ public void testProfile() throws IOException { } public void testProfileOrdinalsGroupingOperator() throws IOException { + assumeTrue("requires pragmas", Build.current().isSnapshot()); indexTimestampData(1); RequestObjectBuilder builder = requestObjectBuilder().query(fromIndex() + " | STATS AVG(value) BY test.keyword"); builder.profile(true); - if (Build.current().isSnapshot()) { - // Lock to shard level partitioning, so we get consistent profile output - builder.pragmas(Settings.builder().put("data_partitioning", "shard").build()); - } + // Lock to shard level partitioning, so we get consistent profile output + builder.pragmas(Settings.builder().put("data_partitioning", "shard").build()); Map result = runEsql(builder); List> signatures = new ArrayList<>(); @@ -356,7 +356,7 @@ public void testProfileOrdinalsGroupingOperator() throws IOException { signatures.add(sig); } - assertThat(signatures.get(0).get(2), equalTo("OrdinalsGroupingOperator[aggregators=[\"sum of longs\", \"count\"]]")); + assertThat(signatures, hasItem(hasItem("OrdinalsGroupingOperator[aggregators=[\"sum of longs\", \"count\"]]"))); } public void testInlineStatsProfile() throws IOException { @@ -423,11 +423,9 @@ public void testInlineStatsProfile() throws IOException { .item("ProjectOperator") .item("OutputOperator"), // Second pass read and join via eval - matchesList().item("LuceneSourceOperator") + matchesList().item("LuceneTopNSourceOperator") .item("EvalOperator") .item("ValuesSourceReaderOperator") - .item("TopNOperator") - .item("ValuesSourceReaderOperator") .item("ProjectOperator") .item("ExchangeSinkOperator"), // Second pass node level reduce @@ -591,6 +589,16 @@ private String checkOperatorProfile(Map o) { case "TopNOperator" -> matchesMap().entry("occupied_rows", 0) .entry("ram_used", instanceOf(String.class)) .entry("ram_bytes_used", greaterThan(0)); + case "LuceneTopNSourceOperator" -> matchesMap().entry("pages_emitted", greaterThan(0)) + .entry("current", greaterThan(0)) + .entry("processed_slices", greaterThan(0)) + .entry("processed_shards", List.of("rest-esql-test:0")) + .entry("total_slices", greaterThan(0)) + .entry("slice_max", 0) + .entry("slice_min", 0) + .entry("processing_nanos", greaterThan(0)) + .entry("processed_queries", List.of("*:*")) + .entry("slice_index", 0); default -> throw new AssertionError("unexpected status: " + o); }; MapMatcher expectedOp = matchesMap().entry("operator", startsWith(name)); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/spatial.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/spatial.csv-spec index 35416c7945128..4c40808a4ff96 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/spatial.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/spatial.csv-spec @@ -1116,9 +1116,147 @@ count:long | country:k 1 | Poland ; +airportsWithinEvalDistanceBandCopenhagenTrainStationCount +required_capability: st_distance + +FROM airports +| EVAL distance = ST_DISTANCE(location, TO_GEOPOINT("POINT(12.565 55.673)")) +| WHERE distance < 600000 AND distance > 400000 +| STATS count=COUNT() BY country +| SORT count DESC, country ASC +; + +count:long | country:k +3 | Sweden +2 | Norway +1 | Germany +1 | Lithuania +1 | Poland +; + +airportsWithinEvalDistanceBandCopenhagenTrainStationKeepDistance +required_capability: st_distance + +FROM airports +| EVAL distance = ST_DISTANCE(location, TO_GEOPOINT("POINT(12.565 55.673)")), importance = 10 - scalerank +| WHERE distance < 500000 AND distance > 400000 +| STATS count=COUNT() BY distance, importance +| SORT distance ASC, importance DESC, count DESC +; + +count:long | distance:double | importance:integer +1 | 402611.1308019835 | 4 +1 | 433987.3301951482 | 3 +; + +airportsWithinEvalDistanceBandCopenhagenTrainStationCountNonPushableEval +required_capability: st_distance + +FROM airports +| EVAL distance = ST_DISTANCE(location, TO_GEOPOINT("POINT(12.565 55.673)")), position = location::keyword +| WHERE distance < 600000 AND distance > 400000 AND SUBSTRING(position, 1, 5) == "POINT" +| STATS count=COUNT() BY country +| SORT count DESC, country ASC +; + +count:long | country:k +3 | Sweden +2 | Norway +1 | Germany +1 | Lithuania +1 | Poland +; + +airportsWithinEvalDistanceBandCopenhagenTrainStationCountNonPushableWhereConjunctive +required_capability: st_distance + +FROM airports +| EVAL distance = ST_DISTANCE(location, TO_GEOPOINT("POINT(12.565 55.673)")) +| WHERE distance < 500000 AND 0.5*distance < 300000 +| STATS count=COUNT() BY country +| SORT count DESC, country ASC +; + +count:long | country:k +3 | Germany +3 | Sweden +1 | Denmark +1 | Poland +; + +airportsWithinEvalDistanceBandCopenhagenTrainStationCountNonPushableWhereDisjunctive +required_capability: st_distance + +FROM airports +| EVAL distance = ST_DISTANCE(location, TO_GEOPOINT("POINT(12.565 55.673)")) +| WHERE distance < 500000 OR 0.5*distance < 300000 +| STATS count=COUNT() BY country +| SORT count DESC, country ASC +; + +count:long | country:k +5 | Sweden +4 | Germany +2 | Norway +1 | Denmark +1 | Lithuania +1 | Poland +; + +airportsSortCityName +FROM airports +| SORT abbrev +| LIMIT 5 +| KEEP abbrev, name, location, country, city +; + +abbrev:keyword | name:text | location:geo_point | country:keyword | city:keyword +ABJ | Abidjan Port Bouet | POINT(-3.93221929167636 5.2543984451492) | Côte d'Ivoire | Abidjan +ABQ | Albuquerque Int'l | POINT(-106.6166851616 35.0491578018276) | United States | Albuquerque +ABV | Abuja Int'l | POINT(7.27025993974356 9.00437659781094) | Nigeria | Abuja +ACA | General Juan N Alvarez Int'l | POINT(-99.7545085619681 16.76196735278) | Mexico | Acapulco de Juárez +ACC | Kotoka Int'l | POINT(-0.171402855660817 5.60698152381193) | Ghana | Accra +; + airportsSortDistanceFromCopenhagenTrainStation required_capability: st_distance +FROM airports +| EVAL distance = ST_DISTANCE(location, TO_GEOPOINT("POINT(12.565 55.673)")) +| SORT distance ASC +| LIMIT 5 +| KEEP abbrev, name, location, country, city +; + +abbrev:k | name:text | location:geo_point | country:k | city:k +CPH | Copenhagen | POINT(12.6493508684508 55.6285017221528) | Denmark | Copenhagen +GOT | Gothenburg | POINT(12.2938269092573 57.6857493534879) | Sweden | Gothenburg +HAM | Hamburg | POINT(10.005647830925 53.6320011640866) | Germany | Norderstedt +TXL | Berlin-Tegel Int'l | POINT(13.2903090925074 52.5544287044101) | Germany | Hohen Neuendorf +BRE | Bremen | POINT(8.7858617703132 53.052287104156) | Germany | Bremen +; + +airportsSortDistanceFromCopenhagenTrainStationInline +required_capability: st_distance +required_capability: spatial_distance_pushdown_enhancements + +FROM airports +| SORT ST_DISTANCE(location, TO_GEOPOINT("POINT(12.565 55.673)")) ASC +| LIMIT 5 +| KEEP abbrev, name, location, country, city +; + +abbrev:k | name:text | location:geo_point | country:k | city:k +CPH | Copenhagen | POINT(12.6493508684508 55.6285017221528) | Denmark | Copenhagen +GOT | Gothenburg | POINT(12.2938269092573 57.6857493534879) | Sweden | Gothenburg +HAM | Hamburg | POINT(10.005647830925 53.6320011640866) | Germany | Norderstedt +TXL | Berlin-Tegel Int'l | POINT(13.2903090925074 52.5544287044101) | Germany | Hohen Neuendorf +BRE | Bremen | POINT(8.7858617703132 53.052287104156) | Germany | Bremen +; + +airportsSortDistanceFromCopenhagenTrainStationDetails +required_capability: st_distance + FROM airports | EVAL distance = ST_DISTANCE(location, TO_GEOPOINT("POINT(12.565 55.673)")) | SORT distance ASC @@ -1136,6 +1274,26 @@ TXL | Berlin-Tegel Int'l | POINT(13.2903090925074 52.5544287044101) | BRE | Bremen | POINT(8.7858617703132 53.052287104156) | Germany | Bremen | POINT(8.8 53.0833) | 380.5 | 377.22 ; +airportsSortDistanceFromCopenhagenTrainStationDetailsAndNonPushableEval +required_capability: st_distance + +FROM airports +| EVAL distance = ST_DISTANCE(location, TO_GEOPOINT("POINT(12.565 55.673)")), position = location::keyword +| WHERE distance < 600000 AND SUBSTRING(position, 1, 5) == "POINT" +| SORT distance ASC +| LIMIT 5 +| EVAL distance = ROUND(distance/1000,2) +| KEEP abbrev, name, position, distance +; + +abbrev:k | name:text | position:keyword | distance:d +CPH | Copenhagen | POINT (12.6493508684508 55.6285017221528) | 7.24 +GOT | Gothenburg | POINT (12.2938269092573 57.6857493534879) | 224.42 +HAM | Hamburg | POINT (10.005647830925 53.6320011640866) | 280.34 +TXL | Berlin-Tegel Int'l | POINT (13.2903090925074 52.5544287044101) | 349.97 +BRE | Bremen | POINT (8.7858617703132 53.052287104156) | 380.5 +; + airportsSortDistanceFromAirportToCity required_capability: st_distance diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/topN.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/topN.csv-spec index e09bc933340d1..3d4d890546050 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/topN.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/topN.csv-spec @@ -134,3 +134,37 @@ Otmar |Herbst |[-8.19, -1.9, -0.32] |[-0.32, -1.90, null |Swan |-8.46 |-8.46 |-8 |10034 Sanjiv |Zschoche |[-7.67, -3.25] |[-3.25, -7.67] |[-3, -8] |10053 ; + +sortingOnSwappedFields +FROM employees +| EVAL name = last_name, last_name = first_name, first_name = name +| WHERE first_name > "B" AND last_name IS NOT NULL +| SORT name +| LIMIT 10 +| KEEP name, last_name, first_name +; + +name:keyword | last_name:keyword | first_name:keyword +Baek | Premal | Baek +Bamford | Parto | Bamford +Bernatsky | Mokhtar | Bernatsky +Bernini | Brendon | Bernini +Berztiss | Yongqiao | Berztiss +Bierman | Margareta | Bierman +Billingsley | Breannda | Billingsley +Bouloucos | Cristinel | Bouloucos +Brattka | Charlene | Brattka +Bridgland | Patricio | Bridgland +; + +sortingOnSwappedFieldsNoKeep +// Note that this test requires all fields to be returned in order to test a specific code path in physical planning +FROM employees +| EVAL name = first_name, first_name = last_name, last_name = name +| WHERE first_name == "Bernini" AND last_name == "Brendon" +| SORT name +; + +avg_worked_seconds:long | birth_date:date | emp_no:i | gender:k | height:d | height.float:d | height.half_float:d | height.scaled_float:d | hire_date:date | is_rehired:bool | job_positions:k | languages:i | languages.byte:i | languages.long:l | languages.short:short | salary:i | salary_change:d | salary_change.int:i | salary_change.keyword:k | salary_change.long:l | still_hired:bool | name:k | first_name:k | last_name:k +349086555 | 1961-09-01T00:00:00Z | 10056 | F | 1.57 | 1.5700000524520874 | 1.5703125 | 1.57 | 1990-02-01T00:00:00Z | [false, false, true] | [Senior Team Lead] | 2 | 2 | 2 | 2 | 33370 | [-5.17, 10.99] | [-5, 10] | [-5.17, 10.99] | [-5, 10] | true | Brendon | Bernini | Brendon +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index f7454e41a6a8b..9aa4d874c53e2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -165,6 +165,11 @@ public enum Cap { */ SPATIAL_PREDICATES_SUPPORT_MULTIVALUES, + /** + * Support a number of fixes and enhancements to spatial distance pushdown. Done in #112938. + */ + SPATIAL_DISTANCE_PUSHDOWN_ENHANCEMENTS, + /** * Fix to GROK and DISSECT that allows extracting attributes with the same name as the input * https://github.com/elastic/elasticsearch/issues/110184 diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java index a29e16139dde7..647a29b71c5e1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java @@ -552,10 +552,9 @@ private static Failure validateUnsignedLongNegation(Neg neg) { */ private static void checkForSortOnSpatialTypes(LogicalPlan p, Set localFailures) { if (p instanceof OrderBy ob) { - ob.forEachExpression(Attribute.class, attr -> { - DataType dataType = attr.dataType(); - if (DataType.isSpatial(dataType)) { - localFailures.add(fail(attr, "cannot sort on " + dataType.typeName())); + ob.order().forEach(order -> { + if (DataType.isSpatial(order.dataType())) { + localFailures.add(fail(order, "cannot sort on " + order.dataType().typeName())); } }); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/UnsupportedAttribute.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/UnsupportedAttribute.java index 931c5d57107cf..08c249662c7d2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/UnsupportedAttribute.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/UnsupportedAttribute.java @@ -76,7 +76,7 @@ public UnsupportedAttribute(Source source, String name, UnsupportedEsField field private UnsupportedAttribute(StreamInput in) throws IOException { this( Source.readFrom((PlanStreamInput) in), - in.readString(), + ((PlanStreamInput) in).readCachedString(), in.getTransportVersion().onOrAfter(TransportVersions.ESQL_ES_FIELD_CACHED_SERIALIZATION) || in.getTransportVersion().isPatchFrom(TransportVersions.ESQL_ATTRIBUTE_CACHED_SERIALIZATION_8_15) ? EsField.readFrom(in) @@ -90,7 +90,7 @@ private UnsupportedAttribute(StreamInput in) throws IOException { public void writeTo(StreamOutput out) throws IOException { if (((PlanStreamOutput) out).writeAttributeCacheHeader(this)) { Source.EMPTY.writeTo(out); - out.writeString(name()); + ((PlanStreamOutput) out).writeCachedString(name()); if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_ES_FIELD_CACHED_SERIALIZATION) || out.getTransportVersion().isPatchFrom(TransportVersions.ESQL_ATTRIBUTE_CACHED_SERIALIZATION_8_15)) { field().writeTo(out); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/BinarySpatialFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/BinarySpatialFunction.java index 72dd052fc7637..5e8d39217fcca 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/BinarySpatialFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/BinarySpatialFunction.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.util.List; +import java.util.Objects; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; @@ -80,6 +81,29 @@ public void writeTo(StreamOutput out) throws IOException { // The CRS type is re-resolved from the combination of left and right fields, and also not necessary to serialize } + /** + * Mark the function as expecting the specified fields to arrive as doc-values. + */ + public abstract BinarySpatialFunction withDocValues(boolean foundLeft, boolean foundRight); + + @Override + public int hashCode() { + // NB: the hashcode is currently used for key generation so + // to avoid clashes between aggs with the same arguments, add the class name as variation + return Objects.hash(getClass(), children(), leftDocValues, rightDocValues); + } + + @Override + public boolean equals(Object obj) { + if (super.equals(obj)) { + BinarySpatialFunction other = (BinarySpatialFunction) obj; + return Objects.equals(other.children(), children()) + && Objects.equals(other.leftDocValues, leftDocValues) + && Objects.equals(other.rightDocValues, rightDocValues); + } + return false; + } + @Override protected TypeResolution resolveType() { return spatialTypeResolver.resolveType(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContains.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContains.java index 7f578565f81f2..9189c6a7b8f70 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContains.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContains.java @@ -26,7 +26,6 @@ import org.elasticsearch.lucene.spatial.CoordinateEncoder; import org.elasticsearch.lucene.spatial.GeometryDocValueReader; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -38,7 +37,6 @@ import java.io.IOException; import java.util.HashMap; import java.util.Map; -import java.util.Set; import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_POINT; import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_SHAPE; @@ -198,10 +196,10 @@ public ShapeRelation queryRelation() { } @Override - public SpatialContains withDocValues(Set attributes) { + public SpatialContains withDocValues(boolean foundLeft, boolean foundRight) { // Only update the docValues flags if the field is found in the attributes - boolean leftDV = leftDocValues || foundField(left(), attributes); - boolean rightDV = rightDocValues || foundField(right(), attributes); + boolean leftDV = leftDocValues || foundLeft; + boolean rightDV = rightDocValues || foundRight; return new SpatialContains(source(), left(), right(), leftDV, rightDV); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialDisjoint.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialDisjoint.java index 47d19ebae884b..ee78f50c4d6bd 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialDisjoint.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialDisjoint.java @@ -23,7 +23,6 @@ import org.elasticsearch.lucene.spatial.CoordinateEncoder; import org.elasticsearch.lucene.spatial.GeometryDocValueReader; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -35,7 +34,6 @@ import java.io.IOException; import java.util.HashMap; import java.util.Map; -import java.util.Set; import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_POINT; import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_SHAPE; @@ -113,10 +111,10 @@ public ShapeRelation queryRelation() { } @Override - public SpatialDisjoint withDocValues(Set attributes) { + public SpatialDisjoint withDocValues(boolean foundLeft, boolean foundRight) { // Only update the docValues flags if the field is found in the attributes - boolean leftDV = leftDocValues || foundField(left(), attributes); - boolean rightDV = rightDocValues || foundField(right(), attributes); + boolean leftDV = leftDocValues || foundLeft; + boolean rightDV = rightDocValues || foundRight; return new SpatialDisjoint(source(), left(), right(), leftDV, rightDV); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialIntersects.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialIntersects.java index 8e287baeaa9b8..8d54e5ee443c2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialIntersects.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialIntersects.java @@ -23,7 +23,6 @@ import org.elasticsearch.lucene.spatial.CoordinateEncoder; import org.elasticsearch.lucene.spatial.GeometryDocValueReader; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -35,7 +34,6 @@ import java.io.IOException; import java.util.HashMap; import java.util.Map; -import java.util.Set; import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_POINT; import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_SHAPE; @@ -111,10 +109,10 @@ public ShapeRelation queryRelation() { } @Override - public SpatialIntersects withDocValues(Set attributes) { + public SpatialIntersects withDocValues(boolean foundLeft, boolean foundRight) { // Only update the docValues flags if the field is found in the attributes - boolean leftDV = leftDocValues || foundField(left(), attributes); - boolean rightDV = rightDocValues || foundField(right(), attributes); + boolean leftDV = leftDocValues || foundLeft; + boolean rightDV = rightDocValues || foundRight; return new SpatialIntersects(source(), left(), right(), leftDV, rightDV); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesFunction.java index ee2b4450a64ff..8ca89334b059b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesFunction.java @@ -22,7 +22,6 @@ import org.elasticsearch.lucene.spatial.CoordinateEncoder; import org.elasticsearch.lucene.spatial.GeometryDocValueReader; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.util.SpatialCoordinateTypes; @@ -30,9 +29,6 @@ import java.io.IOException; import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.function.Predicate; import static org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialRelatesUtils.asGeometryDocValueReader; import static org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialRelatesUtils.asLuceneComponent2D; @@ -57,47 +53,6 @@ public DataType dataType() { return DataType.BOOLEAN; } - /** - * Mark the function as expecting the specified fields to arrive as doc-values. - */ - public abstract SpatialRelatesFunction withDocValues(Set attributes); - - /** - * Push-down to Lucene is only possible if one field is an indexed spatial field, and the other is a constant spatial or string column. - */ - public boolean canPushToSource(Predicate isAggregatable) { - // The use of foldable here instead of SpatialEvaluatorFieldKey.isConstant is intentional to match the behavior of the - // Lucene pushdown code in EsqlTranslationHandler::SpatialRelatesTranslator - // We could enhance both places to support ReferenceAttributes that refer to constants, but that is a larger change - return isPushableFieldAttribute(left(), isAggregatable) && right().foldable() - || isPushableFieldAttribute(right(), isAggregatable) && left().foldable(); - } - - private static boolean isPushableFieldAttribute(Expression exp, Predicate isAggregatable) { - return exp instanceof FieldAttribute fa - && fa.getExactInfo().hasExact() - && isAggregatable.test(fa) - && DataType.isSpatial(fa.dataType()); - } - - @Override - public int hashCode() { - // NB: the hashcode is currently used for key generation so - // to avoid clashes between aggs with the same arguments, add the class name as variation - return Objects.hash(getClass(), children(), leftDocValues, rightDocValues); - } - - @Override - public boolean equals(Object obj) { - if (super.equals(obj)) { - SpatialRelatesFunction other = (SpatialRelatesFunction) obj; - return Objects.equals(other.children(), children()) - && Objects.equals(other.leftDocValues, leftDocValues) - && Objects.equals(other.rightDocValues, rightDocValues); - } - return false; - } - /** * Produce a map of rules defining combinations of incoming types to the evaluator factory that should be used. */ @@ -115,19 +70,6 @@ public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvalua return SpatialEvaluatorFactory.makeSpatialEvaluator(this, evaluatorRules(), toEvaluator); } - /** - * When performing local physical plan optimization, it is necessary to know if this function has a field attribute. - * This is because the planner might push down a spatial aggregation to lucene, which results in the field being provided - * as doc-values instead of source values, and this function needs to know if it should use doc-values or not. - */ - public boolean hasFieldAttribute(Set foundAttributes) { - return foundField(left(), foundAttributes) || foundField(right(), foundAttributes); - } - - protected boolean foundField(Expression expression, Set foundAttributes) { - return expression instanceof FieldAttribute field && foundAttributes.contains(field); - } - protected static class SpatialRelations extends BinarySpatialComparator { protected final ShapeField.QueryRelation queryRelation; protected final ShapeIndexer shapeIndexer; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialWithin.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialWithin.java index 84ea5b86f1d40..2005709cd37e9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialWithin.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialWithin.java @@ -23,7 +23,6 @@ import org.elasticsearch.lucene.spatial.CoordinateEncoder; import org.elasticsearch.lucene.spatial.GeometryDocValueReader; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -36,7 +35,6 @@ import java.io.IOException; import java.util.HashMap; import java.util.Map; -import java.util.Set; import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_POINT; import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_SHAPE; @@ -113,10 +111,10 @@ public ShapeRelation queryRelation() { } @Override - public SpatialWithin withDocValues(Set attributes) { + public SpatialWithin withDocValues(boolean foundLeft, boolean foundRight) { // Only update the docValues flags if the field is found in the attributes - boolean leftDV = leftDocValues || foundField(left(), attributes); - boolean rightDV = rightDocValues || foundField(right(), attributes); + boolean leftDV = leftDocValues || foundLeft; + boolean rightDV = rightDocValues || foundRight; return new SpatialWithin(source(), left(), right(), leftDV, rightDV); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StDistance.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StDistance.java index 17bcc68004bff..ae9d3383bad39 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StDistance.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StDistance.java @@ -148,6 +148,14 @@ private StDistance(StreamInput in) throws IOException { super(in, false, false, true); } + @Override + public StDistance withDocValues(boolean foundLeft, boolean foundRight) { + // Only update the docValues flags if the field is found in the attributes + boolean leftDV = leftDocValues || foundLeft; + boolean rightDV = rightDocValues || foundRight; + return new StDistance(source(), left(), right(), leftDV, rightDV); + } + @Override public String getWriteableName() { return ENTRY.name; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java index 2b09a395c4a3d..c832f64363048 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java @@ -62,9 +62,11 @@ public NameId apply(long streamNameId) { private final Map cachedBlocks = new HashMap<>(); - private Attribute[] attributesCache = new Attribute[64]; + private Attribute[] attributesCache = new Attribute[1024]; - private EsField[] esFieldsCache = new EsField[64]; + private EsField[] esFieldsCache = new EsField[1024]; + + private String[] stringCache = new String[1024]; // hook for nameId, where can cache and map, for now just return a NameId of the same long value. private final LongFunction nameIdFunction; @@ -195,10 +197,11 @@ public A readAttributeWithCache(CheckedFunction A readEsFieldWithCache() throws IOException { // it's safe to cast to int, since the max value for this is {@link PlanStreamOutput#MAX_SERIALIZED_ATTRIBUTES} int cacheId = Math.toIntExact(readZLong()); if (cacheId < 0) { - String className = readString(); + String className = readCachedString(); Writeable.Reader reader = EsField.getReader(className); cacheId = -1 - cacheId; EsField result = reader.read(this); @@ -231,17 +234,37 @@ public A readEsFieldWithCache() throws IOException { return (A) esFieldFromCache(cacheId); } } else { - String className = readString(); + String className = readCachedString(); Writeable.Reader reader = EsField.getReader(className); return (A) reader.read(this); } } + /** + * Reads a cached string, serialized with {@link PlanStreamOutput#writeCachedString(String)}. + */ + @Override + public String readCachedString() throws IOException { + if (getTransportVersion().before(TransportVersions.ESQL_CACHED_STRING_SERIALIZATION)) { + return readString(); + } + int cacheId = Math.toIntExact(readZLong()); + if (cacheId < 0) { + String string = readString(); + cacheId = -1 - cacheId; + cacheString(cacheId, string); + return string; + } else { + return stringFromCache(cacheId); + } + } + private EsField esFieldFromCache(int id) throws IOException { - if (esFieldsCache[id] == null) { + EsField field = esFieldsCache[id]; + if (field == null) { throw new IOException("Attribute ID not found in serialization cache [" + id + "]"); } - return esFieldsCache[id]; + return field; } /** @@ -257,4 +280,27 @@ private void cacheEsField(int id, EsField field) { esFieldsCache[id] = field; } + private String stringFromCache(int id) throws IOException { + String value = stringCache[id]; + if (value == null) { + throw new IOException("String not found in serialization cache [" + id + "]"); + } + return value; + } + + private void cacheString(int id, String string) { + assert id >= 0; + if (id >= stringCache.length) { + stringCache = ArrayUtil.grow(stringCache, id + 1); + } + stringCache[id] = string; + } + + @Override + public void close() throws IOException { + super.close(); + this.stringCache = null; + this.attributesCache = null; + this.esFieldsCache = null; + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java index fe66b799195f8..5e31a57ed669b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java @@ -26,6 +26,7 @@ import org.elasticsearch.xpack.esql.session.Configuration; import java.io.IOException; +import java.util.HashMap; import java.util.IdentityHashMap; import java.util.Map; @@ -65,6 +66,8 @@ public final class PlanStreamOutput extends StreamOutput implements org.elastics */ protected final Map cachedEsFields = new IdentityHashMap<>(); + protected final Map stringCache = new HashMap<>(); + private final StreamOutput delegate; private int nextCachedBlock = 0; @@ -105,6 +108,9 @@ public void flush() throws IOException { @Override public void close() throws IOException { delegate.close(); + stringCache.clear(); + cachedEsFields.clear(); + cachedAttributes.clear(); } @Override @@ -121,10 +127,10 @@ public void setTransportVersion(TransportVersion version) { /** * Write a {@link Block} as part of the plan. *

- * These {@link Block}s are not tracked by {@link BlockFactory} and closing them - * does nothing so they should be small. We do make sure not to send duplicates, - * reusing blocks sent as part of the {@link Configuration#tables()} if - * possible, otherwise sending a {@linkplain Block} inline. + * These {@link Block}s are not tracked by {@link BlockFactory} and closing them + * does nothing so they should be small. We do make sure not to send duplicates, + * reusing blocks sent as part of the {@link Configuration#tables()} if + * possible, otherwise sending a {@linkplain Block} inline. *

*/ public void writeCachedBlock(Block block) throws IOException { @@ -189,10 +195,37 @@ public boolean writeEsFieldCacheHeader(EsField field) throws IOException { cacheId = cacheEsField(field); writeZLong(-1 - cacheId); } - writeString(field.getWriteableName()); + writeCachedString(field.getWriteableName()); return true; } + /** + * Writes a string caching it, ie. the second time the same string is written, only a small, numeric ID will be sent. + * This should be used only to serialize recurring strings. + * + * Values serialized with this method have to be deserialized with {@link PlanStreamInput#readCachedString()} + */ + @Override + public void writeCachedString(String string) throws IOException { + if (getTransportVersion().before(TransportVersions.ESQL_CACHED_STRING_SERIALIZATION)) { + writeString(string); + return; + } + Integer cacheId = stringCache.get(string); + if (cacheId != null) { + writeZLong(cacheId); + return; + } + cacheId = stringCache.size(); + if (cacheId >= maxSerializedAttributes) { + throw new InvalidArgumentException("Limit of the number of serialized strings exceeded [{}]", maxSerializedAttributes); + } + stringCache.put(string, cacheId); + + writeZLong(-1 - cacheId); + writeString(string); + } + private Integer esFieldIdFromCache(EsField field) { return cachedEsFields.get(field); } @@ -248,12 +281,12 @@ static BytesReference fromPreviousKey(int id) throws IOException { * This is important because some operations like {@code LOOKUP} frequently read * {@linkplain Block}s directly from the configuration. *

- * It'd be possible to implement this by adding all of the Blocks as "previous" - * keys in the constructor and never use this construct at all, but that'd - * require there be a consistent ordering of Blocks there. We could make one, - * but I'm afraid that'd be brittle as we evolve the code. It'd make wire - * compatibility difficult. This signal is much simpler to deal with even though - * it is more bytes over the wire. + * It'd be possible to implement this by adding all of the Blocks as "previous" + * keys in the constructor and never use this construct at all, but that'd + * require there be a consistent ordering of Blocks there. We could make one, + * but I'm afraid that'd be brittle as we evolve the code. It'd make wire + * compatibility difficult. This signal is much simpler to deal with even though + * it is more bytes over the wire. *

*/ static BytesReference fromConfigKey(String table, String column) throws IOException { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFilters.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFilters.java index e39b590228d57..ed09d0bc16754 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFilters.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFilters.java @@ -7,10 +7,13 @@ package org.elasticsearch.xpack.esql.optimizer.rules.logical; +import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.AttributeMap; import org.elasticsearch.xpack.esql.core.expression.AttributeSet; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; +import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; @@ -25,6 +28,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.function.Function; import java.util.function.Predicate; public final class PushDownAndCombineFilters extends OptimizerRules.OptimizerRule { @@ -43,20 +47,37 @@ protected LogicalPlan rule(Filter filter) { filter, agg, e -> e instanceof Attribute && agg.output().contains(e) && agg.groupings().contains(e) == false - || e instanceof AggregateFunction + || e instanceof AggregateFunction, + NO_OP ); } else if (child instanceof Eval eval) { - // Don't push if Filter (still) contains references of Eval's fields. - var attributes = new AttributeSet(Expressions.asAttributes(eval.fields())); - plan = maybePushDownPastUnary(filter, eval, attributes::contains); + // Don't push if Filter (still) contains references to Eval's fields. + // Account for simple aliases in the Eval, though - these shouldn't stop us. + AttributeMap.Builder aliasesBuilder = AttributeMap.builder(); + for (Alias alias : eval.fields()) { + aliasesBuilder.put(alias.toAttribute(), alias.child()); + } + AttributeMap evalAliases = aliasesBuilder.build(); + + Function resolveRenames = expr -> expr.transformDown(ReferenceAttribute.class, r -> { + Expression resolved = evalAliases.resolve(r, null); + // Avoid resolving to an intermediate attribute that only lives inside the Eval - only replace if the attribute existed + // before the Eval. + if (resolved instanceof Attribute && eval.inputSet().contains(resolved)) { + return resolved; + } + return r; + }); + + plan = maybePushDownPastUnary(filter, eval, evalAliases::containsKey, resolveRenames); } else if (child instanceof RegexExtract re) { // Push down filters that do not rely on attributes created by RegexExtract var attributes = new AttributeSet(Expressions.asAttributes(re.extractedFields())); - plan = maybePushDownPastUnary(filter, re, attributes::contains); + plan = maybePushDownPastUnary(filter, re, attributes::contains, NO_OP); } else if (child instanceof Enrich enrich) { // Push down filters that do not rely on attributes created by Enrich var attributes = new AttributeSet(Expressions.asAttributes(enrich.enrichFields())); - plan = maybePushDownPastUnary(filter, enrich, attributes::contains); + plan = maybePushDownPastUnary(filter, enrich, attributes::contains, NO_OP); } else if (child instanceof Project) { return PushDownUtils.pushDownPastProject(filter); } else if (child instanceof OrderBy orderBy) { @@ -67,21 +88,35 @@ protected LogicalPlan rule(Filter filter) { return plan; } - private static LogicalPlan maybePushDownPastUnary(Filter filter, UnaryPlan unary, Predicate cannotPush) { + private static Function NO_OP = expression -> expression; + + private static LogicalPlan maybePushDownPastUnary( + Filter filter, + UnaryPlan unary, + Predicate cannotPush, + Function resolveRenames + ) { LogicalPlan plan; List pushable = new ArrayList<>(); List nonPushable = new ArrayList<>(); for (Expression exp : Predicates.splitAnd(filter.condition())) { - (exp.anyMatch(cannotPush) ? nonPushable : pushable).add(exp); + Expression resolvedExp = resolveRenames.apply(exp); + if (resolvedExp.anyMatch(cannotPush)) { + // Add the original expression to the non-pushables. + nonPushable.add(exp); + } else { + // When we can push down, we use the resolved expression. + pushable.add(resolvedExp); + } } // Push the filter down even if it might not be pushable all the way to ES eventually: eval'ing it closer to the source, // potentially still in the Exec Engine, distributes the computation. - if (pushable.size() > 0) { - if (nonPushable.size() > 0) { - Filter pushed = new Filter(filter.source(), unary.child(), Predicates.combineAnd(pushable)); + if (pushable.isEmpty() == false) { + Filter pushed = filter.with(unary.child(), Predicates.combineAnd(pushable)); + if (nonPushable.isEmpty() == false) { plan = filter.with(unary.replaceChild(pushed), Predicates.combineAnd(nonPushable)); } else { - plan = unary.replaceChild(filter.with(unary.child(), filter.condition())); + plan = unary.replaceChild(pushed); } } else { plan = filter; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/EnableSpatialDistancePushdown.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/EnableSpatialDistancePushdown.java index e27418c2cf6a9..be6e124502ba5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/EnableSpatialDistancePushdown.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/EnableSpatialDistancePushdown.java @@ -12,8 +12,14 @@ import org.elasticsearch.geometry.Geometry; import org.elasticsearch.geometry.Point; import org.elasticsearch.geometry.utils.WellKnownBinary; +import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.AttributeMap; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.NameId; +import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; +import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -25,17 +31,42 @@ import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerRules; import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec; +import org.elasticsearch.xpack.esql.plan.physical.EvalExec; import org.elasticsearch.xpack.esql.plan.physical.FilterExec; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.esql.core.expression.predicate.Predicates.splitAnd; +import static org.elasticsearch.xpack.esql.optimizer.rules.physical.local.PushFiltersToSource.canPushSpatialFunctionToSource; +import static org.elasticsearch.xpack.esql.optimizer.rules.physical.local.PushFiltersToSource.canPushToSource; +import static org.elasticsearch.xpack.esql.optimizer.rules.physical.local.PushFiltersToSource.getAliasReplacedBy; /** * When a spatial distance predicate can be pushed down to lucene, this is done by capturing the distance within the same function. * In principle this is like re-writing the predicate: *
WHERE ST_DISTANCE(field, TO_GEOPOINT("POINT(0 0)")) <= 10000
* as: - *
WHERE ST_INTERSECTS(field, TO_GEOSHAPE("CIRCLE(0,0,10000)"))
+ *
WHERE ST_INTERSECTS(field, TO_GEOSHAPE("CIRCLE(0,0,10000)"))
. + *

+ * In addition, since the distance could be calculated in a preceding EVAL command, we also need to consider the case: + *

+ *     FROM index
+ *     | EVAL distance = ST_DISTANCE(field, TO_GEOPOINT("POINT(0 0)"))
+ *     | WHERE distance <= 10000
+ * 
+ * And re-write that as: + *
+ *     FROM index
+ *     | WHERE ST_INTERSECTS(field, TO_GEOSHAPE("CIRCLE(0,0,10000)"))
+ *     | EVAL distance = ST_DISTANCE(field, TO_GEOPOINT("POINT(0 0)"))
+ * 
+ * Note that the WHERE clause is both rewritten to an intersection and pushed down closer to the EsQueryExec, + * which allows the predicate to be pushed down to Lucene in a later rule, PushFiltersToSource. */ public class EnableSpatialDistancePushdown extends PhysicalOptimizerRules.ParameterizedOptimizerRule< FilterExec, @@ -44,23 +75,121 @@ public class EnableSpatialDistancePushdown extends PhysicalOptimizerRules.Parame @Override protected PhysicalPlan rule(FilterExec filterExec, LocalPhysicalOptimizerContext ctx) { PhysicalPlan plan = filterExec; - if (filterExec.child() instanceof EsQueryExec) { + if (filterExec.child() instanceof EsQueryExec esQueryExec) { + plan = rewrite(filterExec, esQueryExec); + } else if (filterExec.child() instanceof EvalExec evalExec && evalExec.child() instanceof EsQueryExec esQueryExec) { + plan = rewriteBySplittingFilter(filterExec, evalExec, esQueryExec); + } + + return plan; + } + + private FilterExec rewrite(FilterExec filterExec, EsQueryExec esQueryExec) { + // Find and rewrite any binary comparisons that involve a distance function and a literal + var rewritten = filterExec.condition().transformDown(EsqlBinaryComparison.class, comparison -> { + ComparisonType comparisonType = ComparisonType.from(comparison.getFunctionType()); + if (comparison.left() instanceof StDistance dist && comparison.right().foldable()) { + return rewriteComparison(comparison, dist, comparison.right(), comparisonType); + } else if (comparison.right() instanceof StDistance dist && comparison.left().foldable()) { + return rewriteComparison(comparison, dist, comparison.left(), ComparisonType.invert(comparisonType)); + } + return comparison; + }); + if (rewritten.equals(filterExec.condition()) == false) { + return new FilterExec(filterExec.source(), esQueryExec, rewritten); + } + return filterExec; + } + + /** + * This version of the rewrite will try to split the filter into two parts, one that can be pushed down to the source and + * one that needs to be kept after the EVAL. + * For example: + *
+     *     FROM index
+     *     | EVAL distance = ST_DISTANCE(field, TO_GEOPOINT("POINT(0 0)")), other = scale * 2
+     *     | WHERE distance <= 10000 AND distance > 5000 AND other > 10
+     * 
+ * Should be rewritten as: + *
+     *     FROM index
+     *     | WHERE ST_INTERSECTS(field, TO_GEOSHAPE("CIRCLE(0,0,10000)"))
+     *         AND ST_DISJOINT(field, TO_GEOSHAPE("CIRCLE(0,0,5000)"))
+     *     | EVAL distance = ST_DISTANCE(field, TO_GEOPOINT("POINT(0 0)")), other = scale * 2
+     *     | WHERE other > 10
+     * 
+ */ + private PhysicalPlan rewriteBySplittingFilter(FilterExec filterExec, EvalExec evalExec, EsQueryExec esQueryExec) { + // Find all pushable distance functions in the EVAL + Map distances = getPushableDistances(evalExec.fields()); + + // Don't do anything if there are no distances to push down + if (distances.isEmpty()) { + return filterExec; + } + + // Process the EVAL to get all aliases that might be needed in the filter rewrite + AttributeMap aliasReplacedBy = getAliasReplacedBy(evalExec); + + // First we split the filter into multiple AND'd expressions, and then we evaluate each individually for distance rewrites + List pushable = new ArrayList<>(); + List nonPushable = new ArrayList<>(); + for (Expression exp : splitAnd(filterExec.condition())) { + Expression resExp = exp.transformUp(ReferenceAttribute.class, r -> aliasReplacedBy.resolve(r, r)); // Find and rewrite any binary comparisons that involve a distance function and a literal - var rewritten = filterExec.condition().transformDown(EsqlBinaryComparison.class, comparison -> { - ComparisonType comparisonType = ComparisonType.from(comparison.getFunctionType()); - if (comparison.left() instanceof StDistance dist && comparison.right().foldable()) { - return rewriteComparison(comparison, dist, comparison.right(), comparisonType); - } else if (comparison.right() instanceof StDistance dist && comparison.left().foldable()) { - return rewriteComparison(comparison, dist, comparison.left(), ComparisonType.invert(comparisonType)); - } - return comparison; - }); - if (rewritten.equals(filterExec.condition()) == false) { - plan = new FilterExec(filterExec.source(), filterExec.child(), rewritten); + var rewritten = rewriteDistanceFilters(resExp, distances); + // If all pushable StDistance functions were found and re-written, we need to re-write the FILTER/EVAL combination + if (rewritten.equals(resExp) == false && canPushToSource(rewritten, x -> false)) { + pushable.add(rewritten); + } else { + nonPushable.add(exp); } } - return plan; + // If nothing pushable was rewritten, we can return the original filter + if (pushable.isEmpty()) { + return filterExec; + } + + // Move the rewritten pushable filters below the EVAL + var distanceFilter = new FilterExec(filterExec.source(), esQueryExec, Predicates.combineAnd(pushable)); + var newEval = new EvalExec(evalExec.source(), distanceFilter, evalExec.fields()); + if (nonPushable.isEmpty()) { + // No other filters found, we can just return the original eval with the new filter as child + return newEval; + } else { + // Some other filters found, we need to return two filters with the eval in between + return new FilterExec(filterExec.source(), newEval, Predicates.combineAnd(nonPushable)); + } + } + + private Map getPushableDistances(List aliases) { + Map distances = new LinkedHashMap<>(); + aliases.forEach(alias -> { + if (alias.child() instanceof StDistance distance && canPushSpatialFunctionToSource(distance)) { + distances.put(alias.id(), distance); + } else if (alias.child() instanceof ReferenceAttribute ref && distances.containsKey(ref.id())) { + StDistance distance = distances.get(ref.id()); + distances.put(alias.id(), distance); + } + }); + return distances; + } + + private Expression rewriteDistanceFilters(Expression expr, Map distances) { + return expr.transformDown(EsqlBinaryComparison.class, comparison -> { + ComparisonType comparisonType = ComparisonType.from(comparison.getFunctionType()); + if (comparison.left() instanceof ReferenceAttribute r && distances.containsKey(r.id()) && comparison.right().foldable()) { + StDistance dist = distances.get(r.id()); + return rewriteComparison(comparison, dist, comparison.right(), comparisonType); + } else if (comparison.right() instanceof ReferenceAttribute r + && distances.containsKey(r.id()) + && comparison.left().foldable()) { + StDistance dist = distances.get(r.id()); + return rewriteComparison(comparison, dist, comparison.left(), ComparisonType.invert(comparisonType)); + } + return comparison; + }); } private Expression rewriteComparison( @@ -117,7 +246,7 @@ private Literal makeCircleLiteral(Point point, double distance, Expression liter /** * This enum captures the key differences between various inequalities as perceived from the spatial distance function. * In particular, we need to know which direction the inequality points, with lt=true meaning the left is expected to be smaller - * than the right. And eq=true meaning we expect euality as well. We currently don't support Equals and NotEquals, so the third + * than the right. And eq=true meaning we expect equality as well. We currently don't support Equals and NotEquals, so the third * field disables those. */ enum ComparisonType { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushFiltersToSource.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushFiltersToSource.java index 0a71bce2575fa..1ba966e318219 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushFiltersToSource.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushFiltersToSource.java @@ -8,10 +8,14 @@ package org.elasticsearch.xpack.esql.optimizer.rules.physical.local; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.AttributeMap; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; +import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; import org.elasticsearch.xpack.esql.core.expression.function.scalar.UnaryScalarFunction; import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates; import org.elasticsearch.xpack.esql.core.expression.predicate.Range; @@ -30,6 +34,7 @@ import org.elasticsearch.xpack.esql.core.util.Queries; import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.ip.CIDRMatch; +import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.BinarySpatialFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialRelatesFunction; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison; @@ -43,6 +48,7 @@ import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerRules; import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec; +import org.elasticsearch.xpack.esql.plan.physical.EvalExec; import org.elasticsearch.xpack.esql.plan.physical.FilterExec; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.esql.planner.PlannerUtils; @@ -53,6 +59,7 @@ import static java.util.Arrays.asList; import static org.elasticsearch.xpack.esql.core.expression.predicate.Predicates.splitAnd; +import static org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushDownUtils.isAggregatable; public class PushFiltersToSource extends PhysicalOptimizerRules.ParameterizedOptimizerRule { @@ -60,40 +67,88 @@ public class PushFiltersToSource extends PhysicalOptimizerRules.ParameterizedOpt protected PhysicalPlan rule(FilterExec filterExec, LocalPhysicalOptimizerContext ctx) { PhysicalPlan plan = filterExec; if (filterExec.child() instanceof EsQueryExec queryExec) { - List pushable = new ArrayList<>(); - List nonPushable = new ArrayList<>(); - for (Expression exp : splitAnd(filterExec.condition())) { - (canPushToSource(exp, x -> LucenePushDownUtils.hasIdenticalDelegate(x, ctx.searchStats())) ? pushable : nonPushable).add( - exp - ); - } - // Combine GT, GTE, LT and LTE in pushable to Range if possible - List newPushable = combineEligiblePushableToRange(pushable); - if (newPushable.size() > 0) { // update the executable with pushable conditions - Query queryDSL = PlannerUtils.TRANSLATOR_HANDLER.asQuery(Predicates.combineAnd(newPushable)); - QueryBuilder planQuery = queryDSL.asBuilder(); - var query = Queries.combine(Queries.Clause.FILTER, asList(queryExec.query(), planQuery)); - queryExec = new EsQueryExec( - queryExec.source(), - queryExec.index(), - queryExec.indexMode(), - queryExec.output(), - query, - queryExec.limit(), - queryExec.sorts(), - queryExec.estimatedRowSize() - ); - if (nonPushable.size() > 0) { // update filter with remaining non-pushable conditions - plan = new FilterExec(filterExec.source(), queryExec, Predicates.combineAnd(nonPushable)); - } else { // prune Filter entirely - plan = queryExec; - } - } // else: nothing changes + plan = planFilterExec(filterExec, queryExec, ctx); + } else if (filterExec.child() instanceof EvalExec evalExec && evalExec.child() instanceof EsQueryExec queryExec) { + plan = planFilterExec(filterExec, evalExec, queryExec, ctx); } - return plan; } + private static PhysicalPlan planFilterExec(FilterExec filterExec, EsQueryExec queryExec, LocalPhysicalOptimizerContext ctx) { + List pushable = new ArrayList<>(); + List nonPushable = new ArrayList<>(); + for (Expression exp : splitAnd(filterExec.condition())) { + (canPushToSource(exp, x -> LucenePushDownUtils.hasIdenticalDelegate(x, ctx.searchStats())) ? pushable : nonPushable).add(exp); + } + return rewrite(filterExec, queryExec, pushable, nonPushable, List.of()); + } + + private static PhysicalPlan planFilterExec( + FilterExec filterExec, + EvalExec evalExec, + EsQueryExec queryExec, + LocalPhysicalOptimizerContext ctx + ) { + AttributeMap aliasReplacedBy = getAliasReplacedBy(evalExec); + List pushable = new ArrayList<>(); + List nonPushable = new ArrayList<>(); + for (Expression exp : splitAnd(filterExec.condition())) { + Expression resExp = exp.transformUp(ReferenceAttribute.class, r -> aliasReplacedBy.resolve(r, r)); + (canPushToSource(resExp, x -> LucenePushDownUtils.hasIdenticalDelegate(x, ctx.searchStats())) ? pushable : nonPushable).add( + exp + ); + } + // Replace field references with their actual field attributes + pushable.replaceAll(e -> e.transformDown(ReferenceAttribute.class, r -> aliasReplacedBy.resolve(r, r))); + return rewrite(filterExec, queryExec, pushable, nonPushable, evalExec.fields()); + } + + static AttributeMap getAliasReplacedBy(EvalExec evalExec) { + AttributeMap.Builder aliasReplacedByBuilder = AttributeMap.builder(); + evalExec.fields().forEach(alias -> { + if (alias.child() instanceof Attribute attr) { + aliasReplacedByBuilder.put(alias.toAttribute(), attr); + } + }); + return aliasReplacedByBuilder.build(); + } + + private static PhysicalPlan rewrite( + FilterExec filterExec, + EsQueryExec queryExec, + List pushable, + List nonPushable, + List evalFields + ) { + // Combine GT, GTE, LT and LTE in pushable to Range if possible + List newPushable = combineEligiblePushableToRange(pushable); + if (newPushable.size() > 0) { // update the executable with pushable conditions + Query queryDSL = PlannerUtils.TRANSLATOR_HANDLER.asQuery(Predicates.combineAnd(newPushable)); + QueryBuilder planQuery = queryDSL.asBuilder(); + var query = Queries.combine(Queries.Clause.FILTER, asList(queryExec.query(), planQuery)); + queryExec = new EsQueryExec( + queryExec.source(), + queryExec.index(), + queryExec.indexMode(), + queryExec.output(), + query, + queryExec.limit(), + queryExec.sorts(), + queryExec.estimatedRowSize() + ); + // If the eval contains other aliases, not just field attributes, we need to keep them in the plan + PhysicalPlan plan = evalFields.isEmpty() ? queryExec : new EvalExec(filterExec.source(), queryExec, evalFields); + if (nonPushable.size() > 0) { + // update filter with remaining non-pushable conditions + return new FilterExec(filterExec.source(), plan, Predicates.combineAnd(nonPushable)); + } else { + // prune Filter entirely + return plan; + } + } // else: nothing changes + return filterExec; + } + private static List combineEligiblePushableToRange(List pushable) { List bcs = new ArrayList<>(); List ranges = new ArrayList<>(); @@ -189,8 +244,8 @@ public static boolean canPushToSource(Expression exp, Predicate } } else if (exp instanceof CIDRMatch cidrMatch) { return isAttributePushable(cidrMatch.ipField(), cidrMatch, hasIdenticalDelegate) && Expressions.foldable(cidrMatch.matches()); - } else if (exp instanceof SpatialRelatesFunction bc) { - return bc.canPushToSource(LucenePushDownUtils::isAggregatable); + } else if (exp instanceof SpatialRelatesFunction spatial) { + return canPushSpatialFunctionToSource(spatial); } else if (exp instanceof MatchQueryPredicate mqp) { return mqp.field() instanceof FieldAttribute && DataType.isString(mqp.field().dataType()); } else if (exp instanceof StringQueryPredicate) { @@ -201,6 +256,20 @@ public static boolean canPushToSource(Expression exp, Predicate return false; } + /** + * Push-down to Lucene is only possible if one field is an indexed spatial field, and the other is a constant spatial or string column. + */ + public static boolean canPushSpatialFunctionToSource(BinarySpatialFunction s) { + // The use of foldable here instead of SpatialEvaluatorFieldKey.isConstant is intentional to match the behavior of the + // Lucene pushdown code in EsqlTranslationHandler::SpatialRelatesTranslator + // We could enhance both places to support ReferenceAttributes that refer to constants, but that is a larger change + return isPushableSpatialAttribute(s.left()) && s.right().foldable() || isPushableSpatialAttribute(s.right()) && s.left().foldable(); + } + + private static boolean isPushableSpatialAttribute(Expression exp) { + return exp instanceof FieldAttribute fa && fa.getExactInfo().hasExact() && isAggregatable(fa) && DataType.isSpatial(fa.dataType()); + } + private static boolean isAttributePushable( Expression expression, Expression operation, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSource.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSource.java index 87bc344c397c1..6db35fa0a06d6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSource.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSource.java @@ -7,55 +7,217 @@ package org.elasticsearch.xpack.esql.optimizer.rules.physical.local; +import org.elasticsearch.geometry.Geometry; +import org.elasticsearch.geometry.Point; +import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.AttributeMap; +import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.NameId; +import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; import org.elasticsearch.xpack.esql.expression.Order; +import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.BinarySpatialFunction; +import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialRelatesUtils; +import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.StDistance; import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerRules; import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec; +import org.elasticsearch.xpack.esql.plan.physical.EvalExec; import org.elasticsearch.xpack.esql.plan.physical.ExchangeExec; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.esql.plan.physical.TopNExec; import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.List; import java.util.function.Predicate; +/** + * We handle two main scenarios here: + *
    + *
  1. + * Queries like `FROM index | SORT field` will be pushed to the source if the field is an indexed field. + *
  2. + *
  3. + * Queries like `FROM index | EVAL ref = ... | SORT ref` will be pushed to the source if the reference function is pushable, + * which can happen under two conditions: + *
      + *
    • + * The reference refers linearly to an indexed field. + * For example: `FROM index | EVAL ref = field | SORT ref` + *
    • + *
    • + * The reference refers to a distance function that refers to an indexed field and a constant expression. + * For example `FROM index | EVAL distance = ST_DISTANCE(field, POINT(0, 0)) | SORT distance`. + * As with the previous condition, both the attribute and the constant can be further aliased. + *
    • + *
    + *
  4. + *
  5. + *
  6. + *
+ */ public class PushTopNToSource extends PhysicalOptimizerRules.ParameterizedOptimizerRule { @Override protected PhysicalPlan rule(TopNExec topNExec, LocalPhysicalOptimizerContext ctx) { - PhysicalPlan plan = topNExec; - PhysicalPlan child = topNExec.child(); - if (canPushSorts(child) - && canPushDownOrders(topNExec.order(), x -> LucenePushDownUtils.hasIdenticalDelegate(x, ctx.searchStats()))) { + Pushable pushable = evaluatePushable(topNExec, x -> LucenePushDownUtils.hasIdenticalDelegate(x, ctx.searchStats())); + return pushable.rewrite(topNExec); + } + + /** + * Multiple scenarios for pushing down TopN to Lucene source. Each involve checking a combination of conditions and then + * performing an associated rewrite specific to that scenario. This interface should be extended by each scenario, and + * include the appropriate rewrite logic. + */ + interface Pushable { + PhysicalPlan rewrite(TopNExec topNExec); + } + + private static final Pushable NO_OP = new NoOpPushable(); + + record NoOpPushable() implements Pushable { + public PhysicalPlan rewrite(TopNExec topNExec) { + return topNExec; + } + } + + /** + * TODO: Consider deleting this case entirely. We do not know if this is ever hit. + */ + record PushableExchangeExec(ExchangeExec exchangeExec, EsQueryExec queryExec) implements Pushable { + public PhysicalPlan rewrite(TopNExec topNExec) { + var sorts = buildFieldSorts(topNExec.order()); + var limit = topNExec.limit(); + return exchangeExec.replaceChild(queryExec.withSorts(sorts).withLimit(limit)); + } + } + + record PushableQueryExec(EsQueryExec queryExec) implements Pushable { + public PhysicalPlan rewrite(TopNExec topNExec) { var sorts = buildFieldSorts(topNExec.order()); var limit = topNExec.limit(); + return queryExec.withSorts(sorts).withLimit(limit); + } + } + + record PushableGeoDistance(FieldAttribute fieldAttribute, Order order, Point point) { + private EsQueryExec.Sort sort() { + return new EsQueryExec.GeoDistanceSort(fieldAttribute.exactAttribute(), order.direction(), point.getLat(), point.getLon()); + } - if (child instanceof ExchangeExec exchangeExec && exchangeExec.child() instanceof EsQueryExec queryExec) { - plan = exchangeExec.replaceChild(queryExec.withSorts(sorts).withLimit(limit)); - } else { - plan = ((EsQueryExec) child).withSorts(sorts).withLimit(limit); + private static PushableGeoDistance from(StDistance distance, Order order) { + if (distance.left() instanceof Attribute attr && distance.right().foldable()) { + return from(attr, distance.right(), order); + } else if (distance.right() instanceof Attribute attr && distance.left().foldable()) { + return from(attr, distance.left(), order); } + return null; + } + + private static PushableGeoDistance from(Attribute attr, Expression foldable, Order order) { + if (attr instanceof FieldAttribute fieldAttribute) { + Geometry geometry = SpatialRelatesUtils.makeGeometryFromLiteral(foldable); + if (geometry instanceof Point point) { + return new PushableGeoDistance(fieldAttribute, order, point); + } + } + return null; } - return plan; } - private static boolean canPushSorts(PhysicalPlan plan) { - if (plan instanceof EsQueryExec queryExec) { - return queryExec.canPushSorts(); + record PushableCompoundExec(EvalExec evalExec, EsQueryExec queryExec, List pushableSorts) implements Pushable { + public PhysicalPlan rewrite(TopNExec topNExec) { + // We need to keep the EVAL in place because the coordinator will have its own TopNExec so we need to keep the distance + return evalExec.replaceChild(queryExec.withSorts(pushableSorts).withLimit(topNExec.limit())); } - if (plan instanceof ExchangeExec exchangeExec && exchangeExec.child() instanceof EsQueryExec queryExec) { - return queryExec.canPushSorts(); + } + + private static Pushable evaluatePushable(TopNExec topNExec, Predicate hasIdenticalDelegate) { + PhysicalPlan child = topNExec.child(); + if (child instanceof EsQueryExec queryExec + && queryExec.canPushSorts() + && canPushDownOrders(topNExec.order(), hasIdenticalDelegate)) { + // With the simplest case of `FROM index | SORT ...` we only allow pushing down if the sort is on a field + return new PushableQueryExec(queryExec); + } + if (child instanceof ExchangeExec exchangeExec + && exchangeExec.child() instanceof EsQueryExec queryExec + && queryExec.canPushSorts() + && canPushDownOrders(topNExec.order(), hasIdenticalDelegate)) { + // When we have an exchange between the FROM and the SORT, we also only allow pushing down if the sort is on a field + return new PushableExchangeExec(exchangeExec, queryExec); + } + if (child instanceof EvalExec evalExec && evalExec.child() instanceof EsQueryExec queryExec && queryExec.canPushSorts()) { + // When we have an EVAL between the FROM and the SORT, we consider pushing down if the sort is on a field and/or + // a distance function defined in the EVAL. We also move the EVAL to after the SORT. + List orders = topNExec.order(); + List fields = evalExec.fields(); + LinkedHashMap distances = new LinkedHashMap<>(); + AttributeMap.Builder aliasReplacedByBuilder = AttributeMap.builder(); + fields.forEach(alias -> { + // TODO: can we support CARTESIAN also? + if (alias.child() instanceof StDistance distance && distance.crsType() == BinarySpatialFunction.SpatialCrsType.GEO) { + distances.put(alias.id(), distance); + } else if (alias.child() instanceof Attribute attr) { + aliasReplacedByBuilder.put(alias.toAttribute(), attr.toAttribute()); + } + }); + AttributeMap aliasReplacedBy = aliasReplacedByBuilder.build(); + + List pushableSorts = new ArrayList<>(); + for (Order order : orders) { + if (LucenePushDownUtils.isPushableFieldAttribute(order.child(), hasIdenticalDelegate)) { + pushableSorts.add( + new EsQueryExec.FieldSort( + ((FieldAttribute) order.child()).exactAttribute(), + order.direction(), + order.nullsPosition() + ) + ); + } else if (order.child() instanceof ReferenceAttribute referenceAttribute) { + Attribute resolvedAttribute = aliasReplacedBy.resolve(referenceAttribute, referenceAttribute); + if (distances.containsKey(resolvedAttribute.id())) { + StDistance distance = distances.get(resolvedAttribute.id()); + StDistance d = (StDistance) distance.transformDown(ReferenceAttribute.class, r -> aliasReplacedBy.resolve(r, r)); + PushableGeoDistance pushableGeoDistance = PushableGeoDistance.from(d, order); + if (pushableGeoDistance != null) { + pushableSorts.add(pushableGeoDistance.sort()); + } else { + // As soon as we see a non-pushable sort, we know we need a final SORT command + break; + } + } else if (aliasReplacedBy.resolve(referenceAttribute, referenceAttribute) instanceof FieldAttribute fieldAttribute + && LucenePushDownUtils.isPushableFieldAttribute(fieldAttribute, hasIdenticalDelegate)) { + // If the SORT refers to a reference to a pushable field, we can push it down + pushableSorts.add( + new EsQueryExec.FieldSort(fieldAttribute.exactAttribute(), order.direction(), order.nullsPosition()) + ); + } else { + // If the SORT refers to a non-pushable reference function, the EVAL must remain before the SORT, + // and we can no longer push down anything + break; + } + } else { + // As soon as we see a non-pushable sort, we know we need a final SORT command + break; + } + } + // TODO: We can push down partial sorts where `pushableSorts.size() < orders.size()`, but that should involve benchmarks + if (pushableSorts.size() > 0 && pushableSorts.size() == orders.size()) { + return new PushableCompoundExec(evalExec, queryExec, pushableSorts); + } } - return false; + return NO_OP; } - private boolean canPushDownOrders(List orders, Predicate hasIdenticalDelegate) { + private static boolean canPushDownOrders(List orders, Predicate hasIdenticalDelegate) { // allow only exact FieldAttributes (no expressions) for sorting return orders.stream().allMatch(o -> LucenePushDownUtils.isPushableFieldAttribute(o.child(), hasIdenticalDelegate)); } - private List buildFieldSorts(List orders) { - List sorts = new ArrayList<>(orders.size()); + private static List buildFieldSorts(List orders) { + List sorts = new ArrayList<>(orders.size()); for (Order o : orders) { sorts.add(new EsQueryExec.FieldSort(((FieldAttribute) o.child()).exactAttribute(), o.direction(), o.nullsPosition())); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/SpatialDocValuesExtraction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/SpatialDocValuesExtraction.java index ea6541326458e..d03cd9ef7cb0b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/SpatialDocValuesExtraction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/SpatialDocValuesExtraction.java @@ -9,9 +9,11 @@ import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialAggregateFunction; +import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.BinarySpatialFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialRelatesFunction; import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerRules; import org.elasticsearch.xpack.esql.plan.physical.AggregateExec; @@ -26,6 +28,41 @@ import java.util.List; import java.util.Set; +/** + * This rule is responsible for marking spatial fields to be extracted from doc-values instead of source values. + * This is a very specific optimization that is only used in the context of spatial aggregations. + * Normally spatial fields are extracted from source values because this maintains original precision, but is very slow. + * Simply loading from doc-values loses precision for points, and loses the geometry topological information for shapes. + * For this reason we only consider loading from doc values under very specific conditions: + *
    + *
  • The spatial data is consumed by a spatial aggregation (eg. ST_CENTROIDS_AGG, negating the need for precision.
  • + *
  • This aggregation is planned to run on the data node, so the doc-values Blocks are never transmit to the coordinator node.
  • + *
  • The data node index in question has doc-values stored for the field in question.
  • + *
+ * While we do not support transmitting spatial doc-values to the coordinator node, it is still important on the data node to ensure + * that all spatial functions that will receive these doc-values are aware of this fact. For this reason, if the above conditions are met, + * we need to make four edits to the local physical plan to consistently support spatial doc-values: + *
    + *
  • The spatial aggregation function itself is marked using withDocValues() to enable its + * toEvaluator() method to produce the correct doc-values aware Evaluator functions.
  • + *
  • Any spatial functions called within EVAL commands before the doc-values are consumed by the aggregation + * also need to be marked using withDocValues() so their evaluators are correct.
  • + *
  • Any spatial functions used within filters, WHERE commands, are similarly marked for the same reason.
  • + *
  • The FieldExtractExec that will extract the field is marked with withDocValuesAttributes(...) + * so that it calls the FieldType.blockReader() method with the correct FieldExtractPreference
  • + *
+ * The question has been raised why the spatial functions need to know if they are using doc-values or not. At first glance one might + * perceive ES|QL functions as being logical planning only constructs, reflecting only the intent of the user. This, however, is not true. + * The ES|QL functions all contain the runtime implementation of the functions behaviour, in the form of one or more static methods, + * as well as a toEvaluator() instance method that is used to generates Block traversal code to call these runtime + * implementations, based on some internal state of the instance of the function. In most cases this internal state contains information + * determined during the logical planning phase, such as the field name and type, and whether it is a literal and can be folded. + * In the case of spatial functions, the internal state also contains information about whether the function is using doc-values or not. + * This knowledge is determined in the class being described here, and is only determined during local physical planning on each data + * node. This is because the decision to use doc-values is based on the local data node's index configuration, and the local physical plan + * is the only place where this information is available. This also means that the knowledge of the usage of doc-values does not need + * to be serialized between nodes, and is only used locally. + */ public class SpatialDocValuesExtraction extends PhysicalOptimizerRules.OptimizerRule { @Override protected PhysicalPlan rule(AggregateExec aggregate) { @@ -65,14 +102,7 @@ && allowedForDocValues(fieldAttribute, agg, foundAttributes)) { if (exec instanceof EvalExec evalExec) { List fields = evalExec.fields(); List changed = fields.stream() - .map( - f -> (Alias) f.transformDown( - SpatialRelatesFunction.class, - spatialRelatesFunction -> (spatialRelatesFunction.hasFieldAttribute(foundAttributes)) - ? spatialRelatesFunction.withDocValues(foundAttributes) - : spatialRelatesFunction - ) - ) + .map(f -> (Alias) f.transformDown(BinarySpatialFunction.class, s -> withDocValues(s, foundAttributes))) .toList(); if (changed.equals(fields) == false) { exec = new EvalExec(exec.source(), exec.child(), changed); @@ -81,13 +111,7 @@ && allowedForDocValues(fieldAttribute, agg, foundAttributes)) { if (exec instanceof FilterExec filterExec) { // Note that ST_CENTROID does not support shapes, but SpatialRelatesFunction does, so when we extend the centroid // to support shapes, we need to consider loading shape doc-values for both centroid and relates (ST_INTERSECTS) - var condition = filterExec.condition() - .transformDown( - SpatialRelatesFunction.class, - spatialRelatesFunction -> (spatialRelatesFunction.hasFieldAttribute(foundAttributes)) - ? spatialRelatesFunction.withDocValues(foundAttributes) - : spatialRelatesFunction - ); + var condition = filterExec.condition().transformDown(BinarySpatialFunction.class, s -> withDocValues(s, foundAttributes)); if (filterExec.condition().equals(condition) == false) { exec = new FilterExec(filterExec.source(), filterExec.child(), condition); } @@ -110,6 +134,21 @@ && allowedForDocValues(fieldAttribute, agg, foundAttributes)) { return plan; } + private BinarySpatialFunction withDocValues(BinarySpatialFunction spatial, Set foundAttributes) { + // Only update the docValues flags if the field is found in the attributes + boolean foundLeft = foundField(spatial.left(), foundAttributes); + boolean foundRight = foundField(spatial.right(), foundAttributes); + return foundLeft || foundRight ? spatial.withDocValues(foundLeft, foundRight) : spatial; + } + + private boolean hasFieldAttribute(BinarySpatialFunction spatial, Set foundAttributes) { + return foundField(spatial.left(), foundAttributes) || foundField(spatial.right(), foundAttributes); + } + + private boolean foundField(Expression expression, Set foundAttributes) { + return expression instanceof FieldAttribute field && foundAttributes.contains(field); + } + /** * This function disallows the use of more than one field for doc-values extraction in the same spatial relation function. * This is because comparing two doc-values fields is not supported in the current implementation. @@ -123,7 +162,7 @@ private boolean allowedForDocValues(FieldAttribute fieldAttribute, AggregateExec var spatialRelatesAttributes = new HashSet(); agg.forEachExpressionDown(SpatialRelatesFunction.class, relatesFunction -> { candidateDocValuesAttributes.forEach(candidate -> { - if (relatesFunction.hasFieldAttribute(Set.of(candidate))) { + if (hasFieldAttribute(relatesFunction, Set.of(candidate))) { spatialRelatesAttributes.add(candidate); } }); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsQueryExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsQueryExec.java index 21aa2cb7d1860..82848fb2f1062 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsQueryExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsQueryExec.java @@ -11,10 +11,11 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.index.IndexMode; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.search.sort.FieldSortBuilder; +import org.elasticsearch.search.sort.GeoDistanceSortBuilder; +import org.elasticsearch.search.sort.SortBuilder; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -38,16 +39,17 @@ public class EsQueryExec extends LeafExec implements EstimatesRowSize { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( PhysicalPlan.class, "EsQueryExec", - EsQueryExec::new + EsQueryExec::deserialize ); public static final EsField DOC_ID_FIELD = new EsField("_doc", DataType.DOC_DATA_TYPE, Map.of(), false); + public static final List NO_SORTS = List.of(); // only exists to mimic older serialization, but we no longer serialize sorts private final EsIndex index; private final IndexMode indexMode; private final QueryBuilder query; private final Expression limit; - private final List sorts; + private final List sorts; private final List attrs; /** @@ -56,8 +58,17 @@ public class EsQueryExec extends LeafExec implements EstimatesRowSize { */ private final Integer estimatedRowSize; - public record FieldSort(FieldAttribute field, Order.OrderDirection direction, Order.NullsPosition nulls) implements Writeable { - public FieldSortBuilder fieldSortBuilder() { + public interface Sort { + SortBuilder sortBuilder(); + + Order.OrderDirection direction(); + + FieldAttribute field(); + } + + public record FieldSort(FieldAttribute field, Order.OrderDirection direction, Order.NullsPosition nulls) implements Sort { + @Override + public SortBuilder sortBuilder() { FieldSortBuilder builder = new FieldSortBuilder(field.name()); builder.order(Direction.from(direction).asOrder()); builder.missing(Missing.from(nulls).searchOrder()); @@ -72,12 +83,14 @@ private static FieldSort readFrom(StreamInput in) throws IOException { in.readEnum(Order.NullsPosition.class) ); } + } + public record GeoDistanceSort(FieldAttribute field, Order.OrderDirection direction, double lat, double lon) implements Sort { @Override - public void writeTo(StreamOutput out) throws IOException { - field().writeTo(out); - out.writeEnum(direction()); - out.writeEnum(nulls()); + public SortBuilder sortBuilder() { + GeoDistanceSortBuilder builder = new GeoDistanceSortBuilder(field.name(), lat, lon); + builder.order(Direction.from(direction).asOrder()); + return builder; } } @@ -92,7 +105,7 @@ public EsQueryExec( List attrs, QueryBuilder query, Expression limit, - List sorts, + List sorts, Integer estimatedRowSize ) { super(source); @@ -105,17 +118,29 @@ public EsQueryExec( this.estimatedRowSize = estimatedRowSize; } - private EsQueryExec(StreamInput in) throws IOException { - this( - Source.readFrom((PlanStreamInput) in), - new EsIndex(in), - EsRelation.readIndexMode(in), - in.readNamedWriteableCollectionAsList(Attribute.class), - in.readOptionalNamedWriteable(QueryBuilder.class), - in.readOptionalNamedWriteable(Expression.class), - in.readOptionalCollectionAsList(FieldSort::readFrom), - in.readOptionalVInt() - ); + /** + * The matching constructor is used during physical plan optimization and needs valid sorts. But we no longer serialize sorts. + * If this cluster node is talking to an older instance it might receive a plan with sorts, but it will ignore them. + */ + public static EsQueryExec deserialize(StreamInput in) throws IOException { + var source = Source.readFrom((PlanStreamInput) in); + var index = new EsIndex(in); + var indexMode = EsRelation.readIndexMode(in); + var attrs = in.readNamedWriteableCollectionAsList(Attribute.class); + var query = in.readOptionalNamedWriteable(QueryBuilder.class); + var limit = in.readOptionalNamedWriteable(Expression.class); + in.readOptionalCollectionAsList(EsQueryExec::readSort); + var rowSize = in.readOptionalVInt(); + // Ignore sorts from the old serialization format + return new EsQueryExec(source, index, indexMode, attrs, query, limit, NO_SORTS, rowSize); + } + + private static Sort readSort(StreamInput in) throws IOException { + return FieldSort.readFrom(in); + } + + private static void writeSort(StreamOutput out, Sort sort) { + throw new IllegalStateException("sorts are no longer serialized"); } @Override @@ -126,7 +151,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeNamedWriteableCollection(output()); out.writeOptionalNamedWriteable(query()); out.writeOptionalNamedWriteable(limit()); - out.writeOptionalCollection(sorts()); + out.writeOptionalCollection(NO_SORTS, EsQueryExec::writeSort); out.writeOptionalVInt(estimatedRowSize()); } @@ -165,7 +190,7 @@ public Expression limit() { return limit; } - public List sorts() { + public List sorts() { return sorts; } @@ -208,7 +233,7 @@ public boolean canPushSorts() { return indexMode != IndexMode.TIME_SERIES; } - public EsQueryExec withSorts(List sorts) { + public EsQueryExec withSorts(List sorts) { if (indexMode == IndexMode.TIME_SERIES) { assert false : "time-series index mode doesn't support sorts"; throw new UnsupportedOperationException("time-series index mode doesn't support sorts"); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java index 04be731484267..ab0d68b152262 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java @@ -56,7 +56,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.AbstractConvertFunction; import org.elasticsearch.xpack.esql.plan.physical.AggregateExec; import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec; -import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec.FieldSort; +import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec.Sort; import org.elasticsearch.xpack.esql.plan.physical.FieldExtractExec; import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.DriverParallelism; import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.LocalExecutionPlannerContext; @@ -161,15 +161,14 @@ public Function querySuppl public final PhysicalOperation sourcePhysicalOperation(EsQueryExec esQueryExec, LocalExecutionPlannerContext context) { final LuceneOperator.Factory luceneFactory; - List sorts = esQueryExec.sorts(); - List> fieldSorts = null; + List sorts = esQueryExec.sorts(); assert esQueryExec.estimatedRowSize() != null : "estimated row size not initialized"; int rowEstimatedSize = esQueryExec.estimatedRowSize(); int limit = esQueryExec.limit() != null ? (Integer) esQueryExec.limit().fold() : NO_LIMIT; if (sorts != null && sorts.isEmpty() == false) { - fieldSorts = new ArrayList<>(sorts.size()); - for (FieldSort sort : sorts) { - fieldSorts.add(sort.fieldSortBuilder()); + List> sortBuilders = new ArrayList<>(sorts.size()); + for (Sort sort : sorts) { + sortBuilders.add(sort.sortBuilder()); } luceneFactory = new LuceneTopNSourceOperator.Factory( shardContexts, @@ -178,7 +177,7 @@ public final PhysicalOperation sourcePhysicalOperation(EsQueryExec esQueryExec, context.queryPragmas().taskConcurrency(), context.pageSize(rowEstimatedSize), limit, - fieldSorts + sortBuilders ); } else { if (esQueryExec.indexMode() == IndexMode.TIME_SERIES) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/SpatialRelatesQuery.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/SpatialRelatesQuery.java index 4f0bcbb43e260..532825290af0d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/SpatialRelatesQuery.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/SpatialRelatesQuery.java @@ -9,6 +9,8 @@ import org.apache.lucene.search.ConstantScoreQuery; import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.geo.GeoJson; import org.elasticsearch.common.geo.ShapeRelation; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.geometry.Geometry; @@ -92,9 +94,16 @@ public ShapeRelation shapeRelation() { */ public abstract class ShapeQueryBuilder implements QueryBuilder { - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - throw new UnsupportedOperationException("Unimplemented: toXContent()"); + protected void doToXContent(String queryName, XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.startObject(queryName); + builder.startObject(field); + builder.field("relation", queryRelation); + builder.field("shape"); + GeoJson.toXContent(shape, builder, params); + builder.endObject(); + builder.endObject(); + builder.endObject(); } @Override @@ -157,6 +166,11 @@ public ShapeRelation relation() { public Geometry shape() { return shape; } + + @Override + public String toString() { + return Strings.toString(this, true, true); + } } private class GeoShapeQueryBuilder extends ShapeQueryBuilder { @@ -178,6 +192,13 @@ org.apache.lucene.search.Query buildShapeQuery(SearchExecutionContext context, M final GeoShapeQueryable ft = (GeoShapeQueryable) fieldType; return new ConstantScoreQuery(ft.geoShapeQuery(context, fieldType.name(), shapeRelation(), shape)); } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + // Currently used only in testing and debugging + doToXContent(NAME, builder, params); + return builder; + } } private class CartesianShapeQueryBuilder extends ShapeQueryBuilder { @@ -225,5 +246,13 @@ private static org.apache.lucene.search.Query shapeShapeQuery( throw new QueryShardException(context, "Exception creating query on Field [" + fieldName + "] " + e.getMessage(), e); } } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + // Currently used only in testing and debugging + doToXContent("cartesian_shape", builder, params); + return builder; + } + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/index/EsIndexSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/index/EsIndexSerializationTests.java index 7be200baf6c58..687b83370f571 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/index/EsIndexSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/index/EsIndexSerializationTests.java @@ -135,11 +135,12 @@ public static EsIndex indexWithManyConflicts(boolean withParent) { * See {@link #testManyTypeConflicts(boolean, ByteSizeValue)} for more. */ public void testManyTypeConflicts() throws IOException { - testManyTypeConflicts(false, ByteSizeValue.ofBytes(991026)); + testManyTypeConflicts(false, ByteSizeValue.ofBytes(916998)); /* * History: * 953.7kb - shorten error messages for UnsupportedAttributes #111973 * 967.7kb - cache EsFields #112008 (little overhead of the cache) + * 895.5kb - string serialization #112929 */ } @@ -148,12 +149,13 @@ public void testManyTypeConflicts() throws IOException { * See {@link #testManyTypeConflicts(boolean, ByteSizeValue)} for more. */ public void testManyTypeConflictsWithParent() throws IOException { - testManyTypeConflicts(true, ByteSizeValue.ofBytes(1374497)); + testManyTypeConflicts(true, ByteSizeValue.ofBytes(1300467)); /* * History: * 16.9mb - start * 1.8mb - shorten error messages for UnsupportedAttributes #111973 * 1.3mb - cache EsFields #112008 + * 1.2mb - string serialization #112929 */ } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java index f66349351bb59..6746b8ff61268 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java @@ -9,6 +9,7 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import org.apache.lucene.util.BytesRef; import org.elasticsearch.Build; import org.elasticsearch.common.geo.ShapeRelation; import org.elasticsearch.common.lucene.BytesRefs; @@ -20,12 +21,15 @@ import org.elasticsearch.geometry.ShapeType; import org.elasticsearch.index.IndexMode; import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.ExistsQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.RangeQueryBuilder; import org.elasticsearch.index.query.RegexpQueryBuilder; import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.index.query.TermsQueryBuilder; import org.elasticsearch.index.query.WildcardQueryBuilder; +import org.elasticsearch.search.sort.FieldSortBuilder; +import org.elasticsearch.search.sort.GeoDistanceSortBuilder; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.enrich.EnrichPolicy; import org.elasticsearch.xpack.esql.EsqlTestUtils; @@ -41,6 +45,7 @@ import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; +import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; @@ -151,14 +156,17 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.GEO_POINT; import static org.elasticsearch.xpack.esql.parser.ExpressionBuilder.MAX_EXPRESSION_DEPTH; import static org.elasticsearch.xpack.esql.parser.LogicalPlanBuilder.MAX_QUERY_DEPTH; +import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; import static org.hamcrest.Matchers.matchesRegex; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; @@ -529,11 +537,11 @@ public void testExtractorForField() { assertThat(names(extract.attributesToExtract()), contains("salary", "emp_no", "last_name")); var source = source(extract.child()); assertThat(source.limit(), is(topN.limit())); - assertThat(source.sorts(), is(sorts(topN.order()))); + assertThat(source.sorts(), is(fieldSorts(topN.order()))); assertThat(source.limit(), is(l(10))); assertThat(source.sorts().size(), is(1)); - FieldSort order = source.sorts().get(0); + EsQueryExec.Sort order = source.sorts().get(0); assertThat(order.direction(), is(Order.OrderDirection.ASC)); assertThat(name(order.field()), is("last_name")); // last name is keyword, salary, emp_no, doc id, segment, forwards and backwards doc id maps are all ints @@ -802,6 +810,19 @@ public void testQueryForStatWithMultiAgg() { assertThat(query.query(), is(boolQuery().should(existsQuery("emp_no")).should(existsQuery("salary")))); } + /** + * This used to not allow pushing the sort down to the source, but now it does, since the eval is not used for the sort + * + * TopNExec[[Order[emp_no{f}#6,ASC,LAST]],1[INTEGER],0] + * \_ExchangeExec[[_meta_field{f}#12, emp_no{f}#6, first_name{f}#7, gender{f}#8, job{f}#13, job.raw{f}#14, ..],false] + * \_ProjectExec[[_meta_field{f}#12, emp_no{f}#6, first_name{f}#7, gender{f}#8, job{f}#13, job.raw{f}#14, ..]] + * \_FieldExtractExec[_meta_field{f}#12, emp_no{f}#6, first_name{f}#7, ge..][] + * \_EvalExec[[null[INTEGER] AS nullsum]] + * \_EsQueryExec[test], indexMode[standard], query[][_doc{f}#27], limit[1], sort[[ + * FieldSort[field=emp_no{f}#6, direction=ASC, nulls=LAST] + * ]] estimatedRowSize[340] + * + */ public void testQueryWithNull() { var plan = physicalPlan(""" from test @@ -818,15 +839,10 @@ public void testQueryWithNull() { var exchange = asRemoteExchange(topN.child()); var project = as(exchange.child(), ProjectExec.class); var extract = as(project.child(), FieldExtractExec.class); - var topNLocal = as(extract.child(), TopNExec.class); - // All fields except emp_no are loaded after this topn. We load an extra int for the doc and segment mapping. - assertThat(topNLocal.estimatedRowSize(), equalTo(allFieldRowSize + Integer.BYTES)); - - var extractForEval = as(topNLocal.child(), FieldExtractExec.class); - var eval = as(extractForEval.child(), EvalExec.class); + var eval = as(extract.child(), EvalExec.class); var source = source(eval.child()); - // emp_no and nullsum are longs, doc id is an int - assertThat(source.estimatedRowSize(), equalTo(Integer.BYTES * 2 + Integer.BYTES)); + // All fields loaded + assertThat(source.estimatedRowSize(), equalTo(allFieldRowSize + 3 * Integer.BYTES + Long.BYTES)); } public void testPushAndInequalitiesFilter() { @@ -1065,7 +1081,7 @@ public void testProjectAfterTopN() throws Exception { var extract = as(project.child(), FieldExtractExec.class); var source = source(extract.child()); assertThat(source.limit(), is(topN.limit())); - assertThat(source.sorts(), is(sorts(topN.order()))); + assertThat(source.sorts(), is(fieldSorts(topN.order()))); // an int for doc id, an int for segment id, two ints for doc id map, and int for emp_no. assertThat(source.estimatedRowSize(), equalTo(Integer.BYTES * 5 + KEYWORD_EST)); } @@ -1221,11 +1237,11 @@ public void testDoNotAliasesDefinedAfterTheExchange() throws Exception { assertThat(names(extract.attributesToExtract()), contains("languages", "salary")); var source = source(extract.child()); assertThat(source.limit(), is(topN.limit())); - assertThat(source.sorts(), is(sorts(topN.order()))); + assertThat(source.sorts(), is(fieldSorts(topN.order()))); assertThat(source.limit(), is(l(1))); assertThat(source.sorts().size(), is(1)); - FieldSort order = source.sorts().get(0); + EsQueryExec.Sort order = source.sorts().get(0); assertThat(order.direction(), is(Order.OrderDirection.ASC)); assertThat(name(order.field()), is("salary")); // ints for doc id, segment id, forwards and backwards mapping, languages, and salary @@ -1754,6 +1770,124 @@ public void testPushDownNotRLike() { assertThat(regexpQuery.value(), is(".*foo.*")); } + /** + * + * TopNExec[[Order[name{r}#4,ASC,LAST]],1000[INTEGER],0] + * \_ExchangeExec[[_meta_field{f}#20, emp_no{f}#14, gender{f}#16, job{f}#21, job.raw{f}#22, languages{f}#17, + * long_noidx{f}#23, salary{f}#19, name{r}#4, first_name{r}#7, last_name{r}#10 + * ],false] + * \_ProjectExec[[_meta_field{f}#20, emp_no{f}#14, gender{f}#16, job{f}#21, job.raw{f}#22, languages{f}#17, + * long_noidx{f}#23, salary{f}#19, name{r}#4, first_name{r}#7, last_name{r}#10 + * ]] + * \_FieldExtractExec[_meta_field{f}#20, emp_no{f}#14, gender{f}#16, job{..][] + * \_EvalExec[[first_name{f}#15 AS name, last_name{f}#18 AS first_name, name{r}#4 AS last_name]] + * \_FieldExtractExec[first_name{f}#15, last_name{f}#18][] + * \_EsQueryExec[test], indexMode[standard], query[{ + * "bool":{"must":[ + * {"esql_single_value":{"field":"last_name","next":{"term":{"last_name":{"value":"foo"}}},"source":...}}, + * {"esql_single_value":{"field":"first_name","next":{"term":{"first_name":{"value":"bar"}}},"source":...}} + * ],"boost":1.0}}][_doc{f}#37], limit[1000], sort[[ + * FieldSort[field=first_name{f}#15, direction=ASC, nulls=LAST] + * ]] estimatedRowSize[486] + * + */ + public void testPushDownEvalFilter() { + var plan = physicalPlan(""" + FROM test + | EVAL name = first_name, first_name = last_name, last_name = name + | WHERE first_name == "foo" AND last_name == "bar" + | SORT name + """); + var optimized = optimizedPlan(plan); + + var topN = as(optimized, TopNExec.class); + var exchange = asRemoteExchange(topN.child()); + var project = as(exchange.child(), ProjectExec.class); + var extract = as(project.child(), FieldExtractExec.class); + assertThat(extract.attributesToExtract().size(), greaterThan(5)); + var eval = as(extract.child(), EvalExec.class); + extract = as(eval.child(), FieldExtractExec.class); + assertThat( + extract.attributesToExtract().stream().map(Attribute::name).collect(Collectors.toList()), + contains("first_name", "last_name") + ); + + // Now verify the correct Lucene push-down of both the filter and the sort + var source = source(extract.child()); + QueryBuilder query = source.query(); + assertNotNull(query); + assertThat(query, instanceOf(BoolQueryBuilder.class)); + var boolQuery = (BoolQueryBuilder) query; + var must = boolQuery.must(); + assertThat(must.size(), is(2)); + var range1 = (TermQueryBuilder) ((SingleValueQuery.Builder) must.get(0)).next(); + assertThat(range1.fieldName(), is("last_name")); + var range2 = (TermQueryBuilder) ((SingleValueQuery.Builder) must.get(1)).next(); + assertThat(range2.fieldName(), is("first_name")); + var sort = source.sorts(); + assertThat(sort.size(), is(1)); + assertThat(sort.get(0).field().fieldName(), is("first_name")); + } + + /** + * + * ProjectExec[[last_name{f}#21 AS name, first_name{f}#18 AS last_name, last_name{f}#21 AS first_name]] + * \_TopNExec[[Order[last_name{f}#21,ASC,LAST]],10[INTEGER],0] + * \_ExchangeExec[[last_name{f}#21, first_name{f}#18],false] + * \_ProjectExec[[last_name{f}#21, first_name{f}#18]] + * \_FieldExtractExec[last_name{f}#21, first_name{f}#18][] + * \_EsQueryExec[test], indexMode[standard], query[{ + * "bool":{"must":[ + * {"esql_single_value":{ + * "field":"last_name", + * "next":{"range":{"last_name":{"gt":"B","boost":1.0}}}, + * "source":"first_name > \"B\"@3:9" + * }}, + * {"exists":{"field":"first_name","boost":1.0}} + * ],"boost":1.0}}][_doc{f}#40], limit[10], sort[[ + * FieldSort[field=last_name{f}#21, direction=ASC, nulls=LAST] + * ]] estimatedRowSize[116] + * + */ + public void testPushDownEvalSwapFilter() { + var plan = physicalPlan(""" + FROM test + | EVAL name = last_name, last_name = first_name, first_name = name + | WHERE first_name > "B" AND last_name IS NOT NULL + | SORT name + | LIMIT 10 + | KEEP name, last_name, first_name + """); + var optimized = optimizedPlan(plan); + + var topProject = as(optimized, ProjectExec.class); + var topN = as(topProject.child(), TopNExec.class); + var exchange = asRemoteExchange(topN.child()); + var project = as(exchange.child(), ProjectExec.class); + var extract = as(project.child(), FieldExtractExec.class); + assertThat( + extract.attributesToExtract().stream().map(Attribute::name).collect(Collectors.toList()), + contains("last_name", "first_name") + ); + + // Now verify the correct Lucene push-down of both the filter and the sort + var source = source(extract.child()); + QueryBuilder query = source.query(); + assertNotNull(query); + assertThat(query, instanceOf(BoolQueryBuilder.class)); + var boolQuery = (BoolQueryBuilder) query; + var must = boolQuery.must(); + assertThat(must.size(), is(2)); + var svq = (SingleValueQuery.Builder) must.get(0); + var range = (RangeQueryBuilder) svq.next(); + assertThat(range.fieldName(), is("last_name")); + var exists = (ExistsQueryBuilder) must.get(1); + assertThat(exists.fieldName(), is("first_name")); + var sort = source.sorts(); + assertThat(sort.size(), is(1)); + assertThat(sort.get(0).field().fieldName(), is("last_name")); + } + /** * EnrichExec[first_name{f}#3,foo,fld,idx,[a{r}#11, b{r}#12]] * \_LimitExec[10000[INTEGER]] @@ -3074,6 +3208,107 @@ public void testPushSpatialIntersectsStringToSource() { } } + /** + * Plan: + * + * LimitExec[1000[INTEGER]] + * \_ExchangeExec[[],false] + * \_FragmentExec[filter=null, estimatedRowSize=0, reducer=[], fragment=[ + * Limit[1000[INTEGER]] + * \_Filter[rank{r}#4 lt 4[INTEGER]] + * \_Eval[[scalerank{f}#8 AS rank]] + * \_EsRelation[airports][abbrev{f}#6, city{f}#12, city_location{f}#13, count..]]] + * + * Optimized: + * + * LimitExec[1000[INTEGER]] + * \_ExchangeExec[[abbrev{f}#6, city{f}#12, city_location{f}#13, country{f}#11, location{f}#10, name{f}#7, scalerank{f}#8, + * type{f}#9, rank{r}#4],false] + * \_ProjectExec[[abbrev{f}#6, city{f}#12, city_location{f}#13, country{f}#11, location{f}#10, name{f}#7, scalerank{f}#8, + * type{f}#9, rank{r}#4]] + * \_FieldExtractExec[abbrev{f}#6, city{f}#12, city_location{f}#13, count..][] + * \_LimitExec[1000[INTEGER]] + * \_EvalExec[[scalerank{f}#8 AS rank]] + * \_FieldExtractExec[scalerank{f}#8][] + * \_EsQueryExec[airports], indexMode[standard], query[{" + * esql_single_value":{"field":"scalerank","next":{"range":{"scalerank":{"lt":4,"boost":1.0}}},"source":"rank < 4@3:9"} + * }][_doc{f}#23], limit[], sort[] estimatedRowSize[304] + * + */ + public void testPushWhereEvalToSource() { + String query = """ + FROM airports + | EVAL rank = scalerank + | WHERE rank < 4 + """; + + var plan = this.physicalPlan(query, airports); + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var fragment = as(exchange.child(), FragmentExec.class); + var limit2 = as(fragment.fragment(), Limit.class); + var filter = as(limit2.child(), Filter.class); + assertThat("filter contains LessThan", filter.condition(), instanceOf(LessThan.class)); + + var optimized = optimizedPlan(plan); + var topLimit = as(optimized, LimitExec.class); + exchange = as(topLimit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var fieldExtract = as(project.child(), FieldExtractExec.class); + assertThat(fieldExtract.attributesToExtract().size(), greaterThan(5)); + limit = as(fieldExtract.child(), LimitExec.class); + var eval = as(limit.child(), EvalExec.class); + fieldExtract = as(eval.child(), FieldExtractExec.class); + assertThat(fieldExtract.attributesToExtract().stream().map(Attribute::name).collect(Collectors.toList()), contains("scalerank")); + var source = source(fieldExtract.child()); + var condition = as(source.query(), SingleValueQuery.Builder.class); + assertThat("Expected predicate to be passed to Lucene query", condition.source().text(), equalTo("rank < 4")); + assertThat("Expected field to be passed to Lucene query", condition.field(), equalTo("scalerank")); + var range = as(condition.next(), RangeQueryBuilder.class); + assertThat("Expected range have no lower bound", range.from(), nullValue()); + assertThat("Expected range to be less than 4", range.to(), equalTo(4)); + } + + public void testPushSpatialIntersectsEvalToSource() { + for (String query : new String[] { """ + FROM airports + | EVAL point = location + | WHERE ST_INTERSECTS(point, TO_GEOSHAPE("POLYGON((42 14, 43 14, 43 15, 42 15, 42 14))")) + """, """ + FROM airports + | EVAL point = location + | WHERE ST_INTERSECTS(TO_GEOSHAPE("POLYGON((42 14, 43 14, 43 15, 42 15, 42 14))"), point) + """ }) { + + var plan = this.physicalPlan(query, airports); + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var fragment = as(exchange.child(), FragmentExec.class); + var limit2 = as(fragment.fragment(), Limit.class); + var filter = as(limit2.child(), Filter.class); + assertThat("filter contains ST_INTERSECTS", filter.condition(), instanceOf(SpatialIntersects.class)); + + var optimized = optimizedPlan(plan); + var topLimit = as(optimized, LimitExec.class); + exchange = as(topLimit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var fieldExtract = as(project.child(), FieldExtractExec.class); + assertThat(fieldExtract.attributesToExtract().size(), greaterThan(5)); + limit = as(fieldExtract.child(), LimitExec.class); + var eval = as(limit.child(), EvalExec.class); + fieldExtract = as(eval.child(), FieldExtractExec.class); + assertThat(fieldExtract.attributesToExtract().stream().map(Attribute::name).collect(Collectors.toList()), contains("location")); + var source = source(fieldExtract.child()); + var condition = as(source.query(), SpatialRelatesQuery.ShapeQueryBuilder.class); + assertThat("Geometry field name", condition.fieldName(), equalTo("location")); + assertThat("Spatial relationship", condition.relation(), equalTo(ShapeRelation.INTERSECTS)); + assertThat("Geometry is Polygon", condition.shape().type(), equalTo(ShapeType.POLYGON)); + var polygon = as(condition.shape(), Polygon.class); + assertThat("Polygon shell length", polygon.getPolygon().length(), equalTo(5)); + assertThat("Polygon holes", polygon.getNumberOfHoles(), equalTo(0)); + } + } + private record TestSpatialRelation(ShapeRelation relation, TestDataSource index, boolean literalRight, boolean canPushToSource) { String function() { return switch (relation) { @@ -3257,30 +3492,22 @@ public void testPushDownSpatialRelatesStringToSourceAndUseDocValuesForCentroid() * * Optimized: * - * LimitExec[500[INTEGER]] - * \_AggregateExec[[],[SPATIALCENTROID(location{f}#12) AS centroid, COUNT([2a][KEYWORD]) AS count],FINAL,58] + * LimitExec[1000[INTEGER]] + * \_AggregateExec[[],[SPATIALCENTROID(location{f}#12) AS centroid, COUNT([2a][KEYWORD]) AS count],FINAL,[...],29] * \_ExchangeExec[[xVal{r}#16, xDel{r}#17, yVal{r}#18, yDel{r}#19, count{r}#20, count{r}#21, seen{r}#22],true] - * \_AggregateExec[[],[SPATIALCENTROID(location{f}#12) AS centroid, COUNT([2a][KEYWORD]) AS count],PARTIAL,58] + * \_AggregateExec[[],[SPATIALCENTROID(location{f}#12) AS centroid, COUNT([2a][KEYWORD]) AS count],INITIAL,[...],29] * \_FieldExtractExec[location{f}#12][location{f}#12] - * \_EsQueryExec[airports], query[{ - * "esql_single_value":{ - * "field":"location", - * "next":{ - * "geo_shape":{ - * "location":{ - * "shape":{ - * "type":"Polygon", - * "coordinates":[[[42.0,14.0],[43.0,14.0],[43.0,15.0],[42.0,15.0],[42.0,14.0]]] - * }, - * "relation":"intersects" - * }, - * "ignore_unmapped":false, - * "boost":1.0 + * \_EsQueryExec[airports], indexMode[standard], query[{ + * "geo_shape":{ + * "location":{ + * "relation":"INTERSECTS", + * "shape":{ + * "type":"Polygon", + * "coordinates":[[[42.0,14.0],[43.0,14.0],[43.0,15.0],[42.0,15.0],[42.0,14.0]]] * } - * }, - * "source":"ST_INTERSECTS(location, \"POLYGON((42 14, 43 14, 43 15, 42 15, 42 14))\")@2:9" + * } * } - * }][_doc{f}#140, limit[], sort[] estimatedRowSize[54] + * }][_doc{f}#47], limit[], sort[] estimatedRowSize[25] * */ public void testPushSpatialIntersectsStringToSourceAndUseDocValuesForCentroid() { @@ -3732,6 +3959,167 @@ AND ST_DISTANCE(location, TO_GEOPOINT("POINT(12.565 55.673)")) >= 400000 assertShapeQueryRange(shapeQueryBuilders, 400000.0, 600000.0); } + /** + * Plan: + * + * LimitExec[1000[INTEGER]] + * \_ExchangeExec[[],false] + * \_FragmentExec[filter=null, estimatedRowSize=0, reducer=[], fragment=[ + * Limit[1000[INTEGER]] + * \_Filter[distance{r}#4 le 600000[INTEGER] AND distance{r}#4 ge 400000[INTEGER]] + * \_Eval[[STDISTANCE(location{f}#11,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) + * AS distance]] + * \_EsRelation[airports][abbrev{f}#7, city{f}#13, city_location{f}#14, count..]]] + * + * Optimized: + * + * LimitExec[1000[INTEGER]] + * \_ExchangeExec[[abbrev{f}#7, city{f}#13, city_location{f}#14, country{f}#12, location{f}#11, name{f}#8, scalerank{f}#9, type{ + * f}#10, distance{r}#4],false] + * \_ProjectExec[[abbrev{f}#7, city{f}#13, city_location{f}#14, country{f}#12, location{f}#11, name{f}#8, scalerank{f}#9, type{ + * f}#10, distance{r}#4]] + * \_FieldExtractExec[abbrev{f}#7, city{f}#13, city_location{f}#14, count..][] + * \_LimitExec[1000[INTEGER]] + * \_EvalExec[[STDISTANCE(location{f}#11,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) + * AS distance]] + * \_FieldExtractExec[location{f}#11][] + * \_EsQueryExec[airports], indexMode[standard], query[{ + * "bool":{ + * "must":[ + * { + * "geo_shape":{ + * "location":{ + * "relation":"INTERSECTS", + * "shape":{ + * "type":"Circle", + * "radius":"600000.0m", + * "coordinates":[12.565,55.673] + * } + * } + * } + * }, + * { + * "geo_shape":{ + * "location":{ + * "relation":"DISJOINT", + * "shape":{ + * "type":"Circle", + * "radius":"400000.0m", + * "coordinates":[12.565,55.673] + * } + * } + * } + * } + * ], + * "boost":1.0 + * }}][_doc{f}#24], limit[], sort[] estimatedRowSize[308] + * + */ + public void testPushSpatialDistanceEvalToSource() { + var query = """ + FROM airports + | EVAL distance = ST_DISTANCE(location, TO_GEOPOINT("POINT(12.565 55.673)")) + | WHERE distance <= 600000 + AND distance >= 400000 + """; + var plan = this.physicalPlan(query, airports); + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var fragment = as(exchange.child(), FragmentExec.class); + var limit2 = as(fragment.fragment(), Limit.class); + var filter = as(limit2.child(), Filter.class); + + // Validate the EVAL expression + var eval = as(filter.child(), Eval.class); + var alias = as(eval.fields().get(0), Alias.class); + assertThat(alias.name(), is("distance")); + var stDistance = as(alias.child(), StDistance.class); + var location = as(stDistance.left(), FieldAttribute.class); + assertThat(location.fieldName(), is("location")); + + // Validate the filter condition + var and = as(filter.condition(), And.class); + for (Expression expression : and.arguments()) { + var comp = as(expression, EsqlBinaryComparison.class); + var expectedComp = comp.equals(and.left()) ? LessThanOrEqual.class : GreaterThanOrEqual.class; + assertThat("filter contains expected binary comparison", comp, instanceOf(expectedComp)); + var distance = as(comp.left(), ReferenceAttribute.class); + assertThat(distance.name(), is("distance")); + } + + var optimized = optimizedPlan(plan); + var topLimit = as(optimized, LimitExec.class); + exchange = as(topLimit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var fieldExtract = as(project.child(), FieldExtractExec.class); + var limit3 = as(fieldExtract.child(), LimitExec.class); + var evalExec = as(limit3.child(), EvalExec.class); + var fieldExtract2 = as(evalExec.child(), FieldExtractExec.class); + var source = source(fieldExtract2.child()); + var bool = as(source.query(), BoolQueryBuilder.class); + var rangeQueryBuilders = bool.filter().stream().filter(p -> p instanceof SingleValueQuery.Builder).toList(); + assertThat("Expected zero range query builder", rangeQueryBuilders.size(), equalTo(0)); + var shapeQueryBuilders = bool.must().stream().filter(p -> p instanceof SpatialRelatesQuery.ShapeQueryBuilder).toList(); + assertShapeQueryRange(shapeQueryBuilders, 400000.0, 600000.0); + } + + public void testPushSpatialDistanceMultiEvalToSource() { + var query = """ + FROM airports + | EVAL poi = TO_GEOPOINT("POINT(12.565 55.673)") + | EVAL distance = ST_DISTANCE(location, poi) + | WHERE distance <= 600000 + AND distance >= 400000 + """; + var plan = this.physicalPlan(query, airports); + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var fragment = as(exchange.child(), FragmentExec.class); + var limit2 = as(fragment.fragment(), Limit.class); + var filter = as(limit2.child(), Filter.class); + + // Validate the EVAL expression + var eval = as(filter.child(), Eval.class); + assertThat(eval.fields().size(), is(2)); + var alias1 = as(eval.fields().get(0), Alias.class); + assertThat(alias1.name(), is("poi")); + var poi = as(alias1.child(), Literal.class); + assertThat(poi.fold(), instanceOf(BytesRef.class)); + var alias2 = as(eval.fields().get(1), Alias.class); + assertThat(alias2.name(), is("distance")); + var stDistance = as(alias2.child(), StDistance.class); + var location = as(stDistance.left(), FieldAttribute.class); + assertThat(location.fieldName(), is("location")); + var poiRef = as(stDistance.right(), Literal.class); + assertThat(poiRef.fold(), instanceOf(BytesRef.class)); + assertThat(poiRef.fold().toString(), is(poi.fold().toString())); + + // Validate the filter condition + var and = as(filter.condition(), And.class); + for (Expression expression : and.arguments()) { + var comp = as(expression, EsqlBinaryComparison.class); + var expectedComp = comp.equals(and.left()) ? LessThanOrEqual.class : GreaterThanOrEqual.class; + assertThat("filter contains expected binary comparison", comp, instanceOf(expectedComp)); + var distance = as(comp.left(), ReferenceAttribute.class); + assertThat(distance.name(), is("distance")); + } + + var optimized = optimizedPlan(plan); + var topLimit = as(optimized, LimitExec.class); + exchange = as(topLimit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var fieldExtract = as(project.child(), FieldExtractExec.class); + var limit3 = as(fieldExtract.child(), LimitExec.class); + var evalExec = as(limit3.child(), EvalExec.class); + var fieldExtract2 = as(evalExec.child(), FieldExtractExec.class); + var source = source(fieldExtract2.child()); + var bool = as(source.query(), BoolQueryBuilder.class); + var rangeQueryBuilders = bool.filter().stream().filter(p -> p instanceof SingleValueQuery.Builder).toList(); + assertThat("Expected zero range query builder", rangeQueryBuilders.size(), equalTo(0)); + var shapeQueryBuilders = bool.must().stream().filter(p -> p instanceof SpatialRelatesQuery.ShapeQueryBuilder).toList(); + assertShapeQueryRange(shapeQueryBuilders, 400000.0, 600000.0); + } + public void testPushSpatialDistanceDisjointBandsToSource() { var query = """ FROM airports @@ -3826,6 +4214,1361 @@ AND ST_DISTANCE(location, TO_GEOPOINT("POINT(12.565 55.673)")) >= 200000)) } } + /** + * + * \_ExchangeExec[[abbrev{f}#22, city{f}#28, city_location{f}#29, country{f}#27, location{f}#26, name{f}#23, scalerank{f}#24, + * type{f}#25, poi_x{r}#3, distance_x{r}#7, poi{r}#10, distance{r}#13],false] + * \_ProjectExec[[abbrev{f}#22, city{f}#28, city_location{f}#29, country{f}#27, location{f}#26, name{f}#23, scalerank{f}#24, + * type{f}#25, poi_x{r}#3, distance_x{r}#7, poi{r}#10, distance{r}#13]] + * \_FieldExtractExec[abbrev{f}#22, city{f}#28, city_location{f}#29, coun..][] + * \_LimitExec[1000[INTEGER]] + * \_EvalExec[[ + * [1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT] AS poi_x, + * DISTANCE(location{f}#26,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS distance_x, + * [1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT] AS poi, + * distance_x{r}#7 AS distance + * ]] + * \_FieldExtractExec[location{f}#26][] + * \_EsQueryExec[airports], indexMode[standard], query[{ + * "bool":{ + * "must":[ + * {"esql_single_value":{ + * "field":"abbrev", + * "next":{"bool":{"must_not":[{"term":{"abbrev":{"value":"PLQ"}}}],"boost":1.0}}, + * "source":"NOT abbrev == \"PLQ\"@10:9" + * }}, + * {"esql_single_value":{ + * "field":"scalerank", + * "next":{"range":{"scalerank":{"lt":6,"boost":1.0}}}, + * "source":"scalerank lt 6@11:9" + * }} + * ], + * "filter":[ + * {"bool":{ + * "should":[ + * {"bool":{"must":[ + * {"geo_shape":{"location":{"relation":"INTERSECTS","shape":{...}}}}, + * {"geo_shape":{"location":{"relation":"DISJOINT","shape":{...}}}}, + * {"bool":{"must_not":[ + * {"bool":{"must":[ + * {"geo_shape":{"location":{"relation":"INTERSECTS","shape":{...}}}}, + * {"geo_shape":{"location":{"relation":"DISJOINT","shape":{...}}}} + * ],"boost":1.0}} + * ],"boost":1.0}} + * ],"boost":1.0}}, + * {"bool":{"must":[ + * {"geo_shape":{"location":{"relation":"INTERSECTS","shape":{...}}}}, + * {"geo_shape":{"location":{"relation":"DISJOINT","shape":{...}}}} + * ],"boost":1.0}} + * ],"boost":1.0 + * }} + * ],"boost":1.0}}][_doc{f}#34], limit[], sort[] estimatedRowSize[329] + * + */ + public void testPushSpatialDistanceComplexPredicateWithEvalToSource() { + var query = """ + FROM airports + | EVAL poi_x = TO_GEOPOINT("POINT(12.565 55.673)") + | EVAL distance_x = ST_DISTANCE(location, poi_x) + | EVAL poi = poi_x + | EVAL distance = distance_x + | WHERE ((distance <= 600000 + AND distance >= 400000 + AND NOT (distance <= 500000 + AND distance >= 430000)) + OR (distance <= 300000 + AND distance >= 200000)) + AND NOT abbrev == "PLQ" + AND scalerank < 6 + """; + var plan = this.physicalPlan(query, airports); + var optimized = optimizedPlan(plan); + var topLimit = as(optimized, LimitExec.class); + var exchange = as(topLimit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var fieldExtract = as(project.child(), FieldExtractExec.class); + var limit2 = as(fieldExtract.child(), LimitExec.class); + var evalExec = as(limit2.child(), EvalExec.class); + var fieldExtract2 = as(evalExec.child(), FieldExtractExec.class); + var source = source(fieldExtract2.child()); + var bool = as(source.query(), BoolQueryBuilder.class); + assertThat("Expected boolean query of three MUST clauses", bool.must().size(), equalTo(2)); + assertThat("Expected boolean query of one FILTER clause", bool.filter().size(), equalTo(1)); + var boolDisjuntive = as(bool.filter().get(0), BoolQueryBuilder.class); + var disjuntiveQueryBuilders = boolDisjuntive.should().stream().filter(p -> p instanceof BoolQueryBuilder).toList(); + assertThat("Expected two disjunctive query builders", disjuntiveQueryBuilders.size(), equalTo(2)); + for (int i = 0; i < disjuntiveQueryBuilders.size(); i++) { + var subRangeBool = as(disjuntiveQueryBuilders.get(i), BoolQueryBuilder.class); + var shapeQueryBuilders = subRangeBool.must().stream().filter(p -> p instanceof SpatialRelatesQuery.ShapeQueryBuilder).toList(); + assertShapeQueryRange(shapeQueryBuilders, i == 0 ? 400000.0 : 200000.0, i == 0 ? 600000.0 : 300000.0); + } + } + + /** + * Plan: + * + * LimitExec[1000[INTEGER]] + * \_AggregateExec[[],[COUNT([2a][KEYWORD]) AS count],FINAL,[count{r}#17, seen{r}#18],null] + * \_ExchangeExec[[count{r}#17, seen{r}#18],true] + * \_FragmentExec[filter=null, estimatedRowSize=0, reducer=[], fragment=[ + * Aggregate[STANDARD,[],[COUNT([2a][KEYWORD]) AS count]] + * \_Filter[distance{r}#4 lt 1000000[INTEGER] AND distance{r}#4 gt 10000[INTEGER]] + * \_Eval[[ + * STDISTANCE(location{f}#13,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS distance + * ]] + * \_EsRelation[airports][abbrev{f}#9, city{f}#15, city_location{f}#16, count..]]] + * + * Optimized: + * + * LimitExec[1000[INTEGER]] + * \_AggregateExec[[],[COUNT([2a][KEYWORD]) AS count],FINAL,[count{r}#17, seen{r}#18],8] + * \_ExchangeExec[[count{r}#17, seen{r}#18],true] + * \_AggregateExec[[],[COUNT([2a][KEYWORD]) AS count],INITIAL,[count{r}#31, seen{r}#32],8] + * \_EvalExec[[ + * STDISTANCE(location{f}#13,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS distance + * ]] + * \_FieldExtractExec[location{f}#13][] + * \_EsQueryExec[airports], indexMode[standard], query[{ + * "bool":{ + * "must":[ + * {"geo_shape":{"location":{"relation":"INTERSECTS","shape":{...}}}}, + * {"geo_shape":{"location":{"relation":"DISJOINT","shape":{...}}}} + * ],"boost":1.0}}][_doc{f}#33], limit[], sort[] estimatedRowSize[33] + * + */ + public void testPushSpatialDistanceEvalWithSimpleStatsToSource() { + var query = """ + FROM airports + | EVAL distance = ST_DISTANCE(location, TO_GEOPOINT("POINT(12.565 55.673)")) + | WHERE distance < 1000000 AND distance > 10000 + | STATS count=COUNT(*) + """; + var plan = this.physicalPlan(query, airports); + var limit = as(plan, LimitExec.class); + var agg = as(limit.child(), AggregateExec.class); + var exchange = as(agg.child(), ExchangeExec.class); + var fragment = as(exchange.child(), FragmentExec.class); + var agg2 = as(fragment.fragment(), Aggregate.class); + var filter = as(agg2.child(), Filter.class); + + // Validate the filter condition (two distance filters) + var and = as(filter.condition(), And.class); + for (Expression expression : and.arguments()) { + var comp = as(expression, EsqlBinaryComparison.class); + var expectedComp = comp.equals(and.left()) ? LessThan.class : GreaterThan.class; + assertThat("filter contains expected binary comparison", comp, instanceOf(expectedComp)); + var distance = as(comp.left(), ReferenceAttribute.class); + assertThat(distance.name(), is("distance")); + } + + // Validate the eval (calculating distance) + var eval = as(filter.child(), Eval.class); + var alias = as(eval.fields().get(0), Alias.class); + assertThat(alias.name(), is("distance")); + as(eval.child(), EsRelation.class); + + // Now optimize the plan + var optimized = optimizedPlan(plan); + var topLimit = as(optimized, LimitExec.class); + var aggExec = as(topLimit.child(), AggregateExec.class); + var exchangeExec = as(aggExec.child(), ExchangeExec.class); + var aggExec2 = as(exchangeExec.child(), AggregateExec.class); + // TODO: Remove the eval entirely, since the distance is no longer required after filter pushdown + // Right now we don't mark the distance field as doc-values, introducing a performance hit + // However, fixing this to doc-values is not as good as removing the EVAL entirely, which is a more sensible optimization + var evalExec = as(aggExec2.child(), EvalExec.class); + var stDistance = as(evalExec.fields().get(0).child(), StDistance.class); + assertThat("Expect distance function to expect doc-values", stDistance.leftDocValues(), is(false)); + var source = assertChildIsGeoPointExtract(evalExec, false); + + // No sort is pushed down + assertThat(source.limit(), nullValue()); + assertThat(source.sorts(), nullValue()); + + // Fine-grained checks on the pushed down query + var bool = as(source.query(), BoolQueryBuilder.class); + var shapeQueryBuilders = bool.must().stream().filter(p -> p instanceof SpatialRelatesQuery.ShapeQueryBuilder).toList(); + assertShapeQueryRange(shapeQueryBuilders, 10000.0, 1000000.0); + } + + /** + * Plan: + * + * TopNExec[[Order[count{r}#10,DESC,FIRST], Order[country{f}#21,ASC,LAST]],1000[INTEGER],null] + * \_AggregateExec[[country{f}#21],[COUNT([2a][KEYWORD]) AS count, SPATIALCENTROID(location{f}#20) AS centroid, country{f}#21],FINA + * L,[country{f}#21, count{r}#24, seen{r}#25, xVal{r}#26, xDel{r}#27, yVal{r}#28, yDel{r}#29, count{r}#30],null] + * \_ExchangeExec[[country{f}#21, count{r}#24, seen{r}#25, xVal{r}#26, xDel{r}#27, yVal{r}#28, yDel{r}#29, count{r}#30],true] + * \_FragmentExec[filter=null, estimatedRowSize=0, reducer=[], fragment=[ + * Aggregate[STANDARD,[country{f}#21],[COUNT([2a][KEYWORD]) AS count, SPATIALCENTROID(location{f}#20) AS centroid, country{f} + * #21]] + * \_Filter[distance{r}#4 lt 1000000[INTEGER] AND distance{r}#4 gt 10000[INTEGER]] + * \_Eval[[STDISTANCE(location{f}#20,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) + * AS distance]] + * \_Filter[scalerank{f}#18 lt 6[INTEGER]] + * \_EsRelation[airports][abbrev{f}#16, city{f}#22, city_location{f}#23, coun..]]] + * + * Optimized: + * + * TopNExec[[Order[count{r}#10,DESC,FIRST], Order[country{f}#21,ASC,LAST]],1000[INTEGER],0] + * \_AggregateExec[[country{f}#21],[COUNT([2a][KEYWORD]) AS count, SPATIALCENTROID(location{f}#20) AS centroid, country{f}#21],FINA + * L,[country{f}#21, count{r}#24, seen{r}#25, xVal{r}#26, xDel{r}#27, yVal{r}#28, yDel{r}#29, count{r}#30],79] + * \_ExchangeExec[[country{f}#21, count{r}#24, seen{r}#25, xVal{r}#26, xDel{r}#27, yVal{r}#28, yDel{r}#29, count{r}#30],true] + * \_AggregateExec[[country{f}#21],[COUNT([2a][KEYWORD]) AS count, SPATIALCENTROID(location{f}#20) AS centroid, country{f}#21],INIT + * IAL,[country{f}#21, count{r}#49, seen{r}#50, xVal{r}#51, xDel{r}#52, yVal{r}#53, yDel{r}#54, count{r}#55],79] + * \_EvalExec[[STDISTANCE(location{f}#20,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) + * AS distance]] + * \_FieldExtractExec[location{f}#20][location{f}#20] + * \_EsQueryExec[airports], indexMode[standard], query[{ + * "bool":{ + * "filter":[ + * { + * "esql_single_value":{ + * "field":"scalerank", + * "next":{"range":{"scalerank":{"lt":6,"boost":1.0}}}, + * "source":"scalerank lt 6@3:31" + * } + * }, + * { + * "bool":{ + * "must":[ + * {"geo_shape":{ + * "location":{ + * "relation":"INTERSECTS", + * "shape":{"type":"Circle","radius":"1000000m","coordinates":[12.565,55.673]} + * } + * }}, + * {"geo_shape":{ + * "location":{ + * "relation":"DISJOINT", + * "shape":{"type":"Circle","radius":"10000m","coordinates":[12.565,55.673]} + * } + * }} + * ], + * "boost":1.0 + * } + * } + * ], + * "boost":1.0 + * }}][_doc{f}#56], limit[], sort[] estimatedRowSize[33] + * + */ + public void testPushSpatialDistanceEvalWithStatsToSource() { + var query = """ + FROM airports + | EVAL distance = ST_DISTANCE(location, TO_GEOPOINT("POINT(12.565 55.673)")) + | WHERE distance < 1000000 AND scalerank < 6 AND distance > 10000 + | STATS count=COUNT(*), centroid=ST_CENTROID_AGG(location) BY country + | SORT count DESC, country ASC + """; + var plan = this.physicalPlan(query, airports); + var topN = as(plan, TopNExec.class); + var agg = as(topN.child(), AggregateExec.class); + var exchange = as(agg.child(), ExchangeExec.class); + var fragment = as(exchange.child(), FragmentExec.class); + var agg2 = as(fragment.fragment(), Aggregate.class); + var filter = as(agg2.child(), Filter.class); + + // Validate the filter condition (two distance filters) + var and = as(filter.condition(), And.class); + for (Expression expression : and.arguments()) { + var comp = as(expression, EsqlBinaryComparison.class); + var expectedComp = comp.equals(and.left()) ? LessThan.class : GreaterThan.class; + assertThat("filter contains expected binary comparison", comp, instanceOf(expectedComp)); + var distance = as(comp.left(), ReferenceAttribute.class); + assertThat(distance.name(), is("distance")); + } + + // Validate the eval (calculating distance) + var eval = as(filter.child(), Eval.class); + var alias = as(eval.fields().get(0), Alias.class); + assertThat(alias.name(), is("distance")); + var filter2 = as(eval.child(), Filter.class); + + // Now optimize the plan + var optimized = optimizedPlan(plan); + var topLimit = as(optimized, TopNExec.class); + var aggExec = as(topLimit.child(), AggregateExec.class); + var exchangeExec = as(aggExec.child(), ExchangeExec.class); + var aggExec2 = as(exchangeExec.child(), AggregateExec.class); + // TODO: Remove the eval entirely, since the distance is no longer required after filter pushdown + var evalExec = as(aggExec2.child(), EvalExec.class); + var stDistance = as(evalExec.fields().get(0).child(), StDistance.class); + assertThat("Expect distance function to expect doc-values", stDistance.leftDocValues(), is(true)); + var source = assertChildIsGeoPointExtract(evalExec, true); + + // No sort is pushed down + assertThat(source.limit(), nullValue()); + assertThat(source.sorts(), nullValue()); + + // Fine-grained checks on the pushed down query + var bool = as(source.query(), BoolQueryBuilder.class); + var rangeQueryBuilders = bool.filter().stream().filter(p -> p instanceof SingleValueQuery.Builder).toList(); + assertThat("Expected one range query builder", rangeQueryBuilders.size(), equalTo(1)); + assertThat(((SingleValueQuery.Builder) rangeQueryBuilders.get(0)).field(), equalTo("scalerank")); + var filterBool = bool.filter().stream().filter(p -> p instanceof BoolQueryBuilder).toList(); + var fb = as(filterBool.get(0), BoolQueryBuilder.class); + var shapeQueryBuilders = fb.must().stream().filter(p -> p instanceof SpatialRelatesQuery.ShapeQueryBuilder).toList(); + assertShapeQueryRange(shapeQueryBuilders, 10000.0, 1000000.0); + } + + /** + * ProjectExec[[languages{f}#8, salary{f}#10]] + * \_TopNExec[[Order[salary{f}#10,DESC,FIRST]],10[INTEGER],0] + * \_ExchangeExec[[languages{f}#8, salary{f}#10],false] + * \_ProjectExec[[languages{f}#8, salary{f}#10]] + * \_FieldExtractExec[languages{f}#8, salary{f}#10][] + * \_EsQueryExec[test], + * indexMode[standard], + * query[][_doc{f}#25], + * limit[10], + * sort[[FieldSort[field=salary{f}#10, direction=DESC, nulls=FIRST]]] estimatedRowSize[24] + */ + public void testPushTopNToSource() { + var optimized = optimizedPlan(physicalPlan(""" + FROM test + | SORT salary DESC + | LIMIT 10 + | KEEP languages, salary + """)); + + var project = as(optimized, ProjectExec.class); + var topN = as(project.child(), TopNExec.class); + var exchange = asRemoteExchange(topN.child()); + + project = as(exchange.child(), ProjectExec.class); + assertThat(names(project.projections()), contains("languages", "salary")); + var extract = as(project.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("languages", "salary")); + var source = source(extract.child()); + assertThat(source.limit(), is(topN.limit())); + assertThat(source.sorts(), is(fieldSorts(topN.order()))); + + assertThat(source.limit(), is(l(10))); + assertThat(source.sorts().size(), is(1)); + EsQueryExec.Sort sort = source.sorts().get(0); + assertThat(sort.direction(), is(Order.OrderDirection.DESC)); + assertThat(name(sort.field()), is("salary")); + assertThat(sort.sortBuilder(), isA(FieldSortBuilder.class)); + assertNull(source.query()); + } + + /** + * ProjectExec[[languages{f}#9, salary{f}#11]] + * \_TopNExec[[Order[salary{f}#11,DESC,FIRST]],10[INTEGER],0] + * \_ExchangeExec[[languages{f}#9, salary{f}#11],false] + * \_ProjectExec[[languages{f}#9, salary{f}#11]] + * \_FieldExtractExec[languages{f}#9, salary{f}#11][] + * \_EsQueryExec[test], + * indexMode[standard], + * query[{"esql_single_value":{ + * "field":"salary", + * "next":{"range":{"salary":{"gt":50000,"boost":1.0}}}, + * "source":"salary > 50000@2:9" + * }}][_doc{f}#26], + * limit[10], + * sort[[FieldSort[field=salary{f}#11, direction=DESC, nulls=FIRST]]] estimatedRowSize[24] + */ + public void testPushTopNWithFilterToSource() { + var optimized = optimizedPlan(physicalPlan(""" + FROM test + | WHERE salary > 50000 + | SORT salary DESC + | LIMIT 10 + | KEEP languages, salary + """)); + + var project = as(optimized, ProjectExec.class); + var topN = as(project.child(), TopNExec.class); + var exchange = asRemoteExchange(topN.child()); + + project = as(exchange.child(), ProjectExec.class); + assertThat(names(project.projections()), contains("languages", "salary")); + var extract = as(project.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("languages", "salary")); + var source = source(extract.child()); + assertThat(source.limit(), is(topN.limit())); + assertThat(source.sorts(), is(fieldSorts(topN.order()))); + + assertThat(source.limit(), is(l(10))); + assertThat(source.sorts().size(), is(1)); + EsQueryExec.Sort sort = source.sorts().get(0); + assertThat(sort.direction(), is(Order.OrderDirection.DESC)); + assertThat(name(sort.field()), is("salary")); + assertThat(sort.sortBuilder(), isA(FieldSortBuilder.class)); + var rq = as(sv(source.query(), "salary"), RangeQueryBuilder.class); + assertThat(rq.fieldName(), equalTo("salary")); + assertThat(rq.from(), equalTo(50000)); + assertThat(rq.includeLower(), equalTo(false)); + assertThat(rq.to(), nullValue()); + } + + /** + * ProjectExec[[abbrev{f}#12321, name{f}#12322, location{f}#12325, country{f}#12326, city{f}#12327]] + * \_TopNExec[[Order[abbrev{f}#12321,ASC,LAST]],5[INTEGER],0] + * \_ExchangeExec[[abbrev{f}#12321, name{f}#12322, location{f}#12325, country{f}#12326, city{f}#12327],false] + * \_ProjectExec[[abbrev{f}#12321, name{f}#12322, location{f}#12325, country{f}#12326, city{f}#12327]] + * \_FieldExtractExec[abbrev{f}#12321, name{f}#12322, location{f}#12325, ..][] + * \_EsQueryExec[airports], + * indexMode[standard], + * query[][_doc{f}#12337], + * limit[5], + * sort[[FieldSort[field=abbrev{f}#12321, direction=ASC, nulls=LAST]]] estimatedRowSize[237] + */ + public void testPushTopNKeywordToSource() { + var optimized = optimizedPlan(physicalPlan(""" + FROM airports + | SORT abbrev + | LIMIT 5 + | KEEP abbrev, name, location, country, city + """, airports)); + + var project = as(optimized, ProjectExec.class); + var topN = as(project.child(), TopNExec.class); + var exchange = asRemoteExchange(topN.child()); + + project = as(exchange.child(), ProjectExec.class); + assertThat(names(project.projections()), contains("abbrev", "name", "location", "country", "city")); + var extract = as(project.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("abbrev", "name", "location", "country", "city")); + var source = source(extract.child()); + assertThat(source.limit(), is(topN.limit())); + assertThat(source.sorts(), is(fieldSorts(topN.order()))); + + assertThat(source.limit(), is(l(5))); + assertThat(source.sorts().size(), is(1)); + EsQueryExec.Sort sort = source.sorts().get(0); + assertThat(sort.direction(), is(Order.OrderDirection.ASC)); + assertThat(name(sort.field()), is("abbrev")); + assertThat(sort.sortBuilder(), isA(FieldSortBuilder.class)); + assertNull(source.query()); + } + + /** + * + * ProjectExec[[abbrev{f}#12, name{f}#13, location{f}#16, country{f}#17, city{f}#18, abbrev{f}#12 AS code]] + * \_TopNExec[[Order[abbrev{f}#12,ASC,LAST]],5[INTEGER],0] + * \_ExchangeExec[[abbrev{f}#12, name{f}#13, location{f}#16, country{f}#17, city{f}#18],false] + * \_ProjectExec[[abbrev{f}#12, name{f}#13, location{f}#16, country{f}#17, city{f}#18]] + * \_FieldExtractExec[abbrev{f}#12, name{f}#13, location{f}#16, country{f..][] + * \_EsQueryExec[airports], indexMode[standard], query[][_doc{f}#29], limit[5], + * sort[[FieldSort[field=abbrev{f}#12, direction=ASC, nulls=LAST]]] estimatedRowSize[237] + * + */ + public void testPushTopNAliasedKeywordToSource() { + var optimized = optimizedPlan(physicalPlan(""" + FROM airports + | EVAL code = abbrev + | SORT code + | LIMIT 5 + | KEEP abbrev, name, location, country, city, code + """, airports)); + + var project = as(optimized, ProjectExec.class); + assertThat(names(project.projections()), contains("abbrev", "name", "location", "country", "city", "code")); + var topN = as(project.child(), TopNExec.class); + var exchange = asRemoteExchange(topN.child()); + + project = as(exchange.child(), ProjectExec.class); + assertThat(names(project.projections()), contains("abbrev", "name", "location", "country", "city")); + var extract = as(project.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("abbrev", "name", "location", "country", "city")); + var source = source(extract.child()); + assertThat(source.limit(), is(topN.limit())); + assertThat(source.sorts(), is(fieldSorts(topN.order()))); + + assertThat(source.limit(), is(l(5))); + assertThat(source.sorts().size(), is(1)); + EsQueryExec.Sort sort = source.sorts().get(0); + assertThat(sort.direction(), is(Order.OrderDirection.ASC)); + assertThat(name(sort.field()), is("abbrev")); + assertThat(sort.sortBuilder(), isA(FieldSortBuilder.class)); + assertNull(source.query()); + } + + /** + * ProjectExec[[abbrev{f}#11, name{f}#12, location{f}#15, country{f}#16, city{f}#17]] + * \_TopNExec[[Order[distance{r}#4,ASC,LAST]],5[INTEGER],0] + * \_ExchangeExec[[abbrev{f}#11, name{f}#12, location{f}#15, country{f}#16, city{f}#17, distance{r}#4],false] + * \_ProjectExec[[abbrev{f}#11, name{f}#12, location{f}#15, country{f}#16, city{f}#17, distance{r}#4]] + * \_FieldExtractExec[abbrev{f}#11, name{f}#12, country{f}#16, city{f}#17][] + * \_EvalExec[[STDISTANCE(location{f}#15,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) + * AS distance]] + * \_FieldExtractExec[location{f}#15][] + * \_EsQueryExec[airports], + * indexMode[standard], + * query[][_doc{f}#28], + * limit[5], + * sort[[GeoDistanceSort[field=location{f}#15, direction=ASC, lat=55.673, lon=12.565]]] estimatedRowSize[245] + */ + public void testPushTopNDistanceToSource() { + var optimized = optimizedPlan(physicalPlan(""" + FROM airports + | EVAL distance = ST_DISTANCE(location, TO_GEOPOINT("POINT(12.565 55.673)")) + | SORT distance ASC + | LIMIT 5 + | KEEP abbrev, name, location, country, city + """, airports)); + + var project = as(optimized, ProjectExec.class); + var topN = as(project.child(), TopNExec.class); + var exchange = asRemoteExchange(topN.child()); + + project = as(exchange.child(), ProjectExec.class); + assertThat(names(project.projections()), contains("abbrev", "name", "location", "country", "city", "distance")); + var extract = as(project.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("abbrev", "name", "country", "city")); + var evalExec = as(extract.child(), EvalExec.class); + var alias = as(evalExec.fields().get(0), Alias.class); + assertThat(alias.name(), is("distance")); + var stDistance = as(alias.child(), StDistance.class); + assertThat(stDistance.left().toString(), startsWith("location")); + extract = as(evalExec.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("location")); + var source = source(extract.child()); + + // Assert that the TopN(distance) is pushed down as geo-sort(location) + assertThat(source.limit(), is(topN.limit())); + Set orderSet = orderAsSet(topN.order()); + Set sortsSet = sortsAsSet(source.sorts(), Map.of("location", "distance")); + assertThat(orderSet, is(sortsSet)); + + // Fine-grained checks on the pushed down sort + assertThat(source.limit(), is(l(5))); + assertThat(source.sorts().size(), is(1)); + EsQueryExec.Sort sort = source.sorts().get(0); + assertThat(sort.direction(), is(Order.OrderDirection.ASC)); + assertThat(name(sort.field()), is("location")); + assertThat(sort.sortBuilder(), isA(GeoDistanceSortBuilder.class)); + assertNull(source.query()); + } + + /** + * ProjectExec[[abbrev{f}#8, name{f}#9, location{f}#12, country{f}#13, city{f}#14]] + * \_TopNExec[[Order[$$order_by$0$0{r}#16,ASC,LAST]],5[INTEGER],0] + * \_ExchangeExec[[abbrev{f}#8, name{f}#9, location{f}#12, country{f}#13, city{f}#14, $$order_by$0$0{r}#16],false] + * \_ProjectExec[[abbrev{f}#8, name{f}#9, location{f}#12, country{f}#13, city{f}#14, $$order_by$0$0{r}#16]] + * \_FieldExtractExec[abbrev{f}#8, name{f}#9, country{f}#13, city{f}#14][] + * \_EvalExec[[ + * STDISTANCE(location{f}#12,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS $$order_by$0$0 + * ]] + * \_FieldExtractExec[location{f}#12][] + * \_EsQueryExec[airports], + * indexMode[standard], + * query[][_doc{f}#26], + * limit[5], + * sort[[GeoDistanceSort[field=location{f}#12, direction=ASC, lat=55.673, lon=12.565]]] estimatedRowSize[245] + */ + public void testPushTopNInlineDistanceToSource() { + var optimized = optimizedPlan(physicalPlan(""" + FROM airports + | SORT ST_DISTANCE(location, TO_GEOPOINT("POINT(12.565 55.673)")) ASC + | LIMIT 5 + | KEEP abbrev, name, location, country, city + """, airports)); + + var project = as(optimized, ProjectExec.class); + var topN = as(project.child(), TopNExec.class); + var exchange = asRemoteExchange(topN.child()); + + project = as(exchange.child(), ProjectExec.class); + assertThat(names(project.projections()), contains("abbrev", "name", "location", "country", "city", "$$order_by$0$0")); + var extract = as(project.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("abbrev", "name", "country", "city")); + var evalExec = as(extract.child(), EvalExec.class); + var alias = as(evalExec.fields().get(0), Alias.class); + assertThat(alias.name(), is("$$order_by$0$0")); + var stDistance = as(alias.child(), StDistance.class); + assertThat(stDistance.left().toString(), startsWith("location")); + extract = as(evalExec.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("location")); + var source = source(extract.child()); + + // Assert that the TopN(distance) is pushed down as geo-sort(location) + assertThat(source.limit(), is(topN.limit())); + Set orderSet = orderAsSet(topN.order()); + Set sortsSet = sortsAsSet(source.sorts(), Map.of("location", "$$order_by$0$0")); + assertThat(orderSet, is(sortsSet)); + + // Fine-grained checks on the pushed down sort + assertThat(source.limit(), is(l(5))); + assertThat(source.sorts().size(), is(1)); + EsQueryExec.Sort sort = source.sorts().get(0); + assertThat(sort.direction(), is(Order.OrderDirection.ASC)); + assertThat(name(sort.field()), is("location")); + assertThat(sort.sortBuilder(), isA(GeoDistanceSortBuilder.class)); + assertNull(source.query()); + } + + /** + * + * ProjectExec[[abbrev{f}#12, name{f}#13, location{f}#16, country{f}#17, city{f}#18]] + * \_TopNExec[[Order[distance{r}#4,ASC,LAST]],5[INTEGER],0] + * \_ExchangeExec[[abbrev{f}#12, name{f}#13, location{f}#16, country{f}#17, city{f}#18, distance{r}#4],false] + * \_ProjectExec[[abbrev{f}#12, name{f}#13, location{f}#16, country{f}#17, city{f}#18, distance{r}#4]] + * \_FieldExtractExec[abbrev{f}#12, name{f}#13, country{f}#17, city{f}#18][] + * \_EvalExec[[STDISTANCE(location{f}#16,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS distance + * ]] + * \_FieldExtractExec[location{f}#16][] + * \_EsQueryExec[airports], indexMode[standard], query[ + * { + * "geo_shape":{ + * "location":{ + * "relation":"DISJOINT", + * "shape":{ + * "type":"Circle", + * "radius":"50000.00000000001m", + * "coordinates":[12.565,55.673] + * } + * } + * } + * }][_doc{f}#29], limit[5], sort[[GeoDistanceSort[field=location{f}#16, direction=ASC, lat=55.673, lon=12.565]]] estimatedRowSize[245] + * + */ + public void testPushTopNDistanceWithFilterToSource() { + var optimized = optimizedPlan(physicalPlan(""" + FROM airports + | EVAL distance = ST_DISTANCE(location, TO_GEOPOINT("POINT(12.565 55.673)")) + | WHERE distance > 50000 + | SORT distance ASC + | LIMIT 5 + | KEEP abbrev, name, location, country, city + """, airports)); + + var project = as(optimized, ProjectExec.class); + var topN = as(project.child(), TopNExec.class); + var exchange = asRemoteExchange(topN.child()); + + project = as(exchange.child(), ProjectExec.class); + assertThat(names(project.projections()), contains("abbrev", "name", "location", "country", "city", "distance")); + var extract = as(project.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("abbrev", "name", "country", "city")); + var evalExec = as(extract.child(), EvalExec.class); + var alias = as(evalExec.fields().get(0), Alias.class); + assertThat(alias.name(), is("distance")); + var stDistance = as(alias.child(), StDistance.class); + assertThat(stDistance.left().toString(), startsWith("location")); + extract = as(evalExec.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("location")); + var source = source(extract.child()); + + // Assert that the TopN(distance) is pushed down as geo-sort(location) + assertThat(source.limit(), is(topN.limit())); + Set orderSet = orderAsSet(topN.order()); + Set sortsSet = sortsAsSet(source.sorts(), Map.of("location", "distance")); + assertThat(orderSet, is(sortsSet)); + + // Fine-grained checks on the pushed down sort + assertThat(source.limit(), is(l(5))); + assertThat(source.sorts().size(), is(1)); + EsQueryExec.Sort sort = source.sorts().get(0); + assertThat(sort.direction(), is(Order.OrderDirection.ASC)); + assertThat(name(sort.field()), is("location")); + assertThat(sort.sortBuilder(), isA(GeoDistanceSortBuilder.class)); + + var condition = as(source.query(), SpatialRelatesQuery.ShapeQueryBuilder.class); + assertThat("Geometry field name", condition.fieldName(), equalTo("location")); + assertThat("Spatial relationship", condition.relation(), equalTo(ShapeRelation.DISJOINT)); + assertThat("Geometry is Circle", condition.shape().type(), equalTo(ShapeType.CIRCLE)); + var circle = as(condition.shape(), Circle.class); + assertThat("Circle center-x", circle.getX(), equalTo(12.565)); + assertThat("Circle center-y", circle.getY(), equalTo(55.673)); + assertThat("Circle radius for predicate", circle.getRadiusMeters(), closeTo(50000.0, 1e-9)); + } + + /** + * + * ProjectExec[[abbrev{f}#14, name{f}#15, location{f}#18, country{f}#19, city{f}#20]] + * \_TopNExec[[Order[distance{r}#4,ASC,LAST]],5[INTEGER],0] + * \_ExchangeExec[[abbrev{f}#14, name{f}#15, location{f}#18, country{f}#19, city{f}#20, distance{r}#4],false] + * \_ProjectExec[[abbrev{f}#14, name{f}#15, location{f}#18, country{f}#19, city{f}#20, distance{r}#4]] + * \_FieldExtractExec[abbrev{f}#14, name{f}#15, country{f}#19, city{f}#20][] + * \_EvalExec[[STDISTANCE(location{f}#18,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) + * AS distance]] + * \_FieldExtractExec[location{f}#18][] + * \_EsQueryExec[airports], indexMode[standard], query[{ + * "bool":{ + * "filter":[ + * { + * "esql_single_value":{ + * "field":"scalerank", + * "next":{"range":{"scalerank":{"lt":6,"boost":1.0}}}, + * "source":"scalerank lt 6@3:31" + * } + * }, + * { + * "bool":{ + * "must":[ + * {"geo_shape":{ + * "location":{ + * "relation":"INTERSECTS", + * "shape":{"type":"Circle","radius":"499999.99999999994m","coordinates":[12.565,55.673]} + * } + * }}, + * {"geo_shape":{ + * "location":{ + * "relation":"DISJOINT", + * "shape":{"type":"Circle","radius":"10000.000000000002m","coordinates":[12.565,55.673]} + * } + * }} + * ], + * "boost":1.0 + * } + * } + * ], + * "boost":1.0 + * }}][_doc{f}#31], limit[5], sort[[ + * GeoDistanceSort[field=location{f}#18, direction=ASC, lat=55.673, lon=12.565] + * ]] estimatedRowSize[245] + * + */ + public void testPushTopNDistanceWithCompoundFilterToSource() { + var optimized = optimizedPlan(physicalPlan(""" + FROM airports + | EVAL distance = ST_DISTANCE(location, TO_GEOPOINT("POINT(12.565 55.673)")) + | WHERE distance < 500000 AND scalerank < 6 AND distance > 10000 + | SORT distance ASC + | LIMIT 5 + | KEEP abbrev, name, location, country, city + """, airports)); + + var project = as(optimized, ProjectExec.class); + var topN = as(project.child(), TopNExec.class); + var exchange = asRemoteExchange(topN.child()); + + project = as(exchange.child(), ProjectExec.class); + assertThat(names(project.projections()), contains("abbrev", "name", "location", "country", "city", "distance")); + var extract = as(project.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("abbrev", "name", "country", "city")); + var evalExec = as(extract.child(), EvalExec.class); + var alias = as(evalExec.fields().get(0), Alias.class); + assertThat(alias.name(), is("distance")); + var stDistance = as(alias.child(), StDistance.class); + assertThat(stDistance.left().toString(), startsWith("location")); + extract = as(evalExec.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("location")); + var source = source(extract.child()); + + // Assert that the TopN(distance) is pushed down as geo-sort(location) + assertThat(source.limit(), is(topN.limit())); + Set orderSet = orderAsSet(topN.order()); + Set sortsSet = sortsAsSet(source.sorts(), Map.of("location", "distance")); + assertThat(orderSet, is(sortsSet)); + + // Fine-grained checks on the pushed down sort + assertThat(source.limit(), is(l(5))); + assertThat(source.sorts().size(), is(1)); + EsQueryExec.Sort sort = source.sorts().get(0); + assertThat(sort.direction(), is(Order.OrderDirection.ASC)); + assertThat(name(sort.field()), is("location")); + assertThat(sort.sortBuilder(), isA(GeoDistanceSortBuilder.class)); + + // Fine-grained checks on the pushed down query + var bool = as(source.query(), BoolQueryBuilder.class); + var rangeQueryBuilders = bool.filter().stream().filter(p -> p instanceof SingleValueQuery.Builder).toList(); + assertThat("Expected one range query builder", rangeQueryBuilders.size(), equalTo(1)); + assertThat(((SingleValueQuery.Builder) rangeQueryBuilders.get(0)).field(), equalTo("scalerank")); + var filterBool = bool.filter().stream().filter(p -> p instanceof BoolQueryBuilder).toList(); + var fb = as(filterBool.get(0), BoolQueryBuilder.class); + var shapeQueryBuilders = fb.must().stream().filter(p -> p instanceof SpatialRelatesQuery.ShapeQueryBuilder).toList(); + assertShapeQueryRange(shapeQueryBuilders, 10000.0, 500000.0); + } + + /** + * This test shows that with an additional EVAL used in the filter, we can no longer push down the SORT distance. + * TODO: This could be optimized in future work. Consider moving much of EnableSpatialDistancePushdown into logical planning. + * + * ProjectExec[[abbrev{f}#23, name{f}#24, location{f}#27, country{f}#28, city{f}#29, scalerank{f}#25 AS scale]] + * \_TopNExec[[Order[distance{r}#4,ASC,LAST], Order[scalerank{f}#25,ASC,LAST]],5[INTEGER],0] + * \_ExchangeExec[[abbrev{f}#23, name{f}#24, location{f}#27, country{f}#28, city{f}#29, scalerank{f}#25, distance{r}#4],false] + * \_ProjectExec[[abbrev{f}#23, name{f}#24, location{f}#27, country{f}#28, city{f}#29, scalerank{f}#25, distance{r}#4]] + * \_FieldExtractExec[abbrev{f}#23, name{f}#24, country{f}#28, city{f}#29][] + * \_TopNExec[[Order[distance{r}#4,ASC,LAST], Order[scalerank{f}#25,ASC,LAST]],5[INTEGER],208] + * \_FieldExtractExec[scalerank{f}#25][] + * \_FilterExec[SUBSTRING(position{r}#7,1[INTEGER],5[INTEGER]) == [50 4f 49 4e 54][KEYWORD]] + * \_EvalExec[[ + * STDISTANCE(location{f}#27,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS distance, + * TOSTRING(location{f}#27) AS position + * ]] + * \_FieldExtractExec[location{f}#27][] + * \_EsQueryExec[airports], indexMode[standard], query[{ + * "bool":{"filter":[ + * {"esql_single_value":{"field":"scalerank","next":{"range":{"scalerank":{"lt":6,"boost":1.0}}},"source":...}}, + * {"bool":{"must":[ + * {"geo_shape":{"location":{"relation":"INTERSECTS","shape":{...}}}}, + * {"geo_shape":{"location":{"relation":"DISJOINT","shape":{...}}}} + * ],"boost":1.0}}],"boost":1.0}}][_doc{f}#42], limit[], sort[] estimatedRowSize[87] + * + */ + public void testPushTopNDistanceAndNonPushableEvalWithCompoundFilterToSource() { + var optimized = optimizedPlan(physicalPlan(""" + FROM airports + | EVAL distance = ST_DISTANCE(location, TO_GEOPOINT("POINT(12.565 55.673)")), position = location::keyword, scale = scalerank + | WHERE distance < 500000 AND SUBSTRING(position, 1, 5) == "POINT" AND distance > 10000 AND scale < 6 + | SORT distance ASC, scale ASC + | LIMIT 5 + | KEEP abbrev, name, location, country, city, scale + """, airports)); + + var project = as(optimized, ProjectExec.class); + var topN = as(project.child(), TopNExec.class); + var exchange = asRemoteExchange(topN.child()); + + project = as(exchange.child(), ProjectExec.class); + assertThat(names(project.projections()), contains("abbrev", "name", "location", "country", "city", "scalerank", "distance")); + var extract = as(project.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("abbrev", "name", "country", "city")); + var topNChild = as(extract.child(), TopNExec.class); + extract = as(topNChild.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("scalerank")); + var filter = as(extract.child(), FilterExec.class); + var evalExec = as(filter.child(), EvalExec.class); + assertThat(evalExec.fields().size(), is(2)); + var aliasDistance = as(evalExec.fields().get(0), Alias.class); + assertThat(aliasDistance.name(), is("distance")); + var stDistance = as(aliasDistance.child(), StDistance.class); + assertThat(stDistance.left().toString(), startsWith("location")); + var aliasPosition = as(evalExec.fields().get(1), Alias.class); + assertThat(aliasPosition.name(), is("position")); + extract = as(evalExec.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("location")); + var source = source(extract.child()); + + // In this example TopN is not pushed down (we can optimize that in later work) + assertThat(source.limit(), nullValue()); + assertThat(source.sorts(), nullValue()); + + // Fine-grained checks on the pushed down query + var bool = as(source.query(), BoolQueryBuilder.class); + var rangeQueryBuilders = bool.filter().stream().filter(p -> p instanceof SingleValueQuery.Builder).toList(); + assertThat("Expected one range query builder", rangeQueryBuilders.size(), equalTo(1)); + assertThat(((SingleValueQuery.Builder) rangeQueryBuilders.get(0)).field(), equalTo("scalerank")); + var filterBool = bool.filter().stream().filter(p -> p instanceof BoolQueryBuilder).toList(); + var fb = as(filterBool.get(0), BoolQueryBuilder.class); + var shapeQueryBuilders = fb.must().stream().filter(p -> p instanceof SpatialRelatesQuery.ShapeQueryBuilder).toList(); + assertShapeQueryRange(shapeQueryBuilders, 10000.0, 500000.0); + } + + /** + * This test further shows that with a non-aliasing function, with the same name, less gets pushed down. + * + * ProjectExec[[abbrev{f}#23, name{f}#24, location{f}#27, country{f}#28, city{f}#29, scale{r}#10]] + * \_TopNExec[[Order[distance{r}#4,ASC,LAST], Order[scale{r}#10,ASC,LAST]],5[INTEGER],0] + * \_ExchangeExec[[abbrev{f}#23, name{f}#24, location{f}#27, country{f}#28, city{f}#29, scale{r}#10, distance{r}#4],false] + * \_ProjectExec[[abbrev{f}#23, name{f}#24, location{f}#27, country{f}#28, city{f}#29, scale{r}#10, distance{r}#4]] + * \_FieldExtractExec[abbrev{f}#23, name{f}#24, country{f}#28, city{f}#29][] + * \_TopNExec[[Order[distance{r}#4,ASC,LAST], Order[scale{r}#10,ASC,LAST]],5[INTEGER],208] + * \_FilterExec[ + * SUBSTRING(position{r}#7,1[INTEGER],5[INTEGER]) == [50 4f 49 4e 54][KEYWORD] + * AND scale{r}#10 > 3[INTEGER] + * ] + * \_EvalExec[[ + * STDISTANCE(location{f}#27,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS distance, + * TOSTRING(location{f}#27) AS position, + * 10[INTEGER] - scalerank{f}#25 AS scale + * ]] + * \_FieldExtractExec[location{f}#27, scalerank{f}#25][] + * \_EsQueryExec[airports], indexMode[standard], query[{ + * "bool":{"must":[ + * {"geo_shape":{"location":{"relation":"INTERSECTS","shape":{...}}}}, + * {"geo_shape":{"location":{"relation":"DISJOINT","shape":{...}}}} + * ],"boost":1.0}}][_doc{f}#42], limit[], sort[] estimatedRowSize[91] + * + */ + public void testPushTopNDistanceAndNonPushableEvalsWithCompoundFilterToSource() { + var optimized = optimizedPlan(physicalPlan(""" + FROM airports + | EVAL distance = ST_DISTANCE(location, TO_GEOPOINT("POINT(12.565 55.673)")), + position = location::keyword, scalerank = 10 - scalerank + | WHERE distance < 500000 AND SUBSTRING(position, 1, 5) == "POINT" AND distance > 10000 AND scalerank > 3 + | SORT distance ASC, scalerank ASC + | LIMIT 5 + | KEEP abbrev, name, location, country, city, scalerank + """, airports)); + var project = as(optimized, ProjectExec.class); + var topN = as(project.child(), TopNExec.class); + var exchange = asRemoteExchange(topN.child()); + + project = as(exchange.child(), ProjectExec.class); + assertThat(names(project.projections()), contains("abbrev", "name", "location", "country", "city", "scalerank", "distance")); + var extract = as(project.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("abbrev", "name", "country", "city")); + var topNChild = as(extract.child(), TopNExec.class); + var filter = as(topNChild.child(), FilterExec.class); + assertThat(filter.condition(), isA(And.class)); + var and = (And) filter.condition(); + assertThat(and.left(), isA(Equals.class)); + assertThat(and.right(), isA(GreaterThan.class)); + var evalExec = as(filter.child(), EvalExec.class); + assertThat(evalExec.fields().size(), is(3)); + var aliasDistance = as(evalExec.fields().get(0), Alias.class); + assertThat(aliasDistance.name(), is("distance")); + var stDistance = as(aliasDistance.child(), StDistance.class); + assertThat(stDistance.left().toString(), startsWith("location")); + var aliasPosition = as(evalExec.fields().get(1), Alias.class); + assertThat(aliasPosition.name(), is("position")); + var aliasScale = as(evalExec.fields().get(2), Alias.class); + assertThat(aliasScale.name(), is("scalerank")); + extract = as(evalExec.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("location", "scalerank")); + var source = source(extract.child()); + + // In this example TopN is not pushed down (we can optimize that in later work) + assertThat(source.limit(), nullValue()); + assertThat(source.sorts(), nullValue()); + + // Fine-grained checks on the pushed down query, only the spatial distance gets pushed down, not the scale filter + var bool = as(source.query(), BoolQueryBuilder.class); + var shapeQueryBuilders = bool.must().stream().filter(p -> p instanceof SpatialRelatesQuery.ShapeQueryBuilder).toList(); + assertShapeQueryRange(shapeQueryBuilders, 10000.0, 500000.0); + } + + /** + * This test shows that with if the top level AND'd predicate contains a non-pushable component, we should not push anything. + * + * ProjectExec[[abbrev{f}#8612, name{f}#8613, location{f}#8616, country{f}#8617, city{f}#8618, scalerank{f}#8614 AS scale]] + * \_TopNExec[[Order[distance{r}#8596,ASC,LAST], Order[scalerank{f}#8614,ASC,LAST]],5[INTEGER],0] + * \_ExchangeExec[[abbrev{f}#8612, name{f}#8613, location{f}#8616, country{f}#8617, city{f}#8618, + * scalerank{f}#8614, distance{r}#8596 + * ],false] + * \_ProjectExec[[abbrev{f}#8612, name{f}#8613, location{f}#8616, country{f}#8617, city{f}#8618, + * scalerank{f}#8614, distance{r}#8596 + * ]] + * \_FieldExtractExec[abbrev{f}#8612, name{f}#8613, country{f}#8617, city..][] + * \_TopNExec[[Order[distance{r}#8596,ASC,LAST], Order[scalerank{f}#8614,ASC,LAST]],5[INTEGER],208] + * \_FilterExec[ + * distance{r}#8596 < 500000[INTEGER] + * AND distance{r}#8596 > 10000[INTEGER] + * AND scalerank{f}#8614 < 6[INTEGER] + * OR SUBSTRING(TOSTRING(location{f}#8616),1[INTEGER],5[INTEGER]) == [50 4f 49 4e 54][KEYWORD] + * ] + * \_FieldExtractExec[scalerank{f}#8614][] + * \_EvalExec[[ + * STDISTANCE(location{f}#8616,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS distance + * ]] + * \_FieldExtractExec[location{f}#8616][] + * \_EsQueryExec[airports], indexMode[standard], query[][_doc{f}#8630], limit[], sort[] estimatedRowSize[37] + * + */ + public void testPushTopNDistanceWithCompoundFilterToSourceAndDisjunctiveNonPushableEval() { + var optimized = optimizedPlan(physicalPlan(""" + FROM airports + | EVAL distance = ST_DISTANCE(location, TO_GEOPOINT("POINT(12.565 55.673)")), scale = scalerank + | WHERE distance < 500000 AND distance > 10000 AND scale < 6 OR SUBSTRING(location::keyword, 1, 5) == "POINT" + | SORT distance ASC, scale ASC + | LIMIT 5 + | KEEP abbrev, name, location, country, city, scale + """, airports)); + + var project = as(optimized, ProjectExec.class); + var topN = as(project.child(), TopNExec.class); + var exchange = asRemoteExchange(topN.child()); + + project = as(exchange.child(), ProjectExec.class); + assertThat(names(project.projections()), contains("abbrev", "name", "location", "country", "city", "scalerank", "distance")); + var extract = as(project.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("abbrev", "name", "country", "city")); + var topNChild = as(extract.child(), TopNExec.class); + var filter = as(topNChild.child(), FilterExec.class); + assertThat(filter.condition(), isA(Or.class)); + var filterOr = (Or) filter.condition(); + assertThat(filterOr.left(), isA(And.class)); + assertThat(filterOr.right(), isA(Equals.class)); + extract = as(filter.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("scalerank")); + var evalExec = as(extract.child(), EvalExec.class); + assertThat(evalExec.fields().size(), is(1)); + var aliasDistance = as(evalExec.fields().get(0), Alias.class); + assertThat(aliasDistance.name(), is("distance")); + var stDistance = as(aliasDistance.child(), StDistance.class); + assertThat(stDistance.left().toString(), startsWith("location")); + extract = as(evalExec.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("location")); + var source = source(extract.child()); + + // In this example neither TopN not filter is pushed down + assertThat(source.limit(), nullValue()); + assertThat(source.sorts(), nullValue()); + assertThat(source.query(), nullValue()); + } + + /** + * + * ProjectExec[[abbrev{f}#15, name{f}#16, location{f}#19, country{f}#20, city{f}#21]] + * \_TopNExec[[Order[scalerank{f}#17,ASC,LAST], Order[distance{r}#4,ASC,LAST]],15[INTEGER],0] + * \_ExchangeExec[[abbrev{f}#15, name{f}#16, location{f}#19, country{f}#20, city{f}#21, scalerank{f}#17, distance{r}#4],false] + * \_ProjectExec[[abbrev{f}#15, name{f}#16, location{f}#19, country{f}#20, city{f}#21, scalerank{f}#17, distance{r}#4]] + * \_FieldExtractExec[abbrev{f}#15, name{f}#16, country{f}#20, city{f}#21, ..][] + * \_EvalExec[[STDISTANCE(location{f}#19,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) + * AS distance]] + * \_FieldExtractExec[location{f}#19][] + * \_EsQueryExec[airports], indexMode[standard], query[{ + * "bool":{ + * "filter":[ + * {"esql_single_value":{"field":"scalerank",...,"source":"scalerank lt 6@3:31"}}, + * {"bool":{"must":[ + * {"geo_shape":{"location":{"relation":"INTERSECTS","shape":{...}}}}, + * {"geo_shape":{"location":{"relation":"DISJOINT","shape":{...}}}} + * ],"boost":1.0}} + * ],"boost":1.0 + * } + * }][_doc{f}#32], limit[], sort[[ + * FieldSort[field=scalerank{f}#17, direction=ASC, nulls=LAST], + * GeoDistanceSort[field=location{f}#19, direction=ASC, lat=55.673, lon=12.565] + * ]] estimatedRowSize[37] + * + */ + public void testPushCompoundTopNDistanceWithCompoundFilterToSource() { + var optimized = optimizedPlan(physicalPlan(""" + FROM airports + | EVAL distance = ST_DISTANCE(location, TO_GEOPOINT("POINT(12.565 55.673)")) + | WHERE distance < 500000 AND scalerank < 6 AND distance > 10000 + | SORT scalerank, distance + | LIMIT 15 + | KEEP abbrev, name, location, country, city + """, airports)); + + var project = as(optimized, ProjectExec.class); + assertThat(names(project.projections()), contains("abbrev", "name", "location", "country", "city")); + var topN = as(project.child(), TopNExec.class); + var exchange = asRemoteExchange(topN.child()); + + project = as(exchange.child(), ProjectExec.class); + assertThat(names(project.projections()), contains("abbrev", "name", "location", "country", "city", "scalerank", "distance")); + var extract = as(project.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("abbrev", "name", "country", "city", "scalerank")); + var evalExec = as(extract.child(), EvalExec.class); + var alias = as(evalExec.fields().get(0), Alias.class); + assertThat(alias.name(), is("distance")); + var stDistance = as(alias.child(), StDistance.class); + assertThat(stDistance.left().toString(), startsWith("location")); + extract = as(evalExec.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("location")); + var source = source(extract.child()); + + // Assert that the TopN(distance) is pushed down as geo-sort(location) + assertThat(source.limit(), is(topN.limit())); + Set orderSet = orderAsSet(topN.order()); + Set sortsSet = sortsAsSet(source.sorts(), Map.of("location", "distance")); + assertThat(orderSet, is(sortsSet)); + + // Fine-grained checks on the pushed down sort + assertThat(source.limit(), is(l(15))); + assertThat(source.sorts().size(), is(2)); + EsQueryExec.Sort fieldSort = source.sorts().get(0); + assertThat(fieldSort.direction(), is(Order.OrderDirection.ASC)); + assertThat(name(fieldSort.field()), is("scalerank")); + assertThat(fieldSort.sortBuilder(), isA(FieldSortBuilder.class)); + EsQueryExec.Sort distSort = source.sorts().get(1); + assertThat(distSort.direction(), is(Order.OrderDirection.ASC)); + assertThat(name(distSort.field()), is("location")); + assertThat(distSort.sortBuilder(), isA(GeoDistanceSortBuilder.class)); + + // Fine-grained checks on the pushed down query + var bool = as(source.query(), BoolQueryBuilder.class); + var rangeQueryBuilders = bool.filter().stream().filter(p -> p instanceof SingleValueQuery.Builder).toList(); + assertThat("Expected one range query builder", rangeQueryBuilders.size(), equalTo(1)); + assertThat(((SingleValueQuery.Builder) rangeQueryBuilders.get(0)).field(), equalTo("scalerank")); + var filterBool = bool.filter().stream().filter(p -> p instanceof BoolQueryBuilder).toList(); + var fb = as(filterBool.get(0), BoolQueryBuilder.class); + var shapeQueryBuilders = fb.must().stream().filter(p -> p instanceof SpatialRelatesQuery.ShapeQueryBuilder).toList(); + assertShapeQueryRange(shapeQueryBuilders, 10000.0, 500000.0); + } + + /** + * + * TopNExec[[Order[scalerank{f}#15,ASC,LAST], Order[distance{r}#7,ASC,LAST]],15[INTEGER],0] + * \_ExchangeExec[[abbrev{f}#13, city{f}#19, city_location{f}#20, country{f}#18, location{f}#17, name{f}#14, scalerank{f}#15, + * type{f}#16, poi{r}#3, distance{r}#7],false] + * \_ProjectExec[[abbrev{f}#13, city{f}#19, city_location{f}#20, country{f}#18, location{f}#17, name{f}#14, scalerank{f}#15, + * type{f}#16, poi{r}#3, distance{r}#7]] + * \_FieldExtractExec[abbrev{f}#13, city{f}#19, city_location{f}#20, coun..][] + * \_EvalExec[[ + * [1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT] AS poi, + * STDISTANCE(location{f}#17,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS distance + * ]] + * \_FieldExtractExec[location{f}#17][] + * \_EsQueryExec[airports], indexMode[standard], query[{ + * "bool":{ + * "filter":[ + * {"esql_single_value":{"field":"scalerank",...,"source":"scalerank lt 6@4:31"}}, + * {"bool":{"must":[ + * {"geo_shape":{"location":{"relation":"INTERSECTS","shape":{...}}}}, + * {"geo_shape":{"location":{"relation":"DISJOINT","shape":{...}}}} + * ],"boost":1.0}} + * ],"boost":1.0 + * } + * }][_doc{f}#31], limit[15], sort[[ + * FieldSort[field=scalerank{f}#15, direction=ASC, nulls=LAST], + * GeoDistanceSort[field=location{f}#17, direction=ASC, lat=55.673, lon=12.565] + * ]] estimatedRowSize[341] + * + */ + public void testPushCompoundTopNDistanceWithCompoundFilterAndCompoundEvalToSource() { + var optimized = optimizedPlan(physicalPlan(""" + FROM airports + | EVAL poi = TO_GEOPOINT("POINT(12.565 55.673)") + | EVAL distance = ST_DISTANCE(location, poi) + | WHERE distance < 500000 AND scalerank < 6 AND distance > 10000 + | SORT scalerank, distance + | LIMIT 15 + """, airports)); + + var topN = as(optimized, TopNExec.class); + var exchange = asRemoteExchange(topN.child()); + + var project = as(exchange.child(), ProjectExec.class); + assertThat( + names(project.projections()), + containsInAnyOrder("abbrev", "name", "type", "location", "country", "city", "city_location", "scalerank", "poi", "distance") + ); + var extract = as(project.child(), FieldExtractExec.class); + assertThat( + names(extract.attributesToExtract()), + containsInAnyOrder("abbrev", "name", "type", "country", "city", "city_location", "scalerank") + ); + var evalExec = as(extract.child(), EvalExec.class); + assertThat(evalExec.fields().size(), is(2)); + var alias1 = as(evalExec.fields().get(0), Alias.class); + assertThat(alias1.name(), is("poi")); + var poi = as(alias1.child(), Literal.class); + assertThat(poi.fold(), instanceOf(BytesRef.class)); + var alias2 = as(evalExec.fields().get(1), Alias.class); + assertThat(alias2.name(), is("distance")); + var stDistance = as(alias2.child(), StDistance.class); + var location = as(stDistance.left(), FieldAttribute.class); + assertThat(location.fieldName(), is("location")); + var poiRef = as(stDistance.right(), Literal.class); + assertThat(poiRef.fold(), instanceOf(BytesRef.class)); + assertThat(poiRef.fold().toString(), is(poi.fold().toString())); + extract = as(evalExec.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("location")); + var source = source(extract.child()); + + // Assert that the TopN(distance) is pushed down as geo-sort(location) + assertThat(source.limit(), is(topN.limit())); + Set orderSet = orderAsSet(topN.order()); + Set sortsSet = sortsAsSet(source.sorts(), Map.of("location", "distance")); + assertThat(orderSet, is(sortsSet)); + + // Fine-grained checks on the pushed down sort + assertThat(source.limit(), is(l(15))); + assertThat(source.sorts().size(), is(2)); + EsQueryExec.Sort fieldSort = source.sorts().get(0); + assertThat(fieldSort.direction(), is(Order.OrderDirection.ASC)); + assertThat(name(fieldSort.field()), is("scalerank")); + assertThat(fieldSort.sortBuilder(), isA(FieldSortBuilder.class)); + EsQueryExec.Sort distSort = source.sorts().get(1); + assertThat(distSort.direction(), is(Order.OrderDirection.ASC)); + assertThat(name(distSort.field()), is("location")); + assertThat(distSort.sortBuilder(), isA(GeoDistanceSortBuilder.class)); + + // Fine-grained checks on the pushed down query + var bool = as(source.query(), BoolQueryBuilder.class); + var rangeQueryBuilders = bool.filter().stream().filter(p -> p instanceof SingleValueQuery.Builder).toList(); + assertThat("Expected one range query builder", rangeQueryBuilders.size(), equalTo(1)); + assertThat(((SingleValueQuery.Builder) rangeQueryBuilders.get(0)).field(), equalTo("scalerank")); + var filterBool = bool.filter().stream().filter(p -> p instanceof BoolQueryBuilder).toList(); + var fb = as(filterBool.get(0), BoolQueryBuilder.class); + var shapeQueryBuilders = fb.must().stream().filter(p -> p instanceof SpatialRelatesQuery.ShapeQueryBuilder).toList(); + assertShapeQueryRange(shapeQueryBuilders, 10000.0, 500000.0); + } + + public void testPushCompoundTopNDistanceWithDeeplyNestedCompoundEvalToSource() { + var optimized = optimizedPlan(physicalPlan(""" + FROM airports + | EVAL poi = TO_GEOPOINT("POINT(12.565 55.673)") + | EVAL poi2 = poi, poi3 = poi2 + | EVAL loc2 = location + | EVAL loc3 = loc2 + | EVAL dist = ST_DISTANCE(loc3, poi3) + | EVAL distance = dist + | SORT scalerank, distance + | LIMIT 15 + """, airports)); + + var topN = as(optimized, TopNExec.class); + var exchange = asRemoteExchange(topN.child()); + + var project = as(exchange.child(), ProjectExec.class); + assertThat( + names(project.projections()), + containsInAnyOrder( + "abbrev", + "name", + "type", + "location", + "country", + "city", + "city_location", + "scalerank", + "poi", + "poi2", + "poi3", + "loc2", + "loc3", + "dist", + "distance" + ) + ); + var extract = as(project.child(), FieldExtractExec.class); + assertThat( + names(extract.attributesToExtract()), + containsInAnyOrder("abbrev", "name", "type", "country", "city", "city_location", "scalerank") + ); + var evalExec = as(extract.child(), EvalExec.class); + assertThat(evalExec.fields().size(), is(7)); + var alias1 = as(evalExec.fields().get(0), Alias.class); + assertThat(alias1.name(), is("poi")); + var poi = as(alias1.child(), Literal.class); + assertThat(poi.fold(), instanceOf(BytesRef.class)); + var alias4 = as(evalExec.fields().get(3), Alias.class); + assertThat(alias4.name(), is("loc2")); + as(alias4.child(), FieldAttribute.class); + var alias5 = as(evalExec.fields().get(4), Alias.class); + assertThat(alias5.name(), is("loc3")); + as(alias5.child(), ReferenceAttribute.class); + var alias6 = as(evalExec.fields().get(5), Alias.class); + assertThat(alias6.name(), is("dist")); + var stDistance = as(alias6.child(), StDistance.class); + var refLocation = as(stDistance.left(), ReferenceAttribute.class); + assertThat(refLocation.name(), is("loc3")); + var poiRef = as(stDistance.right(), Literal.class); + assertThat(poiRef.fold(), instanceOf(BytesRef.class)); + assertThat(poiRef.fold().toString(), is(poi.fold().toString())); + var alias7 = as(evalExec.fields().get(6), Alias.class); + assertThat(alias7.name(), is("distance")); + as(alias7.child(), ReferenceAttribute.class); + extract = as(evalExec.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("location")); + var source = source(extract.child()); + + // Assert that the TopN(distance) is pushed down as geo-sort(location) + assertThat(source.limit(), is(topN.limit())); + Set orderSet = orderAsSet(topN.order()); + Set sortsSet = sortsAsSet(source.sorts(), Map.of("location", "distance")); + assertThat(orderSet, is(sortsSet)); + + // Fine-grained checks on the pushed down sort + assertThat(source.limit(), is(l(15))); + assertThat(source.sorts().size(), is(2)); + EsQueryExec.Sort fieldSort = source.sorts().get(0); + assertThat(fieldSort.direction(), is(Order.OrderDirection.ASC)); + assertThat(name(fieldSort.field()), is("scalerank")); + assertThat(fieldSort.sortBuilder(), isA(FieldSortBuilder.class)); + EsQueryExec.Sort distSort = source.sorts().get(1); + assertThat(distSort.direction(), is(Order.OrderDirection.ASC)); + assertThat(name(distSort.field()), is("location")); + assertThat(distSort.sortBuilder(), isA(GeoDistanceSortBuilder.class)); + + // No filter is pushed down + assertThat(source.query(), nullValue()); + } + + /** + * TopNExec[[Order[scalerank{f}#15,ASC,LAST], Order[distance{r}#7,ASC,LAST]],15[INTEGER],0] + * \_ExchangeExec[[abbrev{f}#13, city{f}#19, city_location{f}#20, country{f}#18, location{f}#17, name{f}#14, scalerank{f}#15, + * type{f}#16, poi{r}#3, distance{r}#7],false] + * \_ProjectExec[[abbrev{f}#13, city{f}#19, city_location{f}#20, country{f}#18, location{f}#17, name{f}#14, scalerank{f}#15, + * type{f}#16, poi{r}#3, distance{r}#7]] + * \_FieldExtractExec[abbrev{f}#13, city{f}#19, city_location{f}#20, coun..][] + * \_EvalExec[[ + * [1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT] AS poi, + * STDISTANCE(location{f}#17,[1 1 0 0 0 e1 7a 14 ae 47 21 29 40 a0 1a 2f dd 24 d6 4b 40][GEO_POINT]) AS distance] + * ] + * \_FieldExtractExec[location{f}#17][] + * \_EsQueryExec[airports], indexMode[standard], query[{"bool":{ + * "filter":[ + * {"esql_single_value":{"field":"scalerank","next":{"range":{...}},"source":"scalerank lt 6@4:31"}}, + * {"bool":{"must":[ + * {"geo_shape":{"location":{"relation":"INTERSECTS","shape":{...}}}}, + * {"geo_shape":{"location":{"relation":"DISJOINT","shape":{...}}}} + * ],"boost":1.0}} + * ],"boost":1.0 + * }}][_doc{f}#31], limit[15], sort[[ + * FieldSort[field=scalerank{f}#15, direction=ASC, nulls=LAST], + * GeoDistanceSort[field=location{f}#17, direction=ASC, lat=55.673, lon=12.565] + * ]] estimatedRowSize[341] + */ + public void testPushCompoundTopNDistanceWithCompoundFilterAndNestedCompoundEvalToSource() { + var optimized = optimizedPlan(physicalPlan(""" + FROM airports + | EVAL poi = TO_GEOPOINT("POINT(12.565 55.673)") + | EVAL distance = ST_DISTANCE(location, poi) + | WHERE distance < 500000 AND scalerank < 6 AND distance > 10000 + | SORT scalerank, distance + | LIMIT 15 + """, airports)); + + var topN = as(optimized, TopNExec.class); + var exchange = asRemoteExchange(topN.child()); + + var project = as(exchange.child(), ProjectExec.class); + assertThat( + names(project.projections()), + containsInAnyOrder("abbrev", "name", "type", "location", "country", "city", "city_location", "scalerank", "poi", "distance") + ); + var extract = as(project.child(), FieldExtractExec.class); + assertThat( + names(extract.attributesToExtract()), + containsInAnyOrder("abbrev", "name", "type", "country", "city", "city_location", "scalerank") + ); + var evalExec = as(extract.child(), EvalExec.class); + assertThat(evalExec.fields().size(), is(2)); + var alias1 = as(evalExec.fields().get(0), Alias.class); + assertThat(alias1.name(), is("poi")); + var poi = as(alias1.child(), Literal.class); + assertThat(poi.fold(), instanceOf(BytesRef.class)); + var alias2 = as(evalExec.fields().get(1), Alias.class); + assertThat(alias2.name(), is("distance")); + var stDistance = as(alias2.child(), StDistance.class); + var location = as(stDistance.left(), FieldAttribute.class); + assertThat(location.fieldName(), is("location")); + var poiRef = as(stDistance.right(), Literal.class); + assertThat(poiRef.fold(), instanceOf(BytesRef.class)); + assertThat(poiRef.fold().toString(), is(poi.fold().toString())); + extract = as(evalExec.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), contains("location")); + var source = source(extract.child()); + + // Assert that the TopN(distance) is pushed down as geo-sort(location) + assertThat(source.limit(), is(topN.limit())); + Set orderSet = orderAsSet(topN.order()); + Set sortsSet = sortsAsSet(source.sorts(), Map.of("location", "distance")); + assertThat(orderSet, is(sortsSet)); + + // Fine-grained checks on the pushed down sort + assertThat(source.limit(), is(l(15))); + assertThat(source.sorts().size(), is(2)); + EsQueryExec.Sort fieldSort = source.sorts().get(0); + assertThat(fieldSort.direction(), is(Order.OrderDirection.ASC)); + assertThat(name(fieldSort.field()), is("scalerank")); + assertThat(fieldSort.sortBuilder(), isA(FieldSortBuilder.class)); + EsQueryExec.Sort distSort = source.sorts().get(1); + assertThat(distSort.direction(), is(Order.OrderDirection.ASC)); + assertThat(name(distSort.field()), is("location")); + assertThat(distSort.sortBuilder(), isA(GeoDistanceSortBuilder.class)); + + // Fine-grained checks on the pushed down query + var bool = as(source.query(), BoolQueryBuilder.class); + var rangeQueryBuilders = bool.filter().stream().filter(p -> p instanceof SingleValueQuery.Builder).toList(); + assertThat("Expected one range query builder", rangeQueryBuilders.size(), equalTo(1)); + assertThat(((SingleValueQuery.Builder) rangeQueryBuilders.get(0)).field(), equalTo("scalerank")); + var filterBool = bool.filter().stream().filter(p -> p instanceof BoolQueryBuilder).toList(); + var fb = as(filterBool.get(0), BoolQueryBuilder.class); + var shapeQueryBuilders = fb.must().stream().filter(p -> p instanceof SpatialRelatesQuery.ShapeQueryBuilder).toList(); + assertShapeQueryRange(shapeQueryBuilders, 10000.0, 500000.0); + } + + private Set orderAsSet(List sorts) { + return sorts.stream().map(o -> ((Attribute) o.child()).name() + "->" + o.direction()).collect(Collectors.toSet()); + } + + private Set sortsAsSet(List sorts, Map fieldMap) { + return sorts.stream() + .map(s -> fieldMap.getOrDefault(s.field().name(), s.field().name()) + "->" + s.direction()) + .collect(Collectors.toSet()); + } + private void assertShapeQueryRange(List shapeQueryBuilders, double min, double max) { assertThat("Expected two shape query builders", shapeQueryBuilders.size(), equalTo(2)); var relationStats = new HashMap(); @@ -3838,7 +5581,7 @@ private void assertShapeQueryRange(List shapeQueryBuilders, double var circle = as(condition.shape(), Circle.class); assertThat("Circle center-x", circle.getX(), equalTo(12.565)); assertThat("Circle center-y", circle.getY(), equalTo(55.673)); - assertThat("Circle radius for shape relation " + condition.relation(), circle.getRadiusMeters(), equalTo(expected)); + assertThat("Circle radius for shape relation " + condition.relation(), circle.getRadiusMeters(), closeTo(expected, 1e-9)); } assertThat("Expected one INTERSECTS and one DISJOINT", relationStats.size(), equalTo(2)); assertThat("Expected one INTERSECTS", relationStats.get(ShapeRelation.INTERSECTS), equalTo(1)); @@ -4796,7 +6539,7 @@ private PhysicalPlan physicalPlan(String query, TestDataSource dataSource) { return physical; } - private List sorts(List orders) { + private List fieldSorts(List orders) { return orders.stream().map(o -> new FieldSort((FieldAttribute) o.child(), o.direction(), o.nullsPosition())).toList(); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java index 121c0ef337817..49a738f4f4fa3 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java @@ -9,9 +9,14 @@ import org.elasticsearch.index.IndexMode; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; +import org.elasticsearch.xpack.esql.expression.function.scalar.math.Pow; import org.elasticsearch.xpack.esql.expression.function.scalar.string.RLike; import org.elasticsearch.xpack.esql.expression.function.scalar.string.WildcardLike; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; @@ -20,17 +25,23 @@ import org.elasticsearch.xpack.esql.index.EsIndex; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; +import org.elasticsearch.xpack.esql.plan.logical.Eval; import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject; +import java.util.ArrayList; import java.util.List; import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; import static java.util.Collections.singletonList; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.FOUR; import static org.elasticsearch.xpack.esql.EsqlTestUtils.ONE; import static org.elasticsearch.xpack.esql.EsqlTestUtils.THREE; import static org.elasticsearch.xpack.esql.EsqlTestUtils.TWO; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.as; import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute; import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOf; import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOrEqualOf; @@ -77,6 +88,115 @@ public void testPushDownFilter() { assertEquals(new EsqlProject(EMPTY, combinedFilter, projections), new PushDownAndCombineFilters().apply(fb)); } + public void testPushDownFilterPastRenamingProject() { + FieldAttribute a = getFieldAttribute("a"); + FieldAttribute b = getFieldAttribute("b"); + EsRelation relation = relation(List.of(a, b)); + + Alias aRenamed = new Alias(EMPTY, "a_renamed", a); + Alias aRenamedTwice = new Alias(EMPTY, "a_renamed_twice", aRenamed.toAttribute()); + Alias bRenamed = new Alias(EMPTY, "b_renamed", b); + + Project project = new Project(EMPTY, relation, List.of(aRenamed, aRenamedTwice, bRenamed)); + + GreaterThan aRenamedTwiceGreaterThanOne = greaterThanOf(aRenamedTwice.toAttribute(), ONE); + LessThan bRenamedLessThanTwo = lessThanOf(bRenamed.toAttribute(), TWO); + Filter filter = new Filter(EMPTY, project, Predicates.combineAnd(List.of(aRenamedTwiceGreaterThanOne, bRenamedLessThanTwo))); + + LogicalPlan optimized = new PushDownAndCombineFilters().apply(filter); + + Project optimizedProject = as(optimized, Project.class); + assertEquals(optimizedProject.projections(), project.projections()); + Filter optimizedFilter = as(optimizedProject.child(), Filter.class); + assertEquals(optimizedFilter.condition(), Predicates.combineAnd(List.of(greaterThanOf(a, ONE), lessThanOf(b, TWO)))); + EsRelation optimizedRelation = as(optimizedFilter.child(), EsRelation.class); + assertEquals(optimizedRelation, relation); + } + + // ... | eval a_renamed = a, a_renamed_twice = a_renamed, a_squared = pow(a, 2) + // | where a_renamed > 1 and a_renamed_twice < 2 and a_squared < 4 + // -> + // ... | where a > 1 and a < 2 | eval a_renamed = a, a_renamed_twice = a_renamed, non_pushable = pow(a, 2) | where a_squared < 4 + public void testPushDownFilterOnAliasInEval() { + FieldAttribute a = getFieldAttribute("a"); + FieldAttribute b = getFieldAttribute("b"); + EsRelation relation = relation(List.of(a, b)); + + Alias aRenamed = new Alias(EMPTY, "a_renamed", a); + Alias aRenamedTwice = new Alias(EMPTY, "a_renamed_twice", aRenamed.toAttribute()); + Alias bRenamed = new Alias(EMPTY, "b_renamed", b); + Alias aSquared = new Alias(EMPTY, "a_squared", new Pow(EMPTY, a, TWO)); + Eval eval = new Eval(EMPTY, relation, List.of(aRenamed, aRenamedTwice, aSquared, bRenamed)); + + // We'll construct a Filter after the Eval that has conditions that can or cannot be pushed before the Eval. + List pushableConditionsBefore = List.of( + greaterThanOf(a.toAttribute(), TWO), + greaterThanOf(aRenamed.toAttribute(), ONE), + lessThanOf(aRenamedTwice.toAttribute(), TWO), + lessThanOf(aRenamedTwice.toAttribute(), bRenamed.toAttribute()) + ); + List pushableConditionsAfter = List.of( + greaterThanOf(a.toAttribute(), TWO), + greaterThanOf(a.toAttribute(), ONE), + lessThanOf(a.toAttribute(), TWO), + lessThanOf(a.toAttribute(), b.toAttribute()) + ); + List nonPushableConditions = List.of( + lessThanOf(aSquared.toAttribute(), FOUR), + greaterThanOf(aRenamedTwice.toAttribute(), aSquared.toAttribute()) + ); + + // Try different combinations of pushable and non-pushable conditions in the filter while also randomizing their order a bit. + for (int numPushable = 0; numPushable <= pushableConditionsBefore.size(); numPushable++) { + for (int numNonPushable = 0; numNonPushable <= nonPushableConditions.size(); numNonPushable++) { + if (numPushable == 0 && numNonPushable == 0) { + continue; + } + + List conditions = new ArrayList<>(); + + int pushableIndex = 0, nonPushableIndex = 0; + // Loop and add either a pushable or non-pushable condition to the filter. + boolean addPushable; + while (pushableIndex < numPushable || nonPushableIndex < numNonPushable) { + if (pushableIndex == numPushable) { + addPushable = false; + } else if (nonPushableIndex == numNonPushable) { + addPushable = true; + } else { + addPushable = randomBoolean(); + } + + if (addPushable) { + conditions.add(pushableConditionsBefore.get(pushableIndex++)); + } else { + conditions.add(nonPushableConditions.get(nonPushableIndex++)); + } + } + + Filter filter = new Filter(EMPTY, eval, Predicates.combineAnd(conditions)); + + LogicalPlan plan = new PushDownAndCombineFilters().apply(filter); + + if (numNonPushable > 0) { + Filter optimizedFilter = as(plan, Filter.class); + assertEquals(optimizedFilter.condition(), Predicates.combineAnd(nonPushableConditions.subList(0, numNonPushable))); + plan = optimizedFilter.child(); + } + Eval optimizedEval = as(plan, Eval.class); + assertEquals(optimizedEval.fields(), eval.fields()); + plan = optimizedEval.child(); + if (numPushable > 0) { + Filter pushedFilter = as(plan, Filter.class); + assertEquals(pushedFilter.condition(), Predicates.combineAnd(pushableConditionsAfter.subList(0, numPushable))); + plan = pushedFilter.child(); + } + EsRelation optimizedRelation = as(plan, EsRelation.class); + assertEquals(optimizedRelation, relation); + } + } + } + public void testPushDownLikeRlikeFilter() { EsRelation relation = relation(); org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLike conditionA = rlike(getFieldAttribute("a"), "foo"); @@ -125,7 +245,17 @@ public void testSelectivelyPushDownFilterPastFunctionAgg() { assertEquals(expected, new PushDownAndCombineFilters().apply(fb)); } - private EsRelation relation() { - return new EsRelation(EMPTY, new EsIndex(randomAlphaOfLength(8), emptyMap()), randomFrom(IndexMode.values()), randomBoolean()); + private static EsRelation relation() { + return relation(List.of()); + } + + private static EsRelation relation(List fieldAttributes) { + return new EsRelation( + EMPTY, + new EsIndex(randomAlphaOfLength(8), emptyMap()), + fieldAttributes, + randomFrom(IndexMode.values()), + randomBoolean() + ); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSourceTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSourceTests.java new file mode 100644 index 0000000000000..0fe7eb6b3d43b --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSourceTests.java @@ -0,0 +1,466 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules.physical.local; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.geometry.Geometry; +import org.elasticsearch.geometry.utils.GeometryValidator; +import org.elasticsearch.geometry.utils.WellKnownBinary; +import org.elasticsearch.geometry.utils.WellKnownText; +import org.elasticsearch.index.IndexMode; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.EsqlTestUtils; +import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.Nullability; +import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.type.EsField; +import org.elasticsearch.xpack.esql.expression.Order; +import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.StDistance; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; +import org.elasticsearch.xpack.esql.index.EsIndex; +import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext; +import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec; +import org.elasticsearch.xpack.esql.plan.physical.EvalExec; +import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; +import org.elasticsearch.xpack.esql.plan.physical.TopNExec; +import org.elasticsearch.xpack.esql.stats.DisabledSearchStats; + +import java.io.IOException; +import java.nio.ByteOrder; +import java.text.ParseException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; +import static org.elasticsearch.xpack.esql.core.type.DataType.GEO_POINT; +import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; +import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD; +import static org.elasticsearch.xpack.esql.optimizer.rules.physical.local.PushTopNToSourceTests.TestPhysicalPlanBuilder.from; +import static org.elasticsearch.xpack.esql.plan.physical.AbstractPhysicalPlanSerializationTests.randomEstimatedRowSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; + +public class PushTopNToSourceTests extends ESTestCase { + + public void testSimpleSortField() { + // FROM index | SORT field | LIMIT 10 + var query = from("index").sort("field").limit(10); + assertPushdownSort(query); + assertNoPushdownSort(query.asTimeSeries(), "for time series index mode"); + } + + public void testSimpleSortMultipleFields() { + // FROM index | SORT field, integer, double | LIMIT 10 + var query = from("index").sort("field").sort("integer").sort("double").limit(10); + assertPushdownSort(query); + assertNoPushdownSort(query.asTimeSeries(), "for time series index mode"); + } + + public void testSimpleSortFieldAndEvalLiteral() { + // FROM index | EVAL x = 1 | SORT field | LIMIT 10 + var query = from("index").eval("x", e -> e.i(1)).sort("field").limit(10); + assertPushdownSort(query, List.of(EvalExec.class, EsQueryExec.class)); + assertNoPushdownSort(query.asTimeSeries(), "for time series index mode"); + } + + public void testSimpleSortFieldWithAlias() { + // FROM index | EVAL x = field | SORT field | LIMIT 10 + var query = from("index").eval("x", b -> b.field("field")).sort("field").limit(10); + assertPushdownSort(query, List.of(EvalExec.class, EsQueryExec.class)); + assertNoPushdownSort(query.asTimeSeries(), "for time series index mode"); + } + + public void testSimpleSortMultipleFieldsWithAliases() { + // FROM index | EVAL x = field, y = integer, z = double | SORT field, integer, double | LIMIT 10 + var query = from("index").eval("x", b -> b.field("field")) + .eval("y", b -> b.field("integer")) + .eval("z", b -> b.field("double")) + .sort("field") + .sort("integer") + .sort("double") + .limit(10); + assertPushdownSort(query, List.of(EvalExec.class, EsQueryExec.class)); + assertNoPushdownSort(query.asTimeSeries(), "for time series index mode"); + } + + public void testSimpleSortFieldAsAlias() { + // FROM index | EVAL x = field | SORT x | LIMIT 10 + var query = from("index").eval("x", b -> b.field("field")).sort("x").limit(10); + assertPushdownSort(query, Map.of("x", "field"), List.of(EvalExec.class, EsQueryExec.class)); + assertNoPushdownSort(query.asTimeSeries(), "for time series index mode"); + } + + public void testSimpleSortFieldAndEvalSumLiterals() { + // FROM index | EVAL sum = 1 + 2 | SORT field | LIMIT 10 + var query = from("index").eval("sum", b -> b.add(b.i(1), b.i(2))).sort("field").limit(10); + assertPushdownSort(query, List.of(EvalExec.class, EsQueryExec.class)); + assertNoPushdownSort(query.asTimeSeries(), "for time series index mode"); + } + + public void testSimpleSortFieldAndEvalSumLiteralAndField() { + // FROM index | EVAL sum = 1 + integer | SORT integer | LIMIT 10 + var query = from("index").eval("sum", b -> b.add(b.i(1), b.field("integer"))).sort("integer").limit(10); + assertPushdownSort(query, List.of(EvalExec.class, EsQueryExec.class)); + assertNoPushdownSort(query.asTimeSeries(), "for time series index mode"); + } + + public void testSimpleSortEvalSumLiteralAndField() { + // FROM index | EVAL sum = 1 + integer | SORT sum | LIMIT 10 + var query = from("index").eval("sum", b -> b.add(b.i(1), b.field("integer"))).sort("sum").limit(10); + // TODO: Consider supporting this if we can determine that the eval function maintains the same order + assertNoPushdownSort(query, "when sorting on a derived field"); + assertNoPushdownSort(query.asTimeSeries(), "for time series index mode"); + } + + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/114515") + public void testPartiallyPushableSort() { + // FROM index | EVAL sum = 1 + integer | SORT integer, sum, field | LIMIT 10 + var query = from("index").eval("sum", b -> b.add(b.i(1), b.field("integer"))).sort("integer").sort("sum").sort("field").limit(10); + // Both integer and field can be pushed down, but we can only push down the leading sortable fields, so the 'sum' blocks 'field' + assertPushdownSort(query, List.of(query.orders.get(0)), null, List.of(EvalExec.class, EsQueryExec.class)); + assertNoPushdownSort(query.asTimeSeries(), "for time series index mode"); + } + + public void testSortGeoPointField() { + // FROM index | SORT location | LIMIT 10 + var query = from("index").sort("location", Order.OrderDirection.ASC).limit(10); + // NOTE: while geo_point is not sortable, this is checked during logical planning and the physical planner does not know or care + assertPushdownSort(query); + assertNoPushdownSort(query.asTimeSeries(), "for time series index mode"); + } + + public void testSortGeoDistanceFunction() { + // FROM index | EVAL distance = ST_DISTANCE(location, POINT(1 2)) | SORT distance | LIMIT 10 + var query = from("index").eval("distance", b -> b.distance("location", "POINT(1 2)")) + .sort("distance", Order.OrderDirection.ASC) + .limit(10); + // The pushed-down sort will use the underlying field 'location', not the sorted reference field 'distance' + assertPushdownSort(query, Map.of("distance", "location"), List.of(EvalExec.class, EsQueryExec.class)); + assertNoPushdownSort(query.asTimeSeries(), "for time series index mode"); + } + + public void testSortGeoDistanceFunctionInverted() { + // FROM index | EVAL distance = ST_DISTANCE(POINT(1 2), location) | SORT distance | LIMIT 10 + var query = from("index").eval("distance", b -> b.distance("POINT(1 2)", "location")) + .sort("distance", Order.OrderDirection.ASC) + .limit(10); + // The pushed-down sort will use the underlying field 'location', not the sorted reference field 'distance' + assertPushdownSort(query, Map.of("distance", "location"), List.of(EvalExec.class, EsQueryExec.class)); + assertNoPushdownSort(query.asTimeSeries(), "for time series index mode"); + } + + public void testSortGeoDistanceFunctionLiterals() { + // FROM index | EVAL distance = ST_DISTANCE(POINT(2 1), POINT(1 2)) | SORT distance | LIMIT 10 + var query = from("index").eval("distance", b -> b.distance("POINT(2 1)", "POINT(1 2)")) + .sort("distance", Order.OrderDirection.ASC) + .limit(10); + // The pushed-down sort will use the underlying field 'location', not the sorted reference field 'distance' + assertNoPushdownSort(query, "sort on foldable distance function"); + assertNoPushdownSort(query.asTimeSeries(), "for time series index mode"); + } + + public void testSortGeoDistanceFunctionAndFieldsWithAliases() { + // FROM index | EVAL distance = ST_DISTANCE(location, POINT(1 2)), x = field | SORT distance, field, integer | LIMIT 10 + var query = from("index").eval("distance", b -> b.distance("location", "POINT(1 2)")) + .eval("x", b -> b.field("field")) + .sort("distance", Order.OrderDirection.ASC) + .sort("field", Order.OrderDirection.DESC) + .sort("integer", Order.OrderDirection.DESC) + .limit(10); + // The pushed-down sort will use the underlying field 'location', not the sorted reference field 'distance' + assertPushdownSort(query, query.orders, Map.of("distance", "location"), List.of(EvalExec.class, EsQueryExec.class)); + assertNoPushdownSort(query.asTimeSeries(), "for time series index mode"); + } + + public void testSortGeoDistanceFunctionAndFieldsAndAliases() { + // FROM index | EVAL distance = ST_DISTANCE(location, POINT(1 2)), x = field | SORT distance, x, integer | LIMIT 10 + var query = from("index").eval("distance", b -> b.distance("location", "POINT(1 2)")) + .eval("x", b -> b.field("field")) + .sort("distance", Order.OrderDirection.ASC) + .sort("x", Order.OrderDirection.DESC) + .sort("integer", Order.OrderDirection.DESC) + .limit(10); + // The pushed-down sort will use the underlying field 'location', not the sorted reference field 'distance' + assertPushdownSort(query, query.orders, Map.of("distance", "location", "x", "field"), List.of(EvalExec.class, EsQueryExec.class)); + assertNoPushdownSort(query.asTimeSeries(), "for time series index mode"); + } + + public void testSortGeoDistanceFunctionAndFieldsAndManyAliases() { + // FROM index + // | EVAL loc = location, loc2 = loc, loc3 = loc2, distance = ST_DISTANCE(loc3, POINT(1 2)), x = field + // | SORT distance, x, integer + // | LIMIT 10 + var query = from("index").eval("loc", b -> b.field("location")) + .eval("loc2", b -> b.ref("loc")) + .eval("loc3", b -> b.ref("loc2")) + .eval("distance", b -> b.distance("loc3", "POINT(1 2)")) + .eval("x", b -> b.field("field")) + .sort("distance", Order.OrderDirection.ASC) + .sort("x", Order.OrderDirection.DESC) + .sort("integer", Order.OrderDirection.DESC) + .limit(10); + // The pushed-down sort will use the underlying field 'location', not the sorted reference field 'distance' + assertPushdownSort(query, Map.of("distance", "location", "x", "field"), List.of(EvalExec.class, EsQueryExec.class)); + assertNoPushdownSort(query.asTimeSeries(), "for time series index mode"); + } + + private static void assertPushdownSort(TestPhysicalPlanBuilder builder) { + assertPushdownSort(builder, null, List.of(EsQueryExec.class)); + } + + private static void assertPushdownSort(TestPhysicalPlanBuilder builder, List> topClass) { + assertPushdownSort(builder, null, topClass); + } + + private static void assertPushdownSort( + TestPhysicalPlanBuilder builder, + Map fieldMap, + List> topClass + ) { + var topNExec = builder.build(); + var result = pushTopNToSource(topNExec); + assertPushdownSort(result, builder.orders, fieldMap, topClass); + } + + private static void assertPushdownSort( + TestPhysicalPlanBuilder builder, + List expectedSorts, + Map fieldMap, + List> topClass + ) { + var topNExec = builder.build(); + var result = pushTopNToSource(topNExec); + assertPushdownSort(result, expectedSorts, fieldMap, topClass); + } + + private static void assertNoPushdownSort(TestPhysicalPlanBuilder builder, String message) { + var topNExec = builder.build(); + var result = pushTopNToSource(topNExec); + assertNoPushdownSort(result, message); + } + + private static PhysicalPlan pushTopNToSource(TopNExec topNExec) { + var configuration = EsqlTestUtils.configuration("from test"); + var searchStats = new DisabledSearchStats(); + var ctx = new LocalPhysicalOptimizerContext(configuration, searchStats); + var pushTopNToSource = new PushTopNToSource(); + return pushTopNToSource.rule(topNExec, ctx); + } + + private static void assertNoPushdownSort(PhysicalPlan plan, String message) { + var esQueryExec = findEsQueryExec(plan); + var sorts = esQueryExec.sorts(); + assertThat("Expect no sorts " + message, sorts.size(), is(0)); + } + + private static void assertPushdownSort( + PhysicalPlan plan, + List expectedSorts, + Map fieldMap, + List> topClass + ) { + if (topClass != null && topClass.size() > 0) { + PhysicalPlan current = plan; + for (var clazz : topClass) { + assertThat("Expect non-null physical plan class to match " + clazz.getSimpleName(), current, notNullValue()); + assertThat("Expect top physical plan class to match", current.getClass(), is(clazz)); + current = current.children().size() > 0 ? current.children().get(0) : null; + } + if (current != null) { + fail("No more child classes expected in plan, but found: " + current.getClass().getSimpleName()); + } + } + var esQueryExec = findEsQueryExec(plan); + var sorts = esQueryExec.sorts(); + assertThat("Expect sorts count to match", sorts.size(), is(expectedSorts.size())); + for (int i = 0; i < expectedSorts.size(); i++) { + String name = ((Attribute) expectedSorts.get(i).child()).name(); + String fieldName = sorts.get(i).field().fieldName(); + assertThat("Expect sort[" + i + "] name to match", fieldName, is(sortName(name, fieldMap))); + assertThat("Expect sort[" + i + "] direction to match", sorts.get(i).direction(), is(expectedSorts.get(i).direction())); + } + } + + private static String sortName(String name, Map fieldMap) { + return fieldMap != null ? fieldMap.getOrDefault(name, name) : name; + } + + private static EsQueryExec findEsQueryExec(PhysicalPlan plan) { + if (plan instanceof EsQueryExec esQueryExec) { + return esQueryExec; + } + // We assume no physical plans with multiple children would be generated + return findEsQueryExec(plan.children().get(0)); + } + + /** + * This builder allows for easy creation of physical plans using a syntax like `from("index").sort("field").limit(10)`. + * The idea is to create tests that are clearly related to real queries, but also easy to make assertions on. + * It only supports a very small subset of possible plans, with FROM, EVAL and SORT+LIMIT, in that order, matching + * the physical plan rules that are being tested: TopNExec, EvalExec and EsQueryExec. + */ + static class TestPhysicalPlanBuilder { + private final String index; + private final LinkedHashMap fields; + private final LinkedHashMap refs; + private IndexMode indexMode; + private final List aliases = new ArrayList<>(); + private final List orders = new ArrayList<>(); + private int limit = Integer.MAX_VALUE; + + private TestPhysicalPlanBuilder(String index, IndexMode indexMode) { + this.index = index; + this.indexMode = indexMode; + this.fields = new LinkedHashMap<>(); + this.refs = new LinkedHashMap<>(); + addSortableFieldAttributes(this.fields); + } + + private static void addSortableFieldAttributes(Map fields) { + addFieldAttribute(fields, "field", KEYWORD); + addFieldAttribute(fields, "integer", INTEGER); + addFieldAttribute(fields, "double", DOUBLE); + addFieldAttribute(fields, "keyword", KEYWORD); + addFieldAttribute(fields, "location", GEO_POINT); + } + + private static void addFieldAttribute(Map fields, String name, DataType type) { + fields.put(name, new FieldAttribute(Source.EMPTY, name, new EsField(name, type, new HashMap<>(), true))); + } + + static TestPhysicalPlanBuilder from(String index) { + return new TestPhysicalPlanBuilder(index, IndexMode.STANDARD); + } + + public TestPhysicalPlanBuilder eval(Alias... aliases) { + if (orders.isEmpty() == false) { + throw new IllegalArgumentException("Eval must be before sort"); + } + if (aliases.length == 0) { + throw new IllegalArgumentException("At least one alias must be provided"); + } + for (Alias alias : aliases) { + if (refs.containsKey(alias.name())) { + throw new IllegalArgumentException("Reference already exists: " + alias.name()); + } + refs.put( + alias.name(), + new ReferenceAttribute(Source.EMPTY, alias.name(), alias.dataType(), Nullability.FALSE, alias.id(), alias.synthetic()) + ); + this.aliases.add(alias); + } + return this; + } + + public TestPhysicalPlanBuilder eval(String name, Function builder) { + var testExpressionBuilder = new TestExpressionBuilder(); + Expression expression = builder.apply(testExpressionBuilder); + return eval(new Alias(Source.EMPTY, name, expression)); + } + + public TestPhysicalPlanBuilder sort(String field) { + return sort(field, Order.OrderDirection.ASC); + } + + public TestPhysicalPlanBuilder sort(String field, Order.OrderDirection direction) { + Attribute attr = refs.get(field); + if (attr == null) { + attr = fields.get(field); + } + if (attr == null) { + throw new IllegalArgumentException("Field not found: " + field); + } + orders.add(new Order(Source.EMPTY, attr, direction, Order.NullsPosition.LAST)); + return this; + } + + public TestPhysicalPlanBuilder limit(int limit) { + this.limit = limit; + return this; + } + + public TopNExec build() { + EsIndex esIndex = new EsIndex(this.index, Map.of()); + List attributes = new ArrayList<>(fields.values()); + PhysicalPlan child = new EsQueryExec(Source.EMPTY, esIndex, indexMode, attributes, null, null, List.of(), 0); + if (aliases.isEmpty() == false) { + child = new EvalExec(Source.EMPTY, child, aliases); + } + return new TopNExec(Source.EMPTY, child, orders, new Literal(Source.EMPTY, limit, INTEGER), randomEstimatedRowSize()); + } + + public TestPhysicalPlanBuilder asTimeSeries() { + this.indexMode = IndexMode.TIME_SERIES; + return this; + } + + class TestExpressionBuilder { + Expression field(String name) { + return fields.get(name); + } + + Expression ref(String name) { + return refs.get(name); + } + + Expression literal(Object value, DataType dataType) { + return new Literal(Source.EMPTY, value, dataType); + } + + Expression i(int value) { + return new Literal(Source.EMPTY, value, DataType.INTEGER); + } + + Expression d(double value) { + return new Literal(Source.EMPTY, value, DOUBLE); + } + + Expression k(String value) { + return new Literal(Source.EMPTY, value, KEYWORD); + } + + public Expression add(Expression left, Expression right) { + return new Add(Source.EMPTY, left, right); + } + + public Expression distance(String left, String right) { + return new StDistance(Source.EMPTY, geoExpr(left), geoExpr(right)); + } + + private Expression geoExpr(String text) { + if (text.startsWith("POINT")) { + try { + Geometry geometry = WellKnownText.fromWKT(GeometryValidator.NOOP, false, text); + BytesRef bytes = new BytesRef(WellKnownBinary.toWKB(geometry, ByteOrder.LITTLE_ENDIAN)); + return new Literal(Source.EMPTY, bytes, GEO_POINT); + } catch (IOException | ParseException e) { + throw new IllegalArgumentException("Failed to parse WKT: " + text, e); + } + } + if (fields.containsKey(text)) { + return fields.get(text); + } + if (refs.containsKey(text)) { + return refs.get(text); + } + throw new IllegalArgumentException("Unknown field: " + text); + } + } + + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/EsQueryExecSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/EsQueryExecSerializationTests.java index 6bb5111b154e6..6104069769085 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/EsQueryExecSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/EsQueryExecSerializationTests.java @@ -13,20 +13,15 @@ import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.expression.Order; -import org.elasticsearch.xpack.esql.expression.function.FieldAttributeTests; import org.elasticsearch.xpack.esql.index.EsIndex; import org.elasticsearch.xpack.esql.index.EsIndexSerializationTests; import java.io.IOException; import java.util.List; -import static org.elasticsearch.xpack.esql.plan.logical.AbstractLogicalPlanSerializationTests.randomFieldAttributes; - public class EsQueryExecSerializationTests extends AbstractPhysicalPlanSerializationTests { public static EsQueryExec randomEsQueryExec() { Source source = randomSource(); @@ -35,26 +30,14 @@ public static EsQueryExec randomEsQueryExec() { List attrs = randomFieldAttributes(1, 10, false); QueryBuilder query = randomQuery(); Expression limit = new Literal(randomSource(), between(0, Integer.MAX_VALUE), DataType.INTEGER); - List sorts = randomFieldSorts(); Integer estimatedRowSize = randomEstimatedRowSize(); - return new EsQueryExec(source, index, indexMode, attrs, query, limit, sorts, estimatedRowSize); + return new EsQueryExec(source, index, indexMode, attrs, query, limit, EsQueryExec.NO_SORTS, estimatedRowSize); } public static QueryBuilder randomQuery() { return randomBoolean() ? new MatchAllQueryBuilder() : new TermQueryBuilder(randomAlphaOfLength(4), randomAlphaOfLength(4)); } - public static List randomFieldSorts() { - return randomList(0, 4, EsQueryExecSerializationTests::randomFieldSort); - } - - public static EsQueryExec.FieldSort randomFieldSort() { - FieldAttribute field = FieldAttributeTests.createFieldAttribute(0, false); - Order.OrderDirection direction = randomFrom(Order.OrderDirection.values()); - Order.NullsPosition nulls = randomFrom(Order.NullsPosition.values()); - return new EsQueryExec.FieldSort(field, direction, nulls); - } - @Override protected EsQueryExec createTestInstance() { return randomEsQueryExec(); @@ -67,9 +50,8 @@ protected EsQueryExec mutateInstance(EsQueryExec instance) throws IOException { List attrs = instance.attrs(); QueryBuilder query = instance.query(); Expression limit = instance.limit(); - List sorts = instance.sorts(); Integer estimatedRowSize = instance.estimatedRowSize(); - switch (between(0, 6)) { + switch (between(0, 5)) { case 0 -> index = randomValueOtherThan(index, EsIndexSerializationTests::randomEsIndex); case 1 -> indexMode = randomValueOtherThan(indexMode, () -> randomFrom(IndexMode.values())); case 2 -> attrs = randomValueOtherThan(attrs, () -> randomFieldAttributes(1, 10, false)); @@ -78,13 +60,12 @@ protected EsQueryExec mutateInstance(EsQueryExec instance) throws IOException { limit, () -> new Literal(randomSource(), between(0, Integer.MAX_VALUE), DataType.INTEGER) ); - case 5 -> sorts = randomValueOtherThan(sorts, EsQueryExecSerializationTests::randomFieldSorts); - case 6 -> estimatedRowSize = randomValueOtherThan( + case 5 -> estimatedRowSize = randomValueOtherThan( estimatedRowSize, AbstractPhysicalPlanSerializationTests::randomEstimatedRowSize ); } - return new EsQueryExec(instance.source(), index, indexMode, attrs, query, limit, sorts, estimatedRowSize); + return new EsQueryExec(instance.source(), index, indexMode, attrs, query, limit, EsQueryExec.NO_SORTS, estimatedRowSize); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/ExchangeSinkExecSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/ExchangeSinkExecSerializationTests.java index dd163609de8a8..1f52795dbacd7 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/ExchangeSinkExecSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/ExchangeSinkExecSerializationTests.java @@ -66,11 +66,12 @@ protected boolean alwaysEmptySource() { * See {@link #testManyTypeConflicts(boolean, ByteSizeValue)} for more. */ public void testManyTypeConflicts() throws IOException { - testManyTypeConflicts(false, ByteSizeValue.ofBytes(1897373)); + testManyTypeConflicts(false, ByteSizeValue.ofBytes(1424048)); /* * History: * 2.3mb - shorten error messages for UnsupportedAttributes #111973 * 1.8mb - cache EsFields #112008 + * 1.4mb - string serialization #112929 */ } @@ -79,13 +80,14 @@ public void testManyTypeConflicts() throws IOException { * See {@link #testManyTypeConflicts(boolean, ByteSizeValue)} for more. */ public void testManyTypeConflictsWithParent() throws IOException { - testManyTypeConflicts(true, ByteSizeValue.ofBytes(3271486)); + testManyTypeConflicts(true, ByteSizeValue.ofBytes(2774214)); /* * History: * 2 gb+ - start * 43.3mb - Cache attribute subclasses #111447 * 5.6mb - shorten error messages for UnsupportedAttributes #111973 * 3.1mb - cache EsFields #112008 + * 2.6mb - string serialization #112929 */ } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java index 272321b0f350b..f60e5384e1a6f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java @@ -107,6 +107,21 @@ public void testLuceneTopNSourceOperator() throws IOException { assertThat(factory.limit(), equalTo(10)); } + public void testLuceneTopNSourceOperatorDistanceSort() throws IOException { + int estimatedRowSize = randomEstimatedRowSize(estimatedRowSizeIsHuge); + FieldAttribute sortField = new FieldAttribute(Source.EMPTY, "point", new EsField("point", DataType.GEO_POINT, Map.of(), true)); + EsQueryExec.GeoDistanceSort sort = new EsQueryExec.GeoDistanceSort(sortField, Order.OrderDirection.ASC, 1, -1); + Literal limit = new Literal(Source.EMPTY, 10, DataType.INTEGER); + LocalExecutionPlanner.LocalExecutionPlan plan = planner().plan( + new EsQueryExec(Source.EMPTY, index(), IndexMode.STANDARD, List.of(), null, limit, List.of(sort), estimatedRowSize) + ); + assertThat(plan.driverFactories.size(), lessThanOrEqualTo(pragmas.taskConcurrency())); + LocalExecutionPlanner.DriverSupplier supplier = plan.driverFactories.get(0).driverSupplier(); + var factory = (LuceneTopNSourceOperator.Factory) supplier.physicalOperation().sourceOperatorFactory; + assertThat(factory.maxPageSize(), maxPageSizeMatcher(estimatedRowSizeIsHuge, estimatedRowSize)); + assertThat(factory.limit(), equalTo(10)); + } + private int randomEstimatedRowSize(boolean huge) { int hugeBoundary = SourceOperator.MIN_TARGET_PAGE_SIZE * 10; return huge ? between(hugeBoundary, Integer.MAX_VALUE) : between(1, hugeBoundary); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java index 0ab3980f112ef..d186b4c199d77 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java @@ -87,6 +87,7 @@ import static java.util.Collections.emptyList; import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfiguration; +import static org.elasticsearch.xpack.esql.core.type.DataType.GEO_POINT; import static org.mockito.Mockito.mock; /** @@ -421,7 +422,11 @@ public void accept(Page page) { // Grok.Parser is a record / final, cannot be mocked return Grok.pattern(Source.EMPTY, randomGrokPattern()); } else if (argClass == EsQueryExec.FieldSort.class) { + // TODO: It appears neither FieldSort nor GeoDistanceSort are ever actually tested return randomFieldSort(); + } else if (argClass == EsQueryExec.GeoDistanceSort.class) { + // TODO: It appears neither FieldSort nor GeoDistanceSort are ever actually tested + return randomGeoDistanceSort(); } else if (toBuildClass == Pow.class && Expression.class.isAssignableFrom(argClass)) { return randomResolvedExpression(randomBoolean() ? FieldAttribute.class : Literal.class); } else if (isPlanNodeClass(toBuildClass) && Expression.class.isAssignableFrom(argClass)) { @@ -679,6 +684,15 @@ static EsQueryExec.FieldSort randomFieldSort() { ); } + static EsQueryExec.GeoDistanceSort randomGeoDistanceSort() { + return new EsQueryExec.GeoDistanceSort( + field(randomAlphaOfLength(16), GEO_POINT), + randomFrom(EnumSet.allOf(Order.OrderDirection.class)), + randomDoubleBetween(-90, 90, false), + randomDoubleBetween(-180, 180, false) + ); + } + static FieldAttribute field(String name, DataType type) { return new FieldAttribute(Source.EMPTY, name, new EsField(name, type, Collections.emptyMap(), false)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java index a85b92dd1a055..93d435eb0b69f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsOptions.java @@ -10,7 +10,8 @@ public enum ChunkingSettingsOptions { STRATEGY("strategy"), MAX_CHUNK_SIZE("max_chunk_size"), - OVERLAP("overlap"); + OVERLAP("overlap"), + SENTENCE_OVERLAP("sentence_overlap"); private final String chunkingSettingsOption; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java index 3a53ecc7ae958..5df940d6a3fba 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java @@ -34,6 +34,7 @@ public class SentenceBoundaryChunker implements Chunker { public SentenceBoundaryChunker() { sentenceIterator = BreakIterator.getSentenceInstance(Locale.ROOT); wordIterator = BreakIterator.getWordInstance(Locale.ROOT); + } /** @@ -46,7 +47,7 @@ public SentenceBoundaryChunker() { @Override public List chunk(String input, ChunkingSettings chunkingSettings) { if (chunkingSettings instanceof SentenceBoundaryChunkingSettings sentenceBoundaryChunkingSettings) { - return chunk(input, sentenceBoundaryChunkingSettings.maxChunkSize); + return chunk(input, sentenceBoundaryChunkingSettings.maxChunkSize, sentenceBoundaryChunkingSettings.sentenceOverlap > 0); } else { throw new IllegalArgumentException( Strings.format( @@ -64,7 +65,7 @@ public List chunk(String input, ChunkingSettings chunkingSettings) { * @param maxNumberWordsPerChunk Maximum size of the chunk * @return The input text chunked */ - public List chunk(String input, int maxNumberWordsPerChunk) { + public List chunk(String input, int maxNumberWordsPerChunk, boolean includePrecedingSentence) { var chunks = new ArrayList(); sentenceIterator.setText(input); @@ -75,24 +76,46 @@ public List chunk(String input, int maxNumberWordsPerChunk) { int sentenceStart = 0; int chunkWordCount = 0; + int wordsInPrecedingSentenceCount = 0; + int previousSentenceStart = 0; + int boundary = sentenceIterator.next(); while (boundary != BreakIterator.DONE) { int sentenceEnd = sentenceIterator.current(); - int countWordsInSentence = countWords(sentenceStart, sentenceEnd); + int wordsInSentenceCount = countWords(sentenceStart, sentenceEnd); - if (chunkWordCount + countWordsInSentence > maxNumberWordsPerChunk) { + if (chunkWordCount + wordsInSentenceCount > maxNumberWordsPerChunk) { // over the max chunk size, roll back to the last sentence + int nextChunkWordCount = wordsInSentenceCount; if (chunkWordCount > 0) { // add a new chunk containing all the input up to this sentence chunks.add(input.substring(chunkStart, chunkEnd)); - chunkStart = chunkEnd; - chunkWordCount = countWordsInSentence; // the next chunk will contain this sentence + + if (includePrecedingSentence) { + if (wordsInPrecedingSentenceCount + wordsInSentenceCount > maxNumberWordsPerChunk) { + // cut the last sentence + int numWordsToSkip = numWordsToSkipInPreviousSentence(wordsInPrecedingSentenceCount, maxNumberWordsPerChunk); + + chunkStart = skipWords(input, previousSentenceStart, numWordsToSkip); + chunkWordCount = (wordsInPrecedingSentenceCount - numWordsToSkip) + wordsInSentenceCount; + } else { + chunkWordCount = wordsInPrecedingSentenceCount + wordsInSentenceCount; + chunkStart = previousSentenceStart; + } + + nextChunkWordCount = chunkWordCount; + } else { + chunkStart = chunkEnd; + chunkWordCount = wordsInSentenceCount; // the next chunk will contain this sentence + } } - if (countWordsInSentence > maxNumberWordsPerChunk) { - // This sentence is bigger than the max chunk size. + // Is the next chunk larger than max chunk size? + // If so split it + if (nextChunkWordCount > maxNumberWordsPerChunk) { + // This sentence (and optional overlap) is bigger than the max chunk size. // Split the sentence on the word boundary var sentenceSplits = splitLongSentence( input.substring(chunkStart, sentenceEnd), @@ -113,7 +136,12 @@ public List chunk(String input, int maxNumberWordsPerChunk) { chunkWordCount = sentenceSplits.get(i).wordCount(); } } else { - chunkWordCount += countWordsInSentence; + chunkWordCount += wordsInSentenceCount; + } + + if (includePrecedingSentence) { + previousSentenceStart = sentenceStart; + wordsInPrecedingSentenceCount = wordsInSentenceCount; } sentenceStart = sentenceEnd; @@ -133,6 +161,45 @@ static List splitLongSentence(String text, in return new WordBoundaryChunker().chunkPositions(text, maxNumberOfWords, overlap); } + static int numWordsToSkipInPreviousSentence(int wordsInPrecedingSentenceCount, int maxNumberWordsPerChunk) { + var maxWordsInOverlap = maxWordsInOverlap(maxNumberWordsPerChunk); + if (wordsInPrecedingSentenceCount > maxWordsInOverlap) { + return wordsInPrecedingSentenceCount - maxWordsInOverlap; + } else { + return 0; + } + } + + static int maxWordsInOverlap(int maxNumberWordsPerChunk) { + return Math.min(maxNumberWordsPerChunk / 2, 20); + } + + private int skipWords(String input, int start, int numWords) { + var itr = BreakIterator.getWordInstance(Locale.ROOT); + itr.setText(input); + return skipWords(start, numWords, itr); + } + + static int skipWords(int start, int numWords, BreakIterator wordIterator) { + wordIterator.preceding(start); // start of the current word + + int boundary = wordIterator.current(); + int wordCount = 0; + while (boundary != BreakIterator.DONE && wordCount < numWords) { + int wordStatus = wordIterator.getRuleStatus(); + if (wordStatus != BreakIterator.WORD_NONE) { + wordCount++; + } + boundary = wordIterator.next(); + } + + if (boundary == BreakIterator.DONE) { + return wordIterator.last(); + } else { + return boundary; + } + } + private int countWords(int start, int end) { return countWords(start, end, this.wordIterator); } @@ -157,6 +224,6 @@ static int countWords(int start, int end, BreakIterator wordIterator) { } private static int overlapForChunkSize(int chunkSize) { - return (chunkSize - 1) / 2; + return Math.min(20, (chunkSize - 1) / 2); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java index 0d1903895f615..758dd5d04e268 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.ChunkingStrategy; import org.elasticsearch.inference.ModelConfigurations; @@ -30,16 +31,25 @@ public class SentenceBoundaryChunkingSettings implements ChunkingSettings { private static final ChunkingStrategy STRATEGY = ChunkingStrategy.SENTENCE; private static final Set VALID_KEYS = Set.of( ChunkingSettingsOptions.STRATEGY.toString(), - ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString() + ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), + ChunkingSettingsOptions.SENTENCE_OVERLAP.toString() ); + + private static int DEFAULT_OVERLAP = 0; + protected final int maxChunkSize; + protected int sentenceOverlap = DEFAULT_OVERLAP; - public SentenceBoundaryChunkingSettings(Integer maxChunkSize) { + public SentenceBoundaryChunkingSettings(Integer maxChunkSize, @Nullable Integer sentenceOverlap) { this.maxChunkSize = maxChunkSize; + this.sentenceOverlap = sentenceOverlap == null ? DEFAULT_OVERLAP : sentenceOverlap; } public SentenceBoundaryChunkingSettings(StreamInput in) throws IOException { maxChunkSize = in.readInt(); + if (in.getTransportVersion().onOrAfter(TransportVersions.CHUNK_SENTENCE_OVERLAP_SETTING_ADDED)) { + sentenceOverlap = in.readVInt(); + } } public static SentenceBoundaryChunkingSettings fromMap(Map map) { @@ -59,11 +69,24 @@ public static SentenceBoundaryChunkingSettings fromMap(Map map) validationException ); + Integer sentenceOverlap = ServiceUtils.extractOptionalPositiveInteger( + map, + ChunkingSettingsOptions.SENTENCE_OVERLAP.toString(), + ModelConfigurations.CHUNKING_SETTINGS, + validationException + ); + + if (sentenceOverlap != null && sentenceOverlap > 1) { + validationException.addValidationError( + ChunkingSettingsOptions.SENTENCE_OVERLAP.toString() + "[" + sentenceOverlap + "] must be either 0 or 1" + ); // todo better + } + if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new SentenceBoundaryChunkingSettings(maxChunkSize); + return new SentenceBoundaryChunkingSettings(maxChunkSize, sentenceOverlap); } @Override @@ -72,6 +95,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws { builder.field(ChunkingSettingsOptions.STRATEGY.toString(), STRATEGY); builder.field(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSize); + builder.field(ChunkingSettingsOptions.SENTENCE_OVERLAP.toString(), sentenceOverlap); } builder.endObject(); return builder; @@ -90,6 +114,9 @@ public TransportVersion getMinimalSupportedVersion() { @Override public void writeTo(StreamOutput out) throws IOException { out.writeInt(maxChunkSize); + if (out.getTransportVersion().onOrAfter(TransportVersions.CHUNK_SENTENCE_OVERLAP_SETTING_ADDED)) { + out.writeVInt(sentenceOverlap); + } } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java index 6517e0eea14d9..5b91e122b9c80 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java @@ -52,7 +52,7 @@ public static WordBoundaryChunkingSettings fromMap(Map map) { var invalidSettings = map.keySet().stream().filter(key -> VALID_KEYS.contains(key) == false).toArray(); if (invalidSettings.length > 0) { validationException.addValidationError( - Strings.format("Sentence based chunking settings can not have the following settings: %s", Arrays.toString(invalidSettings)) + Strings.format("Word based chunking settings can not have the following settings: %s", Arrays.toString(invalidSettings)) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java index 6bc43a4309b0c..fd0ad220faa3b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java @@ -23,8 +23,6 @@ import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; import org.elasticsearch.search.rank.feature.RankFeatureDoc; -import org.elasticsearch.search.rank.rerank.RerankingQueryPhaseRankCoordinatorContext; -import org.elasticsearch.search.rank.rerank.RerankingQueryPhaseRankShardContext; import org.elasticsearch.search.rank.rerank.RerankingRankFeaturePhaseRankShardContext; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.XContentBuilder; @@ -157,12 +155,12 @@ public Explanation explainHit(Explanation baseExplanation, RankDoc scoreDoc, Lis @Override public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { - return new RerankingQueryPhaseRankShardContext(queries, rankWindowSize()); + return null; } @Override public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { - return new RerankingQueryPhaseRankCoordinatorContext(rankWindowSize()); + return null; } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java index 061ea677e6fe1..3c09984ac0162 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java @@ -56,7 +56,7 @@ private Map, ChunkingSettings> chunkingSettingsMapToChunking ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSize ), - new SentenceBoundaryChunkingSettings(maxChunkSize) + new SentenceBoundaryChunkingSettings(maxChunkSize, 1) ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsTests.java index 2482586c75595..8373ae93354b1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsTests.java @@ -25,7 +25,7 @@ public static ChunkingSettings createRandomChunkingSettings() { return new WordBoundaryChunkingSettings(maxChunkSize, randomIntBetween(1, maxChunkSize / 2)); } case SENTENCE -> { - return new SentenceBoundaryChunkingSettings(randomNonNegativeInt()); + return new SentenceBoundaryChunkingSettings(randomNonNegativeInt(), randomBoolean() ? 0 : 1); } default -> throw new IllegalArgumentException("Unsupported random strategy [" + randomStrategy + "]"); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java index 335752faa6b22..5687ebc4dbae7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java @@ -13,19 +13,24 @@ import org.elasticsearch.test.ESTestCase; import org.hamcrest.Matchers; +import java.util.ArrayList; import java.util.Arrays; import java.util.Locale; import static org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkerTests.TEST_TEXT; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.endsWith; +import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.hamcrest.Matchers.startsWith; public class SentenceBoundaryChunkerTests extends ESTestCase { public void testChunkSplitLargeChunkSizes() { for (int maxWordsPerChunk : new int[] { 100, 200 }) { var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk); + var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, false); int numChunks = expectedNumberOfChunks(sentenceSizes(TEST_TEXT), maxWordsPerChunk); assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(numChunks)); @@ -39,11 +44,94 @@ public void testChunkSplitLargeChunkSizes() { } } + public void testChunkSplitLargeChunkSizes_withOverlap() { + boolean overlap = true; + for (int maxWordsPerChunk : new int[] { 70, 80, 100, 120, 150, 200 }) { + var chunker = new SentenceBoundaryChunker(); + var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, overlap); + + int[] overlaps = chunkOverlaps(sentenceSizes(TEST_TEXT), maxWordsPerChunk, overlap); + assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(overlaps.length)); + + assertTrue(Character.isUpperCase(chunks.get(0).charAt(0))); + + for (int i = 0; i < overlaps.length; i++) { + if (overlaps[i] == 0) { + // start of a sentence + assertTrue(Character.isUpperCase(chunks.get(i).charAt(0))); + } else { + // The start of this chunk should contain some text from the end of the previous + var previousChunk = chunks.get(i - 1); + assertThat(chunks.get(i), containsString(previousChunk.substring(previousChunk.length() - 20))); + } + } + + var trailingWhiteSpaceRemoved = chunks.get(0).strip(); + var lastChar = trailingWhiteSpaceRemoved.charAt(trailingWhiteSpaceRemoved.length() - 1); + assertThat(lastChar, Matchers.is('.')); + trailingWhiteSpaceRemoved = chunks.get(chunks.size() - 1).strip(); + lastChar = trailingWhiteSpaceRemoved.charAt(trailingWhiteSpaceRemoved.length() - 1); + assertThat(lastChar, Matchers.is('.')); + } + } + + public void testWithOverlap_SentencesFitInChunks() { + int numChunks = 4; + int chunkSize = 100; + + var sb = new StringBuilder(); + + int[] sentenceStartIndexes = new int[numChunks]; + sentenceStartIndexes[0] = 0; + + int numSentences = randomIntBetween(2, 5); + int sentenceIndex = 0; + int lastSentenceSize = 0; + int roughSentenceSize = (chunkSize / numSentences) - 1; + for (int j = 0; j < numSentences; j++) { + sb.append(makeSentence(roughSentenceSize, sentenceIndex++)); + lastSentenceSize = roughSentenceSize; + } + + for (int i = 1; i < numChunks; i++) { + sentenceStartIndexes[i] = sentenceIndex - 1; + + roughSentenceSize = (chunkSize / numSentences) - 1; + int wordCount = lastSentenceSize; + + while (wordCount + roughSentenceSize < chunkSize) { + sb.append(makeSentence(roughSentenceSize, sentenceIndex++)); + lastSentenceSize = roughSentenceSize; + wordCount += roughSentenceSize; + } + } + + var chunker = new SentenceBoundaryChunker(); + var chunks = chunker.chunk(sb.toString(), chunkSize, true); + assertThat(chunks, hasSize(numChunks)); + for (int i = 0; i < numChunks; i++) { + assertThat("num sentences " + numSentences, chunks.get(i), startsWith("SStart" + sentenceStartIndexes[i])); + assertThat("num sentences " + numSentences, chunks.get(i).trim(), endsWith(".")); + } + } + + private String makeSentence(int numWords, int sentenceIndex) { + StringBuilder sb = new StringBuilder(); + sb.append("SStart").append(sentenceIndex).append(' '); + for (int i = 1; i < numWords - 1; i++) { + sb.append(i).append(' '); + } + sb.append(numWords - 1).append(". "); + return sb.toString(); + } + public void testChunk_ChunkSizeLargerThanText() { int maxWordsPerChunk = 500; var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk); + var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, false); + assertEquals(chunks.get(0), TEST_TEXT); + chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, true); assertEquals(chunks.get(0), TEST_TEXT); } @@ -54,7 +142,7 @@ public void testChunkSplit_SentencesLongerThanChunkSize() { for (int i = 0; i < chunkSizes.length; i++) { int maxWordsPerChunk = chunkSizes[i]; var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk); + var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, false); assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(expectedNumberOFChunks[i])); for (var chunk : chunks) { @@ -76,6 +164,48 @@ public void testChunkSplit_SentencesLongerThanChunkSize() { } } + public void testChunkSplit_SentencesLongerThanChunkSize_WithOverlap() { + var chunkSizes = new int[] { 10, 30, 50 }; + + // Chunk sizes are shorter the sentences most of the sentences will be split. + for (int i = 0; i < chunkSizes.length; i++) { + int maxWordsPerChunk = chunkSizes[i]; + var chunker = new SentenceBoundaryChunker(); + var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, true); + assertThat(chunks.get(0), containsString("Word segmentation is the problem of dividing")); + assertThat(chunks.get(chunks.size() - 1), containsString(", with solidification being a stronger norm.")); + } + } + + public void testShortLongShortSentences_WithOverlap() { + int maxWordsPerChunk = 40; + var sb = new StringBuilder(); + int[] sentenceLengths = new int[] { 15, 30, 20, 5 }; + for (int l = 0; l < sentenceLengths.length; l++) { + sb.append("SStart").append(l).append(" "); + for (int i = 1; i < sentenceLengths[l] - 1; i++) { + sb.append(i).append(' '); + } + sb.append(sentenceLengths[l] - 1).append(". "); + } + + var chunker = new SentenceBoundaryChunker(); + var chunks = chunker.chunk(sb.toString(), maxWordsPerChunk, true); + assertThat(chunks, hasSize(5)); + assertTrue(chunks.get(0).trim().startsWith("SStart0")); // Entire sentence + assertTrue(chunks.get(0).trim().endsWith(".")); // Entire sentence + + assertTrue(chunks.get(1).trim().startsWith("SStart0")); // contains previous sentence + assertFalse(chunks.get(1).trim().endsWith(".")); // not a full sentence(s) + + assertTrue(chunks.get(2).trim().endsWith(".")); + assertTrue(chunks.get(3).trim().endsWith(".")); + + assertTrue(chunks.get(4).trim().startsWith("SStart2")); // contains previous sentence + assertThat(chunks.get(4), containsString("SStart3")); // last chunk contains 2 sentences + assertTrue(chunks.get(4).trim().endsWith(".")); // full sentence(s) + } + public void testCountWords() { // Test word count matches the whitespace separated word count. var splitByWhiteSpaceSentenceSizes = sentenceSizes(TEST_TEXT); @@ -102,6 +232,30 @@ public void testCountWords() { assertEquals(BreakIterator.DONE, sentenceIterator.next()); } + public void testSkipWords() { + int numWords = 50; + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < numWords; i++) { + sb.append("word").append(i).append(" "); + } + var text = sb.toString(); + + var wordIterator = BreakIterator.getWordInstance(Locale.ROOT); + wordIterator.setText(text); + + int start = 0; + int pos = SentenceBoundaryChunker.skipWords(start, 3, wordIterator); + assertThat(text.substring(pos), startsWith("word3 ")); + pos = SentenceBoundaryChunker.skipWords(pos + 1, 1, wordIterator); + assertThat(text.substring(pos), startsWith("word4 ")); + pos = SentenceBoundaryChunker.skipWords(pos + 1, 5, wordIterator); + assertThat(text.substring(pos), startsWith("word9 ")); + + // past the end of the input + pos = SentenceBoundaryChunker.skipWords(0, numWords + 10, wordIterator); + assertThat(pos, greaterThan(0)); + } + public void testCountWords_short() { // Test word count matches the whitespace separated word count. var text = "This is a short sentence. Followed by another."; @@ -148,7 +302,7 @@ public void testCountWords_WithSymbols() { public void testChunkSplitLargeChunkSizesWithChunkingSettings() { for (int maxWordsPerChunk : new int[] { 100, 200 }) { var chunker = new SentenceBoundaryChunker(); - SentenceBoundaryChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(maxWordsPerChunk); + SentenceBoundaryChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(maxWordsPerChunk, 0); var chunks = chunker.chunk(TEST_TEXT, chunkingSettings); int numChunks = expectedNumberOfChunks(sentenceSizes(TEST_TEXT), maxWordsPerChunk); @@ -182,16 +336,30 @@ private int[] sentenceSizes(String text) { } private int expectedNumberOfChunks(int[] sentenceLengths, int maxWordsPerChunk) { - int numChunks = 1; + return chunkOverlaps(sentenceLengths, maxWordsPerChunk, false).length; + } + + private int[] chunkOverlaps(int[] sentenceLengths, int maxWordsPerChunk, boolean includeSingleSentenceOverlap) { + int maxOverlap = SentenceBoundaryChunker.maxWordsInOverlap(maxWordsPerChunk); + + var overlaps = new ArrayList(); + overlaps.add(0); int runningWordCount = 0; for (int i = 0; i < sentenceLengths.length; i++) { if (runningWordCount + sentenceLengths[i] > maxWordsPerChunk) { - numChunks++; runningWordCount = sentenceLengths[i]; + if (includeSingleSentenceOverlap && i > 0) { + // include what is carried over from the previous + int overlap = Math.min(maxOverlap, sentenceLengths[i - 1]); + overlaps.add(overlap); + runningWordCount += overlap; + } else { + overlaps.add(0); + } } else { runningWordCount += sentenceLengths[i]; } } - return numChunks; + return overlaps.stream().mapToInt(Integer::intValue).toArray(); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettingsTests.java index 3f304a593144b..fe97d7eb3af54 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettingsTests.java @@ -59,13 +59,12 @@ protected Writeable.Reader instanceReader() { @Override protected SentenceBoundaryChunkingSettings createTestInstance() { - return new SentenceBoundaryChunkingSettings(randomNonNegativeInt()); + return new SentenceBoundaryChunkingSettings(randomNonNegativeInt(), randomBoolean() ? 0 : 1); } @Override protected SentenceBoundaryChunkingSettings mutateInstance(SentenceBoundaryChunkingSettings instance) throws IOException { var chunkSize = randomValueOtherThan(instance.maxChunkSize, ESTestCase::randomNonNegativeInt); - - return new SentenceBoundaryChunkingSettings(chunkSize); + return new SentenceBoundaryChunkingSettings(chunkSize, instance.sentenceOverlap); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java index 21d8c65ad7dcd..08c0724f36270 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java @@ -54,9 +54,6 @@ public class WordBoundaryChunkerTests extends ESTestCase { + " خليفہ المومنين يا خليفہ المسلمين يا صحابی يا رضي الله عنه چئي۔ (ب) آنحضور ﷺ جي گھروارين کان علاوه ڪنھن کي ام المومنين " + "چئي۔ (ج) آنحضور ﷺ جي خاندان جي اھل بيت کان علاوہڍه ڪنھن کي اھل بيت چئي۔ (د) پنھنجي عبادت گاھ کي مسجد چئي۔" }; - private static final int DEFAULT_MAX_CHUNK_SIZE = 250; - private static final int DEFAULT_OVERLAP = 100; - public static int NUM_WORDS_IN_TEST_TEXT; static { var wordIterator = BreakIterator.getWordInstance(Locale.ROOT); @@ -139,7 +136,7 @@ public void testNumberOfChunksWithWordBoundaryChunkingSettings() { } public void testInvalidChunkingSettingsProvided() { - ChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(randomNonNegativeInt()); + ChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(randomNonNegativeInt(), 0); assertThrows(IllegalArgumentException.class, () -> { new WordBoundaryChunker().chunk(TEST_TEXT, chunkingSettings); }); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java index a3605aade1fa1..6d6403b69ea11 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java @@ -42,4 +42,12 @@ protected RankBuilder getThrowingRankBuilder(int rankWindowSize, String rankFeat protected Collection> pluginsNeeded() { return List.of(InferencePlugin.class, TextSimilarityTestPlugin.class); } + + public void testQueryPhaseShardThrowingAllShardsFail() throws Exception { + // no-op + } + + public void testQueryPhaseCoordinatorThrowingAllShardsFail() throws Exception { + // no-op + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java index 120527f489549..358aa9804b916 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java @@ -7,14 +7,10 @@ package org.elasticsearch.xpack.inference.rank.textsimilarity; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; -import org.elasticsearch.action.search.SearchPhaseController; import org.elasticsearch.action.support.ActionFilter; import org.elasticsearch.action.support.ActionFilterChain; import org.elasticsearch.client.internal.Client; @@ -28,11 +24,8 @@ import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.SearchHits; -import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.rank.RankBuilder; import org.elasticsearch.search.rank.RankShardResult; -import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext; -import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; import org.elasticsearch.search.rank.rerank.AbstractRerankerIT; @@ -246,37 +239,6 @@ public void doXContent(XContentBuilder builder, ToXContent.Params params) throws builder.field(THROWING_TYPE_FIELD.getPreferredName(), throwingRankBuilderType); } - @Override - public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { - if (this.throwingRankBuilderType == AbstractRerankerIT.ThrowingRankBuilderType.THROWING_QUERY_PHASE_SHARD_CONTEXT) - return new QueryPhaseRankShardContext(queries, rankWindowSize()) { - @Override - public RankShardResult combineQueryPhaseResults(List rankResults) { - throw new UnsupportedOperationException("qps - simulated failure"); - } - }; - else { - return super.buildQueryPhaseShardContext(queries, from); - } - } - - @Override - public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { - if (this.throwingRankBuilderType == AbstractRerankerIT.ThrowingRankBuilderType.THROWING_QUERY_PHASE_COORDINATOR_CONTEXT) - return new QueryPhaseRankCoordinatorContext(rankWindowSize()) { - @Override - public ScoreDoc[] rankQueryPhaseResults( - List querySearchResults, - SearchPhaseController.TopDocsStats topDocStats - ) { - throw new UnsupportedOperationException("qpc - simulated failure"); - } - }; - else { - return super.buildQueryPhaseCoordinatorContext(size, from); - } - } - @Override public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { if (this.throwingRankBuilderType == AbstractRerankerIT.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_SHARD_CONTEXT)