From bde7bff6dd491a81a2543fd8bd0f3a3cfe410673 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Wed, 11 Dec 2024 05:17:35 +1100 Subject: [PATCH 01/28] Mute org.elasticsearch.xpack.searchablesnapshots.RetrySearchIntegTests testSearcherId #118374 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index 2caa85eebd568..3d0a57c8d650a 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -291,6 +291,9 @@ tests: - class: org.elasticsearch.xpack.remotecluster.CrossClusterEsqlRCS1UnavailableRemotesIT method: testEsqlRcs1UnavailableRemoteScenarios issue: https://github.com/elastic/elasticsearch/issues/118350 +- class: org.elasticsearch.xpack.searchablesnapshots.RetrySearchIntegTests + method: testSearcherId + issue: https://github.com/elastic/elasticsearch/issues/118374 # Examples: # From 56379a314377f08c4fc1b365e4052de04d6cc8d6 Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Tue, 10 Dec 2024 21:01:18 +0100 Subject: [PATCH 02/28] Lazily initialize TransportActionStats in RequestHandlerRegistry (#117435) Same as #114107 but for the transport side. We are consuming more than 0.5M for these stat instances and a large fraction if not the majority will go unused in many use-cases, given how expensive stats are in terms of volatile operations the overhead of an added `getAcquire` is trivial in the absence of writes. Outside of reducing the prod footprint this is also kind of nice in x-pack internal cluster tests in particular where it saves quite a bit of heap. --- .../bootstrap/Elasticsearch.java | 8 +++- .../elasticsearch/rest/MethodHandlers.java | 4 +- .../transport/RequestHandlerRegistry.java | 39 +++++++++++++++++-- .../transport/TransportActionStats.java | 2 + 4 files changed, 47 insertions(+), 6 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java b/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java index 04f93494a50f1..ae59f6578f03a 100644 --- a/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java +++ b/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java @@ -44,6 +44,8 @@ import org.elasticsearch.node.Node; import org.elasticsearch.node.NodeValidationException; import org.elasticsearch.plugins.PluginsLoader; +import org.elasticsearch.rest.MethodHandlers; +import org.elasticsearch.transport.RequestHandlerRegistry; import java.io.IOException; import java.io.InputStream; @@ -201,7 +203,11 @@ private static void initPhase2(Bootstrap bootstrap) throws IOException { SubscribableListener.class, RunOnce.class, // We eagerly initialize to work around log4j permissions & JDK-8309727 - VectorUtil.class + VectorUtil.class, + // RequestHandlerRegistry and MethodHandlers classes do nontrivial static initialization which should always succeed but load + // it now (before SM) to be sure + RequestHandlerRegistry.class, + MethodHandlers.class ); // load the plugin Java modules and layers now for use in entitlements diff --git a/server/src/main/java/org/elasticsearch/rest/MethodHandlers.java b/server/src/main/java/org/elasticsearch/rest/MethodHandlers.java index 2f53f48f9ae5b..bfb8ca018fca2 100644 --- a/server/src/main/java/org/elasticsearch/rest/MethodHandlers.java +++ b/server/src/main/java/org/elasticsearch/rest/MethodHandlers.java @@ -22,13 +22,13 @@ /** * Encapsulate multiple handlers for the same path, allowing different handlers for different HTTP verbs and versions. */ -final class MethodHandlers { +public final class MethodHandlers { private final String path; private final Map> methodHandlers; @SuppressWarnings("unused") // only accessed via #STATS_TRACKER_HANDLE, lazy initialized because instances consume non-trivial heap - private volatile HttpRouteStatsTracker statsTracker; + private HttpRouteStatsTracker statsTracker; private static final VarHandle STATS_TRACKER_HANDLE; diff --git a/server/src/main/java/org/elasticsearch/transport/RequestHandlerRegistry.java b/server/src/main/java/org/elasticsearch/transport/RequestHandlerRegistry.java index bfe6377b495ab..84a8ee1b2ebbf 100644 --- a/server/src/main/java/org/elasticsearch/transport/RequestHandlerRegistry.java +++ b/server/src/main/java/org/elasticsearch/transport/RequestHandlerRegistry.java @@ -19,6 +19,8 @@ import org.elasticsearch.telemetry.tracing.Tracer; import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; import java.util.concurrent.Executor; import static org.elasticsearch.core.Releasables.assertOnce; @@ -33,7 +35,19 @@ public class RequestHandlerRegistry implements private final TaskManager taskManager; private final Tracer tracer; private final Writeable.Reader requestReader; - private final TransportActionStatsTracker statsTracker = new TransportActionStatsTracker(); + @SuppressWarnings("unused") // only accessed via #STATS_TRACKER_HANDLE, lazy initialized because instances consume non-trivial heap + private TransportActionStatsTracker statsTracker; + + private static final VarHandle STATS_TRACKER_HANDLE; + + static { + try { + STATS_TRACKER_HANDLE = MethodHandles.lookup() + .findVarHandle(RequestHandlerRegistry.class, "statsTracker", TransportActionStatsTracker.class); + } catch (Exception e) { + throw new ExceptionInInitializerError(e); + } + } public RequestHandlerRegistry( String action, @@ -118,15 +132,34 @@ public static RequestHandlerRegistry replaceHand } public void addRequestStats(int messageSize) { - statsTracker.addRequestStats(messageSize); + statsTracker().addRequestStats(messageSize); } @Override public void addResponseStats(int messageSize) { - statsTracker.addResponseStats(messageSize); + statsTracker().addResponseStats(messageSize); } public TransportActionStats getStats() { + var statsTracker = existingStatsTracker(); + if (statsTracker == null) { + return TransportActionStats.EMPTY; + } return statsTracker.getStats(); } + + private TransportActionStatsTracker statsTracker() { + var tracker = existingStatsTracker(); + if (tracker == null) { + var newTracker = new TransportActionStatsTracker(); + if ((tracker = (TransportActionStatsTracker) STATS_TRACKER_HANDLE.compareAndExchange(this, null, newTracker)) == null) { + tracker = newTracker; + } + } + return tracker; + } + + private TransportActionStatsTracker existingStatsTracker() { + return (TransportActionStatsTracker) STATS_TRACKER_HANDLE.getAcquire(this); + } } diff --git a/server/src/main/java/org/elasticsearch/transport/TransportActionStats.java b/server/src/main/java/org/elasticsearch/transport/TransportActionStats.java index feed042a5934e..f35443cdc8d6d 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportActionStats.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportActionStats.java @@ -27,6 +27,8 @@ public record TransportActionStats( long[] responseSizeHistogram ) implements Writeable, ToXContentObject { + public static final TransportActionStats EMPTY = new TransportActionStats(0, 0, new long[0], 0, 0, new long[0]); + public TransportActionStats(StreamInput in) throws IOException { this(in.readVLong(), in.readVLong(), in.readVLongArray(), in.readVLong(), in.readVLong(), in.readVLongArray()); } From 537f4ce8718e6f238e5a519c18374718f8566748 Mon Sep 17 00:00:00 2001 From: Joe Gallo Date: Tue, 10 Dec 2024 15:01:53 -0500 Subject: [PATCH 03/28] Fix log message format bugs (#118354) --- docs/changelog/118354.yaml | 5 +++++ .../elasticsearch/repositories/s3/S3BlobContainer.java | 2 +- .../upgrades/FullClusterRestartDownsampleIT.java | 2 +- .../java/org/elasticsearch/upgrades/DownsampleIT.java | 2 +- .../elasticsearch/index/seqno/RetentionLeaseIT.java | 2 +- .../repositories/blobstore/BlobStoreRepository.java | 10 ++++------ .../xpack/downsample/DownsampleShardIndexer.java | 2 +- .../org/elasticsearch/xpack/enrich/EnrichPlugin.java | 2 +- 8 files changed, 15 insertions(+), 12 deletions(-) create mode 100644 docs/changelog/118354.yaml diff --git a/docs/changelog/118354.yaml b/docs/changelog/118354.yaml new file mode 100644 index 0000000000000..e2d72db121276 --- /dev/null +++ b/docs/changelog/118354.yaml @@ -0,0 +1,5 @@ +pr: 118354 +summary: Fix log message format bugs +area: Ingest Node +type: bug +issues: [] diff --git a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobContainer.java b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobContainer.java index f527dcd42814c..bf693222a4b72 100644 --- a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobContainer.java +++ b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobContainer.java @@ -1025,7 +1025,7 @@ public void onResponse(Void unused) { // should be no other processes interacting with the repository. logger.warn( Strings.format( - "failed to clean up multipart upload [{}] of blob [{}][{}][{}]", + "failed to clean up multipart upload [%s] of blob [%s][%s][%s]", abortMultipartUploadRequest.getUploadId(), blobStore.getRepositoryMetadata().name(), abortMultipartUploadRequest.getBucketName(), diff --git a/qa/full-cluster-restart/src/javaRestTest/java/org/elasticsearch/upgrades/FullClusterRestartDownsampleIT.java b/qa/full-cluster-restart/src/javaRestTest/java/org/elasticsearch/upgrades/FullClusterRestartDownsampleIT.java index d98d53baf9015..f907870fc8254 100644 --- a/qa/full-cluster-restart/src/javaRestTest/java/org/elasticsearch/upgrades/FullClusterRestartDownsampleIT.java +++ b/qa/full-cluster-restart/src/javaRestTest/java/org/elasticsearch/upgrades/FullClusterRestartDownsampleIT.java @@ -263,7 +263,7 @@ private String getRollupIndexName() throws IOException { if (asMap.size() == 1) { return (String) asMap.keySet().toArray()[0]; } - logger.warn("--> No matching rollup name for path [%s]", endpoint); + logger.warn("--> No matching rollup name for path [{}]", endpoint); return null; } diff --git a/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/DownsampleIT.java b/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/DownsampleIT.java index bca0c26ad2c32..b1212913b7fb0 100644 --- a/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/DownsampleIT.java +++ b/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/DownsampleIT.java @@ -238,7 +238,7 @@ private String getRollupIndexName() throws IOException { if (asMap.size() == 1) { return (String) asMap.keySet().toArray()[0]; } - logger.warn("--> No matching rollup name for path [%s]", endpoint); + logger.warn("--> No matching rollup name for path [{}]", endpoint); return null; } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/seqno/RetentionLeaseIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/seqno/RetentionLeaseIT.java index 1c5d67d1fa40a..70689dc689673 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/seqno/RetentionLeaseIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/seqno/RetentionLeaseIT.java @@ -336,7 +336,7 @@ public void testRetentionLeasesSyncOnRecovery() throws Exception { .getShardOrNull(new ShardId(resolveIndex("index"), 0)); final int length = randomIntBetween(1, 8); final Map currentRetentionLeases = new LinkedHashMap<>(); - logger.info("adding retention [{}}] leases", length); + logger.info("adding retention [{}] leases", length); for (int i = 0; i < length; i++) { final String id = randomValueOtherThanMany(currentRetentionLeases.keySet()::contains, () -> randomAlphaOfLength(8)); final long retainingSequenceNumber = randomLongBetween(0, Long.MAX_VALUE); diff --git a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java index f1c3d82b74cab..cc22b60991703 100644 --- a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java +++ b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java @@ -3484,12 +3484,10 @@ private static void ensureNotAborted(ShardId shardId, SnapshotId snapshotId, Ind // A normally running shard snapshot should be in stage INIT or STARTED. And we know it's not in PAUSING or ABORTED because // the ensureNotAborted() call above did not throw. The remaining options don't make sense, if they ever happen. logger.error( - () -> Strings.format( - "Shard snapshot found an unexpected state. ShardId [{}], SnapshotID [{}], Stage [{}]", - shardId, - snapshotId, - shardSnapshotStage - ) + "Shard snapshot found an unexpected state. ShardId [{}], SnapshotID [{}], Stage [{}]", + shardId, + snapshotId, + shardSnapshotStage ); assert false; } diff --git a/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/DownsampleShardIndexer.java b/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/DownsampleShardIndexer.java index 81aea221ebb4f..50149ec2cbe58 100644 --- a/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/DownsampleShardIndexer.java +++ b/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/DownsampleShardIndexer.java @@ -386,7 +386,7 @@ public void collect(int docId, long owningBucketOrd) throws IOException { if (logger.isTraceEnabled()) { logger.trace( - "Doc: [{}] - _tsid: [{}], @timestamp: [{}}] -> downsample bucket ts: [{}]", + "Doc: [{}] - _tsid: [{}], @timestamp: [{}] -> downsample bucket ts: [{}]", docId, DocValueFormat.TIME_SERIES_ID.format(tsidHash), timestampFormat.format(timestamp), diff --git a/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/EnrichPlugin.java b/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/EnrichPlugin.java index d46639d700420..ef455acb645d9 100644 --- a/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/EnrichPlugin.java +++ b/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/EnrichPlugin.java @@ -172,7 +172,7 @@ public EnrichPlugin(final Settings settings) { if (settings.hasValue(CACHE_SIZE_SETTING_NAME)) { throw new IllegalArgumentException( Strings.format( - "Both [{}] and [{}] are set, please use [{}]", + "Both [%s] and [%s] are set, please use [%s]", CACHE_SIZE_SETTING_NAME, CACHE_SIZE_SETTING_BWC_NAME, CACHE_SIZE_SETTING_NAME From 5406850abf7a83e8078001dc7d81081fa6e77bc8 Mon Sep 17 00:00:00 2001 From: Keith Massey Date: Tue, 10 Dec 2024 14:35:49 -0600 Subject: [PATCH 04/28] Disabling ReindexDataStreamTransportActionIT in release mode (#118377) The actions in the migrate plugin are currently behind a feature flag. This disables the integration test for them if the flag is not set. Closes #118376 --- .../migrate/action/ReindexDataStreamTransportActionIT.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/x-pack/plugin/migrate/src/internalClusterTest/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamTransportActionIT.java b/x-pack/plugin/migrate/src/internalClusterTest/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamTransportActionIT.java index 515250bb58a94..7f2243ed76849 100644 --- a/x-pack/plugin/migrate/src/internalClusterTest/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamTransportActionIT.java +++ b/x-pack/plugin/migrate/src/internalClusterTest/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamTransportActionIT.java @@ -40,6 +40,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; +import static org.elasticsearch.xpack.migrate.action.ReindexDataStreamAction.REINDEX_DATA_STREAM_FEATURE_FLAG; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; @@ -51,6 +52,7 @@ protected Collection> nodePlugins() { } public void testNonExistentDataStream() { + assumeTrue("requires the migration reindex feature flag", REINDEX_DATA_STREAM_FEATURE_FLAG.isEnabled()); String nonExistentDataStreamName = randomAlphaOfLength(50); ReindexDataStreamRequest reindexDataStreamRequest = new ReindexDataStreamRequest( ReindexDataStreamAction.Mode.UPGRADE, @@ -64,6 +66,7 @@ public void testNonExistentDataStream() { } public void testAlreadyUpToDateDataStream() throws Exception { + assumeTrue("requires the migration reindex feature flag", REINDEX_DATA_STREAM_FEATURE_FLAG.isEnabled()); String dataStreamName = randomAlphaOfLength(50).toLowerCase(Locale.ROOT); ReindexDataStreamRequest reindexDataStreamRequest = new ReindexDataStreamRequest( ReindexDataStreamAction.Mode.UPGRADE, From 87b0f72cf571491b7535c7a96024b8d7e55a0d01 Mon Sep 17 00:00:00 2001 From: Nik Everett Date: Tue, 10 Dec 2024 16:28:12 -0500 Subject: [PATCH 05/28] ESQL: Opt into extra data stream resolution (#118378) * ESQL: Opt into extra data stream resolution This opts ESQL's data node request into extra data stream resolution. * Update docs/changelog/118378.yaml --- docs/changelog/118378.yaml | 5 + .../xpack/esql/EsqlSecurityIT.java | 129 ++++++++++++++++++ .../src/javaRestTest/resources/roles.yml | 54 ++++++++ .../xpack/esql/plugin/DataNodeRequest.java | 5 + 4 files changed, 193 insertions(+) create mode 100644 docs/changelog/118378.yaml diff --git a/docs/changelog/118378.yaml b/docs/changelog/118378.yaml new file mode 100644 index 0000000000000..d6c388b671968 --- /dev/null +++ b/docs/changelog/118378.yaml @@ -0,0 +1,5 @@ +pr: 118378 +summary: Opt into extra data stream resolution +area: ES|QL +type: bug +issues: [] diff --git a/x-pack/plugin/esql/qa/security/src/javaRestTest/java/org/elasticsearch/xpack/esql/EsqlSecurityIT.java b/x-pack/plugin/esql/qa/security/src/javaRestTest/java/org/elasticsearch/xpack/esql/EsqlSecurityIT.java index f8f1fe872711d..9566aeb8f28dc 100644 --- a/x-pack/plugin/esql/qa/security/src/javaRestTest/java/org/elasticsearch/xpack/esql/EsqlSecurityIT.java +++ b/x-pack/plugin/esql/qa/security/src/javaRestTest/java/org/elasticsearch/xpack/esql/EsqlSecurityIT.java @@ -34,6 +34,7 @@ import java.util.Locale; import java.util.Map; +import static org.elasticsearch.test.ListMatcher.matchesList; import static org.elasticsearch.test.MapMatcher.assertMap; import static org.elasticsearch.test.MapMatcher.matchesMap; import static org.hamcrest.Matchers.containsString; @@ -56,6 +57,11 @@ public class EsqlSecurityIT extends ESRestTestCase { .user("metadata1_read2", "x-pack-test-password", "metadata1_read2", false) .user("alias_user1", "x-pack-test-password", "alias_user1", false) .user("alias_user2", "x-pack-test-password", "alias_user2", false) + .user("logs_foo_all", "x-pack-test-password", "logs_foo_all", false) + .user("logs_foo_16_only", "x-pack-test-password", "logs_foo_16_only", false) + .user("logs_foo_after_2021", "x-pack-test-password", "logs_foo_after_2021", false) + .user("logs_foo_after_2021_pattern", "x-pack-test-password", "logs_foo_after_2021_pattern", false) + .user("logs_foo_after_2021_alias", "x-pack-test-password", "logs_foo_after_2021_alias", false) .build(); @Override @@ -342,6 +348,14 @@ public void testDocumentLevelSecurity() throws Exception { assertThat(respMap.get("values"), equalTo(List.of(List.of(10.0)))); } + public void testDocumentLevelSecurityFromStar() throws Exception { + Response resp = runESQLCommand("user3", "from in*x | stats sum=sum(value)"); + assertOK(resp); + Map respMap = entityAsMap(resp); + assertThat(respMap.get("columns"), equalTo(List.of(Map.of("name", "sum", "type", "double")))); + assertThat(respMap.get("values"), equalTo(List.of(List.of(10.0)))); + } + public void testFieldLevelSecurityAllow() throws Exception { Response resp = runESQLCommand("fls_user", "FROM index* | SORT value | LIMIT 1"); assertOK(resp); @@ -545,6 +559,22 @@ private void removeEnrichPolicy() throws Exception { client().performRequest(new Request("DELETE", "_enrich/policy/songs")); } + public void testDataStream() throws IOException { + createDataStream(); + MapMatcher twoResults = matchesMap().extraOk().entry("values", matchesList().item(matchesList().item(2))); + MapMatcher oneResult = matchesMap().extraOk().entry("values", matchesList().item(matchesList().item(1))); + assertMap(entityAsMap(runESQLCommand("logs_foo_all", "FROM logs-foo | STATS COUNT(*)")), twoResults); + assertMap(entityAsMap(runESQLCommand("logs_foo_16_only", "FROM logs-foo | STATS COUNT(*)")), oneResult); + assertMap(entityAsMap(runESQLCommand("logs_foo_after_2021", "FROM logs-foo | STATS COUNT(*)")), oneResult); + assertMap(entityAsMap(runESQLCommand("logs_foo_after_2021_pattern", "FROM logs-foo | STATS COUNT(*)")), oneResult); + assertMap(entityAsMap(runESQLCommand("logs_foo_after_2021_alias", "FROM alias-foo | STATS COUNT(*)")), oneResult); + assertMap(entityAsMap(runESQLCommand("logs_foo_all", "FROM logs-* | STATS COUNT(*)")), twoResults); + assertMap(entityAsMap(runESQLCommand("logs_foo_16_only", "FROM logs-* | STATS COUNT(*)")), oneResult); + assertMap(entityAsMap(runESQLCommand("logs_foo_after_2021", "FROM logs-* | STATS COUNT(*)")), oneResult); + assertMap(entityAsMap(runESQLCommand("logs_foo_after_2021_pattern", "FROM logs-* | STATS COUNT(*)")), oneResult); + assertMap(entityAsMap(runESQLCommand("logs_foo_after_2021_alias", "FROM alias-* | STATS COUNT(*)")), oneResult); + } + protected Response runESQLCommand(String user, String command) throws IOException { if (command.toLowerCase(Locale.ROOT).contains("limit") == false) { // add a (high) limit to avoid warnings on default limit @@ -592,4 +622,103 @@ static Settings randomPragmas() { } return settings.build(); } + + private void createDataStream() throws IOException { + createDataStreamPolicy(); + createDataStreamComponentTemplate(); + createDataStreamIndexTemplate(); + createDataStreamDocuments(); + createDataStreamAlias(); + } + + private void createDataStreamPolicy() throws IOException { + Request request = new Request("PUT", "_ilm/policy/my-lifecycle-policy"); + request.setJsonEntity(""" + { + "policy": { + "phases": { + "hot": { + "actions": { + "rollover": { + "max_primary_shard_size": "50gb" + } + } + }, + "delete": { + "min_age": "735d", + "actions": { + "delete": {} + } + } + } + } + }"""); + client().performRequest(request); + } + + private void createDataStreamComponentTemplate() throws IOException { + Request request = new Request("PUT", "_component_template/my-template"); + request.setJsonEntity(""" + { + "template": { + "settings": { + "index.lifecycle.name": "my-lifecycle-policy" + }, + "mappings": { + "properties": { + "@timestamp": { + "type": "date", + "format": "date_optional_time||epoch_millis" + }, + "data_stream": { + "properties": { + "namespace": {"type": "keyword"} + } + } + } + } + } + }"""); + client().performRequest(request); + } + + private void createDataStreamIndexTemplate() throws IOException { + Request request = new Request("PUT", "_index_template/my-index-template"); + request.setJsonEntity(""" + { + "index_patterns": ["logs-*"], + "data_stream": {}, + "composed_of": ["my-template"], + "priority": 500 + }"""); + client().performRequest(request); + } + + private void createDataStreamDocuments() throws IOException { + Request request = new Request("POST", "logs-foo/_bulk"); + request.addParameter("refresh", ""); + request.setJsonEntity(""" + { "create" : {} } + { "@timestamp": "2099-05-06T16:21:15.000Z", "data_stream": {"namespace": "16"} } + { "create" : {} } + { "@timestamp": "2001-05-06T16:21:15.000Z", "data_stream": {"namespace": "17"} } + """); + assertMap(entityAsMap(client().performRequest(request)), matchesMap().extraOk().entry("errors", false)); + } + + private void createDataStreamAlias() throws IOException { + Request request = new Request("PUT", "_alias"); + request.setJsonEntity(""" + { + "actions": [ + { + "add": { + "index": "logs-foo", + "alias": "alias-foo" + } + } + ] + }"""); + assertMap(entityAsMap(client().performRequest(request)), matchesMap().extraOk().entry("errors", false)); + } } diff --git a/x-pack/plugin/esql/qa/security/src/javaRestTest/resources/roles.yml b/x-pack/plugin/esql/qa/security/src/javaRestTest/resources/roles.yml index 5c0164782d181..365a072edb74e 100644 --- a/x-pack/plugin/esql/qa/security/src/javaRestTest/resources/roles.yml +++ b/x-pack/plugin/esql/qa/security/src/javaRestTest/resources/roles.yml @@ -92,3 +92,57 @@ fls_user: privileges: [ 'read' ] field_security: grant: [ value ] + +logs_foo_all: + cluster: [] + indices: + - names: [ 'logs-foo' ] + privileges: [ 'read' ] + +logs_foo_16_only: + cluster: [] + indices: + - names: [ 'logs-foo' ] + privileges: [ 'read' ] + query: | + { + "term": { + "data_stream.namespace": "16" + } + } + +logs_foo_after_2021: + cluster: [] + indices: + - names: [ 'logs-foo' ] + privileges: [ 'read' ] + query: | + { + "range": { + "@timestamp": {"gte": "2021-01-01T00:00:00"} + } + } + +logs_foo_after_2021_pattern: + cluster: [] + indices: + - names: [ 'logs-*' ] + privileges: [ 'read' ] + query: | + { + "range": { + "@timestamp": {"gte": "2021-01-01T00:00:00"} + } + } + +logs_foo_after_2021_alias: + cluster: [] + indices: + - names: [ 'alias-foo' ] + privileges: [ 'read' ] + query: | + { + "range": { + "@timestamp": {"gte": "2021-01-01T00:00:00"} + } + } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequest.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequest.java index 4c01d326ed7bc..6014e24e39c5f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequest.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequest.java @@ -118,6 +118,11 @@ public IndicesRequest indices(String... indices) { return this; } + @Override + public boolean includeDataStreams() { + return true; + } + @Override public IndicesOptions indicesOptions() { return indicesOptions; From a27e5dba731bbcb2859228ba463bc6e16a9e74cb Mon Sep 17 00:00:00 2001 From: Keith Massey Date: Tue, 10 Dec 2024 15:29:35 -0600 Subject: [PATCH 06/28] Disabling the migration reindex yaml rest tests unless feature is available (#118382) This disables the migration reindex yaml rest tests when the feature flag is not available, so that the build doesn't fail in release mode, like: ``` ./gradlew ":x-pack:plugin:yamlRestTest" --tests "org.elasticsearch.xpack.test.rest.XPackRestIT.test {p0=migrate/*}" -Dtests.seed=28C1E85204B9DCAE -Dbuild.snapshot=false -Dtests.jvm.argline="-Dbuild.snapshot=false" -Dlicense.key=x-pack/license-tools/src/test/resources/public.key -Dtests.locale=bn-BD -Dtests.timezone=America/Dawson -Druntime.java=23 ``` --- .../rest/RestMigrationReindexAction.java | 14 ++++++++++ .../rest-api-spec/test/migrate/10_reindex.yml | 28 +++++++++++++++++++ .../test/migrate/20_reindex_status.yml | 14 ++++++++++ 3 files changed, 56 insertions(+) diff --git a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/rest/RestMigrationReindexAction.java b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/rest/RestMigrationReindexAction.java index a7f630d68234d..19cb439495e9a 100644 --- a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/rest/RestMigrationReindexAction.java +++ b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/rest/RestMigrationReindexAction.java @@ -20,11 +20,16 @@ import org.elasticsearch.xpack.migrate.action.ReindexDataStreamAction.ReindexDataStreamResponse; import java.io.IOException; +import java.util.Collections; +import java.util.HashSet; import java.util.List; +import java.util.Set; import static org.elasticsearch.rest.RestRequest.Method.POST; +import static org.elasticsearch.xpack.migrate.action.ReindexDataStreamAction.REINDEX_DATA_STREAM_FEATURE_FLAG; public class RestMigrationReindexAction extends BaseRestHandler { + public static final String MIGRATION_REINDEX_CAPABILITY = "migration_reindex"; @Override public String getName() { @@ -49,6 +54,15 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli ); } + @Override + public Set supportedCapabilities() { + Set capabilities = new HashSet<>(); + if (REINDEX_DATA_STREAM_FEATURE_FLAG.isEnabled()) { + capabilities.add(MIGRATION_REINDEX_CAPABILITY); + } + return Collections.unmodifiableSet(capabilities); + } + static class ReindexDataStreamRestToXContentListener extends RestBuilderListener { ReindexDataStreamRestToXContentListener(RestChannel channel) { diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/migrate/10_reindex.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/migrate/10_reindex.yml index 01a41b3aa8c94..f50a7a65f53d3 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/migrate/10_reindex.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/migrate/10_reindex.yml @@ -6,6 +6,13 @@ setup: --- "Test Reindex With Unsupported Mode": + - requires: + reason: "migration reindex is behind a feature flag" + test_runner_features: [capabilities] + capabilities: + - method: POST + path: /_migration/reindex + capabilities: [migration_reindex] - do: catch: /illegal_argument_exception/ migrate.reindex: @@ -19,6 +26,13 @@ setup: --- "Test Reindex With Nonexistent Data Stream": + - requires: + reason: "migration reindex is behind a feature flag" + test_runner_features: [capabilities] + capabilities: + - method: POST + path: /_migration/reindex + capabilities: [migration_reindex] - do: catch: /resource_not_found_exception/ migrate.reindex: @@ -44,6 +58,13 @@ setup: --- "Test Reindex With Bad Data Stream Name": + - requires: + reason: "migration reindex is behind a feature flag" + test_runner_features: [capabilities] + capabilities: + - method: POST + path: /_migration/reindex + capabilities: [migration_reindex] - do: catch: /illegal_argument_exception/ migrate.reindex: @@ -57,6 +78,13 @@ setup: --- "Test Reindex With Existing Data Stream": + - requires: + reason: "migration reindex is behind a feature flag" + test_runner_features: [capabilities] + capabilities: + - method: POST + path: /_migration/reindex + capabilities: [migration_reindex] - do: indices.put_index_template: name: my-template1 diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/migrate/20_reindex_status.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/migrate/20_reindex_status.yml index 3fe133aeda70e..ae343a0b4db95 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/migrate/20_reindex_status.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/migrate/20_reindex_status.yml @@ -6,6 +6,13 @@ setup: --- "Test get reindex status with nonexistent task id": + - requires: + reason: "migration reindex is behind a feature flag" + test_runner_features: [capabilities] + capabilities: + - method: POST + path: /_migration/reindex + capabilities: [migration_reindex] - do: catch: /resource_not_found_exception/ migrate.get_reindex_status: @@ -13,6 +20,13 @@ setup: --- "Test Reindex With Existing Data Stream": + - requires: + reason: "migration reindex is behind a feature flag" + test_runner_features: [capabilities] + capabilities: + - method: POST + path: /_migration/reindex + capabilities: [migration_reindex] - do: indices.put_index_template: name: my-template1 From 47be542459ee024801fc2fd8bea5e8cc5dea2e0e Mon Sep 17 00:00:00 2001 From: Dianna Hohensee Date: Tue, 10 Dec 2024 16:33:57 -0500 Subject: [PATCH 07/28] Add a shard snapshot current status debug string (#118198) SnapshotShardsService will update the new debug string field in IndexShardSnapshotStatus as a shard snapshot operation proceeds so that the current state can be logged. The SnapshotShutdownProgressTracker will iterate through the SnapshotShardsService's list of shard snapshots and log their current status. We want to know where in the code a shard snapshot operation potentially gets stuck. This new field should be updated as frequently as is reasonable to inform on the shard snapshot's progress. Closes ES-10261 --- .../snapshots/SnapshotShutdownIT.java | 17 ++++ .../snapshots/IndexShardSnapshotStatus.java | 37 ++++++-- .../blobstore/BlobStoreRepository.java | 13 +++ .../snapshots/SnapshotShardsService.java | 87 ++++++++++++++++--- .../SnapshotShutdownProgressTracker.java | 18 +++- .../SnapshotShutdownProgressTrackerTests.java | 19 +++- 6 files changed, 166 insertions(+), 25 deletions(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SnapshotShutdownIT.java b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SnapshotShutdownIT.java index e5e641bfdda21..755ee960be73e 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SnapshotShutdownIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SnapshotShutdownIT.java @@ -524,6 +524,15 @@ public void testSnapshotShutdownProgressTracker() throws Exception { "Pause signals have been set for all shard snapshots on data node [" + nodeForRemovalId + "]" ) ); + mockLog.addExpectation( + new MockLog.SeenEventExpectation( + "SnapshotShutdownProgressTracker index shard snapshot status messages", + SnapshotShutdownProgressTracker.class.getCanonicalName(), + Level.INFO, + // Expect the shard snapshot to stall in data file upload, since we've blocked the data node file upload to the blob store. + "statusDescription='enqueued file snapshot tasks: threads running concurrent file uploads'" + ) + ); putShutdownForRemovalMetadata(nodeForRemoval, clusterService); @@ -583,6 +592,14 @@ public void testSnapshotShutdownProgressTracker() throws Exception { "Current active shard snapshot stats on data node [" + nodeForRemovalId + "]*Paused [" + numShards + "]" ) ); + mockLog.addExpectation( + new MockLog.SeenEventExpectation( + "SnapshotShutdownProgressTracker index shard snapshot messages", + SnapshotShutdownProgressTracker.class.getCanonicalName(), + Level.INFO, + "statusDescription='finished: master notification attempt complete'" + ) + ); // Release the master node to respond snapshotStatusUpdateLatch.countDown(); diff --git a/server/src/main/java/org/elasticsearch/index/snapshots/IndexShardSnapshotStatus.java b/server/src/main/java/org/elasticsearch/index/snapshots/IndexShardSnapshotStatus.java index d8bd460f6f819..6aa6a5e498789 100644 --- a/server/src/main/java/org/elasticsearch/index/snapshots/IndexShardSnapshotStatus.java +++ b/server/src/main/java/org/elasticsearch/index/snapshots/IndexShardSnapshotStatus.java @@ -98,6 +98,7 @@ public enum AbortStatus { private long processedSize; private String failure; private final SubscribableListener abortListeners = new SubscribableListener<>(); + private volatile String statusDescription; private IndexShardSnapshotStatus( final Stage stage, @@ -110,7 +111,8 @@ private IndexShardSnapshotStatus( final long totalSize, final long processedSize, final String failure, - final ShardGeneration generation + final ShardGeneration generation, + final String statusDescription ) { this.stage = new AtomicReference<>(Objects.requireNonNull(stage)); this.generation = new AtomicReference<>(generation); @@ -124,6 +126,7 @@ private IndexShardSnapshotStatus( this.processedSize = processedSize; this.incrementalSize = incrementalSize; this.failure = failure; + updateStatusDescription(statusDescription); } public synchronized Copy moveToStarted( @@ -272,6 +275,15 @@ public synchronized void addProcessedFiles(int count, long totalSize) { processedSize += totalSize; } + /** + * Updates the string explanation for what the snapshot is actively doing right now. + */ + public void updateStatusDescription(String statusString) { + assert statusString != null; + assert statusString.isEmpty() == false; + this.statusDescription = statusString; + } + /** * Returns a copy of the current {@link IndexShardSnapshotStatus}. This method is * intended to be used when a coherent state of {@link IndexShardSnapshotStatus} is needed. @@ -289,12 +301,13 @@ public synchronized IndexShardSnapshotStatus.Copy asCopy() { incrementalSize, totalSize, processedSize, - failure + failure, + statusDescription ); } public static IndexShardSnapshotStatus newInitializing(ShardGeneration generation) { - return new IndexShardSnapshotStatus(Stage.INIT, 0L, 0L, 0, 0, 0, 0, 0, 0, null, generation); + return new IndexShardSnapshotStatus(Stage.INIT, 0L, 0L, 0, 0, 0, 0, 0, 0, null, generation, "initializing"); } public static IndexShardSnapshotStatus.Copy newFailed(final String failure) { @@ -302,7 +315,7 @@ public static IndexShardSnapshotStatus.Copy newFailed(final String failure) { if (failure == null) { throw new IllegalArgumentException("A failure description is required for a failed IndexShardSnapshotStatus"); } - return new IndexShardSnapshotStatus(Stage.FAILURE, 0L, 0L, 0, 0, 0, 0, 0, 0, failure, null).asCopy(); + return new IndexShardSnapshotStatus(Stage.FAILURE, 0L, 0L, 0, 0, 0, 0, 0, 0, failure, null, "initialized as failed").asCopy(); } public static IndexShardSnapshotStatus.Copy newDone( @@ -326,7 +339,8 @@ public static IndexShardSnapshotStatus.Copy newDone( size, incrementalSize, null, - generation + generation, + "initialized as done" ).asCopy(); } @@ -345,6 +359,7 @@ public static class Copy { private final long processedSize; private final long incrementalSize; private final String failure; + private final String statusDescription; public Copy( final Stage stage, @@ -356,7 +371,8 @@ public Copy( final long incrementalSize, final long totalSize, final long processedSize, - final String failure + final String failure, + final String statusDescription ) { this.stage = stage; this.startTime = startTime; @@ -368,6 +384,7 @@ public Copy( this.processedSize = processedSize; this.incrementalSize = incrementalSize; this.failure = failure; + this.statusDescription = statusDescription; } public Stage getStage() { @@ -410,6 +427,10 @@ public String getFailure() { return failure; } + public String getStatusDescription() { + return statusDescription; + } + @Override public String toString() { return "index shard snapshot status (" @@ -433,6 +454,8 @@ public String toString() { + processedSize + ", failure='" + failure + + "', statusDescription='" + + statusDescription + '\'' + ')'; } @@ -461,6 +484,8 @@ public String toString() { + processedSize + ", failure='" + failure + + "', statusDescription='" + + statusDescription + '\'' + ')'; } diff --git a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java index cc22b60991703..11386eba10196 100644 --- a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java +++ b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java @@ -3186,6 +3186,7 @@ private void writeAtomic( @Override public void snapshotShard(SnapshotShardContext context) { + context.status().updateStatusDescription("queued in snapshot task runner"); shardSnapshotTaskRunner.enqueueShardSnapshot(context); } @@ -3198,6 +3199,7 @@ private void doSnapshotShard(SnapshotShardContext context) { final ShardId shardId = store.shardId(); final SnapshotId snapshotId = context.snapshotId(); final IndexShardSnapshotStatus snapshotStatus = context.status(); + snapshotStatus.updateStatusDescription("snapshot task runner: setting up shard snapshot"); final long startTime = threadPool.absoluteTimeInMillis(); try { final ShardGeneration generation = snapshotStatus.generation(); @@ -3206,6 +3208,7 @@ private void doSnapshotShard(SnapshotShardContext context) { final Set blobs; if (generation == null) { snapshotStatus.ensureNotAborted(); + snapshotStatus.updateStatusDescription("snapshot task runner: listing blob prefixes"); try { blobs = shardContainer.listBlobsByPrefix(OperationPurpose.SNAPSHOT_METADATA, SNAPSHOT_INDEX_PREFIX).keySet(); } catch (IOException e) { @@ -3216,6 +3219,7 @@ private void doSnapshotShard(SnapshotShardContext context) { } snapshotStatus.ensureNotAborted(); + snapshotStatus.updateStatusDescription("snapshot task runner: loading snapshot blobs"); Tuple tuple = buildBlobStoreIndexShardSnapshots( context.indexId(), shardId.id(), @@ -3316,6 +3320,7 @@ private void doSnapshotShard(SnapshotShardContext context) { indexCommitPointFiles = filesFromSegmentInfos; } + snapshotStatus.updateStatusDescription("snapshot task runner: starting shard snapshot"); snapshotStatus.moveToStarted( startTime, indexIncrementalFileCount, @@ -3342,6 +3347,7 @@ private void doSnapshotShard(SnapshotShardContext context) { BlobStoreIndexShardSnapshot.FileInfo.SERIALIZE_WRITER_UUID, Boolean.toString(writeFileInfoWriterUUID) ); + snapshotStatus.updateStatusDescription("snapshot task runner: updating blob store with new shard generation"); INDEX_SHARD_SNAPSHOTS_FORMAT.write( updatedBlobStoreIndexShardSnapshots, shardContainer, @@ -3387,6 +3393,7 @@ private void doSnapshotShard(SnapshotShardContext context) { BlobStoreIndexShardSnapshot.FileInfo.SERIALIZE_WRITER_UUID, Boolean.toString(writeFileInfoWriterUUID) ); + snapshotStatus.updateStatusDescription("no shard generations: writing new index-${N} file"); writeShardIndexBlobAtomic(shardContainer, newGen, updatedBlobStoreIndexShardSnapshots, serializationParams); } catch (IOException e) { throw new IndexShardSnapshotFailedException( @@ -3401,6 +3408,7 @@ private void doSnapshotShard(SnapshotShardContext context) { } snapshotStatus.addProcessedFiles(finalFilesInShardMetadataCount, finalFilesInShardMetadataSize); try { + snapshotStatus.updateStatusDescription("no shard generations: deleting blobs"); deleteFromContainer(OperationPurpose.SNAPSHOT_METADATA, shardContainer, blobsToDelete.iterator()); } catch (IOException e) { logger.warn( @@ -3414,6 +3422,7 @@ private void doSnapshotShard(SnapshotShardContext context) { // filesToSnapshot will be emptied while snapshotting the file. We make a copy here for cleanup purpose in case of failure. final AtomicReference> fileToCleanUp = new AtomicReference<>(List.copyOf(filesToSnapshot)); final ActionListener> allFilesUploadedListener = ActionListener.assertOnce(ActionListener.wrap(ignore -> { + snapshotStatus.updateStatusDescription("all files uploaded: finalizing"); final IndexShardSnapshotStatus.Copy lastSnapshotStatus = snapshotStatus.moveToFinalize(); // now create and write the commit point @@ -3435,6 +3444,7 @@ private void doSnapshotShard(SnapshotShardContext context) { BlobStoreIndexShardSnapshot.FileInfo.SERIALIZE_WRITER_UUID, Boolean.toString(writeFileInfoWriterUUID) ); + snapshotStatus.updateStatusDescription("all files uploaded: writing to index shard file"); INDEX_SHARD_SNAPSHOT_FORMAT.write( blobStoreIndexShardSnapshot, shardContainer, @@ -3451,10 +3461,12 @@ private void doSnapshotShard(SnapshotShardContext context) { ByteSizeValue.ofBytes(blobStoreIndexShardSnapshot.totalSize()), getSegmentInfoFileCount(blobStoreIndexShardSnapshot.indexFiles()) ); + snapshotStatus.updateStatusDescription("all files uploaded: done"); snapshotStatus.moveToDone(threadPool.absoluteTimeInMillis(), shardSnapshotResult); context.onResponse(shardSnapshotResult); }, e -> { try { + snapshotStatus.updateStatusDescription("all files uploaded: cleaning up data files, exception while finalizing: " + e); shardContainer.deleteBlobsIgnoringIfNotExists( OperationPurpose.SNAPSHOT_DATA, Iterators.flatMap(fileToCleanUp.get().iterator(), f -> Iterators.forRange(0, f.numberOfParts(), f::partName)) @@ -3517,6 +3529,7 @@ protected void snapshotFiles( ) { final int noOfFilesToSnapshot = filesToSnapshot.size(); final ActionListener filesListener = fileQueueListener(filesToSnapshot, noOfFilesToSnapshot, allFilesUploadedListener); + context.status().updateStatusDescription("enqueued file snapshot tasks: threads running concurrent file uploads"); for (int i = 0; i < noOfFilesToSnapshot; i++) { shardSnapshotTaskRunner.enqueueFileSnapshot(context, filesToSnapshot::poll, filesListener); } diff --git a/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java b/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java index 234c0239a68ce..90111c44fbd96 100644 --- a/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java +++ b/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java @@ -61,6 +61,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.function.Supplier; import static java.util.Collections.emptyMap; import static org.elasticsearch.core.Strings.format; @@ -108,6 +109,7 @@ public SnapshotShardsService( this.threadPool = transportService.getThreadPool(); this.snapshotShutdownProgressTracker = new SnapshotShutdownProgressTracker( () -> clusterService.state().nodes().getLocalNodeId(), + (callerLogger) -> logIndexShardSnapshotStatuses(callerLogger), clusterService.getClusterSettings(), threadPool ); @@ -234,6 +236,14 @@ public void beforeIndexShardClosed(ShardId shardId, @Nullable IndexShard indexSh } } + private void logIndexShardSnapshotStatuses(Logger callerLogger) { + for (var snapshotStatuses : shardSnapshots.values()) { + for (var shardSnapshot : snapshotStatuses.entrySet()) { + callerLogger.info(Strings.format("ShardId %s, %s", shardSnapshot.getKey(), shardSnapshot.getValue())); + } + } + } + /** * Returns status of shards that are snapshotted on the node and belong to the given snapshot *

@@ -321,7 +331,8 @@ private void handleUpdatedSnapshotsInProgressEntry(String localNodeId, boolean r sid, ShardState.FAILED, shard.getValue().reason(), - shard.getValue().generation() + shard.getValue().generation(), + () -> null ); } } else { @@ -372,6 +383,7 @@ private void startNewShardSnapshots(String localNodeId, SnapshotsInProgress.Entr + snapshotStatus.generation() + "] for snapshot with old-format compatibility"; shardSnapshotTasks.add(newShardSnapshotTask(shardId, snapshot, indexId, snapshotStatus, entry.version(), entry.startTime())); + snapshotStatus.updateStatusDescription("shard snapshot scheduled to start"); } threadPool.executor(ThreadPool.Names.SNAPSHOT).execute(() -> shardSnapshotTasks.forEach(Runnable::run)); @@ -383,6 +395,7 @@ private void pauseShardSnapshotsForNodeRemoval(String localNodeId, SnapshotsInPr for (final Map.Entry shardEntry : entry.shards().entrySet()) { final ShardId shardId = shardEntry.getKey(); final ShardSnapshotStatus masterShardSnapshotStatus = shardEntry.getValue(); + IndexShardSnapshotStatus indexShardSnapshotStatus = localShardSnapshots.get(shardId); if (masterShardSnapshotStatus.state() != ShardState.INIT) { // shard snapshot not currently scheduled by master @@ -402,7 +415,11 @@ private void pauseShardSnapshotsForNodeRemoval(String localNodeId, SnapshotsInPr shardId, ShardState.PAUSED_FOR_NODE_REMOVAL, "paused", - masterShardSnapshotStatus.generation() + masterShardSnapshotStatus.generation(), + () -> { + indexShardSnapshotStatus.updateStatusDescription("finished: master notification attempt complete"); + return null; + } ); } else { // shard snapshot currently running, mark for pause @@ -419,9 +436,16 @@ private Runnable newShardSnapshotTask( final IndexVersion entryVersion, final long entryStartTime ) { + Supplier postMasterNotificationAction = () -> { + snapshotStatus.updateStatusDescription("finished: master notification attempt complete"); + return null; + }; + + // Listener that runs on completion of the shard snapshot: it will notify the master node of success or failure. ActionListener snapshotResultListener = new ActionListener<>() { @Override public void onResponse(ShardSnapshotResult shardSnapshotResult) { + snapshotStatus.updateStatusDescription("snapshot succeeded: proceeding to notify master of success"); final ShardGeneration newGeneration = shardSnapshotResult.getGeneration(); assert newGeneration != null; assert newGeneration.equals(snapshotStatus.generation()); @@ -436,11 +460,13 @@ public void onResponse(ShardSnapshotResult shardSnapshotResult) { snapshotStatus.generation() ); } - notifySuccessfulSnapshotShard(snapshot, shardId, shardSnapshotResult); + + notifySuccessfulSnapshotShard(snapshot, shardId, shardSnapshotResult, postMasterNotificationAction); } @Override public void onFailure(Exception e) { + snapshotStatus.updateStatusDescription("failed with exception '" + e + ": proceeding to notify master of failure"); final String failure; final Stage nextStage; if (e instanceof AbortedSnapshotException) { @@ -457,7 +483,14 @@ public void onFailure(Exception e) { logger.warn(() -> format("[%s][%s] failed to snapshot shard", shardId, snapshot), e); } final var shardState = snapshotStatus.moveToUnsuccessful(nextStage, failure, threadPool.absoluteTimeInMillis()); - notifyUnsuccessfulSnapshotShard(snapshot, shardId, shardState, failure, snapshotStatus.generation()); + notifyUnsuccessfulSnapshotShard( + snapshot, + shardId, + shardState, + failure, + snapshotStatus.generation(), + postMasterNotificationAction + ); } }; @@ -508,6 +541,7 @@ private void snapshot( ActionListener resultListener ) { ActionListener.run(resultListener, listener -> { + snapshotStatus.updateStatusDescription("has started"); snapshotStatus.ensureNotAborted(); final IndexShard indexShard = indicesService.indexServiceSafe(shardId.getIndex()).getShard(shardId.id()); if (indexShard.routingEntry().primary() == false) { @@ -527,7 +561,9 @@ private void snapshot( final Repository repository = repositoriesService.repository(snapshot.getRepository()); SnapshotIndexCommit snapshotIndexCommit = null; try { + snapshotStatus.updateStatusDescription("acquiring commit reference from IndexShard: triggers a shard flush"); snapshotIndexCommit = new SnapshotIndexCommit(indexShard.acquireIndexCommitForSnapshot()); + snapshotStatus.updateStatusDescription("commit reference acquired, proceeding with snapshot"); final var shardStateId = getShardStateId(indexShard, snapshotIndexCommit.indexCommit()); // not aborted so indexCommit() ok snapshotStatus.addAbortListener(makeAbortListener(indexShard.shardId(), snapshot, snapshotIndexCommit)); snapshotStatus.ensureNotAborted(); @@ -652,8 +688,12 @@ private void syncShardStatsOnNewMaster(List entries) snapshot.snapshot(), shardId ); - notifySuccessfulSnapshotShard(snapshot.snapshot(), shardId, localShard.getValue().getShardSnapshotResult()); - + notifySuccessfulSnapshotShard( + snapshot.snapshot(), + shardId, + localShard.getValue().getShardSnapshotResult(), + () -> null + ); } else if (stage == Stage.FAILURE) { // but we think the shard failed - we need to make new master know that the shard failed logger.debug( @@ -667,7 +707,8 @@ private void syncShardStatsOnNewMaster(List entries) shardId, ShardState.FAILED, indexShardSnapshotStatus.getFailure(), - localShard.getValue().generation() + localShard.getValue().generation(), + () -> null ); } else if (stage == Stage.PAUSED) { // but we think the shard has paused - we need to make new master know that @@ -680,7 +721,8 @@ private void syncShardStatsOnNewMaster(List entries) shardId, ShardState.PAUSED_FOR_NODE_REMOVAL, indexShardSnapshotStatus.getFailure(), - localShard.getValue().generation() + localShard.getValue().generation(), + () -> null ); } } @@ -693,10 +735,20 @@ private void syncShardStatsOnNewMaster(List entries) /** * Notify the master node that the given shard snapshot completed successfully. */ - private void notifySuccessfulSnapshotShard(final Snapshot snapshot, final ShardId shardId, ShardSnapshotResult shardSnapshotResult) { + private void notifySuccessfulSnapshotShard( + final Snapshot snapshot, + final ShardId shardId, + ShardSnapshotResult shardSnapshotResult, + Supplier postMasterNotificationAction + ) { assert shardSnapshotResult != null; assert shardSnapshotResult.getGeneration() != null; - sendSnapshotShardUpdate(snapshot, shardId, ShardSnapshotStatus.success(clusterService.localNode().getId(), shardSnapshotResult)); + sendSnapshotShardUpdate( + snapshot, + shardId, + ShardSnapshotStatus.success(clusterService.localNode().getId(), shardSnapshotResult), + postMasterNotificationAction + ); } /** @@ -707,13 +759,15 @@ private void notifyUnsuccessfulSnapshotShard( final ShardId shardId, final ShardState shardState, final String failure, - final ShardGeneration generation + final ShardGeneration generation, + Supplier postMasterNotificationAction ) { assert shardState == ShardState.FAILED || shardState == ShardState.PAUSED_FOR_NODE_REMOVAL : shardState; sendSnapshotShardUpdate( snapshot, shardId, - new ShardSnapshotStatus(clusterService.localNode().getId(), shardState, generation, failure) + new ShardSnapshotStatus(clusterService.localNode().getId(), shardState, generation, failure), + postMasterNotificationAction ); if (shardState == ShardState.PAUSED_FOR_NODE_REMOVAL) { logger.debug( @@ -726,7 +780,12 @@ private void notifyUnsuccessfulSnapshotShard( } /** Updates the shard snapshot status by sending a {@link UpdateIndexShardSnapshotStatusRequest} to the master node */ - private void sendSnapshotShardUpdate(final Snapshot snapshot, final ShardId shardId, final ShardSnapshotStatus status) { + private void sendSnapshotShardUpdate( + final Snapshot snapshot, + final ShardId shardId, + final ShardSnapshotStatus status, + Supplier postMasterNotificationAction + ) { ActionListener updateResultListener = new ActionListener<>() { @Override public void onResponse(Void aVoid) { @@ -738,9 +797,11 @@ public void onFailure(Exception e) { logger.warn(() -> format("[%s][%s] failed to update snapshot state to [%s]", shardId, snapshot, status), e); } }; + snapshotShutdownProgressTracker.trackRequestSentToMaster(snapshot, shardId); var releaseTrackerRequestRunsBeforeResultListener = ActionListener.runBefore(updateResultListener, () -> { snapshotShutdownProgressTracker.releaseRequestSentToMaster(snapshot, shardId); + postMasterNotificationAction.get(); }); remoteFailedRequestDeduplicator.executeOnce( diff --git a/server/src/main/java/org/elasticsearch/snapshots/SnapshotShutdownProgressTracker.java b/server/src/main/java/org/elasticsearch/snapshots/SnapshotShutdownProgressTracker.java index 5d81e3c4e46af..45f2fb96fce4e 100644 --- a/server/src/main/java/org/elasticsearch/snapshots/SnapshotShutdownProgressTracker.java +++ b/server/src/main/java/org/elasticsearch/snapshots/SnapshotShutdownProgressTracker.java @@ -25,6 +25,7 @@ import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; import java.util.function.Supplier; /** @@ -45,6 +46,7 @@ public class SnapshotShutdownProgressTracker { private static final Logger logger = LogManager.getLogger(SnapshotShutdownProgressTracker.class); private final Supplier getLocalNodeId; + private final Consumer logIndexShardSnapshotStatuses; private final ThreadPool threadPool; private volatile TimeValue progressLoggerInterval; @@ -83,8 +85,14 @@ public class SnapshotShutdownProgressTracker { private final AtomicLong abortedCount = new AtomicLong(); private final AtomicLong pausedCount = new AtomicLong(); - public SnapshotShutdownProgressTracker(Supplier localNodeIdSupplier, ClusterSettings clusterSettings, ThreadPool threadPool) { + public SnapshotShutdownProgressTracker( + Supplier localNodeIdSupplier, + Consumer logShardStatuses, + ClusterSettings clusterSettings, + ThreadPool threadPool + ) { this.getLocalNodeId = localNodeIdSupplier; + this.logIndexShardSnapshotStatuses = logShardStatuses; clusterSettings.initializeAndWatch( SNAPSHOT_PROGRESS_DURING_SHUTDOWN_LOG_INTERVAL_SETTING, value -> this.progressLoggerInterval = value @@ -122,14 +130,14 @@ private void cancelProgressLogger() { } /** - * Logs some statistics about shard snapshot progress. + * Logs information about shard snapshot progress. */ private void logProgressReport() { logger.info( """ Current active shard snapshot stats on data node [{}]. \ - Node shutdown cluster state update received at [{}]. \ - Finished signalling shard snapshots to pause at [{}]. \ + Node shutdown cluster state update received at [{} millis]. \ + Finished signalling shard snapshots to pause at [{} millis]. \ Number shard snapshots running [{}]. \ Number shard snapshots waiting for master node reply to status update request [{}] \ Shard snapshot completion stats since shutdown began: Done [{}]; Failed [{}]; Aborted [{}]; Paused [{}]\ @@ -144,6 +152,8 @@ private void logProgressReport() { abortedCount.get(), pausedCount.get() ); + // Use a callback to log the shard snapshot details. + logIndexShardSnapshotStatuses.accept(logger); } /** diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotShutdownProgressTrackerTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotShutdownProgressTrackerTests.java index fbf742ae2ea57..8adcb3eb9d5f4 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotShutdownProgressTrackerTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotShutdownProgressTrackerTests.java @@ -110,8 +110,10 @@ void simulateShardSnapshotsCompleting(SnapshotShutdownProgressTracker tracker, i } public void testTrackerLogsStats() { + final String dummyStatusMsg = "Dummy log message for index shard snapshot statuses"; SnapshotShutdownProgressTracker tracker = new SnapshotShutdownProgressTracker( getLocalNodeIdSupplier, + (callerLogger) -> callerLogger.info(dummyStatusMsg), clusterSettings, testThreadPool ); @@ -144,6 +146,14 @@ public void testTrackerLogsStats() { "*Shard snapshot completion stats since shutdown began: Done [2]; Failed [1]; Aborted [1]; Paused [1]*" ) ); + mockLog.addExpectation( + new MockLog.SeenEventExpectation( + "index shard snapshot statuses", + SnapshotShutdownProgressTracker.class.getCanonicalName(), + Level.INFO, + dummyStatusMsg + ) + ); // Simulate updating the shard snapshot completion stats. simulateShardSnapshotsCompleting(tracker, 5); @@ -171,6 +181,7 @@ public void testTrackerProgressLoggingIntervalSettingCanBeDisabled() { ); SnapshotShutdownProgressTracker tracker = new SnapshotShutdownProgressTracker( getLocalNodeIdSupplier, + (callerLogger) -> {}, clusterSettingsDisabledLogging, testThreadPool ); @@ -214,6 +225,7 @@ public void testTrackerIntervalSettingDynamically() { ); SnapshotShutdownProgressTracker tracker = new SnapshotShutdownProgressTracker( getLocalNodeIdSupplier, + (callerLogger) -> {}, clusterSettingsDisabledLogging, testThreadPool ); @@ -253,6 +265,7 @@ public void testTrackerIntervalSettingDynamically() { public void testTrackerPauseTimestamp() { SnapshotShutdownProgressTracker tracker = new SnapshotShutdownProgressTracker( getLocalNodeIdSupplier, + (callerLogger) -> {}, clusterSettings, testThreadPool ); @@ -263,7 +276,7 @@ public void testTrackerPauseTimestamp() { "pausing timestamp should be set", SnapshotShutdownProgressTracker.class.getName(), Level.INFO, - "*Finished signalling shard snapshots to pause at [" + testThreadPool.relativeTimeInMillis() + "]*" + "*Finished signalling shard snapshots to pause at [" + testThreadPool.relativeTimeInMillis() + " millis]*" ) ); @@ -283,6 +296,7 @@ public void testTrackerPauseTimestamp() { public void testTrackerRequestsToMaster() { SnapshotShutdownProgressTracker tracker = new SnapshotShutdownProgressTracker( getLocalNodeIdSupplier, + (callerLogger) -> {}, clusterSettings, testThreadPool ); @@ -335,6 +349,7 @@ public void testTrackerRequestsToMaster() { public void testTrackerClearShutdown() { SnapshotShutdownProgressTracker tracker = new SnapshotShutdownProgressTracker( getLocalNodeIdSupplier, + (callerLogger) -> {}, clusterSettings, testThreadPool ); @@ -345,7 +360,7 @@ public void testTrackerClearShutdown() { "pausing timestamp should be unset", SnapshotShutdownProgressTracker.class.getName(), Level.INFO, - "*Finished signalling shard snapshots to pause at [-1]*" + "*Finished signalling shard snapshots to pause at [-1 millis]*" ) ); From 730f42a293ca06560deae2dcfdc2c3943662ce2c Mon Sep 17 00:00:00 2001 From: Luca Cavanna Date: Tue, 10 Dec 2024 22:54:04 +0100 Subject: [PATCH 08/28] Handle all exceptions in data nodes can match (#117469) During the can match phase, prior to the query phase, we may have exceptions that are returned back to the coordinating node, handled gracefully as if the shard returned canMatch=true. During the query phase, we perform an additional rewrite and can match phase to eventually shortcut the query phase for the shard. That needs to handle exceptions as well. Currently, an exception there causes shard failures, while we should rather go ahead and execute the query on the shard. Instead of adding another try catch on consumers code, this commit adds exception handling to the method itself so that it can no longer throw exceptions and similar mistakes can no longer be made in the future. At the same time, this commit makes the can match method more easily testable without requiring a full-blown SearchService instance. Closes #104994 --- docs/changelog/117469.yaml | 6 + .../elasticsearch/search/SearchService.java | 182 +- .../search/SearchServiceSingleNodeTests.java | 3011 ++++++++++++++++ .../search/SearchServiceTests.java | 3148 ++--------------- .../action/TransportTermsEnumAction.java | 2 +- .../index/engine/frozen/FrozenIndexTests.java | 41 +- .../BaseSearchableSnapshotsIntegTestCase.java | 12 + ...pshotsCanMatchOnCoordinatorIntegTests.java | 14 - .../SearchableSnapshotsSearchIntegTests.java | 129 + 9 files changed, 3524 insertions(+), 3021 deletions(-) create mode 100644 docs/changelog/117469.yaml create mode 100644 server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java create mode 100644 x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsSearchIntegTests.java diff --git a/docs/changelog/117469.yaml b/docs/changelog/117469.yaml new file mode 100644 index 0000000000000..cfb14f78cb578 --- /dev/null +++ b/docs/changelog/117469.yaml @@ -0,0 +1,6 @@ +pr: 117469 +summary: Handle exceptions in query phase can match +area: Search +type: bug +issues: + - 104994 diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index cec1fa8712ec1..b9bd398500c71 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -147,6 +147,8 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.BiFunction; +import java.util.function.Function; import java.util.function.LongSupplier; import java.util.function.Supplier; @@ -549,16 +551,17 @@ public void executeQueryPhase(ShardSearchRequest request, SearchShardTask task, // check if we can shortcut the query phase entirely. if (orig.canReturnNullResponseIfMatchNoDocs()) { assert orig.scroll() == null; - final CanMatchShardResponse canMatchResp; - try { - ShardSearchRequest clone = new ShardSearchRequest(orig); - canMatchResp = canMatch(clone, false); - } catch (Exception exc) { - l.onFailure(exc); - return; - } + ShardSearchRequest clone = new ShardSearchRequest(orig); + CanMatchContext canMatchContext = new CanMatchContext( + clone, + indicesService::indexServiceSafe, + this::findReaderContext, + defaultKeepAlive, + maxKeepAlive + ); + CanMatchShardResponse canMatchResp = canMatch(canMatchContext, false); if (canMatchResp.canMatch() == false) { - l.onResponse(QuerySearchResult.nullInstance()); + listener.onResponse(QuerySearchResult.nullInstance()); return; } } @@ -1191,10 +1194,14 @@ public void freeAllScrollContexts() { } private long getKeepAlive(ShardSearchRequest request) { + return getKeepAlive(request, defaultKeepAlive, maxKeepAlive); + } + + private static long getKeepAlive(ShardSearchRequest request, long defaultKeepAlive, long maxKeepAlive) { if (request.scroll() != null) { - return getScrollKeepAlive(request.scroll()); + return getScrollKeepAlive(request.scroll(), defaultKeepAlive, maxKeepAlive); } else if (request.keepAlive() != null) { - checkKeepAliveLimit(request.keepAlive().millis()); + checkKeepAliveLimit(request.keepAlive().millis(), maxKeepAlive); return request.keepAlive().getMillis(); } else { return request.readerId() == null ? defaultKeepAlive : -1; @@ -1202,14 +1209,22 @@ private long getKeepAlive(ShardSearchRequest request) { } private long getScrollKeepAlive(Scroll scroll) { + return getScrollKeepAlive(scroll, defaultKeepAlive, maxKeepAlive); + } + + private static long getScrollKeepAlive(Scroll scroll, long defaultKeepAlive, long maxKeepAlive) { if (scroll != null && scroll.keepAlive() != null) { - checkKeepAliveLimit(scroll.keepAlive().millis()); + checkKeepAliveLimit(scroll.keepAlive().millis(), maxKeepAlive); return scroll.keepAlive().getMillis(); } return defaultKeepAlive; } private void checkKeepAliveLimit(long keepAlive) { + checkKeepAliveLimit(keepAlive, maxKeepAlive); + } + + private static void checkKeepAliveLimit(long keepAlive, long maxKeepAlive) { if (keepAlive > maxKeepAlive) { throw new IllegalArgumentException( "Keep alive for request (" @@ -1620,6 +1635,7 @@ public void canMatch(CanMatchNodeRequest request, ActionListener responses = new ArrayList<>(shardLevelRequests.size()); for (var shardLevelRequest : shardLevelRequests) { try { + // TODO remove the exception handling as it's now in canMatch itself responses.add(new CanMatchNodeResponse.ResponseOrFailure(canMatch(request.createShardSearchRequest(shardLevelRequest)))); } catch (Exception e) { responses.add(new CanMatchNodeResponse.ResponseOrFailure(e)); @@ -1631,82 +1647,145 @@ public void canMatch(CanMatchNodeRequest request, ActionListener indexServiceLookup; + private final BiFunction findReaderContext; + private final long defaultKeepAlive; + private final long maxKeepAlive; + + private IndexService indexService; + + CanMatchContext( + ShardSearchRequest request, + Function indexServiceLookup, + BiFunction findReaderContext, + long defaultKeepAlive, + long maxKeepAlive + ) { + this.request = request; + this.indexServiceLookup = indexServiceLookup; + this.findReaderContext = findReaderContext; + this.defaultKeepAlive = defaultKeepAlive; + this.maxKeepAlive = maxKeepAlive; + } + + long getKeepAlive() { + return SearchService.getKeepAlive(request, defaultKeepAlive, maxKeepAlive); + } + + ReaderContext findReaderContext() { + return findReaderContext.apply(request.readerId(), request); + } + + QueryRewriteContext getQueryRewriteContext(IndexService indexService) { + return indexService.newQueryRewriteContext(request::nowInMillis, request.getRuntimeMappings(), request.getClusterAlias()); + } + + SearchExecutionContext getSearchExecutionContext(Engine.Searcher searcher) { + return getIndexService().newSearchExecutionContext( + request.shardId().id(), + 0, + searcher, + request::nowInMillis, + request.getClusterAlias(), + request.getRuntimeMappings() + ); + } + + IndexShard getShard() { + return getIndexService().getShard(request.shardId().getId()); + } + + IndexService getIndexService() { + if (this.indexService == null) { + this.indexService = indexServiceLookup.apply(request.shardId().getIndex()); + } + return this.indexService; + } + } + + static CanMatchShardResponse canMatch(CanMatchContext canMatchContext, boolean checkRefreshPending) { + assert canMatchContext.request.searchType() == SearchType.QUERY_THEN_FETCH + : "unexpected search type: " + canMatchContext.request.searchType(); Releasable releasable = null; try { IndexService indexService; final boolean hasRefreshPending; final Engine.Searcher canMatchSearcher; - if (request.readerId() != null) { + if (canMatchContext.request.readerId() != null) { hasRefreshPending = false; ReaderContext readerContext; Engine.Searcher searcher; try { - readerContext = findReaderContext(request.readerId(), request); - releasable = readerContext.markAsUsed(getKeepAlive(request)); + readerContext = canMatchContext.findReaderContext(); + releasable = readerContext.markAsUsed(canMatchContext.getKeepAlive()); indexService = readerContext.indexService(); - if (canMatchAfterRewrite(request, indexService) == false) { + QueryRewriteContext queryRewriteContext = canMatchContext.getQueryRewriteContext(indexService); + if (queryStillMatchesAfterRewrite(canMatchContext.request, queryRewriteContext) == false) { return new CanMatchShardResponse(false, null); } searcher = readerContext.acquireSearcher(Engine.CAN_MATCH_SEARCH_SOURCE); } catch (SearchContextMissingException e) { - final String searcherId = request.readerId().getSearcherId(); + final String searcherId = canMatchContext.request.readerId().getSearcherId(); if (searcherId == null) { - throw e; + return new CanMatchShardResponse(true, null); } - indexService = indicesService.indexServiceSafe(request.shardId().getIndex()); - if (canMatchAfterRewrite(request, indexService) == false) { + if (queryStillMatchesAfterRewrite( + canMatchContext.request, + canMatchContext.getQueryRewriteContext(canMatchContext.getIndexService()) + ) == false) { return new CanMatchShardResponse(false, null); } - IndexShard indexShard = indexService.getShard(request.shardId().getId()); - final Engine.SearcherSupplier searcherSupplier = indexShard.acquireSearcherSupplier(); + final Engine.SearcherSupplier searcherSupplier = canMatchContext.getShard().acquireSearcherSupplier(); if (searcherId.equals(searcherSupplier.getSearcherId()) == false) { searcherSupplier.close(); - throw e; + return new CanMatchShardResponse(true, null); } releasable = searcherSupplier; searcher = searcherSupplier.acquireSearcher(Engine.CAN_MATCH_SEARCH_SOURCE); } canMatchSearcher = searcher; } else { - indexService = indicesService.indexServiceSafe(request.shardId().getIndex()); - if (canMatchAfterRewrite(request, indexService) == false) { + if (queryStillMatchesAfterRewrite( + canMatchContext.request, + canMatchContext.getQueryRewriteContext(canMatchContext.getIndexService()) + ) == false) { return new CanMatchShardResponse(false, null); } - IndexShard indexShard = indexService.getShard(request.shardId().getId()); - boolean needsWaitForRefresh = request.waitForCheckpoint() != UNASSIGNED_SEQ_NO; + boolean needsWaitForRefresh = canMatchContext.request.waitForCheckpoint() != UNASSIGNED_SEQ_NO; // If this request wait_for_refresh behavior, it is safest to assume a refresh is pending. Theoretically, // this can be improved in the future by manually checking that the requested checkpoint has already been refresh. // However, this will request modifying the engine to surface that information. + IndexShard indexShard = canMatchContext.getShard(); hasRefreshPending = needsWaitForRefresh || (indexShard.hasRefreshPending() && checkRefreshPending); canMatchSearcher = indexShard.acquireSearcher(Engine.CAN_MATCH_SEARCH_SOURCE); } try (canMatchSearcher) { - SearchExecutionContext context = indexService.newSearchExecutionContext( - request.shardId().id(), - 0, - canMatchSearcher, - request::nowInMillis, - request.getClusterAlias(), - request.getRuntimeMappings() - ); - final boolean canMatch = queryStillMatchesAfterRewrite(request, context); - final MinAndMax minMax; + SearchExecutionContext context = canMatchContext.getSearchExecutionContext(canMatchSearcher); + final boolean canMatch = queryStillMatchesAfterRewrite(canMatchContext.request, context); if (canMatch || hasRefreshPending) { - FieldSortBuilder sortBuilder = FieldSortBuilder.getPrimaryFieldSortOrNull(request.source()); - minMax = sortBuilder != null ? FieldSortBuilder.getMinMaxOrNull(context, sortBuilder) : null; - } else { - minMax = null; + FieldSortBuilder sortBuilder = FieldSortBuilder.getPrimaryFieldSortOrNull(canMatchContext.request.source()); + final MinAndMax minMax = sortBuilder != null ? FieldSortBuilder.getMinMaxOrNull(context, sortBuilder) : null; + return new CanMatchShardResponse(true, minMax); } - return new CanMatchShardResponse(canMatch || hasRefreshPending, minMax); + return new CanMatchShardResponse(false, null); } + } catch (Exception e) { + return new CanMatchShardResponse(true, null); } finally { Releasables.close(releasable); } @@ -1719,15 +1798,6 @@ private CanMatchShardResponse canMatch(ShardSearchRequest request, boolean check * {@link MatchNoneQueryBuilder}. This allows us to avoid extra work for example making the shard search active and waiting for * refreshes. */ - private static boolean canMatchAfterRewrite(final ShardSearchRequest request, final IndexService indexService) throws IOException { - final QueryRewriteContext queryRewriteContext = indexService.newQueryRewriteContext( - request::nowInMillis, - request.getRuntimeMappings(), - request.getClusterAlias() - ); - return queryStillMatchesAfterRewrite(request, queryRewriteContext); - } - @SuppressWarnings("unchecked") public static boolean queryStillMatchesAfterRewrite(ShardSearchRequest request, QueryRewriteContext context) throws IOException { Rewriteable.rewrite(request.getRewriteable(), context, false); diff --git a/server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java b/server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java new file mode 100644 index 0000000000000..02593e41f5d84 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java @@ -0,0 +1,3011 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.search; + +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FilterDirectoryReader; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHitCountCollectorManager; +import org.apache.lucene.store.AlreadyClosedException; +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.DocWriteResponse; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.admin.cluster.settings.ClusterUpdateSettingsResponse; +import org.elasticsearch.action.search.ClearScrollRequest; +import org.elasticsearch.action.search.ClosePointInTimeRequest; +import org.elasticsearch.action.search.OpenPointInTimeRequest; +import org.elasticsearch.action.search.SearchPhaseController; +import org.elasticsearch.action.search.SearchPhaseExecutionException; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.SearchScrollRequest; +import org.elasticsearch.action.search.SearchShardTask; +import org.elasticsearch.action.search.SearchType; +import org.elasticsearch.action.search.TransportClosePointInTimeAction; +import org.elasticsearch.action.search.TransportOpenPointInTimeAction; +import org.elasticsearch.action.search.TransportSearchAction; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.routing.ShardRouting; +import org.elasticsearch.cluster.routing.ShardRoutingState; +import org.elasticsearch.cluster.routing.TestShardRouting; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.UUIDs; +import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.AbstractRefCounted; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.IndexModule; +import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.index.IndexService; +import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.engine.Engine; +import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.query.MatchNoneQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.index.query.TermQueryBuilder; +import org.elasticsearch.index.search.stats.SearchStats; +import org.elasticsearch.index.shard.IndexShard; +import org.elasticsearch.index.shard.SearchOperationListener; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.indices.IndicesService; +import org.elasticsearch.indices.settings.InternalOrPrivateSettingsPlugin; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.SearchPlugin; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.script.MockScriptEngine; +import org.elasticsearch.script.MockScriptPlugin; +import org.elasticsearch.script.Script; +import org.elasticsearch.script.ScriptType; +import org.elasticsearch.search.SearchService.ResultsType; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.AggregationReduceContext; +import org.elasticsearch.search.aggregations.MultiBucketConsumerService; +import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregationBuilder; +import org.elasticsearch.search.aggregations.bucket.global.GlobalAggregationBuilder; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.elasticsearch.search.aggregations.support.AggregationContext; +import org.elasticsearch.search.aggregations.support.ValueType; +import org.elasticsearch.search.builder.PointInTimeBuilder; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.dfs.AggregatedDfs; +import org.elasticsearch.search.fetch.FetchSearchResult; +import org.elasticsearch.search.fetch.ShardFetchRequest; +import org.elasticsearch.search.fetch.ShardFetchSearchRequest; +import org.elasticsearch.search.fetch.subphase.FieldAndFormat; +import org.elasticsearch.search.internal.AliasFilter; +import org.elasticsearch.search.internal.ContextIndexSearcher; +import org.elasticsearch.search.internal.ReaderContext; +import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.search.internal.ShardSearchContextId; +import org.elasticsearch.search.internal.ShardSearchRequest; +import org.elasticsearch.search.query.NonCountingTermQuery; +import org.elasticsearch.search.query.QuerySearchRequest; +import org.elasticsearch.search.query.QuerySearchResult; +import org.elasticsearch.search.query.SearchTimeoutException; +import org.elasticsearch.search.rank.RankBuilder; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.rank.RankShardResult; +import org.elasticsearch.search.rank.TestRankBuilder; +import org.elasticsearch.search.rank.TestRankShardResult; +import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; +import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; +import org.elasticsearch.search.rank.feature.RankFeatureDoc; +import org.elasticsearch.search.rank.feature.RankFeatureResult; +import org.elasticsearch.search.rank.feature.RankFeatureShardRequest; +import org.elasticsearch.search.rank.feature.RankFeatureShardResult; +import org.elasticsearch.search.suggest.SuggestBuilder; +import org.elasticsearch.tasks.TaskCancelHelper; +import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.test.hamcrest.ElasticsearchAssertions; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.LinkedList; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Semaphore; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.IntConsumer; +import java.util.function.Supplier; + +import static java.util.Collections.emptyList; +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonList; +import static org.elasticsearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; +import static org.elasticsearch.indices.cluster.IndicesClusterStateService.AllocatedIndices.IndexRemovalReason.DELETED; +import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; +import static org.elasticsearch.search.SearchService.QUERY_PHASE_PARALLEL_COLLECTION_ENABLED; +import static org.elasticsearch.search.SearchService.SEARCH_WORKER_THREADS_ENABLED; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchHits; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.CoreMatchers.startsWith; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.not; +import static org.mockito.Mockito.mock; + +public class SearchServiceSingleNodeTests extends ESSingleNodeTestCase { + + @Override + protected boolean resetNodeAfterTest() { + return true; + } + + @Override + protected Collection> getPlugins() { + return pluginList( + FailOnRewriteQueryPlugin.class, + CustomScriptPlugin.class, + ReaderWrapperCountPlugin.class, + InternalOrPrivateSettingsPlugin.class, + MockSearchService.TestPlugin.class + ); + } + + public static class ReaderWrapperCountPlugin extends Plugin { + @Override + public void onIndexModule(IndexModule indexModule) { + indexModule.setReaderWrapper(service -> SearchServiceSingleNodeTests::apply); + } + } + + @Before + public void resetCount() { + numWrapInvocations = new AtomicInteger(0); + } + + private static AtomicInteger numWrapInvocations = new AtomicInteger(0); + + private static DirectoryReader apply(DirectoryReader directoryReader) throws IOException { + numWrapInvocations.incrementAndGet(); + return new FilterDirectoryReader(directoryReader, new FilterDirectoryReader.SubReaderWrapper() { + @Override + public LeafReader wrap(LeafReader reader) { + return reader; + } + }) { + @Override + protected DirectoryReader doWrapDirectoryReader(DirectoryReader in) throws IOException { + return in; + } + + @Override + public CacheHelper getReaderCacheHelper() { + return directoryReader.getReaderCacheHelper(); + } + }; + } + + public static class CustomScriptPlugin extends MockScriptPlugin { + + static final String DUMMY_SCRIPT = "dummyScript"; + + @Override + protected Map, Object>> pluginScripts() { + return Collections.singletonMap(DUMMY_SCRIPT, vars -> "dummy"); + } + + @Override + public void onIndexModule(IndexModule indexModule) { + indexModule.addSearchOperationListener(new SearchOperationListener() { + @Override + public void onFetchPhase(SearchContext context, long tookInNanos) { + if ("throttled_threadpool_index".equals(context.indexShard().shardId().getIndex().getName())) { + assertThat(Thread.currentThread().getName(), startsWith("elasticsearch[node_s_0][search_throttled]")); + } else { + assertThat(Thread.currentThread().getName(), startsWith("elasticsearch[node_s_0][search]")); + } + } + + @Override + public void onQueryPhase(SearchContext context, long tookInNanos) { + if ("throttled_threadpool_index".equals(context.indexShard().shardId().getIndex().getName())) { + assertThat(Thread.currentThread().getName(), startsWith("elasticsearch[node_s_0][search_throttled]")); + } else { + assertThat(Thread.currentThread().getName(), startsWith("elasticsearch[node_s_0][search]")); + } + } + }); + } + } + + @Override + protected Settings nodeSettings() { + return Settings.builder().put("search.default_search_timeout", "5s").build(); + } + + public void testClearOnClose() { + createIndex("index"); + prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); + assertResponse( + client().prepareSearch("index").setSize(1).setScroll(TimeValue.timeValueMinutes(1)), + searchResponse -> assertThat(searchResponse.getScrollId(), is(notNullValue())) + ); + SearchService service = getInstanceFromNode(SearchService.class); + + assertEquals(1, service.getActiveContexts()); + service.doClose(); // this kills the keep-alive reaper we have to reset the node after this test + assertEquals(0, service.getActiveContexts()); + } + + public void testClearOnStop() { + createIndex("index"); + prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); + assertResponse( + client().prepareSearch("index").setSize(1).setScroll(TimeValue.timeValueMinutes(1)), + searchResponse -> assertThat(searchResponse.getScrollId(), is(notNullValue())) + ); + SearchService service = getInstanceFromNode(SearchService.class); + + assertEquals(1, service.getActiveContexts()); + service.doStop(); + assertEquals(0, service.getActiveContexts()); + } + + public void testClearIndexDelete() { + createIndex("index"); + prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); + assertResponse( + client().prepareSearch("index").setSize(1).setScroll(TimeValue.timeValueMinutes(1)), + searchResponse -> assertThat(searchResponse.getScrollId(), is(notNullValue())) + ); + SearchService service = getInstanceFromNode(SearchService.class); + + assertEquals(1, service.getActiveContexts()); + assertAcked(indicesAdmin().prepareDelete("index")); + awaitIndexShardCloseAsyncTasks(); + assertEquals(0, service.getActiveContexts()); + } + + public void testCloseSearchContextOnRewriteException() { + // if refresh happens while checking the exception, the subsequent reference count might not match, so we switch it off + createIndex("index", Settings.builder().put("index.refresh_interval", -1).build()); + prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); + + SearchService service = getInstanceFromNode(SearchService.class); + IndicesService indicesService = getInstanceFromNode(IndicesService.class); + IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); + IndexShard indexShard = indexService.getShard(0); + + final int activeContexts = service.getActiveContexts(); + final int activeRefs = indexShard.store().refCount(); + expectThrows( + SearchPhaseExecutionException.class, + () -> client().prepareSearch("index").setQuery(new FailOnRewriteQueryBuilder()).get() + ); + assertEquals(activeContexts, service.getActiveContexts()); + assertEquals(activeRefs, indexShard.store().refCount()); + } + + public void testSearchWhileIndexDeleted() throws InterruptedException { + createIndex("index"); + prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); + + SearchService service = getInstanceFromNode(SearchService.class); + IndicesService indicesService = getInstanceFromNode(IndicesService.class); + IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); + IndexShard indexShard = indexService.getShard(0); + AtomicBoolean running = new AtomicBoolean(true); + CountDownLatch startGun = new CountDownLatch(1); + final int permitCount = 100; + Semaphore semaphore = new Semaphore(permitCount); + ShardRouting routing = TestShardRouting.newShardRouting( + indexShard.shardId(), + randomAlphaOfLength(5), + randomBoolean(), + ShardRoutingState.INITIALIZING + ); + final Thread thread = new Thread(() -> { + startGun.countDown(); + while (running.get()) { + if (randomBoolean()) { + service.afterIndexRemoved(indexService.index(), indexService.getIndexSettings(), DELETED); + } else { + service.beforeIndexShardCreated(routing, indexService.getIndexSettings().getSettings()); + } + if (randomBoolean()) { + // here we trigger some refreshes to ensure the IR go out of scope such that we hit ACE if we access a search + // context in a non-sane way. + try { + semaphore.acquire(); + } catch (InterruptedException e) { + throw new AssertionError(e); + } + prepareIndex("index").setSource("field", "value") + .setRefreshPolicy(randomFrom(WriteRequest.RefreshPolicy.values())) + .execute(ActionListener.running(semaphore::release)); + } + } + }); + thread.start(); + startGun.await(); + try { + final int rounds = scaledRandomIntBetween(100, 10000); + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + SearchRequest scrollSearchRequest = new SearchRequest().allowPartialSearchResults(true) + .scroll(new Scroll(TimeValue.timeValueMinutes(1))); + for (int i = 0; i < rounds; i++) { + try { + try { + PlainActionFuture result = new PlainActionFuture<>(); + final boolean useScroll = randomBoolean(); + service.executeQueryPhase( + new ShardSearchRequest( + OriginalIndices.NONE, + useScroll ? scrollSearchRequest : searchRequest, + indexShard.shardId(), + 0, + 1, + AliasFilter.EMPTY, + 1.0f, + -1, + null + ), + new SearchShardTask(123L, "", "", "", null, emptyMap()), + result.delegateFailure((l, r) -> { + r.incRef(); + l.onResponse(r); + }) + ); + final SearchPhaseResult searchPhaseResult = result.get(); + try { + List intCursors = new ArrayList<>(1); + intCursors.add(0); + ShardFetchRequest req = new ShardFetchRequest( + searchPhaseResult.getContextId(), + intCursors, + null/* not a scroll */ + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.executeFetchPhase(req, new SearchShardTask(123L, "", "", "", null, emptyMap()), listener); + listener.get(); + if (useScroll) { + // have to free context since this test does not remove the index from IndicesService. + service.freeReaderContext(searchPhaseResult.getContextId()); + } + } finally { + searchPhaseResult.decRef(); + } + } catch (ExecutionException ex) { + assertThat(ex.getCause(), instanceOf(RuntimeException.class)); + throw ((RuntimeException) ex.getCause()); + } + } catch (AlreadyClosedException ex) { + throw ex; + } catch (IllegalStateException ex) { + assertEquals(AbstractRefCounted.ALREADY_CLOSED_MESSAGE, ex.getMessage()); + } catch (SearchContextMissingException ex) { + // that's fine + } + } + } finally { + running.set(false); + thread.join(); + semaphore.acquire(permitCount); + } + + assertEquals(0, service.getActiveContexts()); + + SearchStats.Stats totalStats = indexShard.searchStats().getTotal(); + assertEquals(0, totalStats.getQueryCurrent()); + assertEquals(0, totalStats.getScrollCurrent()); + assertEquals(0, totalStats.getFetchCurrent()); + } + + public void testRankFeaturePhaseSearchPhases() throws InterruptedException, ExecutionException { + final String indexName = "index"; + final String rankFeatureFieldName = "field"; + final String searchFieldName = "search_field"; + final String searchFieldValue = "some_value"; + final String fetchFieldName = "fetch_field"; + final String fetchFieldValue = "fetch_value"; + + final int minDocs = 3; + final int maxDocs = 10; + int numDocs = between(minDocs, maxDocs); + createIndex(indexName); + // index some documents + for (int i = 0; i < numDocs; i++) { + prepareIndex(indexName).setId(String.valueOf(i)) + .setSource( + rankFeatureFieldName, + "aardvark_" + i, + searchFieldName, + searchFieldValue, + fetchFieldName, + fetchFieldValue + "_" + i + ) + .get(); + } + indicesAdmin().prepareRefresh(indexName).get(); + + final SearchService service = getInstanceFromNode(SearchService.class); + + final IndicesService indicesService = getInstanceFromNode(IndicesService.class); + final IndexService indexService = indicesService.indexServiceSafe(resolveIndex(indexName)); + final IndexShard indexShard = indexService.getShard(0); + SearchShardTask searchTask = new SearchShardTask(123L, "", "", "", null, emptyMap()); + + // create a SearchRequest that will return all documents and defines a TestRankBuilder with shard-level only operations + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true) + .source( + new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue)) + .size(DEFAULT_SIZE) + .fetchField(fetchFieldName) + .rankBuilder( + // here we override only the shard-level contexts + new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + return new QueryPhaseRankShardContext(queries, from) { + + @Override + public int rankWindowSize() { + return DEFAULT_RANK_WINDOW_SIZE; + } + + @Override + public RankShardResult combineQueryPhaseResults(List rankResults) { + // we know we have just 1 query, so return all the docs from it + return new TestRankShardResult( + Arrays.stream(rankResults.get(0).scoreDocs) + .map(x -> new RankDoc(x.doc, x.score, x.shardIndex)) + .limit(rankWindowSize()) + .toArray(RankDoc[]::new) + ); + } + }; + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; + for (int i = 0; i < hits.getHits().length; i++) { + SearchHit hit = hits.getHits()[i]; + rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); + rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue()); + rankFeatureDocs[i].score = (numDocs - i) + randomFloat(); + rankFeatureDocs[i].rank = i + 1; + } + return new RankFeatureShardResult(rankFeatureDocs); + } + }; + } + } + ) + ); + + ShardSearchRequest request = new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + indexShard.shardId(), + 0, + 1, + AliasFilter.EMPTY, + 1.0f, + -1, + null + ); + QuerySearchResult queryResult = null; + RankFeatureResult rankResult = null; + try { + // Execute the query phase and store the result in a SearchPhaseResult container using a PlainActionFuture + PlainActionFuture queryPhaseResults = new PlainActionFuture<>(); + service.executeQueryPhase(request, searchTask, queryPhaseResults); + queryResult = (QuerySearchResult) queryPhaseResults.get(); + + // these are the matched docs from the query phase + final RankDoc[] queryRankDocs = ((TestRankShardResult) queryResult.getRankShardResult()).testRankDocs; + + // assume that we have cut down to these from the coordinator node as the top-docs to run the rank feature phase upon + List topRankWindowSizeDocs = randomNonEmptySubsetOf(Arrays.stream(queryRankDocs).map(x -> x.doc).toList()); + + // now we create a RankFeatureShardRequest to extract feature info for the top-docs above + RankFeatureShardRequest rankFeatureShardRequest = new RankFeatureShardRequest( + OriginalIndices.NONE, + queryResult.getContextId(), // use the context from the query phase + request, + topRankWindowSizeDocs + ); + PlainActionFuture rankPhaseResults = new PlainActionFuture<>(); + service.executeRankFeaturePhase(rankFeatureShardRequest, searchTask, rankPhaseResults); + rankResult = rankPhaseResults.get(); + + assertNotNull(rankResult); + assertNotNull(rankResult.rankFeatureResult()); + RankFeatureShardResult rankFeatureShardResult = rankResult.rankFeatureResult().shardResult(); + assertNotNull(rankFeatureShardResult); + + List sortedRankWindowDocs = topRankWindowSizeDocs.stream().sorted().toList(); + assertEquals(sortedRankWindowDocs.size(), rankFeatureShardResult.rankFeatureDocs.length); + for (int i = 0; i < sortedRankWindowDocs.size(); i++) { + assertEquals((long) sortedRankWindowDocs.get(i), rankFeatureShardResult.rankFeatureDocs[i].doc); + assertEquals(rankFeatureShardResult.rankFeatureDocs[i].featureData, "aardvark_" + sortedRankWindowDocs.get(i)); + } + + List globalTopKResults = randomNonEmptySubsetOf( + Arrays.stream(rankFeatureShardResult.rankFeatureDocs).map(x -> x.doc).toList() + ); + + // finally let's create a fetch request to bring back fetch info for the top results + ShardFetchSearchRequest fetchRequest = new ShardFetchSearchRequest( + OriginalIndices.NONE, + rankResult.getContextId(), + request, + globalTopKResults, + null, + null, + rankResult.getRescoreDocIds(), + null + ); + + // execute fetch phase and perform any validations once we retrieve the response + // the difference in how we do assertions here is needed because once the transport service sends back the response + // it decrements the reference to the FetchSearchResult (through the ActionListener#respondAndRelease) and sets hits to null + PlainActionFuture fetchListener = new PlainActionFuture<>() { + @Override + public void onResponse(FetchSearchResult fetchSearchResult) { + assertNotNull(fetchSearchResult); + assertNotNull(fetchSearchResult.hits()); + + int totalHits = fetchSearchResult.hits().getHits().length; + assertEquals(globalTopKResults.size(), totalHits); + for (int i = 0; i < totalHits; i++) { + // rank and score are set by the SearchPhaseController#merge so no need to validate that here + SearchHit hit = fetchSearchResult.hits().getAt(i); + assertNotNull(hit.getFields().get(fetchFieldName)); + assertEquals(hit.getFields().get(fetchFieldName).getValue(), fetchFieldValue + "_" + hit.docId()); + } + super.onResponse(fetchSearchResult); + } + + @Override + public void onFailure(Exception e) { + super.onFailure(e); + throw new AssertionError("No failure should have been raised", e); + } + }; + service.executeFetchPhase(fetchRequest, searchTask, fetchListener); + fetchListener.get(); + } catch (Exception ex) { + if (queryResult != null) { + if (queryResult.hasReferences()) { + queryResult.decRef(); + } + service.freeReaderContext(queryResult.getContextId()); + } + if (rankResult != null && rankResult.hasReferences()) { + rankResult.decRef(); + } + throw ex; + } + } + + public void testRankFeaturePhaseUsingClient() { + final String indexName = "index"; + final String rankFeatureFieldName = "field"; + final String searchFieldName = "search_field"; + final String searchFieldValue = "some_value"; + final String fetchFieldName = "fetch_field"; + final String fetchFieldValue = "fetch_value"; + + final int minDocs = 4; + final int maxDocs = 10; + int numDocs = between(minDocs, maxDocs); + createIndex(indexName); + // index some documents + for (int i = 0; i < numDocs; i++) { + prepareIndex(indexName).setId(String.valueOf(i)) + .setSource( + rankFeatureFieldName, + "aardvark_" + i, + searchFieldName, + searchFieldValue, + fetchFieldName, + fetchFieldValue + "_" + i + ) + .get(); + } + indicesAdmin().prepareRefresh(indexName).get(); + + ElasticsearchAssertions.assertResponse( + client().prepareSearch(indexName) + .setSource( + new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue)) + .size(2) + .from(2) + .fetchField(fetchFieldName) + .rankBuilder( + // here we override only the shard-level contexts + new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + + // no need for more than one queries + @Override + public boolean isCompoundBuilder() { + return false; + } + + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext( + int size, + int from, + Client client + ) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { + @Override + protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + float[] scores = new float[featureDocs.length]; + for (int i = 0; i < featureDocs.length; i++) { + scores[i] = featureDocs[i].score; + } + scoreListener.onResponse(scores); + } + }; + } + + @Override + public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { + return new QueryPhaseRankCoordinatorContext(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + @Override + public ScoreDoc[] rankQueryPhaseResults( + List querySearchResults, + SearchPhaseController.TopDocsStats topDocStats + ) { + List rankDocs = new ArrayList<>(); + for (int i = 0; i < querySearchResults.size(); i++) { + QuerySearchResult querySearchResult = querySearchResults.get(i); + TestRankShardResult shardResult = (TestRankShardResult) querySearchResult + .getRankShardResult(); + for (RankDoc trd : shardResult.testRankDocs) { + trd.shardIndex = i; + rankDocs.add(trd); + } + } + rankDocs.sort(Comparator.comparing((RankDoc doc) -> doc.score).reversed()); + RankDoc[] topResults = rankDocs.stream().limit(rankWindowSize).toArray(RankDoc[]::new); + topDocStats.fetchHits = topResults.length; + return topResults; + } + }; + } + + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + return new QueryPhaseRankShardContext(queries, from) { + + @Override + public int rankWindowSize() { + return DEFAULT_RANK_WINDOW_SIZE; + } + + @Override + public RankShardResult combineQueryPhaseResults(List rankResults) { + // we know we have just 1 query, so return all the docs from it + return new TestRankShardResult( + Arrays.stream(rankResults.get(0).scoreDocs) + .map(x -> new RankDoc(x.doc, x.score, x.shardIndex)) + .limit(rankWindowSize()) + .toArray(RankDoc[]::new) + ); + } + }; + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; + for (int i = 0; i < hits.getHits().length; i++) { + SearchHit hit = hits.getHits()[i]; + rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); + rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue()); + rankFeatureDocs[i].score = randomFloat(); + rankFeatureDocs[i].rank = i + 1; + } + return new RankFeatureShardResult(rankFeatureDocs); + } + }; + } + } + ) + ), + (response) -> { + SearchHits hits = response.getHits(); + assertEquals(hits.getTotalHits().value(), numDocs); + assertEquals(hits.getHits().length, 2); + int index = 0; + for (SearchHit hit : hits.getHits()) { + assertEquals(hit.getRank(), 3 + index); + assertTrue(hit.getScore() >= 0); + assertEquals(hit.getFields().get(fetchFieldName).getValue(), fetchFieldValue + "_" + hit.docId()); + index++; + } + } + ); + } + + public void testRankFeaturePhaseExceptionOnCoordinatingNode() { + final String indexName = "index"; + final String rankFeatureFieldName = "field"; + final String searchFieldName = "search_field"; + final String searchFieldValue = "some_value"; + final String fetchFieldName = "fetch_field"; + final String fetchFieldValue = "fetch_value"; + + final int minDocs = 3; + final int maxDocs = 10; + int numDocs = between(minDocs, maxDocs); + createIndex(indexName); + // index some documents + for (int i = 0; i < numDocs; i++) { + prepareIndex(indexName).setId(String.valueOf(i)) + .setSource( + rankFeatureFieldName, + "aardvark_" + i, + searchFieldName, + searchFieldValue, + fetchFieldName, + fetchFieldValue + "_" + i + ) + .get(); + } + indicesAdmin().prepareRefresh(indexName).get(); + + expectThrows( + SearchPhaseExecutionException.class, + () -> client().prepareSearch(indexName) + .setSource( + new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue)) + .size(2) + .from(2) + .fetchField(fetchFieldName) + .rankBuilder(new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + + // no need for more than one queries + @Override + public boolean isCompoundBuilder() { + return false; + } + + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext( + int size, + int from, + Client client + ) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { + @Override + protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + throw new IllegalStateException("should have failed earlier"); + } + }; + } + + @Override + public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { + return new QueryPhaseRankCoordinatorContext(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + @Override + public ScoreDoc[] rankQueryPhaseResults( + List querySearchResults, + SearchPhaseController.TopDocsStats topDocStats + ) { + throw new UnsupportedOperationException("simulated failure"); + } + }; + } + + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + return new QueryPhaseRankShardContext(queries, from) { + + @Override + public int rankWindowSize() { + return DEFAULT_RANK_WINDOW_SIZE; + } + + @Override + public RankShardResult combineQueryPhaseResults(List rankResults) { + // we know we have just 1 query, so return all the docs from it + return new TestRankShardResult( + Arrays.stream(rankResults.get(0).scoreDocs) + .map(x -> new RankDoc(x.doc, x.score, x.shardIndex)) + .limit(rankWindowSize()) + .toArray(RankDoc[]::new) + ); + } + }; + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; + for (int i = 0; i < hits.getHits().length; i++) { + SearchHit hit = hits.getHits()[i]; + rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); + rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue()); + rankFeatureDocs[i].score = randomFloat(); + rankFeatureDocs[i].rank = i + 1; + } + return new RankFeatureShardResult(rankFeatureDocs); + } + }; + } + }) + ) + .get() + ); + } + + public void testRankFeaturePhaseExceptionAllShardFail() { + final String indexName = "index"; + final String rankFeatureFieldName = "field"; + final String searchFieldName = "search_field"; + final String searchFieldValue = "some_value"; + final String fetchFieldName = "fetch_field"; + final String fetchFieldValue = "fetch_value"; + + final int minDocs = 3; + final int maxDocs = 10; + int numDocs = between(minDocs, maxDocs); + createIndex(indexName); + // index some documents + for (int i = 0; i < numDocs; i++) { + prepareIndex(indexName).setId(String.valueOf(i)) + .setSource( + rankFeatureFieldName, + "aardvark_" + i, + searchFieldName, + searchFieldValue, + fetchFieldName, + fetchFieldValue + "_" + i + ) + .get(); + } + indicesAdmin().prepareRefresh(indexName).get(); + + expectThrows( + SearchPhaseExecutionException.class, + () -> client().prepareSearch(indexName) + .setAllowPartialSearchResults(true) + .setSource( + new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue)) + .fetchField(fetchFieldName) + .rankBuilder( + // here we override only the shard-level contexts + new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + + // no need for more than one queries + @Override + public boolean isCompoundBuilder() { + return false; + } + + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext( + int size, + int from, + Client client + ) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { + @Override + protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + float[] scores = new float[featureDocs.length]; + for (int i = 0; i < featureDocs.length; i++) { + scores[i] = featureDocs[i].score; + } + scoreListener.onResponse(scores); + } + }; + } + + @Override + public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { + return new QueryPhaseRankCoordinatorContext(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + @Override + public ScoreDoc[] rankQueryPhaseResults( + List querySearchResults, + SearchPhaseController.TopDocsStats topDocStats + ) { + List rankDocs = new ArrayList<>(); + for (int i = 0; i < querySearchResults.size(); i++) { + QuerySearchResult querySearchResult = querySearchResults.get(i); + TestRankShardResult shardResult = (TestRankShardResult) querySearchResult + .getRankShardResult(); + for (RankDoc trd : shardResult.testRankDocs) { + trd.shardIndex = i; + rankDocs.add(trd); + } + } + rankDocs.sort(Comparator.comparing((RankDoc doc) -> doc.score).reversed()); + RankDoc[] topResults = rankDocs.stream().limit(rankWindowSize).toArray(RankDoc[]::new); + topDocStats.fetchHits = topResults.length; + return topResults; + } + }; + } + + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + return new QueryPhaseRankShardContext(queries, from) { + + @Override + public int rankWindowSize() { + return DEFAULT_RANK_WINDOW_SIZE; + } + + @Override + public RankShardResult combineQueryPhaseResults(List rankResults) { + // we know we have just 1 query, so return all the docs from it + return new TestRankShardResult( + Arrays.stream(rankResults.get(0).scoreDocs) + .map(x -> new RankDoc(x.doc, x.score, x.shardIndex)) + .limit(rankWindowSize()) + .toArray(RankDoc[]::new) + ); + } + }; + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + throw new UnsupportedOperationException("simulated failure"); + } + }; + } + } + ) + ) + .get() + ); + } + + public void testRankFeaturePhaseExceptionOneShardFails() { + // if we have only one shard and it fails, it will fallback to context.onPhaseFailure which will eventually clean up all contexts. + // in this test we want to make sure that even if one shard (of many) fails during the RankFeaturePhase, then the appropriate + // context will have been cleaned up. + final String indexName = "index"; + final String rankFeatureFieldName = "field"; + final String searchFieldName = "search_field"; + final String searchFieldValue = "some_value"; + final String fetchFieldName = "fetch_field"; + final String fetchFieldValue = "fetch_value"; + + final int minDocs = 3; + final int maxDocs = 10; + int numDocs = between(minDocs, maxDocs); + createIndex(indexName, Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 2).build()); + // index some documents + for (int i = 0; i < numDocs; i++) { + prepareIndex(indexName).setId(String.valueOf(i)) + .setSource( + rankFeatureFieldName, + "aardvark_" + i, + searchFieldName, + searchFieldValue, + fetchFieldName, + fetchFieldValue + "_" + i + ) + .get(); + } + indicesAdmin().prepareRefresh(indexName).get(); + + assertResponse( + client().prepareSearch(indexName) + .setAllowPartialSearchResults(true) + .setSource( + new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue)) + .fetchField(fetchFieldName) + .rankBuilder( + // here we override only the shard-level contexts + new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + + // no need for more than one queries + @Override + public boolean isCompoundBuilder() { + return false; + } + + @Override + public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext( + int size, + int from, + Client client + ) { + return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { + @Override + protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { + float[] scores = new float[featureDocs.length]; + for (int i = 0; i < featureDocs.length; i++) { + scores[i] = featureDocs[i].score; + } + scoreListener.onResponse(scores); + } + }; + } + + @Override + public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { + return new QueryPhaseRankCoordinatorContext(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { + @Override + public ScoreDoc[] rankQueryPhaseResults( + List querySearchResults, + SearchPhaseController.TopDocsStats topDocStats + ) { + List rankDocs = new ArrayList<>(); + for (int i = 0; i < querySearchResults.size(); i++) { + QuerySearchResult querySearchResult = querySearchResults.get(i); + TestRankShardResult shardResult = (TestRankShardResult) querySearchResult + .getRankShardResult(); + for (RankDoc trd : shardResult.testRankDocs) { + trd.shardIndex = i; + rankDocs.add(trd); + } + } + rankDocs.sort(Comparator.comparing((RankDoc doc) -> doc.score).reversed()); + RankDoc[] topResults = rankDocs.stream().limit(rankWindowSize).toArray(RankDoc[]::new); + topDocStats.fetchHits = topResults.length; + return topResults; + } + }; + } + + @Override + public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { + return new QueryPhaseRankShardContext(queries, from) { + + @Override + public int rankWindowSize() { + return DEFAULT_RANK_WINDOW_SIZE; + } + + @Override + public RankShardResult combineQueryPhaseResults(List rankResults) { + // we know we have just 1 query, so return all the docs from it + return new TestRankShardResult( + Arrays.stream(rankResults.get(0).scoreDocs) + .map(x -> new RankDoc(x.doc, x.score, x.shardIndex)) + .limit(rankWindowSize()) + .toArray(RankDoc[]::new) + ); + } + }; + } + + @Override + public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { + return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) { + @Override + public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { + if (shardId == 0) { + throw new UnsupportedOperationException("simulated failure"); + } else { + RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; + for (int i = 0; i < hits.getHits().length; i++) { + SearchHit hit = hits.getHits()[i]; + rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); + rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue()); + rankFeatureDocs[i].score = randomFloat(); + rankFeatureDocs[i].rank = i + 1; + } + return new RankFeatureShardResult(rankFeatureDocs); + } + } + }; + } + } + ) + ), + (searchResponse) -> { + assertEquals(1, searchResponse.getSuccessfulShards()); + assertEquals("simulated failure", searchResponse.getShardFailures()[0].getCause().getMessage()); + assertNotEquals(0, searchResponse.getHits().getHits().length); + for (SearchHit hit : searchResponse.getHits().getHits()) { + assertEquals(fetchFieldValue + "_" + hit.getId(), hit.getFields().get(fetchFieldName).getValue()); + assertEquals(1, hit.getShard().getShardId().id()); + } + } + ); + } + + public void testSearchWhileIndexDeletedDoesNotLeakSearchContext() throws ExecutionException, InterruptedException { + createIndex("index"); + prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); + + IndicesService indicesService = getInstanceFromNode(IndicesService.class); + IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); + IndexShard indexShard = indexService.getShard(0); + + MockSearchService service = (MockSearchService) getInstanceFromNode(SearchService.class); + service.setOnPutContext(context -> { + if (context.indexShard() == indexShard) { + assertAcked(indicesAdmin().prepareDelete("index")); + } + }); + + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + SearchRequest scrollSearchRequest = new SearchRequest().allowPartialSearchResults(true) + .scroll(new Scroll(TimeValue.timeValueMinutes(1))); + + // the scrolls are not explicitly freed, but should all be gone when the test finished. + // for completeness, we also randomly test the regular search path. + final boolean useScroll = randomBoolean(); + PlainActionFuture result = new PlainActionFuture<>(); + service.executeQueryPhase( + new ShardSearchRequest( + OriginalIndices.NONE, + useScroll ? scrollSearchRequest : searchRequest, + new ShardId(resolveIndex("index"), 0), + 0, + 1, + AliasFilter.EMPTY, + 1.0f, + -1, + null + ), + new SearchShardTask(123L, "", "", "", null, emptyMap()), + result + ); + + try { + result.get(); + } catch (Exception e) { + // ok + } + + expectThrows(IndexNotFoundException.class, () -> indicesAdmin().prepareGetIndex().setIndices("index").get()); + + assertEquals(0, service.getActiveContexts()); + + SearchStats.Stats totalStats = indexShard.searchStats().getTotal(); + assertEquals(0, totalStats.getQueryCurrent()); + assertEquals(0, totalStats.getScrollCurrent()); + assertEquals(0, totalStats.getFetchCurrent()); + } + + public void testBeforeShardLockDuringShardCreate() { + IndexService indexService = createIndex("index", Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).build()); + prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); + assertResponse( + client().prepareSearch("index").setSize(1).setScroll(TimeValue.timeValueMinutes(1)), + searchResponse -> assertThat(searchResponse.getScrollId(), is(notNullValue())) + ); + SearchService service = getInstanceFromNode(SearchService.class); + + assertEquals(1, service.getActiveContexts()); + service.beforeIndexShardCreated( + TestShardRouting.newShardRouting( + "test", + 0, + randomAlphaOfLength(5), + randomAlphaOfLength(5), + randomBoolean(), + ShardRoutingState.INITIALIZING + ), + indexService.getIndexSettings().getSettings() + ); + assertEquals(1, service.getActiveContexts()); + + service.beforeIndexShardCreated( + TestShardRouting.newShardRouting( + new ShardId(indexService.index(), 0), + randomAlphaOfLength(5), + randomBoolean(), + ShardRoutingState.INITIALIZING + ), + indexService.getIndexSettings().getSettings() + ); + assertEquals(0, service.getActiveContexts()); + } + + public void testTimeout() throws IOException { + createIndex("index"); + final SearchService service = getInstanceFromNode(SearchService.class); + final IndicesService indicesService = getInstanceFromNode(IndicesService.class); + final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); + final IndexShard indexShard = indexService.getShard(0); + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + final ShardSearchRequest requestWithDefaultTimeout = new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + indexShard.shardId(), + 0, + 1, + AliasFilter.EMPTY, + 1.0f, + -1, + null + ); + + try ( + ReaderContext reader = createReaderContext(indexService, indexShard); + SearchContext contextWithDefaultTimeout = service.createContext( + reader, + requestWithDefaultTimeout, + mock(SearchShardTask.class), + ResultsType.NONE, + randomBoolean() + ) + ) { + // the search context should inherit the default timeout + assertThat(contextWithDefaultTimeout.timeout(), equalTo(TimeValue.timeValueSeconds(5))); + } + + final long seconds = randomIntBetween(6, 10); + searchRequest.source(new SearchSourceBuilder().timeout(TimeValue.timeValueSeconds(seconds))); + final ShardSearchRequest requestWithCustomTimeout = new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + indexShard.shardId(), + 0, + 1, + AliasFilter.EMPTY, + 1.0f, + -1, + null + ); + try ( + ReaderContext reader = createReaderContext(indexService, indexShard); + SearchContext context = service.createContext( + reader, + requestWithCustomTimeout, + mock(SearchShardTask.class), + ResultsType.NONE, + randomBoolean() + ) + ) { + // the search context should inherit the query timeout + assertThat(context.timeout(), equalTo(TimeValue.timeValueSeconds(seconds))); + } + } + + /** + * test that getting more than the allowed number of docvalue_fields throws an exception + */ + public void testMaxDocvalueFieldsSearch() throws IOException { + final Settings settings = Settings.builder().put(IndexSettings.MAX_DOCVALUE_FIELDS_SEARCH_SETTING.getKey(), 1).build(); + createIndex("index", settings, null, "field1", "keyword", "field2", "keyword"); + prepareIndex("index").setId("1").setSource("field1", "value1", "field2", "value2").setRefreshPolicy(IMMEDIATE).get(); + + final SearchService service = getInstanceFromNode(SearchService.class); + final IndicesService indicesService = getInstanceFromNode(IndicesService.class); + final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); + final IndexShard indexShard = indexService.getShard(0); + + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchRequest.source(searchSourceBuilder); + searchSourceBuilder.docValueField("field1"); + + final ShardSearchRequest request = new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + indexShard.shardId(), + 0, + 1, + AliasFilter.EMPTY, + 1.0f, + -1, + null + ); + try ( + ReaderContext reader = createReaderContext(indexService, indexShard); + SearchContext context = service.createContext(reader, request, mock(SearchShardTask.class), ResultsType.NONE, randomBoolean()) + ) { + assertNotNull(context); + } + + searchSourceBuilder.docValueField("unmapped_field"); + try ( + ReaderContext reader = createReaderContext(indexService, indexShard); + SearchContext context = service.createContext(reader, request, mock(SearchShardTask.class), ResultsType.NONE, randomBoolean()) + ) { + assertNotNull(context); + } + + searchSourceBuilder.docValueField("field2"); + try (ReaderContext reader = createReaderContext(indexService, indexShard)) { + IllegalArgumentException ex = expectThrows( + IllegalArgumentException.class, + () -> service.createContext(reader, request, mock(SearchShardTask.class), ResultsType.NONE, randomBoolean()) + ); + assertEquals( + "Trying to retrieve too many docvalue_fields. Must be less than or equal to: [1] but was [2]. " + + "This limit can be set by changing the [index.max_docvalue_fields_search] index level setting.", + ex.getMessage() + ); + } + } + + public void testDeduplicateDocValuesFields() throws Exception { + createIndex("index", Settings.EMPTY, "_doc", "field1", "type=date", "field2", "type=date"); + prepareIndex("index").setId("1").setSource("field1", "2022-08-03", "field2", "2022-08-04").setRefreshPolicy(IMMEDIATE).get(); + SearchService service = getInstanceFromNode(SearchService.class); + IndicesService indicesService = getInstanceFromNode(IndicesService.class); + IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); + IndexShard indexShard = indexService.getShard(0); + + try (ReaderContext reader = createReaderContext(indexService, indexShard)) { + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchRequest.source(searchSourceBuilder); + searchSourceBuilder.docValueField("f*"); + if (randomBoolean()) { + searchSourceBuilder.docValueField("field*"); + } + if (randomBoolean()) { + searchSourceBuilder.docValueField("*2"); + } + ShardSearchRequest request = new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + indexShard.shardId(), + 0, + 1, + AliasFilter.EMPTY, + 1.0f, + -1, + null + ); + try ( + SearchContext context = service.createContext( + reader, + request, + mock(SearchShardTask.class), + ResultsType.NONE, + randomBoolean() + ) + ) { + Collection fields = context.docValuesContext().fields(); + assertThat(fields, containsInAnyOrder(new FieldAndFormat("field1", null), new FieldAndFormat("field2", null))); + } + } + } + + /** + * test that getting more than the allowed number of script_fields throws an exception + */ + public void testMaxScriptFieldsSearch() throws IOException { + createIndex("index"); + final SearchService service = getInstanceFromNode(SearchService.class); + final IndicesService indicesService = getInstanceFromNode(IndicesService.class); + final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); + final IndexShard indexShard = indexService.getShard(0); + + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchRequest.source(searchSourceBuilder); + // adding the maximum allowed number of script_fields to retrieve + int maxScriptFields = indexService.getIndexSettings().getMaxScriptFields(); + for (int i = 0; i < maxScriptFields; i++) { + searchSourceBuilder.scriptField( + "field" + i, + new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, emptyMap()) + ); + } + final ShardSearchRequest request = new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + indexShard.shardId(), + 0, + 1, + AliasFilter.EMPTY, + 1.0f, + -1, + null + ); + + try (ReaderContext reader = createReaderContext(indexService, indexShard)) { + try ( + SearchContext context = service.createContext( + reader, + request, + mock(SearchShardTask.class), + ResultsType.NONE, + randomBoolean() + ) + ) { + assertNotNull(context); + } + searchSourceBuilder.scriptField( + "anotherScriptField", + new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, emptyMap()) + ); + IllegalArgumentException ex = expectThrows( + IllegalArgumentException.class, + () -> service.createContext(reader, request, mock(SearchShardTask.class), ResultsType.NONE, randomBoolean()) + ); + assertEquals( + "Trying to retrieve too many script_fields. Must be less than or equal to: [" + + maxScriptFields + + "] but was [" + + (maxScriptFields + 1) + + "]. This limit can be set by changing the [index.max_script_fields] index level setting.", + ex.getMessage() + ); + } + } + + public void testIgnoreScriptfieldIfSizeZero() throws IOException { + createIndex("index"); + final SearchService service = getInstanceFromNode(SearchService.class); + final IndicesService indicesService = getInstanceFromNode(IndicesService.class); + final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); + final IndexShard indexShard = indexService.getShard(0); + + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchRequest.source(searchSourceBuilder); + searchSourceBuilder.scriptField( + "field" + 0, + new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, emptyMap()) + ); + searchSourceBuilder.size(0); + final ShardSearchRequest request = new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + indexShard.shardId(), + 0, + 1, + AliasFilter.EMPTY, + 1.0f, + -1, + null + ); + try ( + ReaderContext reader = createReaderContext(indexService, indexShard); + SearchContext context = service.createContext(reader, request, mock(SearchShardTask.class), ResultsType.NONE, randomBoolean()) + ) { + assertEquals(0, context.scriptFields().fields().size()); + } + } + + /** + * test that creating more than the allowed number of scroll contexts throws an exception + */ + public void testMaxOpenScrollContexts() throws Exception { + createIndex("index"); + prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); + + final SearchService service = getInstanceFromNode(SearchService.class); + final IndicesService indicesService = getInstanceFromNode(IndicesService.class); + final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); + final IndexShard indexShard = indexService.getShard(0); + + // Open all possible scrolls, clear some of them, then open more until the limit is reached + LinkedList clearScrollIds = new LinkedList<>(); + + for (int i = 0; i < SearchService.MAX_OPEN_SCROLL_CONTEXT.get(Settings.EMPTY); i++) { + assertResponse(client().prepareSearch("index").setSize(1).setScroll(TimeValue.timeValueMinutes(1)), searchResponse -> { + if (randomInt(4) == 0) clearScrollIds.addLast(searchResponse.getScrollId()); + }); + } + + ClearScrollRequest clearScrollRequest = new ClearScrollRequest(); + clearScrollRequest.setScrollIds(clearScrollIds); + client().clearScroll(clearScrollRequest).get(); + + for (int i = 0; i < clearScrollIds.size(); i++) { + client().prepareSearch("index").setSize(1).setScroll(TimeValue.timeValueMinutes(1)).get().decRef(); + } + + final ShardScrollRequestTest request = new ShardScrollRequestTest(indexShard.shardId()); + ElasticsearchException ex = expectThrows( + ElasticsearchException.class, + () -> service.createAndPutReaderContext( + request, + indexService, + indexShard, + indexShard.acquireSearcherSupplier(), + SearchService.KEEPALIVE_INTERVAL_SETTING.get(Settings.EMPTY).millis() + ) + ); + assertEquals( + "Trying to create too many scroll contexts. Must be less than or equal to: [" + + SearchService.MAX_OPEN_SCROLL_CONTEXT.get(Settings.EMPTY) + + "]. " + + "This limit can be set by changing the [search.max_open_scroll_context] setting.", + ex.getMessage() + ); + assertEquals(RestStatus.TOO_MANY_REQUESTS, ex.status()); + + service.freeAllScrollContexts(); + } + + public void testOpenScrollContextsConcurrently() throws Exception { + createIndex("index"); + final IndicesService indicesService = getInstanceFromNode(IndicesService.class); + final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); + final IndexShard indexShard = indexService.getShard(0); + + final int maxScrollContexts = SearchService.MAX_OPEN_SCROLL_CONTEXT.get(Settings.EMPTY); + final SearchService searchService = getInstanceFromNode(SearchService.class); + Thread[] threads = new Thread[randomIntBetween(2, 8)]; + CountDownLatch latch = new CountDownLatch(threads.length); + for (int i = 0; i < threads.length; i++) { + threads[i] = new Thread(() -> { + latch.countDown(); + try { + latch.await(); + for (;;) { + final Engine.SearcherSupplier reader = indexShard.acquireSearcherSupplier(); + try { + final ShardScrollRequestTest request = new ShardScrollRequestTest(indexShard.shardId()); + searchService.createAndPutReaderContext( + request, + indexService, + indexShard, + reader, + SearchService.KEEPALIVE_INTERVAL_SETTING.get(Settings.EMPTY).millis() + ); + } catch (ElasticsearchException e) { + assertThat( + e.getMessage(), + equalTo( + "Trying to create too many scroll contexts. Must be less than or equal to: " + + "[" + + maxScrollContexts + + "]. " + + "This limit can be set by changing the [search.max_open_scroll_context] setting." + ) + ); + return; + } + } + } catch (Exception e) { + throw new AssertionError(e); + } + }); + threads[i].setName("elasticsearch[node_s_0][search]"); + threads[i].start(); + } + for (Thread thread : threads) { + thread.join(); + } + assertThat(searchService.getActiveContexts(), equalTo(maxScrollContexts)); + searchService.freeAllScrollContexts(); + } + + public static class FailOnRewriteQueryPlugin extends Plugin implements SearchPlugin { + @Override + public List> getQueries() { + return singletonList(new QuerySpec<>("fail_on_rewrite_query", FailOnRewriteQueryBuilder::new, parseContext -> { + throw new UnsupportedOperationException("No query parser for this plugin"); + })); + } + } + + public static class FailOnRewriteQueryBuilder extends DummyQueryBuilder { + + public FailOnRewriteQueryBuilder(StreamInput in) throws IOException { + super(in); + } + + public FailOnRewriteQueryBuilder() {} + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { + if (queryRewriteContext.convertToSearchExecutionContext() != null) { + throw new IllegalStateException("Fail on rewrite phase"); + } + return this; + } + } + + private static class ShardScrollRequestTest extends ShardSearchRequest { + private Scroll scroll; + + ShardScrollRequestTest(ShardId shardId) { + super( + OriginalIndices.NONE, + new SearchRequest().allowPartialSearchResults(true), + shardId, + 0, + 1, + AliasFilter.EMPTY, + 1f, + -1, + null + ); + this.scroll = new Scroll(TimeValue.timeValueMinutes(1)); + } + + @Override + public Scroll scroll() { + return this.scroll; + } + } + + public void testCanMatch() throws Exception { + createIndex("index"); + final SearchService service = getInstanceFromNode(SearchService.class); + final IndicesService indicesService = getInstanceFromNode(IndicesService.class); + final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); + final IndexShard indexShard = indexService.getShard(0); + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + assertTrue( + service.canMatch( + new ShardSearchRequest(OriginalIndices.NONE, searchRequest, indexShard.shardId(), 0, 1, AliasFilter.EMPTY, 1f, -1, null) + ).canMatch() + ); + + searchRequest.source(new SearchSourceBuilder()); + assertTrue( + service.canMatch( + new ShardSearchRequest(OriginalIndices.NONE, searchRequest, indexShard.shardId(), 0, 1, AliasFilter.EMPTY, 1f, -1, null) + ).canMatch() + ); + + searchRequest.source(new SearchSourceBuilder().query(new MatchAllQueryBuilder())); + assertTrue( + service.canMatch( + new ShardSearchRequest(OriginalIndices.NONE, searchRequest, indexShard.shardId(), 0, 1, AliasFilter.EMPTY, 1f, -1, null) + ).canMatch() + ); + + searchRequest.source( + new SearchSourceBuilder().query(new MatchNoneQueryBuilder()) + .aggregation(new TermsAggregationBuilder("test").userValueTypeHint(ValueType.STRING).minDocCount(0)) + ); + assertTrue( + service.canMatch( + new ShardSearchRequest(OriginalIndices.NONE, searchRequest, indexShard.shardId(), 0, 1, AliasFilter.EMPTY, 1f, -1, null) + ).canMatch() + ); + searchRequest.source( + new SearchSourceBuilder().query(new MatchNoneQueryBuilder()).aggregation(new GlobalAggregationBuilder("test")) + ); + assertTrue( + service.canMatch( + new ShardSearchRequest(OriginalIndices.NONE, searchRequest, indexShard.shardId(), 0, 1, AliasFilter.EMPTY, 1f, -1, null) + ).canMatch() + ); + + searchRequest.source(new SearchSourceBuilder().query(new MatchNoneQueryBuilder())); + assertFalse( + service.canMatch( + new ShardSearchRequest(OriginalIndices.NONE, searchRequest, indexShard.shardId(), 0, 1, AliasFilter.EMPTY, 1f, -1, null) + ).canMatch() + ); + assertEquals(5, numWrapInvocations.get()); + + ShardSearchRequest request = new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + indexShard.shardId(), + 0, + 1, + AliasFilter.EMPTY, + 1.0f, + -1, + null + ); + + /* + * Checks that canMatch takes into account the alias filter + */ + // the source cannot be rewritten to a match_none + searchRequest.indices("alias").source(new SearchSourceBuilder().query(new MatchAllQueryBuilder())); + assertFalse( + service.canMatch( + new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + indexShard.shardId(), + 0, + 1, + AliasFilter.of(new TermQueryBuilder("foo", "bar"), "alias"), + 1f, + -1, + null + ) + ).canMatch() + ); + // the source can match and can be rewritten to a match_none, but not the alias filter + final DocWriteResponse response = prepareIndex("index").setSource("id", "1").get(); + assertEquals(RestStatus.CREATED, response.status()); + searchRequest.indices("alias").source(new SearchSourceBuilder().query(new TermQueryBuilder("id", "1"))); + assertFalse( + service.canMatch( + new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + indexShard.shardId(), + 0, + 1, + AliasFilter.of(new TermQueryBuilder("foo", "bar"), "alias"), + 1f, + -1, + null + ) + ).canMatch() + ); + + CountDownLatch latch = new CountDownLatch(1); + SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); + // Because the foo field used in alias filter is unmapped the term query builder rewrite can resolve to a match no docs query, + // without acquiring a searcher and that means the wrapper is not called + assertEquals(5, numWrapInvocations.get()); + service.executeQueryPhase(request, task, new ActionListener<>() { + @Override + public void onResponse(SearchPhaseResult searchPhaseResult) { + try { + // make sure that the wrapper is called when the query is actually executed + assertEquals(6, numWrapInvocations.get()); + } finally { + latch.countDown(); + } + } + + @Override + public void onFailure(Exception e) { + try { + throw new AssertionError(e); + } finally { + latch.countDown(); + } + } + }); + latch.await(); + } + + public void testCanRewriteToMatchNone() { + assertFalse( + SearchService.canRewriteToMatchNone( + new SearchSourceBuilder().query(new MatchNoneQueryBuilder()).aggregation(new GlobalAggregationBuilder("test")) + ) + ); + assertFalse(SearchService.canRewriteToMatchNone(new SearchSourceBuilder())); + assertFalse(SearchService.canRewriteToMatchNone(null)); + assertFalse( + SearchService.canRewriteToMatchNone( + new SearchSourceBuilder().query(new MatchNoneQueryBuilder()) + .aggregation(new TermsAggregationBuilder("test").userValueTypeHint(ValueType.STRING).minDocCount(0)) + ) + ); + assertTrue(SearchService.canRewriteToMatchNone(new SearchSourceBuilder().query(new TermQueryBuilder("foo", "bar")))); + assertTrue( + SearchService.canRewriteToMatchNone( + new SearchSourceBuilder().query(new MatchNoneQueryBuilder()) + .aggregation(new TermsAggregationBuilder("test").userValueTypeHint(ValueType.STRING).minDocCount(1)) + ) + ); + assertFalse( + SearchService.canRewriteToMatchNone( + new SearchSourceBuilder().query(new MatchNoneQueryBuilder()) + .aggregation(new TermsAggregationBuilder("test").userValueTypeHint(ValueType.STRING).minDocCount(1)) + .suggest(new SuggestBuilder()) + ) + ); + assertFalse( + SearchService.canRewriteToMatchNone( + new SearchSourceBuilder().query(new TermQueryBuilder("foo", "bar")).suggest(new SuggestBuilder()) + ) + ); + } + + public void testSetSearchThrottled() throws IOException { + createIndex("throttled_threadpool_index"); + client().execute( + InternalOrPrivateSettingsPlugin.UpdateInternalOrPrivateAction.INSTANCE, + new InternalOrPrivateSettingsPlugin.UpdateInternalOrPrivateAction.Request( + "throttled_threadpool_index", + IndexSettings.INDEX_SEARCH_THROTTLED.getKey(), + "true" + ) + ).actionGet(); + final SearchService service = getInstanceFromNode(SearchService.class); + Index index = resolveIndex("throttled_threadpool_index"); + assertTrue(service.getIndicesService().indexServiceSafe(index).getIndexSettings().isSearchThrottled()); + prepareIndex("throttled_threadpool_index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); + assertSearchHits( + client().prepareSearch("throttled_threadpool_index") + .setIndicesOptions(IndicesOptions.STRICT_EXPAND_OPEN_FORBID_CLOSED) + .setSize(1), + "1" + ); + // we add a search action listener in a plugin above to assert that this is actually used + client().execute( + InternalOrPrivateSettingsPlugin.UpdateInternalOrPrivateAction.INSTANCE, + new InternalOrPrivateSettingsPlugin.UpdateInternalOrPrivateAction.Request( + "throttled_threadpool_index", + IndexSettings.INDEX_SEARCH_THROTTLED.getKey(), + "false" + ) + ).actionGet(); + + IllegalArgumentException iae = expectThrows( + IllegalArgumentException.class, + () -> indicesAdmin().prepareUpdateSettings("throttled_threadpool_index") + .setSettings(Settings.builder().put(IndexSettings.INDEX_SEARCH_THROTTLED.getKey(), false)) + .get() + ); + assertEquals("can not update private setting [index.search.throttled]; this setting is managed by Elasticsearch", iae.getMessage()); + assertFalse(service.getIndicesService().indexServiceSafe(index).getIndexSettings().isSearchThrottled()); + } + + public void testAggContextGetsMatchAll() throws IOException { + createIndex("test"); + withAggregationContext("test", context -> assertThat(context.query(), equalTo(new MatchAllDocsQuery()))); + } + + public void testAggContextGetsNestedFilter() throws IOException { + XContentBuilder mapping = JsonXContent.contentBuilder().startObject().startObject("properties"); + mapping.startObject("nested").field("type", "nested").endObject(); + mapping.endObject().endObject(); + + createIndex("test", Settings.EMPTY, mapping); + withAggregationContext("test", context -> assertThat(context.query(), equalTo(new MatchAllDocsQuery()))); + } + + /** + * Build an {@link AggregationContext} with the named index. + */ + private void withAggregationContext(String index, Consumer check) throws IOException { + IndexService indexService = getInstanceFromNode(IndicesService.class).indexServiceSafe(resolveIndex(index)); + ShardId shardId = new ShardId(indexService.index(), 0); + + SearchRequest request = new SearchRequest().indices(index) + .source(new SearchSourceBuilder().aggregation(new FiltersAggregationBuilder("test", new MatchAllQueryBuilder()))) + .allowPartialSearchResults(false); + ShardSearchRequest shardRequest = new ShardSearchRequest( + OriginalIndices.NONE, + request, + shardId, + 0, + 1, + AliasFilter.EMPTY, + 1, + 0, + null + ); + + try (ReaderContext readerContext = createReaderContext(indexService, indexService.getShard(0))) { + try ( + SearchContext context = getInstanceFromNode(SearchService.class).createContext( + readerContext, + shardRequest, + mock(SearchShardTask.class), + ResultsType.QUERY, + true + ) + ) { + check.accept(context.aggregations().factories().context()); + } + } + } + + public void testExpandSearchThrottled() { + createIndex("throttled_threadpool_index"); + client().execute( + InternalOrPrivateSettingsPlugin.UpdateInternalOrPrivateAction.INSTANCE, + new InternalOrPrivateSettingsPlugin.UpdateInternalOrPrivateAction.Request( + "throttled_threadpool_index", + IndexSettings.INDEX_SEARCH_THROTTLED.getKey(), + "true" + ) + ).actionGet(); + + prepareIndex("throttled_threadpool_index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); + assertHitCount(client().prepareSearch(), 1L); + assertHitCount(client().prepareSearch().setIndicesOptions(IndicesOptions.STRICT_EXPAND_OPEN_FORBID_CLOSED), 1L); + } + + public void testExpandSearchFrozen() { + String indexName = "frozen_index"; + createIndex(indexName); + client().execute( + InternalOrPrivateSettingsPlugin.UpdateInternalOrPrivateAction.INSTANCE, + new InternalOrPrivateSettingsPlugin.UpdateInternalOrPrivateAction.Request(indexName, "index.frozen", "true") + ).actionGet(); + + prepareIndex(indexName).setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); + assertHitCount(client().prepareSearch(), 0L); + assertHitCount(client().prepareSearch().setIndicesOptions(IndicesOptions.STRICT_EXPAND_OPEN_FORBID_CLOSED), 1L); + assertWarnings(TransportSearchAction.FROZEN_INDICES_DEPRECATION_MESSAGE.replace("{}", indexName)); + } + + public void testCreateReduceContext() { + SearchService service = getInstanceFromNode(SearchService.class); + AggregationReduceContext.Builder reduceContextBuilder = service.aggReduceContextBuilder( + () -> false, + new SearchRequest().source(new SearchSourceBuilder()).source().aggregations() + ); + { + AggregationReduceContext reduceContext = reduceContextBuilder.forFinalReduction(); + expectThrows( + MultiBucketConsumerService.TooManyBucketsException.class, + () -> reduceContext.consumeBucketsAndMaybeBreak(MultiBucketConsumerService.DEFAULT_MAX_BUCKETS + 1) + ); + } + { + AggregationReduceContext reduceContext = reduceContextBuilder.forPartialReduction(); + reduceContext.consumeBucketsAndMaybeBreak(MultiBucketConsumerService.DEFAULT_MAX_BUCKETS + 1); + } + } + + public void testMultiBucketConsumerServiceCB() { + MultiBucketConsumerService service = new MultiBucketConsumerService( + getInstanceFromNode(ClusterService.class), + Settings.EMPTY, + new NoopCircuitBreaker("test") { + + @Override + public void addEstimateBytesAndMaybeBreak(long bytes, String label) throws CircuitBreakingException { + throw new CircuitBreakingException("tripped", getDurability()); + } + } + ); + // for partial + { + IntConsumer consumer = service.createForPartial(); + for (int i = 0; i < 1023; i++) { + consumer.accept(0); + } + CircuitBreakingException ex = expectThrows(CircuitBreakingException.class, () -> consumer.accept(0)); + assertThat(ex.getMessage(), equalTo("tripped")); + } + // for final + { + IntConsumer consumer = service.createForFinal(); + for (int i = 0; i < 1023; i++) { + consumer.accept(0); + } + CircuitBreakingException ex = expectThrows(CircuitBreakingException.class, () -> consumer.accept(0)); + assertThat(ex.getMessage(), equalTo("tripped")); + } + } + + public void testCreateSearchContext() throws IOException { + String index = randomAlphaOfLengthBetween(5, 10).toLowerCase(Locale.ROOT); + IndexService indexService = createIndex(index); + final SearchService service = getInstanceFromNode(SearchService.class); + ShardId shardId = new ShardId(indexService.index(), 0); + long nowInMillis = System.currentTimeMillis(); + String clusterAlias = randomBoolean() ? null : randomAlphaOfLengthBetween(3, 10); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.allowPartialSearchResults(randomBoolean()); + ShardSearchRequest request = new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + shardId, + 0, + indexService.numberOfShards(), + AliasFilter.EMPTY, + 1f, + nowInMillis, + clusterAlias + ); + try (SearchContext searchContext = service.createSearchContext(request, new TimeValue(System.currentTimeMillis()))) { + SearchShardTarget searchShardTarget = searchContext.shardTarget(); + SearchExecutionContext searchExecutionContext = searchContext.getSearchExecutionContext(); + String expectedIndexName = clusterAlias == null ? index : clusterAlias + ":" + index; + assertEquals(expectedIndexName, searchExecutionContext.getFullyQualifiedIndex().getName()); + assertEquals(expectedIndexName, searchShardTarget.getFullyQualifiedIndexName()); + assertEquals(clusterAlias, searchShardTarget.getClusterAlias()); + assertEquals(shardId, searchShardTarget.getShardId()); + + assertNull(searchContext.dfsResult()); + searchContext.addDfsResult(); + assertSame(searchShardTarget, searchContext.dfsResult().getSearchShardTarget()); + + assertNull(searchContext.queryResult()); + searchContext.addQueryResult(); + assertSame(searchShardTarget, searchContext.queryResult().getSearchShardTarget()); + + assertNull(searchContext.fetchResult()); + searchContext.addFetchResult(); + assertSame(searchShardTarget, searchContext.fetchResult().getSearchShardTarget()); + } + } + + /** + * While we have no NPE in DefaultContext constructor anymore, we still want to guard against it (or other failures) in the future to + * avoid leaking searchers. + */ + public void testCreateSearchContextFailure() throws Exception { + final String index = randomAlphaOfLengthBetween(5, 10).toLowerCase(Locale.ROOT); + final IndexService indexService = createIndex(index); + final SearchService service = getInstanceFromNode(SearchService.class); + final ShardId shardId = new ShardId(indexService.index(), 0); + final ShardSearchRequest request = new ShardSearchRequest(shardId, 0, null) { + @Override + public SearchType searchType() { + // induce an artificial NPE + throw new NullPointerException("expected"); + } + }; + try (ReaderContext reader = createReaderContext(indexService, indexService.getShard(shardId.id()))) { + NullPointerException e = expectThrows( + NullPointerException.class, + () -> service.createContext(reader, request, mock(SearchShardTask.class), ResultsType.NONE, randomBoolean()) + ); + assertEquals("expected", e.getMessage()); + } + // Needs to busily assert because Engine#refreshNeeded can increase the refCount. + assertBusy( + () -> assertEquals("should have 2 store refs (IndexService + InternalEngine)", 2, indexService.getShard(0).store().refCount()) + ); + } + + public void testMatchNoDocsEmptyResponse() throws InterruptedException { + createIndex("index"); + Thread currentThread = Thread.currentThread(); + SearchService service = getInstanceFromNode(SearchService.class); + IndicesService indicesService = getInstanceFromNode(IndicesService.class); + IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); + IndexShard indexShard = indexService.getShard(0); + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(false) + .source(new SearchSourceBuilder().aggregation(AggregationBuilders.count("count").field("value"))); + ShardSearchRequest shardRequest = new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + indexShard.shardId(), + 0, + 5, + AliasFilter.EMPTY, + 1.0f, + 0, + null + ); + SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); + + { + CountDownLatch latch = new CountDownLatch(1); + shardRequest.source().query(new MatchAllQueryBuilder()); + service.executeQueryPhase(shardRequest, task, new ActionListener<>() { + @Override + public void onResponse(SearchPhaseResult result) { + try { + assertNotSame(Thread.currentThread(), currentThread); + assertThat(Thread.currentThread().getName(), startsWith("elasticsearch[node_s_0][search]")); + assertThat(result, instanceOf(QuerySearchResult.class)); + assertFalse(result.queryResult().isNull()); + assertNotNull(result.queryResult().topDocs()); + assertNotNull(result.queryResult().aggregations()); + } finally { + latch.countDown(); + } + } + + @Override + public void onFailure(Exception exc) { + try { + throw new AssertionError(exc); + } finally { + latch.countDown(); + } + } + }); + latch.await(); + } + + { + CountDownLatch latch = new CountDownLatch(1); + shardRequest.source().query(new MatchNoneQueryBuilder()); + service.executeQueryPhase(shardRequest, task, new ActionListener<>() { + @Override + public void onResponse(SearchPhaseResult result) { + try { + assertNotSame(Thread.currentThread(), currentThread); + assertThat(Thread.currentThread().getName(), startsWith("elasticsearch[node_s_0][search]")); + assertThat(result, instanceOf(QuerySearchResult.class)); + assertFalse(result.queryResult().isNull()); + assertNotNull(result.queryResult().topDocs()); + assertNotNull(result.queryResult().aggregations()); + } finally { + latch.countDown(); + } + } + + @Override + public void onFailure(Exception exc) { + try { + throw new AssertionError(exc); + } finally { + latch.countDown(); + } + } + }); + latch.await(); + } + + { + CountDownLatch latch = new CountDownLatch(1); + shardRequest.canReturnNullResponseIfMatchNoDocs(true); + service.executeQueryPhase(shardRequest, task, new ActionListener<>() { + @Override + public void onResponse(SearchPhaseResult result) { + try { + // make sure we don't use the search threadpool + assertSame(Thread.currentThread(), currentThread); + assertThat(result, instanceOf(QuerySearchResult.class)); + assertTrue(result.queryResult().isNull()); + } finally { + latch.countDown(); + } + } + + @Override + public void onFailure(Exception e) { + try { + throw new AssertionError(e); + } finally { + latch.countDown(); + } + } + }); + latch.await(); + } + } + + public void testDeleteIndexWhileSearch() throws Exception { + createIndex("test"); + int numDocs = randomIntBetween(1, 20); + for (int i = 0; i < numDocs; i++) { + prepareIndex("test").setSource("f", "v").get(); + } + indicesAdmin().prepareRefresh("test").get(); + AtomicBoolean stopped = new AtomicBoolean(false); + Thread[] searchers = new Thread[randomIntBetween(1, 4)]; + CountDownLatch latch = new CountDownLatch(searchers.length); + for (int i = 0; i < searchers.length; i++) { + searchers[i] = new Thread(() -> { + latch.countDown(); + while (stopped.get() == false) { + try { + client().prepareSearch("test").setRequestCache(false).get().decRef(); + } catch (Exception ignored) { + return; + } + } + }); + searchers[i].start(); + } + latch.await(); + indicesAdmin().prepareDelete("test").get(); + stopped.set(true); + for (Thread searcher : searchers) { + searcher.join(); + } + } + + public void testLookUpSearchContext() throws Exception { + createIndex("index"); + SearchService searchService = getInstanceFromNode(SearchService.class); + IndicesService indicesService = getInstanceFromNode(IndicesService.class); + IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); + IndexShard indexShard = indexService.getShard(0); + List contextIds = new ArrayList<>(); + int numContexts = randomIntBetween(1, 10); + CountDownLatch latch = new CountDownLatch(1); + indexShard.getThreadPool().executor(ThreadPool.Names.SEARCH).execute(() -> { + try { + for (int i = 0; i < numContexts; i++) { + ShardSearchRequest request = new ShardSearchRequest( + OriginalIndices.NONE, + new SearchRequest().allowPartialSearchResults(true), + indexShard.shardId(), + 0, + 1, + AliasFilter.EMPTY, + 1.0f, + -1, + null + ); + final ReaderContext context = searchService.createAndPutReaderContext( + request, + indexService, + indexShard, + indexShard.acquireSearcherSupplier(), + SearchService.KEEPALIVE_INTERVAL_SETTING.get(Settings.EMPTY).millis() + ); + assertThat(context.id().getId(), equalTo((long) (i + 1))); + contextIds.add(context.id()); + } + assertThat(searchService.getActiveContexts(), equalTo(contextIds.size())); + while (contextIds.isEmpty() == false) { + final ShardSearchContextId contextId = randomFrom(contextIds); + assertFalse(searchService.freeReaderContext(new ShardSearchContextId(UUIDs.randomBase64UUID(), contextId.getId()))); + assertThat(searchService.getActiveContexts(), equalTo(contextIds.size())); + if (randomBoolean()) { + assertTrue(searchService.freeReaderContext(contextId)); + } else { + assertTrue( + searchService.freeReaderContext((new ShardSearchContextId(contextId.getSessionId(), contextId.getId()))) + ); + } + contextIds.remove(contextId); + assertThat(searchService.getActiveContexts(), equalTo(contextIds.size())); + assertFalse(searchService.freeReaderContext(contextId)); + assertThat(searchService.getActiveContexts(), equalTo(contextIds.size())); + } + } finally { + latch.countDown(); + } + }); + latch.await(); + } + + public void testOpenReaderContext() { + createIndex("index"); + SearchService searchService = getInstanceFromNode(SearchService.class); + PlainActionFuture future = new PlainActionFuture<>(); + searchService.openReaderContext(new ShardId(resolveIndex("index"), 0), TimeValue.timeValueMinutes(between(1, 10)), future); + future.actionGet(); + assertThat(searchService.getActiveContexts(), equalTo(1)); + assertTrue(searchService.freeReaderContext(future.actionGet())); + } + + public void testCancelQueryPhaseEarly() throws Exception { + createIndex("index"); + final MockSearchService service = (MockSearchService) getInstanceFromNode(SearchService.class); + final IndicesService indicesService = getInstanceFromNode(IndicesService.class); + final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); + final IndexShard indexShard = indexService.getShard(0); + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + ShardSearchRequest request = new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + indexShard.shardId(), + 0, + 1, + AliasFilter.EMPTY, + 1.0f, + -1, + null + ); + + CountDownLatch latch1 = new CountDownLatch(1); + SearchShardTask task = new SearchShardTask(1, "", "", "", TaskId.EMPTY_TASK_ID, emptyMap()); + service.executeQueryPhase(request, task, new ActionListener<>() { + @Override + public void onResponse(SearchPhaseResult searchPhaseResult) { + service.freeReaderContext(searchPhaseResult.getContextId()); + latch1.countDown(); + } + + @Override + public void onFailure(Exception e) { + try { + fail("Search should not be cancelled"); + } finally { + latch1.countDown(); + } + } + }); + latch1.await(); + + CountDownLatch latch2 = new CountDownLatch(1); + service.executeDfsPhase(request, task, new ActionListener<>() { + @Override + public void onResponse(SearchPhaseResult searchPhaseResult) { + service.freeReaderContext(searchPhaseResult.getContextId()); + latch2.countDown(); + } + + @Override + public void onFailure(Exception e) { + try { + fail("Search should not be cancelled"); + } finally { + latch2.countDown(); + } + } + }); + latch2.await(); + + AtomicBoolean searchContextCreated = new AtomicBoolean(false); + service.setOnCreateSearchContext(c -> searchContextCreated.set(true)); + CountDownLatch latch3 = new CountDownLatch(1); + TaskCancelHelper.cancel(task, "simulated"); + service.executeQueryPhase(request, task, new ActionListener<>() { + @Override + public void onResponse(SearchPhaseResult searchPhaseResult) { + try { + fail("Search not cancelled early"); + } finally { + service.freeReaderContext(searchPhaseResult.getContextId()); + searchPhaseResult.decRef(); + latch3.countDown(); + } + } + + @Override + public void onFailure(Exception e) { + assertThat(e, is(instanceOf(TaskCancelledException.class))); + assertThat(e.getMessage(), is("task cancelled [simulated]")); + assertThat(((TaskCancelledException) e).status(), is(RestStatus.BAD_REQUEST)); + assertThat(searchContextCreated.get(), is(false)); + latch3.countDown(); + } + }); + latch3.await(); + + searchContextCreated.set(false); + CountDownLatch latch4 = new CountDownLatch(1); + service.executeDfsPhase(request, task, new ActionListener<>() { + @Override + public void onResponse(SearchPhaseResult searchPhaseResult) { + try { + fail("Search not cancelled early"); + } finally { + service.freeReaderContext(searchPhaseResult.getContextId()); + latch4.countDown(); + } + } + + @Override + public void onFailure(Exception e) { + assertThat(e, is(instanceOf(TaskCancelledException.class))); + assertThat(e.getMessage(), is("task cancelled [simulated]")); + assertThat(((TaskCancelledException) e).status(), is(RestStatus.BAD_REQUEST)); + assertThat(searchContextCreated.get(), is(false)); + latch4.countDown(); + } + }); + latch4.await(); + } + + public void testCancelFetchPhaseEarly() throws Exception { + createIndex("index"); + final MockSearchService service = (MockSearchService) getInstanceFromNode(SearchService.class); + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + + AtomicBoolean searchContextCreated = new AtomicBoolean(false); + service.setOnCreateSearchContext(c -> searchContextCreated.set(true)); + + // Test fetch phase is cancelled early + String scrollId; + var searchResponse = client().search(searchRequest.allowPartialSearchResults(false).scroll(TimeValue.timeValueMinutes(10))).get(); + try { + scrollId = searchResponse.getScrollId(); + } finally { + searchResponse.decRef(); + } + + client().searchScroll(new SearchScrollRequest(scrollId)).get().decRef(); + assertThat(searchContextCreated.get(), is(true)); + + ClearScrollRequest clearScrollRequest = new ClearScrollRequest(); + clearScrollRequest.addScrollId(scrollId); + client().clearScroll(clearScrollRequest); + + searchResponse = client().search(searchRequest.allowPartialSearchResults(false).scroll(TimeValue.timeValueMinutes(10))).get(); + try { + scrollId = searchResponse.getScrollId(); + } finally { + searchResponse.decRef(); + } + searchContextCreated.set(false); + service.setOnCheckCancelled(t -> { + SearchShardTask task = new SearchShardTask(randomLong(), "transport", "action", "", TaskId.EMPTY_TASK_ID, emptyMap()); + TaskCancelHelper.cancel(task, "simulated"); + return task; + }); + CountDownLatch latch = new CountDownLatch(1); + client().searchScroll(new SearchScrollRequest(scrollId), new ActionListener<>() { + @Override + public void onResponse(SearchResponse searchResponse) { + try { + fail("Search not cancelled early"); + } finally { + latch.countDown(); + } + } + + @Override + public void onFailure(Exception e) { + Throwable cancelledExc = e.getCause().getCause(); + assertThat(cancelledExc, is(instanceOf(TaskCancelledException.class))); + assertThat(cancelledExc.getMessage(), is("task cancelled [simulated]")); + assertThat(((TaskCancelledException) cancelledExc).status(), is(RestStatus.BAD_REQUEST)); + latch.countDown(); + } + }); + latch.await(); + assertThat(searchContextCreated.get(), is(false)); + + clearScrollRequest.setScrollIds(singletonList(scrollId)); + client().clearScroll(clearScrollRequest); + } + + public void testWaitOnRefresh() throws ExecutionException, InterruptedException { + createIndex("index"); + final SearchService service = getInstanceFromNode(SearchService.class); + final IndicesService indicesService = getInstanceFromNode(IndicesService.class); + final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); + final IndexShard indexShard = indexService.getShard(0); + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + searchRequest.setWaitForCheckpointsTimeout(TimeValue.timeValueSeconds(30)); + searchRequest.setWaitForCheckpoints(Collections.singletonMap("index", new long[] { 0 })); + + final DocWriteResponse response = prepareIndex("index").setSource("id", "1").get(); + assertEquals(RestStatus.CREATED, response.status()); + + SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); + ShardSearchRequest request = new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + indexShard.shardId(), + 0, + 1, + AliasFilter.EMPTY, + 1.0f, + -1, + null, + null, + null + ); + PlainActionFuture future = new PlainActionFuture<>(); + service.executeQueryPhase(request, task, future.delegateFailure((l, r) -> { + assertEquals(1, r.queryResult().getTotalHits().value()); + l.onResponse(null); + })); + future.get(); + } + + public void testWaitOnRefreshFailsWithRefreshesDisabled() { + createIndex("index", Settings.builder().put("index.refresh_interval", "-1").build()); + final SearchService service = getInstanceFromNode(SearchService.class); + final IndicesService indicesService = getInstanceFromNode(IndicesService.class); + final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); + final IndexShard indexShard = indexService.getShard(0); + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + searchRequest.setWaitForCheckpointsTimeout(TimeValue.timeValueSeconds(30)); + searchRequest.setWaitForCheckpoints(Collections.singletonMap("index", new long[] { 0 })); + + final DocWriteResponse response = prepareIndex("index").setSource("id", "1").get(); + assertEquals(RestStatus.CREATED, response.status()); + + SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); + PlainActionFuture future = new PlainActionFuture<>(); + ShardSearchRequest request = new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + indexShard.shardId(), + 0, + 1, + AliasFilter.EMPTY, + 1.0f, + -1, + null, + null, + null + ); + service.executeQueryPhase(request, task, future); + IllegalArgumentException illegalArgumentException = expectThrows(IllegalArgumentException.class, future::actionGet); + assertThat( + illegalArgumentException.getMessage(), + containsString("Cannot use wait_for_checkpoints with [index.refresh_interval=-1]") + ); + } + + public void testWaitOnRefreshFailsIfCheckpointNotIndexed() { + createIndex("index"); + final SearchService service = getInstanceFromNode(SearchService.class); + final IndicesService indicesService = getInstanceFromNode(IndicesService.class); + final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); + final IndexShard indexShard = indexService.getShard(0); + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + // Increased timeout to avoid cancelling the search task prior to its completion, + // as we expect to raise an Exception. Timeout itself is tested on the following `testWaitOnRefreshTimeout` test. + searchRequest.setWaitForCheckpointsTimeout(TimeValue.timeValueMillis(randomIntBetween(200, 300))); + searchRequest.setWaitForCheckpoints(Collections.singletonMap("index", new long[] { 1 })); + + final DocWriteResponse response = prepareIndex("index").setSource("id", "1").get(); + assertEquals(RestStatus.CREATED, response.status()); + + SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); + PlainActionFuture future = new PlainActionFuture<>(); + ShardSearchRequest request = new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + indexShard.shardId(), + 0, + 1, + AliasFilter.EMPTY, + 1.0f, + -1, + null, + null, + null + ); + service.executeQueryPhase(request, task, future); + + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, future::actionGet); + assertThat( + ex.getMessage(), + containsString("Cannot wait for unissued seqNo checkpoint [wait_for_checkpoint=1, max_issued_seqNo=0]") + ); + } + + public void testWaitOnRefreshTimeout() { + createIndex("index", Settings.builder().put("index.refresh_interval", "60s").build()); + final SearchService service = getInstanceFromNode(SearchService.class); + final IndicesService indicesService = getInstanceFromNode(IndicesService.class); + final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); + final IndexShard indexShard = indexService.getShard(0); + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + searchRequest.setWaitForCheckpointsTimeout(TimeValue.timeValueMillis(randomIntBetween(10, 100))); + searchRequest.setWaitForCheckpoints(Collections.singletonMap("index", new long[] { 0 })); + + final DocWriteResponse response = prepareIndex("index").setSource("id", "1").get(); + assertEquals(RestStatus.CREATED, response.status()); + + SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); + PlainActionFuture future = new PlainActionFuture<>(); + ShardSearchRequest request = new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + indexShard.shardId(), + 0, + 1, + AliasFilter.EMPTY, + 1.0f, + -1, + null, + null, + null + ); + service.executeQueryPhase(request, task, future); + + SearchTimeoutException ex = expectThrows(SearchTimeoutException.class, future::actionGet); + assertThat(ex.getMessage(), containsString("Wait for seq_no [0] refreshed timed out [")); + } + + public void testMinimalSearchSourceInShardRequests() { + createIndex("test"); + int numDocs = between(0, 10); + for (int i = 0; i < numDocs; i++) { + prepareIndex("test").setSource("id", Integer.toString(i)).get(); + } + indicesAdmin().prepareRefresh("test").get(); + + BytesReference pitId = client().execute( + TransportOpenPointInTimeAction.TYPE, + new OpenPointInTimeRequest("test").keepAlive(TimeValue.timeValueMinutes(10)) + ).actionGet().getPointInTimeId(); + final MockSearchService searchService = (MockSearchService) getInstanceFromNode(SearchService.class); + final List shardRequests = new CopyOnWriteArrayList<>(); + searchService.setOnCreateSearchContext(ctx -> shardRequests.add(ctx.request())); + try { + assertHitCount( + client().prepareSearch() + .setSource( + new SearchSourceBuilder().size(between(numDocs, numDocs * 2)).pointInTimeBuilder(new PointInTimeBuilder(pitId)) + ), + numDocs + ); + } finally { + client().execute(TransportClosePointInTimeAction.TYPE, new ClosePointInTimeRequest(pitId)).actionGet(); + } + assertThat(shardRequests, not(emptyList())); + for (ShardSearchRequest shardRequest : shardRequests) { + assertNotNull(shardRequest.source()); + assertNotNull(shardRequest.source().pointInTimeBuilder()); + assertThat(shardRequest.source().pointInTimeBuilder().getEncodedId(), equalTo(BytesArray.EMPTY)); + } + } + + public void testDfsQueryPhaseRewrite() { + createIndex("index"); + prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); + final SearchService service = getInstanceFromNode(SearchService.class); + final IndicesService indicesService = getInstanceFromNode(IndicesService.class); + final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); + final IndexShard indexShard = indexService.getShard(0); + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + searchRequest.source(SearchSourceBuilder.searchSource().query(new TestRewriteCounterQueryBuilder())); + ShardSearchRequest request = new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + indexShard.shardId(), + 0, + 1, + AliasFilter.EMPTY, + 1.0f, + -1, + null + ); + final Engine.SearcherSupplier reader = indexShard.acquireSearcherSupplier(); + ReaderContext context = service.createAndPutReaderContext( + request, + indexService, + indexShard, + reader, + SearchService.KEEPALIVE_INTERVAL_SETTING.get(Settings.EMPTY).millis() + ); + PlainActionFuture plainActionFuture = new PlainActionFuture<>(); + service.executeQueryPhase( + new QuerySearchRequest(null, context.id(), request, new AggregatedDfs(Map.of(), Map.of(), 10)), + new SearchShardTask(42L, "", "", "", null, emptyMap()), + plainActionFuture + ); + + plainActionFuture.actionGet(); + assertThat(((TestRewriteCounterQueryBuilder) request.source().query()).asyncRewriteCount, equalTo(1)); + final ShardSearchContextId contextId = context.id(); + assertTrue(service.freeReaderContext(contextId)); + } + + public void testEnableSearchWorkerThreads() throws IOException { + IndexService indexService = createIndex("index", Settings.EMPTY); + IndexShard indexShard = indexService.getShard(0); + ShardSearchRequest request = new ShardSearchRequest( + OriginalIndices.NONE, + new SearchRequest().allowPartialSearchResults(randomBoolean()), + indexShard.shardId(), + 0, + indexService.numberOfShards(), + AliasFilter.EMPTY, + 1f, + System.currentTimeMillis(), + null + ); + try (ReaderContext readerContext = createReaderContext(indexService, indexShard)) { + SearchService service = getInstanceFromNode(SearchService.class); + SearchShardTask task = new SearchShardTask(0, "type", "action", "description", null, emptyMap()); + + try (SearchContext searchContext = service.createContext(readerContext, request, task, ResultsType.DFS, randomBoolean())) { + assertTrue(searchContext.searcher().hasExecutor()); + } + + try { + ClusterUpdateSettingsResponse response = client().admin() + .cluster() + .prepareUpdateSettings(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT) + .setPersistentSettings(Settings.builder().put(SEARCH_WORKER_THREADS_ENABLED.getKey(), false).build()) + .get(); + assertTrue(response.isAcknowledged()); + try (SearchContext searchContext = service.createContext(readerContext, request, task, ResultsType.DFS, randomBoolean())) { + assertFalse(searchContext.searcher().hasExecutor()); + } + } finally { + // reset original default setting + client().admin() + .cluster() + .prepareUpdateSettings(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT) + .setPersistentSettings(Settings.builder().putNull(SEARCH_WORKER_THREADS_ENABLED.getKey()).build()) + .get(); + try (SearchContext searchContext = service.createContext(readerContext, request, task, ResultsType.DFS, randomBoolean())) { + assertTrue(searchContext.searcher().hasExecutor()); + } + } + } + } + + /** + * Verify that a single slice is created for requests that don't support parallel collection, while an executor is still + * provided to the searcher to parallelize other operations. Also ensure multiple slices are created for requests that do support + * parallel collection. + */ + public void testSlicingBehaviourForParallelCollection() throws Exception { + IndexService indexService = createIndex("index", Settings.EMPTY); + ThreadPoolExecutor executor = (ThreadPoolExecutor) indexService.getThreadPool().executor(ThreadPool.Names.SEARCH); + final int configuredMaxPoolSize = 10; + executor.setMaximumPoolSize(configuredMaxPoolSize); // We set this explicitly to be independent of CPU cores. + int numDocs = randomIntBetween(50, 100); + for (int i = 0; i < numDocs; i++) { + prepareIndex("index").setId(String.valueOf(i)).setSource("field", "value").get(); + if (i % 5 == 0) { + indicesAdmin().prepareRefresh("index").get(); + } + } + final IndexShard indexShard = indexService.getShard(0); + ShardSearchRequest request = new ShardSearchRequest( + OriginalIndices.NONE, + new SearchRequest().allowPartialSearchResults(randomBoolean()), + indexShard.shardId(), + 0, + indexService.numberOfShards(), + AliasFilter.EMPTY, + 1f, + System.currentTimeMillis(), + null + ); + SearchService service = getInstanceFromNode(SearchService.class); + NonCountingTermQuery termQuery = new NonCountingTermQuery(new Term("field", "value")); + assertEquals(0, executor.getCompletedTaskCount()); + try (ReaderContext readerContext = createReaderContext(indexService, indexShard)) { + SearchShardTask task = new SearchShardTask(0, "type", "action", "description", null, emptyMap()); + { + try (SearchContext searchContext = service.createContext(readerContext, request, task, ResultsType.DFS, true)) { + ContextIndexSearcher searcher = searchContext.searcher(); + assertTrue(searcher.hasExecutor()); + + final int maxPoolSize = executor.getMaximumPoolSize(); + assertEquals( + "Sanity check to ensure this isn't the default of 1 when pool size is unset", + configuredMaxPoolSize, + maxPoolSize + ); + + final int expectedSlices = ContextIndexSearcher.computeSlices( + searcher.getIndexReader().leaves(), + maxPoolSize, + 1 + ).length; + assertNotEquals("Sanity check to ensure this isn't the default of 1 when pool size is unset", 1, expectedSlices); + + final long priorExecutorTaskCount = executor.getCompletedTaskCount(); + searcher.search(termQuery, new TotalHitCountCollectorManager(searcher.getSlices())); + assertBusy( + () -> assertEquals( + "DFS supports parallel collection, so the number of slices should be > 1.", + expectedSlices - 1, // one slice executes on the calling thread + executor.getCompletedTaskCount() - priorExecutorTaskCount + ) + ); + } + } + { + try (SearchContext searchContext = service.createContext(readerContext, request, task, ResultsType.QUERY, true)) { + ContextIndexSearcher searcher = searchContext.searcher(); + assertTrue(searcher.hasExecutor()); + + final int maxPoolSize = executor.getMaximumPoolSize(); + assertEquals( + "Sanity check to ensure this isn't the default of 1 when pool size is unset", + configuredMaxPoolSize, + maxPoolSize + ); + + final int expectedSlices = ContextIndexSearcher.computeSlices( + searcher.getIndexReader().leaves(), + maxPoolSize, + 1 + ).length; + assertNotEquals("Sanity check to ensure this isn't the default of 1 when pool size is unset", 1, expectedSlices); + + final long priorExecutorTaskCount = executor.getCompletedTaskCount(); + searcher.search(termQuery, new TotalHitCountCollectorManager(searcher.getSlices())); + assertBusy( + () -> assertEquals( + "QUERY supports parallel collection when enabled, so the number of slices should be > 1.", + expectedSlices - 1, // one slice executes on the calling thread + executor.getCompletedTaskCount() - priorExecutorTaskCount + ) + ); + } + } + { + try (SearchContext searchContext = service.createContext(readerContext, request, task, ResultsType.FETCH, true)) { + ContextIndexSearcher searcher = searchContext.searcher(); + assertFalse(searcher.hasExecutor()); + final long priorExecutorTaskCount = executor.getCompletedTaskCount(); + searcher.search(termQuery, new TotalHitCountCollectorManager(searcher.getSlices())); + assertBusy( + () -> assertEquals( + "The number of slices should be 1 as FETCH does not support parallel collection and thus runs on the calling" + + " thread.", + 0, + executor.getCompletedTaskCount() - priorExecutorTaskCount + ) + ); + } + } + { + try (SearchContext searchContext = service.createContext(readerContext, request, task, ResultsType.NONE, true)) { + ContextIndexSearcher searcher = searchContext.searcher(); + assertFalse(searcher.hasExecutor()); + final long priorExecutorTaskCount = executor.getCompletedTaskCount(); + searcher.search(termQuery, new TotalHitCountCollectorManager(searcher.getSlices())); + assertBusy( + () -> assertEquals( + "The number of slices should be 1 as NONE does not support parallel collection.", + 0, // zero since one slice executes on the calling thread + executor.getCompletedTaskCount() - priorExecutorTaskCount + ) + ); + } + } + + try { + ClusterUpdateSettingsResponse response = client().admin() + .cluster() + .prepareUpdateSettings(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT) + .setPersistentSettings(Settings.builder().put(QUERY_PHASE_PARALLEL_COLLECTION_ENABLED.getKey(), false).build()) + .get(); + assertTrue(response.isAcknowledged()); + { + try (SearchContext searchContext = service.createContext(readerContext, request, task, ResultsType.QUERY, true)) { + ContextIndexSearcher searcher = searchContext.searcher(); + assertFalse(searcher.hasExecutor()); + final long priorExecutorTaskCount = executor.getCompletedTaskCount(); + searcher.search(termQuery, new TotalHitCountCollectorManager(searcher.getSlices())); + assertBusy( + () -> assertEquals( + "The number of slices should be 1 when QUERY parallel collection is disabled.", + 0, // zero since one slice executes on the calling thread + executor.getCompletedTaskCount() - priorExecutorTaskCount + ) + ); + } + } + } finally { + // Reset to the original default setting and check to ensure it takes effect. + client().admin() + .cluster() + .prepareUpdateSettings(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT) + .setPersistentSettings(Settings.builder().putNull(QUERY_PHASE_PARALLEL_COLLECTION_ENABLED.getKey()).build()) + .get(); + { + try (SearchContext searchContext = service.createContext(readerContext, request, task, ResultsType.QUERY, true)) { + ContextIndexSearcher searcher = searchContext.searcher(); + assertTrue(searcher.hasExecutor()); + + final int maxPoolSize = executor.getMaximumPoolSize(); + assertEquals( + "Sanity check to ensure this isn't the default of 1 when pool size is unset", + configuredMaxPoolSize, + maxPoolSize + ); + + final int expectedSlices = ContextIndexSearcher.computeSlices( + searcher.getIndexReader().leaves(), + maxPoolSize, + 1 + ).length; + assertNotEquals("Sanity check to ensure this isn't the default of 1 when pool size is unset", 1, expectedSlices); + + final long priorExecutorTaskCount = executor.getCompletedTaskCount(); + searcher.search(termQuery, new TotalHitCountCollectorManager(searcher.getSlices())); + assertBusy( + () -> assertEquals( + "QUERY supports parallel collection when enabled, so the number of slices should be > 1.", + expectedSlices - 1, // one slice executes on the calling thread + executor.getCompletedTaskCount() - priorExecutorTaskCount + ) + ); + } + } + } + } + } + + private static ReaderContext createReaderContext(IndexService indexService, IndexShard indexShard) { + return new ReaderContext( + new ShardSearchContextId(UUIDs.randomBase64UUID(), randomNonNegativeLong()), + indexService, + indexShard, + indexShard.acquireSearcherSupplier(), + randomNonNegativeLong(), + false + ); + } + + private static class TestRewriteCounterQueryBuilder extends AbstractQueryBuilder { + + final int asyncRewriteCount; + final Supplier fetched; + + TestRewriteCounterQueryBuilder() { + asyncRewriteCount = 0; + fetched = null; + } + + private TestRewriteCounterQueryBuilder(int asyncRewriteCount, Supplier fetched) { + this.asyncRewriteCount = asyncRewriteCount; + this.fetched = fetched; + } + + @Override + public String getWriteableName() { + return "test_query"; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ZERO; + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException {} + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException {} + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + return new MatchAllDocsQuery(); + } + + @Override + protected boolean doEquals(TestRewriteCounterQueryBuilder other) { + return true; + } + + @Override + protected int doHashCode() { + return 42; + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (asyncRewriteCount > 0) { + return this; + } + if (fetched != null) { + if (fetched.get() == null) { + return this; + } + assert fetched.get(); + return new TestRewriteCounterQueryBuilder(1, null); + } + if (queryRewriteContext.convertToDataRewriteContext() != null) { + SetOnce awaitingFetch = new SetOnce<>(); + queryRewriteContext.registerAsyncAction((c, l) -> { + awaitingFetch.set(true); + l.onResponse(null); + }); + return new TestRewriteCounterQueryBuilder(0, awaitingFetch::get); + } + return this; + } + } +} diff --git a/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java b/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java index 89fd25f638e1c..31bcab31ca8a7 100644 --- a/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java +++ b/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java @@ -6,3006 +6,298 @@ * your election, the "Elastic License 2.0", the "GNU Affero General Public * License v3.0 only", or the "Server Side Public License, v 1". */ + package org.elasticsearch.search; -import org.apache.lucene.index.DirectoryReader; -import org.apache.lucene.index.FilterDirectoryReader; -import org.apache.lucene.index.LeafReader; -import org.apache.lucene.index.Term; -import org.apache.lucene.search.MatchAllDocsQuery; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHitCountCollectorManager; -import org.apache.lucene.store.AlreadyClosedException; -import org.apache.lucene.util.SetOnce; -import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.DocWriteResponse; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.SortField; +import org.apache.lucene.util.BytesRef; import org.elasticsearch.action.OriginalIndices; -import org.elasticsearch.action.admin.cluster.settings.ClusterUpdateSettingsResponse; -import org.elasticsearch.action.search.ClearScrollRequest; -import org.elasticsearch.action.search.ClosePointInTimeRequest; -import org.elasticsearch.action.search.OpenPointInTimeRequest; -import org.elasticsearch.action.search.SearchPhaseController; -import org.elasticsearch.action.search.SearchPhaseExecutionException; import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.action.search.SearchScrollRequest; -import org.elasticsearch.action.search.SearchShardTask; -import org.elasticsearch.action.search.SearchType; -import org.elasticsearch.action.search.TransportClosePointInTimeAction; -import org.elasticsearch.action.search.TransportOpenPointInTimeAction; -import org.elasticsearch.action.search.TransportSearchAction; -import org.elasticsearch.action.support.IndicesOptions; -import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.action.support.WriteRequest; -import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.cluster.routing.ShardRouting; -import org.elasticsearch.cluster.routing.ShardRoutingState; -import org.elasticsearch.cluster.routing.TestShardRouting; -import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.UUIDs; -import org.elasticsearch.common.breaker.CircuitBreakingException; -import org.elasticsearch.common.breaker.NoopCircuitBreaker; -import org.elasticsearch.common.bytes.BytesArray; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.regex.Regex; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.core.AbstractRefCounted; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.index.Index; -import org.elasticsearch.index.IndexModule; -import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.index.IndexService; import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.engine.Engine; -import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.fielddata.FieldDataContext; +import org.elasticsearch.index.fielddata.IndexFieldData; +import org.elasticsearch.index.fielddata.LeafFieldData; +import org.elasticsearch.index.mapper.KeywordFieldMapper; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.MapperMetrics; +import org.elasticsearch.index.mapper.Mapping; +import org.elasticsearch.index.mapper.MappingLookup; +import org.elasticsearch.index.mapper.MetadataFieldMapper; +import org.elasticsearch.index.mapper.ObjectMapper; +import org.elasticsearch.index.mapper.RootObjectMapper; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.MatchNoneQueryBuilder; -import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.index.query.TermQueryBuilder; -import org.elasticsearch.index.search.stats.SearchStats; import org.elasticsearch.index.shard.IndexShard; -import org.elasticsearch.index.shard.SearchOperationListener; +import org.elasticsearch.index.shard.IndexShardTestCase; import org.elasticsearch.index.shard.ShardId; -import org.elasticsearch.indices.IndicesService; -import org.elasticsearch.indices.settings.InternalOrPrivateSettingsPlugin; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.plugins.SearchPlugin; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.script.MockScriptEngine; -import org.elasticsearch.script.MockScriptPlugin; -import org.elasticsearch.script.Script; -import org.elasticsearch.script.ScriptType; -import org.elasticsearch.search.SearchService.ResultsType; -import org.elasticsearch.search.aggregations.AggregationBuilders; -import org.elasticsearch.search.aggregations.AggregationReduceContext; -import org.elasticsearch.search.aggregations.MultiBucketConsumerService; -import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregationBuilder; -import org.elasticsearch.search.aggregations.bucket.global.GlobalAggregationBuilder; -import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; -import org.elasticsearch.search.aggregations.support.AggregationContext; -import org.elasticsearch.search.aggregations.support.ValueType; -import org.elasticsearch.search.builder.PointInTimeBuilder; +import org.elasticsearch.search.aggregations.support.ValuesSourceType; import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.search.dfs.AggregatedDfs; -import org.elasticsearch.search.fetch.FetchSearchResult; -import org.elasticsearch.search.fetch.ShardFetchRequest; -import org.elasticsearch.search.fetch.ShardFetchSearchRequest; -import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.internal.AliasFilter; -import org.elasticsearch.search.internal.ContextIndexSearcher; -import org.elasticsearch.search.internal.ReaderContext; -import org.elasticsearch.search.internal.SearchContext; -import org.elasticsearch.search.internal.ShardSearchContextId; import org.elasticsearch.search.internal.ShardSearchRequest; -import org.elasticsearch.search.query.NonCountingTermQuery; -import org.elasticsearch.search.query.QuerySearchRequest; -import org.elasticsearch.search.query.QuerySearchResult; -import org.elasticsearch.search.query.SearchTimeoutException; -import org.elasticsearch.search.rank.RankBuilder; -import org.elasticsearch.search.rank.RankDoc; -import org.elasticsearch.search.rank.RankShardResult; -import org.elasticsearch.search.rank.TestRankBuilder; -import org.elasticsearch.search.rank.TestRankShardResult; -import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext; -import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; -import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext; -import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext; -import org.elasticsearch.search.rank.feature.RankFeatureDoc; -import org.elasticsearch.search.rank.feature.RankFeatureResult; -import org.elasticsearch.search.rank.feature.RankFeatureShardRequest; -import org.elasticsearch.search.rank.feature.RankFeatureShardResult; -import org.elasticsearch.search.suggest.SuggestBuilder; -import org.elasticsearch.tasks.TaskCancelHelper; -import org.elasticsearch.tasks.TaskCancelledException; -import org.elasticsearch.tasks.TaskId; -import org.elasticsearch.test.ESSingleNodeTestCase; -import org.elasticsearch.test.hamcrest.ElasticsearchAssertions; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.json.JsonXContent; -import org.junit.Before; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.MinAndMax; +import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.xcontent.XContentParserConfiguration; import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; import java.util.Collections; -import java.util.Comparator; -import java.util.LinkedList; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Semaphore; -import java.util.concurrent.ThreadPoolExecutor; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Consumer; -import java.util.function.Function; -import java.util.function.IntConsumer; -import java.util.function.Supplier; - -import static java.util.Collections.emptyList; -import static java.util.Collections.emptyMap; -import static java.util.Collections.singletonList; -import static org.elasticsearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; -import static org.elasticsearch.indices.cluster.IndicesClusterStateService.AllocatedIndices.IndexRemovalReason.DELETED; -import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; -import static org.elasticsearch.search.SearchService.QUERY_PHASE_PARALLEL_COLLECTION_ENABLED; -import static org.elasticsearch.search.SearchService.SEARCH_WORKER_THREADS_ENABLED; -import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; -import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; -import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; -import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchHits; -import static org.hamcrest.CoreMatchers.containsString; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.CoreMatchers.instanceOf; -import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.CoreMatchers.notNullValue; -import static org.hamcrest.CoreMatchers.startsWith; -import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.hamcrest.Matchers.not; -import static org.mockito.Mockito.mock; +import java.util.function.BiFunction; +import java.util.function.Predicate; -public class SearchServiceTests extends ESSingleNodeTestCase { +public class SearchServiceTests extends IndexShardTestCase { - @Override - protected boolean resetNodeAfterTest() { - return true; + public void testCanMatchMatchAll() throws IOException { + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(false) + .source(new SearchSourceBuilder().query(new MatchAllQueryBuilder())); + doTestCanMatch(searchRequest, null, true, null, false); } - @Override - protected Collection> getPlugins() { - return pluginList( - FailOnRewriteQueryPlugin.class, - CustomScriptPlugin.class, - ReaderWrapperCountPlugin.class, - InternalOrPrivateSettingsPlugin.class, - MockSearchService.TestPlugin.class - ); + public void testCanMatchMatchNone() throws IOException { + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(false) + .source(new SearchSourceBuilder().query(new MatchNoneQueryBuilder())); + doTestCanMatch(searchRequest, null, false, null, false); } - public static class ReaderWrapperCountPlugin extends Plugin { - @Override - public void onIndexModule(IndexModule indexModule) { - indexModule.setReaderWrapper(service -> SearchServiceTests::apply); - } + public void testCanMatchMatchNoneWithException() throws IOException { + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(false) + .source(new SearchSourceBuilder().query(new MatchNoneQueryBuilder())); + doTestCanMatch(searchRequest, null, true, null, true); } - @Before - public void resetCount() { - numWrapInvocations = new AtomicInteger(0); + public void testCanMatchKeywordSortedQueryMatchNone() throws IOException { + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(false) + .source(new SearchSourceBuilder().sort("field").query(new MatchNoneQueryBuilder())); + SortField sortField = new SortField("field", SortField.Type.STRING); + doTestCanMatch(searchRequest, sortField, false, null, false); } - private static AtomicInteger numWrapInvocations = new AtomicInteger(0); - - private static DirectoryReader apply(DirectoryReader directoryReader) throws IOException { - numWrapInvocations.incrementAndGet(); - return new FilterDirectoryReader(directoryReader, new FilterDirectoryReader.SubReaderWrapper() { - @Override - public LeafReader wrap(LeafReader reader) { - return reader; - } - }) { - @Override - protected DirectoryReader doWrapDirectoryReader(DirectoryReader in) throws IOException { - return in; - } + public void testCanMatchKeywordSortedQueryMatchAll() throws IOException { + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(false) + .source(new SearchSourceBuilder().sort("field").query(new MatchAllQueryBuilder())); + SortField sortField = new SortField("field", SortField.Type.STRING); + MinAndMax expectedMinAndMax = new MinAndMax<>(new BytesRef("value"), new BytesRef("value")); + doTestCanMatch(searchRequest, sortField, true, expectedMinAndMax, false); + } + public void testCanMatchKeywordSortedQueryMatchNoneWithException() throws IOException { + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(false) + .source(new SearchSourceBuilder().sort("field").query(new MatchNoneQueryBuilder())); + // provide a sort field that throws exception + SortField sortField = new SortField("field", SortField.Type.STRING) { @Override - public CacheHelper getReaderCacheHelper() { - return directoryReader.getReaderCacheHelper(); + public Type getType() { + throw new UnsupportedOperationException(); } }; + doTestCanMatch(searchRequest, sortField, false, null, false); } - public static class CustomScriptPlugin extends MockScriptPlugin { - - static final String DUMMY_SCRIPT = "dummyScript"; - - @Override - protected Map, Object>> pluginScripts() { - return Collections.singletonMap(DUMMY_SCRIPT, vars -> "dummy"); - } - - @Override - public void onIndexModule(IndexModule indexModule) { - indexModule.addSearchOperationListener(new SearchOperationListener() { - @Override - public void onFetchPhase(SearchContext context, long tookInNanos) { - if ("throttled_threadpool_index".equals(context.indexShard().shardId().getIndex().getName())) { - assertThat(Thread.currentThread().getName(), startsWith("elasticsearch[node_s_0][search_throttled]")); - } else { - assertThat(Thread.currentThread().getName(), startsWith("elasticsearch[node_s_0][search]")); - } - } - - @Override - public void onQueryPhase(SearchContext context, long tookInNanos) { - if ("throttled_threadpool_index".equals(context.indexShard().shardId().getIndex().getName())) { - assertThat(Thread.currentThread().getName(), startsWith("elasticsearch[node_s_0][search_throttled]")); - } else { - assertThat(Thread.currentThread().getName(), startsWith("elasticsearch[node_s_0][search]")); - } - } - }); - } - } - - @Override - protected Settings nodeSettings() { - return Settings.builder().put("search.default_search_timeout", "5s").build(); - } - - public void testClearOnClose() { - createIndex("index"); - prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); - assertResponse( - client().prepareSearch("index").setSize(1).setScroll(TimeValue.timeValueMinutes(1)), - searchResponse -> assertThat(searchResponse.getScrollId(), is(notNullValue())) - ); - SearchService service = getInstanceFromNode(SearchService.class); - - assertEquals(1, service.getActiveContexts()); - service.doClose(); // this kills the keep-alive reaper we have to reset the node after this test - assertEquals(0, service.getActiveContexts()); - } - - public void testClearOnStop() { - createIndex("index"); - prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); - assertResponse( - client().prepareSearch("index").setSize(1).setScroll(TimeValue.timeValueMinutes(1)), - searchResponse -> assertThat(searchResponse.getScrollId(), is(notNullValue())) - ); - SearchService service = getInstanceFromNode(SearchService.class); - - assertEquals(1, service.getActiveContexts()); - service.doStop(); - assertEquals(0, service.getActiveContexts()); - } - - public void testClearIndexDelete() { - createIndex("index"); - prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); - assertResponse( - client().prepareSearch("index").setSize(1).setScroll(TimeValue.timeValueMinutes(1)), - searchResponse -> assertThat(searchResponse.getScrollId(), is(notNullValue())) - ); - SearchService service = getInstanceFromNode(SearchService.class); - - assertEquals(1, service.getActiveContexts()); - assertAcked(indicesAdmin().prepareDelete("index")); - awaitIndexShardCloseAsyncTasks(); - assertEquals(0, service.getActiveContexts()); - } - - public void testCloseSearchContextOnRewriteException() { - // if refresh happens while checking the exception, the subsequent reference count might not match, so we switch it off - createIndex("index", Settings.builder().put("index.refresh_interval", -1).build()); - prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); - - SearchService service = getInstanceFromNode(SearchService.class); - IndicesService indicesService = getInstanceFromNode(IndicesService.class); - IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); - IndexShard indexShard = indexService.getShard(0); - - final int activeContexts = service.getActiveContexts(); - final int activeRefs = indexShard.store().refCount(); - expectThrows( - SearchPhaseExecutionException.class, - () -> client().prepareSearch("index").setQuery(new FailOnRewriteQueryBuilder()).get() - ); - assertEquals(activeContexts, service.getActiveContexts()); - assertEquals(activeRefs, indexShard.store().refCount()); - } - - public void testSearchWhileIndexDeleted() throws InterruptedException { - createIndex("index"); - prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); - - SearchService service = getInstanceFromNode(SearchService.class); - IndicesService indicesService = getInstanceFromNode(IndicesService.class); - IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); - IndexShard indexShard = indexService.getShard(0); - AtomicBoolean running = new AtomicBoolean(true); - CountDownLatch startGun = new CountDownLatch(1); - final int permitCount = 100; - Semaphore semaphore = new Semaphore(permitCount); - ShardRouting routing = TestShardRouting.newShardRouting( - indexShard.shardId(), - randomAlphaOfLength(5), - randomBoolean(), - ShardRoutingState.INITIALIZING - ); - final Thread thread = new Thread(() -> { - startGun.countDown(); - while (running.get()) { - if (randomBoolean()) { - service.afterIndexRemoved(indexService.index(), indexService.getIndexSettings(), DELETED); - } else { - service.beforeIndexShardCreated(routing, indexService.getIndexSettings().getSettings()); - } - if (randomBoolean()) { - // here we trigger some refreshes to ensure the IR go out of scope such that we hit ACE if we access a search - // context in a non-sane way. - try { - semaphore.acquire(); - } catch (InterruptedException e) { - throw new AssertionError(e); - } - prepareIndex("index").setSource("field", "value") - .setRefreshPolicy(randomFrom(WriteRequest.RefreshPolicy.values())) - .execute(ActionListener.running(semaphore::release)); - } - } - }); - thread.start(); - startGun.await(); - try { - final int rounds = scaledRandomIntBetween(100, 10000); - SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); - SearchRequest scrollSearchRequest = new SearchRequest().allowPartialSearchResults(true) - .scroll(new Scroll(TimeValue.timeValueMinutes(1))); - for (int i = 0; i < rounds; i++) { - try { - try { - PlainActionFuture result = new PlainActionFuture<>(); - final boolean useScroll = randomBoolean(); - service.executeQueryPhase( - new ShardSearchRequest( - OriginalIndices.NONE, - useScroll ? scrollSearchRequest : searchRequest, - indexShard.shardId(), - 0, - 1, - AliasFilter.EMPTY, - 1.0f, - -1, - null - ), - new SearchShardTask(123L, "", "", "", null, emptyMap()), - result.delegateFailure((l, r) -> { - r.incRef(); - l.onResponse(r); - }) - ); - final SearchPhaseResult searchPhaseResult = result.get(); - try { - List intCursors = new ArrayList<>(1); - intCursors.add(0); - ShardFetchRequest req = new ShardFetchRequest( - searchPhaseResult.getContextId(), - intCursors, - null/* not a scroll */ - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.executeFetchPhase(req, new SearchShardTask(123L, "", "", "", null, emptyMap()), listener); - listener.get(); - if (useScroll) { - // have to free context since this test does not remove the index from IndicesService. - service.freeReaderContext(searchPhaseResult.getContextId()); - } - } finally { - searchPhaseResult.decRef(); - } - } catch (ExecutionException ex) { - assertThat(ex.getCause(), instanceOf(RuntimeException.class)); - throw ((RuntimeException) ex.getCause()); - } - } catch (AlreadyClosedException ex) { - throw ex; - } catch (IllegalStateException ex) { - assertEquals(AbstractRefCounted.ALREADY_CLOSED_MESSAGE, ex.getMessage()); - } catch (SearchContextMissingException ex) { - // that's fine - } + public void testCanMatchKeywordSortedQueryMatchAllWithException() throws IOException { + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(false) + .source(new SearchSourceBuilder().sort("field").query(new MatchAllQueryBuilder())); + // provide a sort field that throws exception + SortField sortField = new SortField("field", SortField.Type.STRING) { + @Override + public Type getType() { + throw new UnsupportedOperationException(); } - } finally { - running.set(false); - thread.join(); - semaphore.acquire(permitCount); - } - - assertEquals(0, service.getActiveContexts()); - - SearchStats.Stats totalStats = indexShard.searchStats().getTotal(); - assertEquals(0, totalStats.getQueryCurrent()); - assertEquals(0, totalStats.getScrollCurrent()); - assertEquals(0, totalStats.getFetchCurrent()); + }; + doTestCanMatch(searchRequest, sortField, true, null, false); } - public void testRankFeaturePhaseSearchPhases() throws InterruptedException, ExecutionException { - final String indexName = "index"; - final String rankFeatureFieldName = "field"; - final String searchFieldName = "search_field"; - final String searchFieldValue = "some_value"; - final String fetchFieldName = "fetch_field"; - final String fetchFieldValue = "fetch_value"; - - final int minDocs = 3; - final int maxDocs = 10; - int numDocs = between(minDocs, maxDocs); - createIndex(indexName); - // index some documents - for (int i = 0; i < numDocs; i++) { - prepareIndex(indexName).setId(String.valueOf(i)) - .setSource( - rankFeatureFieldName, - "aardvark_" + i, - searchFieldName, - searchFieldValue, - fetchFieldName, - fetchFieldValue + "_" + i - ) - .get(); - } - indicesAdmin().prepareRefresh(indexName).get(); - - final SearchService service = getInstanceFromNode(SearchService.class); - - final IndicesService indicesService = getInstanceFromNode(IndicesService.class); - final IndexService indexService = indicesService.indexServiceSafe(resolveIndex(indexName)); - final IndexShard indexShard = indexService.getShard(0); - SearchShardTask searchTask = new SearchShardTask(123L, "", "", "", null, emptyMap()); - - // create a SearchRequest that will return all documents and defines a TestRankBuilder with shard-level only operations - SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true) - .source( - new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue)) - .size(DEFAULT_SIZE) - .fetchField(fetchFieldName) - .rankBuilder( - // here we override only the shard-level contexts - new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { - @Override - public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { - return new QueryPhaseRankShardContext(queries, from) { - - @Override - public int rankWindowSize() { - return DEFAULT_RANK_WINDOW_SIZE; - } - - @Override - public RankShardResult combineQueryPhaseResults(List rankResults) { - // we know we have just 1 query, so return all the docs from it - return new TestRankShardResult( - Arrays.stream(rankResults.get(0).scoreDocs) - .map(x -> new RankDoc(x.doc, x.score, x.shardIndex)) - .limit(rankWindowSize()) - .toArray(RankDoc[]::new) - ); - } - }; - } - - @Override - public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { - return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) { - @Override - public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { - RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; - for (int i = 0; i < hits.getHits().length; i++) { - SearchHit hit = hits.getHits()[i]; - rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); - rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue()); - rankFeatureDocs[i].score = (numDocs - i) + randomFloat(); - rankFeatureDocs[i].rank = i + 1; - } - return new RankFeatureShardResult(rankFeatureDocs); - } - }; - } - } - ) - ); - - ShardSearchRequest request = new ShardSearchRequest( + private void doTestCanMatch( + SearchRequest searchRequest, + SortField sortField, + boolean expectedCanMatch, + MinAndMax expectedMinAndMax, + boolean throwException + ) throws IOException { + ShardSearchRequest shardRequest = new ShardSearchRequest( OriginalIndices.NONE, searchRequest, - indexShard.shardId(), + new ShardId("index", "index", 0), 0, - 1, + 5, AliasFilter.EMPTY, 1.0f, - -1, + 0, null ); - QuerySearchResult queryResult = null; - RankFeatureResult rankResult = null; + IndexFieldData indexFieldData = indexFieldData(sortField); + IndexShard indexShard = newShard(true); try { - // Execute the query phase and store the result in a SearchPhaseResult container using a PlainActionFuture - PlainActionFuture queryPhaseResults = new PlainActionFuture<>(); - service.executeQueryPhase(request, searchTask, queryPhaseResults); - queryResult = (QuerySearchResult) queryPhaseResults.get(); - - // these are the matched docs from the query phase - final RankDoc[] queryRankDocs = ((TestRankShardResult) queryResult.getRankShardResult()).testRankDocs; - - // assume that we have cut down to these from the coordinator node as the top-docs to run the rank feature phase upon - List topRankWindowSizeDocs = randomNonEmptySubsetOf(Arrays.stream(queryRankDocs).map(x -> x.doc).toList()); - - // now we create a RankFeatureShardRequest to extract feature info for the top-docs above - RankFeatureShardRequest rankFeatureShardRequest = new RankFeatureShardRequest( - OriginalIndices.NONE, - queryResult.getContextId(), // use the context from the query phase - request, - topRankWindowSizeDocs - ); - PlainActionFuture rankPhaseResults = new PlainActionFuture<>(); - service.executeRankFeaturePhase(rankFeatureShardRequest, searchTask, rankPhaseResults); - rankResult = rankPhaseResults.get(); - - assertNotNull(rankResult); - assertNotNull(rankResult.rankFeatureResult()); - RankFeatureShardResult rankFeatureShardResult = rankResult.rankFeatureResult().shardResult(); - assertNotNull(rankFeatureShardResult); - - List sortedRankWindowDocs = topRankWindowSizeDocs.stream().sorted().toList(); - assertEquals(sortedRankWindowDocs.size(), rankFeatureShardResult.rankFeatureDocs.length); - for (int i = 0; i < sortedRankWindowDocs.size(); i++) { - assertEquals((long) sortedRankWindowDocs.get(i), rankFeatureShardResult.rankFeatureDocs[i].doc); - assertEquals(rankFeatureShardResult.rankFeatureDocs[i].featureData, "aardvark_" + sortedRankWindowDocs.get(i)); - } - - List globalTopKResults = randomNonEmptySubsetOf( - Arrays.stream(rankFeatureShardResult.rankFeatureDocs).map(x -> x.doc).toList() - ); - - // finally let's create a fetch request to bring back fetch info for the top results - ShardFetchSearchRequest fetchRequest = new ShardFetchSearchRequest( - OriginalIndices.NONE, - rankResult.getContextId(), - request, - globalTopKResults, - null, - null, - rankResult.getRescoreDocIds(), - null - ); - - // execute fetch phase and perform any validations once we retrieve the response - // the difference in how we do assertions here is needed because once the transport service sends back the response - // it decrements the reference to the FetchSearchResult (through the ActionListener#respondAndRelease) and sets hits to null - PlainActionFuture fetchListener = new PlainActionFuture<>() { - @Override - public void onResponse(FetchSearchResult fetchSearchResult) { - assertNotNull(fetchSearchResult); - assertNotNull(fetchSearchResult.hits()); - - int totalHits = fetchSearchResult.hits().getHits().length; - assertEquals(globalTopKResults.size(), totalHits); - for (int i = 0; i < totalHits; i++) { - // rank and score are set by the SearchPhaseController#merge so no need to validate that here - SearchHit hit = fetchSearchResult.hits().getAt(i); - assertNotNull(hit.getFields().get(fetchFieldName)); - assertEquals(hit.getFields().get(fetchFieldName).getValue(), fetchFieldValue + "_" + hit.docId()); - } - super.onResponse(fetchSearchResult); - } - - @Override - public void onFailure(Exception e) { - super.onFailure(e); - throw new AssertionError("No failure should have been raised", e); - } - }; - service.executeFetchPhase(fetchRequest, searchTask, fetchListener); - fetchListener.get(); - } catch (Exception ex) { - if (queryResult != null) { - if (queryResult.hasReferences()) { - queryResult.decRef(); - } - service.freeReaderContext(queryResult.getContextId()); - } - if (rankResult != null && rankResult.hasReferences()) { - rankResult.decRef(); - } - throw ex; - } - } - - public void testRankFeaturePhaseUsingClient() { - final String indexName = "index"; - final String rankFeatureFieldName = "field"; - final String searchFieldName = "search_field"; - final String searchFieldValue = "some_value"; - final String fetchFieldName = "fetch_field"; - final String fetchFieldValue = "fetch_value"; - - final int minDocs = 4; - final int maxDocs = 10; - int numDocs = between(minDocs, maxDocs); - createIndex(indexName); - // index some documents - for (int i = 0; i < numDocs; i++) { - prepareIndex(indexName).setId(String.valueOf(i)) - .setSource( - rankFeatureFieldName, - "aardvark_" + i, - searchFieldName, - searchFieldValue, - fetchFieldName, - fetchFieldValue + "_" + i - ) - .get(); - } - indicesAdmin().prepareRefresh(indexName).get(); - - ElasticsearchAssertions.assertResponse( - client().prepareSearch(indexName) - .setSource( - new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue)) - .size(2) - .from(2) - .fetchField(fetchFieldName) - .rankBuilder( - // here we override only the shard-level contexts - new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { - - // no need for more than one queries - @Override - public boolean isCompoundBuilder() { - return false; - } - - @Override - public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext( - int size, - int from, - Client client - ) { - return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { - @Override - protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { - float[] scores = new float[featureDocs.length]; - for (int i = 0; i < featureDocs.length; i++) { - scores[i] = featureDocs[i].score; - } - scoreListener.onResponse(scores); - } - }; - } - - @Override - public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { - return new QueryPhaseRankCoordinatorContext(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { - @Override - public ScoreDoc[] rankQueryPhaseResults( - List querySearchResults, - SearchPhaseController.TopDocsStats topDocStats - ) { - List rankDocs = new ArrayList<>(); - for (int i = 0; i < querySearchResults.size(); i++) { - QuerySearchResult querySearchResult = querySearchResults.get(i); - TestRankShardResult shardResult = (TestRankShardResult) querySearchResult - .getRankShardResult(); - for (RankDoc trd : shardResult.testRankDocs) { - trd.shardIndex = i; - rankDocs.add(trd); - } - } - rankDocs.sort(Comparator.comparing((RankDoc doc) -> doc.score).reversed()); - RankDoc[] topResults = rankDocs.stream().limit(rankWindowSize).toArray(RankDoc[]::new); - topDocStats.fetchHits = topResults.length; - return topResults; - } - }; - } - - @Override - public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { - return new QueryPhaseRankShardContext(queries, from) { - - @Override - public int rankWindowSize() { - return DEFAULT_RANK_WINDOW_SIZE; - } - - @Override - public RankShardResult combineQueryPhaseResults(List rankResults) { - // we know we have just 1 query, so return all the docs from it - return new TestRankShardResult( - Arrays.stream(rankResults.get(0).scoreDocs) - .map(x -> new RankDoc(x.doc, x.score, x.shardIndex)) - .limit(rankWindowSize()) - .toArray(RankDoc[]::new) - ); - } - }; - } - - @Override - public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { - return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) { - @Override - public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { - RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; - for (int i = 0; i < hits.getHits().length; i++) { - SearchHit hit = hits.getHits()[i]; - rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); - rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue()); - rankFeatureDocs[i].score = randomFloat(); - rankFeatureDocs[i].rank = i + 1; - } - return new RankFeatureShardResult(rankFeatureDocs); - } - }; - } - } - ) - ), - (response) -> { - SearchHits hits = response.getHits(); - assertEquals(hits.getTotalHits().value(), numDocs); - assertEquals(hits.getHits().length, 2); - int index = 0; - for (SearchHit hit : hits.getHits()) { - assertEquals(hit.getRank(), 3 + index); - assertTrue(hit.getScore() >= 0); - assertEquals(hit.getFields().get(fetchFieldName).getValue(), fetchFieldValue + "_" + hit.docId()); - index++; - } - } - ); - } - - public void testRankFeaturePhaseExceptionOnCoordinatingNode() { - final String indexName = "index"; - final String rankFeatureFieldName = "field"; - final String searchFieldName = "search_field"; - final String searchFieldValue = "some_value"; - final String fetchFieldName = "fetch_field"; - final String fetchFieldValue = "fetch_value"; - - final int minDocs = 3; - final int maxDocs = 10; - int numDocs = between(minDocs, maxDocs); - createIndex(indexName); - // index some documents - for (int i = 0; i < numDocs; i++) { - prepareIndex(indexName).setId(String.valueOf(i)) - .setSource( - rankFeatureFieldName, - "aardvark_" + i, - searchFieldName, - searchFieldValue, - fetchFieldName, - fetchFieldValue + "_" + i - ) - .get(); - } - indicesAdmin().prepareRefresh(indexName).get(); - - expectThrows( - SearchPhaseExecutionException.class, - () -> client().prepareSearch(indexName) - .setSource( - new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue)) - .size(2) - .from(2) - .fetchField(fetchFieldName) - .rankBuilder(new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { - - // no need for more than one queries - @Override - public boolean isCompoundBuilder() { - return false; - } - - @Override - public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext( - int size, - int from, - Client client - ) { - return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { - @Override - protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { - throw new IllegalStateException("should have failed earlier"); - } - }; - } - - @Override - public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { - return new QueryPhaseRankCoordinatorContext(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { - @Override - public ScoreDoc[] rankQueryPhaseResults( - List querySearchResults, - SearchPhaseController.TopDocsStats topDocStats - ) { - throw new UnsupportedOperationException("simulated failure"); - } - }; - } - - @Override - public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { - return new QueryPhaseRankShardContext(queries, from) { - - @Override - public int rankWindowSize() { - return DEFAULT_RANK_WINDOW_SIZE; - } - - @Override - public RankShardResult combineQueryPhaseResults(List rankResults) { - // we know we have just 1 query, so return all the docs from it - return new TestRankShardResult( - Arrays.stream(rankResults.get(0).scoreDocs) - .map(x -> new RankDoc(x.doc, x.score, x.shardIndex)) - .limit(rankWindowSize()) - .toArray(RankDoc[]::new) - ); - } - }; - } - - @Override - public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { - return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) { - @Override - public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { - RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; - for (int i = 0; i < hits.getHits().length; i++) { - SearchHit hit = hits.getHits()[i]; - rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); - rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue()); - rankFeatureDocs[i].score = randomFloat(); - rankFeatureDocs[i].rank = i + 1; - } - return new RankFeatureShardResult(rankFeatureDocs); - } - }; - } - }) - ) - .get() - ); - } - - public void testRankFeaturePhaseExceptionAllShardFail() { - final String indexName = "index"; - final String rankFeatureFieldName = "field"; - final String searchFieldName = "search_field"; - final String searchFieldValue = "some_value"; - final String fetchFieldName = "fetch_field"; - final String fetchFieldValue = "fetch_value"; - - final int minDocs = 3; - final int maxDocs = 10; - int numDocs = between(minDocs, maxDocs); - createIndex(indexName); - // index some documents - for (int i = 0; i < numDocs; i++) { - prepareIndex(indexName).setId(String.valueOf(i)) - .setSource( - rankFeatureFieldName, - "aardvark_" + i, - searchFieldName, - searchFieldValue, - fetchFieldName, - fetchFieldValue + "_" + i - ) - .get(); - } - indicesAdmin().prepareRefresh(indexName).get(); - - expectThrows( - SearchPhaseExecutionException.class, - () -> client().prepareSearch(indexName) - .setAllowPartialSearchResults(true) - .setSource( - new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue)) - .fetchField(fetchFieldName) - .rankBuilder( - // here we override only the shard-level contexts - new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { - - // no need for more than one queries - @Override - public boolean isCompoundBuilder() { - return false; - } - - @Override - public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext( - int size, - int from, - Client client - ) { - return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { - @Override - protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { - float[] scores = new float[featureDocs.length]; - for (int i = 0; i < featureDocs.length; i++) { - scores[i] = featureDocs[i].score; - } - scoreListener.onResponse(scores); - } - }; - } - - @Override - public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { - return new QueryPhaseRankCoordinatorContext(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { - @Override - public ScoreDoc[] rankQueryPhaseResults( - List querySearchResults, - SearchPhaseController.TopDocsStats topDocStats - ) { - List rankDocs = new ArrayList<>(); - for (int i = 0; i < querySearchResults.size(); i++) { - QuerySearchResult querySearchResult = querySearchResults.get(i); - TestRankShardResult shardResult = (TestRankShardResult) querySearchResult - .getRankShardResult(); - for (RankDoc trd : shardResult.testRankDocs) { - trd.shardIndex = i; - rankDocs.add(trd); - } - } - rankDocs.sort(Comparator.comparing((RankDoc doc) -> doc.score).reversed()); - RankDoc[] topResults = rankDocs.stream().limit(rankWindowSize).toArray(RankDoc[]::new); - topDocStats.fetchHits = topResults.length; - return topResults; - } - }; - } - - @Override - public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { - return new QueryPhaseRankShardContext(queries, from) { - - @Override - public int rankWindowSize() { - return DEFAULT_RANK_WINDOW_SIZE; - } - - @Override - public RankShardResult combineQueryPhaseResults(List rankResults) { - // we know we have just 1 query, so return all the docs from it - return new TestRankShardResult( - Arrays.stream(rankResults.get(0).scoreDocs) - .map(x -> new RankDoc(x.doc, x.score, x.shardIndex)) - .limit(rankWindowSize()) - .toArray(RankDoc[]::new) - ); - } - }; - } - - @Override - public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { - return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) { - @Override - public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { - throw new UnsupportedOperationException("simulated failure"); - } - }; - } - } - ) - ) - .get() - ); - } - - public void testRankFeaturePhaseExceptionOneShardFails() { - // if we have only one shard and it fails, it will fallback to context.onPhaseFailure which will eventually clean up all contexts. - // in this test we want to make sure that even if one shard (of many) fails during the RankFeaturePhase, then the appropriate - // context will have been cleaned up. - final String indexName = "index"; - final String rankFeatureFieldName = "field"; - final String searchFieldName = "search_field"; - final String searchFieldValue = "some_value"; - final String fetchFieldName = "fetch_field"; - final String fetchFieldValue = "fetch_value"; - - final int minDocs = 3; - final int maxDocs = 10; - int numDocs = between(minDocs, maxDocs); - createIndex(indexName, Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 2).build()); - // index some documents - for (int i = 0; i < numDocs; i++) { - prepareIndex(indexName).setId(String.valueOf(i)) - .setSource( - rankFeatureFieldName, - "aardvark_" + i, - searchFieldName, - searchFieldValue, - fetchFieldName, - fetchFieldValue + "_" + i - ) - .get(); - } - indicesAdmin().prepareRefresh(indexName).get(); - - assertResponse( - client().prepareSearch(indexName) - .setAllowPartialSearchResults(true) - .setSource( - new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue)) - .fetchField(fetchFieldName) - .rankBuilder( - // here we override only the shard-level contexts - new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { - - // no need for more than one queries - @Override - public boolean isCompoundBuilder() { - return false; - } - - @Override - public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext( - int size, - int from, - Client client - ) { - return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) { - @Override - protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener scoreListener) { - float[] scores = new float[featureDocs.length]; - for (int i = 0; i < featureDocs.length; i++) { - scores[i] = featureDocs[i].score; - } - scoreListener.onResponse(scores); - } - }; - } - - @Override - public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) { - return new QueryPhaseRankCoordinatorContext(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) { - @Override - public ScoreDoc[] rankQueryPhaseResults( - List querySearchResults, - SearchPhaseController.TopDocsStats topDocStats - ) { - List rankDocs = new ArrayList<>(); - for (int i = 0; i < querySearchResults.size(); i++) { - QuerySearchResult querySearchResult = querySearchResults.get(i); - TestRankShardResult shardResult = (TestRankShardResult) querySearchResult - .getRankShardResult(); - for (RankDoc trd : shardResult.testRankDocs) { - trd.shardIndex = i; - rankDocs.add(trd); - } - } - rankDocs.sort(Comparator.comparing((RankDoc doc) -> doc.score).reversed()); - RankDoc[] topResults = rankDocs.stream().limit(rankWindowSize).toArray(RankDoc[]::new); - topDocStats.fetchHits = topResults.length; - return topResults; - } - }; - } - - @Override - public QueryPhaseRankShardContext buildQueryPhaseShardContext(List queries, int from) { - return new QueryPhaseRankShardContext(queries, from) { - - @Override - public int rankWindowSize() { - return DEFAULT_RANK_WINDOW_SIZE; - } - - @Override - public RankShardResult combineQueryPhaseResults(List rankResults) { - // we know we have just 1 query, so return all the docs from it - return new TestRankShardResult( - Arrays.stream(rankResults.get(0).scoreDocs) - .map(x -> new RankDoc(x.doc, x.score, x.shardIndex)) - .limit(rankWindowSize()) - .toArray(RankDoc[]::new) - ); - } - }; - } - - @Override - public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() { - return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) { - @Override - public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) { - if (shardId == 0) { - throw new UnsupportedOperationException("simulated failure"); - } else { - RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length]; - for (int i = 0; i < hits.getHits().length; i++) { - SearchHit hit = hits.getHits()[i]; - rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId); - rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue()); - rankFeatureDocs[i].score = randomFloat(); - rankFeatureDocs[i].rank = i + 1; - } - return new RankFeatureShardResult(rankFeatureDocs); - } - } - }; - } - } - ) - ), - (searchResponse) -> { - assertEquals(1, searchResponse.getSuccessfulShards()); - assertEquals("simulated failure", searchResponse.getShardFailures()[0].getCause().getMessage()); - assertNotEquals(0, searchResponse.getHits().getHits().length); - for (SearchHit hit : searchResponse.getHits().getHits()) { - assertEquals(fetchFieldValue + "_" + hit.getId(), hit.getFields().get(fetchFieldName).getValue()); - assertEquals(1, hit.getShard().getShardId().id()); + recoverShardFromStore(indexShard); + assertTrue(indexDoc(indexShard, "_doc", "id", "{\"field\":\"value\"}").isCreated()); + assertTrue(indexShard.refresh("test").refreshed()); + try (Engine.Searcher searcher = indexShard.acquireSearcher("test")) { + SearchExecutionContext searchExecutionContext = createSearchExecutionContext( + (mappedFieldType, fieldDataContext) -> indexFieldData, + searcher + ); + SearchService.CanMatchContext canMatchContext = createCanMatchContext( + shardRequest, + indexShard, + searchExecutionContext, + parserConfig(), + throwException + ); + CanMatchShardResponse canMatchShardResponse = SearchService.canMatch(canMatchContext, false); + assertEquals(expectedCanMatch, canMatchShardResponse.canMatch()); + if (expectedMinAndMax == null) { + assertNull(canMatchShardResponse.estimatedMinAndMax()); + } else { + MinAndMax minAndMax = canMatchShardResponse.estimatedMinAndMax(); + assertNotNull(minAndMax); + assertEquals(expectedMinAndMax.getMin(), minAndMax.getMin()); + assertEquals(expectedMinAndMax.getMin(), minAndMax.getMax()); } - } - ); - } - - public void testSearchWhileIndexDeletedDoesNotLeakSearchContext() throws ExecutionException, InterruptedException { - createIndex("index"); - prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); - IndicesService indicesService = getInstanceFromNode(IndicesService.class); - IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); - IndexShard indexShard = indexService.getShard(0); - - MockSearchService service = (MockSearchService) getInstanceFromNode(SearchService.class); - service.setOnPutContext(context -> { - if (context.indexShard() == indexShard) { - assertAcked(indicesAdmin().prepareDelete("index")); } - }); - - SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); - SearchRequest scrollSearchRequest = new SearchRequest().allowPartialSearchResults(true) - .scroll(new Scroll(TimeValue.timeValueMinutes(1))); - - // the scrolls are not explicitly freed, but should all be gone when the test finished. - // for completeness, we also randomly test the regular search path. - final boolean useScroll = randomBoolean(); - PlainActionFuture result = new PlainActionFuture<>(); - service.executeQueryPhase( - new ShardSearchRequest( - OriginalIndices.NONE, - useScroll ? scrollSearchRequest : searchRequest, - new ShardId(resolveIndex("index"), 0), - 0, - 1, - AliasFilter.EMPTY, - 1.0f, - -1, - null - ), - new SearchShardTask(123L, "", "", "", null, emptyMap()), - result - ); - - try { - result.get(); - } catch (Exception e) { - // ok - } - - expectThrows(IndexNotFoundException.class, () -> indicesAdmin().prepareGetIndex().setIndices("index").get()); - - assertEquals(0, service.getActiveContexts()); - - SearchStats.Stats totalStats = indexShard.searchStats().getTotal(); - assertEquals(0, totalStats.getQueryCurrent()); - assertEquals(0, totalStats.getScrollCurrent()); - assertEquals(0, totalStats.getFetchCurrent()); - } - - public void testBeforeShardLockDuringShardCreate() { - IndexService indexService = createIndex("index", Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).build()); - prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); - assertResponse( - client().prepareSearch("index").setSize(1).setScroll(TimeValue.timeValueMinutes(1)), - searchResponse -> assertThat(searchResponse.getScrollId(), is(notNullValue())) - ); - SearchService service = getInstanceFromNode(SearchService.class); - - assertEquals(1, service.getActiveContexts()); - service.beforeIndexShardCreated( - TestShardRouting.newShardRouting( - "test", - 0, - randomAlphaOfLength(5), - randomAlphaOfLength(5), - randomBoolean(), - ShardRoutingState.INITIALIZING - ), - indexService.getIndexSettings().getSettings() - ); - assertEquals(1, service.getActiveContexts()); - - service.beforeIndexShardCreated( - TestShardRouting.newShardRouting( - new ShardId(indexService.index(), 0), - randomAlphaOfLength(5), - randomBoolean(), - ShardRoutingState.INITIALIZING - ), - indexService.getIndexSettings().getSettings() - ); - assertEquals(0, service.getActiveContexts()); - } - - public void testTimeout() throws IOException { - createIndex("index"); - final SearchService service = getInstanceFromNode(SearchService.class); - final IndicesService indicesService = getInstanceFromNode(IndicesService.class); - final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); - final IndexShard indexShard = indexService.getShard(0); - SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); - final ShardSearchRequest requestWithDefaultTimeout = new ShardSearchRequest( - OriginalIndices.NONE, - searchRequest, - indexShard.shardId(), + } finally { + closeShards(indexShard); + } + } + + private SearchExecutionContext createSearchExecutionContext( + BiFunction> indexFieldDataLookup, + IndexSearcher searcher + ) { + IndexMetadata indexMetadata = IndexMetadata.builder("index") + .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) + .numberOfShards(1) + .numberOfReplicas(0) + .build(); + IndexSettings indexSettings = new IndexSettings(indexMetadata, Settings.EMPTY); + Predicate indexNameMatcher = pattern -> Regex.simpleMatch(pattern, "index"); + + MapperBuilderContext root = MapperBuilderContext.root(false, false); + RootObjectMapper.Builder builder = new RootObjectMapper.Builder("_doc", ObjectMapper.Defaults.SUBOBJECTS); + Mapping mapping = new Mapping( + builder.build(MapperBuilderContext.root(false, false)), + new MetadataFieldMapper[0], + Collections.emptyMap() + ); + KeywordFieldMapper keywordFieldMapper = new KeywordFieldMapper.Builder("field", IndexVersion.current()).build(root); + MappingLookup mappingLookup = MappingLookup.fromMappers( + mapping, + Collections.singletonList(keywordFieldMapper), + Collections.emptyList() + ); + return new SearchExecutionContext( 0, - 1, - AliasFilter.EMPTY, - 1.0f, - -1, - null - ); - - try ( - ReaderContext reader = createReaderContext(indexService, indexShard); - SearchContext contextWithDefaultTimeout = service.createContext( - reader, - requestWithDefaultTimeout, - mock(SearchShardTask.class), - ResultsType.NONE, - randomBoolean() - ) - ) { - // the search context should inherit the default timeout - assertThat(contextWithDefaultTimeout.timeout(), equalTo(TimeValue.timeValueSeconds(5))); - } - - final long seconds = randomIntBetween(6, 10); - searchRequest.source(new SearchSourceBuilder().timeout(TimeValue.timeValueSeconds(seconds))); - final ShardSearchRequest requestWithCustomTimeout = new ShardSearchRequest( - OriginalIndices.NONE, - searchRequest, - indexShard.shardId(), 0, - 1, - AliasFilter.EMPTY, - 1.0f, - -1, - null + indexSettings, + null, + indexFieldDataLookup, + null, + mappingLookup, + null, + null, + parserConfig(), + writableRegistry(), + null, + searcher, + System::currentTimeMillis, + null, + indexNameMatcher, + () -> true, + null, + Collections.emptyMap(), + MapperMetrics.NOOP ); - try ( - ReaderContext reader = createReaderContext(indexService, indexShard); - SearchContext context = service.createContext( - reader, - requestWithCustomTimeout, - mock(SearchShardTask.class), - ResultsType.NONE, - randomBoolean() - ) - ) { - // the search context should inherit the query timeout - assertThat(context.timeout(), equalTo(TimeValue.timeValueSeconds(seconds))); - } } - /** - * test that getting more than the allowed number of docvalue_fields throws an exception - */ - public void testMaxDocvalueFieldsSearch() throws IOException { - final Settings settings = Settings.builder().put(IndexSettings.MAX_DOCVALUE_FIELDS_SEARCH_SETTING.getKey(), 1).build(); - createIndex("index", settings, null, "field1", "keyword", "field2", "keyword"); - prepareIndex("index").setId("1").setSource("field1", "value1", "field2", "value2").setRefreshPolicy(IMMEDIATE).get(); - - final SearchService service = getInstanceFromNode(SearchService.class); - final IndicesService indicesService = getInstanceFromNode(IndicesService.class); - final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); - final IndexShard indexShard = indexService.getShard(0); - - SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchRequest.source(searchSourceBuilder); - searchSourceBuilder.docValueField("field1"); - - final ShardSearchRequest request = new ShardSearchRequest( - OriginalIndices.NONE, - searchRequest, - indexShard.shardId(), - 0, - 1, - AliasFilter.EMPTY, - 1.0f, - -1, - null - ); - try ( - ReaderContext reader = createReaderContext(indexService, indexShard); - SearchContext context = service.createContext(reader, request, mock(SearchShardTask.class), ResultsType.NONE, randomBoolean()) - ) { - assertNotNull(context); - } - - searchSourceBuilder.docValueField("unmapped_field"); - try ( - ReaderContext reader = createReaderContext(indexService, indexShard); - SearchContext context = service.createContext(reader, request, mock(SearchShardTask.class), ResultsType.NONE, randomBoolean()) - ) { - assertNotNull(context); - } - - searchSourceBuilder.docValueField("field2"); - try (ReaderContext reader = createReaderContext(indexService, indexShard)) { - IllegalArgumentException ex = expectThrows( - IllegalArgumentException.class, - () -> service.createContext(reader, request, mock(SearchShardTask.class), ResultsType.NONE, randomBoolean()) - ); - assertEquals( - "Trying to retrieve too many docvalue_fields. Must be less than or equal to: [1] but was [2]. " - + "This limit can be set by changing the [index.max_docvalue_fields_search] index level setting.", - ex.getMessage() - ); - } - } + private static IndexFieldData indexFieldData(SortField sortField) { + return new IndexFieldData<>() { + @Override + public String getFieldName() { + return "field"; + } - public void testDeduplicateDocValuesFields() throws Exception { - createIndex("index", Settings.EMPTY, "_doc", "field1", "type=date", "field2", "type=date"); - prepareIndex("index").setId("1").setSource("field1", "2022-08-03", "field2", "2022-08-04").setRefreshPolicy(IMMEDIATE).get(); - SearchService service = getInstanceFromNode(SearchService.class); - IndicesService indicesService = getInstanceFromNode(IndicesService.class); - IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); - IndexShard indexShard = indexService.getShard(0); + @Override + public ValuesSourceType getValuesSourceType() { + throw new UnsupportedOperationException(); + } - try (ReaderContext reader = createReaderContext(indexService, indexShard)) { - SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchRequest.source(searchSourceBuilder); - searchSourceBuilder.docValueField("f*"); - if (randomBoolean()) { - searchSourceBuilder.docValueField("field*"); + @Override + public LeafFieldData load(LeafReaderContext context) { + throw new UnsupportedOperationException(); } - if (randomBoolean()) { - searchSourceBuilder.docValueField("*2"); + + @Override + public LeafFieldData loadDirect(LeafReaderContext context) { + throw new UnsupportedOperationException(); } - ShardSearchRequest request = new ShardSearchRequest( - OriginalIndices.NONE, - searchRequest, - indexShard.shardId(), - 0, - 1, - AliasFilter.EMPTY, - 1.0f, - -1, - null - ); - try ( - SearchContext context = service.createContext( - reader, - request, - mock(SearchShardTask.class), - ResultsType.NONE, - randomBoolean() - ) + + @Override + public SortField sortField( + Object missingValue, + MultiValueMode sortMode, + XFieldComparatorSource.Nested nested, + boolean reverse ) { - Collection fields = context.docValuesContext().fields(); - assertThat(fields, containsInAnyOrder(new FieldAndFormat("field1", null), new FieldAndFormat("field2", null))); + return sortField; } - } - } - - /** - * test that getting more than the allowed number of script_fields throws an exception - */ - public void testMaxScriptFieldsSearch() throws IOException { - createIndex("index"); - final SearchService service = getInstanceFromNode(SearchService.class); - final IndicesService indicesService = getInstanceFromNode(IndicesService.class); - final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); - final IndexShard indexShard = indexService.getShard(0); - - SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchRequest.source(searchSourceBuilder); - // adding the maximum allowed number of script_fields to retrieve - int maxScriptFields = indexService.getIndexSettings().getMaxScriptFields(); - for (int i = 0; i < maxScriptFields; i++) { - searchSourceBuilder.scriptField( - "field" + i, - new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, emptyMap()) - ); - } - final ShardSearchRequest request = new ShardSearchRequest( - OriginalIndices.NONE, - searchRequest, - indexShard.shardId(), - 0, - 1, - AliasFilter.EMPTY, - 1.0f, - -1, - null - ); - try (ReaderContext reader = createReaderContext(indexService, indexShard)) { - try ( - SearchContext context = service.createContext( - reader, - request, - mock(SearchShardTask.class), - ResultsType.NONE, - randomBoolean() - ) + @Override + public BucketedSort newBucketedSort( + BigArrays bigArrays, + Object missingValue, + MultiValueMode sortMode, + XFieldComparatorSource.Nested nested, + SortOrder sortOrder, + DocValueFormat format, + int bucketSize, + BucketedSort.ExtraData extra ) { - assertNotNull(context); + throw new UnsupportedOperationException(); } - searchSourceBuilder.scriptField( - "anotherScriptField", - new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, emptyMap()) - ); - IllegalArgumentException ex = expectThrows( - IllegalArgumentException.class, - () -> service.createContext(reader, request, mock(SearchShardTask.class), ResultsType.NONE, randomBoolean()) - ); - assertEquals( - "Trying to retrieve too many script_fields. Must be less than or equal to: [" - + maxScriptFields - + "] but was [" - + (maxScriptFields + 1) - + "]. This limit can be set by changing the [index.max_script_fields] index level setting.", - ex.getMessage() - ); - } - } - - public void testIgnoreScriptfieldIfSizeZero() throws IOException { - createIndex("index"); - final SearchService service = getInstanceFromNode(SearchService.class); - final IndicesService indicesService = getInstanceFromNode(IndicesService.class); - final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); - final IndexShard indexShard = indexService.getShard(0); - - SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchRequest.source(searchSourceBuilder); - searchSourceBuilder.scriptField( - "field" + 0, - new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, emptyMap()) - ); - searchSourceBuilder.size(0); - final ShardSearchRequest request = new ShardSearchRequest( - OriginalIndices.NONE, - searchRequest, - indexShard.shardId(), - 0, - 1, - AliasFilter.EMPTY, - 1.0f, - -1, - null - ); - try ( - ReaderContext reader = createReaderContext(indexService, indexShard); - SearchContext context = service.createContext(reader, request, mock(SearchShardTask.class), ResultsType.NONE, randomBoolean()) - ) { - assertEquals(0, context.scriptFields().fields().size()); - } + }; } - /** - * test that creating more than the allowed number of scroll contexts throws an exception - */ - public void testMaxOpenScrollContexts() throws Exception { - createIndex("index"); - prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); - - final SearchService service = getInstanceFromNode(SearchService.class); - final IndicesService indicesService = getInstanceFromNode(IndicesService.class); - final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); - final IndexShard indexShard = indexService.getShard(0); - - // Open all possible scrolls, clear some of them, then open more until the limit is reached - LinkedList clearScrollIds = new LinkedList<>(); - - for (int i = 0; i < SearchService.MAX_OPEN_SCROLL_CONTEXT.get(Settings.EMPTY); i++) { - assertResponse(client().prepareSearch("index").setSize(1).setScroll(TimeValue.timeValueMinutes(1)), searchResponse -> { - if (randomInt(4) == 0) clearScrollIds.addLast(searchResponse.getScrollId()); - }); - } + private static SearchService.CanMatchContext createCanMatchContext( + ShardSearchRequest shardRequest, + IndexShard indexShard, + SearchExecutionContext searchExecutionContext, + XContentParserConfiguration parserConfig, + boolean throwException + ) { + return new SearchService.CanMatchContext(shardRequest, null, null, -1, -1) { + @Override + IndexShard getShard() { + return indexShard; + } - ClearScrollRequest clearScrollRequest = new ClearScrollRequest(); - clearScrollRequest.setScrollIds(clearScrollIds); - client().clearScroll(clearScrollRequest).get(); + @Override + QueryRewriteContext getQueryRewriteContext(IndexService indexService) { + if (throwException) { + throw new IllegalArgumentException(); + } + return new QueryRewriteContext(parserConfig, null, System::currentTimeMillis); + } - for (int i = 0; i < clearScrollIds.size(); i++) { - client().prepareSearch("index").setSize(1).setScroll(TimeValue.timeValueMinutes(1)).get().decRef(); - } + @Override + SearchExecutionContext getSearchExecutionContext(Engine.Searcher searcher) { + return searchExecutionContext; + } - final ShardScrollRequestTest request = new ShardScrollRequestTest(indexShard.shardId()); - ElasticsearchException ex = expectThrows( - ElasticsearchException.class, - () -> service.createAndPutReaderContext( - request, - indexService, - indexShard, - indexShard.acquireSearcherSupplier(), - SearchService.KEEPALIVE_INTERVAL_SETTING.get(Settings.EMPTY).millis() - ) - ); - assertEquals( - "Trying to create too many scroll contexts. Must be less than or equal to: [" - + SearchService.MAX_OPEN_SCROLL_CONTEXT.get(Settings.EMPTY) - + "]. " - + "This limit can be set by changing the [search.max_open_scroll_context] setting.", - ex.getMessage() - ); - assertEquals(RestStatus.TOO_MANY_REQUESTS, ex.status()); - - service.freeAllScrollContexts(); - } - - public void testOpenScrollContextsConcurrently() throws Exception { - createIndex("index"); - final IndicesService indicesService = getInstanceFromNode(IndicesService.class); - final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); - final IndexShard indexShard = indexService.getShard(0); - - final int maxScrollContexts = SearchService.MAX_OPEN_SCROLL_CONTEXT.get(Settings.EMPTY); - final SearchService searchService = getInstanceFromNode(SearchService.class); - Thread[] threads = new Thread[randomIntBetween(2, 8)]; - CountDownLatch latch = new CountDownLatch(threads.length); - for (int i = 0; i < threads.length; i++) { - threads[i] = new Thread(() -> { - latch.countDown(); - try { - latch.await(); - for (;;) { - final Engine.SearcherSupplier reader = indexShard.acquireSearcherSupplier(); - try { - final ShardScrollRequestTest request = new ShardScrollRequestTest(indexShard.shardId()); - searchService.createAndPutReaderContext( - request, - indexService, - indexShard, - reader, - SearchService.KEEPALIVE_INTERVAL_SETTING.get(Settings.EMPTY).millis() - ); - } catch (ElasticsearchException e) { - assertThat( - e.getMessage(), - equalTo( - "Trying to create too many scroll contexts. Must be less than or equal to: " - + "[" - + maxScrollContexts - + "]. " - + "This limit can be set by changing the [search.max_open_scroll_context] setting." - ) - ); - return; - } - } - } catch (Exception e) { - throw new AssertionError(e); - } - }); - threads[i].setName("elasticsearch[node_s_0][search]"); - threads[i].start(); - } - for (Thread thread : threads) { - thread.join(); - } - assertThat(searchService.getActiveContexts(), equalTo(maxScrollContexts)); - searchService.freeAllScrollContexts(); - } - - public static class FailOnRewriteQueryPlugin extends Plugin implements SearchPlugin { - @Override - public List> getQueries() { - return singletonList(new QuerySpec<>("fail_on_rewrite_query", FailOnRewriteQueryBuilder::new, parseContext -> { - throw new UnsupportedOperationException("No query parser for this plugin"); - })); - } - } - - public static class FailOnRewriteQueryBuilder extends DummyQueryBuilder { - - public FailOnRewriteQueryBuilder(StreamInput in) throws IOException { - super(in); - } - - public FailOnRewriteQueryBuilder() {} - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { - if (queryRewriteContext.convertToSearchExecutionContext() != null) { - throw new IllegalStateException("Fail on rewrite phase"); - } - return this; - } - } - - private static class ShardScrollRequestTest extends ShardSearchRequest { - private Scroll scroll; - - ShardScrollRequestTest(ShardId shardId) { - super( - OriginalIndices.NONE, - new SearchRequest().allowPartialSearchResults(true), - shardId, - 0, - 1, - AliasFilter.EMPTY, - 1f, - -1, - null - ); - this.scroll = new Scroll(TimeValue.timeValueMinutes(1)); - } - - @Override - public Scroll scroll() { - return this.scroll; - } - } - - public void testCanMatch() throws Exception { - createIndex("index"); - final SearchService service = getInstanceFromNode(SearchService.class); - final IndicesService indicesService = getInstanceFromNode(IndicesService.class); - final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); - final IndexShard indexShard = indexService.getShard(0); - SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); - assertTrue( - service.canMatch( - new ShardSearchRequest(OriginalIndices.NONE, searchRequest, indexShard.shardId(), 0, 1, AliasFilter.EMPTY, 1f, -1, null) - ).canMatch() - ); - - searchRequest.source(new SearchSourceBuilder()); - assertTrue( - service.canMatch( - new ShardSearchRequest(OriginalIndices.NONE, searchRequest, indexShard.shardId(), 0, 1, AliasFilter.EMPTY, 1f, -1, null) - ).canMatch() - ); - - searchRequest.source(new SearchSourceBuilder().query(new MatchAllQueryBuilder())); - assertTrue( - service.canMatch( - new ShardSearchRequest(OriginalIndices.NONE, searchRequest, indexShard.shardId(), 0, 1, AliasFilter.EMPTY, 1f, -1, null) - ).canMatch() - ); - - searchRequest.source( - new SearchSourceBuilder().query(new MatchNoneQueryBuilder()) - .aggregation(new TermsAggregationBuilder("test").userValueTypeHint(ValueType.STRING).minDocCount(0)) - ); - assertTrue( - service.canMatch( - new ShardSearchRequest(OriginalIndices.NONE, searchRequest, indexShard.shardId(), 0, 1, AliasFilter.EMPTY, 1f, -1, null) - ).canMatch() - ); - searchRequest.source( - new SearchSourceBuilder().query(new MatchNoneQueryBuilder()).aggregation(new GlobalAggregationBuilder("test")) - ); - assertTrue( - service.canMatch( - new ShardSearchRequest(OriginalIndices.NONE, searchRequest, indexShard.shardId(), 0, 1, AliasFilter.EMPTY, 1f, -1, null) - ).canMatch() - ); - - searchRequest.source(new SearchSourceBuilder().query(new MatchNoneQueryBuilder())); - assertFalse( - service.canMatch( - new ShardSearchRequest(OriginalIndices.NONE, searchRequest, indexShard.shardId(), 0, 1, AliasFilter.EMPTY, 1f, -1, null) - ).canMatch() - ); - assertEquals(5, numWrapInvocations.get()); - - ShardSearchRequest request = new ShardSearchRequest( - OriginalIndices.NONE, - searchRequest, - indexShard.shardId(), - 0, - 1, - AliasFilter.EMPTY, - 1.0f, - -1, - null - ); - - /* - * Checks that canMatch takes into account the alias filter - */ - // the source cannot be rewritten to a match_none - searchRequest.indices("alias").source(new SearchSourceBuilder().query(new MatchAllQueryBuilder())); - assertFalse( - service.canMatch( - new ShardSearchRequest( - OriginalIndices.NONE, - searchRequest, - indexShard.shardId(), - 0, - 1, - AliasFilter.of(new TermQueryBuilder("foo", "bar"), "alias"), - 1f, - -1, - null - ) - ).canMatch() - ); - // the source can match and can be rewritten to a match_none, but not the alias filter - final DocWriteResponse response = prepareIndex("index").setSource("id", "1").get(); - assertEquals(RestStatus.CREATED, response.status()); - searchRequest.indices("alias").source(new SearchSourceBuilder().query(new TermQueryBuilder("id", "1"))); - assertFalse( - service.canMatch( - new ShardSearchRequest( - OriginalIndices.NONE, - searchRequest, - indexShard.shardId(), - 0, - 1, - AliasFilter.of(new TermQueryBuilder("foo", "bar"), "alias"), - 1f, - -1, - null - ) - ).canMatch() - ); - - CountDownLatch latch = new CountDownLatch(1); - SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); - // Because the foo field used in alias filter is unmapped the term query builder rewrite can resolve to a match no docs query, - // without acquiring a searcher and that means the wrapper is not called - assertEquals(5, numWrapInvocations.get()); - service.executeQueryPhase(request, task, new ActionListener<>() { - @Override - public void onResponse(SearchPhaseResult searchPhaseResult) { - try { - // make sure that the wrapper is called when the query is actually executed - assertEquals(6, numWrapInvocations.get()); - } finally { - latch.countDown(); - } - } - - @Override - public void onFailure(Exception e) { - try { - throw new AssertionError(e); - } finally { - latch.countDown(); - } - } - }); - latch.await(); - } - - public void testCanRewriteToMatchNone() { - assertFalse( - SearchService.canRewriteToMatchNone( - new SearchSourceBuilder().query(new MatchNoneQueryBuilder()).aggregation(new GlobalAggregationBuilder("test")) - ) - ); - assertFalse(SearchService.canRewriteToMatchNone(new SearchSourceBuilder())); - assertFalse(SearchService.canRewriteToMatchNone(null)); - assertFalse( - SearchService.canRewriteToMatchNone( - new SearchSourceBuilder().query(new MatchNoneQueryBuilder()) - .aggregation(new TermsAggregationBuilder("test").userValueTypeHint(ValueType.STRING).minDocCount(0)) - ) - ); - assertTrue(SearchService.canRewriteToMatchNone(new SearchSourceBuilder().query(new TermQueryBuilder("foo", "bar")))); - assertTrue( - SearchService.canRewriteToMatchNone( - new SearchSourceBuilder().query(new MatchNoneQueryBuilder()) - .aggregation(new TermsAggregationBuilder("test").userValueTypeHint(ValueType.STRING).minDocCount(1)) - ) - ); - assertFalse( - SearchService.canRewriteToMatchNone( - new SearchSourceBuilder().query(new MatchNoneQueryBuilder()) - .aggregation(new TermsAggregationBuilder("test").userValueTypeHint(ValueType.STRING).minDocCount(1)) - .suggest(new SuggestBuilder()) - ) - ); - assertFalse( - SearchService.canRewriteToMatchNone( - new SearchSourceBuilder().query(new TermQueryBuilder("foo", "bar")).suggest(new SuggestBuilder()) - ) - ); - } - - public void testSetSearchThrottled() throws IOException { - createIndex("throttled_threadpool_index"); - client().execute( - InternalOrPrivateSettingsPlugin.UpdateInternalOrPrivateAction.INSTANCE, - new InternalOrPrivateSettingsPlugin.UpdateInternalOrPrivateAction.Request( - "throttled_threadpool_index", - IndexSettings.INDEX_SEARCH_THROTTLED.getKey(), - "true" - ) - ).actionGet(); - final SearchService service = getInstanceFromNode(SearchService.class); - Index index = resolveIndex("throttled_threadpool_index"); - assertTrue(service.getIndicesService().indexServiceSafe(index).getIndexSettings().isSearchThrottled()); - prepareIndex("throttled_threadpool_index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); - assertSearchHits( - client().prepareSearch("throttled_threadpool_index") - .setIndicesOptions(IndicesOptions.STRICT_EXPAND_OPEN_FORBID_CLOSED) - .setSize(1), - "1" - ); - // we add a search action listener in a plugin above to assert that this is actually used - client().execute( - InternalOrPrivateSettingsPlugin.UpdateInternalOrPrivateAction.INSTANCE, - new InternalOrPrivateSettingsPlugin.UpdateInternalOrPrivateAction.Request( - "throttled_threadpool_index", - IndexSettings.INDEX_SEARCH_THROTTLED.getKey(), - "false" - ) - ).actionGet(); - - IllegalArgumentException iae = expectThrows( - IllegalArgumentException.class, - () -> indicesAdmin().prepareUpdateSettings("throttled_threadpool_index") - .setSettings(Settings.builder().put(IndexSettings.INDEX_SEARCH_THROTTLED.getKey(), false)) - .get() - ); - assertEquals("can not update private setting [index.search.throttled]; this setting is managed by Elasticsearch", iae.getMessage()); - assertFalse(service.getIndicesService().indexServiceSafe(index).getIndexSettings().isSearchThrottled()); - } - - public void testAggContextGetsMatchAll() throws IOException { - createIndex("test"); - withAggregationContext("test", context -> assertThat(context.query(), equalTo(new MatchAllDocsQuery()))); - } - - public void testAggContextGetsNestedFilter() throws IOException { - XContentBuilder mapping = JsonXContent.contentBuilder().startObject().startObject("properties"); - mapping.startObject("nested").field("type", "nested").endObject(); - mapping.endObject().endObject(); - - createIndex("test", Settings.EMPTY, mapping); - withAggregationContext("test", context -> assertThat(context.query(), equalTo(new MatchAllDocsQuery()))); - } - - /** - * Build an {@link AggregationContext} with the named index. - */ - private void withAggregationContext(String index, Consumer check) throws IOException { - IndexService indexService = getInstanceFromNode(IndicesService.class).indexServiceSafe(resolveIndex(index)); - ShardId shardId = new ShardId(indexService.index(), 0); - - SearchRequest request = new SearchRequest().indices(index) - .source(new SearchSourceBuilder().aggregation(new FiltersAggregationBuilder("test", new MatchAllQueryBuilder()))) - .allowPartialSearchResults(false); - ShardSearchRequest shardRequest = new ShardSearchRequest( - OriginalIndices.NONE, - request, - shardId, - 0, - 1, - AliasFilter.EMPTY, - 1, - 0, - null - ); - - try (ReaderContext readerContext = createReaderContext(indexService, indexService.getShard(0))) { - try ( - SearchContext context = getInstanceFromNode(SearchService.class).createContext( - readerContext, - shardRequest, - mock(SearchShardTask.class), - ResultsType.QUERY, - true - ) - ) { - check.accept(context.aggregations().factories().context()); - } - } - } - - public void testExpandSearchThrottled() { - createIndex("throttled_threadpool_index"); - client().execute( - InternalOrPrivateSettingsPlugin.UpdateInternalOrPrivateAction.INSTANCE, - new InternalOrPrivateSettingsPlugin.UpdateInternalOrPrivateAction.Request( - "throttled_threadpool_index", - IndexSettings.INDEX_SEARCH_THROTTLED.getKey(), - "true" - ) - ).actionGet(); - - prepareIndex("throttled_threadpool_index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); - assertHitCount(client().prepareSearch(), 1L); - assertHitCount(client().prepareSearch().setIndicesOptions(IndicesOptions.STRICT_EXPAND_OPEN_FORBID_CLOSED), 1L); - } - - public void testExpandSearchFrozen() { - String indexName = "frozen_index"; - createIndex(indexName); - client().execute( - InternalOrPrivateSettingsPlugin.UpdateInternalOrPrivateAction.INSTANCE, - new InternalOrPrivateSettingsPlugin.UpdateInternalOrPrivateAction.Request(indexName, "index.frozen", "true") - ).actionGet(); - - prepareIndex(indexName).setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); - assertHitCount(client().prepareSearch(), 0L); - assertHitCount(client().prepareSearch().setIndicesOptions(IndicesOptions.STRICT_EXPAND_OPEN_FORBID_CLOSED), 1L); - assertWarnings(TransportSearchAction.FROZEN_INDICES_DEPRECATION_MESSAGE.replace("{}", indexName)); - } - - public void testCreateReduceContext() { - SearchService service = getInstanceFromNode(SearchService.class); - AggregationReduceContext.Builder reduceContextBuilder = service.aggReduceContextBuilder( - () -> false, - new SearchRequest().source(new SearchSourceBuilder()).source().aggregations() - ); - { - AggregationReduceContext reduceContext = reduceContextBuilder.forFinalReduction(); - expectThrows( - MultiBucketConsumerService.TooManyBucketsException.class, - () -> reduceContext.consumeBucketsAndMaybeBreak(MultiBucketConsumerService.DEFAULT_MAX_BUCKETS + 1) - ); - } - { - AggregationReduceContext reduceContext = reduceContextBuilder.forPartialReduction(); - reduceContext.consumeBucketsAndMaybeBreak(MultiBucketConsumerService.DEFAULT_MAX_BUCKETS + 1); - } - } - - public void testMultiBucketConsumerServiceCB() { - MultiBucketConsumerService service = new MultiBucketConsumerService( - getInstanceFromNode(ClusterService.class), - Settings.EMPTY, - new NoopCircuitBreaker("test") { - - @Override - public void addEstimateBytesAndMaybeBreak(long bytes, String label) throws CircuitBreakingException { - throw new CircuitBreakingException("tripped", getDurability()); - } - } - ); - // for partial - { - IntConsumer consumer = service.createForPartial(); - for (int i = 0; i < 1023; i++) { - consumer.accept(0); - } - CircuitBreakingException ex = expectThrows(CircuitBreakingException.class, () -> consumer.accept(0)); - assertThat(ex.getMessage(), equalTo("tripped")); - } - // for final - { - IntConsumer consumer = service.createForFinal(); - for (int i = 0; i < 1023; i++) { - consumer.accept(0); - } - CircuitBreakingException ex = expectThrows(CircuitBreakingException.class, () -> consumer.accept(0)); - assertThat(ex.getMessage(), equalTo("tripped")); - } - } - - public void testCreateSearchContext() throws IOException { - String index = randomAlphaOfLengthBetween(5, 10).toLowerCase(Locale.ROOT); - IndexService indexService = createIndex(index); - final SearchService service = getInstanceFromNode(SearchService.class); - ShardId shardId = new ShardId(indexService.index(), 0); - long nowInMillis = System.currentTimeMillis(); - String clusterAlias = randomBoolean() ? null : randomAlphaOfLengthBetween(3, 10); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.allowPartialSearchResults(randomBoolean()); - ShardSearchRequest request = new ShardSearchRequest( - OriginalIndices.NONE, - searchRequest, - shardId, - 0, - indexService.numberOfShards(), - AliasFilter.EMPTY, - 1f, - nowInMillis, - clusterAlias - ); - try (SearchContext searchContext = service.createSearchContext(request, new TimeValue(System.currentTimeMillis()))) { - SearchShardTarget searchShardTarget = searchContext.shardTarget(); - SearchExecutionContext searchExecutionContext = searchContext.getSearchExecutionContext(); - String expectedIndexName = clusterAlias == null ? index : clusterAlias + ":" + index; - assertEquals(expectedIndexName, searchExecutionContext.getFullyQualifiedIndex().getName()); - assertEquals(expectedIndexName, searchShardTarget.getFullyQualifiedIndexName()); - assertEquals(clusterAlias, searchShardTarget.getClusterAlias()); - assertEquals(shardId, searchShardTarget.getShardId()); - - assertNull(searchContext.dfsResult()); - searchContext.addDfsResult(); - assertSame(searchShardTarget, searchContext.dfsResult().getSearchShardTarget()); - - assertNull(searchContext.queryResult()); - searchContext.addQueryResult(); - assertSame(searchShardTarget, searchContext.queryResult().getSearchShardTarget()); - - assertNull(searchContext.fetchResult()); - searchContext.addFetchResult(); - assertSame(searchShardTarget, searchContext.fetchResult().getSearchShardTarget()); - } - } - - /** - * While we have no NPE in DefaultContext constructor anymore, we still want to guard against it (or other failures) in the future to - * avoid leaking searchers. - */ - public void testCreateSearchContextFailure() throws Exception { - final String index = randomAlphaOfLengthBetween(5, 10).toLowerCase(Locale.ROOT); - final IndexService indexService = createIndex(index); - final SearchService service = getInstanceFromNode(SearchService.class); - final ShardId shardId = new ShardId(indexService.index(), 0); - final ShardSearchRequest request = new ShardSearchRequest(shardId, 0, null) { @Override - public SearchType searchType() { - // induce an artificial NPE - throw new NullPointerException("expected"); + IndexService getIndexService() { + // it's ok to return null because the three above methods are overridden + return null; } }; - try (ReaderContext reader = createReaderContext(indexService, indexService.getShard(shardId.id()))) { - NullPointerException e = expectThrows( - NullPointerException.class, - () -> service.createContext(reader, request, mock(SearchShardTask.class), ResultsType.NONE, randomBoolean()) - ); - assertEquals("expected", e.getMessage()); - } - // Needs to busily assert because Engine#refreshNeeded can increase the refCount. - assertBusy( - () -> assertEquals("should have 2 store refs (IndexService + InternalEngine)", 2, indexService.getShard(0).store().refCount()) - ); - } - - public void testMatchNoDocsEmptyResponse() throws InterruptedException { - createIndex("index"); - Thread currentThread = Thread.currentThread(); - SearchService service = getInstanceFromNode(SearchService.class); - IndicesService indicesService = getInstanceFromNode(IndicesService.class); - IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); - IndexShard indexShard = indexService.getShard(0); - SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(false) - .source(new SearchSourceBuilder().aggregation(AggregationBuilders.count("count").field("value"))); - ShardSearchRequest shardRequest = new ShardSearchRequest( - OriginalIndices.NONE, - searchRequest, - indexShard.shardId(), - 0, - 5, - AliasFilter.EMPTY, - 1.0f, - 0, - null - ); - SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); - - { - CountDownLatch latch = new CountDownLatch(1); - shardRequest.source().query(new MatchAllQueryBuilder()); - service.executeQueryPhase(shardRequest, task, new ActionListener<>() { - @Override - public void onResponse(SearchPhaseResult result) { - try { - assertNotSame(Thread.currentThread(), currentThread); - assertThat(Thread.currentThread().getName(), startsWith("elasticsearch[node_s_0][search]")); - assertThat(result, instanceOf(QuerySearchResult.class)); - assertFalse(result.queryResult().isNull()); - assertNotNull(result.queryResult().topDocs()); - assertNotNull(result.queryResult().aggregations()); - } finally { - latch.countDown(); - } - } - - @Override - public void onFailure(Exception exc) { - try { - throw new AssertionError(exc); - } finally { - latch.countDown(); - } - } - }); - latch.await(); - } - - { - CountDownLatch latch = new CountDownLatch(1); - shardRequest.source().query(new MatchNoneQueryBuilder()); - service.executeQueryPhase(shardRequest, task, new ActionListener<>() { - @Override - public void onResponse(SearchPhaseResult result) { - try { - assertNotSame(Thread.currentThread(), currentThread); - assertThat(Thread.currentThread().getName(), startsWith("elasticsearch[node_s_0][search]")); - assertThat(result, instanceOf(QuerySearchResult.class)); - assertFalse(result.queryResult().isNull()); - assertNotNull(result.queryResult().topDocs()); - assertNotNull(result.queryResult().aggregations()); - } finally { - latch.countDown(); - } - } - - @Override - public void onFailure(Exception exc) { - try { - throw new AssertionError(exc); - } finally { - latch.countDown(); - } - } - }); - latch.await(); - } - - { - CountDownLatch latch = new CountDownLatch(1); - shardRequest.canReturnNullResponseIfMatchNoDocs(true); - service.executeQueryPhase(shardRequest, task, new ActionListener<>() { - @Override - public void onResponse(SearchPhaseResult result) { - try { - // make sure we don't use the search threadpool - assertSame(Thread.currentThread(), currentThread); - assertThat(result, instanceOf(QuerySearchResult.class)); - assertTrue(result.queryResult().isNull()); - } finally { - latch.countDown(); - } - } - - @Override - public void onFailure(Exception e) { - try { - throw new AssertionError(e); - } finally { - latch.countDown(); - } - } - }); - latch.await(); - } - } - - public void testDeleteIndexWhileSearch() throws Exception { - createIndex("test"); - int numDocs = randomIntBetween(1, 20); - for (int i = 0; i < numDocs; i++) { - prepareIndex("test").setSource("f", "v").get(); - } - indicesAdmin().prepareRefresh("test").get(); - AtomicBoolean stopped = new AtomicBoolean(false); - Thread[] searchers = new Thread[randomIntBetween(1, 4)]; - CountDownLatch latch = new CountDownLatch(searchers.length); - for (int i = 0; i < searchers.length; i++) { - searchers[i] = new Thread(() -> { - latch.countDown(); - while (stopped.get() == false) { - try { - client().prepareSearch("test").setRequestCache(false).get().decRef(); - } catch (Exception ignored) { - return; - } - } - }); - searchers[i].start(); - } - latch.await(); - indicesAdmin().prepareDelete("test").get(); - stopped.set(true); - for (Thread searcher : searchers) { - searcher.join(); - } - } - - public void testLookUpSearchContext() throws Exception { - createIndex("index"); - SearchService searchService = getInstanceFromNode(SearchService.class); - IndicesService indicesService = getInstanceFromNode(IndicesService.class); - IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); - IndexShard indexShard = indexService.getShard(0); - List contextIds = new ArrayList<>(); - int numContexts = randomIntBetween(1, 10); - CountDownLatch latch = new CountDownLatch(1); - indexShard.getThreadPool().executor(ThreadPool.Names.SEARCH).execute(() -> { - try { - for (int i = 0; i < numContexts; i++) { - ShardSearchRequest request = new ShardSearchRequest( - OriginalIndices.NONE, - new SearchRequest().allowPartialSearchResults(true), - indexShard.shardId(), - 0, - 1, - AliasFilter.EMPTY, - 1.0f, - -1, - null - ); - final ReaderContext context = searchService.createAndPutReaderContext( - request, - indexService, - indexShard, - indexShard.acquireSearcherSupplier(), - SearchService.KEEPALIVE_INTERVAL_SETTING.get(Settings.EMPTY).millis() - ); - assertThat(context.id().getId(), equalTo((long) (i + 1))); - contextIds.add(context.id()); - } - assertThat(searchService.getActiveContexts(), equalTo(contextIds.size())); - while (contextIds.isEmpty() == false) { - final ShardSearchContextId contextId = randomFrom(contextIds); - assertFalse(searchService.freeReaderContext(new ShardSearchContextId(UUIDs.randomBase64UUID(), contextId.getId()))); - assertThat(searchService.getActiveContexts(), equalTo(contextIds.size())); - if (randomBoolean()) { - assertTrue(searchService.freeReaderContext(contextId)); - } else { - assertTrue( - searchService.freeReaderContext((new ShardSearchContextId(contextId.getSessionId(), contextId.getId()))) - ); - } - contextIds.remove(contextId); - assertThat(searchService.getActiveContexts(), equalTo(contextIds.size())); - assertFalse(searchService.freeReaderContext(contextId)); - assertThat(searchService.getActiveContexts(), equalTo(contextIds.size())); - } - } finally { - latch.countDown(); - } - }); - latch.await(); - } - - public void testOpenReaderContext() { - createIndex("index"); - SearchService searchService = getInstanceFromNode(SearchService.class); - PlainActionFuture future = new PlainActionFuture<>(); - searchService.openReaderContext(new ShardId(resolveIndex("index"), 0), TimeValue.timeValueMinutes(between(1, 10)), future); - future.actionGet(); - assertThat(searchService.getActiveContexts(), equalTo(1)); - assertTrue(searchService.freeReaderContext(future.actionGet())); - } - - public void testCancelQueryPhaseEarly() throws Exception { - createIndex("index"); - final MockSearchService service = (MockSearchService) getInstanceFromNode(SearchService.class); - final IndicesService indicesService = getInstanceFromNode(IndicesService.class); - final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); - final IndexShard indexShard = indexService.getShard(0); - SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); - ShardSearchRequest request = new ShardSearchRequest( - OriginalIndices.NONE, - searchRequest, - indexShard.shardId(), - 0, - 1, - AliasFilter.EMPTY, - 1.0f, - -1, - null - ); - - CountDownLatch latch1 = new CountDownLatch(1); - SearchShardTask task = new SearchShardTask(1, "", "", "", TaskId.EMPTY_TASK_ID, emptyMap()); - service.executeQueryPhase(request, task, new ActionListener<>() { - @Override - public void onResponse(SearchPhaseResult searchPhaseResult) { - service.freeReaderContext(searchPhaseResult.getContextId()); - latch1.countDown(); - } - - @Override - public void onFailure(Exception e) { - try { - fail("Search should not be cancelled"); - } finally { - latch1.countDown(); - } - } - }); - latch1.await(); - - CountDownLatch latch2 = new CountDownLatch(1); - service.executeDfsPhase(request, task, new ActionListener<>() { - @Override - public void onResponse(SearchPhaseResult searchPhaseResult) { - service.freeReaderContext(searchPhaseResult.getContextId()); - latch2.countDown(); - } - - @Override - public void onFailure(Exception e) { - try { - fail("Search should not be cancelled"); - } finally { - latch2.countDown(); - } - } - }); - latch2.await(); - - AtomicBoolean searchContextCreated = new AtomicBoolean(false); - service.setOnCreateSearchContext(c -> searchContextCreated.set(true)); - CountDownLatch latch3 = new CountDownLatch(1); - TaskCancelHelper.cancel(task, "simulated"); - service.executeQueryPhase(request, task, new ActionListener<>() { - @Override - public void onResponse(SearchPhaseResult searchPhaseResult) { - try { - fail("Search not cancelled early"); - } finally { - service.freeReaderContext(searchPhaseResult.getContextId()); - searchPhaseResult.decRef(); - latch3.countDown(); - } - } - - @Override - public void onFailure(Exception e) { - assertThat(e, is(instanceOf(TaskCancelledException.class))); - assertThat(e.getMessage(), is("task cancelled [simulated]")); - assertThat(((TaskCancelledException) e).status(), is(RestStatus.BAD_REQUEST)); - assertThat(searchContextCreated.get(), is(false)); - latch3.countDown(); - } - }); - latch3.await(); - - searchContextCreated.set(false); - CountDownLatch latch4 = new CountDownLatch(1); - service.executeDfsPhase(request, task, new ActionListener<>() { - @Override - public void onResponse(SearchPhaseResult searchPhaseResult) { - try { - fail("Search not cancelled early"); - } finally { - service.freeReaderContext(searchPhaseResult.getContextId()); - latch4.countDown(); - } - } - - @Override - public void onFailure(Exception e) { - assertThat(e, is(instanceOf(TaskCancelledException.class))); - assertThat(e.getMessage(), is("task cancelled [simulated]")); - assertThat(((TaskCancelledException) e).status(), is(RestStatus.BAD_REQUEST)); - assertThat(searchContextCreated.get(), is(false)); - latch4.countDown(); - } - }); - latch4.await(); - } - - public void testCancelFetchPhaseEarly() throws Exception { - createIndex("index"); - final MockSearchService service = (MockSearchService) getInstanceFromNode(SearchService.class); - SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); - - AtomicBoolean searchContextCreated = new AtomicBoolean(false); - service.setOnCreateSearchContext(c -> searchContextCreated.set(true)); - - // Test fetch phase is cancelled early - String scrollId; - var searchResponse = client().search(searchRequest.allowPartialSearchResults(false).scroll(TimeValue.timeValueMinutes(10))).get(); - try { - scrollId = searchResponse.getScrollId(); - } finally { - searchResponse.decRef(); - } - - client().searchScroll(new SearchScrollRequest(scrollId)).get().decRef(); - assertThat(searchContextCreated.get(), is(true)); - - ClearScrollRequest clearScrollRequest = new ClearScrollRequest(); - clearScrollRequest.addScrollId(scrollId); - client().clearScroll(clearScrollRequest); - - searchResponse = client().search(searchRequest.allowPartialSearchResults(false).scroll(TimeValue.timeValueMinutes(10))).get(); - try { - scrollId = searchResponse.getScrollId(); - } finally { - searchResponse.decRef(); - } - searchContextCreated.set(false); - service.setOnCheckCancelled(t -> { - SearchShardTask task = new SearchShardTask(randomLong(), "transport", "action", "", TaskId.EMPTY_TASK_ID, emptyMap()); - TaskCancelHelper.cancel(task, "simulated"); - return task; - }); - CountDownLatch latch = new CountDownLatch(1); - client().searchScroll(new SearchScrollRequest(scrollId), new ActionListener<>() { - @Override - public void onResponse(SearchResponse searchResponse) { - try { - fail("Search not cancelled early"); - } finally { - latch.countDown(); - } - } - - @Override - public void onFailure(Exception e) { - Throwable cancelledExc = e.getCause().getCause(); - assertThat(cancelledExc, is(instanceOf(TaskCancelledException.class))); - assertThat(cancelledExc.getMessage(), is("task cancelled [simulated]")); - assertThat(((TaskCancelledException) cancelledExc).status(), is(RestStatus.BAD_REQUEST)); - latch.countDown(); - } - }); - latch.await(); - assertThat(searchContextCreated.get(), is(false)); - - clearScrollRequest.setScrollIds(singletonList(scrollId)); - client().clearScroll(clearScrollRequest); - } - - public void testWaitOnRefresh() throws ExecutionException, InterruptedException { - createIndex("index"); - final SearchService service = getInstanceFromNode(SearchService.class); - final IndicesService indicesService = getInstanceFromNode(IndicesService.class); - final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); - final IndexShard indexShard = indexService.getShard(0); - SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); - searchRequest.setWaitForCheckpointsTimeout(TimeValue.timeValueSeconds(30)); - searchRequest.setWaitForCheckpoints(Collections.singletonMap("index", new long[] { 0 })); - - final DocWriteResponse response = prepareIndex("index").setSource("id", "1").get(); - assertEquals(RestStatus.CREATED, response.status()); - - SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); - ShardSearchRequest request = new ShardSearchRequest( - OriginalIndices.NONE, - searchRequest, - indexShard.shardId(), - 0, - 1, - AliasFilter.EMPTY, - 1.0f, - -1, - null, - null, - null - ); - PlainActionFuture future = new PlainActionFuture<>(); - service.executeQueryPhase(request, task, future.delegateFailure((l, r) -> { - assertEquals(1, r.queryResult().getTotalHits().value()); - l.onResponse(null); - })); - future.get(); - } - - public void testWaitOnRefreshFailsWithRefreshesDisabled() { - createIndex("index", Settings.builder().put("index.refresh_interval", "-1").build()); - final SearchService service = getInstanceFromNode(SearchService.class); - final IndicesService indicesService = getInstanceFromNode(IndicesService.class); - final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); - final IndexShard indexShard = indexService.getShard(0); - SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); - searchRequest.setWaitForCheckpointsTimeout(TimeValue.timeValueSeconds(30)); - searchRequest.setWaitForCheckpoints(Collections.singletonMap("index", new long[] { 0 })); - - final DocWriteResponse response = prepareIndex("index").setSource("id", "1").get(); - assertEquals(RestStatus.CREATED, response.status()); - - SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); - PlainActionFuture future = new PlainActionFuture<>(); - ShardSearchRequest request = new ShardSearchRequest( - OriginalIndices.NONE, - searchRequest, - indexShard.shardId(), - 0, - 1, - AliasFilter.EMPTY, - 1.0f, - -1, - null, - null, - null - ); - service.executeQueryPhase(request, task, future); - IllegalArgumentException illegalArgumentException = expectThrows(IllegalArgumentException.class, future::actionGet); - assertThat( - illegalArgumentException.getMessage(), - containsString("Cannot use wait_for_checkpoints with [index.refresh_interval=-1]") - ); - } - - public void testWaitOnRefreshFailsIfCheckpointNotIndexed() { - createIndex("index"); - final SearchService service = getInstanceFromNode(SearchService.class); - final IndicesService indicesService = getInstanceFromNode(IndicesService.class); - final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); - final IndexShard indexShard = indexService.getShard(0); - SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); - // Increased timeout to avoid cancelling the search task prior to its completion, - // as we expect to raise an Exception. Timeout itself is tested on the following `testWaitOnRefreshTimeout` test. - searchRequest.setWaitForCheckpointsTimeout(TimeValue.timeValueMillis(randomIntBetween(200, 300))); - searchRequest.setWaitForCheckpoints(Collections.singletonMap("index", new long[] { 1 })); - - final DocWriteResponse response = prepareIndex("index").setSource("id", "1").get(); - assertEquals(RestStatus.CREATED, response.status()); - - SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); - PlainActionFuture future = new PlainActionFuture<>(); - ShardSearchRequest request = new ShardSearchRequest( - OriginalIndices.NONE, - searchRequest, - indexShard.shardId(), - 0, - 1, - AliasFilter.EMPTY, - 1.0f, - -1, - null, - null, - null - ); - service.executeQueryPhase(request, task, future); - - IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, future::actionGet); - assertThat( - ex.getMessage(), - containsString("Cannot wait for unissued seqNo checkpoint [wait_for_checkpoint=1, max_issued_seqNo=0]") - ); - } - - public void testWaitOnRefreshTimeout() { - createIndex("index", Settings.builder().put("index.refresh_interval", "60s").build()); - final SearchService service = getInstanceFromNode(SearchService.class); - final IndicesService indicesService = getInstanceFromNode(IndicesService.class); - final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); - final IndexShard indexShard = indexService.getShard(0); - SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); - searchRequest.setWaitForCheckpointsTimeout(TimeValue.timeValueMillis(randomIntBetween(10, 100))); - searchRequest.setWaitForCheckpoints(Collections.singletonMap("index", new long[] { 0 })); - - final DocWriteResponse response = prepareIndex("index").setSource("id", "1").get(); - assertEquals(RestStatus.CREATED, response.status()); - - SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap()); - PlainActionFuture future = new PlainActionFuture<>(); - ShardSearchRequest request = new ShardSearchRequest( - OriginalIndices.NONE, - searchRequest, - indexShard.shardId(), - 0, - 1, - AliasFilter.EMPTY, - 1.0f, - -1, - null, - null, - null - ); - service.executeQueryPhase(request, task, future); - - SearchTimeoutException ex = expectThrows(SearchTimeoutException.class, future::actionGet); - assertThat(ex.getMessage(), containsString("Wait for seq_no [0] refreshed timed out [")); - } - - public void testMinimalSearchSourceInShardRequests() { - createIndex("test"); - int numDocs = between(0, 10); - for (int i = 0; i < numDocs; i++) { - prepareIndex("test").setSource("id", Integer.toString(i)).get(); - } - indicesAdmin().prepareRefresh("test").get(); - - BytesReference pitId = client().execute( - TransportOpenPointInTimeAction.TYPE, - new OpenPointInTimeRequest("test").keepAlive(TimeValue.timeValueMinutes(10)) - ).actionGet().getPointInTimeId(); - final MockSearchService searchService = (MockSearchService) getInstanceFromNode(SearchService.class); - final List shardRequests = new CopyOnWriteArrayList<>(); - searchService.setOnCreateSearchContext(ctx -> shardRequests.add(ctx.request())); - try { - assertHitCount( - client().prepareSearch() - .setSource( - new SearchSourceBuilder().size(between(numDocs, numDocs * 2)).pointInTimeBuilder(new PointInTimeBuilder(pitId)) - ), - numDocs - ); - } finally { - client().execute(TransportClosePointInTimeAction.TYPE, new ClosePointInTimeRequest(pitId)).actionGet(); - } - assertThat(shardRequests, not(emptyList())); - for (ShardSearchRequest shardRequest : shardRequests) { - assertNotNull(shardRequest.source()); - assertNotNull(shardRequest.source().pointInTimeBuilder()); - assertThat(shardRequest.source().pointInTimeBuilder().getEncodedId(), equalTo(BytesArray.EMPTY)); - } - } - - public void testDfsQueryPhaseRewrite() { - createIndex("index"); - prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); - final SearchService service = getInstanceFromNode(SearchService.class); - final IndicesService indicesService = getInstanceFromNode(IndicesService.class); - final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index")); - final IndexShard indexShard = indexService.getShard(0); - SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); - searchRequest.source(SearchSourceBuilder.searchSource().query(new TestRewriteCounterQueryBuilder())); - ShardSearchRequest request = new ShardSearchRequest( - OriginalIndices.NONE, - searchRequest, - indexShard.shardId(), - 0, - 1, - AliasFilter.EMPTY, - 1.0f, - -1, - null - ); - final Engine.SearcherSupplier reader = indexShard.acquireSearcherSupplier(); - ReaderContext context = service.createAndPutReaderContext( - request, - indexService, - indexShard, - reader, - SearchService.KEEPALIVE_INTERVAL_SETTING.get(Settings.EMPTY).millis() - ); - PlainActionFuture plainActionFuture = new PlainActionFuture<>(); - service.executeQueryPhase( - new QuerySearchRequest(null, context.id(), request, new AggregatedDfs(Map.of(), Map.of(), 10)), - new SearchShardTask(42L, "", "", "", null, emptyMap()), - plainActionFuture - ); - - plainActionFuture.actionGet(); - assertThat(((TestRewriteCounterQueryBuilder) request.source().query()).asyncRewriteCount, equalTo(1)); - final ShardSearchContextId contextId = context.id(); - assertTrue(service.freeReaderContext(contextId)); - } - - public void testEnableSearchWorkerThreads() throws IOException { - IndexService indexService = createIndex("index", Settings.EMPTY); - IndexShard indexShard = indexService.getShard(0); - ShardSearchRequest request = new ShardSearchRequest( - OriginalIndices.NONE, - new SearchRequest().allowPartialSearchResults(randomBoolean()), - indexShard.shardId(), - 0, - indexService.numberOfShards(), - AliasFilter.EMPTY, - 1f, - System.currentTimeMillis(), - null - ); - try (ReaderContext readerContext = createReaderContext(indexService, indexShard)) { - SearchService service = getInstanceFromNode(SearchService.class); - SearchShardTask task = new SearchShardTask(0, "type", "action", "description", null, emptyMap()); - - try (SearchContext searchContext = service.createContext(readerContext, request, task, ResultsType.DFS, randomBoolean())) { - assertTrue(searchContext.searcher().hasExecutor()); - } - - try { - ClusterUpdateSettingsResponse response = client().admin() - .cluster() - .prepareUpdateSettings(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT) - .setPersistentSettings(Settings.builder().put(SEARCH_WORKER_THREADS_ENABLED.getKey(), false).build()) - .get(); - assertTrue(response.isAcknowledged()); - try (SearchContext searchContext = service.createContext(readerContext, request, task, ResultsType.DFS, randomBoolean())) { - assertFalse(searchContext.searcher().hasExecutor()); - } - } finally { - // reset original default setting - client().admin() - .cluster() - .prepareUpdateSettings(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT) - .setPersistentSettings(Settings.builder().putNull(SEARCH_WORKER_THREADS_ENABLED.getKey()).build()) - .get(); - try (SearchContext searchContext = service.createContext(readerContext, request, task, ResultsType.DFS, randomBoolean())) { - assertTrue(searchContext.searcher().hasExecutor()); - } - } - } - } - - /** - * Verify that a single slice is created for requests that don't support parallel collection, while an executor is still - * provided to the searcher to parallelize other operations. Also ensure multiple slices are created for requests that do support - * parallel collection. - */ - public void testSlicingBehaviourForParallelCollection() throws Exception { - IndexService indexService = createIndex("index", Settings.EMPTY); - ThreadPoolExecutor executor = (ThreadPoolExecutor) indexService.getThreadPool().executor(ThreadPool.Names.SEARCH); - final int configuredMaxPoolSize = 10; - executor.setMaximumPoolSize(configuredMaxPoolSize); // We set this explicitly to be independent of CPU cores. - int numDocs = randomIntBetween(50, 100); - for (int i = 0; i < numDocs; i++) { - prepareIndex("index").setId(String.valueOf(i)).setSource("field", "value").get(); - if (i % 5 == 0) { - indicesAdmin().prepareRefresh("index").get(); - } - } - final IndexShard indexShard = indexService.getShard(0); - ShardSearchRequest request = new ShardSearchRequest( - OriginalIndices.NONE, - new SearchRequest().allowPartialSearchResults(randomBoolean()), - indexShard.shardId(), - 0, - indexService.numberOfShards(), - AliasFilter.EMPTY, - 1f, - System.currentTimeMillis(), - null - ); - SearchService service = getInstanceFromNode(SearchService.class); - NonCountingTermQuery termQuery = new NonCountingTermQuery(new Term("field", "value")); - assertEquals(0, executor.getCompletedTaskCount()); - try (ReaderContext readerContext = createReaderContext(indexService, indexShard)) { - SearchShardTask task = new SearchShardTask(0, "type", "action", "description", null, emptyMap()); - { - try (SearchContext searchContext = service.createContext(readerContext, request, task, ResultsType.DFS, true)) { - ContextIndexSearcher searcher = searchContext.searcher(); - assertTrue(searcher.hasExecutor()); - - final int maxPoolSize = executor.getMaximumPoolSize(); - assertEquals( - "Sanity check to ensure this isn't the default of 1 when pool size is unset", - configuredMaxPoolSize, - maxPoolSize - ); - - final int expectedSlices = ContextIndexSearcher.computeSlices( - searcher.getIndexReader().leaves(), - maxPoolSize, - 1 - ).length; - assertNotEquals("Sanity check to ensure this isn't the default of 1 when pool size is unset", 1, expectedSlices); - - final long priorExecutorTaskCount = executor.getCompletedTaskCount(); - searcher.search(termQuery, new TotalHitCountCollectorManager(searcher.getSlices())); - assertBusy( - () -> assertEquals( - "DFS supports parallel collection, so the number of slices should be > 1.", - expectedSlices - 1, // one slice executes on the calling thread - executor.getCompletedTaskCount() - priorExecutorTaskCount - ) - ); - } - } - { - try (SearchContext searchContext = service.createContext(readerContext, request, task, ResultsType.QUERY, true)) { - ContextIndexSearcher searcher = searchContext.searcher(); - assertTrue(searcher.hasExecutor()); - - final int maxPoolSize = executor.getMaximumPoolSize(); - assertEquals( - "Sanity check to ensure this isn't the default of 1 when pool size is unset", - configuredMaxPoolSize, - maxPoolSize - ); - - final int expectedSlices = ContextIndexSearcher.computeSlices( - searcher.getIndexReader().leaves(), - maxPoolSize, - 1 - ).length; - assertNotEquals("Sanity check to ensure this isn't the default of 1 when pool size is unset", 1, expectedSlices); - - final long priorExecutorTaskCount = executor.getCompletedTaskCount(); - searcher.search(termQuery, new TotalHitCountCollectorManager(searcher.getSlices())); - assertBusy( - () -> assertEquals( - "QUERY supports parallel collection when enabled, so the number of slices should be > 1.", - expectedSlices - 1, // one slice executes on the calling thread - executor.getCompletedTaskCount() - priorExecutorTaskCount - ) - ); - } - } - { - try (SearchContext searchContext = service.createContext(readerContext, request, task, ResultsType.FETCH, true)) { - ContextIndexSearcher searcher = searchContext.searcher(); - assertFalse(searcher.hasExecutor()); - final long priorExecutorTaskCount = executor.getCompletedTaskCount(); - searcher.search(termQuery, new TotalHitCountCollectorManager(searcher.getSlices())); - assertBusy( - () -> assertEquals( - "The number of slices should be 1 as FETCH does not support parallel collection and thus runs on the calling" - + " thread.", - 0, - executor.getCompletedTaskCount() - priorExecutorTaskCount - ) - ); - } - } - { - try (SearchContext searchContext = service.createContext(readerContext, request, task, ResultsType.NONE, true)) { - ContextIndexSearcher searcher = searchContext.searcher(); - assertFalse(searcher.hasExecutor()); - final long priorExecutorTaskCount = executor.getCompletedTaskCount(); - searcher.search(termQuery, new TotalHitCountCollectorManager(searcher.getSlices())); - assertBusy( - () -> assertEquals( - "The number of slices should be 1 as NONE does not support parallel collection.", - 0, // zero since one slice executes on the calling thread - executor.getCompletedTaskCount() - priorExecutorTaskCount - ) - ); - } - } - - try { - ClusterUpdateSettingsResponse response = client().admin() - .cluster() - .prepareUpdateSettings(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT) - .setPersistentSettings(Settings.builder().put(QUERY_PHASE_PARALLEL_COLLECTION_ENABLED.getKey(), false).build()) - .get(); - assertTrue(response.isAcknowledged()); - { - try (SearchContext searchContext = service.createContext(readerContext, request, task, ResultsType.QUERY, true)) { - ContextIndexSearcher searcher = searchContext.searcher(); - assertFalse(searcher.hasExecutor()); - final long priorExecutorTaskCount = executor.getCompletedTaskCount(); - searcher.search(termQuery, new TotalHitCountCollectorManager(searcher.getSlices())); - assertBusy( - () -> assertEquals( - "The number of slices should be 1 when QUERY parallel collection is disabled.", - 0, // zero since one slice executes on the calling thread - executor.getCompletedTaskCount() - priorExecutorTaskCount - ) - ); - } - } - } finally { - // Reset to the original default setting and check to ensure it takes effect. - client().admin() - .cluster() - .prepareUpdateSettings(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT) - .setPersistentSettings(Settings.builder().putNull(QUERY_PHASE_PARALLEL_COLLECTION_ENABLED.getKey()).build()) - .get(); - { - try (SearchContext searchContext = service.createContext(readerContext, request, task, ResultsType.QUERY, true)) { - ContextIndexSearcher searcher = searchContext.searcher(); - assertTrue(searcher.hasExecutor()); - - final int maxPoolSize = executor.getMaximumPoolSize(); - assertEquals( - "Sanity check to ensure this isn't the default of 1 when pool size is unset", - configuredMaxPoolSize, - maxPoolSize - ); - - final int expectedSlices = ContextIndexSearcher.computeSlices( - searcher.getIndexReader().leaves(), - maxPoolSize, - 1 - ).length; - assertNotEquals("Sanity check to ensure this isn't the default of 1 when pool size is unset", 1, expectedSlices); - - final long priorExecutorTaskCount = executor.getCompletedTaskCount(); - searcher.search(termQuery, new TotalHitCountCollectorManager(searcher.getSlices())); - assertBusy( - () -> assertEquals( - "QUERY supports parallel collection when enabled, so the number of slices should be > 1.", - expectedSlices - 1, // one slice executes on the calling thread - executor.getCompletedTaskCount() - priorExecutorTaskCount - ) - ); - } - } - } - } - } - - private static ReaderContext createReaderContext(IndexService indexService, IndexShard indexShard) { - return new ReaderContext( - new ShardSearchContextId(UUIDs.randomBase64UUID(), randomNonNegativeLong()), - indexService, - indexShard, - indexShard.acquireSearcherSupplier(), - randomNonNegativeLong(), - false - ); - } - - private static class TestRewriteCounterQueryBuilder extends AbstractQueryBuilder { - - final int asyncRewriteCount; - final Supplier fetched; - - TestRewriteCounterQueryBuilder() { - asyncRewriteCount = 0; - fetched = null; - } - - private TestRewriteCounterQueryBuilder(int asyncRewriteCount, Supplier fetched) { - this.asyncRewriteCount = asyncRewriteCount; - this.fetched = fetched; - } - - @Override - public String getWriteableName() { - return "test_query"; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ZERO; - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException {} - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException {} - - @Override - protected Query doToQuery(SearchExecutionContext context) throws IOException { - return new MatchAllDocsQuery(); - } - - @Override - protected boolean doEquals(TestRewriteCounterQueryBuilder other) { - return true; - } - - @Override - protected int doHashCode() { - return 42; - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (asyncRewriteCount > 0) { - return this; - } - if (fetched != null) { - if (fetched.get() == null) { - return this; - } - assert fetched.get(); - return new TestRewriteCounterQueryBuilder(1, null); - } - if (queryRewriteContext.convertToDataRewriteContext() != null) { - SetOnce awaitingFetch = new SetOnce<>(); - queryRewriteContext.registerAsyncAction((c, l) -> { - awaitingFetch.set(true); - l.onResponse(null); - }); - return new TestRewriteCounterQueryBuilder(0, awaitingFetch::get); - } - return this; - } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/termsenum/action/TransportTermsEnumAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/termsenum/action/TransportTermsEnumAction.java index 9164fd88b6395..08e89a0fcab00 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/termsenum/action/TransportTermsEnumAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/termsenum/action/TransportTermsEnumAction.java @@ -492,7 +492,7 @@ private boolean hasMatchAllEquivalent( return false; } - private boolean canMatchShard(ShardId shardId, NodeTermsEnumRequest req) throws IOException { + private boolean canMatchShard(ShardId shardId, NodeTermsEnumRequest req) { if (req.indexFilter() == null || req.indexFilter() instanceof MatchAllQueryBuilder) { return true; } diff --git a/x-pack/plugin/frozen-indices/src/internalClusterTest/java/org/elasticsearch/index/engine/frozen/FrozenIndexTests.java b/x-pack/plugin/frozen-indices/src/internalClusterTest/java/org/elasticsearch/index/engine/frozen/FrozenIndexTests.java index 8ba88865e361a..89d80cf34aec5 100644 --- a/x-pack/plugin/frozen-indices/src/internalClusterTest/java/org/elasticsearch/index/engine/frozen/FrozenIndexTests.java +++ b/x-pack/plugin/frozen-indices/src/internalClusterTest/java/org/elasticsearch/index/engine/frozen/FrozenIndexTests.java @@ -46,7 +46,6 @@ import org.elasticsearch.plugins.Plugin; import org.elasticsearch.protocol.xpack.frozen.FreezeRequest; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.search.SearchContextMissingException; import org.elasticsearch.search.SearchService; import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; @@ -469,27 +468,25 @@ public void testCanMatch() throws IOException { ).canMatch() ); - expectThrows(SearchContextMissingException.class, () -> { - ShardSearchContextId withoutCommitId = new ShardSearchContextId(contextId.getSessionId(), contextId.getId(), null); - sourceBuilder.query(QueryBuilders.rangeQuery("field").gt("2010-01-06T02:00").lt("2010-01-07T02:00")); - assertFalse( - searchService.canMatch( - new ShardSearchRequest( - OriginalIndices.NONE, - searchRequest, - shard.shardId(), - 0, - 1, - AliasFilter.EMPTY, - 1f, - -1, - null, - withoutCommitId, - null - ) - ).canMatch() - ); - }); + ShardSearchContextId withoutCommitId = new ShardSearchContextId(contextId.getSessionId(), contextId.getId(), null); + sourceBuilder.query(QueryBuilders.rangeQuery("field").gt("2010-01-06T02:00").lt("2010-01-07T02:00")); + assertTrue( + searchService.canMatch( + new ShardSearchRequest( + OriginalIndices.NONE, + searchRequest, + shard.shardId(), + 0, + 1, + AliasFilter.EMPTY, + 1f, + -1, + null, + withoutCommitId, + null + ) + ).canMatch() + ); } } diff --git a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/BaseSearchableSnapshotsIntegTestCase.java b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/BaseSearchableSnapshotsIntegTestCase.java index 3ee49cce85a8a..6115bec91ad62 100644 --- a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/BaseSearchableSnapshotsIntegTestCase.java +++ b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/BaseSearchableSnapshotsIntegTestCase.java @@ -355,6 +355,18 @@ protected void assertExecutorIsIdle(String executorName) throws Exception { }); } + protected static void waitUntilRecoveryIsDone(String index) throws Exception { + assertBusy(() -> { + RecoveryResponse recoveryResponse = indicesAdmin().prepareRecoveries(index).get(); + assertThat(recoveryResponse.hasRecoveries(), equalTo(true)); + for (List value : recoveryResponse.shardRecoveryStates().values()) { + for (RecoveryState recoveryState : value) { + assertThat(recoveryState.getStage(), equalTo(RecoveryState.Stage.DONE)); + } + } + }); + } + public static class LicensedSnapshotBasedRecoveriesPlugin extends SnapshotBasedRecoveriesPlugin { public LicensedSnapshotBasedRecoveriesPlugin(Settings settings) { diff --git a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsCanMatchOnCoordinatorIntegTests.java b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsCanMatchOnCoordinatorIntegTests.java index 23e414c0dc1bf..291161c090c27 100644 --- a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsCanMatchOnCoordinatorIntegTests.java +++ b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsCanMatchOnCoordinatorIntegTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.searchablesnapshots; -import org.elasticsearch.action.admin.indices.recovery.RecoveryResponse; import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.search.SearchPhaseExecutionException; import org.elasticsearch.action.search.SearchRequest; @@ -36,7 +35,6 @@ import org.elasticsearch.index.shard.IndexLongFieldRange; import org.elasticsearch.indices.DateFieldRangeInfo; import org.elasticsearch.indices.IndicesService; -import org.elasticsearch.indices.recovery.RecoveryState; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.snapshots.SnapshotId; @@ -1324,18 +1322,6 @@ private static IndexMetadata getIndexMetadata(String indexName) { .index(indexName); } - private static void waitUntilRecoveryIsDone(String index) throws Exception { - assertBusy(() -> { - RecoveryResponse recoveryResponse = indicesAdmin().prepareRecoveries(index).get(); - assertThat(recoveryResponse.hasRecoveries(), equalTo(true)); - for (List value : recoveryResponse.shardRecoveryStates().values()) { - for (RecoveryState recoveryState : value) { - assertThat(recoveryState.getStage(), equalTo(RecoveryState.Stage.DONE)); - } - } - }); - } - private void waitUntilAllShardsAreUnassigned(Index index) throws Exception { awaitClusterState(state -> state.getRoutingTable().index(index).allPrimaryShardsUnassigned()); } diff --git a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsSearchIntegTests.java b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsSearchIntegTests.java new file mode 100644 index 0000000000000..f77a3f5698c98 --- /dev/null +++ b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsSearchIntegTests.java @@ -0,0 +1,129 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.searchablesnapshots; + +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchType; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.snapshots.SnapshotId; +import org.elasticsearch.xpack.core.searchablesnapshots.MountSearchableSnapshotAction; +import org.elasticsearch.xpack.core.searchablesnapshots.MountSearchableSnapshotRequest; + +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + +import static org.elasticsearch.cluster.metadata.IndexMetadata.INDEX_ROUTING_REQUIRE_GROUP_SETTING; +import static org.elasticsearch.index.IndexSettings.INDEX_SOFT_DELETES_SETTING; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; +import static org.hamcrest.Matchers.equalTo; + +public class SearchableSnapshotsSearchIntegTests extends BaseFrozenSearchableSnapshotsIntegTestCase { + + /** + * Tests basic search functionality with a query sorted by field against partially mounted indices + * The can match phase is always executed against read only indices, and for sorted queries it extracts the min and max range from + * each shard. This will happen not only in the can match phase, but optionally also in the query phase. + * See {@link org.elasticsearch.search.internal.ShardSearchRequest#canReturnNullResponseIfMatchNoDocs()}. + * For keyword fields, it is not possible to retrieve min and max from the index reader on frozen, hence we need to make sure that + * while that fails, the query will go ahead and won't return shard failures. + */ + public void testKeywordSortedQueryOnFrozen() throws Exception { + internalCluster().startMasterOnlyNode(); + internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); + final String dataNodeHoldingRegularIndex = internalCluster().startDataOnlyNode(); + String dataNodeHoldingSearchableSnapshot = internalCluster().startDataOnlyNode(); + + String[] indices = new String[] { "index-0001", "index-0002" }; + for (String index : indices) { + Settings extraSettings = Settings.builder() + .put(INDEX_ROUTING_REQUIRE_GROUP_SETTING.getConcreteSettingForNamespace("_name").getKey(), dataNodeHoldingRegularIndex) + .build(); + // we use a high number of shards because that's more likely to trigger can match as part of query phase: + // see ShardSearchRequest#canReturnNullResponseIfMatchNoDocs + assertAcked( + indicesAdmin().prepareCreate(index) + .setSettings(indexSettingsNoReplicas(10).put(INDEX_SOFT_DELETES_SETTING.getKey(), true).put(extraSettings)) + ); + } + ensureGreen(indices); + + for (String index : indices) { + final List indexRequestBuilders = new ArrayList<>(); + indexRequestBuilders.add(prepareIndex(index).setSource("keyword", "value1")); + indexRequestBuilders.add(prepareIndex(index).setSource("keyword", "value2")); + indexRandom(true, false, indexRequestBuilders); + assertThat( + indicesAdmin().prepareForceMerge(index).setOnlyExpungeDeletes(true).setFlush(true).get().getFailedShards(), + equalTo(0) + ); + refresh(index); + forceMerge(); + } + + final String repositoryName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT); + createRepository(repositoryName, "mock"); + + final SnapshotId snapshotId = createSnapshot(repositoryName, "snapshot-1", List.of(indices[0])).snapshotId(); + assertAcked(indicesAdmin().prepareDelete(indices[0])); + + // Block the repository for the node holding the searchable snapshot shards + // to delay its restore + blockDataNode(repositoryName, dataNodeHoldingSearchableSnapshot); + + // Force the searchable snapshot to be allocated in a particular node + Settings restoredIndexSettings = Settings.builder() + .put(INDEX_ROUTING_REQUIRE_GROUP_SETTING.getConcreteSettingForNamespace("_name").getKey(), dataNodeHoldingSearchableSnapshot) + .build(); + String[] mountedIndices = new String[indices.length]; + for (int i = 0; i < indices.length; i++) { + + String index = indices[i]; + String mountedIndex = index + "-mounted"; + mountedIndices[i] = mountedIndex; + final MountSearchableSnapshotRequest mountRequest = new MountSearchableSnapshotRequest( + TEST_REQUEST_TIMEOUT, + mountedIndex, + repositoryName, + snapshotId.getName(), + indices[0], + restoredIndexSettings, + Strings.EMPTY_ARRAY, + false, + randomFrom(MountSearchableSnapshotRequest.Storage.values()) + ); + client().execute(MountSearchableSnapshotAction.INSTANCE, mountRequest).actionGet(); + } + + // Allow the searchable snapshots to be finally mounted + unblockNode(repositoryName, dataNodeHoldingSearchableSnapshot); + for (String mountedIndex : mountedIndices) { + waitUntilRecoveryIsDone(mountedIndex); + } + ensureGreen(mountedIndices); + + SearchRequest request = new SearchRequest(mountedIndices).searchType(SearchType.QUERY_THEN_FETCH) + .source(SearchSourceBuilder.searchSource().sort("keyword.keyword")) + .allowPartialSearchResults(false); + if (randomBoolean()) { + request.setPreFilterShardSize(100); + } + + assertResponse(client().search(request), searchResponse -> { + assertThat(searchResponse.getSuccessfulShards(), equalTo(20)); + assertThat(searchResponse.getFailedShards(), equalTo(0)); + assertThat(searchResponse.getSkippedShards(), equalTo(0)); + assertThat(searchResponse.getTotalShards(), equalTo(20)); + assertThat(searchResponse.getHits().getTotalHits().value(), equalTo(4L)); + }); + } +} From cca7bb12501475ce967ba8baa8ca0149083119a2 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Wed, 11 Dec 2024 09:04:56 +1100 Subject: [PATCH 09/28] Mute org.elasticsearch.docker.test.DockerYmlTestSuiteIT test {p0=/10_info/Info} #118394 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index 3d0a57c8d650a..0d485a0f2723c 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -294,6 +294,9 @@ tests: - class: org.elasticsearch.xpack.searchablesnapshots.RetrySearchIntegTests method: testSearcherId issue: https://github.com/elastic/elasticsearch/issues/118374 +- class: org.elasticsearch.docker.test.DockerYmlTestSuiteIT + method: test {p0=/10_info/Info} + issue: https://github.com/elastic/elasticsearch/issues/118394 # Examples: # From 637ab7f846d18139b658e0cdba8266556253ab8a Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Wed, 11 Dec 2024 09:05:14 +1100 Subject: [PATCH 10/28] Mute org.elasticsearch.docker.test.DockerYmlTestSuiteIT test {p0=/11_nodes/Additional disk information} #118395 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index 0d485a0f2723c..f6e9bf21dc0a6 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -297,6 +297,9 @@ tests: - class: org.elasticsearch.docker.test.DockerYmlTestSuiteIT method: test {p0=/10_info/Info} issue: https://github.com/elastic/elasticsearch/issues/118394 +- class: org.elasticsearch.docker.test.DockerYmlTestSuiteIT + method: test {p0=/11_nodes/Additional disk information} + issue: https://github.com/elastic/elasticsearch/issues/118395 # Examples: # From 33e3a59cc91cc2a5794714db8ed6afa88ae0afa3 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Wed, 11 Dec 2024 09:05:29 +1100 Subject: [PATCH 11/28] Mute org.elasticsearch.docker.test.DockerYmlTestSuiteIT test {p0=/11_nodes/Test cat nodes output with full_id set} #118396 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index f6e9bf21dc0a6..5ebe7a78bf3ae 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -300,6 +300,9 @@ tests: - class: org.elasticsearch.docker.test.DockerYmlTestSuiteIT method: test {p0=/11_nodes/Additional disk information} issue: https://github.com/elastic/elasticsearch/issues/118395 +- class: org.elasticsearch.docker.test.DockerYmlTestSuiteIT + method: test {p0=/11_nodes/Test cat nodes output with full_id set} + issue: https://github.com/elastic/elasticsearch/issues/118396 # Examples: # From b2d07fc967616963957b5845a495b04048ca6184 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Wed, 11 Dec 2024 09:05:44 +1100 Subject: [PATCH 12/28] Mute org.elasticsearch.docker.test.DockerYmlTestSuiteIT test {p0=/11_nodes/Test cat nodes output} #118397 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index 5ebe7a78bf3ae..d2b9957419009 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -303,6 +303,9 @@ tests: - class: org.elasticsearch.docker.test.DockerYmlTestSuiteIT method: test {p0=/11_nodes/Test cat nodes output with full_id set} issue: https://github.com/elastic/elasticsearch/issues/118396 +- class: org.elasticsearch.docker.test.DockerYmlTestSuiteIT + method: test {p0=/11_nodes/Test cat nodes output} + issue: https://github.com/elastic/elasticsearch/issues/118397 # Examples: # From d30d2727dd9b861c41f6160452454a861cd418d1 Mon Sep 17 00:00:00 2001 From: Keith Massey Date: Tue, 10 Dec 2024 16:31:01 -0600 Subject: [PATCH 13/28] Only perform operator check for migration reindex actions when feature flag is present (#118388) --- muted-tests.yml | 3 --- .../org/elasticsearch/xpack/security/operator/Constants.java | 3 +++ 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/muted-tests.yml b/muted-tests.yml index d2b9957419009..fa74d269b7094 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -274,9 +274,6 @@ tests: - class: org.elasticsearch.datastreams.DataStreamsClientYamlTestSuiteIT method: test {p0=data_stream/120_data_streams_stats/Multiple data stream} issue: https://github.com/elastic/elasticsearch/issues/118217 -- class: org.elasticsearch.xpack.security.operator.OperatorPrivilegesIT - method: testEveryActionIsEitherOperatorOnlyOrNonOperator - issue: https://github.com/elastic/elasticsearch/issues/118220 - class: org.elasticsearch.validation.DotPrefixClientYamlTestSuiteIT issue: https://github.com/elastic/elasticsearch/issues/118224 - class: org.elasticsearch.packaging.test.ArchiveTests diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 4450b5a5119de..db87fdbcb8f1f 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.security.operator; import org.elasticsearch.cluster.metadata.DataStream; +import org.elasticsearch.common.util.FeatureFlag; import java.util.Objects; import java.util.Set; @@ -636,6 +637,8 @@ public class Constants { "internal:gateway/local/started_shards", "internal:admin/indices/prevalidate_shard_path", "internal:index/metadata/migration_version/update", + new FeatureFlag("reindex_data_stream").isEnabled() ? "indices:admin/migration/reindex_status" : null, + new FeatureFlag("reindex_data_stream").isEnabled() ? "indices:admin/data_stream/reindex" : null, "internal:admin/repository/verify", "internal:admin/repository/verify/coordinate" ).filter(Objects::nonNull).collect(Collectors.toUnmodifiableSet()); From 254dd0de20b9b5243b52eada8f4459348c0cf481 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Wed, 11 Dec 2024 09:31:42 +1100 Subject: [PATCH 14/28] Mute org.elasticsearch.xpack.test.rest.XPackRestIT test {p0=migrate/20_reindex_status/Test get reindex status with nonexistent task id} #118401 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index fa74d269b7094..e129b35b8699d 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -303,6 +303,9 @@ tests: - class: org.elasticsearch.docker.test.DockerYmlTestSuiteIT method: test {p0=/11_nodes/Test cat nodes output} issue: https://github.com/elastic/elasticsearch/issues/118397 +- class: org.elasticsearch.xpack.test.rest.XPackRestIT + method: test {p0=migrate/20_reindex_status/Test get reindex status with nonexistent task id} + issue: https://github.com/elastic/elasticsearch/issues/118401 # Examples: # From 18707afa5e48e759fefe2c64b7d29d4b67c29a06 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Wed, 11 Dec 2024 09:31:59 +1100 Subject: [PATCH 15/28] Mute org.elasticsearch.xpack.test.rest.XPackRestIT test {p0=migrate/10_reindex/Test Reindex With Nonexistent Data Stream} #118274 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index e129b35b8699d..6eda900ff8bf9 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -306,6 +306,9 @@ tests: - class: org.elasticsearch.xpack.test.rest.XPackRestIT method: test {p0=migrate/20_reindex_status/Test get reindex status with nonexistent task id} issue: https://github.com/elastic/elasticsearch/issues/118401 +- class: org.elasticsearch.xpack.test.rest.XPackRestIT + method: test {p0=migrate/10_reindex/Test Reindex With Nonexistent Data Stream} + issue: https://github.com/elastic/elasticsearch/issues/118274 # Examples: # From 0d1514302bd595209f8dd6c19bdc6ed9f0f82111 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Wed, 11 Dec 2024 09:32:14 +1100 Subject: [PATCH 16/28] Mute org.elasticsearch.xpack.test.rest.XPackRestIT test {p0=migrate/10_reindex/Test Reindex With Bad Data Stream Name} #118272 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index 6eda900ff8bf9..14a06ecd8e8d1 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -309,6 +309,9 @@ tests: - class: org.elasticsearch.xpack.test.rest.XPackRestIT method: test {p0=migrate/10_reindex/Test Reindex With Nonexistent Data Stream} issue: https://github.com/elastic/elasticsearch/issues/118274 +- class: org.elasticsearch.xpack.test.rest.XPackRestIT + method: test {p0=migrate/10_reindex/Test Reindex With Bad Data Stream Name} + issue: https://github.com/elastic/elasticsearch/issues/118272 # Examples: # From f3ab96b9dc6962552b27f49e765281e4fd681c4c Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Wed, 11 Dec 2024 09:32:27 +1100 Subject: [PATCH 17/28] Mute org.elasticsearch.xpack.test.rest.XPackRestIT test {p0=migrate/10_reindex/Test Reindex With Unsupported Mode} #118273 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index 14a06ecd8e8d1..8a173f3ca0d63 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -312,6 +312,9 @@ tests: - class: org.elasticsearch.xpack.test.rest.XPackRestIT method: test {p0=migrate/10_reindex/Test Reindex With Bad Data Stream Name} issue: https://github.com/elastic/elasticsearch/issues/118272 +- class: org.elasticsearch.xpack.test.rest.XPackRestIT + method: test {p0=migrate/10_reindex/Test Reindex With Unsupported Mode} + issue: https://github.com/elastic/elasticsearch/issues/118273 # Examples: # From 16c4e14b37ae2029cc2f4022cdf21dc3d4ea1dea Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Wed, 11 Dec 2024 09:53:33 +1100 Subject: [PATCH 18/28] Mute org.elasticsearch.xpack.inference.InferenceCrudIT testUnifiedCompletionInference #118405 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index 8a173f3ca0d63..df22d10062d8d 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -315,6 +315,9 @@ tests: - class: org.elasticsearch.xpack.test.rest.XPackRestIT method: test {p0=migrate/10_reindex/Test Reindex With Unsupported Mode} issue: https://github.com/elastic/elasticsearch/issues/118273 +- class: org.elasticsearch.xpack.inference.InferenceCrudIT + method: testUnifiedCompletionInference + issue: https://github.com/elastic/elasticsearch/issues/118405 # Examples: # From d213efd272f6ed5cece5d88fc09975f624303363 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Tue, 10 Dec 2024 23:05:27 +0000 Subject: [PATCH 19/28] Add a new index setting to skip recovery source when synthetic source is enabled (#114618) This change adds a new undocumented index settings that allows to use synthetic source for recovery and CCR without storing a recovery source. --- docs/changelog/114618.yaml | 5 + .../admin/indices/create/CloneIndexIT.java | 47 ++ .../index/shard/IndexShardIT.java | 10 +- .../indices/recovery/IndexRecoveryIT.java | 5 +- .../recovery/TruncatedRecoveryIT.java | 11 +- .../snapshots/RestoreSnapshotIT.java | 18 + .../metadata/MetadataCreateIndexService.java | 1 + .../common/settings/ClusterSettings.java | 1 + .../common/settings/IndexScopedSettings.java | 1 + .../elasticsearch/index/IndexSettings.java | 69 ++- .../elasticsearch/index/IndexVersions.java | 1 + .../index/engine/CombinedDocValues.java | 10 + .../elasticsearch/index/engine/Engine.java | 24 +- .../index/engine/InternalEngine.java | 53 +- .../index/engine/LuceneChangesSnapshot.java | 166 +----- .../LuceneSyntheticSourceChangesSnapshot.java | 244 +++++++++ .../index/engine/ReadOnlyEngine.java | 15 +- .../RecoverySourcePruneMergePolicy.java | 53 +- .../engine/SearchBasedChangesSnapshot.java | 233 ++++++++ .../fieldvisitor/LeafStoredFieldLoader.java | 1 - .../index/mapper/SourceFieldMapper.java | 17 +- .../elasticsearch/index/shard/IndexShard.java | 8 +- .../index/shard/PrimaryReplicaSyncer.java | 2 +- .../indices/recovery/RecoverySettings.java | 18 +- .../recovery/RecoverySourceHandler.java | 3 +- .../snapshots/RestoreService.java | 1 + .../index/engine/InternalEngineTests.java | 12 +- .../engine/LuceneChangesSnapshotTests.java | 473 ++-------------- ...neSyntheticSourceChangesSnapshotTests.java | 58 ++ .../RecoverySourcePruneMergePolicyTests.java | 327 ++++++----- .../SearchBasedChangesSnapshotTests.java | 507 ++++++++++++++++++ .../index/mapper/SourceFieldMapperTests.java | 118 +++- .../IndexLevelReplicationTests.java | 19 +- .../index/shard/IndexShardTests.java | 20 +- .../index/shard/RefreshListenersTests.java | 2 +- .../indices/recovery/RecoveryTests.java | 4 +- .../index/engine/EngineTestCase.java | 136 +++-- .../index/engine/TranslogHandler.java | 4 + .../index/mapper/MapperTestCase.java | 81 +++ .../AbstractIndexRecoveryIntegTestCase.java | 2 - .../java/org/elasticsearch/node/MockNode.java | 11 - .../node/RecoverySettingsChunkSizePlugin.java | 40 -- .../xpack/ccr/action/ShardChangesAction.java | 12 +- .../ShardFollowTaskReplicationTests.java | 13 +- .../action/bulk/BulkShardOperationsTests.java | 11 +- .../index/engine/FollowingEngineTests.java | 62 ++- 46 files changed, 2032 insertions(+), 897 deletions(-) create mode 100644 docs/changelog/114618.yaml create mode 100644 server/src/main/java/org/elasticsearch/index/engine/LuceneSyntheticSourceChangesSnapshot.java create mode 100644 server/src/main/java/org/elasticsearch/index/engine/SearchBasedChangesSnapshot.java create mode 100644 server/src/test/java/org/elasticsearch/index/engine/LuceneSyntheticSourceChangesSnapshotTests.java create mode 100644 server/src/test/java/org/elasticsearch/index/engine/SearchBasedChangesSnapshotTests.java delete mode 100644 test/framework/src/main/java/org/elasticsearch/node/RecoverySettingsChunkSizePlugin.java diff --git a/docs/changelog/114618.yaml b/docs/changelog/114618.yaml new file mode 100644 index 0000000000000..ada402fe35742 --- /dev/null +++ b/docs/changelog/114618.yaml @@ -0,0 +1,5 @@ +pr: 114618 +summary: Add a new index setting to skip recovery source when synthetic source is enabled +area: Logs +type: enhancement +issues: [] diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/admin/indices/create/CloneIndexIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/admin/indices/create/CloneIndexIT.java index b6930d06c11ec..47f96aebacd7d 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/admin/indices/create/CloneIndexIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/admin/indices/create/CloneIndexIT.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.query.TermsQueryBuilder; import org.elasticsearch.index.seqno.SeqNoStats; import org.elasticsearch.test.ESIntegTestCase; @@ -26,6 +27,7 @@ import static org.elasticsearch.action.admin.indices.create.ShrinkIndexIT.assertNoResizeSourceIndexSettings; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; +import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -143,6 +145,51 @@ public void testResizeChangeSyntheticSource() { assertThat(error.getMessage(), containsString("can't change setting [index.mapping.source.mode] during resize")); } + public void testResizeChangeRecoveryUseSyntheticSource() { + prepareCreate("source").setSettings( + indexSettings(between(1, 5), 0).put("index.mode", "logsdb") + .put( + "index.version.created", + IndexVersionUtils.randomVersionBetween( + random(), + IndexVersions.USE_SYNTHETIC_SOURCE_FOR_RECOVERY, + IndexVersion.current() + ) + ) + ).setMapping("@timestamp", "type=date", "host.name", "type=keyword").get(); + updateIndexSettings(Settings.builder().put("index.blocks.write", true), "source"); + IllegalArgumentException error = expectThrows(IllegalArgumentException.class, () -> { + indicesAdmin().prepareResizeIndex("source", "target") + .setResizeType(ResizeType.CLONE) + .setSettings( + Settings.builder() + .put( + "index.version.created", + IndexVersionUtils.randomVersionBetween( + random(), + IndexVersions.USE_SYNTHETIC_SOURCE_FOR_RECOVERY, + IndexVersion.current() + ) + ) + .put("index.recovery.use_synthetic_source", true) + .put("index.mode", "logsdb") + .putNull("index.blocks.write") + .build() + ) + .get(); + }); + // The index.recovery.use_synthetic_source setting requires either index.mode or index.mapping.source.mode + // to be present in the settings. Since these are all unmodifiable settings with a non-deterministic evaluation + // order, any of them may trigger a failure first. + assertThat( + error.getMessage(), + anyOf( + containsString("can't change setting [index.mode] during resize"), + containsString("can't change setting [index.recovery.use_synthetic_source] during resize") + ) + ); + } + public void testResizeChangeIndexSorts() { prepareCreate("source").setSettings(indexSettings(between(1, 5), 0)) .setMapping("@timestamp", "type=date", "host.name", "type=keyword") diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/shard/IndexShardIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/shard/IndexShardIT.java index 6ffd5808cea73..870947db5bd85 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/shard/IndexShardIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/shard/IndexShardIT.java @@ -715,7 +715,15 @@ public void testShardChangesWithDefaultDocType() throws Exception { } IndexShard shard = indexService.getShard(0); try ( - Translog.Snapshot luceneSnapshot = shard.newChangesSnapshot("test", 0, numOps - 1, true, randomBoolean(), randomBoolean()); + Translog.Snapshot luceneSnapshot = shard.newChangesSnapshot( + "test", + 0, + numOps - 1, + true, + randomBoolean(), + randomBoolean(), + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) + ); Translog.Snapshot translogSnapshot = getTranslog(shard).newSnapshot() ) { List opsFromLucene = TestTranslog.drainSnapshot(luceneSnapshot, true); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/indices/recovery/IndexRecoveryIT.java b/server/src/internalClusterTest/java/org/elasticsearch/indices/recovery/IndexRecoveryIT.java index 73359078908e7..7d4269550bb88 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/indices/recovery/IndexRecoveryIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/indices/recovery/IndexRecoveryIT.java @@ -156,7 +156,6 @@ import static org.elasticsearch.index.seqno.SequenceNumbers.NO_OPS_PERFORMED; import static org.elasticsearch.indices.IndexingMemoryController.SHARD_INACTIVE_TIME_SETTING; import static org.elasticsearch.node.NodeRoleSettings.NODE_ROLES_SETTING; -import static org.elasticsearch.node.RecoverySettingsChunkSizePlugin.CHUNK_SIZE_SETTING; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; import static org.hamcrest.Matchers.empty; @@ -257,7 +256,7 @@ private void assertOnGoingRecoveryState( public Settings.Builder createRecoverySettingsChunkPerSecond(long chunkSizeBytes) { return Settings.builder() // Set the chunk size in bytes - .put(CHUNK_SIZE_SETTING.getKey(), new ByteSizeValue(chunkSizeBytes, ByteSizeUnit.BYTES)) + .put(RecoverySettings.INDICES_RECOVERY_CHUNK_SIZE.getKey(), new ByteSizeValue(chunkSizeBytes, ByteSizeUnit.BYTES)) // Set one chunk of bytes per second. .put(RecoverySettings.INDICES_RECOVERY_MAX_BYTES_PER_SEC_SETTING.getKey(), chunkSizeBytes, ByteSizeUnit.BYTES); } @@ -280,7 +279,7 @@ private void unthrottleRecovery() { Settings.builder() // 200mb is an arbitrary number intended to be large enough to avoid more throttling. .put(RecoverySettings.INDICES_RECOVERY_MAX_BYTES_PER_SEC_SETTING.getKey(), "200mb") - .put(CHUNK_SIZE_SETTING.getKey(), RecoverySettings.DEFAULT_CHUNK_SIZE) + .put(RecoverySettings.INDICES_RECOVERY_CHUNK_SIZE.getKey(), RecoverySettings.DEFAULT_CHUNK_SIZE) ); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/recovery/TruncatedRecoveryIT.java b/server/src/internalClusterTest/java/org/elasticsearch/recovery/TruncatedRecoveryIT.java index 38eef4f720623..ca2ff69ac9b17 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/recovery/TruncatedRecoveryIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/recovery/TruncatedRecoveryIT.java @@ -24,7 +24,7 @@ import org.elasticsearch.indices.recovery.PeerRecoveryTargetService; import org.elasticsearch.indices.recovery.RecoveryFileChunkRequest; import org.elasticsearch.indices.recovery.RecoveryFilesInfoRequest; -import org.elasticsearch.node.RecoverySettingsChunkSizePlugin; +import org.elasticsearch.indices.recovery.RecoverySettings; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.test.transport.MockTransportService; @@ -41,7 +41,6 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; -import static org.elasticsearch.node.RecoverySettingsChunkSizePlugin.CHUNK_SIZE_SETTING; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -52,7 +51,7 @@ public class TruncatedRecoveryIT extends ESIntegTestCase { @Override protected Collection> nodePlugins() { - return Arrays.asList(MockTransportService.TestPlugin.class, RecoverySettingsChunkSizePlugin.class); + return Arrays.asList(MockTransportService.TestPlugin.class); } /** @@ -63,7 +62,11 @@ protected Collection> nodePlugins() { */ public void testCancelRecoveryAndResume() throws Exception { updateClusterSettings( - Settings.builder().put(CHUNK_SIZE_SETTING.getKey(), new ByteSizeValue(randomIntBetween(50, 300), ByteSizeUnit.BYTES)) + Settings.builder() + .put( + RecoverySettings.INDICES_RECOVERY_CHUNK_SIZE.getKey(), + new ByteSizeValue(randomIntBetween(50, 300), ByteSizeUnit.BYTES) + ) ); NodesStatsResponse nodeStats = clusterAdmin().prepareNodesStats().get(); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/RestoreSnapshotIT.java b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/RestoreSnapshotIT.java index b490c7efd52cd..4ba06a34ca3a7 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/RestoreSnapshotIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/RestoreSnapshotIT.java @@ -809,6 +809,24 @@ public void testRestoreChangeSyntheticSource() { assertThat(error.getMessage(), containsString("cannot modify setting [index.mapping.source.mode] on restore")); } + public void testRestoreChangeRecoveryUseSyntheticSource() { + Client client = client(); + createRepository("test-repo", "fs"); + String indexName = "test-idx"; + assertAcked(client.admin().indices().prepareCreate(indexName).setSettings(Settings.builder().put(indexSettings()))); + createSnapshot("test-repo", "test-snap", Collections.singletonList(indexName)); + cluster().wipeIndices(indexName); + var error = expectThrows(SnapshotRestoreException.class, () -> { + client.admin() + .cluster() + .prepareRestoreSnapshot(TEST_REQUEST_TIMEOUT, "test-repo", "test-snap") + .setIndexSettings(Settings.builder().put("index.recovery.use_synthetic_source", true)) + .setWaitForCompletion(true) + .get(); + }); + assertThat(error.getMessage(), containsString("cannot modify setting [index.recovery.use_synthetic_source] on restore")); + } + public void testRestoreChangeIndexSorts() { Client client = client(); createRepository("test-repo", "fs"); diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java index 52e4d75ac5116..75c2c06f36c8e 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java @@ -1591,6 +1591,7 @@ static void validateCloneIndex( private static final Set UNMODIFIABLE_SETTINGS_DURING_RESIZE = Set.of( IndexSettings.MODE.getKey(), SourceFieldMapper.INDEX_MAPPER_SOURCE_MODE_SETTING.getKey(), + IndexSettings.RECOVERY_USE_SYNTHETIC_SOURCE_SETTING.getKey(), IndexSortConfig.INDEX_SORT_FIELD_SETTING.getKey(), IndexSortConfig.INDEX_SORT_ORDER_SETTING.getKey(), IndexSortConfig.INDEX_SORT_MODE_SETTING.getKey(), diff --git a/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java b/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java index 16af7ca2915d4..a01571b8c237d 100644 --- a/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java +++ b/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java @@ -257,6 +257,7 @@ public void apply(Settings value, Settings current, Settings previous) { RecoverySettings.INDICES_RECOVERY_USE_SNAPSHOTS_SETTING, RecoverySettings.INDICES_RECOVERY_MAX_CONCURRENT_SNAPSHOT_FILE_DOWNLOADS, RecoverySettings.INDICES_RECOVERY_MAX_CONCURRENT_SNAPSHOT_FILE_DOWNLOADS_PER_NODE, + RecoverySettings.INDICES_RECOVERY_CHUNK_SIZE, RecoverySettings.NODE_BANDWIDTH_RECOVERY_FACTOR_READ_SETTING, RecoverySettings.NODE_BANDWIDTH_RECOVERY_FACTOR_WRITE_SETTING, RecoverySettings.NODE_BANDWIDTH_RECOVERY_OPERATOR_FACTOR_SETTING, diff --git a/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java b/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java index 884ce38fba391..fc8f128e92f32 100644 --- a/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java +++ b/server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java @@ -188,6 +188,7 @@ public final class IndexScopedSettings extends AbstractScopedSettings { IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_WRITE_SETTING, IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING, SourceFieldMapper.INDEX_MAPPER_SOURCE_MODE_SETTING, + IndexSettings.RECOVERY_USE_SYNTHETIC_SOURCE_SETTING, // validate that built-in similarities don't get redefined Setting.groupSetting("index.similarity.", (s) -> { diff --git a/server/src/main/java/org/elasticsearch/index/IndexSettings.java b/server/src/main/java/org/elasticsearch/index/IndexSettings.java index 5bea838f9d70c..8f0373d951319 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexSettings.java +++ b/server/src/main/java/org/elasticsearch/index/IndexSettings.java @@ -38,6 +38,7 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; @@ -51,6 +52,7 @@ import static org.elasticsearch.index.mapper.MapperService.INDEX_MAPPING_NESTED_DOCS_LIMIT_SETTING; import static org.elasticsearch.index.mapper.MapperService.INDEX_MAPPING_NESTED_FIELDS_LIMIT_SETTING; import static org.elasticsearch.index.mapper.MapperService.INDEX_MAPPING_TOTAL_FIELDS_LIMIT_SETTING; +import static org.elasticsearch.index.mapper.SourceFieldMapper.INDEX_MAPPER_SOURCE_MODE_SETTING; /** * This class encapsulates all index level settings and handles settings updates. @@ -653,6 +655,62 @@ public Iterator> settings() { Property.Final ); + public static final Setting RECOVERY_USE_SYNTHETIC_SOURCE_SETTING = Setting.boolSetting( + "index.recovery.use_synthetic_source", + false, + new Setting.Validator<>() { + @Override + public void validate(Boolean value) {} + + @Override + public void validate(Boolean enabled, Map, Object> settings) { + if (enabled == false) { + return; + } + + // Verify if synthetic source is enabled on the index; fail if it is not + var indexMode = (IndexMode) settings.get(MODE); + if (indexMode.defaultSourceMode() != SourceFieldMapper.Mode.SYNTHETIC) { + var sourceMode = (SourceFieldMapper.Mode) settings.get(INDEX_MAPPER_SOURCE_MODE_SETTING); + if (sourceMode != SourceFieldMapper.Mode.SYNTHETIC) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "The setting [%s] is only permitted when [%s] is set to [%s]. Current mode: [%s].", + RECOVERY_USE_SYNTHETIC_SOURCE_SETTING.getKey(), + INDEX_MAPPER_SOURCE_MODE_SETTING.getKey(), + SourceFieldMapper.Mode.SYNTHETIC.name(), + sourceMode.name() + ) + ); + } + } + + // Verify that all nodes can handle this setting + var version = (IndexVersion) settings.get(SETTING_INDEX_VERSION_CREATED); + if (version.before(IndexVersions.USE_SYNTHETIC_SOURCE_FOR_RECOVERY)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "The setting [%s] is unavailable on this cluster because some nodes are running older " + + "versions that do not support it. Please upgrade all nodes to the latest version " + + "and try again.", + RECOVERY_USE_SYNTHETIC_SOURCE_SETTING.getKey() + ) + ); + } + } + + @Override + public Iterator> settings() { + List> res = List.of(INDEX_MAPPER_SOURCE_MODE_SETTING, SETTING_INDEX_VERSION_CREATED, MODE); + return res.iterator(); + } + }, + Property.IndexScope, + Property.Final + ); + /** * Returns true if TSDB encoding is enabled. The default is true */ @@ -824,6 +882,7 @@ private void setRetentionLeaseMillis(final TimeValue retentionLease) { private volatile boolean skipIgnoredSourceRead; private final SourceFieldMapper.Mode indexMappingSourceMode; private final boolean recoverySourceEnabled; + private final boolean recoverySourceSyntheticEnabled; /** * The maximum number of refresh listeners allows on this shard. @@ -984,8 +1043,9 @@ public IndexSettings(final IndexMetadata indexMetadata, final Settings nodeSetti es87TSDBCodecEnabled = scopedSettings.get(TIME_SERIES_ES87TSDB_CODEC_ENABLED_SETTING); skipIgnoredSourceWrite = scopedSettings.get(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_WRITE_SETTING); skipIgnoredSourceRead = scopedSettings.get(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING); - indexMappingSourceMode = scopedSettings.get(SourceFieldMapper.INDEX_MAPPER_SOURCE_MODE_SETTING); + indexMappingSourceMode = scopedSettings.get(INDEX_MAPPER_SOURCE_MODE_SETTING); recoverySourceEnabled = RecoverySettings.INDICES_RECOVERY_SOURCE_ENABLED_SETTING.get(nodeSettings); + recoverySourceSyntheticEnabled = scopedSettings.get(RECOVERY_USE_SYNTHETIC_SOURCE_SETTING); scopedSettings.addSettingsUpdateConsumer( MergePolicyConfig.INDEX_COMPOUND_FORMAT_SETTING, @@ -1677,6 +1737,13 @@ public boolean isRecoverySourceEnabled() { return recoverySourceEnabled; } + /** + * @return Whether recovery source should always be bypassed in favor of using synthetic source. + */ + public boolean isRecoverySourceSyntheticEnabled() { + return recoverySourceSyntheticEnabled; + } + /** * The bounds for {@code @timestamp} on this index or * {@code null} if there are no bounds. diff --git a/server/src/main/java/org/elasticsearch/index/IndexVersions.java b/server/src/main/java/org/elasticsearch/index/IndexVersions.java index 6344aa2a72ca9..96a70c3cc432b 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexVersions.java +++ b/server/src/main/java/org/elasticsearch/index/IndexVersions.java @@ -136,6 +136,7 @@ private static Version parseUnchecked(String version) { public static final IndexVersion LOGSDB_DEFAULT_IGNORE_DYNAMIC_BEYOND_LIMIT = def(9_001_00_0, Version.LUCENE_10_0_0); public static final IndexVersion TIME_BASED_K_ORDERED_DOC_ID = def(9_002_00_0, Version.LUCENE_10_0_0); public static final IndexVersion DEPRECATE_SOURCE_MODE_MAPPER = def(9_003_00_0, Version.LUCENE_10_0_0); + public static final IndexVersion USE_SYNTHETIC_SOURCE_FOR_RECOVERY = def(9_004_00_0, Version.LUCENE_10_0_0); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/server/src/main/java/org/elasticsearch/index/engine/CombinedDocValues.java b/server/src/main/java/org/elasticsearch/index/engine/CombinedDocValues.java index 48fc76063f815..190a1ed8b457a 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/CombinedDocValues.java +++ b/server/src/main/java/org/elasticsearch/index/engine/CombinedDocValues.java @@ -24,6 +24,7 @@ final class CombinedDocValues { private final NumericDocValues primaryTermDV; private final NumericDocValues tombstoneDV; private final NumericDocValues recoverySource; + private final NumericDocValues recoverySourceSize; CombinedDocValues(LeafReader leafReader) throws IOException { this.versionDV = Objects.requireNonNull(leafReader.getNumericDocValues(VersionFieldMapper.NAME), "VersionDV is missing"); @@ -34,6 +35,7 @@ final class CombinedDocValues { ); this.tombstoneDV = leafReader.getNumericDocValues(SeqNoFieldMapper.TOMBSTONE_NAME); this.recoverySource = leafReader.getNumericDocValues(SourceFieldMapper.RECOVERY_SOURCE_NAME); + this.recoverySourceSize = leafReader.getNumericDocValues(SourceFieldMapper.RECOVERY_SOURCE_SIZE_NAME); } long docVersion(int segmentDocId) throws IOException { @@ -79,4 +81,12 @@ boolean hasRecoverySource(int segmentDocId) throws IOException { assert recoverySource.docID() < segmentDocId; return recoverySource.advanceExact(segmentDocId); } + + long recoverySourceSize(int segmentDocId) throws IOException { + if (recoverySourceSize == null) { + return -1; + } + assert recoverySourceSize.docID() < segmentDocId; + return recoverySourceSize.advanceExact(segmentDocId) ? recoverySourceSize.longValue() : -1; + } } diff --git a/server/src/main/java/org/elasticsearch/index/engine/Engine.java b/server/src/main/java/org/elasticsearch/index/engine/Engine.java index edafa1ca922fb..dcdff09191667 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/Engine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/Engine.java @@ -937,14 +937,15 @@ public boolean allowSearchIdleOptimization() { * @param source the source of the request * @param fromSeqNo the start sequence number (inclusive) * @param toSeqNo the end sequence number (inclusive) - * @see #newChangesSnapshot(String, long, long, boolean, boolean, boolean) + * @see #newChangesSnapshot(String, long, long, boolean, boolean, boolean, long) */ public abstract int countChanges(String source, long fromSeqNo, long toSeqNo) throws IOException; /** - * Creates a new history snapshot from Lucene for reading operations whose seqno in the requesting seqno range (both inclusive). - * This feature requires soft-deletes enabled. If soft-deletes are disabled, this method will throw an {@link IllegalStateException}. + * @deprecated This method is deprecated will and be removed once #114618 is applied to the serverless repository. + * @see #newChangesSnapshot(String, long, long, boolean, boolean, boolean, long) */ + @Deprecated public abstract Translog.Snapshot newChangesSnapshot( String source, long fromSeqNo, @@ -954,6 +955,23 @@ public abstract Translog.Snapshot newChangesSnapshot( boolean accessStats ) throws IOException; + /** + * Creates a new history snapshot from Lucene for reading operations whose seqno in the requesting seqno range (both inclusive). + * This feature requires soft-deletes enabled. If soft-deletes are disabled, this method will throw an {@link IllegalStateException}. + */ + public Translog.Snapshot newChangesSnapshot( + String source, + long fromSeqNo, + long toSeqNo, + boolean requiredFullRange, + boolean singleConsumer, + boolean accessStats, + long maxChunkSize + ) throws IOException { + // TODO: Remove this default implementation once the deprecated newChangesSnapshot is removed + return newChangesSnapshot(source, fromSeqNo, toSeqNo, requiredFullRange, singleConsumer, accessStats); + } + /** * Checks if this engine has every operations since {@code startingSeqNo}(inclusive) in its history (either Lucene or translog) */ diff --git a/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java b/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java index 8d43252d178ee..fe310dc45c94c 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java @@ -2709,7 +2709,10 @@ private IndexWriterConfig getIndexWriterConfig() { // always configure soft-deletes field so an engine with soft-deletes disabled can open a Lucene index with soft-deletes. iwc.setSoftDeletesField(Lucene.SOFT_DELETES_FIELD); mergePolicy = new RecoverySourcePruneMergePolicy( - SourceFieldMapper.RECOVERY_SOURCE_NAME, + engineConfig.getIndexSettings().isRecoverySourceSyntheticEnabled() ? null : SourceFieldMapper.RECOVERY_SOURCE_NAME, + engineConfig.getIndexSettings().isRecoverySourceSyntheticEnabled() + ? SourceFieldMapper.RECOVERY_SOURCE_SIZE_NAME + : SourceFieldMapper.RECOVERY_SOURCE_NAME, engineConfig.getIndexSettings().getMode() == IndexMode.TIME_SERIES, softDeletesPolicy::getRetentionQuery, new SoftDeletesRetentionMergePolicy( @@ -3141,6 +3144,19 @@ public Translog.Snapshot newChangesSnapshot( boolean requiredFullRange, boolean singleConsumer, boolean accessStats + ) throws IOException { + return newChangesSnapshot(source, fromSeqNo, toSeqNo, requiredFullRange, singleConsumer, accessStats, -1); + } + + @Override + public Translog.Snapshot newChangesSnapshot( + String source, + long fromSeqNo, + long toSeqNo, + boolean requiredFullRange, + boolean singleConsumer, + boolean accessStats, + long maxChunkSize ) throws IOException { if (enableRecoverySource == false) { throw new IllegalStateException( @@ -3153,16 +3169,31 @@ public Translog.Snapshot newChangesSnapshot( refreshIfNeeded(source, toSeqNo); Searcher searcher = acquireSearcher(source, SearcherScope.INTERNAL); try { - LuceneChangesSnapshot snapshot = new LuceneChangesSnapshot( - searcher, - LuceneChangesSnapshot.DEFAULT_BATCH_SIZE, - fromSeqNo, - toSeqNo, - requiredFullRange, - singleConsumer, - accessStats, - config().getIndexSettings().getIndexVersionCreated() - ); + final Translog.Snapshot snapshot; + if (engineConfig.getIndexSettings().isRecoverySourceSyntheticEnabled()) { + snapshot = new LuceneSyntheticSourceChangesSnapshot( + engineConfig.getMapperService().mappingLookup(), + searcher, + SearchBasedChangesSnapshot.DEFAULT_BATCH_SIZE, + maxChunkSize, + fromSeqNo, + toSeqNo, + requiredFullRange, + accessStats, + config().getIndexSettings().getIndexVersionCreated() + ); + } else { + snapshot = new LuceneChangesSnapshot( + searcher, + SearchBasedChangesSnapshot.DEFAULT_BATCH_SIZE, + fromSeqNo, + toSeqNo, + requiredFullRange, + singleConsumer, + accessStats, + config().getIndexSettings().getIndexVersionCreated() + ); + } searcher = null; return snapshot; } catch (Exception e) { diff --git a/server/src/main/java/org/elasticsearch/index/engine/LuceneChangesSnapshot.java b/server/src/main/java/org/elasticsearch/index/engine/LuceneChangesSnapshot.java index e44b344d3b283..d4466cbc17c54 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/LuceneChangesSnapshot.java +++ b/server/src/main/java/org/elasticsearch/index/engine/LuceneChangesSnapshot.java @@ -10,61 +10,33 @@ package org.elasticsearch.index.engine; import org.apache.lucene.codecs.StoredFieldsReader; -import org.apache.lucene.document.LongPoint; -import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.NumericDocValues; -import org.apache.lucene.search.BooleanClause; -import org.apache.lucene.search.BooleanQuery; -import org.apache.lucene.search.FieldDoc; -import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.Sort; -import org.apache.lucene.search.SortField; import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TopFieldCollectorManager; import org.apache.lucene.util.ArrayUtil; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.lucene.index.SequentialStoredFieldsLeafReader; -import org.elasticsearch.common.lucene.search.Queries; -import org.elasticsearch.core.IOUtils; +import org.elasticsearch.core.Assertions; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.fieldvisitor.FieldsVisitor; -import org.elasticsearch.index.mapper.SeqNoFieldMapper; import org.elasticsearch.index.mapper.SourceFieldMapper; import org.elasticsearch.index.translog.Translog; import org.elasticsearch.transport.Transports; -import java.io.Closeable; import java.io.IOException; import java.util.Comparator; -import java.util.List; -import java.util.concurrent.atomic.AtomicBoolean; /** * A {@link Translog.Snapshot} from changes in a Lucene index */ -final class LuceneChangesSnapshot implements Translog.Snapshot { - static final int DEFAULT_BATCH_SIZE = 1024; - - private final int searchBatchSize; - private final long fromSeqNo, toSeqNo; +public final class LuceneChangesSnapshot extends SearchBasedChangesSnapshot { private long lastSeenSeqNo; private int skippedOperations; - private final boolean requiredFullRange; private final boolean singleConsumer; - private final IndexSearcher indexSearcher; private int docIndex = 0; - private final boolean accessStats; - private final int totalHits; - private ScoreDoc[] scoreDocs; + private int maxDocIndex; private final ParallelArray parallelArray; - private final Closeable onClose; - - private final IndexVersion indexVersionCreated; private int storedFieldsReaderOrd = -1; private StoredFieldsReader storedFieldsReader = null; @@ -83,7 +55,7 @@ final class LuceneChangesSnapshot implements Translog.Snapshot { * @param accessStats true if the stats of the snapshot can be accessed via {@link #totalOperations()} * @param indexVersionCreated the version on which this index was created */ - LuceneChangesSnapshot( + public LuceneChangesSnapshot( Engine.Searcher engineSearcher, int searchBatchSize, long fromSeqNo, @@ -93,50 +65,26 @@ final class LuceneChangesSnapshot implements Translog.Snapshot { boolean accessStats, IndexVersion indexVersionCreated ) throws IOException { - if (fromSeqNo < 0 || toSeqNo < 0 || fromSeqNo > toSeqNo) { - throw new IllegalArgumentException("Invalid range; from_seqno [" + fromSeqNo + "], to_seqno [" + toSeqNo + "]"); - } - if (searchBatchSize <= 0) { - throw new IllegalArgumentException("Search_batch_size must be positive [" + searchBatchSize + "]"); - } - final AtomicBoolean closed = new AtomicBoolean(); - this.onClose = () -> { - if (closed.compareAndSet(false, true)) { - IOUtils.close(engineSearcher); - } - }; - final long requestingSize = (toSeqNo - fromSeqNo) == Long.MAX_VALUE ? Long.MAX_VALUE : (toSeqNo - fromSeqNo + 1L); - this.creationThread = Thread.currentThread(); - this.searchBatchSize = requestingSize < searchBatchSize ? Math.toIntExact(requestingSize) : searchBatchSize; - this.fromSeqNo = fromSeqNo; - this.toSeqNo = toSeqNo; - this.lastSeenSeqNo = fromSeqNo - 1; - this.requiredFullRange = requiredFullRange; + super(engineSearcher, searchBatchSize, fromSeqNo, toSeqNo, requiredFullRange, accessStats, indexVersionCreated); + this.creationThread = Assertions.ENABLED ? Thread.currentThread() : null; this.singleConsumer = singleConsumer; - this.indexSearcher = newIndexSearcher(engineSearcher); - this.indexSearcher.setQueryCache(null); - this.accessStats = accessStats; this.parallelArray = new ParallelArray(this.searchBatchSize); - this.indexVersionCreated = indexVersionCreated; - final TopDocs topDocs = searchOperations(null, accessStats); - this.totalHits = Math.toIntExact(topDocs.totalHits.value()); - this.scoreDocs = topDocs.scoreDocs; - fillParallelArray(scoreDocs, parallelArray); + this.lastSeenSeqNo = fromSeqNo - 1; + final TopDocs topDocs = nextTopDocs(); + this.maxDocIndex = topDocs.scoreDocs.length; + fillParallelArray(topDocs.scoreDocs, parallelArray); } @Override public void close() throws IOException { assert assertAccessingThread(); - onClose.close(); + super.close(); } @Override public int totalOperations() { assert assertAccessingThread(); - if (accessStats == false) { - throw new IllegalStateException("Access stats of a snapshot created with [access_stats] is false"); - } - return totalHits; + return super.totalOperations(); } @Override @@ -146,7 +94,7 @@ public int skippedOperations() { } @Override - public Translog.Operation next() throws IOException { + protected Translog.Operation nextOperation() throws IOException { assert assertAccessingThread(); Translog.Operation op = null; for (int idx = nextDocIndex(); idx != -1; idx = nextDocIndex()) { @@ -155,12 +103,6 @@ public Translog.Operation next() throws IOException { break; } } - if (requiredFullRange) { - rangeCheck(op); - } - if (op != null) { - lastSeenSeqNo = op.seqNo(); - } return op; } @@ -171,48 +113,15 @@ private boolean assertAccessingThread() { return true; } - private void rangeCheck(Translog.Operation op) { - if (op == null) { - if (lastSeenSeqNo < toSeqNo) { - throw new MissingHistoryOperationsException( - "Not all operations between from_seqno [" - + fromSeqNo - + "] " - + "and to_seqno [" - + toSeqNo - + "] found; prematurely terminated last_seen_seqno [" - + lastSeenSeqNo - + "]" - ); - } - } else { - final long expectedSeqNo = lastSeenSeqNo + 1; - if (op.seqNo() != expectedSeqNo) { - throw new MissingHistoryOperationsException( - "Not all operations between from_seqno [" - + fromSeqNo - + "] " - + "and to_seqno [" - + toSeqNo - + "] found; expected seqno [" - + expectedSeqNo - + "]; found [" - + op - + "]" - ); - } - } - } - private int nextDocIndex() throws IOException { // we have processed all docs in the current search - fetch the next batch - if (docIndex == scoreDocs.length && docIndex > 0) { - final ScoreDoc prev = scoreDocs[scoreDocs.length - 1]; - scoreDocs = searchOperations((FieldDoc) prev, false).scoreDocs; + if (docIndex == maxDocIndex && docIndex > 0) { + var scoreDocs = nextTopDocs().scoreDocs; fillParallelArray(scoreDocs, parallelArray); docIndex = 0; + maxDocIndex = scoreDocs.length; } - if (docIndex < scoreDocs.length) { + if (docIndex < maxDocIndex) { int idx = docIndex; docIndex++; return idx; @@ -237,14 +146,13 @@ private void fillParallelArray(ScoreDoc[] scoreDocs, ParallelArray parallelArray } int docBase = -1; int maxDoc = 0; - List leaves = indexSearcher.getIndexReader().leaves(); int readerIndex = 0; CombinedDocValues combinedDocValues = null; LeafReaderContext leaf = null; for (ScoreDoc scoreDoc : scoreDocs) { if (scoreDoc.doc >= docBase + maxDoc) { do { - leaf = leaves.get(readerIndex++); + leaf = leaves().get(readerIndex++); docBase = leaf.docBase; maxDoc = leaf.reader().maxDoc(); } while (scoreDoc.doc >= docBase + maxDoc); @@ -253,6 +161,7 @@ private void fillParallelArray(ScoreDoc[] scoreDocs, ParallelArray parallelArray final int segmentDocID = scoreDoc.doc - docBase; final int index = scoreDoc.shardIndex; parallelArray.leafReaderContexts[index] = leaf; + parallelArray.docID[index] = scoreDoc.doc; parallelArray.seqNo[index] = combinedDocValues.docSeqNo(segmentDocID); parallelArray.primaryTerm[index] = combinedDocValues.docPrimaryTerm(segmentDocID); parallelArray.version[index] = combinedDocValues.docVersion(segmentDocID); @@ -275,16 +184,6 @@ private static boolean hasSequentialAccess(ScoreDoc[] scoreDocs) { return true; } - private static IndexSearcher newIndexSearcher(Engine.Searcher engineSearcher) throws IOException { - return new IndexSearcher(Lucene.wrapAllDocsLive(engineSearcher.getDirectoryReader())); - } - - private static Query rangeQuery(long fromSeqNo, long toSeqNo, IndexVersion indexVersionCreated) { - return new BooleanQuery.Builder().add(LongPoint.newRangeQuery(SeqNoFieldMapper.NAME, fromSeqNo, toSeqNo), BooleanClause.Occur.MUST) - .add(Queries.newNonNestedFilter(indexVersionCreated), BooleanClause.Occur.MUST) // exclude non-root nested documents - .build(); - } - static int countOperations(Engine.Searcher engineSearcher, long fromSeqNo, long toSeqNo, IndexVersion indexVersionCreated) throws IOException { if (fromSeqNo < 0 || toSeqNo < 0 || fromSeqNo > toSeqNo) { @@ -293,23 +192,9 @@ static int countOperations(Engine.Searcher engineSearcher, long fromSeqNo, long return newIndexSearcher(engineSearcher).count(rangeQuery(fromSeqNo, toSeqNo, indexVersionCreated)); } - private TopDocs searchOperations(FieldDoc after, boolean accurateTotalHits) throws IOException { - final Query rangeQuery = rangeQuery(Math.max(fromSeqNo, lastSeenSeqNo), toSeqNo, indexVersionCreated); - assert accurateTotalHits == false || after == null : "accurate total hits is required by the first batch only"; - final SortField sortBySeqNo = new SortField(SeqNoFieldMapper.NAME, SortField.Type.LONG); - TopFieldCollectorManager topFieldCollectorManager = new TopFieldCollectorManager( - new Sort(sortBySeqNo), - searchBatchSize, - after, - accurateTotalHits ? Integer.MAX_VALUE : 0, - false - ); - return indexSearcher.search(rangeQuery, topFieldCollectorManager); - } - private Translog.Operation readDocAsOp(int docIndex) throws IOException { final LeafReaderContext leaf = parallelArray.leafReaderContexts[docIndex]; - final int segmentDocID = scoreDocs[docIndex].doc - leaf.docBase; + final int segmentDocID = parallelArray.docID[docIndex] - leaf.docBase; final long primaryTerm = parallelArray.primaryTerm[docIndex]; assert primaryTerm > 0 : "nested child document must be excluded"; final long seqNo = parallelArray.seqNo[docIndex]; @@ -385,19 +270,13 @@ private Translog.Operation readDocAsOp(int docIndex) throws IOException { + "], op [" + op + "]"; + lastSeenSeqNo = op.seqNo(); return op; } - private static boolean assertDocSoftDeleted(LeafReader leafReader, int segmentDocId) throws IOException { - final NumericDocValues ndv = leafReader.getNumericDocValues(Lucene.SOFT_DELETES_FIELD); - if (ndv == null || ndv.advanceExact(segmentDocId) == false) { - throw new IllegalStateException("DocValues for field [" + Lucene.SOFT_DELETES_FIELD + "] is not found"); - } - return ndv.longValue() == 1; - } - private static final class ParallelArray { final LeafReaderContext[] leafReaderContexts; + final int[] docID; final long[] version; final long[] seqNo; final long[] primaryTerm; @@ -406,6 +285,7 @@ private static final class ParallelArray { boolean useSequentialStoredFieldsReader = false; ParallelArray(int size) { + docID = new int[size]; version = new long[size]; seqNo = new long[size]; primaryTerm = new long[size]; diff --git a/server/src/main/java/org/elasticsearch/index/engine/LuceneSyntheticSourceChangesSnapshot.java b/server/src/main/java/org/elasticsearch/index/engine/LuceneSyntheticSourceChangesSnapshot.java new file mode 100644 index 0000000000000..3d3d2f6f66d56 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/engine/LuceneSyntheticSourceChangesSnapshot.java @@ -0,0 +1,244 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.engine; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.util.ArrayUtil; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.fieldvisitor.LeafStoredFieldLoader; +import org.elasticsearch.index.fieldvisitor.StoredFieldLoader; +import org.elasticsearch.index.mapper.MappingLookup; +import org.elasticsearch.index.mapper.SourceFieldMetrics; +import org.elasticsearch.index.mapper.SourceLoader; +import org.elasticsearch.index.translog.Translog; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.Deque; +import java.util.LinkedList; +import java.util.List; +import java.util.Set; + +/** + * A {@link SearchBasedChangesSnapshot} that utilizes a synthetic field loader to rebuild the recovery source. + * This snapshot is activated when {@link IndexSettings#RECOVERY_USE_SYNTHETIC_SOURCE_SETTING} + * is enabled on the underlying index. + * + * The {@code maxMemorySizeInBytes} parameter limits the total size of uncompressed _sources + * loaded into memory during batch retrieval. + */ +public class LuceneSyntheticSourceChangesSnapshot extends SearchBasedChangesSnapshot { + private final long maxMemorySizeInBytes; + private final StoredFieldLoader storedFieldLoader; + private final SourceLoader sourceLoader; + + private int skippedOperations; + private long lastSeenSeqNo; + + private record SearchRecord(FieldDoc doc, boolean isTombstone, long seqNo, long primaryTerm, long version, long size) { + int index() { + return doc.shardIndex; + } + + int docID() { + return doc.doc; + } + + boolean hasRecoverySourceSize() { + return size != -1; + } + } + + private final Deque pendingDocs = new LinkedList<>(); + private final Deque operationQueue = new LinkedList<>(); + + public LuceneSyntheticSourceChangesSnapshot( + MappingLookup mappingLookup, + Engine.Searcher engineSearcher, + int searchBatchSize, + long maxMemorySizeInBytes, + long fromSeqNo, + long toSeqNo, + boolean requiredFullRange, + boolean accessStats, + IndexVersion indexVersionCreated + ) throws IOException { + super(engineSearcher, searchBatchSize, fromSeqNo, toSeqNo, requiredFullRange, accessStats, indexVersionCreated); + assert mappingLookup.isSourceSynthetic(); + // ensure we can buffer at least one document + this.maxMemorySizeInBytes = maxMemorySizeInBytes > 0 ? maxMemorySizeInBytes : 1; + this.sourceLoader = mappingLookup.newSourceLoader(SourceFieldMetrics.NOOP); + Set storedFields = sourceLoader.requiredStoredFields(); + assert mappingLookup.isSourceSynthetic() : "synthetic source must be enabled for proper functionality."; + this.storedFieldLoader = StoredFieldLoader.create(false, storedFields); + this.lastSeenSeqNo = fromSeqNo - 1; + } + + @Override + public int skippedOperations() { + return skippedOperations; + } + + @Override + protected Translog.Operation nextOperation() throws IOException { + while (true) { + if (operationQueue.isEmpty()) { + loadNextBatch(); + } + if (operationQueue.isEmpty()) { + return null; + } + var op = operationQueue.pollFirst(); + if (op.seqNo() == lastSeenSeqNo) { + skippedOperations++; + continue; + } + lastSeenSeqNo = op.seqNo(); + return op; + } + } + + private void loadNextBatch() throws IOException { + List documentsToLoad = new ArrayList<>(); + long accumulatedSize = 0; + while (accumulatedSize < maxMemorySizeInBytes) { + if (pendingDocs.isEmpty()) { + ScoreDoc[] topDocs = nextTopDocs().scoreDocs; + if (topDocs.length == 0) { + break; + } + pendingDocs.addAll(Arrays.asList(transformScoreDocsToRecords(topDocs))); + } + SearchRecord document = pendingDocs.pollFirst(); + document.doc().shardIndex = documentsToLoad.size(); + documentsToLoad.add(document); + accumulatedSize += document.size(); + } + + for (var op : loadDocuments(documentsToLoad)) { + if (op == null) { + skippedOperations++; + continue; + } + operationQueue.add(op); + } + } + + private SearchRecord[] transformScoreDocsToRecords(ScoreDoc[] scoreDocs) throws IOException { + ArrayUtil.introSort(scoreDocs, Comparator.comparingInt(doc -> doc.doc)); + SearchRecord[] documentRecords = new SearchRecord[scoreDocs.length]; + CombinedDocValues combinedDocValues = null; + int docBase = -1; + int maxDoc = 0; + int readerIndex = 0; + LeafReaderContext leafReaderContext; + + for (int i = 0; i < scoreDocs.length; i++) { + ScoreDoc scoreDoc = scoreDocs[i]; + if (scoreDoc.doc >= docBase + maxDoc) { + do { + leafReaderContext = leaves().get(readerIndex++); + docBase = leafReaderContext.docBase; + maxDoc = leafReaderContext.reader().maxDoc(); + } while (scoreDoc.doc >= docBase + maxDoc); + combinedDocValues = new CombinedDocValues(leafReaderContext.reader()); + } + int segmentDocID = scoreDoc.doc - docBase; + int index = scoreDoc.shardIndex; + var primaryTerm = combinedDocValues.docPrimaryTerm(segmentDocID); + assert primaryTerm > 0 : "nested child document must be excluded"; + documentRecords[index] = new SearchRecord( + (FieldDoc) scoreDoc, + combinedDocValues.isTombstone(segmentDocID), + combinedDocValues.docSeqNo(segmentDocID), + primaryTerm, + combinedDocValues.docVersion(segmentDocID), + combinedDocValues.recoverySourceSize(segmentDocID) + ); + } + return documentRecords; + } + + private Translog.Operation[] loadDocuments(List documentRecords) throws IOException { + documentRecords.sort(Comparator.comparingInt(doc -> doc.docID())); + Translog.Operation[] operations = new Translog.Operation[documentRecords.size()]; + + int docBase = -1; + int maxDoc = 0; + int readerIndex = 0; + LeafReaderContext leafReaderContext = null; + LeafStoredFieldLoader leafFieldLoader = null; + SourceLoader.Leaf leafSourceLoader = null; + for (int i = 0; i < documentRecords.size(); i++) { + SearchRecord docRecord = documentRecords.get(i); + if (docRecord.docID() >= docBase + maxDoc) { + do { + leafReaderContext = leaves().get(readerIndex++); + docBase = leafReaderContext.docBase; + maxDoc = leafReaderContext.reader().maxDoc(); + } while (docRecord.docID() >= docBase + maxDoc); + + leafFieldLoader = storedFieldLoader.getLoader(leafReaderContext, null); + leafSourceLoader = sourceLoader.leaf(leafReaderContext.reader(), null); + } + int segmentDocID = docRecord.docID() - docBase; + leafFieldLoader.advanceTo(segmentDocID); + operations[docRecord.index()] = createOperation(docRecord, leafFieldLoader, leafSourceLoader, segmentDocID, leafReaderContext); + } + return operations; + } + + private Translog.Operation createOperation( + SearchRecord docRecord, + LeafStoredFieldLoader fieldLoader, + SourceLoader.Leaf sourceLoader, + int segmentDocID, + LeafReaderContext context + ) throws IOException { + if (docRecord.isTombstone() && fieldLoader.id() == null) { + assert docRecord.version() == 1L : "Noop tombstone should have version 1L; actual version [" + docRecord.version() + "]"; + assert assertDocSoftDeleted(context.reader(), segmentDocID) : "Noop but soft_deletes field is not set [" + docRecord + "]"; + return new Translog.NoOp(docRecord.seqNo(), docRecord.primaryTerm(), "null"); + } else if (docRecord.isTombstone()) { + assert assertDocSoftDeleted(context.reader(), segmentDocID) : "Delete op but soft_deletes field is not set [" + docRecord + "]"; + return new Translog.Delete(fieldLoader.id(), docRecord.seqNo(), docRecord.primaryTerm(), docRecord.version()); + } else { + if (docRecord.hasRecoverySourceSize() == false) { + // TODO: Callers should ask for the range that source should be retained. Thus we should always + // check for the existence source once we make peer-recovery to send ops after the local checkpoint. + if (requiredFullRange) { + throw new MissingHistoryOperationsException( + "source not found for seqno=" + docRecord.seqNo() + " from_seqno=" + fromSeqNo + " to_seqno=" + toSeqNo + ); + } else { + skippedOperations++; + return null; + } + } + BytesReference source = sourceLoader.source(fieldLoader, segmentDocID).internalSourceRef(); + return new Translog.Index( + fieldLoader.id(), + docRecord.seqNo(), + docRecord.primaryTerm(), + docRecord.version(), + source, + fieldLoader.routing(), + -1 // autogenerated timestamp + ); + } + } + +} diff --git a/server/src/main/java/org/elasticsearch/index/engine/ReadOnlyEngine.java b/server/src/main/java/org/elasticsearch/index/engine/ReadOnlyEngine.java index d4a2fe1b57903..1cca1ed5df6ea 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/ReadOnlyEngine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/ReadOnlyEngine.java @@ -356,7 +356,7 @@ public Closeable acquireHistoryRetentionLock() { @Override public int countChanges(String source, long fromSeqNo, long toSeqNo) throws IOException { - try (Translog.Snapshot snapshot = newChangesSnapshot(source, fromSeqNo, toSeqNo, false, true, true)) { + try (Translog.Snapshot snapshot = newChangesSnapshot(source, fromSeqNo, toSeqNo, false, true, true, -1)) { return snapshot.totalOperations(); } } @@ -369,6 +369,19 @@ public Translog.Snapshot newChangesSnapshot( boolean requiredFullRange, boolean singleConsumer, boolean accessStats + ) throws IOException { + return Translog.Snapshot.EMPTY; + } + + @Override + public Translog.Snapshot newChangesSnapshot( + String source, + long fromSeqNo, + long toSeqNo, + boolean requiredFullRange, + boolean singleConsumer, + boolean accessStats, + long maxChunkSize ) { return Translog.Snapshot.EMPTY; } diff --git a/server/src/main/java/org/elasticsearch/index/engine/RecoverySourcePruneMergePolicy.java b/server/src/main/java/org/elasticsearch/index/engine/RecoverySourcePruneMergePolicy.java index 3e99818d1827b..35a2d0b438fe5 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/RecoverySourcePruneMergePolicy.java +++ b/server/src/main/java/org/elasticsearch/index/engine/RecoverySourcePruneMergePolicy.java @@ -33,17 +33,18 @@ import org.apache.lucene.search.Weight; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; +import org.elasticsearch.core.Nullable; import org.elasticsearch.index.mapper.IdFieldMapper; import org.elasticsearch.search.internal.FilterStoredFieldVisitor; import java.io.IOException; import java.util.Arrays; -import java.util.Objects; import java.util.function.Supplier; final class RecoverySourcePruneMergePolicy extends OneMergeWrappingMergePolicy { RecoverySourcePruneMergePolicy( - String recoverySourceField, + @Nullable String pruneStoredFieldName, + String pruneNumericDVFieldName, boolean pruneIdField, Supplier retainSourceQuerySupplier, MergePolicy in @@ -52,18 +53,19 @@ final class RecoverySourcePruneMergePolicy extends OneMergeWrappingMergePolicy { @Override public CodecReader wrapForMerge(CodecReader reader) throws IOException { CodecReader wrapped = toWrap.wrapForMerge(reader); - return wrapReader(recoverySourceField, pruneIdField, wrapped, retainSourceQuerySupplier); + return wrapReader(pruneStoredFieldName, pruneNumericDVFieldName, pruneIdField, wrapped, retainSourceQuerySupplier); } }); } private static CodecReader wrapReader( - String recoverySourceField, + String pruneStoredFieldName, + String pruneNumericDVFieldName, boolean pruneIdField, CodecReader reader, Supplier retainSourceQuerySupplier ) throws IOException { - NumericDocValues recoverySource = reader.getNumericDocValues(recoverySourceField); + NumericDocValues recoverySource = reader.getNumericDocValues(pruneNumericDVFieldName); if (recoverySource == null || recoverySource.nextDoc() == DocIdSetIterator.NO_MORE_DOCS) { return reader; // early terminate - nothing to do here since non of the docs has a recovery source anymore. } @@ -78,21 +80,35 @@ private static CodecReader wrapReader( if (recoverySourceToKeep.cardinality() == reader.maxDoc()) { return reader; // keep all source } - return new SourcePruningFilterCodecReader(recoverySourceField, pruneIdField, reader, recoverySourceToKeep); + return new SourcePruningFilterCodecReader( + pruneStoredFieldName, + pruneNumericDVFieldName, + pruneIdField, + reader, + recoverySourceToKeep + ); } else { - return new SourcePruningFilterCodecReader(recoverySourceField, pruneIdField, reader, null); + return new SourcePruningFilterCodecReader(pruneStoredFieldName, pruneNumericDVFieldName, pruneIdField, reader, null); } } private static class SourcePruningFilterCodecReader extends FilterCodecReader { private final BitSet recoverySourceToKeep; - private final String recoverySourceField; + private final String pruneStoredFieldName; + private final String pruneNumericDVFieldName; private final boolean pruneIdField; - SourcePruningFilterCodecReader(String recoverySourceField, boolean pruneIdField, CodecReader reader, BitSet recoverySourceToKeep) { + SourcePruningFilterCodecReader( + @Nullable String pruneStoredFieldName, + String pruneNumericDVFieldName, + boolean pruneIdField, + CodecReader reader, + BitSet recoverySourceToKeep + ) { super(reader); - this.recoverySourceField = recoverySourceField; + this.pruneStoredFieldName = pruneStoredFieldName; this.recoverySourceToKeep = recoverySourceToKeep; + this.pruneNumericDVFieldName = pruneNumericDVFieldName; this.pruneIdField = pruneIdField; } @@ -103,8 +119,8 @@ public DocValuesProducer getDocValuesReader() { @Override public NumericDocValues getNumeric(FieldInfo field) throws IOException { NumericDocValues numeric = super.getNumeric(field); - if (recoverySourceField.equals(field.name)) { - assert numeric != null : recoverySourceField + " must have numeric DV but was null"; + if (field.name.equals(pruneNumericDVFieldName)) { + assert numeric != null : pruneNumericDVFieldName + " must have numeric DV but was null"; final DocIdSetIterator intersection; if (recoverySourceToKeep == null) { // we can't return null here lucenes DocIdMerger expects an instance @@ -139,10 +155,14 @@ public boolean advanceExact(int target) { @Override public StoredFieldsReader getFieldsReader() { + if (pruneStoredFieldName == null && pruneIdField == false) { + // nothing to prune, we can use the original fields reader + return super.getFieldsReader(); + } return new RecoverySourcePruningStoredFieldsReader( super.getFieldsReader(), recoverySourceToKeep, - recoverySourceField, + pruneStoredFieldName, pruneIdField ); } @@ -241,12 +261,13 @@ private static class RecoverySourcePruningStoredFieldsReader extends FilterStore RecoverySourcePruningStoredFieldsReader( StoredFieldsReader in, BitSet recoverySourceToKeep, - String recoverySourceField, + @Nullable String recoverySourceField, boolean pruneIdField ) { super(in); + assert recoverySourceField != null || pruneIdField : "nothing to prune"; this.recoverySourceToKeep = recoverySourceToKeep; - this.recoverySourceField = Objects.requireNonNull(recoverySourceField); + this.recoverySourceField = recoverySourceField; this.pruneIdField = pruneIdField; } @@ -258,7 +279,7 @@ public void document(int docID, StoredFieldVisitor visitor) throws IOException { super.document(docID, new FilterStoredFieldVisitor(visitor) { @Override public Status needsField(FieldInfo fieldInfo) throws IOException { - if (recoverySourceField.equals(fieldInfo.name)) { + if (fieldInfo.name.equals(recoverySourceField)) { return Status.NO; } if (pruneIdField && IdFieldMapper.NAME.equals(fieldInfo.name)) { diff --git a/server/src/main/java/org/elasticsearch/index/engine/SearchBasedChangesSnapshot.java b/server/src/main/java/org/elasticsearch/index/engine/SearchBasedChangesSnapshot.java new file mode 100644 index 0000000000000..191125c59705e --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/engine/SearchBasedChangesSnapshot.java @@ -0,0 +1,233 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.engine; + +import org.apache.lucene.document.LongPoint; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.NumericDocValues; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopFieldCollectorManager; +import org.elasticsearch.common.lucene.Lucene; +import org.elasticsearch.common.lucene.search.Queries; +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.mapper.SeqNoFieldMapper; +import org.elasticsearch.index.translog.Translog; + +import java.io.Closeable; +import java.io.IOException; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Abstract class that provides a snapshot mechanism to retrieve operations from a live Lucene index + * within a specified range of sequence numbers. Subclasses are expected to define the + * method to fetch the next batch of operations. + */ +public abstract class SearchBasedChangesSnapshot implements Translog.Snapshot, Closeable { + public static final int DEFAULT_BATCH_SIZE = 1024; + + private final IndexVersion indexVersionCreated; + private final IndexSearcher indexSearcher; + private final Closeable onClose; + + protected final long fromSeqNo, toSeqNo; + protected final boolean requiredFullRange; + protected final int searchBatchSize; + + private final boolean accessStats; + private final int totalHits; + private FieldDoc afterDoc; + private long lastSeenSeqNo; + + /** + * Constructs a new snapshot for fetching changes within a sequence number range. + * + * @param engineSearcher Engine searcher instance. + * @param searchBatchSize Number of documents to retrieve per batch. + * @param fromSeqNo Starting sequence number. + * @param toSeqNo Ending sequence number. + * @param requiredFullRange Whether the full range is required. + * @param accessStats If true, enable access statistics for counting total operations. + * @param indexVersionCreated Version of the index when it was created. + */ + protected SearchBasedChangesSnapshot( + Engine.Searcher engineSearcher, + int searchBatchSize, + long fromSeqNo, + long toSeqNo, + boolean requiredFullRange, + boolean accessStats, + IndexVersion indexVersionCreated + ) throws IOException { + + if (fromSeqNo < 0 || toSeqNo < 0 || fromSeqNo > toSeqNo) { + throw new IllegalArgumentException("Invalid range; from_seqno [" + fromSeqNo + "], to_seqno [" + toSeqNo + "]"); + } + if (searchBatchSize <= 0) { + throw new IllegalArgumentException("Search_batch_size must be positive [" + searchBatchSize + "]"); + } + + final AtomicBoolean closed = new AtomicBoolean(); + this.onClose = () -> { + if (closed.compareAndSet(false, true)) { + IOUtils.close(engineSearcher); + } + }; + + this.indexVersionCreated = indexVersionCreated; + this.fromSeqNo = fromSeqNo; + this.toSeqNo = toSeqNo; + this.lastSeenSeqNo = fromSeqNo - 1; + this.requiredFullRange = requiredFullRange; + this.indexSearcher = newIndexSearcher(engineSearcher); + this.indexSearcher.setQueryCache(null); + + long requestingSize = (toSeqNo - fromSeqNo == Long.MAX_VALUE) ? Long.MAX_VALUE : (toSeqNo - fromSeqNo + 1L); + this.searchBatchSize = (int) Math.min(requestingSize, searchBatchSize); + + this.accessStats = accessStats; + this.totalHits = accessStats ? indexSearcher.count(rangeQuery(fromSeqNo, toSeqNo, indexVersionCreated)) : -1; + } + + /** + * Abstract method for retrieving the next operation. Should be implemented by subclasses. + * + * @return The next Translog.Operation in the snapshot. + * @throws IOException If an I/O error occurs. + */ + protected abstract Translog.Operation nextOperation() throws IOException; + + /** + * Returns the list of index leaf reader contexts. + * + * @return List of LeafReaderContext. + */ + public List leaves() { + return indexSearcher.getIndexReader().leaves(); + } + + @Override + public int totalOperations() { + if (accessStats == false) { + throw new IllegalStateException("Access stats of a snapshot created with [access_stats] is false"); + } + return totalHits; + } + + @Override + public final Translog.Operation next() throws IOException { + Translog.Operation op = nextOperation(); + if (requiredFullRange) { + verifyRange(op); + } + if (op != null) { + assert fromSeqNo <= op.seqNo() && op.seqNo() <= toSeqNo && lastSeenSeqNo < op.seqNo() + : "Unexpected operation; last_seen_seqno [" + + lastSeenSeqNo + + "], from_seqno [" + + fromSeqNo + + "], to_seqno [" + + toSeqNo + + "], op [" + + op + + "]"; + lastSeenSeqNo = op.seqNo(); + } + return op; + } + + @Override + public void close() throws IOException { + onClose.close(); + } + + /** + * Retrieves the next batch of top documents based on the sequence range. + * + * @return TopDocs instance containing the documents in the current batch. + */ + protected TopDocs nextTopDocs() throws IOException { + Query rangeQuery = rangeQuery(Math.max(fromSeqNo, lastSeenSeqNo), toSeqNo, indexVersionCreated); + SortField sortBySeqNo = new SortField(SeqNoFieldMapper.NAME, SortField.Type.LONG); + + TopFieldCollectorManager collectorManager = new TopFieldCollectorManager( + new Sort(sortBySeqNo), + searchBatchSize, + afterDoc, + 0, + false + ); + TopDocs results = indexSearcher.search(rangeQuery, collectorManager); + + if (results.scoreDocs.length > 0) { + afterDoc = (FieldDoc) results.scoreDocs[results.scoreDocs.length - 1]; + } + for (int i = 0; i < results.scoreDocs.length; i++) { + results.scoreDocs[i].shardIndex = i; + } + return results; + } + + static IndexSearcher newIndexSearcher(Engine.Searcher engineSearcher) throws IOException { + return new IndexSearcher(Lucene.wrapAllDocsLive(engineSearcher.getDirectoryReader())); + } + + static Query rangeQuery(long fromSeqNo, long toSeqNo, IndexVersion indexVersionCreated) { + return new BooleanQuery.Builder().add(LongPoint.newRangeQuery(SeqNoFieldMapper.NAME, fromSeqNo, toSeqNo), BooleanClause.Occur.MUST) + .add(Queries.newNonNestedFilter(indexVersionCreated), BooleanClause.Occur.MUST) + .build(); + } + + private void verifyRange(Translog.Operation op) { + if (op == null && lastSeenSeqNo < toSeqNo) { + throw new MissingHistoryOperationsException( + "Not all operations between from_seqno [" + + fromSeqNo + + "] " + + "and to_seqno [" + + toSeqNo + + "] found; prematurely terminated last_seen_seqno [" + + lastSeenSeqNo + + "]" + ); + } else if (op != null && op.seqNo() != lastSeenSeqNo + 1) { + throw new MissingHistoryOperationsException( + "Not all operations between from_seqno [" + + fromSeqNo + + "] " + + "and to_seqno [" + + toSeqNo + + "] found; expected seqno [" + + lastSeenSeqNo + + 1 + + "]; found [" + + op + + "]" + ); + } + } + + protected static boolean assertDocSoftDeleted(LeafReader leafReader, int segmentDocId) throws IOException { + NumericDocValues docValues = leafReader.getNumericDocValues(Lucene.SOFT_DELETES_FIELD); + if (docValues == null || docValues.advanceExact(segmentDocId) == false) { + throw new IllegalStateException("DocValues for field [" + Lucene.SOFT_DELETES_FIELD + "] is not found"); + } + return docValues.longValue() == 1; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/fieldvisitor/LeafStoredFieldLoader.java b/server/src/main/java/org/elasticsearch/index/fieldvisitor/LeafStoredFieldLoader.java index c8709d3422213..3ed4c856ccc71 100644 --- a/server/src/main/java/org/elasticsearch/index/fieldvisitor/LeafStoredFieldLoader.java +++ b/server/src/main/java/org/elasticsearch/index/fieldvisitor/LeafStoredFieldLoader.java @@ -47,5 +47,4 @@ public interface LeafStoredFieldLoader { * @return stored fields for the current document */ Map> storedFields(); - } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java index 1cea8154aad43..d491eb9de5886 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java @@ -60,6 +60,8 @@ public class SourceFieldMapper extends MetadataFieldMapper { public static final String NAME = "_source"; public static final String RECOVERY_SOURCE_NAME = "_recovery_source"; + public static final String RECOVERY_SOURCE_SIZE_NAME = "_recovery_source_size"; + public static final String CONTENT_TYPE = "_source"; public static final String LOSSY_PARAMETERS_ALLOWED_SETTING_NAME = "index.lossy.source-mapping-parameters"; @@ -413,8 +415,19 @@ public void preParse(DocumentParserContext context) throws IOException { if (enableRecoverySource && originalSource != null && adaptedSource != originalSource) { // if we omitted source or modified it we add the _recovery_source to ensure we have it for ops based recovery BytesRef ref = originalSource.toBytesRef(); - context.doc().add(new StoredField(RECOVERY_SOURCE_NAME, ref.bytes, ref.offset, ref.length)); - context.doc().add(new NumericDocValuesField(RECOVERY_SOURCE_NAME, 1)); + if (context.indexSettings().isRecoverySourceSyntheticEnabled()) { + assert isSynthetic() : "recovery source should not be disabled on non-synthetic source"; + /** + * We use synthetic source for recovery, so we omit the recovery source. + * Instead, we record only the size of the uncompressed source. + * This size is used in {@link LuceneSyntheticSourceChangesSnapshot} to control memory + * usage during the recovery process when loading a batch of synthetic sources. + */ + context.doc().add(new NumericDocValuesField(RECOVERY_SOURCE_SIZE_NAME, ref.length)); + } else { + context.doc().add(new StoredField(RECOVERY_SOURCE_NAME, ref.bytes, ref.offset, ref.length)); + context.doc().add(new NumericDocValuesField(RECOVERY_SOURCE_NAME, 1)); + } } } diff --git a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java index f84ac22cd78e4..a76feff84e61b 100644 --- a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java +++ b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java @@ -2600,7 +2600,7 @@ public long getMinRetainedSeqNo() { * @param source the source of the request * @param fromSeqNo the start sequence number (inclusive) * @param toSeqNo the end sequence number (inclusive) - * @see #newChangesSnapshot(String, long, long, boolean, boolean, boolean) + * @see #newChangesSnapshot(String, long, long, boolean, boolean, boolean, long) */ public int countChanges(String source, long fromSeqNo, long toSeqNo) throws IOException { return getEngine().countChanges(source, fromSeqNo, toSeqNo); @@ -2619,6 +2619,7 @@ public int countChanges(String source, long fromSeqNo, long toSeqNo) throws IOEx * @param singleConsumer true if the snapshot is accessed by only the thread that creates the snapshot. In this case, the * snapshot can enable some optimizations to improve the performance. * @param accessStats true if the stats of the snapshot is accessed via {@link Translog.Snapshot#totalOperations()} + * @param maxChunkSize The maximum allowable size, in bytes, for buffering source documents during recovery. */ public Translog.Snapshot newChangesSnapshot( String source, @@ -2626,9 +2627,10 @@ public Translog.Snapshot newChangesSnapshot( long toSeqNo, boolean requiredFullRange, boolean singleConsumer, - boolean accessStats + boolean accessStats, + long maxChunkSize ) throws IOException { - return getEngine().newChangesSnapshot(source, fromSeqNo, toSeqNo, requiredFullRange, singleConsumer, accessStats); + return getEngine().newChangesSnapshot(source, fromSeqNo, toSeqNo, requiredFullRange, singleConsumer, accessStats, maxChunkSize); } public List segments() { diff --git a/server/src/main/java/org/elasticsearch/index/shard/PrimaryReplicaSyncer.java b/server/src/main/java/org/elasticsearch/index/shard/PrimaryReplicaSyncer.java index f843357e056c4..1143da30c2952 100644 --- a/server/src/main/java/org/elasticsearch/index/shard/PrimaryReplicaSyncer.java +++ b/server/src/main/java/org/elasticsearch/index/shard/PrimaryReplicaSyncer.java @@ -81,7 +81,7 @@ public void resync(final IndexShard indexShard, final ActionListener // Wrap translog snapshot to make it synchronized as it is accessed by different threads through SnapshotSender. // Even though those calls are not concurrent, snapshot.next() uses non-synchronized state and is not multi-thread-compatible // Also fail the resync early if the shard is shutting down - snapshot = indexShard.newChangesSnapshot("resync", startingSeqNo, Long.MAX_VALUE, false, false, true); + snapshot = indexShard.newChangesSnapshot("resync", startingSeqNo, Long.MAX_VALUE, false, false, true, chunkSize.getBytes()); final Translog.Snapshot originalSnapshot = snapshot; final Translog.Snapshot wrappedSnapshot = new Translog.Snapshot() { @Override diff --git a/server/src/main/java/org/elasticsearch/indices/recovery/RecoverySettings.java b/server/src/main/java/org/elasticsearch/indices/recovery/RecoverySettings.java index 1ec187ea4a34b..475f83de9cae3 100644 --- a/server/src/main/java/org/elasticsearch/indices/recovery/RecoverySettings.java +++ b/server/src/main/java/org/elasticsearch/indices/recovery/RecoverySettings.java @@ -399,6 +399,18 @@ public Iterator> settings() { public static final ByteSizeValue DEFAULT_CHUNK_SIZE = new ByteSizeValue(512, ByteSizeUnit.KB); + /** + * The maximum allowable size, in bytes, for buffering source documents during recovery. + */ + public static final Setting INDICES_RECOVERY_CHUNK_SIZE = Setting.byteSizeSetting( + "indices.recovery.chunk_size", + DEFAULT_CHUNK_SIZE, + ByteSizeValue.ZERO, + ByteSizeValue.ofBytes(Integer.MAX_VALUE), + Property.NodeScope, + Property.Dynamic + ); + private volatile ByteSizeValue maxBytesPerSec; private volatile int maxConcurrentFileChunks; private volatile int maxConcurrentOperations; @@ -417,7 +429,7 @@ public Iterator> settings() { private final AdjustableSemaphore maxSnapshotFileDownloadsPerNodeSemaphore; - private volatile ByteSizeValue chunkSize = DEFAULT_CHUNK_SIZE; + private volatile ByteSizeValue chunkSize; private final ByteSizeValue availableNetworkBandwidth; private final ByteSizeValue availableDiskReadBandwidth; @@ -444,6 +456,7 @@ public RecoverySettings(Settings settings, ClusterSettings clusterSettings) { this.availableNetworkBandwidth = NODE_BANDWIDTH_RECOVERY_NETWORK_SETTING.get(settings); this.availableDiskReadBandwidth = NODE_BANDWIDTH_RECOVERY_DISK_READ_SETTING.get(settings); this.availableDiskWriteBandwidth = NODE_BANDWIDTH_RECOVERY_DISK_WRITE_SETTING.get(settings); + this.chunkSize = INDICES_RECOVERY_CHUNK_SIZE.get(settings); validateNodeBandwidthRecoverySettings(settings); this.nodeBandwidthSettingsExist = hasNodeBandwidthRecoverySettings(settings); computeMaxBytesPerSec(settings); @@ -493,6 +506,7 @@ public RecoverySettings(Settings settings, ClusterSettings clusterSettings) { CLUSTER_ROUTING_ALLOCATION_NODE_CONCURRENT_INCOMING_RECOVERIES_SETTING, this::setMaxConcurrentIncomingRecoveries ); + clusterSettings.addSettingsUpdateConsumer(INDICES_RECOVERY_CHUNK_SIZE, this::setChunkSize); } private void computeMaxBytesPerSec(Settings settings) { @@ -597,7 +611,7 @@ public ByteSizeValue getChunkSize() { return chunkSize; } - public void setChunkSize(ByteSizeValue chunkSize) { // only settable for tests + public void setChunkSize(ByteSizeValue chunkSize) { if (chunkSize.bytesAsInt() <= 0) { throw new IllegalArgumentException("chunkSize must be > 0"); } diff --git a/server/src/main/java/org/elasticsearch/indices/recovery/RecoverySourceHandler.java b/server/src/main/java/org/elasticsearch/indices/recovery/RecoverySourceHandler.java index 3603b984fb148..622e56f596e19 100644 --- a/server/src/main/java/org/elasticsearch/indices/recovery/RecoverySourceHandler.java +++ b/server/src/main/java/org/elasticsearch/indices/recovery/RecoverySourceHandler.java @@ -324,7 +324,8 @@ && isTargetSameHistory() Long.MAX_VALUE, false, false, - true + true, + chunkSizeInBytes ); resources.add(phase2Snapshot); retentionLock.close(); diff --git a/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java b/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java index de241301cfef9..ddb1e3d384fbe 100644 --- a/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java +++ b/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java @@ -160,6 +160,7 @@ public final class RestoreService implements ClusterStateApplier { SETTING_HISTORY_UUID, IndexSettings.MODE.getKey(), SourceFieldMapper.INDEX_MAPPER_SOURCE_MODE_SETTING.getKey(), + IndexSettings.RECOVERY_USE_SYNTHETIC_SOURCE_SETTING.getKey(), IndexSortConfig.INDEX_SORT_FIELD_SETTING.getKey(), IndexSortConfig.INDEX_SORT_ORDER_SETTING.getKey(), IndexSortConfig.INDEX_SORT_MODE_SETTING.getKey(), diff --git a/server/src/test/java/org/elasticsearch/index/engine/InternalEngineTests.java b/server/src/test/java/org/elasticsearch/index/engine/InternalEngineTests.java index 9d8c5649f0dce..3e3be6a315af2 100644 --- a/server/src/test/java/org/elasticsearch/index/engine/InternalEngineTests.java +++ b/server/src/test/java/org/elasticsearch/index/engine/InternalEngineTests.java @@ -89,6 +89,7 @@ import org.elasticsearch.common.lucene.uid.VersionsAndSeqNoResolver; import org.elasticsearch.common.lucene.uid.VersionsAndSeqNoResolver.DocIdAndSeqNo; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; @@ -3448,7 +3449,7 @@ public void testTranslogReplay() throws IOException { assertThat(indexResult.getVersion(), equalTo(1L)); } assertVisibleCount(engine, numDocs); - translogHandler = createTranslogHandler(engine.engineConfig.getIndexSettings()); + translogHandler = createTranslogHandler(mapperService); engine.close(); // we need to reuse the engine config unless the parser.mappingModified won't work @@ -3460,7 +3461,7 @@ public void testTranslogReplay() throws IOException { assertEquals(numDocs, translogHandler.appliedOperations()); engine.close(); - translogHandler = createTranslogHandler(engine.engineConfig.getIndexSettings()); + translogHandler = createTranslogHandler(mapperService); engine = createEngine(store, primaryTranslogDir, inSyncGlobalCheckpointSupplier); engine.refresh("warm_up"); assertVisibleCount(engine, numDocs, false); @@ -3514,7 +3515,7 @@ public void testTranslogReplay() throws IOException { } engine.close(); - translogHandler = createTranslogHandler(engine.engineConfig.getIndexSettings()); + translogHandler = createTranslogHandler(mapperService); engine = createEngine(store, primaryTranslogDir, inSyncGlobalCheckpointSupplier); engine.refresh("warm_up"); try (Engine.Searcher searcher = engine.acquireSearcher("test")) { @@ -6447,7 +6448,8 @@ protected void doRun() throws Exception { max, true, randomBoolean(), - randomBoolean() + randomBoolean(), + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) ) ) {} } else { @@ -7673,7 +7675,7 @@ public void testDisableRecoverySource() throws Exception { ) { IllegalStateException exc = expectThrows( IllegalStateException.class, - () -> engine.newChangesSnapshot("test", 0, 1000, true, true, true) + () -> engine.newChangesSnapshot("test", 0, 1000, true, true, true, randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes())) ); assertThat(exc.getMessage(), containsString("unavailable")); } diff --git a/server/src/test/java/org/elasticsearch/index/engine/LuceneChangesSnapshotTests.java b/server/src/test/java/org/elasticsearch/index/engine/LuceneChangesSnapshotTests.java index 85ba368165ceb..5863d2f932968 100644 --- a/server/src/test/java/org/elasticsearch/index/engine/LuceneChangesSnapshotTests.java +++ b/server/src/test/java/org/elasticsearch/index/engine/LuceneChangesSnapshotTests.java @@ -10,289 +10,37 @@ package org.elasticsearch.index.engine; import org.apache.lucene.index.NoMergePolicy; -import org.elasticsearch.common.Randomness; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.core.IOUtils; -import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.mapper.ParsedDocument; -import org.elasticsearch.index.mapper.Uid; +import org.elasticsearch.index.mapper.MappingLookup; import org.elasticsearch.index.store.Store; -import org.elasticsearch.index.translog.SnapshotMatchers; import org.elasticsearch.index.translog.Translog; -import org.elasticsearch.test.IndexSettingsModule; import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.LongSupplier; - -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.hasSize; - -public class LuceneChangesSnapshotTests extends EngineTestCase { +public class LuceneChangesSnapshotTests extends SearchBasedChangesSnapshotTests { @Override - protected Settings indexSettings() { - return Settings.builder() - .put(super.indexSettings()) - .put(IndexSettings.INDEX_SOFT_DELETES_SETTING.getKey(), true) // always enable soft-deletes - .build(); - } - - public void testBasics() throws Exception { - long fromSeqNo = randomNonNegativeLong(); - long toSeqNo = randomLongBetween(fromSeqNo, Long.MAX_VALUE); - // Empty engine - try (Translog.Snapshot snapshot = engine.newChangesSnapshot("test", fromSeqNo, toSeqNo, true, randomBoolean(), randomBoolean())) { - IllegalStateException error = expectThrows(IllegalStateException.class, () -> drainAll(snapshot)); - assertThat( - error.getMessage(), - containsString("Not all operations between from_seqno [" + fromSeqNo + "] and to_seqno [" + toSeqNo + "] found") - ); - } - try (Translog.Snapshot snapshot = engine.newChangesSnapshot("test", fromSeqNo, toSeqNo, false, randomBoolean(), randomBoolean())) { - assertThat(snapshot, SnapshotMatchers.size(0)); - } - int numOps = between(1, 100); - int refreshedSeqNo = -1; - for (int i = 0; i < numOps; i++) { - String id = Integer.toString(randomIntBetween(i, i + 5)); - ParsedDocument doc = createParsedDoc(id, null, randomBoolean()); - if (randomBoolean()) { - engine.index(indexForDoc(doc)); - } else { - engine.delete(new Engine.Delete(doc.id(), Uid.encodeId(doc.id()), primaryTerm.get())); - } - if (rarely()) { - if (randomBoolean()) { - engine.flush(); - } else { - engine.refresh("test"); - } - refreshedSeqNo = i; - } - } - if (refreshedSeqNo == -1) { - fromSeqNo = between(0, numOps); - toSeqNo = randomLongBetween(fromSeqNo, numOps * 2); - - Engine.Searcher searcher = engine.acquireSearcher("test", Engine.SearcherScope.INTERNAL); - try ( - Translog.Snapshot snapshot = new LuceneChangesSnapshot( - searcher, - between(1, LuceneChangesSnapshot.DEFAULT_BATCH_SIZE), - fromSeqNo, - toSeqNo, - false, - randomBoolean(), - randomBoolean(), - IndexVersion.current() - ) - ) { - searcher = null; - assertThat(snapshot, SnapshotMatchers.size(0)); - } finally { - IOUtils.close(searcher); - } - - searcher = engine.acquireSearcher("test", Engine.SearcherScope.INTERNAL); - try ( - Translog.Snapshot snapshot = new LuceneChangesSnapshot( - searcher, - between(1, LuceneChangesSnapshot.DEFAULT_BATCH_SIZE), - fromSeqNo, - toSeqNo, - true, - randomBoolean(), - randomBoolean(), - IndexVersion.current() - ) - ) { - searcher = null; - IllegalStateException error = expectThrows(IllegalStateException.class, () -> drainAll(snapshot)); - assertThat( - error.getMessage(), - containsString("Not all operations between from_seqno [" + fromSeqNo + "] and to_seqno [" + toSeqNo + "] found") - ); - } finally { - IOUtils.close(searcher); - } - } else { - fromSeqNo = randomLongBetween(0, refreshedSeqNo); - toSeqNo = randomLongBetween(refreshedSeqNo + 1, numOps * 2); - Engine.Searcher searcher = engine.acquireSearcher("test", Engine.SearcherScope.INTERNAL); - try ( - Translog.Snapshot snapshot = new LuceneChangesSnapshot( - searcher, - between(1, LuceneChangesSnapshot.DEFAULT_BATCH_SIZE), - fromSeqNo, - toSeqNo, - false, - randomBoolean(), - randomBoolean(), - IndexVersion.current() - ) - ) { - searcher = null; - assertThat(snapshot, SnapshotMatchers.containsSeqNoRange(fromSeqNo, refreshedSeqNo)); - } finally { - IOUtils.close(searcher); - } - searcher = engine.acquireSearcher("test", Engine.SearcherScope.INTERNAL); - try ( - Translog.Snapshot snapshot = new LuceneChangesSnapshot( - searcher, - between(1, LuceneChangesSnapshot.DEFAULT_BATCH_SIZE), - fromSeqNo, - toSeqNo, - true, - randomBoolean(), - randomBoolean(), - IndexVersion.current() - ) - ) { - searcher = null; - IllegalStateException error = expectThrows(IllegalStateException.class, () -> drainAll(snapshot)); - assertThat( - error.getMessage(), - containsString("Not all operations between from_seqno [" + fromSeqNo + "] and to_seqno [" + toSeqNo + "] found") - ); - } finally { - IOUtils.close(searcher); - } - toSeqNo = randomLongBetween(fromSeqNo, refreshedSeqNo); - searcher = engine.acquireSearcher("test", Engine.SearcherScope.INTERNAL); - try ( - Translog.Snapshot snapshot = new LuceneChangesSnapshot( - searcher, - between(1, LuceneChangesSnapshot.DEFAULT_BATCH_SIZE), - fromSeqNo, - toSeqNo, - true, - randomBoolean(), - randomBoolean(), - IndexVersion.current() - ) - ) { - searcher = null; - assertThat(snapshot, SnapshotMatchers.containsSeqNoRange(fromSeqNo, toSeqNo)); - } finally { - IOUtils.close(searcher); - } - } - // Get snapshot via engine will auto refresh - fromSeqNo = randomLongBetween(0, numOps - 1); - toSeqNo = randomLongBetween(fromSeqNo, numOps - 1); - try ( - Translog.Snapshot snapshot = engine.newChangesSnapshot( - "test", - fromSeqNo, - toSeqNo, - randomBoolean(), - randomBoolean(), - randomBoolean() - ) - ) { - assertThat(snapshot, SnapshotMatchers.containsSeqNoRange(fromSeqNo, toSeqNo)); - } - } - - /** - * A nested document is indexed into Lucene as multiple documents. While the root document has both sequence number and primary term, - * non-root documents don't have primary term but only sequence numbers. This test verifies that {@link LuceneChangesSnapshot} - * correctly skip non-root documents and returns at most one operation per sequence number. - */ - public void testSkipNonRootOfNestedDocuments() throws Exception { - Map seqNoToTerm = new HashMap<>(); - List operations = generateHistoryOnReplica(between(1, 100), randomBoolean(), randomBoolean(), randomBoolean()); - for (Engine.Operation op : operations) { - if (engine.getLocalCheckpointTracker().hasProcessed(op.seqNo()) == false) { - seqNoToTerm.put(op.seqNo(), op.primaryTerm()); - } - applyOperation(engine, op); - if (rarely()) { - engine.refresh("test"); - } - if (rarely()) { - engine.rollTranslogGeneration(); - } - if (rarely()) { - engine.flush(); - } - } - long maxSeqNo = engine.getLocalCheckpointTracker().getMaxSeqNo(); - engine.refresh("test"); - Engine.Searcher searcher = engine.acquireSearcher("test", Engine.SearcherScope.INTERNAL); - final boolean accessStats = randomBoolean(); - try ( - Translog.Snapshot snapshot = new LuceneChangesSnapshot( - searcher, - between(1, 100), - 0, - maxSeqNo, - false, - randomBoolean(), - accessStats, - IndexVersion.current() - ) - ) { - if (accessStats) { - assertThat(snapshot.totalOperations(), equalTo(seqNoToTerm.size())); - } - Translog.Operation op; - while ((op = snapshot.next()) != null) { - assertThat(op.toString(), op.primaryTerm(), equalTo(seqNoToTerm.get(op.seqNo()))); - } - assertThat(snapshot.skippedOperations(), equalTo(0)); - } - } - - public void testUpdateAndReadChangesConcurrently() throws Exception { - Follower[] followers = new Follower[between(1, 3)]; - CountDownLatch readyLatch = new CountDownLatch(followers.length + 1); - AtomicBoolean isDone = new AtomicBoolean(); - for (int i = 0; i < followers.length; i++) { - followers[i] = new Follower(engine, isDone, readyLatch); - followers[i].start(); - } - boolean onPrimary = randomBoolean(); - List operations = new ArrayList<>(); - int numOps = frequently() ? scaledRandomIntBetween(1, 1500) : scaledRandomIntBetween(5000, 20_000); - for (int i = 0; i < numOps; i++) { - String id = Integer.toString(randomIntBetween(0, randomBoolean() ? 10 : numOps * 2)); - ParsedDocument doc = createParsedDoc(id, randomAlphaOfLengthBetween(1, 5), randomBoolean()); - final Engine.Operation op; - if (onPrimary) { - if (randomBoolean()) { - op = new Engine.Index(newUid(doc), primaryTerm.get(), doc); - } else { - op = new Engine.Delete(doc.id(), Uid.encodeId(doc.id()), primaryTerm.get()); - } - } else { - if (randomBoolean()) { - op = replicaIndexForDoc(doc, randomNonNegativeLong(), i, randomBoolean()); - } else { - op = replicaDeleteForDoc(doc.id(), randomNonNegativeLong(), i, randomNonNegativeLong()); - } - } - operations.add(op); - } - readyLatch.countDown(); - readyLatch.await(); - Randomness.shuffle(operations); - concurrentlyApplyOps(operations, engine); - assertThat(engine.getLocalCheckpointTracker().getProcessedCheckpoint(), equalTo(operations.size() - 1L)); - isDone.set(true); - for (Follower follower : followers) { - follower.join(); - IOUtils.close(follower.engine, follower.engine.store); - } + protected Translog.Snapshot newRandomSnapshot( + MappingLookup mappingLookup, + Engine.Searcher engineSearcher, + int searchBatchSize, + long fromSeqNo, + long toSeqNo, + boolean requiredFullRange, + boolean singleConsumer, + boolean accessStats, + IndexVersion indexVersionCreated + ) throws IOException { + return new LuceneChangesSnapshot( + engineSearcher, + searchBatchSize, + fromSeqNo, + toSeqNo, + requiredFullRange, + singleConsumer, + accessStats, + indexVersionCreated + ); } public void testAccessStoredFieldsSequentially() throws Exception { @@ -319,7 +67,8 @@ public void testAccessStoredFieldsSequentially() throws Exception { between(1, smallBatch), false, randomBoolean(), - randomBoolean() + randomBoolean(), + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) ) ) { while ((op = snapshot.next()) != null) { @@ -335,7 +84,8 @@ public void testAccessStoredFieldsSequentially() throws Exception { between(20, 100), false, randomBoolean(), - randomBoolean() + randomBoolean(), + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) ) ) { while ((op = snapshot.next()) != null) { @@ -351,7 +101,8 @@ public void testAccessStoredFieldsSequentially() throws Exception { between(21, 100), false, true, - randomBoolean() + randomBoolean(), + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) ) ) { while ((op = snapshot.next()) != null) { @@ -367,7 +118,8 @@ public void testAccessStoredFieldsSequentially() throws Exception { between(21, 100), false, false, - randomBoolean() + randomBoolean(), + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) ) ) { while ((op = snapshot.next()) != null) { @@ -377,165 +129,4 @@ public void testAccessStoredFieldsSequentially() throws Exception { } } } - - class Follower extends Thread { - private final InternalEngine leader; - private final InternalEngine engine; - private final TranslogHandler translogHandler; - private final AtomicBoolean isDone; - private final CountDownLatch readLatch; - - Follower(InternalEngine leader, AtomicBoolean isDone, CountDownLatch readLatch) throws IOException { - this.leader = leader; - this.isDone = isDone; - this.readLatch = readLatch; - this.translogHandler = new TranslogHandler( - xContentRegistry(), - IndexSettingsModule.newIndexSettings(shardId.getIndexName(), leader.engineConfig.getIndexSettings().getSettings()) - ); - this.engine = createEngine(createStore(), createTempDir()); - } - - void pullOperations(InternalEngine follower) throws IOException { - long leaderCheckpoint = leader.getLocalCheckpointTracker().getProcessedCheckpoint(); - long followerCheckpoint = follower.getLocalCheckpointTracker().getProcessedCheckpoint(); - if (followerCheckpoint < leaderCheckpoint) { - long fromSeqNo = followerCheckpoint + 1; - long batchSize = randomLongBetween(0, 100); - long toSeqNo = Math.min(fromSeqNo + batchSize, leaderCheckpoint); - try ( - Translog.Snapshot snapshot = leader.newChangesSnapshot( - "test", - fromSeqNo, - toSeqNo, - true, - randomBoolean(), - randomBoolean() - ) - ) { - translogHandler.run(follower, snapshot); - } - } - } - - @Override - public void run() { - try { - readLatch.countDown(); - readLatch.await(); - while (isDone.get() == false - || engine.getLocalCheckpointTracker().getProcessedCheckpoint() < leader.getLocalCheckpointTracker() - .getProcessedCheckpoint()) { - pullOperations(engine); - } - assertConsistentHistoryBetweenTranslogAndLuceneIndex(engine); - // have to verify without source since we are randomly testing without _source - List docsWithoutSourceOnFollower = getDocIds(engine, true).stream() - .map(d -> new DocIdSeqNoAndSource(d.id(), null, d.seqNo(), d.primaryTerm(), d.version())) - .toList(); - List docsWithoutSourceOnLeader = getDocIds(leader, true).stream() - .map(d -> new DocIdSeqNoAndSource(d.id(), null, d.seqNo(), d.primaryTerm(), d.version())) - .toList(); - assertThat(docsWithoutSourceOnFollower, equalTo(docsWithoutSourceOnLeader)); - } catch (Exception ex) { - throw new AssertionError(ex); - } - } - } - - private List drainAll(Translog.Snapshot snapshot) throws IOException { - List operations = new ArrayList<>(); - Translog.Operation op; - while ((op = snapshot.next()) != null) { - final Translog.Operation newOp = op; - logger.trace("Reading [{}]", op); - assert operations.stream().allMatch(o -> o.seqNo() < newOp.seqNo()) : "Operations [" + operations + "], op [" + op + "]"; - operations.add(newOp); - } - return operations; - } - - public void testOverFlow() throws Exception { - long fromSeqNo = randomLongBetween(0, 5); - long toSeqNo = randomLongBetween(Long.MAX_VALUE - 5, Long.MAX_VALUE); - try (Translog.Snapshot snapshot = engine.newChangesSnapshot("test", fromSeqNo, toSeqNo, true, randomBoolean(), randomBoolean())) { - IllegalStateException error = expectThrows(IllegalStateException.class, () -> drainAll(snapshot)); - assertThat( - error.getMessage(), - containsString("Not all operations between from_seqno [" + fromSeqNo + "] and to_seqno [" + toSeqNo + "] found") - ); - } - } - - public void testStats() throws Exception { - try (Store store = createStore(); Engine engine = createEngine(defaultSettings, store, createTempDir(), NoMergePolicy.INSTANCE)) { - int numOps = between(100, 5000); - long startingSeqNo = randomLongBetween(0, Integer.MAX_VALUE); - List operations = generateHistoryOnReplica( - numOps, - startingSeqNo, - randomBoolean(), - randomBoolean(), - randomBoolean() - ); - applyOperations(engine, operations); - - LongSupplier fromSeqNo = () -> { - if (randomBoolean()) { - return 0L; - } else if (randomBoolean()) { - return startingSeqNo; - } else { - return randomLongBetween(0, startingSeqNo); - } - }; - - LongSupplier toSeqNo = () -> { - final long maxSeqNo = engine.getSeqNoStats(-1).getMaxSeqNo(); - if (randomBoolean()) { - return maxSeqNo; - } else if (randomBoolean()) { - return Long.MAX_VALUE; - } else { - return randomLongBetween(maxSeqNo, Long.MAX_VALUE); - } - }; - // Can't access stats if didn't request it - try ( - Translog.Snapshot snapshot = engine.newChangesSnapshot( - "test", - fromSeqNo.getAsLong(), - toSeqNo.getAsLong(), - false, - randomBoolean(), - false - ) - ) { - IllegalStateException error = expectThrows(IllegalStateException.class, snapshot::totalOperations); - assertThat(error.getMessage(), equalTo("Access stats of a snapshot created with [access_stats] is false")); - final List translogOps = drainAll(snapshot); - assertThat(translogOps, hasSize(numOps)); - error = expectThrows(IllegalStateException.class, snapshot::totalOperations); - assertThat(error.getMessage(), equalTo("Access stats of a snapshot created with [access_stats] is false")); - } - // Access stats and operations - try ( - Translog.Snapshot snapshot = engine.newChangesSnapshot( - "test", - fromSeqNo.getAsLong(), - toSeqNo.getAsLong(), - false, - randomBoolean(), - true - ) - ) { - assertThat(snapshot.totalOperations(), equalTo(numOps)); - final List translogOps = drainAll(snapshot); - assertThat(translogOps, hasSize(numOps)); - assertThat(snapshot.totalOperations(), equalTo(numOps)); - } - // Verify count - assertThat(engine.countChanges("test", fromSeqNo.getAsLong(), toSeqNo.getAsLong()), equalTo(numOps)); - } - } } diff --git a/server/src/test/java/org/elasticsearch/index/engine/LuceneSyntheticSourceChangesSnapshotTests.java b/server/src/test/java/org/elasticsearch/index/engine/LuceneSyntheticSourceChangesSnapshotTests.java new file mode 100644 index 0000000000000..2a6c3428d6d45 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/engine/LuceneSyntheticSourceChangesSnapshotTests.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.engine; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.mapper.MappingLookup; +import org.elasticsearch.index.mapper.SourceFieldMapper; +import org.elasticsearch.index.translog.Translog; + +import java.io.IOException; + +import static org.elasticsearch.index.mapper.SourceFieldMapper.INDEX_MAPPER_SOURCE_MODE_SETTING; + +public class LuceneSyntheticSourceChangesSnapshotTests extends SearchBasedChangesSnapshotTests { + @Override + protected Settings indexSettings() { + return Settings.builder() + .put(super.indexSettings()) + .put(INDEX_MAPPER_SOURCE_MODE_SETTING.getKey(), SourceFieldMapper.Mode.SYNTHETIC.name()) + .put(IndexSettings.RECOVERY_USE_SYNTHETIC_SOURCE_SETTING.getKey(), true) + .build(); + } + + @Override + protected Translog.Snapshot newRandomSnapshot( + MappingLookup mappingLookup, + Engine.Searcher engineSearcher, + int searchBatchSize, + long fromSeqNo, + long toSeqNo, + boolean requiredFullRange, + boolean singleConsumer, + boolean accessStats, + IndexVersion indexVersionCreated + ) throws IOException { + return new LuceneSyntheticSourceChangesSnapshot( + mappingLookup, + engineSearcher, + searchBatchSize, + randomLongBetween(0, ByteSizeValue.ofBytes(Integer.MAX_VALUE).getBytes()), + fromSeqNo, + toSeqNo, + requiredFullRange, + accessStats, + indexVersionCreated + ); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/engine/RecoverySourcePruneMergePolicyTests.java b/server/src/test/java/org/elasticsearch/index/engine/RecoverySourcePruneMergePolicyTests.java index c0e365909429a..74d6e83aff266 100644 --- a/server/src/test/java/org/elasticsearch/index/engine/RecoverySourcePruneMergePolicyTests.java +++ b/server/src/test/java/org/elasticsearch/index/engine/RecoverySourcePruneMergePolicyTests.java @@ -39,83 +39,99 @@ import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.Set; import java.util.stream.Collectors; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; + public class RecoverySourcePruneMergePolicyTests extends ESTestCase { public void testPruneAll() throws IOException { - try (Directory dir = newDirectory()) { - boolean pruneIdField = randomBoolean(); - IndexWriterConfig iwc = newIndexWriterConfig(); - RecoverySourcePruneMergePolicy mp = new RecoverySourcePruneMergePolicy( - "extra_source", - pruneIdField, - MatchNoDocsQuery::new, - newLogMergePolicy() - ); - iwc.setMergePolicy(new ShuffleForcedMergePolicy(mp)); - try (IndexWriter writer = new IndexWriter(dir, iwc)) { - for (int i = 0; i < 20; i++) { - if (i > 0 && randomBoolean()) { - writer.flush(); - } - Document doc = new Document(); - doc.add(new StoredField(IdFieldMapper.NAME, "_id")); - doc.add(new StoredField("source", "hello world")); - doc.add(new StoredField("extra_source", "hello world")); - doc.add(new NumericDocValuesField("extra_source", 1)); - writer.addDocument(doc); - } - writer.forceMerge(1); - writer.commit(); - try (DirectoryReader reader = DirectoryReader.open(writer)) { - StoredFields storedFields = reader.storedFields(); - for (int i = 0; i < reader.maxDoc(); i++) { - Document document = storedFields.document(i); - if (pruneIdField) { - assertEquals(1, document.getFields().size()); - assertEquals("source", document.getFields().get(0).name()); - } else { - assertEquals(2, document.getFields().size()); - assertEquals(IdFieldMapper.NAME, document.getFields().get(0).name()); - assertEquals("source", document.getFields().get(1).name()); + for (boolean pruneIdField : List.of(true, false)) { + for (boolean syntheticRecoverySource : List.of(true, false)) { + try (Directory dir = newDirectory()) { + IndexWriterConfig iwc = newIndexWriterConfig(); + RecoverySourcePruneMergePolicy mp = new RecoverySourcePruneMergePolicy( + syntheticRecoverySource ? null : "extra_source", + syntheticRecoverySource ? "extra_source_size" : "extra_source", + pruneIdField, + MatchNoDocsQuery::new, + newLogMergePolicy() + ); + iwc.setMergePolicy(new ShuffleForcedMergePolicy(mp)); + try (IndexWriter writer = new IndexWriter(dir, iwc)) { + for (int i = 0; i < 20; i++) { + if (i > 0 && randomBoolean()) { + writer.flush(); + } + Document doc = new Document(); + doc.add(new StoredField(IdFieldMapper.NAME, "_id")); + doc.add(new StoredField("source", "hello world")); + if (syntheticRecoverySource) { + doc.add(new NumericDocValuesField("extra_source_size", randomIntBetween(10, 10000))); + } else { + doc.add(new StoredField("extra_source", "hello world")); + doc.add(new NumericDocValuesField("extra_source", 1)); + } + writer.addDocument(doc); } - } - assertEquals(1, reader.leaves().size()); - LeafReader leafReader = reader.leaves().get(0).reader(); - NumericDocValues extra_source = leafReader.getNumericDocValues("extra_source"); - if (extra_source != null) { - assertEquals(DocIdSetIterator.NO_MORE_DOCS, extra_source.nextDoc()); - } - if (leafReader instanceof CodecReader codecReader && reader instanceof StandardDirectoryReader sdr) { - SegmentInfos segmentInfos = sdr.getSegmentInfos(); - MergePolicy.MergeSpecification forcedMerges = mp.findForcedDeletesMerges( - segmentInfos, - new MergePolicy.MergeContext() { - @Override - public int numDeletesToMerge(SegmentCommitInfo info) { - return info.info.maxDoc() - 1; + writer.forceMerge(1); + writer.commit(); + try (DirectoryReader reader = DirectoryReader.open(writer)) { + StoredFields storedFields = reader.storedFields(); + for (int i = 0; i < reader.maxDoc(); i++) { + Document document = storedFields.document(i); + if (pruneIdField) { + assertEquals(1, document.getFields().size()); + assertEquals("source", document.getFields().get(0).name()); + } else { + assertEquals(2, document.getFields().size()); + assertEquals(IdFieldMapper.NAME, document.getFields().get(0).name()); + assertEquals("source", document.getFields().get(1).name()); } + } - @Override - public int numDeletedDocs(SegmentCommitInfo info) { - return info.info.maxDoc() - 1; - } + assertEquals(1, reader.leaves().size()); + LeafReader leafReader = reader.leaves().get(0).reader(); - @Override - public InfoStream getInfoStream() { - return new NullInfoStream(); - } + NumericDocValues extra_source = leafReader.getNumericDocValues( + syntheticRecoverySource ? "extra_source_size" : "extra_source" + ); + if (extra_source != null) { + assertEquals(DocIdSetIterator.NO_MORE_DOCS, extra_source.nextDoc()); + } + if (leafReader instanceof CodecReader codecReader && reader instanceof StandardDirectoryReader sdr) { + SegmentInfos segmentInfos = sdr.getSegmentInfos(); + MergePolicy.MergeSpecification forcedMerges = mp.findForcedDeletesMerges( + segmentInfos, + new MergePolicy.MergeContext() { + @Override + public int numDeletesToMerge(SegmentCommitInfo info) { + return info.info.maxDoc() - 1; + } - @Override - public Set getMergingSegments() { - return Collections.emptySet(); - } + @Override + public int numDeletedDocs(SegmentCommitInfo info) { + return info.info.maxDoc() - 1; + } + + @Override + public InfoStream getInfoStream() { + return new NullInfoStream(); + } + + @Override + public Set getMergingSegments() { + return Collections.emptySet(); + } + } + ); + // don't wrap if there is nothing to do + assertSame(codecReader, forcedMerges.merges.get(0).wrapForMerge(codecReader)); } - ); - // don't wrap if there is nothing to do - assertSame(codecReader, forcedMerges.merges.get(0).wrapForMerge(codecReader)); + } } } } @@ -123,87 +139,126 @@ public Set getMergingSegments() { } public void testPruneSome() throws IOException { - try (Directory dir = newDirectory()) { - boolean pruneIdField = randomBoolean(); - IndexWriterConfig iwc = newIndexWriterConfig(); - iwc.setMergePolicy( - new RecoverySourcePruneMergePolicy( - "extra_source", - pruneIdField, - () -> new TermQuery(new Term("even", "true")), - iwc.getMergePolicy() - ) - ); - try (IndexWriter writer = new IndexWriter(dir, iwc)) { - for (int i = 0; i < 20; i++) { - if (i > 0 && randomBoolean()) { - writer.flush(); - } - Document doc = new Document(); - doc.add(new StoredField(IdFieldMapper.NAME, "_id")); - doc.add(new StringField("even", Boolean.toString(i % 2 == 0), Field.Store.YES)); - doc.add(new StoredField("source", "hello world")); - doc.add(new StoredField("extra_source", "hello world")); - doc.add(new NumericDocValuesField("extra_source", 1)); - writer.addDocument(doc); - } - writer.forceMerge(1); - writer.commit(); - try (DirectoryReader reader = DirectoryReader.open(writer)) { - assertEquals(1, reader.leaves().size()); - NumericDocValues extra_source = reader.leaves().get(0).reader().getNumericDocValues("extra_source"); - assertNotNull(extra_source); - StoredFields storedFields = reader.storedFields(); - for (int i = 0; i < reader.maxDoc(); i++) { - Document document = storedFields.document(i); - Set collect = document.getFields().stream().map(IndexableField::name).collect(Collectors.toSet()); - assertTrue(collect.contains("source")); - assertTrue(collect.contains("even")); - if (collect.size() == 4) { - assertTrue(collect.contains("extra_source")); - assertTrue(collect.contains(IdFieldMapper.NAME)); - assertEquals("true", document.getField("even").stringValue()); - assertEquals(i, extra_source.nextDoc()); - } else { - assertEquals(pruneIdField ? 2 : 3, document.getFields().size()); + for (boolean pruneIdField : List.of(true, false)) { + for (boolean syntheticRecoverySource : List.of(true, false)) { + try (Directory dir = newDirectory()) { + IndexWriterConfig iwc = newIndexWriterConfig(); + iwc.setMergePolicy( + new RecoverySourcePruneMergePolicy( + syntheticRecoverySource ? null : "extra_source", + syntheticRecoverySource ? "extra_source_size" : "extra_source", + pruneIdField, + () -> new TermQuery(new Term("even", "true")), + iwc.getMergePolicy() + ) + ); + try (IndexWriter writer = new IndexWriter(dir, iwc)) { + for (int i = 0; i < 20; i++) { + if (i > 0 && randomBoolean()) { + writer.flush(); + } + Document doc = new Document(); + doc.add(new StoredField(IdFieldMapper.NAME, "_id")); + doc.add(new StringField("even", Boolean.toString(i % 2 == 0), Field.Store.YES)); + doc.add(new StoredField("source", "hello world")); + if (syntheticRecoverySource) { + doc.add(new NumericDocValuesField("extra_source_size", randomIntBetween(10, 10000))); + } else { + doc.add(new StoredField("extra_source", "hello world")); + doc.add(new NumericDocValuesField("extra_source", 1)); + } + writer.addDocument(doc); + } + writer.forceMerge(1); + writer.commit(); + try (DirectoryReader reader = DirectoryReader.open(writer)) { + assertEquals(1, reader.leaves().size()); + String extraSourceDVName = syntheticRecoverySource ? "extra_source_size" : "extra_source"; + NumericDocValues extra_source = reader.leaves().get(0).reader().getNumericDocValues(extraSourceDVName); + assertNotNull(extra_source); + StoredFields storedFields = reader.storedFields(); + for (int i = 0; i < reader.maxDoc(); i++) { + Document document = storedFields.document(i); + Set collect = document.getFields().stream().map(IndexableField::name).collect(Collectors.toSet()); + assertTrue(collect.contains("source")); + assertTrue(collect.contains("even")); + boolean isEven = Boolean.parseBoolean(document.getField("even").stringValue()); + if (isEven) { + assertTrue(collect.contains(IdFieldMapper.NAME)); + assertThat(collect.contains("extra_source"), equalTo(syntheticRecoverySource == false)); + if (extra_source.docID() < i) { + extra_source.advance(i); + } + assertEquals(i, extra_source.docID()); + if (syntheticRecoverySource) { + assertThat(extra_source.longValue(), greaterThan(10L)); + } else { + assertThat(extra_source.longValue(), equalTo(1L)); + } + } else { + assertThat(collect.contains(IdFieldMapper.NAME), equalTo(pruneIdField == false)); + assertFalse(collect.contains("extra_source")); + if (extra_source.docID() < i) { + extra_source.advance(i); + } + assertNotEquals(i, extra_source.docID()); + } + } + if (extra_source.docID() != DocIdSetIterator.NO_MORE_DOCS) { + assertEquals(DocIdSetIterator.NO_MORE_DOCS, extra_source.nextDoc()); + } } } - assertEquals(DocIdSetIterator.NO_MORE_DOCS, extra_source.nextDoc()); } } } } public void testPruneNone() throws IOException { - try (Directory dir = newDirectory()) { - IndexWriterConfig iwc = newIndexWriterConfig(); - iwc.setMergePolicy(new RecoverySourcePruneMergePolicy("extra_source", false, MatchAllDocsQuery::new, iwc.getMergePolicy())); - try (IndexWriter writer = new IndexWriter(dir, iwc)) { - for (int i = 0; i < 20; i++) { - if (i > 0 && randomBoolean()) { - writer.flush(); + for (boolean syntheticRecoverySource : List.of(true, false)) { + try (Directory dir = newDirectory()) { + IndexWriterConfig iwc = newIndexWriterConfig(); + iwc.setMergePolicy( + new RecoverySourcePruneMergePolicy( + syntheticRecoverySource ? null : "extra_source", + syntheticRecoverySource ? "extra_source_size" : "extra_source", + false, + MatchAllDocsQuery::new, + iwc.getMergePolicy() + ) + ); + try (IndexWriter writer = new IndexWriter(dir, iwc)) { + for (int i = 0; i < 20; i++) { + if (i > 0 && randomBoolean()) { + writer.flush(); + } + Document doc = new Document(); + doc.add(new StoredField("source", "hello world")); + if (syntheticRecoverySource) { + doc.add(new NumericDocValuesField("extra_source_size", randomIntBetween(10, 10000))); + } else { + doc.add(new StoredField("extra_source", "hello world")); + doc.add(new NumericDocValuesField("extra_source", 1)); + } + writer.addDocument(doc); } - Document doc = new Document(); - doc.add(new StoredField("source", "hello world")); - doc.add(new StoredField("extra_source", "hello world")); - doc.add(new NumericDocValuesField("extra_source", 1)); - writer.addDocument(doc); - } - writer.forceMerge(1); - writer.commit(); - try (DirectoryReader reader = DirectoryReader.open(writer)) { - assertEquals(1, reader.leaves().size()); - NumericDocValues extra_source = reader.leaves().get(0).reader().getNumericDocValues("extra_source"); - assertNotNull(extra_source); - StoredFields storedFields = reader.storedFields(); - for (int i = 0; i < reader.maxDoc(); i++) { - Document document = storedFields.document(i); - Set collect = document.getFields().stream().map(IndexableField::name).collect(Collectors.toSet()); - assertTrue(collect.contains("source")); - assertTrue(collect.contains("extra_source")); - assertEquals(i, extra_source.nextDoc()); + writer.forceMerge(1); + writer.commit(); + try (DirectoryReader reader = DirectoryReader.open(writer)) { + assertEquals(1, reader.leaves().size()); + String extraSourceDVName = syntheticRecoverySource ? "extra_source_size" : "extra_source"; + NumericDocValues extra_source = reader.leaves().get(0).reader().getNumericDocValues(extraSourceDVName); + assertNotNull(extra_source); + StoredFields storedFields = reader.storedFields(); + for (int i = 0; i < reader.maxDoc(); i++) { + Document document = storedFields.document(i); + Set collect = document.getFields().stream().map(IndexableField::name).collect(Collectors.toSet()); + assertTrue(collect.contains("source")); + assertThat(collect.contains("extra_source"), equalTo(syntheticRecoverySource == false)); + assertEquals(i, extra_source.nextDoc()); + } + assertEquals(DocIdSetIterator.NO_MORE_DOCS, extra_source.nextDoc()); } - assertEquals(DocIdSetIterator.NO_MORE_DOCS, extra_source.nextDoc()); } } } diff --git a/server/src/test/java/org/elasticsearch/index/engine/SearchBasedChangesSnapshotTests.java b/server/src/test/java/org/elasticsearch/index/engine/SearchBasedChangesSnapshotTests.java new file mode 100644 index 0000000000000..9cfa7321973a4 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/engine/SearchBasedChangesSnapshotTests.java @@ -0,0 +1,507 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.engine; + +import org.apache.lucene.index.NoMergePolicy; +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.mapper.MappingLookup; +import org.elasticsearch.index.mapper.ParsedDocument; +import org.elasticsearch.index.mapper.Uid; +import org.elasticsearch.index.store.Store; +import org.elasticsearch.index.translog.SnapshotMatchers; +import org.elasticsearch.index.translog.Translog; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.LongSupplier; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public abstract class SearchBasedChangesSnapshotTests extends EngineTestCase { + @Override + protected Settings indexSettings() { + return Settings.builder() + .put(super.indexSettings()) + .put(IndexSettings.INDEX_SOFT_DELETES_SETTING.getKey(), true) // always enable soft-deletes + .build(); + } + + protected abstract Translog.Snapshot newRandomSnapshot( + MappingLookup mappingLookup, + Engine.Searcher engineSearcher, + int searchBatchSize, + long fromSeqNo, + long toSeqNo, + boolean requiredFullRange, + boolean singleConsumer, + boolean accessStats, + IndexVersion indexVersionCreated + ) throws IOException; + + public void testBasics() throws Exception { + long fromSeqNo = randomNonNegativeLong(); + long toSeqNo = randomLongBetween(fromSeqNo, Long.MAX_VALUE); + // Empty engine + try ( + Translog.Snapshot snapshot = engine.newChangesSnapshot( + "test", + fromSeqNo, + toSeqNo, + true, + randomBoolean(), + randomBoolean(), + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) + ) + ) { + IllegalStateException error = expectThrows(IllegalStateException.class, () -> drainAll(snapshot)); + assertThat( + error.getMessage(), + containsString("Not all operations between from_seqno [" + fromSeqNo + "] and to_seqno [" + toSeqNo + "] found") + ); + } + try ( + Translog.Snapshot snapshot = engine.newChangesSnapshot( + "test", + fromSeqNo, + toSeqNo, + false, + randomBoolean(), + randomBoolean(), + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) + ) + ) { + assertThat(snapshot, SnapshotMatchers.size(0)); + } + int numOps = between(1, 100); + int refreshedSeqNo = -1; + for (int i = 0; i < numOps; i++) { + String id = Integer.toString(randomIntBetween(i, i + 5)); + ParsedDocument doc = parseDocument(engine.engineConfig.getMapperService(), id, null); + if (randomBoolean()) { + engine.index(indexForDoc(doc)); + } else { + engine.delete(new Engine.Delete(doc.id(), Uid.encodeId(doc.id()), primaryTerm.get())); + } + if (rarely()) { + if (randomBoolean()) { + engine.flush(); + } else { + engine.refresh("test"); + } + refreshedSeqNo = i; + } + } + if (refreshedSeqNo == -1) { + fromSeqNo = between(0, numOps); + toSeqNo = randomLongBetween(fromSeqNo, numOps * 2); + + Engine.Searcher searcher = engine.acquireSearcher("test", Engine.SearcherScope.INTERNAL); + try ( + Translog.Snapshot snapshot = newRandomSnapshot( + engine.engineConfig.getMapperService().mappingLookup(), + searcher, + between(1, SearchBasedChangesSnapshot.DEFAULT_BATCH_SIZE), + fromSeqNo, + toSeqNo, + false, + randomBoolean(), + randomBoolean(), + IndexVersion.current() + ) + ) { + searcher = null; + assertThat(snapshot, SnapshotMatchers.size(0)); + } finally { + IOUtils.close(searcher); + } + + searcher = engine.acquireSearcher("test", Engine.SearcherScope.INTERNAL); + try ( + Translog.Snapshot snapshot = newRandomSnapshot( + engine.engineConfig.getMapperService().mappingLookup(), + searcher, + between(1, SearchBasedChangesSnapshot.DEFAULT_BATCH_SIZE), + fromSeqNo, + toSeqNo, + true, + randomBoolean(), + randomBoolean(), + IndexVersion.current() + ) + ) { + searcher = null; + IllegalStateException error = expectThrows(IllegalStateException.class, () -> drainAll(snapshot)); + assertThat( + error.getMessage(), + containsString("Not all operations between from_seqno [" + fromSeqNo + "] and to_seqno [" + toSeqNo + "] found") + ); + } finally { + IOUtils.close(searcher); + } + } else { + fromSeqNo = randomLongBetween(0, refreshedSeqNo); + toSeqNo = randomLongBetween(refreshedSeqNo + 1, numOps * 2); + Engine.Searcher searcher = engine.acquireSearcher("test", Engine.SearcherScope.INTERNAL); + try ( + Translog.Snapshot snapshot = newRandomSnapshot( + engine.engineConfig.getMapperService().mappingLookup(), + searcher, + between(1, SearchBasedChangesSnapshot.DEFAULT_BATCH_SIZE), + fromSeqNo, + toSeqNo, + false, + randomBoolean(), + randomBoolean(), + IndexVersion.current() + ) + ) { + searcher = null; + assertThat(snapshot, SnapshotMatchers.containsSeqNoRange(fromSeqNo, refreshedSeqNo)); + } finally { + IOUtils.close(searcher); + } + searcher = engine.acquireSearcher("test", Engine.SearcherScope.INTERNAL); + try ( + Translog.Snapshot snapshot = newRandomSnapshot( + engine.engineConfig.getMapperService().mappingLookup(), + searcher, + between(1, SearchBasedChangesSnapshot.DEFAULT_BATCH_SIZE), + fromSeqNo, + toSeqNo, + true, + randomBoolean(), + randomBoolean(), + IndexVersion.current() + ) + ) { + searcher = null; + IllegalStateException error = expectThrows(IllegalStateException.class, () -> drainAll(snapshot)); + assertThat( + error.getMessage(), + containsString("Not all operations between from_seqno [" + fromSeqNo + "] and to_seqno [" + toSeqNo + "] found") + ); + } finally { + IOUtils.close(searcher); + } + toSeqNo = randomLongBetween(fromSeqNo, refreshedSeqNo); + searcher = engine.acquireSearcher("test", Engine.SearcherScope.INTERNAL); + try ( + Translog.Snapshot snapshot = newRandomSnapshot( + engine.engineConfig.getMapperService().mappingLookup(), + searcher, + between(1, SearchBasedChangesSnapshot.DEFAULT_BATCH_SIZE), + fromSeqNo, + toSeqNo, + true, + randomBoolean(), + randomBoolean(), + IndexVersion.current() + ) + ) { + searcher = null; + assertThat(snapshot, SnapshotMatchers.containsSeqNoRange(fromSeqNo, toSeqNo)); + } finally { + IOUtils.close(searcher); + } + } + // Get snapshot via engine will auto refresh + fromSeqNo = randomLongBetween(0, numOps - 1); + toSeqNo = randomLongBetween(fromSeqNo, numOps - 1); + try ( + Translog.Snapshot snapshot = engine.newChangesSnapshot( + "test", + fromSeqNo, + toSeqNo, + randomBoolean(), + randomBoolean(), + randomBoolean(), + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) + ) + ) { + assertThat(snapshot, SnapshotMatchers.containsSeqNoRange(fromSeqNo, toSeqNo)); + } + } + + /** + * A nested document is indexed into Lucene as multiple documents. While the root document has both sequence number and primary term, + * non-root documents don't have primary term but only sequence numbers. This test verifies that {@link LuceneChangesSnapshot} + * correctly skip non-root documents and returns at most one operation per sequence number. + */ + public void testSkipNonRootOfNestedDocuments() throws Exception { + Map seqNoToTerm = new HashMap<>(); + List operations = generateHistoryOnReplica(between(1, 100), randomBoolean(), randomBoolean(), randomBoolean()); + for (Engine.Operation op : operations) { + if (engine.getLocalCheckpointTracker().hasProcessed(op.seqNo()) == false) { + seqNoToTerm.put(op.seqNo(), op.primaryTerm()); + } + applyOperation(engine, op); + if (rarely()) { + engine.refresh("test"); + } + if (rarely()) { + engine.rollTranslogGeneration(); + } + if (rarely()) { + engine.flush(); + } + } + long maxSeqNo = engine.getLocalCheckpointTracker().getMaxSeqNo(); + engine.refresh("test"); + Engine.Searcher searcher = engine.acquireSearcher("test", Engine.SearcherScope.INTERNAL); + final boolean accessStats = randomBoolean(); + try ( + Translog.Snapshot snapshot = newRandomSnapshot( + engine.engineConfig.getMapperService().mappingLookup(), + searcher, + between(1, 100), + 0, + maxSeqNo, + false, + randomBoolean(), + accessStats, + IndexVersion.current() + ) + ) { + if (accessStats) { + assertThat(snapshot.totalOperations(), equalTo(seqNoToTerm.size())); + } + Translog.Operation op; + while ((op = snapshot.next()) != null) { + assertThat(op.toString(), op.primaryTerm(), equalTo(seqNoToTerm.get(op.seqNo()))); + } + assertThat(snapshot.skippedOperations(), equalTo(0)); + } + } + + public void testUpdateAndReadChangesConcurrently() throws Exception { + Follower[] followers = new Follower[between(1, 3)]; + CountDownLatch readyLatch = new CountDownLatch(followers.length + 1); + AtomicBoolean isDone = new AtomicBoolean(); + for (int i = 0; i < followers.length; i++) { + followers[i] = new Follower(engine, isDone, readyLatch); + followers[i].start(); + } + boolean onPrimary = randomBoolean(); + List operations = new ArrayList<>(); + int numOps = frequently() ? scaledRandomIntBetween(1, 1500) : scaledRandomIntBetween(5000, 20_000); + for (int i = 0; i < numOps; i++) { + String id = Integer.toString(randomIntBetween(0, randomBoolean() ? 10 : numOps * 2)); + ParsedDocument doc = parseDocument(engine.engineConfig.getMapperService(), id, randomAlphaOfLengthBetween(1, 5)); + final Engine.Operation op; + if (onPrimary) { + if (randomBoolean()) { + op = new Engine.Index(newUid(doc), primaryTerm.get(), doc); + } else { + op = new Engine.Delete(doc.id(), Uid.encodeId(doc.id()), primaryTerm.get()); + } + } else { + if (randomBoolean()) { + op = replicaIndexForDoc(doc, randomNonNegativeLong(), i, randomBoolean()); + } else { + op = replicaDeleteForDoc(doc.id(), randomNonNegativeLong(), i, randomNonNegativeLong()); + } + } + operations.add(op); + } + readyLatch.countDown(); + readyLatch.await(); + Randomness.shuffle(operations); + concurrentlyApplyOps(operations, engine); + assertThat(engine.getLocalCheckpointTracker().getProcessedCheckpoint(), equalTo(operations.size() - 1L)); + isDone.set(true); + for (Follower follower : followers) { + follower.join(); + IOUtils.close(follower.engine, follower.engine.store); + } + } + + class Follower extends Thread { + private final InternalEngine leader; + private final InternalEngine engine; + private final TranslogHandler translogHandler; + private final AtomicBoolean isDone; + private final CountDownLatch readLatch; + + Follower(InternalEngine leader, AtomicBoolean isDone, CountDownLatch readLatch) throws IOException { + this.leader = leader; + this.isDone = isDone; + this.readLatch = readLatch; + this.engine = createEngine(defaultSettings, createStore(), createTempDir(), newMergePolicy()); + this.translogHandler = new TranslogHandler(engine.engineConfig.getMapperService()); + } + + void pullOperations(InternalEngine follower) throws IOException { + long leaderCheckpoint = leader.getLocalCheckpointTracker().getProcessedCheckpoint(); + long followerCheckpoint = follower.getLocalCheckpointTracker().getProcessedCheckpoint(); + if (followerCheckpoint < leaderCheckpoint) { + long fromSeqNo = followerCheckpoint + 1; + long batchSize = randomLongBetween(0, 100); + long toSeqNo = Math.min(fromSeqNo + batchSize, leaderCheckpoint); + try ( + Translog.Snapshot snapshot = leader.newChangesSnapshot( + "test", + fromSeqNo, + toSeqNo, + true, + randomBoolean(), + randomBoolean(), + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) + ) + ) { + translogHandler.run(follower, snapshot); + } + } + } + + @Override + public void run() { + try { + readLatch.countDown(); + readLatch.await(); + while (isDone.get() == false + || engine.getLocalCheckpointTracker().getProcessedCheckpoint() < leader.getLocalCheckpointTracker() + .getProcessedCheckpoint()) { + pullOperations(engine); + } + assertConsistentHistoryBetweenTranslogAndLuceneIndex(engine); + // have to verify without source since we are randomly testing without _source + List docsWithoutSourceOnFollower = getDocIds(engine, true).stream() + .map(d -> new DocIdSeqNoAndSource(d.id(), null, d.seqNo(), d.primaryTerm(), d.version())) + .toList(); + List docsWithoutSourceOnLeader = getDocIds(leader, true).stream() + .map(d -> new DocIdSeqNoAndSource(d.id(), null, d.seqNo(), d.primaryTerm(), d.version())) + .toList(); + assertThat(docsWithoutSourceOnFollower, equalTo(docsWithoutSourceOnLeader)); + } catch (Exception ex) { + throw new AssertionError(ex); + } + } + } + + private List drainAll(Translog.Snapshot snapshot) throws IOException { + List operations = new ArrayList<>(); + Translog.Operation op; + while ((op = snapshot.next()) != null) { + final Translog.Operation newOp = op; + logger.trace("Reading [{}]", op); + assert operations.stream().allMatch(o -> o.seqNo() < newOp.seqNo()) : "Operations [" + operations + "], op [" + op + "]"; + operations.add(newOp); + } + return operations; + } + + public void testOverFlow() throws Exception { + long fromSeqNo = randomLongBetween(0, 5); + long toSeqNo = randomLongBetween(Long.MAX_VALUE - 5, Long.MAX_VALUE); + try ( + Translog.Snapshot snapshot = engine.newChangesSnapshot( + "test", + fromSeqNo, + toSeqNo, + true, + randomBoolean(), + randomBoolean(), + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) + ) + ) { + IllegalStateException error = expectThrows(IllegalStateException.class, () -> drainAll(snapshot)); + assertThat( + error.getMessage(), + containsString("Not all operations between from_seqno [" + fromSeqNo + "] and to_seqno [" + toSeqNo + "] found") + ); + } + } + + public void testStats() throws Exception { + try (Store store = createStore(); Engine engine = createEngine(defaultSettings, store, createTempDir(), NoMergePolicy.INSTANCE)) { + int numOps = between(100, 5000); + long startingSeqNo = randomLongBetween(0, Integer.MAX_VALUE); + List operations = generateHistoryOnReplica( + numOps, + startingSeqNo, + randomBoolean(), + randomBoolean(), + randomBoolean() + ); + applyOperations(engine, operations); + + LongSupplier fromSeqNo = () -> { + if (randomBoolean()) { + return 0L; + } else if (randomBoolean()) { + return startingSeqNo; + } else { + return randomLongBetween(0, startingSeqNo); + } + }; + + LongSupplier toSeqNo = () -> { + final long maxSeqNo = engine.getSeqNoStats(-1).getMaxSeqNo(); + if (randomBoolean()) { + return maxSeqNo; + } else if (randomBoolean()) { + return Long.MAX_VALUE; + } else { + return randomLongBetween(maxSeqNo, Long.MAX_VALUE); + } + }; + // Can't access stats if didn't request it + try ( + Translog.Snapshot snapshot = engine.newChangesSnapshot( + "test", + fromSeqNo.getAsLong(), + toSeqNo.getAsLong(), + false, + randomBoolean(), + false, + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) + ) + ) { + IllegalStateException error = expectThrows(IllegalStateException.class, snapshot::totalOperations); + assertThat(error.getMessage(), equalTo("Access stats of a snapshot created with [access_stats] is false")); + final List translogOps = drainAll(snapshot); + assertThat(translogOps, hasSize(numOps)); + error = expectThrows(IllegalStateException.class, snapshot::totalOperations); + assertThat(error.getMessage(), equalTo("Access stats of a snapshot created with [access_stats] is false")); + } + // Access stats and operations + try ( + Translog.Snapshot snapshot = engine.newChangesSnapshot( + "test", + fromSeqNo.getAsLong(), + toSeqNo.getAsLong(), + false, + randomBoolean(), + true, + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) + ) + ) { + assertThat(snapshot.totalOperations(), equalTo(numOps)); + final List translogOps = drainAll(snapshot); + assertThat(translogOps, hasSize(numOps)); + assertThat(snapshot.totalOperations(), equalTo(numOps)); + } + // Verify count + assertThat(engine.countChanges("test", fromSeqNo.getAsLong(), toSeqNo.getAsLong()), equalTo(numOps)); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/mapper/SourceFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/SourceFieldMapperTests.java index 4d6a30849e263..bec9cb5fa9be0 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/SourceFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/SourceFieldMapperTests.java @@ -28,6 +28,7 @@ import java.io.IOException; import java.util.List; +import java.util.Locale; import java.util.Map; import static org.elasticsearch.indices.recovery.RecoverySettings.INDICES_RECOVERY_SOURCE_ENABLED_SETTING; @@ -405,16 +406,114 @@ public void testRecoverySourceWithSourceDisabled() throws IOException { } } - public void testRecoverySourceWithSyntheticSource() throws IOException { + public void testRecoverySourceWitInvalidSettings() { { - MapperService mapperService = createMapperService( - topMapping(b -> b.startObject(SourceFieldMapper.NAME).field("mode", "synthetic").endObject()) + Settings settings = Settings.builder().put(IndexSettings.RECOVERY_USE_SYNTHETIC_SOURCE_SETTING.getKey(), true).build(); + IllegalArgumentException exc = expectThrows( + IllegalArgumentException.class, + () -> createMapperService(settings, topMapping(b -> {})) + ); + assertThat( + exc.getMessage(), + containsString( + String.format( + Locale.ROOT, + "The setting [%s] is only permitted", + IndexSettings.RECOVERY_USE_SYNTHETIC_SOURCE_SETTING.getKey() + ) + ) + ); + } + + { + Settings settings = Settings.builder() + .put(SourceFieldMapper.INDEX_MAPPER_SOURCE_MODE_SETTING.getKey(), SourceFieldMapper.Mode.STORED.toString()) + .put(IndexSettings.RECOVERY_USE_SYNTHETIC_SOURCE_SETTING.getKey(), true) + .build(); + IllegalArgumentException exc = expectThrows( + IllegalArgumentException.class, + () -> createMapperService(settings, topMapping(b -> {})) + ); + assertThat( + exc.getMessage(), + containsString( + String.format( + Locale.ROOT, + "The setting [%s] is only permitted", + IndexSettings.RECOVERY_USE_SYNTHETIC_SOURCE_SETTING.getKey() + ) + ) + ); + } + { + Settings settings = Settings.builder() + .put(IndexSettings.MODE.getKey(), IndexMode.STANDARD.toString()) + .put(IndexSettings.RECOVERY_USE_SYNTHETIC_SOURCE_SETTING.getKey(), true) + .build(); + IllegalArgumentException exc = expectThrows( + IllegalArgumentException.class, + () -> createMapperService(settings, topMapping(b -> {})) + ); + assertThat( + exc.getMessage(), + containsString( + String.format( + Locale.ROOT, + "The setting [%s] is only permitted", + IndexSettings.RECOVERY_USE_SYNTHETIC_SOURCE_SETTING.getKey() + ) + ) ); + } + { + Settings settings = Settings.builder() + .put(SourceFieldMapper.INDEX_MAPPER_SOURCE_MODE_SETTING.getKey(), SourceFieldMapper.Mode.SYNTHETIC.toString()) + .put(IndexSettings.RECOVERY_USE_SYNTHETIC_SOURCE_SETTING.getKey(), true) + .build(); + IllegalArgumentException exc = expectThrows( + IllegalArgumentException.class, + () -> createMapperService( + IndexVersionUtils.randomPreviousCompatibleVersion(random(), IndexVersions.USE_SYNTHETIC_SOURCE_FOR_RECOVERY), + settings, + () -> false, + topMapping(b -> {}) + ) + ); + assertThat( + exc.getMessage(), + containsString( + String.format( + Locale.ROOT, + "The setting [%s] is unavailable on this cluster", + IndexSettings.RECOVERY_USE_SYNTHETIC_SOURCE_SETTING.getKey() + ) + ) + ); + } + } + + public void testRecoverySourceWithSyntheticSource() throws IOException { + { + Settings settings = Settings.builder() + .put(SourceFieldMapper.INDEX_MAPPER_SOURCE_MODE_SETTING.getKey(), SourceFieldMapper.Mode.SYNTHETIC.toString()) + .build(); + MapperService mapperService = createMapperService(settings, topMapping(b -> {})); DocumentMapper docMapper = mapperService.documentMapper(); - ParsedDocument doc = docMapper.parse(source(b -> { b.field("field1", "value1"); })); + ParsedDocument doc = docMapper.parse(source(b -> b.field("field1", "value1"))); assertNotNull(doc.rootDoc().getField("_recovery_source")); assertThat(doc.rootDoc().getField("_recovery_source").binaryValue(), equalTo(new BytesRef("{\"field1\":\"value1\"}"))); } + { + Settings settings = Settings.builder() + .put(SourceFieldMapper.INDEX_MAPPER_SOURCE_MODE_SETTING.getKey(), SourceFieldMapper.Mode.SYNTHETIC.toString()) + .put(IndexSettings.RECOVERY_USE_SYNTHETIC_SOURCE_SETTING.getKey(), true) + .build(); + MapperService mapperService = createMapperService(settings, topMapping(b -> {})); + DocumentMapper docMapper = mapperService.documentMapper(); + ParsedDocument doc = docMapper.parse(source(b -> b.field("field1", "value1"))); + assertNotNull(doc.rootDoc().getField("_recovery_source_size")); + assertThat(doc.rootDoc().getField("_recovery_source_size").numericValue(), equalTo(19L)); + } { Settings settings = Settings.builder().put(INDICES_RECOVERY_SOURCE_ENABLED_SETTING.getKey(), false).build(); MapperService mapperService = createMapperService( @@ -436,6 +535,17 @@ public void testRecoverySourceWithLogs() throws IOException { assertNotNull(doc.rootDoc().getField("_recovery_source")); assertThat(doc.rootDoc().getField("_recovery_source").binaryValue(), equalTo(new BytesRef("{\"@timestamp\":\"2012-02-13\"}"))); } + { + Settings settings = Settings.builder() + .put(IndexSettings.MODE.getKey(), IndexMode.LOGSDB.getName()) + .put(IndexSettings.RECOVERY_USE_SYNTHETIC_SOURCE_SETTING.getKey(), true) + .build(); + MapperService mapperService = createMapperService(settings, mapping(b -> {})); + DocumentMapper docMapper = mapperService.documentMapper(); + ParsedDocument doc = docMapper.parse(source(b -> { b.field("@timestamp", "2012-02-13"); })); + assertNotNull(doc.rootDoc().getField("_recovery_source_size")); + assertThat(doc.rootDoc().getField("_recovery_source_size").numericValue(), equalTo(27L)); + } { Settings settings = Settings.builder() .put(IndexSettings.MODE.getKey(), IndexMode.LOGSDB.getName()) diff --git a/server/src/test/java/org/elasticsearch/index/replication/IndexLevelReplicationTests.java b/server/src/test/java/org/elasticsearch/index/replication/IndexLevelReplicationTests.java index 49b1362436ec7..0357d02dbbb98 100644 --- a/server/src/test/java/org/elasticsearch/index/replication/IndexLevelReplicationTests.java +++ b/server/src/test/java/org/elasticsearch/index/replication/IndexLevelReplicationTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.iterable.Iterables; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.IndexSettings; @@ -486,7 +487,8 @@ protected EngineFactory getEngineFactory(ShardRouting routing) { Long.MAX_VALUE, false, randomBoolean(), - randomBoolean() + randomBoolean(), + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) ) ) { assertThat(snapshot, SnapshotMatchers.containsOperationsInAnyOrder(expectedTranslogOps)); @@ -513,7 +515,8 @@ protected EngineFactory getEngineFactory(ShardRouting routing) { Long.MAX_VALUE, false, randomBoolean(), - randomBoolean() + randomBoolean(), + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) ) ) { assertThat(snapshot, SnapshotMatchers.containsOperationsInAnyOrder(expectedTranslogOps)); @@ -608,7 +611,17 @@ public void testSeqNoCollision() throws Exception { shards.promoteReplicaToPrimary(replica2).get(); logger.info("--> Recover replica3 from replica2"); recoverReplica(replica3, replica2, true); - try (Translog.Snapshot snapshot = replica3.newChangesSnapshot("test", 0, Long.MAX_VALUE, false, randomBoolean(), true)) { + try ( + Translog.Snapshot snapshot = replica3.newChangesSnapshot( + "test", + 0, + Long.MAX_VALUE, + false, + randomBoolean(), + true, + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) + ) + ) { assertThat(snapshot.totalOperations(), equalTo(initDocs + 1)); final List expectedOps = new ArrayList<>(initOperations); expectedOps.add(op2); diff --git a/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java b/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java index d480f7bfc8d7f..eacb4cf35a422 100644 --- a/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java +++ b/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java @@ -1819,7 +1819,15 @@ public void testShardFieldStats() throws IOException { shard.refresh("test"); } else { // trigger internal refresh - shard.newChangesSnapshot("test", 0, Long.MAX_VALUE, false, randomBoolean(), randomBoolean()).close(); + shard.newChangesSnapshot( + "test", + 0, + Long.MAX_VALUE, + false, + randomBoolean(), + randomBoolean(), + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) + ).close(); } assertThat(shard.getShardFieldStats(), sameInstance(stats)); // index more docs @@ -1837,7 +1845,15 @@ public void testShardFieldStats() throws IOException { shard.refresh("test"); } else { // trigger internal refresh - shard.newChangesSnapshot("test", 0, Long.MAX_VALUE, false, randomBoolean(), randomBoolean()).close(); + shard.newChangesSnapshot( + "test", + 0, + Long.MAX_VALUE, + false, + randomBoolean(), + randomBoolean(), + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) + ).close(); } stats = shard.getShardFieldStats(); assertThat(stats.numSegments(), equalTo(2)); diff --git a/server/src/test/java/org/elasticsearch/index/shard/RefreshListenersTests.java b/server/src/test/java/org/elasticsearch/index/shard/RefreshListenersTests.java index 9e7f5fbbce1a3..ca616dc619ec9 100644 --- a/server/src/test/java/org/elasticsearch/index/shard/RefreshListenersTests.java +++ b/server/src/test/java/org/elasticsearch/index/shard/RefreshListenersTests.java @@ -158,7 +158,7 @@ public void onFailedEngine(String reason, @Nullable Exception e) { System::nanoTime, null, true, - null + EngineTestCase.createMapperService() ); engine = new InternalEngine(config); EngineTestCase.recoverFromTranslog(engine, (e, s) -> 0, Long.MAX_VALUE); diff --git a/server/src/test/java/org/elasticsearch/indices/recovery/RecoveryTests.java b/server/src/test/java/org/elasticsearch/indices/recovery/RecoveryTests.java index 315eaaf9ffaf1..aef58cee04899 100644 --- a/server/src/test/java/org/elasticsearch/indices/recovery/RecoveryTests.java +++ b/server/src/test/java/org/elasticsearch/indices/recovery/RecoveryTests.java @@ -29,6 +29,7 @@ import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.lucene.uid.Versions; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.MergePolicyConfig; import org.elasticsearch.index.VersionType; @@ -211,7 +212,8 @@ public void testRecoveryWithOutOfOrderDeleteWithSoftDeletes() throws Exception { Long.MAX_VALUE, false, randomBoolean(), - randomBoolean() + randomBoolean(), + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) ) ) { assertThat(snapshot, SnapshotMatchers.size(6)); diff --git a/test/framework/src/main/java/org/elasticsearch/index/engine/EngineTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/engine/EngineTestCase.java index 87c566d543d0f..46f6a0b503bfb 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/engine/EngineTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/engine/EngineTestCase.java @@ -57,13 +57,12 @@ import org.elasticsearch.cluster.routing.AllocationId; import org.elasticsearch.common.CheckedBiFunction; import org.elasticsearch.common.Randomness; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.compress.CompressedXContent; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.lucene.uid.Versions; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.core.CheckedFunction; import org.elasticsearch.core.IOUtils; @@ -142,6 +141,7 @@ import static org.elasticsearch.index.engine.Engine.Operation.Origin.PEER_RECOVERY; import static org.elasticsearch.index.engine.Engine.Operation.Origin.PRIMARY; import static org.elasticsearch.index.engine.Engine.Operation.Origin.REPLICA; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.lessThanOrEqualTo; @@ -160,6 +160,8 @@ public abstract class EngineTestCase extends ESTestCase { protected Store store; protected Store storeReplica; + protected MapperService mapperService; + protected InternalEngine engine; protected InternalEngine replicaEngine; @@ -198,6 +200,27 @@ protected Settings indexSettings() { .build(); } + protected String defaultMapping() { + return """ + { + "dynamic": false, + "properties": { + "value": { + "type": "keyword" + }, + "nested_field": { + "type": "nested", + "properties": { + "field-0": { + "type": "keyword" + } + } + } + } + } + """; + } + @Override @Before public void setUp() throws Exception { @@ -212,15 +235,16 @@ public void setUp() throws Exception { } else { codecName = "default"; } - defaultSettings = IndexSettingsModule.newIndexSettings("test", indexSettings()); + defaultSettings = IndexSettingsModule.newIndexSettings("index", indexSettings()); threadPool = new TestThreadPool(getClass().getName()); store = createStore(); storeReplica = createStore(); Lucene.cleanLuceneIndex(store.directory()); Lucene.cleanLuceneIndex(storeReplica.directory()); primaryTranslogDir = createTempDir("translog-primary"); - translogHandler = createTranslogHandler(defaultSettings); - engine = createEngine(store, primaryTranslogDir); + mapperService = createMapperService(defaultSettings.getSettings(), defaultMapping()); + translogHandler = createTranslogHandler(mapperService); + engine = createEngine(defaultSettings, store, primaryTranslogDir, newMergePolicy()); LiveIndexWriterConfig currentIndexWriterConfig = engine.getCurrentIndexWriterConfig(); assertEquals(engine.config().getCodec().getName(), codecService.codec(codecName).getName()); @@ -230,7 +254,7 @@ public void setUp() throws Exception { engine.config().setEnableGcDeletes(false); } replicaTranslogDir = createTempDir("translog-replica"); - replicaEngine = createEngine(storeReplica, replicaTranslogDir); + replicaEngine = createEngine(defaultSettings, storeReplica, replicaTranslogDir, newMergePolicy()); currentIndexWriterConfig = replicaEngine.getCurrentIndexWriterConfig(); assertEquals(replicaEngine.config().getCodec().getName(), codecService.codec(codecName).getName()); @@ -433,37 +457,9 @@ protected static ParsedDocument testParsedDocument( ); } - public static CheckedBiFunction nestedParsedDocFactory() throws Exception { - final MapperService mapperService = createMapperService(); - final String nestedMapping = Strings.toString( - XContentFactory.jsonBuilder() - .startObject() - .startObject("type") - .startObject("properties") - .startObject("nested_field") - .field("type", "nested") - .endObject() - .endObject() - .endObject() - .endObject() - ); - final DocumentMapper nestedMapper = mapperService.merge( - "type", - new CompressedXContent(nestedMapping), - MapperService.MergeReason.MAPPING_UPDATE - ); - return (docId, nestedFieldValues) -> { - final XContentBuilder source = XContentFactory.jsonBuilder().startObject().field("field", "value"); - if (nestedFieldValues > 0) { - XContentBuilder nestedField = source.startObject("nested_field"); - for (int i = 0; i < nestedFieldValues; i++) { - nestedField.field("field-" + i, "value-" + i); - } - source.endObject(); - } - source.endObject(); - return nestedMapper.parse(new SourceToParse(docId, BytesReference.bytes(source), XContentType.JSON)); - }; + public static ParsedDocument parseDocument(MapperService mapperService, String id, String routing) { + SourceToParse sourceToParse = new SourceToParse(id, new BytesArray("{ \"value\" : \"test\" }"), XContentType.JSON, routing); + return mapperService.documentMapper().parse(sourceToParse); } protected Store createStore() throws IOException { @@ -500,8 +496,8 @@ protected Translog createTranslog(Path translogPath, LongSupplier primaryTermSup ); } - protected TranslogHandler createTranslogHandler(IndexSettings indexSettings) { - return new TranslogHandler(xContentRegistry(), indexSettings); + protected TranslogHandler createTranslogHandler(MapperService mapperService) { + return new TranslogHandler(mapperService); } protected InternalEngine createEngine(Store store, Path translogPath) throws IOException { @@ -857,7 +853,7 @@ public EngineConfig config( this::relativeTimeInNanos, indexCommitListener, true, - null + mapperService ); } @@ -1031,6 +1027,22 @@ public static List generateSingleDocHistory( return ops; } + private CheckedBiFunction nestedParsedDocFactory(MapperService mapperService) { + final DocumentMapper nestedMapper = mapperService.documentMapper(); + return (docId, nestedFieldValues) -> { + final XContentBuilder source = XContentFactory.jsonBuilder().startObject().field("value", "test"); + if (nestedFieldValues > 0) { + XContentBuilder nestedField = source.startObject("nested_field"); + for (int i = 0; i < nestedFieldValues; i++) { + nestedField.field("field-" + i, "value-" + i); + } + source.endObject(); + } + source.endObject(); + return nestedMapper.parse(new SourceToParse(docId, BytesReference.bytes(source), XContentType.JSON)); + }; + } + public List generateHistoryOnReplica( int numOps, boolean allowGapInSeqNo, @@ -1050,7 +1062,9 @@ public List generateHistoryOnReplica( long seqNo = startingSeqNo; final int maxIdValue = randomInt(numOps * 2); final List operations = new ArrayList<>(numOps); - CheckedBiFunction nestedParsedDocFactory = nestedParsedDocFactory(); + CheckedBiFunction nestedParsedDocFactory = nestedParsedDocFactory( + engine.engineConfig.getMapperService() + ); for (int i = 0; i < numOps; i++) { final String id = Integer.toString(randomInt(maxIdValue)); final Engine.Operation.TYPE opType = randomFrom(Engine.Operation.TYPE.values()); @@ -1059,7 +1073,9 @@ public List generateHistoryOnReplica( final long startTime = threadPool.relativeTimeInNanos(); final int copies = allowDuplicate && rarely() ? between(2, 4) : 1; for (int copy = 0; copy < copies; copy++) { - final ParsedDocument doc = isNestedDoc ? nestedParsedDocFactory.apply(id, nestedValues) : createParsedDoc(id, null); + final ParsedDocument doc = isNestedDoc + ? nestedParsedDocFactory.apply(id, nestedValues) + : parseDocument(engine.engineConfig.getMapperService(), id, null); switch (opType) { case INDEX -> operations.add( new Engine.Index( @@ -1274,7 +1290,17 @@ public static List getDocIds(Engine engine, boolean refresh */ public static List readAllOperationsInLucene(Engine engine) throws IOException { final List operations = new ArrayList<>(); - try (Translog.Snapshot snapshot = engine.newChangesSnapshot("test", 0, Long.MAX_VALUE, false, randomBoolean(), randomBoolean())) { + try ( + Translog.Snapshot snapshot = engine.newChangesSnapshot( + "test", + 0, + Long.MAX_VALUE, + false, + randomBoolean(), + randomBoolean(), + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) + ) + ) { Translog.Operation op; while ((op = snapshot.next()) != null) { operations.add(op); @@ -1345,7 +1371,15 @@ public static void assertConsistentHistoryBetweenTranslogAndLuceneIndex(Engine e assertThat(luceneOp.toString(), luceneOp.primaryTerm(), equalTo(translogOp.primaryTerm())); assertThat(luceneOp.opType(), equalTo(translogOp.opType())); if (luceneOp.opType() == Translog.Operation.Type.INDEX) { - assertThat(((Translog.Index) luceneOp).source(), equalTo(((Translog.Index) translogOp).source())); + if (engine.engineConfig.getIndexSettings().isRecoverySourceSyntheticEnabled()) { + assertToXContentEquivalent( + ((Translog.Index) luceneOp).source(), + ((Translog.Index) translogOp).source(), + XContentFactory.xContentType(((Translog.Index) luceneOp).source().array()) + ); + } else { + assertThat(((Translog.Index) luceneOp).source(), equalTo(((Translog.Index) translogOp).source())); + } } } } @@ -1401,15 +1435,19 @@ public static void assertAtMostOneLuceneDocumentPerSequenceNumber(IndexSettings } public static MapperService createMapperService() throws IOException { - IndexMetadata indexMetadata = IndexMetadata.builder("test") - .settings(indexSettings(1, 1).put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())) - .putMapping("{\"properties\": {}}") + return createMapperService(Settings.EMPTY, "{}"); + } + + public static MapperService createMapperService(Settings settings, String mappings) throws IOException { + IndexMetadata indexMetadata = IndexMetadata.builder("index") + .settings(indexSettings(1, 1).put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).put(settings)) + .putMapping(mappings) .build(); MapperService mapperService = MapperTestUtils.newMapperService( new NamedXContentRegistry(ClusterModule.getNamedXWriteables()), createTempDir(), - Settings.EMPTY, - "test" + indexMetadata.getSettings(), + "index" ); mapperService.merge(indexMetadata, MapperService.MergeReason.MAPPING_UPDATE); return mapperService; diff --git a/test/framework/src/main/java/org/elasticsearch/index/engine/TranslogHandler.java b/test/framework/src/main/java/org/elasticsearch/index/engine/TranslogHandler.java index 57cca12f99c41..33c745de25438 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/engine/TranslogHandler.java +++ b/test/framework/src/main/java/org/elasticsearch/index/engine/TranslogHandler.java @@ -43,6 +43,10 @@ long appliedOperations() { return appliedOperations.get(); } + public TranslogHandler(MapperService mapperService) { + this.mapperService = mapperService; + } + public TranslogHandler(NamedXContentRegistry xContentRegistry, IndexSettings indexSettings) { SimilarityService similarityService = new SimilarityService(indexSettings, null, emptyMap()); MapperRegistry mapperRegistry = new IndicesModule(emptyList()).getMapperRegistry(); diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java index 29bb3b15a9f86..7dcbbce9fa8e0 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java @@ -9,9 +9,12 @@ package org.elasticsearch.index.mapper; +import org.apache.lucene.analysis.standard.StandardAnalyzer; +import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.IndexableFieldType; import org.apache.lucene.index.LeafReader; @@ -20,7 +23,11 @@ import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.UsageTrackingQueryCachingPolicy; +import org.apache.lucene.search.similarities.BM25Similarity; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.tests.index.RandomIndexWriter; @@ -30,11 +37,14 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexVersions; +import org.elasticsearch.index.engine.Engine; +import org.elasticsearch.index.engine.LuceneSyntheticSourceChangesSnapshot; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.fielddata.IndexFieldDataCache; @@ -43,6 +53,7 @@ import org.elasticsearch.index.fieldvisitor.StoredFieldLoader; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.index.termvectors.TermVectorsService; +import org.elasticsearch.index.translog.Translog; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptContext; @@ -1130,6 +1141,11 @@ public final void testSyntheticSource() throws IOException { assertSyntheticSource(syntheticSourceSupport(shouldUseIgnoreMalformed()).example(5)); } + public final void testSyntheticSourceWithTranslogSnapshot() throws IOException { + assertSyntheticSourceWithTranslogSnapshot(syntheticSourceSupport(shouldUseIgnoreMalformed()), true); + assertSyntheticSourceWithTranslogSnapshot(syntheticSourceSupport(shouldUseIgnoreMalformed()), false); + } + public void testSyntheticSourceIgnoreMalformedExamples() throws IOException { assumeTrue("type doesn't support ignore_malformed", supportsIgnoreMalformed()); // We need to call this in order to hit the assumption inside so that @@ -1155,6 +1171,71 @@ private void assertSyntheticSource(SyntheticSourceExample example) throws IOExce assertThat(syntheticSource(mapper, example::buildInput), equalTo(example.expected())); } + private void assertSyntheticSourceWithTranslogSnapshot(SyntheticSourceSupport support, boolean doIndexSort) throws IOException { + var firstExample = support.example(1); + int maxDocs = randomIntBetween(20, 50); + var settings = Settings.builder() + .put(SourceFieldMapper.INDEX_MAPPER_SOURCE_MODE_SETTING.getKey(), SourceFieldMapper.Mode.SYNTHETIC) + .put(IndexSettings.RECOVERY_USE_SYNTHETIC_SOURCE_SETTING.getKey(), true) + .build(); + var mapperService = createMapperService(getVersion(), settings, () -> true, mapping(b -> { + b.startObject("field"); + firstExample.mapping().accept(b); + b.endObject(); + })); + var docMapper = mapperService.documentMapper(); + try (var directory = newDirectory()) { + List examples = new ArrayList<>(); + IndexWriterConfig config = newIndexWriterConfig(random(), new StandardAnalyzer()); + config.setIndexSort(new Sort(new SortField("sort", SortField.Type.LONG))); + try (var iw = new RandomIndexWriter(random(), directory, config)) { + for (int seqNo = 0; seqNo < maxDocs; seqNo++) { + var example = support.example(randomIntBetween(1, 5)); + examples.add(example); + var doc = docMapper.parse(source(example::buildInput)); + assertNull(doc.dynamicMappingsUpdate()); + doc.updateSeqID(seqNo, 1); + doc.version().setLongValue(0); + if (doIndexSort) { + doc.rootDoc().add(new NumericDocValuesField("sort", randomLong())); + } + iw.addDocuments(doc.docs()); + if (frequently()) { + iw.flush(); + } + } + } + try (var indexReader = wrapInMockESDirectoryReader(DirectoryReader.open(directory))) { + int start = randomBoolean() ? 0 : randomIntBetween(1, maxDocs - 10); + var snapshot = new LuceneSyntheticSourceChangesSnapshot( + mapperService.mappingLookup(), + new Engine.Searcher( + "recovery", + indexReader, + new BM25Similarity(), + null, + new UsageTrackingQueryCachingPolicy(), + () -> {} + ), + randomIntBetween(1, maxDocs), + randomLongBetween(0, ByteSizeValue.ofBytes(Integer.MAX_VALUE).getBytes()), + start, + maxDocs, + true, + randomBoolean(), + IndexVersion.current() + ); + for (int i = start; i < maxDocs; i++) { + var example = examples.get(i); + var op = snapshot.next(); + if (op instanceof Translog.Index opIndex) { + assertThat(opIndex.source().utf8ToString(), equalTo(example.expected())); + } + } + } + } + } + protected boolean supportsEmptyInputArray() { return true; } diff --git a/test/framework/src/main/java/org/elasticsearch/indices/recovery/AbstractIndexRecoveryIntegTestCase.java b/test/framework/src/main/java/org/elasticsearch/indices/recovery/AbstractIndexRecoveryIntegTestCase.java index ec85feb200984..59c44925f920f 100644 --- a/test/framework/src/main/java/org/elasticsearch/indices/recovery/AbstractIndexRecoveryIntegTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/indices/recovery/AbstractIndexRecoveryIntegTestCase.java @@ -26,7 +26,6 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.MockEngineFactoryPlugin; import org.elasticsearch.index.shard.ShardId; -import org.elasticsearch.node.RecoverySettingsChunkSizePlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.repositories.blobstore.BlobStoreRepository; import org.elasticsearch.snapshots.SnapshotState; @@ -75,7 +74,6 @@ protected Collection> nodePlugins() { return Arrays.asList( MockTransportService.TestPlugin.class, MockFSIndexStore.TestPlugin.class, - RecoverySettingsChunkSizePlugin.class, InternalSettingsPlugin.class, MockEngineFactoryPlugin.class ); diff --git a/test/framework/src/main/java/org/elasticsearch/node/MockNode.java b/test/framework/src/main/java/org/elasticsearch/node/MockNode.java index d3bfacdf7691a..510aa25f9b98e 100644 --- a/test/framework/src/main/java/org/elasticsearch/node/MockNode.java +++ b/test/framework/src/main/java/org/elasticsearch/node/MockNode.java @@ -28,7 +28,6 @@ import org.elasticsearch.indices.ExecutorSelector; import org.elasticsearch.indices.IndicesService; import org.elasticsearch.indices.breaker.CircuitBreakerService; -import org.elasticsearch.indices.recovery.RecoverySettings; import org.elasticsearch.plugins.MockPluginsService; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.PluginsLoader; @@ -194,16 +193,6 @@ protected TransportService newTransportService( } } - @Override - void processRecoverySettings(PluginsService pluginsService, ClusterSettings clusterSettings, RecoverySettings recoverySettings) { - if (pluginsService.filterPlugins(RecoverySettingsChunkSizePlugin.class).findAny().isEmpty() == false) { - clusterSettings.addSettingsUpdateConsumer( - RecoverySettingsChunkSizePlugin.CHUNK_SIZE_SETTING, - recoverySettings::setChunkSize - ); - } - } - @Override protected ClusterInfoService newClusterInfoService( PluginsService pluginsService, diff --git a/test/framework/src/main/java/org/elasticsearch/node/RecoverySettingsChunkSizePlugin.java b/test/framework/src/main/java/org/elasticsearch/node/RecoverySettingsChunkSizePlugin.java deleted file mode 100644 index 489c9f704f419..0000000000000 --- a/test/framework/src/main/java/org/elasticsearch/node/RecoverySettingsChunkSizePlugin.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.node; - -import org.elasticsearch.common.settings.Setting; -import org.elasticsearch.common.settings.Setting.Property; -import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.indices.recovery.RecoverySettings; -import org.elasticsearch.plugins.Plugin; - -import java.util.List; - -import static java.util.Collections.singletonList; - -/** - * Marker plugin that will trigger {@link MockNode} making {@link #CHUNK_SIZE_SETTING} dynamic. - */ -public class RecoverySettingsChunkSizePlugin extends Plugin { - /** - * The chunk size. Only exposed by tests. - */ - public static final Setting CHUNK_SIZE_SETTING = Setting.byteSizeSetting( - "indices.recovery.chunk_size", - RecoverySettings.DEFAULT_CHUNK_SIZE, - Property.Dynamic, - Property.NodeScope - ); - - @Override - public List> getSettings() { - return singletonList(CHUNK_SIZE_SETTING); - } -} diff --git a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardChangesAction.java b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardChangesAction.java index 7429fdd2a27b5..b35dca2881455 100644 --- a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardChangesAction.java +++ b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardChangesAction.java @@ -557,7 +557,17 @@ static Translog.Operation[] getOperations( long toSeqNo = Math.min(globalCheckpoint, (fromSeqNo + maxOperationCount) - 1); assert fromSeqNo <= toSeqNo : "invalid range from_seqno[" + fromSeqNo + "] > to_seqno[" + toSeqNo + "]"; final List operations = new ArrayList<>(); - try (Translog.Snapshot snapshot = indexShard.newChangesSnapshot("ccr", fromSeqNo, toSeqNo, true, true, false)) { + try ( + Translog.Snapshot snapshot = indexShard.newChangesSnapshot( + "ccr", + fromSeqNo, + toSeqNo, + true, + true, + false, + maxBatchSize.getBytes() + ) + ) { Translog.Operation op; while ((op = snapshot.next()) != null) { operations.add(op); diff --git a/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTaskReplicationTests.java b/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTaskReplicationTests.java index 5cd9f8bc5b78c..573c66cbb614a 100644 --- a/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTaskReplicationTests.java +++ b/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTaskReplicationTests.java @@ -755,7 +755,15 @@ private void assertConsistentHistoryBetweenLeaderAndFollower( final Map operationsOnLeader = new HashMap<>(); try ( Translog.Snapshot snapshot = leader.getPrimary() - .newChangesSnapshot("test", 0, Long.MAX_VALUE, false, randomBoolean(), randomBoolean()) + .newChangesSnapshot( + "test", + 0, + Long.MAX_VALUE, + false, + randomBoolean(), + randomBoolean(), + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) + ) ) { Translog.Operation op; while ((op = snapshot.next()) != null) { @@ -780,7 +788,8 @@ private void assertConsistentHistoryBetweenLeaderAndFollower( Long.MAX_VALUE, false, randomBoolean(), - randomBoolean() + randomBoolean(), + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) ) ) { Translog.Operation op; diff --git a/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/bulk/BulkShardOperationsTests.java b/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/bulk/BulkShardOperationsTests.java index 4e3aea2cad205..e3f26eed0c2e9 100644 --- a/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/bulk/BulkShardOperationsTests.java +++ b/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/bulk/BulkShardOperationsTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.IndexShardTestCase; import org.elasticsearch.index.translog.Translog; @@ -84,7 +85,15 @@ public void testPrimaryTermFromFollower() throws IOException { boolean accessStats = randomBoolean(); try ( - Translog.Snapshot snapshot = followerPrimary.newChangesSnapshot("test", 0, Long.MAX_VALUE, false, randomBoolean(), accessStats) + Translog.Snapshot snapshot = followerPrimary.newChangesSnapshot( + "test", + 0, + Long.MAX_VALUE, + false, + randomBoolean(), + accessStats, + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) + ) ) { if (accessStats) { assertThat(snapshot.totalOperations(), equalTo(operations.size())); diff --git a/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/index/engine/FollowingEngineTests.java b/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/index/engine/FollowingEngineTests.java index 150eddf039cec..62dc3313a1172 100644 --- a/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/index/engine/FollowingEngineTests.java +++ b/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/index/engine/FollowingEngineTests.java @@ -15,7 +15,11 @@ import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.CheckedBiFunction; import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.compress.CompressedXContent; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.Index; @@ -31,7 +35,10 @@ import org.elasticsearch.index.engine.EngineTestCase; import org.elasticsearch.index.engine.InternalEngine; import org.elasticsearch.index.engine.TranslogHandler; +import org.elasticsearch.index.mapper.DocumentMapper; +import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.ParsedDocument; +import org.elasticsearch.index.mapper.SourceToParse; import org.elasticsearch.index.seqno.RetentionLeases; import org.elasticsearch.index.seqno.SequenceNumbers; import org.elasticsearch.index.shard.IndexShard; @@ -44,6 +51,9 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; import java.io.IOException; import java.nio.file.Path; @@ -94,7 +104,7 @@ public void tearDown() throws Exception { super.tearDown(); } - public void testFollowingEngineRejectsNonFollowingIndex() { + public void testFollowingEngineRejectsNonFollowingIndex() throws IOException { final Settings.Builder builder = indexSettings(IndexVersion.current(), 1, 0); if (randomBoolean()) { builder.put("index.xpack.ccr.following_index", false); @@ -212,7 +222,7 @@ private EngineConfig engineConfig( final IndexSettings indexSettings, final ThreadPool threadPool, final Store store - ) { + ) throws IOException { final IndexWriterConfig indexWriterConfig = newIndexWriterConfig(); final Path translogPath = createTempDir("translog"); final TranslogConfig translogConfig = new TranslogConfig( @@ -221,6 +231,7 @@ private EngineConfig engineConfig( indexSettings, BigArrays.NON_RECYCLING_INSTANCE ); + final MapperService mapperService = EngineTestCase.createMapperService(); return new EngineConfig( shardIdValue, threadPool, @@ -253,7 +264,7 @@ public void onFailedEngine(String reason, Exception e) { System::nanoTime, null, true, - null + mapperService ); } @@ -641,7 +652,15 @@ private void fetchOperations(AtomicBoolean stopped, AtomicLong lastFetchedSeqNo, final long toSeqNo = randomLongBetween(nextSeqNo, Math.min(nextSeqNo + 5, checkpoint)); try ( Translog.Snapshot snapshot = shuffleSnapshot( - leader.newChangesSnapshot("test", fromSeqNo, toSeqNo, true, randomBoolean(), randomBoolean()) + leader.newChangesSnapshot( + "test", + fromSeqNo, + toSeqNo, + true, + randomBoolean(), + randomBoolean(), + randomLongBetween(1, ByteSizeValue.ofMb(32).getBytes()) + ) ) ) { follower.advanceMaxSeqNoOfUpdatesOrDeletes(leader.getMaxSeqNoOfUpdatesOrDeletes()); @@ -689,6 +708,39 @@ public void close() throws IOException { }; } + private CheckedBiFunction nestedParsedDocFactory() throws Exception { + final MapperService mapperService = EngineTestCase.createMapperService(); + final String nestedMapping = Strings.toString( + XContentFactory.jsonBuilder() + .startObject() + .startObject("type") + .startObject("properties") + .startObject("nested_field") + .field("type", "nested") + .endObject() + .endObject() + .endObject() + .endObject() + ); + final DocumentMapper nestedMapper = mapperService.merge( + "type", + new CompressedXContent(nestedMapping), + MapperService.MergeReason.MAPPING_UPDATE + ); + return (docId, nestedFieldValues) -> { + final XContentBuilder source = XContentFactory.jsonBuilder().startObject().field("field", "value"); + if (nestedFieldValues > 0) { + XContentBuilder nestedField = source.startObject("nested_field"); + for (int i = 0; i < nestedFieldValues; i++) { + nestedField.field("field-" + i, "value-" + i); + } + source.endObject(); + } + source.endObject(); + return nestedMapper.parse(new SourceToParse(docId, BytesReference.bytes(source), XContentType.JSON)); + }; + } + public void testProcessOnceOnPrimary() throws Exception { final Settings.Builder settingsBuilder = indexSettings(IndexVersion.current(), 1, 0).put("index.xpack.ccr.following_index", true); switch (indexMode) { @@ -709,7 +761,7 @@ public void testProcessOnceOnPrimary() throws Exception { final Settings settings = settingsBuilder.build(); final IndexMetadata indexMetadata = IndexMetadata.builder(index.getName()).settings(settings).build(); final IndexSettings indexSettings = new IndexSettings(indexMetadata, settings); - final CheckedBiFunction nestedDocFunc = EngineTestCase.nestedParsedDocFactory(); + final CheckedBiFunction nestedDocFunc = nestedParsedDocFactory(); int numOps = between(10, 100); List operations = new ArrayList<>(numOps); for (int i = 0; i < numOps; i++) { From 840bddd6508e813f552f0aeae07c5abb0bd3ebf0 Mon Sep 17 00:00:00 2001 From: getsolaris Date: Wed, 11 Dec 2024 09:27:26 +0900 Subject: [PATCH 20/28] =?UTF-8?q?Update=20numeral=20format=20from=20'?= =?UTF-8?q?=EC=A1=B0'=20to=20'=EC=9D=BC=EC=A1=B0'=20for=20clarity=20(#1182?= =?UTF-8?q?32)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/plugins/analysis-nori.asciidoc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/plugins/analysis-nori.asciidoc b/docs/plugins/analysis-nori.asciidoc index 0d3e76f71d238..9eb3bf07fbd30 100644 --- a/docs/plugins/analysis-nori.asciidoc +++ b/docs/plugins/analysis-nori.asciidoc @@ -475,7 +475,7 @@ The input is untokenized text and the result is the single term attribute emitte - 영영칠 -> 7 - 일영영영 -> 1000 - 삼천2백2십삼 -> 3223 -- 조육백만오천일 -> 1000006005001 +- 일조육백만오천일 -> 1000006005001 - 3.2천 -> 3200 - 1.2만345.67 -> 12345.67 - 4,647.100 -> 4647.1 From d5589d7fca0e3481d6617a28da9483ed1db83494 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine <58790826+elasticsearchmachine@users.noreply.github.com> Date: Wed, 11 Dec 2024 18:05:06 +1100 Subject: [PATCH 21/28] Mute org.elasticsearch.xpack.security.operator.OperatorPrivilegesIT testEveryActionIsEitherOperatorOnlyOrNonOperator #118220 --- muted-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/muted-tests.yml b/muted-tests.yml index df22d10062d8d..1b84d4dbe33de 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -318,6 +318,9 @@ tests: - class: org.elasticsearch.xpack.inference.InferenceCrudIT method: testUnifiedCompletionInference issue: https://github.com/elastic/elasticsearch/issues/118405 +- class: org.elasticsearch.xpack.security.operator.OperatorPrivilegesIT + method: testEveryActionIsEitherOperatorOnlyOrNonOperator + issue: https://github.com/elastic/elasticsearch/issues/118220 # Examples: # From 59967727cfb20f4f7ed0c969cd2ddd92b1f7b712 Mon Sep 17 00:00:00 2001 From: Carlos Delgado <6339205+carlosdelest@users.noreply.github.com> Date: Wed, 11 Dec 2024 09:14:18 +0100 Subject: [PATCH 22/28] kNN vector rescoring for quantized vectors (#116663) --- docs/changelog/116663.yaml | 5 + .../percolator/PercolatorQuerySearchIT.java | 2 +- .../search.retrievers/20_knn_retriever.yml | 58 ++++- .../search.vectors/210_knn_search_profile.yml | 137 ++++++++++ .../test/search.vectors/40_knn_search.yml | 55 ++++ .../search.vectors/41_knn_search_bbq_hnsw.yml | 69 +++++ .../41_knn_search_byte_quantized.yml | 59 +++++ .../41_knn_search_half_byte_quantized.yml | 56 ++++ .../search.vectors/42_knn_search_bbq_flat.yml | 69 +++++ .../search.vectors/42_knn_search_flat.yml | 55 ++++ .../42_knn_search_int4_flat.yml | 55 ++++ .../42_knn_search_int8_flat.yml | 54 ++++ .../test/search.vectors/45_knn_search_bit.yml | 56 ++++ .../search.vectors/45_knn_search_bit_flat.yml | 56 ++++ .../search.vectors/45_knn_search_byte.yml | 57 +++++ .../search/nested/VectorNestedIT.java | 2 +- .../search/profile/dfs/DfsProfilerIT.java | 2 + .../retriever/RetrieverTelemetryIT.java | 10 +- .../org/elasticsearch/TransportVersions.java | 1 + .../vectors/DenseVectorFieldMapper.java | 64 ++++- .../VectorSimilarityFloatValueSource.java | 105 ++++++++ .../action/search/SearchCapabilities.java | 2 + .../elasticsearch/search/dfs/DfsPhase.java | 6 +- .../search/profile/query/QueryProfiler.java | 12 +- .../search/retriever/KnnRetrieverBuilder.java | 28 +- ...iversifyingChildrenByteKnnVectorQuery.java | 4 +- ...versifyingChildrenFloatKnnVectorQuery.java | 4 +- .../search/vectors/ESKnnByteVectorQuery.java | 8 +- .../search/vectors/ESKnnFloatVectorQuery.java | 8 +- .../search/vectors/KnnScoreDocQuery.java | 38 ++- .../vectors/KnnScoreDocQueryBuilder.java | 24 +- .../search/vectors/KnnSearchBuilder.java | 95 ++++++- .../vectors/KnnSearchRequestParser.java | 2 +- .../search/vectors/KnnVectorQueryBuilder.java | 128 ++++++++-- ...gQuery.java => QueryProfilerProvider.java} | 2 +- .../search/vectors/RescoreKnnVectorQuery.java | 140 ++++++++++ .../search/vectors/RescoreVectorBuilder.java | 85 ++++++ .../search/vectors/VectorSimilarityQuery.java | 13 +- .../action/search/DfsQueryPhaseTests.java | 4 +- .../search/KnnSearchSingleNodeTests.java | 24 +- .../action/search/SearchRequestTests.java | 48 +++- .../search/TransportSearchActionTests.java | 4 +- .../vectors/DenseVectorFieldMapperTests.java | 40 ++- .../vectors/DenseVectorFieldTypeTests.java | 159 ++++++++++-- .../index/query/NestedQueryBuilderTests.java | 1 + .../action/search/RestSearchActionTests.java | 2 +- .../builder/SearchSourceBuilderTests.java | 2 +- .../KnnRetrieverBuilderParsingTests.java | 17 +- .../RankDocsRetrieverBuilderTests.java | 2 + ...AbstractKnnVectorQueryBuilderTestCase.java | 183 +++++++++---- .../KnnByteVectorQueryBuilderTests.java | 17 +- .../KnnFloatVectorQueryBuilderTests.java | 17 +- .../search/vectors/KnnSearchBuilderTests.java | 121 +++++++-- .../vectors/RescoreKnnVectorQueryTests.java | 241 ++++++++++++++++++ .../search/RandomSearchRequestGenerator.java | 9 +- .../AbstractQueryVectorBuilderTestCase.java | 3 + .../mapper/SemanticTextFieldMapper.java | 2 +- .../SemanticTextHighlighterTests.java | 2 +- ...SimilarityRankRetrieverTelemetryTests.java | 10 +- .../xpack/rank/rrf/RRFRankMultiShardIT.java | 40 +-- .../xpack/rank/rrf/RRFRankSingleShardIT.java | 34 +-- .../xpack/rank/rrf/RRFRetrieverBuilderIT.java | 27 +- .../rrf/RRFRetrieverBuilderNestedDocsIT.java | 2 +- .../rank/rrf/RRFRetrieverTelemetryIT.java | 12 +- .../xpack/rank/rrf/RRFRankBuilder.java | 1 + .../DocumentLevelSecurityTests.java | 2 +- .../integration/FieldLevelSecurityTests.java | 6 +- 67 files changed, 2362 insertions(+), 296 deletions(-) create mode 100644 docs/changelog/116663.yaml create mode 100644 rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml create mode 100644 server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java rename server/src/main/java/org/elasticsearch/search/vectors/{ProfilingQuery.java => QueryProfilerProvider.java} (96%) create mode 100644 server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java create mode 100644 server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java create mode 100644 server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java diff --git a/docs/changelog/116663.yaml b/docs/changelog/116663.yaml new file mode 100644 index 0000000000000..40bcdea29bc31 --- /dev/null +++ b/docs/changelog/116663.yaml @@ -0,0 +1,5 @@ +pr: 116663 +summary: KNN vector rescoring for quantized vectors +area: Vector Search +type: feature +issues: [] diff --git a/modules/percolator/src/internalClusterTest/java/org/elasticsearch/percolator/PercolatorQuerySearchIT.java b/modules/percolator/src/internalClusterTest/java/org/elasticsearch/percolator/PercolatorQuerySearchIT.java index 05f456b7f2229..8a7f1405f8f4e 100644 --- a/modules/percolator/src/internalClusterTest/java/org/elasticsearch/percolator/PercolatorQuerySearchIT.java +++ b/modules/percolator/src/internalClusterTest/java/org/elasticsearch/percolator/PercolatorQuerySearchIT.java @@ -1359,7 +1359,7 @@ public void testKnnQueryNotSupportedInPercolator() throws IOException { """); indicesAdmin().prepareCreate("index1").setMapping(mappings).get(); ensureGreen(); - QueryBuilder knnVectorQueryBuilder = new KnnVectorQueryBuilder("my_vector", new float[] { 1, 1, 1, 1, 1 }, 10, 10, null); + QueryBuilder knnVectorQueryBuilder = new KnnVectorQueryBuilder("my_vector", new float[] { 1, 1, 1, 1, 1 }, 10, 10, null, null); IndexRequestBuilder indexRequestBuilder = prepareIndex("index1").setId("knn_query1") .setSource(jsonBuilder().startObject().field("my_query", knnVectorQueryBuilder).endObject()); diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml index d08a8e2a6d39c..e49f0634a4887 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml @@ -18,7 +18,7 @@ setup: dims: 5 index: true index_options: - type: hnsw + type: int8_hnsw similarity: l2_norm - do: @@ -73,3 +73,59 @@ setup: - match: {hits.total.value: 1} - match: {hits.hits.0._id: "3"} - match: {hits.hits.0.fields.name.0: "rabbit.jpg"} + +--- +"Vector rescoring has no effect for non-quantized vectors and provides same results as non-rescored knn": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: index1 + body: + knn: + field: vector + query_vector: [2, 2, 2, 2, 3] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: index1 + body: + query: + script_score: + query: {match_all: {} } + script: + source: "1.0 / (1.0 + Math.pow(l2norm(params.query_vector, 'vector'), 2.0))" + params: + query_vector: [2, 2, 2, 2, 3] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml new file mode 100644 index 0000000000000..d4bf5e7e9807f --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml @@ -0,0 +1,137 @@ +setup: + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [ capabilities ] + capabilities: + - method: GET + path: /_search + capabilities: [ knn_quantized_vector_rescore ] + - skip: + features: "headers" + + - do: + indices.create: + index: bbq_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + + - do: + index: + index: bbq_hnsw + id: "1" + body: + vector: [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_hnsw + + - do: + index: + index: bbq_hnsw + id: "2" + body: + vector: [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_hnsw + + - do: + index: + index: bbq_hnsw + id: "3" + body: + name: rabbit.jpg + vector: [0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.657, 1.285, 0.2 , -0.062, 0.038, 0.089, -0.068, -0.058] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_hnsw + + - do: + indices.forcemerge: + index: bbq_hnsw + max_num_segments: 1 +--- +"Profile rescored knn search": + + - do: + search: + index: bbq_hnsw + body: + profile: true + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + "rescore_vector": + "num_candidates_factor": 2.0 + + # We expect the knn search ops + rescoring num_cnaidates (for rescoring) per shard + - match: { profile.shards.0.dfs.knn.0.vector_operations_count: 6 } + + # Search with similarity to check number of operations are propagated correctly + - do: + search: + index: bbq_hnsw + body: + profile: true + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + similarity: 100000 + "rescore_vector": + "num_candidates_factor": 2.0 + + # We expect the knn search ops + rescoring num_cnaidates (for rescoring) per shard + - match: { profile.shards.0.dfs.knn.0.vector_operations_count: 6 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml index 78a86e4026b30..7d4690204acc7 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml @@ -550,3 +550,58 @@ setup: num_candidates: 3 - match: { hits.total.value: 0 } +--- +"Vector rescoring has no effect for non-quantized vectors and provides same results as non-rescored knn": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Non-rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + + # Get scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: knn_score0 } + - set: { hits.hits.1._score: knn_score1 } + - set: { hits.hits.2._score: knn_score2 } + + # Rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $knn_score0 } + - match: { hits.hits.1._score: $knn_score1 } + - match: { hits.hits.2._score: $knn_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml index 5767c895fbe7e..2567a4ac597d9 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml @@ -108,6 +108,75 @@ setup: - match: { hits.hits.1._id: "3" } - match: { hits.hits.2._id: "2" } --- +"Vector rescoring has same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- "Test bad quantization parameters": - do: catch: bad_request diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml index b7a5517309949..b1e35789e8737 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml @@ -368,6 +368,65 @@ setup: - match: {hits.hits.2._id: "1"} - gte: {hits.hits.2._score: 0.78} - lte: {hits.hits.2._score: 0.791} + +--- +# Won't be true for larger datasets, but this helps checking kNN vs rescoring vs exact search +"Vector rescoring has the same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: hnsw_byte_quantized + body: + size: 3 + query: + knn: + k: 3 + num_candidates: 3 + field: vector + query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] + rescore_vector: + num_candidates_factor: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "1.0 / (1.0 + Math.pow(l2norm(params.query_vector, 'vector'), 2.0))" + params: + query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + --- "Test bad quantization parameters": - do: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml index 5f1af2ca5c52f..54e9eadf42e0b 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml @@ -549,6 +549,62 @@ setup: - match: { hits.hits.1._id: "2"} - match: { hits.hits.2._id: "3"} --- +"Vector rescoring has the same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + index: hnsw_byte_quantized + rest_total_hits_as_int: true + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "1.0 / (1.0 + Math.pow(l2norm(params.query_vector, 'vector'), 2.0))" + params: + query_vector: [-0.5, 90.0, -10, 14.8] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- "Test odd dimensions fail indexing": - do: catch: bad_request diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml index dcdae04aeabb4..a3cd624ef0ab8 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml @@ -107,6 +107,75 @@ setup: - match: { hits.hits.1._id: "3" } - match: { hits.hits.2._id: "2" } --- +"Vector rescoring has same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17, + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_flat + body: + query: + script_score: + query: { match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17, + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- "Test bad parameters": - do: catch: bad_request diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml index 1b439967ba163..a59aedceff3d3 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml @@ -257,6 +257,61 @@ setup: - gte: {hits.hits.2._score: 0.78} - lte: {hits.hits.2._score: 0.791} --- +"Vector rescoring has no effect for non-quantized vectors and provides same results as non-rescored knn": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Non-rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: flat + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + + # Get scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: knn_score0 } + - set: { hits.hits.1._score: knn_score1 } + - set: { hits.hits.2._score: knn_score2 } + + # Rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: flat + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $knn_score0 } + - match: { hits.hits.1._score: $knn_score1 } + - match: { hits.hits.2._score: $knn_score2 } +--- "Test bad parameters": - do: catch: bad_request diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml index b9a0b16f2bd7a..6796a92122f9a 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml @@ -344,3 +344,58 @@ setup: index: dynamic_dim_hnsw_quantized body: vector: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] +--- +"Vector rescoring has the same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + index: int4_flat + rest_total_hits_as_int: true + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "1.0 / (1.0 + Math.pow(l2norm(params.query_vector, 'vector'), 2.0))" + params: + query_vector: [-0.5, 90.0, -10, 14.8] + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml index 139747c5e7ee5..d1d312449cb70 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml @@ -262,6 +262,60 @@ setup: - gte: {hits.hits.2._score: 0.78} - lte: {hits.hits.2._score: 0.791} --- +"Vector rescoring has the same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + index: int8_flat + rest_total_hits_as_int: true + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "1.0 / (1.0 + Math.pow(l2norm(params.query_vector, 'vector'), 2.0))" + params: + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } +--- "Test bad parameters": - do: catch: bad_request diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml index 02576ad1b2b01..effa3fff61525 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml @@ -405,3 +405,59 @@ setup: - match: {hits.hits.0._id: "1"} - match: {hits.hits.0._source.vector1: [2, -1, 1, 4, -3]} - match: {hits.hits.0._source.vector2: [2, -1, 1, 4, -3]} + +--- +"Vector rescoring has no effect for non-quantized vectors and provides same results as non-rescored knn": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Non-rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [127.0, -128.0, 0.0, 1.0, -1.0] + k: 3 + num_candidates: 3 + + # Get scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: knn_score0 } + - set: { hits.hits.1._score: knn_score1 } + - set: { hits.hits.2._score: knn_score2 } + + # Rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [127.0, -128.0, 0.0, 1.0, -1.0] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $knn_score0 } + - match: { hits.hits.1._score: $knn_score1 } + - match: { hits.hits.2._score: $knn_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml index ec7bde4de8435..cdc1d9c64763e 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml @@ -221,3 +221,59 @@ setup: similarity: l2_norm index_options: type: int8_hnsw + +--- +"Vector rescoring has no effect for non-quantized vectors and provides same results as non-rescored knn": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Non-rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [127, 127, -128, -128, 127] + k: 3 + num_candidates: 3 + + # Get scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: knn_score0 } + - set: { hits.hits.1._score: knn_score1 } + - set: { hits.hits.2._score: knn_score2 } + + # Rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [127, 127, -128, -128, 127] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $knn_score0 } + - match: { hits.hits.1._score: $knn_score1 } + - match: { hits.hits.2._score: $knn_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml index 0cedfaa873095..213b571a0b4be 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml @@ -254,3 +254,60 @@ setup: filter: {"term": {"name": "cow.jpg"}} - length: {hits.hits: 0} + +--- +"Vector rescoring has no effect for non-quantized vectors and provides same results as non-rescored knn": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Non-rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [127, 127, -128, -128, 127] + k: 3 + num_candidates: 3 + + # Get scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: knn_score0 } + - set: { hits.hits.1._score: knn_score1 } + - set: { hits.hits.2._score: knn_score2 } + + # Rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [127, 127, -128, -128, 127] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $knn_score0 } + - match: { hits.hits.1._score: $knn_score1 } + - match: { hits.hits.2._score: $knn_score2 } + diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/nested/VectorNestedIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/nested/VectorNestedIT.java index d1021715ceffc..aaab14941d4bb 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/nested/VectorNestedIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/nested/VectorNestedIT.java @@ -69,7 +69,7 @@ public void testSimpleNested() throws Exception { assertResponse( prepareSearch("test").setKnnSearch( - List.of(new KnnSearchBuilder("nested.vector", new float[] { 1, 1, 1 }, 1, 1, null).innerHit(new InnerHitBuilder())) + List.of(new KnnSearchBuilder("nested.vector", new float[] { 1, 1, 1 }, 1, 1, null, null).innerHit(new InnerHitBuilder())) ).setAllowPartialSearchResults(false), response -> assertThat(response.getHits().getHits().length, greaterThan(0)) ); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java index 876edc282c903..95d69a6ebaa86 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java @@ -19,6 +19,7 @@ import org.elasticsearch.search.profile.query.CollectorResult; import org.elasticsearch.search.profile.query.QueryProfileShardResult; import org.elasticsearch.search.vectors.KnnSearchBuilder; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.xcontent.XContentFactory; @@ -71,6 +72,7 @@ public void testProfileDfs() throws Exception { new float[] { randomFloat(), randomFloat(), randomFloat() }, randomIntBetween(5, 10), 50, + randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)), randomBoolean() ? null : randomFloat() ); if (randomBoolean()) { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java index 537ace30e88f0..40849bea5512e 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java @@ -84,7 +84,9 @@ public void testTelemetryForRetrievers() throws IOException { // search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers` { - performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null))); + performSearch( + new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null, null)) + ); } // search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under @@ -98,7 +100,7 @@ public void testTelemetryForRetrievers() throws IOException { { performSearch( new SearchSourceBuilder().retriever( - new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null)) + new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null, null)) ) ); } @@ -112,7 +114,9 @@ public void testTelemetryForRetrievers() throws IOException { // search#5 - t // his will record 1 entry for "knn" in `sections` { - performSearch(new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null)))); + performSearch( + new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null, null))) + ); } // search#6 - this will record 1 entry for "query" in `sections`, and 1 for "match_all" under `queries` diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 3e3556f554089..9f50df02b4fe8 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -140,6 +140,7 @@ static TransportVersion def(int id) { public static final TransportVersion NEW_REFRESH_CLUSTER_BLOCK = def(8_803_00_0); public static final TransportVersion RETRIES_AND_OPERATIONS_IN_BLOBSTORE_STATS = def(8_804_00_0); public static final TransportVersion ADD_DATA_STREAM_OPTIONS_TO_TEMPLATES = def(8_805_00_0); + public static final TransportVersion KNN_QUERY_RESCORE_OVERSAMPLE = def(8_806_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index d780faad96f2d..8c6e874ff577f 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -70,6 +70,7 @@ import org.elasticsearch.search.vectors.ESDiversifyingChildrenFloatKnnVectorQuery; import org.elasticsearch.search.vectors.ESKnnByteVectorQuery; import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery; +import org.elasticsearch.search.vectors.RescoreKnnVectorQuery; import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.search.vectors.VectorSimilarityQuery; import org.elasticsearch.xcontent.ToXContent; @@ -123,6 +124,7 @@ public static boolean isNotUnitVector(float magnitude) { public static short MIN_DIMS_FOR_DYNAMIC_FLOAT_MAPPING = 128; // minimum number of dims for floats to be dynamically mapped to vector public static final int MAGNITUDE_BYTES = 4; + public static final int NUM_CANDS_OVERSAMPLE_LIMIT = 10_000; // Max oversample allowed for k and num_candidates private static DenseVectorFieldMapper toType(FieldMapper in) { return (DenseVectorFieldMapper) in; @@ -1216,7 +1218,7 @@ public final int hashCode() { } private enum VectorIndexType { - HNSW("hnsw") { + HNSW("hnsw", false) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { Object mNode = indexOptionsMap.remove("m"); @@ -1243,7 +1245,7 @@ public boolean supportsDimension(int dims) { return true; } }, - INT8_HNSW("int8_hnsw") { + INT8_HNSW("int8_hnsw", true) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { Object mNode = indexOptionsMap.remove("m"); @@ -1275,7 +1277,7 @@ public boolean supportsDimension(int dims) { return true; } }, - INT4_HNSW("int4_hnsw") { + INT4_HNSW("int4_hnsw", true) { public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); @@ -1306,7 +1308,7 @@ public boolean supportsDimension(int dims) { return dims % 2 == 0; } }, - FLAT("flat") { + FLAT("flat", false) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); @@ -1323,7 +1325,7 @@ public boolean supportsDimension(int dims) { return true; } }, - INT8_FLAT("int8_flat") { + INT8_FLAT("int8_flat", true) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); @@ -1345,7 +1347,7 @@ public boolean supportsDimension(int dims) { return true; } }, - INT4_FLAT("int4_flat") { + INT4_FLAT("int4_flat", true) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); @@ -1367,7 +1369,7 @@ public boolean supportsDimension(int dims) { return dims % 2 == 0; } }, - BBQ_HNSW("bbq_hnsw") { + BBQ_HNSW("bbq_hnsw", true) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { Object mNode = indexOptionsMap.remove("m"); @@ -1394,7 +1396,7 @@ public boolean supportsDimension(int dims) { return dims >= BBQ_MIN_DIMS; } }, - BBQ_FLAT("bbq_flat") { + BBQ_FLAT("bbq_flat", true) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); @@ -1417,9 +1419,11 @@ static Optional fromString(String type) { } private final String name; + private final boolean quantized; - VectorIndexType(String name) { + VectorIndexType(String name, boolean quantized) { this.name = name; + this.quantized = quantized; } abstract IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap); @@ -1428,6 +1432,10 @@ static Optional fromString(String type) { public abstract boolean supportsDimension(int dims); + public boolean isQuantized() { + return quantized; + } + @Override public String toString() { return name; @@ -2000,6 +2008,7 @@ public Query createKnnQuery( VectorData queryVector, Integer k, int numCands, + Float numCandsFactor, Query filter, Float similarityThreshold, BitSetProducer parentFilter @@ -2011,11 +2020,23 @@ public Query createKnnQuery( } return switch (getElementType()) { case BYTE -> createKnnByteQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter); - case FLOAT -> createKnnFloatQuery(queryVector.asFloatVector(), k, numCands, filter, similarityThreshold, parentFilter); + case FLOAT -> createKnnFloatQuery( + queryVector.asFloatVector(), + k, + numCands, + numCandsFactor, + filter, + similarityThreshold, + parentFilter + ); case BIT -> createKnnBitQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter); }; } + private boolean needsRescore(Float rescoreOversample) { + return rescoreOversample != null && (indexOptions != null && indexOptions.type != null && indexOptions.type.isQuantized()); + } + private Query createKnnBitQuery( byte[] queryVector, Integer k, @@ -2069,6 +2090,7 @@ private Query createKnnFloatQuery( float[] queryVector, Integer k, int numCands, + Float numCandsFactor, Query filter, Float similarityThreshold, BitSetProducer parentFilter @@ -2088,9 +2110,27 @@ && isNotUnitVector(squaredMagnitude)) { } } } + + Integer adjustedK = k; + int adjustedNumCands = numCands; + if (needsRescore(numCandsFactor)) { + // Get all candidates, get top k as part of rescoring + adjustedK = null; + // numCands * numCandsFactor <= NUM_CANDS_OVERSAMPLE_LIMIT. Adjust otherwise. + adjustedNumCands = Math.min((int) Math.ceil(numCands * numCandsFactor), NUM_CANDS_OVERSAMPLE_LIMIT); + } Query knnQuery = parentFilter != null - ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter) - : new ESKnnFloatVectorQuery(name(), queryVector, k, numCands, filter); + ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, adjustedK, adjustedNumCands, parentFilter) + : new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, adjustedNumCands, filter); + if (needsRescore(numCandsFactor)) { + knnQuery = new RescoreKnnVectorQuery( + name(), + queryVector, + similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT), + k, + knnQuery + ); + } if (similarityThreshold != null) { knnQuery = new VectorSimilarityQuery( knnQuery, diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java new file mode 100644 index 0000000000000..74a7dbe168e6b --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java @@ -0,0 +1,105 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.mapper.vectors; + +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.DoubleValues; +import org.apache.lucene.search.DoubleValuesSource; +import org.apache.lucene.search.IndexSearcher; +import org.elasticsearch.search.profile.query.QueryProfiler; +import org.elasticsearch.search.vectors.QueryProfilerProvider; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; + +/** + * DoubleValuesSource that is used to calculate scores according to a similarity function for a KnnFloatVectorField, using the + * original vector values stored in the index + */ +public class VectorSimilarityFloatValueSource extends DoubleValuesSource implements QueryProfilerProvider { + + private final String field; + private final float[] target; + private final VectorSimilarityFunction vectorSimilarityFunction; + private long vectorOpsCount; + + public VectorSimilarityFloatValueSource(String field, float[] target, VectorSimilarityFunction vectorSimilarityFunction) { + this.field = field; + this.target = target; + this.vectorSimilarityFunction = vectorSimilarityFunction; + } + + @Override + public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException { + final LeafReader reader = ctx.reader(); + + FloatVectorValues vectorValues = reader.getFloatVectorValues(field); + final KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + + return new DoubleValues() { + @Override + public double doubleValue() throws IOException { + vectorOpsCount++; + return vectorSimilarityFunction.compare(target, vectorValues.vectorValue(iterator.index())); + } + + @Override + public boolean advanceExact(int doc) throws IOException { + return doc >= iterator.docID() && iterator.docID() != DocIdSetIterator.NO_MORE_DOCS && iterator.advance(doc) == doc; + } + }; + } + + @Override + public boolean needsScores() { + return false; + } + + @Override + public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException { + return this; + } + + @Override + public void profile(QueryProfiler queryProfiler) { + queryProfiler.addVectorOpsCount(vectorOpsCount); + } + + @Override + public int hashCode() { + return Objects.hash(field, Arrays.hashCode(target), vectorSimilarityFunction); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + VectorSimilarityFloatValueSource that = (VectorSimilarityFloatValueSource) o; + return Objects.equals(field, that.field) + && Arrays.equals(target, that.target) + && vectorSimilarityFunction == that.vectorSimilarityFunction; + } + + @Override + public String toString() { + return "VectorSimilarityFloatValueSource(" + field + ", [" + target[0] + ",...], " + vectorSimilarityFunction + ")"; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } +} diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java index c9d9569abe93f..86304c8c4bde2 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java @@ -45,6 +45,7 @@ private SearchCapabilities() {} private static final String RANDOM_SAMPLER_WITH_SCORED_SUBAGGS = "random_sampler_with_scored_subaggs"; private static final String OPTIMIZED_SCALAR_QUANTIZATION_BBQ = "optimized_scalar_quantization_bbq"; + private static final String KNN_QUANTIZED_VECTOR_RESCORE = "knn_quantized_vector_rescore"; public static final Set CAPABILITIES; static { @@ -57,6 +58,7 @@ private SearchCapabilities() {} capabilities.add(NESTED_RETRIEVER_INNER_HITS_SUPPORT); capabilities.add(RANDOM_SAMPLER_WITH_SCORED_SUBAGGS); capabilities.add(OPTIMIZED_SCALAR_QUANTIZATION_BBQ); + capabilities.add(KNN_QUANTIZED_VECTOR_RESCORE); if (RankVectorsFieldMapper.FEATURE_FLAG.isEnabled()) { capabilities.add(RANK_VECTORS_FIELD_MAPPER); capabilities.add(RANK_VECTORS_SCRIPT_ACCESS); diff --git a/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java b/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java index 76b3f45ffb84a..6a99b51ac679c 100644 --- a/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java +++ b/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java @@ -34,7 +34,7 @@ import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.vectors.KnnSearchBuilder; import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; -import org.elasticsearch.search.vectors.ProfilingQuery; +import org.elasticsearch.search.vectors.QueryProfilerProvider; import org.elasticsearch.tasks.TaskCancelledException; import java.io.IOException; @@ -224,8 +224,8 @@ static DfsKnnResults singleKnnSearch(Query knnQuery, int k, Profilers profilers, ); topDocs = searcher.search(knnQuery, ipcm); - if (knnQuery instanceof ProfilingQuery profilingQuery) { - profilingQuery.profile(knnProfiler); + if (knnQuery instanceof QueryProfilerProvider queryProfilerProvider) { + queryProfilerProvider.profile(knnProfiler); } knnProfiler.setCollectorResult(ipcm.getCollectorTree()); diff --git a/server/src/main/java/org/elasticsearch/search/profile/query/QueryProfiler.java b/server/src/main/java/org/elasticsearch/search/profile/query/QueryProfiler.java index 98ddfa95bf156..23ce52b6c5b82 100644 --- a/server/src/main/java/org/elasticsearch/search/profile/query/QueryProfiler.java +++ b/server/src/main/java/org/elasticsearch/search/profile/query/QueryProfiler.java @@ -39,10 +39,18 @@ public QueryProfiler() { super(new InternalQueryProfileTree()); } - public void setVectorOpsCount(long vectorOpsCount) { - this.vectorOpsCount = vectorOpsCount; + /** + * Adds a number of vector operations to the current count + * @param vectorOpsCount number of vector ops to add to the profiler + */ + public void addVectorOpsCount(long vectorOpsCount) { + this.vectorOpsCount += vectorOpsCount; } + /** + * Retrieves the number of vector operations performed by the queries + * @return number of vector operations performed by the queries + */ public long getVectorOpsCount() { return this.vectorOpsCount; } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java index f1464c41ca3be..b29546ded75cd 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java @@ -20,8 +20,10 @@ import org.elasticsearch.search.vectors.ExactKnnQueryBuilder; import org.elasticsearch.search.vectors.KnnSearchBuilder; import org.elasticsearch.search.vectors.QueryVectorBuilder; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; @@ -52,6 +54,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector"); public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder"); public static final ParseField VECTOR_SIMILARITY = new ParseField("similarity"); + public static final ParseField RESCORE_VECTOR_FIELD = new ParseField("rescore_vector"); @SuppressWarnings("unchecked") public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( @@ -73,6 +76,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { (QueryVectorBuilder) args[2], (int) args[3], (int) args[4], + (RescoreVectorBuilder) args[6], (Float) args[5] ); } @@ -89,6 +93,12 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { PARSER.declareInt(constructorArg(), K_FIELD); PARSER.declareInt(constructorArg(), NUM_CANDS_FIELD); PARSER.declareFloat(optionalConstructorArg(), VECTOR_SIMILARITY); + PARSER.declareField( + optionalConstructorArg(), + (p, c) -> RescoreVectorBuilder.fromXContent(p), + RESCORE_VECTOR_FIELD, + ObjectParser.ValueType.OBJECT + ); RetrieverBuilder.declareBaseParserFields(NAME, PARSER); } @@ -104,6 +114,7 @@ public static KnnRetrieverBuilder fromXContent(XContentParser parser, RetrieverP private final QueryVectorBuilder queryVectorBuilder; private final int k; private final int numCands; + private final RescoreVectorBuilder rescoreVectorBuilder; private final Float similarity; public KnnRetrieverBuilder( @@ -112,6 +123,7 @@ public KnnRetrieverBuilder( QueryVectorBuilder queryVectorBuilder, int k, int numCands, + RescoreVectorBuilder rescoreVectorBuilder, Float similarity ) { if (queryVector == null && queryVectorBuilder == null) { @@ -137,6 +149,7 @@ public KnnRetrieverBuilder( this.k = k; this.numCands = numCands; this.similarity = similarity; + this.rescoreVectorBuilder = rescoreVectorBuilder; } private KnnRetrieverBuilder(KnnRetrieverBuilder clone, Supplier queryVector, QueryVectorBuilder queryVectorBuilder) { @@ -148,6 +161,7 @@ private KnnRetrieverBuilder(KnnRetrieverBuilder clone, Supplier queryVe this.similarity = clone.similarity; this.retrieverName = clone.retrieverName; this.preFilterQueryBuilders = clone.preFilterQueryBuilders; + this.rescoreVectorBuilder = clone.rescoreVectorBuilder; } @Override @@ -228,6 +242,7 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder null, k, numCands, + rescoreVectorBuilder, similarity ); if (preFilterQueryBuilders != null) { @@ -241,6 +256,10 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder searchSourceBuilder.knnSearch(knnSearchBuilders); } + RescoreVectorBuilder rescoreVectorBuilder() { + return rescoreVectorBuilder; + } + // ---- FOR TESTING XCONTENT PARSING ---- @Override @@ -260,6 +279,10 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept if (similarity != null) { builder.field(VECTOR_SIMILARITY.getPreferredName(), similarity); } + + if (rescoreVectorBuilder != null) { + builder.field(RESCORE_VECTOR_FIELD.getPreferredName(), rescoreVectorBuilder); + } } @Override @@ -271,12 +294,13 @@ public boolean doEquals(Object o) { && ((queryVector == null && that.queryVector == null) || (queryVector != null && that.queryVector != null && Arrays.equals(queryVector.get(), that.queryVector.get()))) && Objects.equals(queryVectorBuilder, that.queryVectorBuilder) - && Objects.equals(similarity, that.similarity); + && Objects.equals(similarity, that.similarity) + && Objects.equals(rescoreVectorBuilder, that.rescoreVectorBuilder); } @Override public int doHashCode() { - int result = Objects.hash(field, queryVectorBuilder, k, numCands, similarity); + int result = Objects.hash(field, queryVectorBuilder, k, numCands, rescoreVectorBuilder, similarity); result = 31 * result + Arrays.hashCode(queryVector != null ? queryVector.get() : null); return result; } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java index 9f3d83b4da082..b7f129f674036 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java @@ -15,7 +15,7 @@ import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; import org.elasticsearch.search.profile.query.QueryProfiler; -public class ESDiversifyingChildrenByteKnnVectorQuery extends DiversifyingChildrenByteKnnVectorQuery implements ProfilingQuery { +public class ESDiversifyingChildrenByteKnnVectorQuery extends DiversifyingChildrenByteKnnVectorQuery implements QueryProfilerProvider { private final Integer kParam; private long vectorOpsCount; @@ -40,6 +40,6 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { @Override public void profile(QueryProfiler queryProfiler) { - queryProfiler.setVectorOpsCount(vectorOpsCount); + queryProfiler.addVectorOpsCount(vectorOpsCount); } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java index 3907bdf89bc6f..cb323bbe3932a 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java @@ -15,7 +15,7 @@ import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.elasticsearch.search.profile.query.QueryProfiler; -public class ESDiversifyingChildrenFloatKnnVectorQuery extends DiversifyingChildrenFloatKnnVectorQuery implements ProfilingQuery { +public class ESDiversifyingChildrenFloatKnnVectorQuery extends DiversifyingChildrenFloatKnnVectorQuery implements QueryProfilerProvider { private final Integer kParam; private long vectorOpsCount; @@ -40,6 +40,6 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { @Override public void profile(QueryProfiler queryProfiler) { - queryProfiler.setVectorOpsCount(vectorOpsCount); + queryProfiler.addVectorOpsCount(vectorOpsCount); } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java index 9363f67a7350b..5c199f42093b1 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java @@ -14,7 +14,7 @@ import org.apache.lucene.search.TopDocs; import org.elasticsearch.search.profile.query.QueryProfiler; -public class ESKnnByteVectorQuery extends KnnByteVectorQuery implements ProfilingQuery { +public class ESKnnByteVectorQuery extends KnnByteVectorQuery implements QueryProfilerProvider { private final Integer kParam; private long vectorOpsCount; @@ -33,6 +33,10 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { @Override public void profile(QueryProfiler queryProfiler) { - queryProfiler.setVectorOpsCount(vectorOpsCount); + queryProfiler.addVectorOpsCount(vectorOpsCount); + } + + public Integer kParam() { + return kParam; } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java index be0437af9131d..b7b9d092ceeac 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java @@ -14,7 +14,7 @@ import org.apache.lucene.search.TopDocs; import org.elasticsearch.search.profile.query.QueryProfiler; -public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements ProfilingQuery { +public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements QueryProfilerProvider { private final Integer kParam; private long vectorOpsCount; @@ -33,6 +33,10 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { @Override public void profile(QueryProfiler queryProfiler) { - queryProfiler.setVectorOpsCount(vectorOpsCount); + queryProfiler.addVectorOpsCount(vectorOpsCount); + } + + public Integer kParam() { + return kParam; } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java index bb83b8528c6c8..3d13f3cd82b9c 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java @@ -9,6 +9,7 @@ package org.elasticsearch.search.vectors; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; @@ -37,7 +38,13 @@ public class KnnScoreDocQuery extends Query { private final int[] docs; private final float[] scores; + + // the indexes in docs and scores corresponding to the first matching document in each segment. + // If a segment has no matching documents, it should be assigned the index of the next segment that does. + // There should be a final entry that is always docs.length-1. private final int[] segmentStarts; + + // an object identifying the reader context that was used to build this query private final Object contextIdentity; /** @@ -45,18 +52,31 @@ public class KnnScoreDocQuery extends Query { * * @param docs the global doc IDs of documents that match, in ascending order * @param scores the scores of the matching documents - * @param segmentStarts the indexes in docs and scores corresponding to the first matching - * document in each segment. If a segment has no matching documents, it should be assigned - * the index of the next segment that does. There should be a final entry that is always - * docs.length-1. - * @param contextIdentity an object identifying the reader context that was used to build this - * query + * @param reader IndexReader */ - KnnScoreDocQuery(int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) { + KnnScoreDocQuery(int[] docs, float[] scores, IndexReader reader) { this.docs = docs; this.scores = scores; - this.segmentStarts = segmentStarts; - this.contextIdentity = contextIdentity; + this.segmentStarts = findSegmentStarts(reader, docs); + this.contextIdentity = reader.getContext().id(); + } + + private static int[] findSegmentStarts(IndexReader reader, int[] docs) { + int[] starts = new int[reader.leaves().size() + 1]; + starts[starts.length - 1] = docs.length; + if (starts.length == 2) { + return starts; + } + int resultIndex = 0; + for (int i = 1; i < starts.length - 1; i++) { + int upper = reader.leaves().get(i).docBase; + resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper); + if (resultIndex < 0) { + resultIndex = -1 - resultIndex; + } + starts[i] = resultIndex; + } + return starts; } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java index b5ba97906f0ec..6fa83ccfb6ac2 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java @@ -9,7 +9,6 @@ package org.elasticsearch.search.vectors; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.TransportVersion; @@ -25,7 +24,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; -import java.util.Arrays; import java.util.Objects; /** @@ -151,9 +149,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { scores[i] = scoreDocs[i].score; } - IndexReader reader = context.getIndexReader(); - int[] segmentStarts = findSegmentStarts(reader, docs); - return new KnnScoreDocQuery(docs, scores, segmentStarts, reader.getContext().id()); + return new KnnScoreDocQuery(docs, scores, context.getIndexReader()); } @Override @@ -167,24 +163,6 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws return super.doRewrite(queryRewriteContext); } - private static int[] findSegmentStarts(IndexReader reader, int[] docs) { - int[] starts = new int[reader.leaves().size() + 1]; - starts[starts.length - 1] = docs.length; - if (starts.length == 2) { - return starts; - } - int resultIndex = 0; - for (int i = 1; i < starts.length - 1; i++) { - int upper = reader.leaves().get(i).docBase; - resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper); - if (resultIndex < 0) { - resultIndex = -1 - resultIndex; - } - starts[i] = resultIndex; - } - return starts; - } - @Override protected boolean doEquals(KnnScoreDocQueryBuilder other) { if (scoreDocs.length != other.scoreDocs.length) { diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java index 41673a0e7edb0..b18ce2dff65cb 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java @@ -56,6 +56,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea public static final ParseField NAME_FIELD = AbstractQueryBuilder.NAME_FIELD; public static final ParseField BOOST_FIELD = AbstractQueryBuilder.BOOST_FIELD; public static final ParseField INNER_HITS_FIELD = new ParseField("inner_hits"); + public static final ParseField RESCORE_VECTOR_FIELD = new ParseField("rescore_vector"); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("knn", args -> { @@ -65,7 +66,8 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea .queryVectorBuilder((QueryVectorBuilder) args[4]) .k((Integer) args[2]) .numCandidates((Integer) args[3]) - .similarity((Float) args[5]); + .similarity((Float) args[5]) + .rescoreVectorBuilder((RescoreVectorBuilder) args[6]); }); static { @@ -78,13 +80,18 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea ); PARSER.declareInt(optionalConstructorArg(), K_FIELD); PARSER.declareInt(optionalConstructorArg(), NUM_CANDS_FIELD); - PARSER.declareNamedObject( optionalConstructorArg(), (p, c, n) -> p.namedObject(QueryVectorBuilder.class, n, c), QUERY_VECTOR_BUILDER_FIELD ); PARSER.declareFloat(optionalConstructorArg(), VECTOR_SIMILARITY); + PARSER.declareField( + optionalConstructorArg(), + (p, c) -> RescoreVectorBuilder.fromXContent(p), + RESCORE_VECTOR_FIELD, + ObjectParser.ValueType.OBJECT + ); PARSER.declareFieldArray( KnnSearchBuilder.Builder::addFilterQueries, (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p), @@ -116,6 +123,7 @@ public static KnnSearchBuilder.Builder fromXContent(XContentParser parser) throw String queryName; float boost = DEFAULT_BOOST; InnerHitBuilder innerHitBuilder; + private final RescoreVectorBuilder rescoreVectorBuilder; /** * Defines a kNN search. @@ -124,14 +132,23 @@ public static KnnSearchBuilder.Builder fromXContent(XContentParser parser) throw * @param queryVector the query vector * @param k the final number of nearest neighbors to return as top hits * @param numCands the number of nearest neighbor candidates to consider per shard + * @param rescoreVectorBuilder rescore vector information */ - public KnnSearchBuilder(String field, float[] queryVector, int k, int numCands, Float similarity) { + public KnnSearchBuilder( + String field, + float[] queryVector, + int k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float similarity + ) { this( field, Objects.requireNonNull(VectorData.fromFloats(queryVector), format("[%s] cannot be null", QUERY_VECTOR_FIELD)), null, k, numCands, + rescoreVectorBuilder, similarity ); } @@ -144,8 +161,15 @@ public KnnSearchBuilder(String field, float[] queryVector, int k, int numCands, * @param k the final number of nearest neighbors to return as top hits * @param numCands the number of nearest neighbor candidates to consider per shard */ - public KnnSearchBuilder(String field, VectorData queryVector, int k, int numCands, Float similarity) { - this(field, queryVector, null, k, numCands, similarity); + public KnnSearchBuilder( + String field, + VectorData queryVector, + int k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float similarity + ) { + this(field, queryVector, null, k, numCands, rescoreVectorBuilder, similarity); } /** @@ -156,13 +180,21 @@ public KnnSearchBuilder(String field, VectorData queryVector, int k, int numCand * @param k the final number of nearest neighbors to return as top hits * @param numCands the number of nearest neighbor candidates to consider per shard */ - public KnnSearchBuilder(String field, QueryVectorBuilder queryVectorBuilder, int k, int numCands, Float similarity) { + public KnnSearchBuilder( + String field, + QueryVectorBuilder queryVectorBuilder, + int k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float similarity + ) { this( field, null, Objects.requireNonNull(queryVectorBuilder, format("[%s] cannot be null", QUERY_VECTOR_BUILDER_FIELD.getPreferredName())), k, numCands, + rescoreVectorBuilder, similarity ); } @@ -173,9 +205,22 @@ public KnnSearchBuilder( QueryVectorBuilder queryVectorBuilder, int k, int numCands, + RescoreVectorBuilder rescoreVectorBuilder, Float similarity ) { - this(field, queryVectorBuilder, queryVector, new ArrayList<>(), k, numCands, similarity, null, null, DEFAULT_BOOST); + this( + field, + queryVectorBuilder, + queryVector, + new ArrayList<>(), + k, + numCands, + rescoreVectorBuilder, + similarity, + null, + null, + DEFAULT_BOOST + ); } private KnnSearchBuilder( @@ -183,6 +228,7 @@ private KnnSearchBuilder( Supplier querySupplier, Integer k, Integer numCands, + RescoreVectorBuilder rescoreVectorBuilder, List filterQueries, Float similarity ) { @@ -194,6 +240,7 @@ private KnnSearchBuilder( this.filterQueries = filterQueries; this.querySupplier = querySupplier; this.similarity = similarity; + this.rescoreVectorBuilder = rescoreVectorBuilder; } private KnnSearchBuilder( @@ -203,6 +250,7 @@ private KnnSearchBuilder( List filterQueries, int k, int numCandidates, + RescoreVectorBuilder rescoreVectorBuilder, Float similarity, InnerHitBuilder innerHitBuilder, String queryName, @@ -242,6 +290,7 @@ private KnnSearchBuilder( this.queryVectorBuilder = queryVectorBuilder; this.k = k; this.numCands = numCandidates; + this.rescoreVectorBuilder = rescoreVectorBuilder; this.innerHitBuilder = innerHitBuilder; this.similarity = similarity; this.queryName = queryName; @@ -280,6 +329,11 @@ public KnnSearchBuilder(StreamInput in) throws IOException { if (in.getTransportVersion().onOrAfter(V_8_11_X)) { this.innerHitBuilder = in.readOptionalWriteable(InnerHitBuilder::new); } + if (in.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE)) { + this.rescoreVectorBuilder = in.readOptional(RescoreVectorBuilder::new); + } else { + this.rescoreVectorBuilder = null; + } } public int k() { @@ -290,6 +344,10 @@ public int getNumCands() { return numCands; } + public RescoreVectorBuilder getRescoreVectorBuilder() { + return rescoreVectorBuilder; + } + public QueryVectorBuilder getQueryVectorBuilder() { return queryVectorBuilder; } @@ -358,7 +416,7 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException { if (querySupplier.get() == null) { return this; } - return new KnnSearchBuilder(field, querySupplier.get(), k, numCands, similarity).boost(boost) + return new KnnSearchBuilder(field, querySupplier.get(), k, numCands, rescoreVectorBuilder, similarity).boost(boost) .queryName(queryName) .addFilterQueries(filterQueries) .innerHit(innerHitBuilder); @@ -381,7 +439,7 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException { } ll.onResponse(null); }))); - return new KnnSearchBuilder(field, toSet::get, k, numCands, filterQueries, similarity).boost(boost) + return new KnnSearchBuilder(field, toSet::get, k, numCands, rescoreVectorBuilder, filterQueries, similarity).boost(boost) .queryName(queryName) .innerHit(innerHitBuilder); } @@ -395,7 +453,7 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException { rewrittenQueries.add(rewrittenQuery); } if (changed) { - return new KnnSearchBuilder(field, queryVector, k, numCands, similarity).boost(boost) + return new KnnSearchBuilder(field, queryVector, k, numCands, rescoreVectorBuilder, similarity).boost(boost) .queryName(queryName) .addFilterQueries(rewrittenQueries) .innerHit(innerHitBuilder); @@ -407,7 +465,7 @@ public KnnVectorQueryBuilder toQueryBuilder() { if (queryVectorBuilder != null) { throw new IllegalArgumentException("missing rewrite"); } - return new KnnVectorQueryBuilder(field, queryVector, null, numCands, similarity).boost(boost) + return new KnnVectorQueryBuilder(field, queryVector, null, numCands, rescoreVectorBuilder, similarity).boost(boost) .queryName(queryName) .addFilterQueries(filterQueries); } @@ -423,6 +481,7 @@ public boolean equals(Object o) { KnnSearchBuilder that = (KnnSearchBuilder) o; return k == that.k && numCands == that.numCands + && Objects.equals(rescoreVectorBuilder, that.rescoreVectorBuilder) && Objects.equals(field, that.field) && Objects.equals(queryVector, that.queryVector) && Objects.equals(queryVectorBuilder, that.queryVectorBuilder) @@ -442,6 +501,7 @@ public int hashCode() { numCands, querySupplier, queryVectorBuilder, + rescoreVectorBuilder, similarity, Objects.hashCode(queryVector), Objects.hashCode(filterQueries), @@ -486,6 +546,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (queryName != null) { builder.field(NAME_FIELD.getPreferredName(), queryName); } + if (rescoreVectorBuilder != null) { + builder.field(RESCORE_VECTOR_FIELD.getPreferredName(), rescoreVectorBuilder); + } return builder; } @@ -526,6 +589,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(V_8_11_X)) { out.writeOptionalWriteable(innerHitBuilder); } + if (out.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE)) { + out.writeOptionalWriteable(rescoreVectorBuilder); + } } public static class Builder { @@ -540,6 +606,7 @@ public static class Builder { private String queryName; private float boost = DEFAULT_BOOST; private InnerHitBuilder innerHitBuilder; + private RescoreVectorBuilder rescoreVectorBuilder; public Builder addFilterQueries(List filterQueries) { Objects.requireNonNull(filterQueries); @@ -592,6 +659,11 @@ public Builder similarity(Float similarity) { return this; } + public Builder rescoreVectorBuilder(RescoreVectorBuilder rescoreVectorBuilder) { + this.rescoreVectorBuilder = rescoreVectorBuilder; + return this; + } + public KnnSearchBuilder build(int size) { int requestSize = size < 0 ? DEFAULT_SIZE : size; int adjustedK = k == null ? requestSize : k; @@ -605,6 +677,7 @@ public KnnSearchBuilder build(int size) { filterQueries, adjustedK, adjustedNumCandidates, + rescoreVectorBuilder, similarity, innerHitBuilder, queryName, diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java index a28448336ab3f..81b00f1329591 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java @@ -256,7 +256,7 @@ public KnnVectorQueryBuilder toQueryBuilder() { if (numCands > NUM_CANDS_LIMIT) { throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]"); } - return new KnnVectorQueryBuilder(field, queryVector, null, numCands, null); + return new KnnVectorQueryBuilder(field, queryVector, null, numCands, null, null); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index 5dd2cbf32dd12..88f6312fa7e6f 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -45,6 +45,7 @@ import java.util.Objects; import java.util.function.Supplier; +import static org.elasticsearch.TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE; import static org.elasticsearch.common.Strings.format; import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; @@ -68,8 +69,8 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder PARSER = new ConstructingObjectParser<>( "knn", args -> new KnnVectorQueryBuilder( @@ -79,6 +80,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder p.namedObject(QueryVectorBuilder.class, n, c), QUERY_VECTOR_BUILDER_FIELD ); + PARSER.declareField( + optionalConstructorArg(), + (p, c) -> RescoreVectorBuilder.fromXContent(p), + RESCORE_VECTOR_FIELD, + ObjectParser.ValueType.OBJECT + ); PARSER.declareFieldArray( KnnVectorQueryBuilder::addFilterQueries, (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p), @@ -115,14 +123,22 @@ public static KnnVectorQueryBuilder fromXContent(XContentParser parser) { private final String fieldName; private final VectorData queryVector; private final Integer k; - private Integer numCands; + private final Integer numCands; private final List filterQueries = new ArrayList<>(); private final Float vectorSimilarity; private final QueryVectorBuilder queryVectorBuilder; private final Supplier queryVectorSupplier; + private final RescoreVectorBuilder rescoreVectorBuilder; - public KnnVectorQueryBuilder(String fieldName, float[] queryVector, Integer k, Integer numCands, Float vectorSimilarity) { - this(fieldName, VectorData.fromFloats(queryVector), null, null, k, numCands, vectorSimilarity); + public KnnVectorQueryBuilder( + String fieldName, + float[] queryVector, + Integer k, + Integer numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float vectorSimilarity + ) { + this(fieldName, VectorData.fromFloats(queryVector), null, null, k, numCands, rescoreVectorBuilder, vectorSimilarity); } public KnnVectorQueryBuilder( @@ -132,15 +148,29 @@ public KnnVectorQueryBuilder( Integer numCands, Float vectorSimilarity ) { - this(fieldName, null, queryVectorBuilder, null, k, numCands, vectorSimilarity); + this(fieldName, null, queryVectorBuilder, null, k, numCands, null, vectorSimilarity); } - public KnnVectorQueryBuilder(String fieldName, byte[] queryVector, Integer k, Integer numCands, Float vectorSimilarity) { - this(fieldName, VectorData.fromBytes(queryVector), null, null, k, numCands, vectorSimilarity); + public KnnVectorQueryBuilder( + String fieldName, + byte[] queryVector, + Integer k, + Integer numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float vectorSimilarity + ) { + this(fieldName, VectorData.fromBytes(queryVector), null, null, k, numCands, rescoreVectorBuilder, vectorSimilarity); } - public KnnVectorQueryBuilder(String fieldName, VectorData queryVector, Integer k, Integer numCands, Float vectorSimilarity) { - this(fieldName, queryVector, null, null, k, numCands, vectorSimilarity); + public KnnVectorQueryBuilder( + String fieldName, + VectorData queryVector, + Integer k, + Integer numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float vectorSimilarity + ) { + this(fieldName, queryVector, null, null, k, numCands, rescoreVectorBuilder, vectorSimilarity); } private KnnVectorQueryBuilder( @@ -150,6 +180,7 @@ private KnnVectorQueryBuilder( Supplier queryVectorSupplier, Integer k, Integer numCands, + RescoreVectorBuilder rescoreVectorBuilder, Float vectorSimilarity ) { if (k != null && k < 1) { @@ -187,6 +218,7 @@ private KnnVectorQueryBuilder( this.vectorSimilarity = vectorSimilarity; this.queryVectorBuilder = queryVectorBuilder; this.queryVectorSupplier = queryVectorSupplier; + this.rescoreVectorBuilder = rescoreVectorBuilder; } public KnnVectorQueryBuilder(StreamInput in) throws IOException { @@ -227,6 +259,12 @@ public KnnVectorQueryBuilder(StreamInput in) throws IOException { } else { this.queryVectorBuilder = null; } + if (in.getTransportVersion().onOrAfter(KNN_QUERY_RESCORE_OVERSAMPLE)) { + this.rescoreVectorBuilder = in.readOptional(RescoreVectorBuilder::new); + } else { + this.rescoreVectorBuilder = null; + } + this.queryVectorSupplier = null; } @@ -261,6 +299,10 @@ public QueryVectorBuilder queryVectorBuilder() { return queryVectorBuilder; } + public RescoreVectorBuilder rescoreVectorBuilder() { + return rescoreVectorBuilder; + } + public KnnVectorQueryBuilder addFilterQuery(QueryBuilder filterQuery) { Objects.requireNonNull(filterQuery); this.filterQueries.add(filterQuery); @@ -327,6 +369,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0)) { out.writeOptionalNamedWriteable(queryVectorBuilder); } + if (out.getTransportVersion().onOrAfter(KNN_QUERY_RESCORE_OVERSAMPLE)) { + out.writeOptionalWriteable(rescoreVectorBuilder); + } } @Override @@ -360,6 +405,9 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep } builder.endArray(); } + if (rescoreVectorBuilder != null) { + builder.field(RESCORE_VECTOR_FIELD.getPreferredName(), rescoreVectorBuilder); + } boostAndQueryNameToXContent(builder); builder.endObject(); } @@ -375,7 +423,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { if (queryVectorSupplier.get() == null) { return this; } - return new KnnVectorQueryBuilder(fieldName, queryVectorSupplier.get(), k, numCands, vectorSimilarity).boost(boost) + return new KnnVectorQueryBuilder(fieldName, queryVectorSupplier.get(), k, numCands, rescoreVectorBuilder, vectorSimilarity) + .boost(boost) .queryName(queryName) .addFilterQueries(filterQueries); } @@ -397,9 +446,16 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { } ll.onResponse(null); }))); - return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, toSet::get, k, numCands, vectorSimilarity).boost( - boost - ).queryName(queryName).addFilterQueries(filterQueries); + return new KnnVectorQueryBuilder( + fieldName, + queryVector, + queryVectorBuilder, + toSet::get, + k, + numCands, + rescoreVectorBuilder, + vectorSimilarity + ).boost(boost).queryName(queryName).addFilterQueries(filterQueries); } if (ctx.convertToInnerHitsRewriteContext() != null) { return new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity).boost(boost).queryName(queryName); @@ -417,14 +473,25 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { rewrittenQueries.add(rewrittenQuery); } if (changed) { - return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, queryVectorSupplier, k, numCands, vectorSimilarity) - .boost(boost) - .queryName(queryName) - .addFilterQueries(rewrittenQueries); + return new KnnVectorQueryBuilder( + fieldName, + queryVector, + queryVectorBuilder, + queryVectorSupplier, + k, + numCands, + rescoreVectorBuilder, + vectorSimilarity + ).boost(boost).queryName(queryName).addFilterQueries(rewrittenQueries); } return this; } + @Override + protected QueryBuilder doIndexMetadataRewrite(QueryRewriteContext context) throws IOException { + return super.doIndexMetadataRewrite(context); + } + @Override protected Query doToQuery(SearchExecutionContext context) throws IOException { MappedFieldType fieldType = context.getFieldType(fieldName); @@ -459,6 +526,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) fieldType; String parentPath = context.nestedLookup().getNestedParent(fieldName); + Float numCandidatesFactor = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.numCandidatesFactor(); if (parentPath != null) { final BitSetProducer parentBitSet; @@ -490,14 +558,31 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { // Now join the filterQuery & parentFilter to provide the matching blocks of children filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet); } - return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, filterQuery, vectorSimilarity, parentBitSet); + return vectorFieldType.createKnnQuery( + queryVector, + k, + adjustedNumCands, + numCandidatesFactor, + filterQuery, + vectorSimilarity, + parentBitSet + ); } - return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, filterQuery, vectorSimilarity, null); + return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, numCandidatesFactor, filterQuery, vectorSimilarity, null); } @Override protected int doHashCode() { - return Objects.hash(fieldName, Objects.hashCode(queryVector), k, numCands, filterQueries, vectorSimilarity, queryVectorBuilder); + return Objects.hash( + fieldName, + Objects.hashCode(queryVector), + k, + numCands, + filterQueries, + vectorSimilarity, + queryVectorBuilder, + rescoreVectorBuilder + ); } @Override @@ -508,7 +593,8 @@ protected boolean doEquals(KnnVectorQueryBuilder other) { && Objects.equals(numCands, other.numCands) && Objects.equals(filterQueries, other.filterQueries) && Objects.equals(vectorSimilarity, other.vectorSimilarity) - && Objects.equals(queryVectorBuilder, other.queryVectorBuilder); + && Objects.equals(queryVectorBuilder, other.queryVectorBuilder) + && Objects.equals(rescoreVectorBuilder, other.rescoreVectorBuilder); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ProfilingQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/QueryProfilerProvider.java similarity index 96% rename from server/src/main/java/org/elasticsearch/search/vectors/ProfilingQuery.java rename to server/src/main/java/org/elasticsearch/search/vectors/QueryProfilerProvider.java index 4d36d8eae57cc..47b0e1e299968 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ProfilingQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/QueryProfilerProvider.java @@ -18,7 +18,7 @@ * must provide an implementation for profile() to store profiling information in the {@link QueryProfiler}. */ -public interface ProfilingQuery { +public interface QueryProfilerProvider { /** * Store the profiling information in the {@link QueryProfiler} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java new file mode 100644 index 0000000000000..a9c606b1f8618 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.queries.function.FunctionScoreQuery; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.DoubleValuesSource; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; +import org.elasticsearch.search.profile.query.QueryProfiler; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; + +/** + * Wraps an internal query to rescore the results using a similarity function over the original, non-quantized vectors of a vector field + */ +public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvider { + private final String fieldName; + private final float[] floatTarget; + private final VectorSimilarityFunction vectorSimilarityFunction; + private final Integer k; + private final Query innerQuery; + + private QueryProfilerProvider vectorProfiling; + + public RescoreKnnVectorQuery( + String fieldName, + float[] floatTarget, + VectorSimilarityFunction vectorSimilarityFunction, + Integer k, + Query innerQuery + ) { + this.fieldName = fieldName; + this.floatTarget = floatTarget; + this.vectorSimilarityFunction = vectorSimilarityFunction; + this.k = k; + this.innerQuery = innerQuery; + } + + @Override + public Query rewrite(IndexSearcher searcher) throws IOException { + DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); + // Vector similarity VectorSimilarityFloatValueSource keep track of the compared vectors - we need that in case we don't need + // to calculate top k and return directly the query to understand how many comparisons were done + vectorProfiling = (QueryProfilerProvider) valueSource; + FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource); + Query query = searcher.rewrite(functionScoreQuery); + + if (k == null) { + // No need to calculate top k - let the request size limit the results. + return query; + } + + // Retrieve top k documents from the rescored query + TopDocs topDocs = searcher.search(query, k); + ScoreDoc[] scoreDocs = topDocs.scoreDocs; + int[] docIds = new int[scoreDocs.length]; + float[] scores = new float[scoreDocs.length]; + for (int i = 0; i < scoreDocs.length; i++) { + docIds[i] = scoreDocs[i].doc; + scores[i] = scoreDocs[i].score; + } + + return new KnnScoreDocQuery(docIds, scores, searcher.getIndexReader()); + } + + public Query innerQuery() { + return innerQuery; + } + + public Integer k() { + return k; + } + + @Override + public void profile(QueryProfiler queryProfiler) { + if (innerQuery instanceof QueryProfilerProvider queryProfilerProvider) { + queryProfilerProvider.profile(queryProfiler); + } + + if (vectorProfiling == null) { + throw new IllegalStateException("Query should have been rewritten"); + } + vectorProfiling.profile(queryProfiler); + } + + @Override + public void visit(QueryVisitor visitor) { + innerQuery.visit(visitor.getSubVisitor(BooleanClause.Occur.MUST, this)); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RescoreKnnVectorQuery that = (RescoreKnnVectorQuery) o; + return Objects.equals(fieldName, that.fieldName) + && Arrays.equals(floatTarget, that.floatTarget) + && vectorSimilarityFunction == that.vectorSimilarityFunction + && Objects.equals(k, that.k) + && Objects.equals(innerQuery, that.innerQuery); + } + + @Override + public int hashCode() { + return Objects.hash(fieldName, Arrays.hashCode(floatTarget), vectorSimilarityFunction, k, innerQuery); + } + + @Override + public String toString(String field) { + return "KnnRescoreVectorQuery{" + + "fieldName='" + + fieldName + + '\'' + + ", floatTarget=" + + floatTarget[0] + + "..." + + ", vectorSimilarityFunction=" + + vectorSimilarityFunction + + ", k=" + + k + + ", vectorQuery=" + + innerQuery + + '}'; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java new file mode 100644 index 0000000000000..4604d4f0ea325 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class RescoreVectorBuilder implements Writeable, ToXContentObject { + + public static final ParseField NUM_CANDIDATES_FACTOR_FIELD = new ParseField("num_candidates_factor"); + public static final float MIN_OVERSAMPLE = 1.0F; + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "rescore_vector", + args -> new RescoreVectorBuilder((Float) args[0]) + ); + + static { + PARSER.declareFloat(ConstructingObjectParser.constructorArg(), NUM_CANDIDATES_FACTOR_FIELD); + } + + // Oversample is required as of now as it is the only field in the rescore vector + private final float numCandidatesFactor; + + public RescoreVectorBuilder(float numCandidatesFactor) { + Objects.requireNonNull(numCandidatesFactor, "[" + NUM_CANDIDATES_FACTOR_FIELD.getPreferredName() + "] must be set"); + if (numCandidatesFactor < MIN_OVERSAMPLE) { + throw new IllegalArgumentException("[" + NUM_CANDIDATES_FACTOR_FIELD.getPreferredName() + "] must be >= " + MIN_OVERSAMPLE); + } + this.numCandidatesFactor = numCandidatesFactor; + } + + public RescoreVectorBuilder(StreamInput in) throws IOException { + this.numCandidatesFactor = in.readFloat(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeFloat(numCandidatesFactor); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(NUM_CANDIDATES_FACTOR_FIELD.getPreferredName(), numCandidatesFactor); + builder.endObject(); + return builder; + } + + public static RescoreVectorBuilder fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RescoreVectorBuilder that = (RescoreVectorBuilder) o; + return Objects.equals(numCandidatesFactor, that.numCandidatesFactor); + } + + @Override + public int hashCode() { + return Objects.hashCode(numCandidatesFactor); + } + + public float numCandidatesFactor() { + return numCandidatesFactor; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java index 5219778047bcd..4cd267bd07186 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java @@ -21,6 +21,7 @@ import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; import org.elasticsearch.common.lucene.search.function.MinScoreScorer; +import org.elasticsearch.search.profile.query.QueryProfiler; import java.io.IOException; import java.util.Objects; @@ -28,9 +29,10 @@ import static org.elasticsearch.common.Strings.format; /** - * This query provides a simple post-filter for the provided Query. The query is assumed to be a Knn(Float|Byte)VectorQuery. + * This query provides a simple post-filter for the provided Query to limit the results of the inner query to those that have a similarity + * above a certain threshold */ -public class VectorSimilarityQuery extends Query { +public class VectorSimilarityQuery extends Query implements QueryProfilerProvider { private final float similarity; private final float docScore; private final Query innerKnnQuery; @@ -78,6 +80,13 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo return new MinScoreWeight(innerWeight, docScore, similarity, this, boost); } + @Override + public void profile(QueryProfiler queryProfiler) { + if (innerKnnQuery instanceof QueryProfilerProvider queryProfilerProvider) { + queryProfilerProvider.profile(queryProfiler); + } + } + @Override public String toString(String field) { return "VectorSimilarityQuery[" diff --git a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java index 64362daf7f75c..193855a4c835f 100644 --- a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java @@ -344,8 +344,8 @@ public void testRewriteShardSearchRequestWithRank() { SearchSourceBuilder ssb = new SearchSourceBuilder().query(bm25) .knnSearch( List.of( - new KnnSearchBuilder("vector", new float[] { 0.0f }, 10, 100, null), - new KnnSearchBuilder("vector2", new float[] { 0.0f }, 10, 100, null) + new KnnSearchBuilder("vector", new float[] { 0.0f }, 10, 100, null, null), + new KnnSearchBuilder("vector2", new float[] { 0.0f }, 10, 100, null, null) ) ) .rankBuilder(new TestRankBuilder(100)); diff --git a/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java b/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java index 042890001c2ea..353188af8be3c 100644 --- a/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java @@ -63,7 +63,7 @@ public void testKnnSearchRemovedVector() throws IOException { client().prepareUpdate("index", "0").setDoc("vector", (Object) null).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).get(); float[] queryVector = randomVector(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 20, 50, null).boost(5.0f); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 20, 50, null, null).boost(5.0f); assertResponse( client().prepareSearch("index") .setKnnSearch(List.of(knnSearch)) @@ -107,7 +107,7 @@ public void testKnnWithQuery() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).boost(5.0f).queryName("knn"); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null, null).boost(5.0f).queryName("knn"); assertResponse( client().prepareSearch("index") .setKnnSearch(List.of(knnSearch)) @@ -156,7 +156,7 @@ public void testKnnFilter() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).addFilterQuery( + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null, null).addFilterQuery( QueryBuilders.termsQuery("field", "second") ); assertResponse(client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).addFetchField("*").setSize(10), response -> { @@ -199,7 +199,7 @@ public void testKnnFilterWithRewrite() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).addFilterQuery( + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null, null).addFilterQuery( QueryBuilders.termsLookupQuery("field", new TermsLookup("index", "lookup-doc", "other-field")) ); assertResponse(client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).setSize(10), response -> { @@ -246,8 +246,8 @@ public void testMultiKnnClauses() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(20f, 21f); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).boost(5.0f); - KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, null).boost(10.0f); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null, null).boost(5.0f); + KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, null, null).boost(10.0f); assertResponse( client().prepareSearch("index") .setKnnSearch(List.of(knnSearch, knnSearch2)) @@ -308,8 +308,8 @@ public void testMultiKnnClausesSameDoc() throws IOException { float[] queryVector = randomVector(); // Having the same query vector and same docs should mean our KNN scores are linearly combined if the same doc is matched - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null); - KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null, null); + KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, null, null); assertResponse( client().prepareSearch("index") .setKnnSearch(List.of(knnSearch)) @@ -381,7 +381,7 @@ public void testKnnFilteredAlias() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 10, 50, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 10, 50, null, null); final int expectedHitCount = expectedHits; assertResponse(client().prepareSearch("test-alias").setKnnSearch(List.of(knnSearch)).setSize(10), response -> { assertHitCount(response, expectedHitCount); @@ -417,7 +417,9 @@ public void testKnnSearchAction() throws IOException { // how the action works (it builds a kNN query under the hood) float[] queryVector = randomVector(); assertResponse( - client().prepareSearch("index1", "index2").setQuery(new KnnVectorQueryBuilder("vector", queryVector, null, 5, null)).setSize(2), + client().prepareSearch("index1", "index2") + .setQuery(new KnnVectorQueryBuilder("vector", queryVector, null, 5, null, null)) + .setSize(2), response -> { // The total hits is num_cands * num_shards, since the query gathers num_cands hits from each shard assertHitCount(response, 5 * 2); @@ -450,7 +452,7 @@ public void testKnnVectorsWith4096Dims() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(4096); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 50, null).boost(5.0f); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 50, null, null).boost(5.0f); assertResponse(client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).addFetchField("*").setSize(10), response -> { assertHitCount(response, 3); assertEquals(3, response.getHits().getHits().length); diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java index 0c11123960622..152e5a9e6eac6 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java @@ -36,6 +36,7 @@ import org.elasticsearch.search.suggest.SuggestBuilder; import org.elasticsearch.search.suggest.term.TermSuggestionBuilder; import org.elasticsearch.search.vectors.KnnSearchBuilder; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.TransportVersionUtils; @@ -115,8 +116,22 @@ public void testSerializationMultiKNN() throws Exception { searchRequest.source() .knnSearch( List.of( - new KnnSearchBuilder(randomAlphaOfLength(10), new float[] { 1, 2 }, 5, 10, randomBoolean() ? null : randomFloat()), - new KnnSearchBuilder(randomAlphaOfLength(10), new float[] { 4, 12, 41 }, 3, 5, randomBoolean() ? null : randomFloat()) + new KnnSearchBuilder( + randomAlphaOfLength(10), + new float[] { 1, 2 }, + 5, + 10, + randomRescoreVectorBuilder(), + randomBoolean() ? null : randomFloat() + ), + new KnnSearchBuilder( + randomAlphaOfLength(10), + new float[] { 4, 12, 41 }, + 3, + 5, + randomRescoreVectorBuilder(), + randomBoolean() ? null : randomFloat() + ) ) ); expectThrows( @@ -131,7 +146,16 @@ public void testSerializationMultiKNN() throws Exception { searchRequest.source() .knnSearch( - List.of(new KnnSearchBuilder(randomAlphaOfLength(10), new float[] { 1, 2 }, 5, 10, randomBoolean() ? null : randomFloat())) + List.of( + new KnnSearchBuilder( + randomAlphaOfLength(10), + new float[] { 1, 2 }, + 5, + 10, + randomRescoreVectorBuilder(), + randomBoolean() ? null : randomFloat() + ) + ) ); // Shouldn't throw because its just one KNN request copyWriteable( @@ -142,6 +166,10 @@ public void testSerializationMultiKNN() throws Exception { ); } + private static RescoreVectorBuilder randomRescoreVectorBuilder() { + return randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); + } + public void testRandomVersionSerialization() throws IOException { SearchRequest searchRequest = createSearchRequest(); TransportVersion version = TransportVersionUtils.randomVersion(random()); @@ -454,7 +482,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(100)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) .size(0) ); ActionRequestValidationException validationErrors = searchRequest.validate(); @@ -466,7 +494,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(1)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) .size(2) ); ActionRequestValidationException validationErrors = searchRequest.validate(); @@ -493,7 +521,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(100)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) ).scroll(new TimeValue(1000)); ActionRequestValidationException validationErrors = searchRequest.validate(); assertNotNull(validationErrors); @@ -504,7 +532,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(9)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) ); ActionRequestValidationException validationErrors = searchRequest.validate(); assertNotNull(validationErrors); @@ -518,7 +546,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(3)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) .size(3) .from(4) ); @@ -529,7 +557,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(100)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) .addRescorer(new QueryRescorerBuilder(QueryBuilders.termQuery("rescore", "another term"))) ); ActionRequestValidationException validationErrors = searchRequest.validate(); @@ -541,7 +569,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(100)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) .suggest(new SuggestBuilder().setGlobalText("test").addSuggestion("suggestion", new TermSuggestionBuilder("term"))) ); ActionRequestValidationException validationErrors = searchRequest.validate(); diff --git a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java index 367508283bb93..a94427cd0df6f 100644 --- a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java @@ -1367,7 +1367,7 @@ public void testShouldMinimizeRoundtrips() throws Exception { { SearchRequest searchRequest = new SearchRequest(); SearchSourceBuilder source = new SearchSourceBuilder(); - source.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50, null))); + source.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50, null, null))); searchRequest.source(source); searchRequest.setCcsMinimizeRoundtrips(true); @@ -1382,7 +1382,7 @@ public void testAdjustSearchType() { // If the search includes kNN, we should always use DFS_QUERY_THEN_FETCH SearchRequest searchRequest = new SearchRequest(); SearchSourceBuilder source = new SearchSourceBuilder(); - source.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50, null))); + source.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50, null, null))); searchRequest.source(source); TransportSearchAction.adjustSearchType(searchRequest, randomBoolean()); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index c043b9ffb381a..342d61b78defd 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -1674,7 +1674,7 @@ public void testByteVectorQueryBoundaries() throws IOException { Exception e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 128, 0, 0 }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 128, 0, 0 }), 3, 3, null, null, null, null) ); assertThat( e.getMessage(), @@ -1683,7 +1683,15 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0f, -129.0f }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery( + VectorData.fromFloats(new float[] { 0.0f, 0f, -129.0f }), + 3, + 3, + null, + null, + null, + null + ) ); assertThat( e.getMessage(), @@ -1692,7 +1700,7 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }), 3, 3, null, null, null, null) ); assertThat( e.getMessage(), @@ -1701,7 +1709,7 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }), 3, 3, null, null, null, null) ); assertThat( e.getMessage(), @@ -1710,7 +1718,15 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery( + VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }), + 3, + 3, + null, + null, + null, + null + ) ); assertThat(e.getMessage(), containsString("element_type [byte] vectors do not support NaN values but found [NaN] at dim [0];")); @@ -1722,6 +1738,7 @@ public void testByteVectorQueryBoundaries() throws IOException { 3, null, null, + null, null ) ); @@ -1738,6 +1755,7 @@ public void testByteVectorQueryBoundaries() throws IOException { 3, null, null, + null, null ) ); @@ -1765,7 +1783,15 @@ public void testFloatVectorQueryBoundaries() throws IOException { Exception e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery( + VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }), + 3, + 3, + null, + null, + null, + null + ) ); assertThat(e.getMessage(), containsString("element_type [float] vectors do not support NaN values but found [NaN] at dim [0];")); @@ -1777,6 +1803,7 @@ public void testFloatVectorQueryBoundaries() throws IOException { 3, null, null, + null, null ) ); @@ -1793,6 +1820,7 @@ public void testFloatVectorQueryBoundaries() throws IOException { 3, null, null, + null, null ) ); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 9e819f38eae6e..d37b4a4bacb4e 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -23,6 +23,9 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorSimilarity; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.vectors.DenseVectorQuery; +import org.elasticsearch.search.vectors.ESKnnByteVectorQuery; +import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery; +import org.elasticsearch.search.vectors.RescoreKnnVectorQuery; import org.elasticsearch.search.vectors.VectorData; import java.io.IOException; @@ -31,8 +34,12 @@ import java.util.Set; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.FLOAT; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; public class DenseVectorFieldTypeTests extends FieldTypeTestCase { private final boolean indexed; @@ -69,11 +76,27 @@ private DenseVectorFieldMapper.IndexOptions randomIndexOptionsAll() { ); } + private DenseVectorFieldMapper.IndexOptions randomIndexOptionsHnswQuantized() { + return randomFrom( + new DenseVectorFieldMapper.Int8HnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) + ), + new DenseVectorFieldMapper.Int4HnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) + ), + new DenseVectorFieldMapper.BBQHnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)) + ); + } + private DenseVectorFieldType createFloatFieldType() { return new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, BBQ_MIN_DIMS, indexed, VectorSimilarity.COSINE, @@ -86,7 +109,7 @@ private DenseVectorFieldType createByteFieldType() { return new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.BYTE, + BYTE, 5, true, VectorSimilarity.COSINE, @@ -159,7 +182,7 @@ public void testCreateNestedKnnQuery() { DenseVectorFieldType field = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, dims, true, VectorSimilarity.COSINE, @@ -170,14 +193,14 @@ public void testCreateNestedKnnQuery() { for (int i = 0; i < dims; i++) { queryVector[i] = randomFloat(); } - Query query = field.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, producer); + Query query = field.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, producer); assertThat(query, instanceOf(DiversifyingChildrenFloatKnnVectorQuery.class)); } { DenseVectorFieldType field = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.BYTE, + BYTE, dims, true, VectorSimilarity.COSINE, @@ -191,11 +214,11 @@ public void testCreateNestedKnnQuery() { floatQueryVector[i] = queryVector[i]; } VectorData vectorData = new VectorData(null, queryVector); - Query query = field.createKnnQuery(vectorData, 10, 10, null, null, producer); + Query query = field.createKnnQuery(vectorData, 10, 10, null, null, null, producer); assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); vectorData = new VectorData(floatQueryVector, null); - query = field.createKnnQuery(vectorData, 10, 10, null, null, producer); + query = field.createKnnQuery(vectorData, 10, 10, null, null, null, producer); assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); } } @@ -209,7 +232,7 @@ public void testExactKnnQuery() { DenseVectorFieldType field = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, dims, true, VectorSimilarity.COSINE, @@ -227,7 +250,7 @@ public void testExactKnnQuery() { DenseVectorFieldType field = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.BYTE, + BYTE, dims, true, VectorSimilarity.COSINE, @@ -247,7 +270,7 @@ public void testFloatCreateKnnQuery() { DenseVectorFieldType unindexedField = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, 4, false, VectorSimilarity.COSINE, @@ -256,14 +279,22 @@ public void testFloatCreateKnnQuery() { ); IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> unindexedField.createKnnQuery(VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f, 0.0f }), 10, 10, null, null, null) + () -> unindexedField.createKnnQuery( + VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f, 0.0f }), + 10, + 10, + null, + null, + null, + null + ) ); assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]")); DenseVectorFieldType dotProductField = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, BBQ_MIN_DIMS, true, VectorSimilarity.DOT_PRODUCT, @@ -276,14 +307,14 @@ public void testFloatCreateKnnQuery() { } e = expectThrows( IllegalArgumentException.class, - () -> dotProductField.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null) + () -> dotProductField.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors.")); DenseVectorFieldType cosineField = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, BBQ_MIN_DIMS, true, VectorSimilarity.COSINE, @@ -292,7 +323,7 @@ public void testFloatCreateKnnQuery() { ); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[BBQ_MIN_DIMS]), 10, 10, null, null, null) + () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[BBQ_MIN_DIMS]), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); } @@ -302,7 +333,7 @@ public void testCreateKnnQueryMaxDims() { DenseVectorFieldType fieldWith4096dims = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, 4096, true, VectorSimilarity.COSINE, @@ -313,7 +344,7 @@ public void testCreateKnnQueryMaxDims() { for (int i = 0; i < 4096; i++) { queryVector[i] = randomFloat(); } - Query query = fieldWith4096dims.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null); + Query query = fieldWith4096dims.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, null); assertThat(query, instanceOf(KnnFloatVectorQuery.class)); } @@ -321,7 +352,7 @@ public void testCreateKnnQueryMaxDims() { DenseVectorFieldType fieldWith4096dims = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.BYTE, + BYTE, 4096, true, VectorSimilarity.COSINE, @@ -333,7 +364,7 @@ public void testCreateKnnQueryMaxDims() { queryVector[i] = randomByte(); } VectorData vectorData = new VectorData(null, queryVector); - Query query = fieldWith4096dims.createKnnQuery(vectorData, 10, 10, null, null, null); + Query query = fieldWith4096dims.createKnnQuery(vectorData, 10, 10, null, null, null, null); assertThat(query, instanceOf(KnnByteVectorQuery.class)); } } @@ -342,7 +373,7 @@ public void testByteCreateKnnQuery() { DenseVectorFieldType unindexedField = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.BYTE, + BYTE, 3, false, VectorSimilarity.COSINE, @@ -351,14 +382,14 @@ public void testByteCreateKnnQuery() { ); IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> unindexedField.createKnnQuery(VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f }), 10, 10, null, null, null) + () -> unindexedField.createKnnQuery(VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f }), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]")); DenseVectorFieldType cosineField = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.BYTE, + BYTE, 3, true, VectorSimilarity.COSINE, @@ -367,14 +398,94 @@ public void testByteCreateKnnQuery() { ); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.0f, 0.0f }), 10, 10, null, null, null) + () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.0f, 0.0f }), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery(new VectorData(null, new byte[] { 0, 0, 0 }), 10, 10, null, null, null) + () -> cosineField.createKnnQuery(new VectorData(null, new byte[] { 0, 0, 0 }), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); } + + public void testRescoreOversampleUsedWithoutQuantization() { + DenseVectorFieldMapper.ElementType elementType = randomFrom(FLOAT, BYTE); + DenseVectorFieldType nonQuantizedField = new DenseVectorFieldType( + "f", + IndexVersion.current(), + elementType, + 3, + true, + VectorSimilarity.COSINE, + randomIndexOptionsNonQuantized(), + Collections.emptyMap() + ); + + Query knnQuery = nonQuantizedField.createKnnQuery( + new VectorData(null, new byte[] { 1, 4, 10 }), + 10, + 100, + randomFloatBetween(1.0F, 10.0F, false), + null, + null, + null + ); + + if (elementType == BYTE) { + ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery) knnQuery; + assertThat(esKnnQuery.getK(), is(100)); + assertThat(esKnnQuery.kParam(), is(10)); + } else { + ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) knnQuery; + assertThat(esKnnQuery.getK(), is(100)); + assertThat(esKnnQuery.kParam(), is(10)); + } + } + + public void testRescoreOversampleModifiesNumCandidates() { + DenseVectorFieldType fieldType = new DenseVectorFieldType( + "f", + IndexVersion.current(), + FLOAT, + 3, + true, + VectorSimilarity.COSINE, + randomIndexOptionsHnswQuantized(), + Collections.emptyMap() + ); + + // Total results is k, internal k is multiplied by oversample + checkRescoreQueryParameters(fieldType, 10, 200, 2.5F, null, 500, 10); + // If numCands < k, update numCands to k + checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, null, 50, 10); + // Oversampling limits for num candidates + checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, null, 10000, 1000); + checkRescoreQueryParameters(fieldType, 5000, 7500, 2.5F, null, 10000, 5000); + } + + private static void checkRescoreQueryParameters( + DenseVectorFieldType fieldType, + Integer k, + int candidates, + float numCandsFactor, + Integer expectedK, + int expectedCandidates, + int expectedResults + ) { + Query query = fieldType.createKnnQuery( + VectorData.fromFloats(new float[] { 1, 4, 10 }), + k, + candidates, + numCandsFactor, + null, + null, + null + ); + RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query; + ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) rescoreQuery.innerQuery(); + assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults)); + assertThat("Unexpected k parameter", esKnnQuery.kParam(), equalTo(expectedK)); + assertThat("Unexpected candidates", esKnnQuery.getK(), equalTo(expectedCandidates)); + } } diff --git a/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java index 6076665e26824..7f4f95cdd2416 100644 --- a/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java @@ -270,6 +270,7 @@ public void testKnnRewriteForInnerHits() throws IOException { new float[] { 1.0f, 2.0f, 3.0f }, null, 1, + null, null ); NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder( diff --git a/server/src/test/java/org/elasticsearch/rest/action/search/RestSearchActionTests.java b/server/src/test/java/org/elasticsearch/rest/action/search/RestSearchActionTests.java index d6953e79a0c3f..580dad8128494 100644 --- a/server/src/test/java/org/elasticsearch/rest/action/search/RestSearchActionTests.java +++ b/server/src/test/java/org/elasticsearch/rest/action/search/RestSearchActionTests.java @@ -83,7 +83,7 @@ public void testValidateSearchRequest() { .build(); SearchRequest searchRequest = new SearchRequest(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", new float[] { 1, 1, 1 }, 10, 100, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", new float[] { 1, 1, 1 }, 10, 100, null, null); searchRequest.source(new SearchSourceBuilder().knnSearch(List.of(knnSearch))); Exception ex = expectThrows( diff --git a/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java b/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java index 380b5189b3efc..070773a2a7d42 100644 --- a/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java @@ -825,7 +825,7 @@ public void testSearchSectionsUsageCollection() throws IOException { searchSourceBuilder.fetchField("field"); // these are not correct runtime mappings but they are counted compared to empty object searchSourceBuilder.runtimeMappings(Collections.singletonMap("field", "keyword")); - searchSourceBuilder.knnSearch(List.of(new KnnSearchBuilder("field", new float[] {}, 2, 5, null))); + searchSourceBuilder.knnSearch(List.of(new KnnSearchBuilder("field", new float[] {}, 2, 5, null, null))); searchSourceBuilder.pointInTimeBuilder(new PointInTimeBuilder(new BytesArray("pitid"))); searchSourceBuilder.docValueField("field"); searchSourceBuilder.storedField("field"); diff --git a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java index 7923cb5f0d918..da28b0eff441f 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.usage.SearchUsage; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -51,8 +52,19 @@ public static KnnRetrieverBuilder createRandomKnnRetrieverBuilder() { int k = randomIntBetween(1, 100); int numCands = randomIntBetween(k + 20, 1000); Float similarity = randomBoolean() ? null : randomFloat(); - - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(field, vector, null, k, numCands, similarity); + RescoreVectorBuilder rescoreVectorBuilder = randomBoolean() + ? null + : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); + + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( + field, + vector, + null, + k, + numCands, + rescoreVectorBuilder, + similarity + ); List preFilterQueryBuilders = new ArrayList<>(); @@ -93,6 +105,7 @@ public void testRewrite() throws IOException { assertNull(source.query()); assertThat(source.knnSearch().size(), equalTo(1)); assertThat(source.knnSearch().get(0).getFilterQueries().size(), equalTo(knnRetriever.preFilterQueryBuilders.size())); + assertThat(source.knnSearch().get(0).getRescoreVectorBuilder(), equalTo(knnRetriever.rescoreVectorBuilder())); for (int j = 0; j < knnRetriever.preFilterQueryBuilders.size(); j++) { assertThat( source.knnSearch().get(0).getFilterQueries().get(j), diff --git a/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java b/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java index ccf33c0b71b6b..eafab1d25c38e 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.test.ESTestCase; import java.io.IOException; @@ -69,6 +70,7 @@ private List innerRetrievers(QueryRewriteContext queryRewriteC null, randomInt(10), randomIntBetween(10, 100), + randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)), randomFloat() ); if (randomBoolean()) { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index f93bdd14f0645..375712ee60861 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -37,11 +37,16 @@ import org.elasticsearch.test.TransportVersionUtils; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; +import org.junit.Before; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.NUM_CANDS_OVERSAMPLE_LIMIT; import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -52,23 +57,70 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCase { private static final String VECTOR_FIELD = "vector"; private static final String VECTOR_ALIAS_FIELD = "vector_alias"; - static final int VECTOR_DIMENSION = 3; + protected static final Set QUANTIZED_INDEX_TYPES = Set.of( + "int8_hnsw", + "int4_hnsw", + "bbq_hnsw", + "int8_flat", + "int4_flat", + "bbq_flat" + ); + protected static final Set NON_QUANTIZED_INDEX_TYPES = Set.of("hnsw", "flat"); + protected static final Set ALL_INDEX_TYPES = Stream.concat(QUANTIZED_INDEX_TYPES.stream(), NON_QUANTIZED_INDEX_TYPES.stream()) + .collect(Collectors.toUnmodifiableSet()); + protected static String indexType; + protected static int vectorDimensions; + + @Before + private void checkIndexTypeAndDimensions() { + // Check that these are initialized - should be done as part of the createAdditionalMappings method + assertNotNull(indexType); + assertNotEquals(0, vectorDimensions); + } abstract DenseVectorFieldMapper.ElementType elementType(); - abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, Integer k, int numCands, Float similarity); + abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder( + String fieldName, + Integer k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float similarity + ); + + protected boolean isQuantizedElementType() { + return QUANTIZED_INDEX_TYPES.contains(indexType); + } + + protected abstract String randomIndexType(); @Override protected void initializeAdditionalMappings(MapperService mapperService) throws IOException { + + // These fields are initialized here, as mappings are initialized only once per test class. + // We want the subclasses to be able to override the index type and vector dimensions so we don't make this static / BeforeClass + // for initialization. + indexType = randomIndexType(); + if (indexType.contains("bbq")) { + vectorDimensions = 64; + } else if (indexType.contains("int4")) { + vectorDimensions = 4; + } else { + vectorDimensions = 3; + } + XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() .startObject("properties") .startObject(VECTOR_FIELD) .field("type", "dense_vector") - .field("dims", VECTOR_DIMENSION) + .field("dims", vectorDimensions) .field("index", true) .field("similarity", "l2_norm") .field("element_type", elementType()) + .startObject("index_options") + .field("type", indexType) + .endObject() .endObject() .startObject(VECTOR_ALIAS_FIELD) .field("type", "alias") @@ -88,7 +140,13 @@ protected KnnVectorQueryBuilder doCreateTestQueryBuilder() { String fieldName = randomBoolean() ? VECTOR_FIELD : VECTOR_ALIAS_FIELD; Integer k = randomBoolean() ? null : randomIntBetween(1, 100); int numCands = randomIntBetween(k == null ? DEFAULT_SIZE : k + 20, 1000); - KnnVectorQueryBuilder queryBuilder = createKnnVectorQueryBuilder(fieldName, k, numCands, randomFloat()); + KnnVectorQueryBuilder queryBuilder = createKnnVectorQueryBuilder( + fieldName, + k, + numCands, + randomRescoreVectorBuilder(), + randomFloat() + ); if (randomBoolean()) { List filters = new ArrayList<>(); @@ -99,24 +157,32 @@ protected KnnVectorQueryBuilder doCreateTestQueryBuilder() { } queryBuilder.addFilterQueries(filters); } + return queryBuilder; } + protected RescoreVectorBuilder randomRescoreVectorBuilder() { + if (randomBoolean()) { + return null; + } + + return new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); + } + @Override protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query query, SearchExecutionContext context) throws IOException { if (queryBuilder.getVectorSimilarity() != null) { assertTrue(query instanceof VectorSimilarityQuery); - Query knnQuery = ((VectorSimilarityQuery) query).getInnerKnnQuery(); assertThat(((VectorSimilarityQuery) query).getSimilarity(), equalTo(queryBuilder.getVectorSimilarity())); - switch (elementType()) { - case FLOAT -> assertTrue(knnQuery instanceof ESKnnFloatVectorQuery); - case BYTE -> assertTrue(knnQuery instanceof ESKnnByteVectorQuery); - } - } else { - switch (elementType()) { - case FLOAT -> assertTrue(query instanceof ESKnnFloatVectorQuery); - case BYTE -> assertTrue(query instanceof ESKnnByteVectorQuery); - } + query = ((VectorSimilarityQuery) query).getInnerKnnQuery(); + } + if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) { + RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query; + query = rescoreQuery.innerQuery(); + } + switch (elementType()) { + case FLOAT -> assertTrue(query instanceof ESKnnFloatVectorQuery); + case BYTE -> assertTrue(query instanceof ESKnnByteVectorQuery); } BooleanQuery.Builder builder = new BooleanQuery.Builder(); @@ -126,21 +192,18 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que BooleanQuery booleanQuery = builder.build(); Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery; // The field should always be resolved to the concrete field + Integer k = queryBuilder.k(); + Integer numCands = queryBuilder.numCands(); + if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) { + Float numCandsFactor = queryBuilder.rescoreVectorBuilder().numCandidatesFactor(); + int minCands = k == null ? 1 : k; + numCands = Math.max(minCands, (int) Math.ceil(numCands * numCandsFactor)); + numCands = Math.min(numCands, NUM_CANDS_OVERSAMPLE_LIMIT); + } + Query knnVectorQueryBuilt = switch (elementType()) { - case BYTE, BIT -> new ESKnnByteVectorQuery( - VECTOR_FIELD, - queryBuilder.queryVector().asByteVector(), - queryBuilder.k(), - queryBuilder.numCands(), - filterQuery - ); - case FLOAT -> new ESKnnFloatVectorQuery( - VECTOR_FIELD, - queryBuilder.queryVector().asFloatVector(), - queryBuilder.k(), - queryBuilder.numCands(), - filterQuery - ); + case BYTE, BIT -> new ESKnnByteVectorQuery(VECTOR_FIELD, queryBuilder.queryVector().asByteVector(), k, numCands, filterQuery); + case FLOAT -> new ESKnnFloatVectorQuery(VECTOR_FIELD, queryBuilder.queryVector().asFloatVector(), k, numCands, filterQuery); }; if (query instanceof VectorSimilarityQuery vectorSimilarityQuery) { query = vectorSimilarityQuery.getInnerKnnQuery(); @@ -150,17 +213,17 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que public void testWrongDimension() { SearchExecutionContext context = createSearchExecutionContext(); - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 5, 10, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 5, 10, null, null); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context)); assertThat( e.getMessage(), - containsString("The query vector has a different number of dimensions [2] than the document vectors [3]") + containsString("The query vector has a different number of dimensions [2] than the document vectors [" + vectorDimensions + "]") ); } public void testNonexistentField() { SearchExecutionContext context = createSearchExecutionContext(); - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null, null); context.setAllowUnmappedFields(false); QueryShardException e = expectThrows(QueryShardException.class, () -> query.doToQuery(context)); assertThat(e.getMessage(), containsString("No field mapping can be found for the field with name [nonexistent]")); @@ -168,7 +231,7 @@ public void testNonexistentField() { public void testNonexistentFieldReturnEmpty() throws IOException { SearchExecutionContext context = createSearchExecutionContext(); - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null, null); Query queryNone = query.doToQuery(context); assertThat(queryNone, instanceOf(MatchNoDocsQuery.class)); } @@ -180,6 +243,7 @@ public void testWrongFieldType() { new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, + null, null ); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context)); @@ -191,14 +255,14 @@ public void testNumCandsLessThanK() { int numCands = 3; IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 1.0f, 1.0f }, k, numCands, null) + () -> new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 1.0f, 1.0f }, k, numCands, null, null) ); assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]")); } @Override public void testValidOutput() { - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, null, 10, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, null, 10, null, null); String expected = """ { "knn" : { @@ -213,7 +277,7 @@ public void testValidOutput() { }"""; assertEquals(expected, query.toString()); - KnnVectorQueryBuilder query2 = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, 5, 10, null); + KnnVectorQueryBuilder query2 = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, 5, 10, null, null); String expected2 = """ { "knn" : { @@ -238,7 +302,8 @@ public void testMustRewrite() throws IOException { KnnVectorQueryBuilder query = new KnnVectorQueryBuilder( VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, - VECTOR_DIMENSION, + vectorDimensions, + null, null, null ); @@ -254,9 +319,14 @@ public void testMustRewrite() throws IOException { public void testBWCVersionSerializationFilters() throws IOException { KnnVectorQueryBuilder query = createTestQueryBuilder(); VectorData vectorData = VectorData.fromFloats(query.queryVector().asFloatVector()); - KnnVectorQueryBuilder queryNoFilters = new KnnVectorQueryBuilder(query.getFieldName(), vectorData, null, query.numCands(), null) - .queryName(query.queryName()) - .boost(query.boost()); + KnnVectorQueryBuilder queryNoFilters = new KnnVectorQueryBuilder( + query.getFieldName(), + vectorData, + null, + query.numCands(), + null, + null + ).queryName(query.queryName()).boost(query.boost()); TransportVersion beforeFilterVersion = TransportVersionUtils.randomVersionBetween( random(), TransportVersions.V_8_0_0, @@ -268,10 +338,14 @@ public void testBWCVersionSerializationFilters() throws IOException { public void testBWCVersionSerializationSimilarity() throws IOException { KnnVectorQueryBuilder query = createTestQueryBuilder(); VectorData vectorData = VectorData.fromFloats(query.queryVector().asFloatVector()); - KnnVectorQueryBuilder queryNoSimilarity = new KnnVectorQueryBuilder(query.getFieldName(), vectorData, null, query.numCands(), null) - .queryName(query.queryName()) - .boost(query.boost()) - .addFilterQueries(query.filterQueries()); + KnnVectorQueryBuilder queryNoSimilarity = new KnnVectorQueryBuilder( + query.getFieldName(), + vectorData, + null, + query.numCands(), + null, + null + ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries()); assertBWCSerialization(query, queryNoSimilarity, TransportVersions.V_8_7_0); } @@ -289,11 +363,34 @@ public void testBWCVersionSerializationQuery() throws IOException { vectorData, null, query.numCands(), + null, similarity ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries()); assertBWCSerialization(query, queryOlderVersion, differentQueryVersion); } + public void testBWCVersionSerializationRescoreVector() throws IOException { + KnnVectorQueryBuilder query = createTestQueryBuilder(); + TransportVersion version = TransportVersionUtils.randomVersionBetween( + random(), + TransportVersions.V_8_8_1, + TransportVersionUtils.getPreviousVersion(TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE) + ); + VectorData vectorData = version.onOrAfter(TransportVersions.V_8_14_0) + ? query.queryVector() + : VectorData.fromFloats(query.queryVector().asFloatVector()); + Integer k = version.before(TransportVersions.V_8_15_0) ? null : query.k(); + KnnVectorQueryBuilder queryNoRescoreVector = new KnnVectorQueryBuilder( + query.getFieldName(), + vectorData, + k, + query.numCands(), + null, + query.getVectorSimilarity() + ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries()); + assertBWCSerialization(query, queryNoRescoreVector, version); + } + private void assertBWCSerialization(QueryBuilder newQuery, QueryBuilder bwcQuery, TransportVersion version) throws IOException { assertSerialization(bwcQuery, version); try (BytesStreamOutput output = new BytesStreamOutput()) { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java index 0fc2304e904a4..f6c2e754cec63 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java @@ -18,11 +18,22 @@ DenseVectorFieldMapper.ElementType elementType() { } @Override - protected KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, Integer k, int numCands, Float similarity) { - byte[] vector = new byte[VECTOR_DIMENSION]; + protected KnnVectorQueryBuilder createKnnVectorQueryBuilder( + String fieldName, + Integer k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float similarity + ) { + byte[] vector = new byte[vectorDimensions]; for (int i = 0; i < vector.length; i++) { vector[i] = randomByte(); } - return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, similarity); + return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, rescoreVectorBuilder, similarity); + } + + @Override + protected String randomIndexType() { + return randomFrom(NON_QUANTIZED_INDEX_TYPES); } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java index ba2245ced3305..6f67e4be29a06 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java @@ -18,11 +18,22 @@ DenseVectorFieldMapper.ElementType elementType() { } @Override - KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, Integer k, int numCands, Float similarity) { - float[] vector = new float[VECTOR_DIMENSION]; + KnnVectorQueryBuilder createKnnVectorQueryBuilder( + String fieldName, + Integer k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float similarity + ) { + float[] vector = new float[vectorDimensions]; for (int i = 0; i < vector.length; i++) { vector[i] = randomFloat(); } - return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, similarity); + return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, rescoreVectorBuilder, similarity); + } + + @Override + protected String randomIndexType() { + return randomFrom(ALL_INDEX_TYPES); } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java index 2184e8af54aed..a39438af5b72a 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java @@ -52,8 +52,18 @@ public static KnnSearchBuilder randomTestInstance() { float[] vector = randomVector(dim); int k = randomIntBetween(1, 100); int numCands = randomIntBetween(k + 20, 1000); - - KnnSearchBuilder builder = new KnnSearchBuilder(field, vector, k, numCands, randomBoolean() ? null : randomFloat()); + RescoreVectorBuilder rescoreVectorBuilder = randomBoolean() + ? null + : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); + + KnnSearchBuilder builder = new KnnSearchBuilder( + field, + vector, + k, + numCands, + rescoreVectorBuilder, + randomBoolean() ? null : randomFloat() + ); if (randomBoolean()) { builder.boost(randomFloat()); } @@ -100,46 +110,90 @@ protected KnnSearchBuilder createTestInstance() { @Override protected KnnSearchBuilder mutateInstance(KnnSearchBuilder instance) { - switch (random().nextInt(7)) { + switch (random().nextInt(8)) { case 0: String newField = randomValueOtherThan(instance.field, () -> randomAlphaOfLength(5)); - return new KnnSearchBuilder(newField, instance.queryVector, instance.k, instance.numCands, instance.similarity).boost( - instance.boost - ); + return new KnnSearchBuilder( + newField, + instance.queryVector, + instance.k, + instance.numCands, + instance.getRescoreVectorBuilder(), + instance.similarity + ).boost(instance.boost); case 1: float[] newVector = randomValueOtherThan(instance.queryVector.asFloatVector(), () -> randomVector(5)); - return new KnnSearchBuilder(instance.field, newVector, instance.k, instance.numCands, instance.similarity).boost( - instance.boost - ); + return new KnnSearchBuilder( + instance.field, + newVector, + instance.k, + instance.numCands, + instance.getRescoreVectorBuilder(), + instance.similarity + ).boost(instance.boost); case 2: // given how the test instance is created, we have a 20-value gap between `k` and `numCands` so we SHOULD be safe Integer newK = randomValueOtherThan(instance.k, () -> instance.k + ESTestCase.randomInt(10)); - return new KnnSearchBuilder(instance.field, instance.queryVector, newK, instance.numCands, instance.similarity).boost( - instance.boost - ); + return new KnnSearchBuilder( + instance.field, + instance.queryVector, + newK, + instance.numCands, + instance.getRescoreVectorBuilder(), + instance.similarity + ).boost(instance.boost); case 3: Integer newNumCands = randomValueOtherThan(instance.numCands, () -> instance.numCands + ESTestCase.randomInt(100)); - return new KnnSearchBuilder(instance.field, instance.queryVector, instance.k, newNumCands, instance.similarity).boost( - instance.boost - ); + return new KnnSearchBuilder( + instance.field, + instance.queryVector, + instance.k, + newNumCands, + instance.getRescoreVectorBuilder(), + instance.similarity + ).boost(instance.boost); case 4: - return new KnnSearchBuilder(instance.field, instance.queryVector, instance.k, instance.numCands, instance.similarity) - .addFilterQueries(instance.filterQueries) + return new KnnSearchBuilder( + instance.field, + instance.queryVector, + instance.k, + instance.numCands, + instance.getRescoreVectorBuilder(), + instance.similarity + ).addFilterQueries(instance.filterQueries) .addFilterQuery(QueryBuilders.termQuery("new_field", "new-value")) .boost(instance.boost); case 5: float newBoost = randomValueOtherThan(instance.boost, ESTestCase::randomFloat); - return new KnnSearchBuilder(instance.field, instance.queryVector, instance.k, instance.numCands, instance.similarity) - .addFilterQueries(instance.filterQueries) - .boost(newBoost); + return new KnnSearchBuilder( + instance.field, + instance.queryVector, + instance.k, + instance.numCands, + instance.getRescoreVectorBuilder(), + instance.similarity + ).addFilterQueries(instance.filterQueries).boost(newBoost); case 6: return new KnnSearchBuilder( instance.field, instance.queryVector, instance.k, instance.numCands, + instance.getRescoreVectorBuilder(), randomValueOtherThan(instance.similarity, ESTestCase::randomFloat) ).addFilterQueries(instance.filterQueries).boost(instance.boost); + case 7: + return new KnnSearchBuilder( + instance.field, + instance.queryVector, + instance.k, + instance.numCands, + randomValueOtherThan( + instance.getRescoreVectorBuilder(), + () -> new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)) + ), + instance.similarity + ).addFilterQueries(instance.filterQueries).boost(instance.boost); default: throw new IllegalStateException(); } @@ -151,7 +205,10 @@ public void testToQueryBuilder() { int k = randomIntBetween(1, 100); int numCands = randomIntBetween(k, 1000); Float similarity = randomBoolean() ? null : randomFloat(); - KnnSearchBuilder builder = new KnnSearchBuilder(field, vector, k, numCands, similarity); + RescoreVectorBuilder rescoreVectorBuilder = randomBoolean() + ? null + : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); + KnnSearchBuilder builder = new KnnSearchBuilder(field, vector, k, numCands, rescoreVectorBuilder, similarity); float boost = AbstractQueryBuilder.DEFAULT_BOOST; if (randomBoolean()) { @@ -167,15 +224,16 @@ public void testToQueryBuilder() { builder.addFilterQuery(filter); } - QueryBuilder expected = new KnnVectorQueryBuilder(field, vector, null, numCands, similarity).addFilterQueries(filterQueries) - .boost(boost); + QueryBuilder expected = new KnnVectorQueryBuilder(field, vector, null, numCands, rescoreVectorBuilder, similarity).addFilterQueries( + filterQueries + ).boost(boost); assertEquals(expected, builder.toQueryBuilder()); } public void testNumCandsLessThanK() { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> new KnnSearchBuilder("field", randomVector(3), 50, 10, null) + () -> new KnnSearchBuilder("field", randomVector(3), 50, 10, null, null) ); assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]")); } @@ -183,7 +241,7 @@ public void testNumCandsLessThanK() { public void testNumCandsExceedsLimit() { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> new KnnSearchBuilder("field", randomVector(3), 100, 10002, null) + () -> new KnnSearchBuilder("field", randomVector(3), 100, 10002, null, null) ); assertThat(e.getMessage(), containsString("[num_candidates] cannot exceed [10000]")); } @@ -191,18 +249,28 @@ public void testNumCandsExceedsLimit() { public void testInvalidK() { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> new KnnSearchBuilder("field", randomVector(3), 0, 100, null) + () -> new KnnSearchBuilder("field", randomVector(3), 0, 100, null, null) ); assertThat(e.getMessage(), containsString("[k] must be greater than 0")); } + public void testInvalidRescoreVectorBuilder() { + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> new KnnSearchBuilder("field", randomVector(3), 10, 100, new RescoreVectorBuilder(0.99F), null) + ); + assertThat(e.getMessage(), containsString("[num_candidates_factor] must be >= 1.0")); + } + public void testRewrite() throws Exception { float[] expectedArray = randomVector(randomIntBetween(10, 1024)); + RescoreVectorBuilder expectedRescore = new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); KnnSearchBuilder searchBuilder = new KnnSearchBuilder( "field", new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(expectedArray), 5, 10, + expectedRescore, 1f ); searchBuilder.boost(randomFloat()); @@ -220,6 +288,7 @@ public void testRewrite() throws Exception { assertThat(rewritten.filterQueries, hasSize(1)); assertThat(rewritten.similarity, equalTo(1f)); assertThat(((RewriteableQuery) rewritten.filterQueries.get(0)).rewrites, equalTo(1)); + assertThat(rewritten.getRescoreVectorBuilder(), equalTo(expectedRescore)); } public static float[] randomVector(int dim) { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java new file mode 100644 index 0000000000000..7bbe7dcc155c5 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -0,0 +1,241 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.elasticsearch.search.profile.query.QueryProfiler; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.stream.Collectors; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; + +public class RescoreKnnVectorQueryTests extends ESTestCase { + + public static final String FIELD_NAME = "float_vector"; + private final int numDocs; + private final Integer k; + + public RescoreKnnVectorQueryTests(boolean useK) { + this.numDocs = randomIntBetween(10, 100); + this.k = useK ? randomIntBetween(1, numDocs - 1) : null; + } + + public void testRescoreDocs() throws Exception { + int numDims = randomIntBetween(5, 100); + + Integer adjustedK = k; + if (k == null) { + adjustedK = numDocs; + } + + try (Directory d = newDirectory()) { + addRandomDocuments(numDocs, d, numDims); + + try (IndexReader reader = DirectoryReader.open(d)) { + + // Use a RescoreKnnVectorQuery with a match all query, to ensure we get scoring of 1 from the inner query + // and thus we're rescoring the top k docs. + float[] queryVector = randomVector(numDims); + RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery( + FIELD_NAME, + queryVector, + VectorSimilarityFunction.COSINE, + adjustedK, + new MatchAllDocsQuery() + ); + + IndexSearcher searcher = newSearcher(reader, true, false); + TopDocs docs = searcher.search(rescoreKnnVectorQuery, numDocs); + Map rescoredDocs = Arrays.stream(docs.scoreDocs) + .collect(Collectors.toMap(scoreDoc -> scoreDoc.doc, scoreDoc -> scoreDoc.score)); + + assertThat(rescoredDocs.size(), equalTo(adjustedK)); + + Collection rescoredScores = new HashSet<>(rescoredDocs.values()); + + // Collect all docs sequentially, and score them using the similarity function to get the top K scores + PriorityQueue topK = new PriorityQueue<>((o1, o2) -> Float.compare(o2, o1)); + + for (LeafReaderContext leafReaderContext : reader.leaves()) { + FloatVectorValues vectorValues = leafReaderContext.reader().getFloatVectorValues(FIELD_NAME); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + while (iterator.nextDoc() != NO_MORE_DOCS) { + float[] vectorData = vectorValues.vectorValue(iterator.docID()); + float score = VectorSimilarityFunction.COSINE.compare(queryVector, vectorData); + topK.add(score); + int docId = iterator.docID(); + // If the doc has been retrieved from the RescoreKnnVectorQuery, check the score is the same and remove it + // to ensure we found them all + if (rescoredDocs.containsKey(docId)) { + assertThat(rescoredDocs.get(docId), equalTo(score)); + rescoredDocs.remove(docId); + } + } + } + + assertThat(rescoredDocs.size(), equalTo(0)); + + // Check top scoring docs are contained in rescored docs + for (int i = 0; i < adjustedK; i++) { + Float topScore = topK.poll(); + if (rescoredScores.contains(topScore) == false) { + fail("Top score " + topScore + " not contained in rescored doc scores " + rescoredScores); + } + } + } + } + } + + public void testProfiling() throws Exception { + int numDims = randomIntBetween(5, 100); + + try (Directory d = newDirectory()) { + addRandomDocuments(numDocs, d, numDims); + + try (IndexReader reader = DirectoryReader.open(d)) { + float[] queryVector = randomVector(numDims); + + checkProfiling(queryVector, reader, new MatchAllDocsQuery()); + checkProfiling(queryVector, reader, new MockQueryProfilerProvider(randomIntBetween(1, 100))); + } + } + } + + private void checkProfiling(float[] queryVector, IndexReader reader, Query innerQuery) throws IOException { + RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery( + FIELD_NAME, + queryVector, + VectorSimilarityFunction.COSINE, + k, + innerQuery + ); + IndexSearcher searcher = newSearcher(reader, true, false); + searcher.search(rescoreKnnVectorQuery, numDocs); + + QueryProfiler queryProfiler = new QueryProfiler(); + rescoreKnnVectorQuery.profile(queryProfiler); + + long expectedVectorOpsCount = numDocs; + if (innerQuery instanceof QueryProfilerProvider queryProfilerProvider) { + QueryProfiler anotherProfiler = new QueryProfiler(); + queryProfilerProvider.profile(anotherProfiler); + assertThat(anotherProfiler.getVectorOpsCount(), greaterThan(0L)); + expectedVectorOpsCount += anotherProfiler.getVectorOpsCount(); + } + + assertThat(queryProfiler.getVectorOpsCount(), equalTo(expectedVectorOpsCount)); + } + + private static float[] randomVector(int numDimensions) { + float[] vector = new float[numDimensions]; + for (int j = 0; j < numDimensions; j++) { + vector[j] = randomFloatBetween(0, 1, true); + } + return vector; + } + + /** + * A mock query that is used to test profiling + */ + private static class MockQueryProfilerProvider extends Query implements QueryProfilerProvider { + + private final long vectorOpsCount; + + private MockQueryProfilerProvider(long vectorOpsCount) { + this.vectorOpsCount = vectorOpsCount; + } + + @Override + public String toString(String field) { + return ""; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + throw new UnsupportedEncodingException("Should have been rewritten"); + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + return new MatchAllDocsQuery(); + } + + @Override + public void visit(QueryVisitor visitor) {} + + @Override + public boolean equals(Object obj) { + return obj instanceof MockQueryProfilerProvider; + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public void profile(QueryProfiler queryProfiler) { + queryProfiler.addVectorOpsCount(vectorOpsCount); + } + } + + private static void addRandomDocuments(int numDocs, Directory d, int numDims) throws IOException { + try (IndexWriter w = new IndexWriter(d, newIndexWriterConfig())) { + for (int i = 0; i < numDocs; i++) { + Document document = new Document(); + float[] vector = randomVector(numDims); + KnnFloatVectorField vectorField = new KnnFloatVectorField(FIELD_NAME, vector); + document.add(vectorField); + w.addDocument(document); + } + w.commit(); + w.forceMerge(1); + } + } + + @ParametersFactory + public static Iterable parameters() { + List params = new ArrayList<>(); + params.add(new Object[] { true }); + params.add(new Object[] { false }); + + return params; + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/search/RandomSearchRequestGenerator.java b/test/framework/src/main/java/org/elasticsearch/search/RandomSearchRequestGenerator.java index 363d34ca3ff86..6e8cf735983aa 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/RandomSearchRequestGenerator.java +++ b/test/framework/src/main/java/org/elasticsearch/search/RandomSearchRequestGenerator.java @@ -36,6 +36,7 @@ import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.search.suggest.SuggestBuilder; import org.elasticsearch.search.vectors.KnnSearchBuilder; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.test.AbstractQueryTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; @@ -61,6 +62,7 @@ import static org.elasticsearch.test.ESTestCase.randomByte; import static org.elasticsearch.test.ESTestCase.randomDouble; import static org.elasticsearch.test.ESTestCase.randomFloat; +import static org.elasticsearch.test.ESTestCase.randomFloatBetween; import static org.elasticsearch.test.ESTestCase.randomFrom; import static org.elasticsearch.test.ESTestCase.randomInt; import static org.elasticsearch.test.ESTestCase.randomIntBetween; @@ -264,7 +266,12 @@ public static SearchSourceBuilder randomSearchSourceBuilder( } int k = randomIntBetween(1, 100); int numCands = randomIntBetween(k, 1000); - knnSearchBuilders.add(new KnnSearchBuilder(field, vector, k, numCands, randomBoolean() ? null : randomFloat())); + RescoreVectorBuilder rescoreVectorBuilder = randomBoolean() + ? null + : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); + knnSearchBuilders.add( + new KnnSearchBuilder(field, vector, k, numCands, rescoreVectorBuilder, randomBoolean() ? null : randomFloat()) + ); } builder.knnSearch(knnSearchBuilders); } diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java index e00dc9f693ff3..1ca6ef0b43a38 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java @@ -23,6 +23,7 @@ import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.vectors.KnnSearchBuilder; import org.elasticsearch.search.vectors.QueryVectorBuilder; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.test.client.NoOpClient; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -97,6 +98,7 @@ public final void testKnnSearchBuilderWireSerialization() throws IOException { createTestInstance(), 5, 10, + randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)), randomBoolean() ? null : randomFloat() ); searchBuilder.queryName(randomAlphaOfLengthBetween(5, 10)); @@ -120,6 +122,7 @@ public final void testKnnSearchRewrite() throws Exception { queryVectorBuilder, 5, 10, + randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)), randomBoolean() ? null : randomFloat() ); KnnSearchBuilder serialized = copyWriteable( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 683bb5a53028b..cda77233bdfd4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -560,7 +560,7 @@ yield new SparseVectorQueryBuilder( k = Math.max(k, DEFAULT_SIZE); } - yield new KnnVectorQueryBuilder(inferenceResultsFieldName, inference, k, null, null); + yield new KnnVectorQueryBuilder(inferenceResultsFieldName, inference, k, null, null, null); } default -> throw new IllegalStateException( "Field [" diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java index 7dc4d99e06acc..78743409ca178 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java @@ -91,7 +91,7 @@ public void testDenseVector() throws Exception { Map queryMap = (Map) queries.get("dense_vector_1"); float[] vector = readDenseVector(queryMap.get("embeddings")); var fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) mapperService.mappingLookup().getFieldType(SEMANTIC_FIELD_E5); - KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder(fieldType.getEmbeddingsField().fullPath(), vector, 10, 10, null); + KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder(fieldType.getEmbeddingsField().fullPath(), vector, 10, 10, null, null); NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(fieldType.getChunksField().fullPath(), knnQuery, ScoreMode.Max); var shardRequest = createShardSearchRequest(nestedQueryBuilder); var sourceToParse = new SourceToParse("0", readSampleDoc("sample-doc.json.gz"), XContentType.JSON); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java index 916703446995d..084a7f3de4a53 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java @@ -102,7 +102,9 @@ public void testTelemetryForRRFRetriever() throws IOException { // search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers` { - performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null))); + performSearch( + new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null, null)) + ); } // search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under @@ -116,7 +118,7 @@ public void testTelemetryForRRFRetriever() throws IOException { { performSearch( new SearchSourceBuilder().retriever( - new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null)) + new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null, null)) ) ); } @@ -146,7 +148,9 @@ public void testTelemetryForRRFRetriever() throws IOException { // search#6 - this will record 1 entry for "knn" in `sections` { - performSearch(new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null)))); + performSearch( + new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null, null))) + ); } // search#7 - this will record 1 entry for "query" in `sections`, and 1 for "match_all" under `queries` diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankMultiShardIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankMultiShardIT.java index b4cf409f5fd72..457c57410d168 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankMultiShardIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankMultiShardIT.java @@ -136,7 +136,7 @@ public void setupSuiteScopeCluster() throws Exception { public void testTotalDocsSmallerThanSize() { float[] queryVector = { 0.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 3, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 3, null, null); assertResponse( prepareSearch("tiny_index").setRankBuilder(new RRFRankBuilder(100, 1)) .setKnnSearch(List.of(knnSearch)) @@ -167,7 +167,7 @@ public void testTotalDocsSmallerThanSize() { public void testBM25AndKnn() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(false) @@ -208,8 +208,8 @@ public void testBM25AndKnn() { public void testMultipleOnlyKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(51, 1)) .setTrackTotalHits(true) @@ -260,8 +260,8 @@ public void testMultipleOnlyKnn() { public void testBM25AndMultipleKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(51, 1)) .setTrackTotalHits(false) @@ -332,7 +332,7 @@ public void testBM25AndMultipleKnn() { public void testBM25AndKnnWithBucketAggregation() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(true) @@ -389,8 +389,8 @@ public void testBM25AndKnnWithBucketAggregation() { public void testMultipleOnlyKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(51, 1)) .setTrackTotalHits(false) @@ -457,8 +457,8 @@ public void testMultipleOnlyKnnWithAggregation() { public void testBM25AndMultipleKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(51, 1)) .setTrackTotalHits(true) @@ -704,7 +704,7 @@ public void testMultiBM25WithAggregation() { public void testMultiBM25AndSingleKnn() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(false) @@ -762,7 +762,7 @@ public void testMultiBM25AndSingleKnn() { public void testMultiBM25AndSingleKnnWithAggregation() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(false) @@ -837,8 +837,8 @@ public void testMultiBM25AndSingleKnnWithAggregation() { public void testMultiBM25AndMultipleKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(false) @@ -899,8 +899,8 @@ public void testMultiBM25AndMultipleKnn() { public void testMultiBM25AndMultipleKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(false) @@ -979,7 +979,7 @@ public void testBasicRRFExplain() { // the first result should be the one present in both queries (i.e. doc with text0: 10 and vector: [10]) and the other ones // should only match the knn query float[] queryVector = { 9f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null).queryName("my_knn_search"); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null).queryName("my_knn_search"); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(100, 1)) .setKnnSearch(List.of(knnSearch)) @@ -1045,7 +1045,7 @@ public void testRRFExplainUnknownField() { // in this test we try knn with a query on an unknown field that would be rewritten to MatchNoneQuery // so we expect results and explanations only for the first part float[] queryVector = { 9f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null).queryName("my_knn_search"); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null).queryName("my_knn_search"); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(100, 1)) .setKnnSearch(List.of(knnSearch)) @@ -1112,7 +1112,7 @@ public void testRRFExplainOneUnknownFieldSubSearches() { // while the other one would produce a match. // So, we'd have a total of 3 queries, a (rewritten) MatchNoneQuery, a TermQuery, and a kNN query float[] queryVector = { 9f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null).queryName("my_knn_search"); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null).queryName("my_knn_search"); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(100, 1)) .setKnnSearch(List.of(knnSearch)) diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankSingleShardIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankSingleShardIT.java index ed26aa50ffa62..a4e7db3b3e3fe 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankSingleShardIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankSingleShardIT.java @@ -131,7 +131,7 @@ public void setupIndices() throws Exception { public void testTotalDocsSmallerThanSize() { float[] queryVector = { 0.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 3, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 3, null, null); assertResponse( client().prepareSearch("tiny_index") @@ -164,7 +164,7 @@ public void testTotalDocsSmallerThanSize() { public void testBM25AndKnn() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) @@ -206,8 +206,8 @@ public void testBM25AndKnn() { public void testMultipleOnlyKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(51, 1)) @@ -259,8 +259,8 @@ public void testMultipleOnlyKnn() { public void testBM25AndMultipleKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(51, 1)) @@ -332,7 +332,7 @@ public void testBM25AndMultipleKnn() { public void testBM25AndKnnWithBucketAggregation() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) @@ -390,8 +390,8 @@ public void testBM25AndKnnWithBucketAggregation() { public void testMultipleOnlyKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(51, 1)) @@ -459,8 +459,8 @@ public void testMultipleOnlyKnnWithAggregation() { public void testBM25AndMultipleKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(51, 1)) @@ -709,7 +709,7 @@ public void testMultiBM25WithAggregation() { public void testMultiBM25AndSingleKnn() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) @@ -768,7 +768,7 @@ public void testMultiBM25AndSingleKnn() { public void testMultiBM25AndSingleKnnWithAggregation() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) @@ -844,8 +844,8 @@ public void testMultiBM25AndSingleKnnWithAggregation() { public void testMultiBM25AndMultipleKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) @@ -907,8 +907,8 @@ public void testMultiBM25AndMultipleKnn() { public void testMultiBM25AndMultipleKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java index ae35153b6f39f..6854fc436038f 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java @@ -183,7 +183,15 @@ public void testRRFPagination() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( + VECTOR_FIELD, + new float[] { 2.0f }, + null, + 10, + 100, + null, + null + ); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -233,7 +241,7 @@ public void testRRFWithAggs() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -288,7 +296,7 @@ public void testRRFWithCollapse() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -345,7 +353,7 @@ public void testRRFRetrieverWithCollapseAndAggs() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -411,7 +419,7 @@ public void testMultipleRRFRetrievers() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -430,7 +438,7 @@ public void testMultipleRRFRetrievers() { ), // this one bring just doc 7 which should be ranked first eventually new CompoundRetrieverBuilder.RetrieverSource( - new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, null), + new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, null, null), null ) ), @@ -477,7 +485,7 @@ public void testRRFExplainWithNamedRetrievers() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -536,7 +544,7 @@ public void testRRFExplainWithAnotherNestedRRF() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); RRFRetrieverBuilder nestedRRF = new RRFRetrieverBuilder( Arrays.asList( @@ -756,6 +764,7 @@ public void testRRFFiltersPropagatedToKnnQueryVectorBuilder() { new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(new float[] { 3 }), 10, 10, + null, null ); source.retriever( @@ -809,7 +818,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws throw new IllegalStateException("Should not be called"); } }; - var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null); + var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null, null); var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null)); var rrf = new RRFRetrieverBuilder( List.of(new CompoundRetrieverBuilder.RetrieverSource(knn, null), new CompoundRetrieverBuilder.RetrieverSource(standard, null)), diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java index b1358f11bf633..a00b940bbed62 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java @@ -149,7 +149,7 @@ public void testRRFRetrieverWithNestedQuery() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 6 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 6.0f }, null, 1, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 6.0f }, null, 1, 100, null, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java index 4eaea9a596361..9bc1cd80ea381 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java @@ -103,7 +103,9 @@ public void testTelemetryForRRFRetriever() throws IOException { // search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers` { - performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null))); + performSearch( + new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null, null)) + ); } // search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under @@ -117,7 +119,7 @@ public void testTelemetryForRRFRetriever() throws IOException { { performSearch( new SearchSourceBuilder().retriever( - new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null)) + new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null, null)) ) ); } @@ -136,7 +138,7 @@ public void testTelemetryForRRFRetriever() throws IOException { new RRFRetrieverBuilder( Arrays.asList( new CompoundRetrieverBuilder.RetrieverSource( - new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null), + new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null, null), null ), new CompoundRetrieverBuilder.RetrieverSource( @@ -153,7 +155,9 @@ public void testTelemetryForRRFRetriever() throws IOException { // search#6 - this will record 1 entry for "knn" in `sections` { - performSearch(new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null)))); + performSearch( + new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null, null))) + ); } // search#7 - this will record 1 entry for "query" in `sections`, and 1 for "match_all" under `queries` diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java index df65aac5b79b8..b5bca57478684 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java @@ -215,6 +215,7 @@ public RetrieverBuilder toRetriever(SearchSourceBuilder source, Predicate Date: Wed, 11 Dec 2024 12:38:42 +0400 Subject: [PATCH 23/28] Fix rest-api-spec and docs for bulk API (#118415) --- docs/reference/docs/bulk.asciidoc | 4 ++++ .../src/main/resources/rest-api-spec/api/bulk.json | 8 ++------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/reference/docs/bulk.asciidoc b/docs/reference/docs/bulk.asciidoc index 69bf3d1b7db5a..6edccfcdb13f5 100644 --- a/docs/reference/docs/bulk.asciidoc +++ b/docs/reference/docs/bulk.asciidoc @@ -257,6 +257,10 @@ include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=refresh] (Optional, Boolean) If `true`, the request's actions must target an index alias. Defaults to `false`. +`require_data_stream`:: +(Optional, Boolean) If `true`, the request's actions must target a data stream (existing or to-be-created). +Defaults to `false`. + include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=routing] include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=source] diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/bulk.json b/rest-api-spec/src/main/resources/rest-api-spec/api/bulk.json index 9ced5d3e8c454..f9c8041d7221f 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/bulk.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/bulk.json @@ -56,10 +56,6 @@ "type":"time", "description":"Explicit operation timeout" }, - "type":{ - "type":"string", - "description":"Default document type for items which don't provide one" - }, "_source":{ "type":"list", "description":"True or false to return the _source field or not, or default list of fields to return, can be overridden on each sub-request" @@ -78,11 +74,11 @@ }, "require_alias": { "type": "boolean", - "description": "Sets require_alias for all incoming documents. Defaults to unset (false)" + "description": "If true, the request’s actions must target an index alias. Defaults to false." }, "require_data_stream": { "type": "boolean", - "description": "When true, requires the destination to be a data stream (existing or to-be-created). Default is false" + "description": "If true, the request's actions must target a data stream (existing or to-be-created). Default to false" }, "list_executed_pipelines": { "type": "boolean", From 1e5d9b92a6e5d90a788cf9a562b284ceb10ebd98 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Wed, 11 Dec 2024 08:57:08 +0000 Subject: [PATCH 24/28] Update owner of IndexVersions removal (#118360) --- server/src/main/java/org/elasticsearch/index/IndexVersions.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/index/IndexVersions.java b/server/src/main/java/org/elasticsearch/index/IndexVersions.java index 96a70c3cc432b..56f40309a825b 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexVersions.java +++ b/server/src/main/java/org/elasticsearch/index/IndexVersions.java @@ -58,7 +58,7 @@ private static Version parseUnchecked(String version) { } } - @UpdateForV9(owner = UpdateForV9.Owner.CORE_INFRA) // remove the index versions with which v9 will not need to interact + @UpdateForV9(owner = UpdateForV9.Owner.SEARCH_FOUNDATIONS) // remove the index versions with which v9 will not need to interact public static final IndexVersion ZERO = def(0, Version.LATEST); public static final IndexVersion V_7_0_0 = def(7_00_00_99, parseUnchecked("8.0.0")); From 368d47131f2d74742ae89005680352389387bdb8 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Wed, 11 Dec 2024 10:10:49 +0000 Subject: [PATCH 25/28] Remove various references to 7.0 and 7.1 transport versions (#118075) --- .../mustache/MultiSearchTemplateResponse.java | 5 +--- .../elasticsearch/ElasticsearchException.java | 2 +- .../org/elasticsearch/TransportVersions.java | 1 - .../GeoTileGridAggregationBuilder.java | 2 +- .../ExceptionSerializationTests.java | 2 +- .../MetadataCreateIndexServiceTests.java | 2 +- .../ClusterSerializationTests.java | 1 - .../common/settings/SettingsTests.java | 2 +- .../xpack/ccr/CCRInfoTransportAction.java | 2 +- .../core/graph/GraphFeatureSetUsage.java | 2 +- .../ilm/IndexLifecycleFeatureSetUsage.java | 2 +- .../logstash/LogstashFeatureSetUsage.java | 2 +- .../ml/MachineLearningFeatureSetUsage.java | 2 +- .../monitoring/MonitoringFeatureSetUsage.java | 2 +- .../core/rollup/RollupFeatureSetUsage.java | 2 +- .../security/SecurityFeatureSetUsage.java | 2 +- .../support/mapper/ExpressionRoleMapping.java | 11 ++------ .../xpack/core/sql/SqlFeatureSetUsage.java | 2 +- .../core/watcher/WatcherFeatureSetUsage.java | 2 +- .../ml/utils/TransportVersionUtilsTests.java | 8 +++--- .../mapper/ExpressionRoleMappingTests.java | 25 +------------------ 21 files changed, 23 insertions(+), 58 deletions(-) diff --git a/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MultiSearchTemplateResponse.java b/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MultiSearchTemplateResponse.java index 93dedb5cb9645..95af247447688 100644 --- a/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MultiSearchTemplateResponse.java +++ b/modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MultiSearchTemplateResponse.java @@ -11,7 +11,6 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ExceptionsHelper; -import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.common.Strings; import org.elasticsearch.common.collect.Iterators; @@ -131,9 +130,7 @@ public TimeValue getTook() { @Override public void writeTo(StreamOutput out) throws IOException { out.writeArray(items); - if (out.getTransportVersion().onOrAfter(TransportVersions.V_7_0_0)) { - out.writeVLong(tookInMillis); - } + out.writeVLong(tookInMillis); } @Override diff --git a/server/src/main/java/org/elasticsearch/ElasticsearchException.java b/server/src/main/java/org/elasticsearch/ElasticsearchException.java index fcb5c20c28162..11736bfe07deb 100644 --- a/server/src/main/java/org/elasticsearch/ElasticsearchException.java +++ b/server/src/main/java/org/elasticsearch/ElasticsearchException.java @@ -1781,7 +1781,7 @@ private enum ElasticsearchExceptionHandle { org.elasticsearch.cluster.coordination.CoordinationStateRejectedException.class, org.elasticsearch.cluster.coordination.CoordinationStateRejectedException::new, 150, - TransportVersions.V_7_0_0 + UNKNOWN_VERSION_ADDED ), SNAPSHOT_IN_PROGRESS_EXCEPTION( org.elasticsearch.snapshots.SnapshotInProgressException.class, diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 9f50df02b4fe8..b22df6c374b4d 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -52,7 +52,6 @@ static TransportVersion def(int id) { @UpdateForV9(owner = UpdateForV9.Owner.CORE_INFRA) // remove the transport versions with which v9 will not need to interact public static final TransportVersion ZERO = def(0); public static final TransportVersion V_7_0_0 = def(7_00_00_99); - public static final TransportVersion V_7_0_1 = def(7_00_01_99); public static final TransportVersion V_7_2_0 = def(7_02_00_99); public static final TransportVersion V_7_2_1 = def(7_02_01_99); public static final TransportVersion V_7_3_0 = def(7_03_00_99); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/geogrid/GeoTileGridAggregationBuilder.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/geogrid/GeoTileGridAggregationBuilder.java index 2499328caca4d..2219c1d9da4ab 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/geogrid/GeoTileGridAggregationBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/geogrid/GeoTileGridAggregationBuilder.java @@ -109,6 +109,6 @@ public String getType() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_0_0; + return TransportVersions.ZERO; } } diff --git a/server/src/test/java/org/elasticsearch/ExceptionSerializationTests.java b/server/src/test/java/org/elasticsearch/ExceptionSerializationTests.java index 2c6be01c851e4..2abe4157583cd 100644 --- a/server/src/test/java/org/elasticsearch/ExceptionSerializationTests.java +++ b/server/src/test/java/org/elasticsearch/ExceptionSerializationTests.java @@ -359,7 +359,7 @@ public void testSearchContextMissingException() throws IOException { public void testCircuitBreakingException() throws IOException { CircuitBreakingException ex = serialize( new CircuitBreakingException("Too large", 0, 100, CircuitBreaker.Durability.TRANSIENT), - TransportVersions.V_7_0_0 + TransportVersions.V_8_0_0 ); assertEquals("Too large", ex.getMessage()); assertEquals(100, ex.getByteLimit()); diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexServiceTests.java index 1876a1f2da556..c0e397c9fb9c9 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexServiceTests.java @@ -1299,7 +1299,7 @@ public void testBuildIndexMetadataWithTransportVersionBeforeEventIngestedRangeAd 4, sourceIndexMetadata, false, - randomFrom(TransportVersions.V_7_0_0, TransportVersions.V_8_0_0) + TransportVersions.V_8_0_0 ); assertThat(indexMetadata.getAliases().size(), is(1)); diff --git a/server/src/test/java/org/elasticsearch/cluster/serialization/ClusterSerializationTests.java b/server/src/test/java/org/elasticsearch/cluster/serialization/ClusterSerializationTests.java index d8f3c4f1af48e..bad2b07ba7972 100644 --- a/server/src/test/java/org/elasticsearch/cluster/serialization/ClusterSerializationTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/serialization/ClusterSerializationTests.java @@ -88,7 +88,6 @@ public void testClusterStateSerialization() throws Exception { public void testClusterStateSerializationWithTimestampRangesWithOlderTransportVersion() throws Exception { TransportVersion versionBeforeEventIngestedInClusterState = randomFrom( - TransportVersions.V_7_0_0, TransportVersions.V_8_0_0, TransportVersionUtils.getPreviousVersion(TransportVersions.V_8_15_0) ); diff --git a/server/src/test/java/org/elasticsearch/common/settings/SettingsTests.java b/server/src/test/java/org/elasticsearch/common/settings/SettingsTests.java index 5fefd92d176a5..b6e21e7bb911c 100644 --- a/server/src/test/java/org/elasticsearch/common/settings/SettingsTests.java +++ b/server/src/test/java/org/elasticsearch/common/settings/SettingsTests.java @@ -619,7 +619,7 @@ public void testMissingValue() throws Exception { public void testReadWriteArray() throws IOException { BytesStreamOutput output = new BytesStreamOutput(); - output.setTransportVersion(randomFrom(TransportVersion.current(), TransportVersions.V_7_0_0)); + output.setTransportVersion(randomFrom(TransportVersion.current(), TransportVersions.V_8_0_0)); Settings settings = Settings.builder().putList("foo.bar", "0", "1", "2", "3").put("foo.bar.baz", "baz").build(); settings.writeTo(output); StreamInput in = StreamInput.wrap(BytesReference.toBytes(output.bytes())); diff --git a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/CCRInfoTransportAction.java b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/CCRInfoTransportAction.java index 6e6b54af4d0e1..8cb98bb4458aa 100644 --- a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/CCRInfoTransportAction.java +++ b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/CCRInfoTransportAction.java @@ -90,7 +90,7 @@ public Usage(StreamInput in) throws IOException { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_0_0; + return TransportVersions.ZERO; } public int getNumberOfFollowerIndices() { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/graph/GraphFeatureSetUsage.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/graph/GraphFeatureSetUsage.java index b046efaa30082..53994f310c949 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/graph/GraphFeatureSetUsage.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/graph/GraphFeatureSetUsage.java @@ -26,7 +26,7 @@ public GraphFeatureSetUsage(boolean available, boolean enabled) { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_0_0; + return TransportVersions.ZERO; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/IndexLifecycleFeatureSetUsage.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/IndexLifecycleFeatureSetUsage.java index 822f15d1ed74a..543a5be812a1c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/IndexLifecycleFeatureSetUsage.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/IndexLifecycleFeatureSetUsage.java @@ -40,7 +40,7 @@ public IndexLifecycleFeatureSetUsage(StreamInput input) throws IOException { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_0_0; + return TransportVersions.ZERO; } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/logstash/LogstashFeatureSetUsage.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/logstash/LogstashFeatureSetUsage.java index f3f0214b89c04..4a5ebad732d78 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/logstash/LogstashFeatureSetUsage.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/logstash/LogstashFeatureSetUsage.java @@ -26,7 +26,7 @@ public LogstashFeatureSetUsage(boolean available) { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_0_0; + return TransportVersions.ZERO; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java index 8c4611f05e72a..82f50ba3a889f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java @@ -75,7 +75,7 @@ public MachineLearningFeatureSetUsage(StreamInput in) throws IOException { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_0_0; + return TransportVersions.ZERO; } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/monitoring/MonitoringFeatureSetUsage.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/monitoring/MonitoringFeatureSetUsage.java index b181bd78dfe41..48c14f79abcbf 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/monitoring/MonitoringFeatureSetUsage.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/monitoring/MonitoringFeatureSetUsage.java @@ -40,7 +40,7 @@ public MonitoringFeatureSetUsage(boolean collectionEnabled, Map @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_0_0; + return TransportVersions.ZERO; } public Map getExporters() { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/rollup/RollupFeatureSetUsage.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/rollup/RollupFeatureSetUsage.java index 82253ba08165a..59d8926a6dcf8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/rollup/RollupFeatureSetUsage.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/rollup/RollupFeatureSetUsage.java @@ -50,7 +50,7 @@ protected void innerXContent(XContentBuilder builder, Params params) throws IOEx @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_0_0; + return TransportVersions.ZERO; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/SecurityFeatureSetUsage.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/SecurityFeatureSetUsage.java index 33f1a9a469b69..726797d2e563a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/SecurityFeatureSetUsage.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/SecurityFeatureSetUsage.java @@ -114,7 +114,7 @@ public SecurityFeatureSetUsage( @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_0_0; + return TransportVersions.ZERO; } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/mapper/ExpressionRoleMapping.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/mapper/ExpressionRoleMapping.java index 41fd3c6938dfc..df732ee4a5cae 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/mapper/ExpressionRoleMapping.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/mapper/ExpressionRoleMapping.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.core.security.authc.support.mapper; import org.apache.logging.log4j.Logger; -import org.elasticsearch.TransportVersions; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; @@ -139,11 +138,7 @@ public ExpressionRoleMapping(StreamInput in) throws IOException { this.name = in.readString(); this.enabled = in.readBoolean(); this.roles = in.readStringCollectionAsList(); - if (in.getTransportVersion().onOrAfter(TransportVersions.V_7_2_0)) { - this.roleTemplates = in.readCollectionAsList(TemplateRoleName::new); - } else { - this.roleTemplates = Collections.emptyList(); - } + this.roleTemplates = in.readCollectionAsList(TemplateRoleName::new); this.expression = ExpressionParser.readExpression(in); this.metadata = in.readGenericMap(); } @@ -175,9 +170,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(name); out.writeBoolean(enabled); out.writeStringCollection(roles); - if (out.getTransportVersion().onOrAfter(TransportVersions.V_7_2_0)) { - out.writeCollection(roleTemplates); - } + out.writeCollection(roleTemplates); ExpressionParser.writeExpression(expression, out); out.writeGenericMap(metadata); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/sql/SqlFeatureSetUsage.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/sql/SqlFeatureSetUsage.java index 2f41c8d2e2bb6..0b86e27a62f17 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/sql/SqlFeatureSetUsage.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/sql/SqlFeatureSetUsage.java @@ -34,7 +34,7 @@ public SqlFeatureSetUsage(Map stats) { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_7_0_0; + return TransportVersions.ZERO; } public Map stats() { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/watcher/WatcherFeatureSetUsage.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/watcher/WatcherFeatureSetUsage.java index 77280c727be3b..de4dbf601f50b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/watcher/WatcherFeatureSetUsage.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/watcher/WatcherFeatureSetUsage.java @@ -33,7 +33,7 @@ public WatcherFeatureSetUsage(boolean available, boolean enabled, Map stats() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/TransportVersionUtilsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/TransportVersionUtilsTests.java index 015614e56c02b..a316bf84d00eb 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/TransportVersionUtilsTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/TransportVersionUtilsTests.java @@ -23,9 +23,9 @@ public class TransportVersionUtilsTests extends ESTestCase { private static final Map transportVersions = Map.of( "Alfredo", - new CompatibilityVersions(TransportVersions.V_7_0_0, Map.of()), + new CompatibilityVersions(TransportVersions.V_8_1_0, Map.of()), "Bertram", - new CompatibilityVersions(TransportVersions.V_7_0_1, Map.of()), + new CompatibilityVersions(TransportVersions.V_8_6_0, Map.of()), "Charles", new CompatibilityVersions(TransportVersions.V_8_9_X, Map.of()), "Dominic", @@ -48,7 +48,7 @@ public class TransportVersionUtilsTests extends ESTestCase { ); public void testGetMinTransportVersion() { - assertThat(TransportVersionUtils.getMinTransportVersion(state), equalTo(TransportVersions.V_7_0_0)); + assertThat(TransportVersionUtils.getMinTransportVersion(state), equalTo(TransportVersions.V_8_0_0)); } public void testIsMinTransformVersionSameAsCurrent() { @@ -78,7 +78,7 @@ public void testIsMinTransformVersionSameAsCurrent() { } public void testIsMinTransportVersionOnOrAfter() { - assertThat(TransportVersionUtils.isMinTransportVersionOnOrAfter(state, TransportVersions.V_7_0_0), equalTo(true)); + assertThat(TransportVersionUtils.isMinTransportVersionOnOrAfter(state, TransportVersions.V_8_0_0), equalTo(true)); assertThat(TransportVersionUtils.isMinTransportVersionOnOrAfter(state, TransportVersions.V_8_9_X), equalTo(false)); } } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/mapper/ExpressionRoleMappingTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/mapper/ExpressionRoleMappingTests.java index fc5eb135343b9..19286ebf58dc6 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/mapper/ExpressionRoleMappingTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/mapper/ExpressionRoleMappingTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.security.authc.support.mapper; import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesArray; @@ -408,29 +407,7 @@ public void testToXContentWithTemplates() throws Exception { public void testSerialization() throws Exception { final ExpressionRoleMapping original = randomRoleMapping(true); - TransportVersion version = TransportVersionUtils.randomVersionBetween(random(), TransportVersions.V_7_2_0, null); - BytesStreamOutput output = new BytesStreamOutput(); - output.setTransportVersion(version); - original.writeTo(output); - - final NamedWriteableRegistry registry = new NamedWriteableRegistry(new XPackClientPlugin().getNamedWriteables()); - StreamInput streamInput = new NamedWriteableAwareStreamInput( - ByteBufferStreamInput.wrap(BytesReference.toBytes(output.bytes())), - registry - ); - streamInput.setTransportVersion(version); - final ExpressionRoleMapping serialized = new ExpressionRoleMapping(streamInput); - assertEquals(original, serialized); - } - - public void testSerializationPreV71() throws Exception { - final ExpressionRoleMapping original = randomRoleMapping(false); - - TransportVersion version = TransportVersionUtils.randomVersionBetween( - random(), - TransportVersions.V_7_0_0, - TransportVersions.V_7_0_1 - ); + TransportVersion version = TransportVersionUtils.randomCompatibleVersion(random()); BytesStreamOutput output = new BytesStreamOutput(); output.setTransportVersion(version); original.writeTo(output); From e3bddd0e953b7e977a68d69ab5cf0b841d166f48 Mon Sep 17 00:00:00 2001 From: Martijn van Groningen Date: Wed, 11 Dec 2024 11:13:24 +0100 Subject: [PATCH 26/28] Fix issues with ReinitializingSourceProvider (#118370) The previous fix to ensure that each thread uses its own SearchProvider wasn't good enough. The read from `perThreadProvider` field could be stale and therefore returning a previous source provider. Instead the source provider should be returned from `provider` local variable. This change also addresses another issue, sometimes current docid goes backwards compared to last seen docid and this causes issue when synthetic source provider is used, as doc values can't advance backwards. This change addresses that by returning a new source provider if backwards docid is detected. Closes #118238 --- docs/changelog/118370.yaml | 6 +++++ .../xpack/esql/action/EsqlActionIT.java | 26 +++++++++++++++---- .../plugin/ReinitializingSourceProvider.java | 17 +++++++++--- 3 files changed, 41 insertions(+), 8 deletions(-) create mode 100644 docs/changelog/118370.yaml diff --git a/docs/changelog/118370.yaml b/docs/changelog/118370.yaml new file mode 100644 index 0000000000000..e6a429448e493 --- /dev/null +++ b/docs/changelog/118370.yaml @@ -0,0 +1,6 @@ +pr: 118370 +summary: Fix concurrency issue with `ReinitializingSourceProvider` +area: Mapping +type: bug +issues: + - 118238 diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionIT.java index 00f53d31165b1..de9eb166688f9 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionIT.java @@ -37,6 +37,7 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.parser.ParsingException; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; +import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import org.junit.Before; import java.io.IOException; @@ -1673,17 +1674,32 @@ public void testScriptField() throws Exception { String sourceMode = randomBoolean() ? "stored" : "synthetic"; Settings.Builder settings = indexSettings(1, 0).put(indexSettings()).put("index.mapping.source.mode", sourceMode); client().admin().indices().prepareCreate("test-script").setMapping(mapping).setSettings(settings).get(); - for (int i = 0; i < 10; i++) { + int numDocs = 256; + for (int i = 0; i < numDocs; i++) { index("test-script", Integer.toString(i), Map.of("k1", i, "k2", "b-" + i, "meter", 10000 * i)); } refresh("test-script"); - try (EsqlQueryResponse resp = run("FROM test-script | SORT k1 | LIMIT 10")) { + + var pragmas = randomPragmas(); + if (canUseQueryPragmas()) { + Settings.Builder pragmaSettings = Settings.builder().put(pragmas.getSettings()); + pragmaSettings.put("task_concurrency", 10); + pragmaSettings.put("data_partitioning", "doc"); + pragmas = new QueryPragmas(pragmaSettings.build()); + } + try (EsqlQueryResponse resp = run("FROM test-script | SORT k1 | LIMIT " + numDocs, pragmas)) { List k1Column = Iterators.toList(resp.column(0)); - assertThat(k1Column, contains(0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L)); + assertThat(k1Column, equalTo(LongStream.range(0L, numDocs).boxed().toList())); List k2Column = Iterators.toList(resp.column(1)); - assertThat(k2Column, contains(null, null, null, null, null, null, null, null, null, null)); + assertThat(k2Column, equalTo(Collections.nCopies(numDocs, null))); List meterColumn = Iterators.toList(resp.column(2)); - assertThat(meterColumn, contains(0.0, 10000.0, 20000.0, 30000.0, 40000.0, 50000.0, 60000.0, 70000.0, 80000.0, 90000.0)); + var expectedMeterColumn = new ArrayList<>(numDocs); + double val = 0.0; + for (int i = 0; i < numDocs; i++) { + expectedMeterColumn.add(val); + val += 10000.0; + } + assertThat(meterColumn, equalTo(expectedMeterColumn)); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ReinitializingSourceProvider.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ReinitializingSourceProvider.java index b6b2c6dfec755..8dee3478b3b64 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ReinitializingSourceProvider.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ReinitializingSourceProvider.java @@ -15,13 +15,23 @@ import java.util.function.Supplier; /** - * This is a workaround for when compute engine executes concurrently with data partitioning by docid. + * This class exists as a workaround for using SourceProvider in the compute engine. + *

+ * The main issue is when compute engine executes concurrently with data partitioning by docid (inter segment parallelization). + * A {@link SourceProvider} can only be used by a single thread and this wrapping source provider ensures that each thread uses + * its own {@link SourceProvider}. + *

+ * Additionally, this source provider protects against going backwards, which the synthetic source provider can't handle. */ final class ReinitializingSourceProvider implements SourceProvider { private PerThreadSourceProvider perThreadProvider; private final Supplier sourceProviderFactory; + // Keeping track of last seen doc and if current doc is before last seen doc then source provider is initialized: + // (when source mode is synthetic then _source is read from doc values and doc values don't support going backwards) + private int lastSeenDocId; + ReinitializingSourceProvider(Supplier sourceProviderFactory) { this.sourceProviderFactory = sourceProviderFactory; } @@ -30,11 +40,12 @@ final class ReinitializingSourceProvider implements SourceProvider { public Source getSource(LeafReaderContext ctx, int doc) throws IOException { var currentThread = Thread.currentThread(); PerThreadSourceProvider provider = perThreadProvider; - if (provider == null || provider.creatingThread != currentThread) { + if (provider == null || provider.creatingThread != currentThread || doc < lastSeenDocId) { provider = new PerThreadSourceProvider(sourceProviderFactory.get(), currentThread); this.perThreadProvider = provider; } - return perThreadProvider.source.getSource(ctx, doc); + lastSeenDocId = doc; + return provider.source.getSource(ctx, doc); } private record PerThreadSourceProvider(SourceProvider source, Thread creatingThread) { From 1265f3e55499a70a3493321f244cfe479ea52bc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?= Date: Wed, 11 Dec 2024 11:16:52 +0100 Subject: [PATCH 27/28] Added extra javadocs on getKeys and nonEmpty in BlockHash (#118357) --- .../compute/aggregation/blockhash/BlockHash.java | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java index 30afa7ae3128d..9b53e6558f4db 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java @@ -91,6 +91,9 @@ public abstract class BlockHash implements Releasable, SeenGroupIds { /** * Returns a {@link Block} that contains all the keys that are inserted by {@link #add}. + *

+ * Keys must be in the same order as the IDs returned by {@link #nonEmpty()}. + *

*/ public abstract Block[] getKeys(); @@ -100,6 +103,9 @@ public abstract class BlockHash implements Releasable, SeenGroupIds { * {@link BooleanBlockHash} does this by always assigning {@code false} to {@code 0} * and {@code true} to {@code 1}. It's only after collection when we * know if there actually were any {@code true} or {@code false} values received. + *

+ * IDs must be in the same order as the keys returned by {@link #getKeys()}. + *

*/ public abstract IntVector nonEmpty(); From a7fdc10bd4271eefe33801b0021eb40ab63bf474 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 11 Dec 2024 10:51:24 +0000 Subject: [PATCH 28/28] [DOCS][ML] Use elasticsearch service instead of deprecated elser service in tutorials (#118007) --- .../tab-widgets/inference-api/infer-api-task.asciidoc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/reference/tab-widgets/inference-api/infer-api-task.asciidoc b/docs/reference/tab-widgets/inference-api/infer-api-task.asciidoc index 02995591d9c8a..227eb774a4a9f 100644 --- a/docs/reference/tab-widgets/inference-api/infer-api-task.asciidoc +++ b/docs/reference/tab-widgets/inference-api/infer-api-task.asciidoc @@ -36,7 +36,7 @@ the `cosine` measures are equivalent. ------------------------------------------------------------ PUT _inference/sparse_embedding/elser_embeddings <1> { - "service": "elser", + "service": "elasticsearch", "service_settings": { "num_allocations": 1, "num_threads": 1 @@ -206,7 +206,7 @@ PUT _inference/text_embedding/google_vertex_ai_embeddings <1> <2> A valid service account in JSON format for the Google Vertex AI API. <3> For the list of the available models, refer to the https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api[Text embeddings API] page. <4> The name of the location to use for the {infer} task. Refer to https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations[Generative AI on Vertex AI locations] for available locations. -<5> The name of the project to use for the {infer} task. +<5> The name of the project to use for the {infer} task. // end::google-vertex-ai[]