diff --git a/CHANGELOG.md b/CHANGELOG.md index 076eebb65f19c..ce67273f2a6f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -182,6 +182,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Redefine telemetry context restoration and propagation ([#9617](https://github.com/opensearch-project/OpenSearch/pull/9617)) - Use non-concurrent path for sort request on timeseries index and field([#9562](https://github.com/opensearch-project/OpenSearch/pull/9562)) - Added sampler based on `Blanket Probabilistic Sampling rate` and `Override for on demand` ([#9621](https://github.com/opensearch-project/OpenSearch/issues/9621)) +- Improve performance of rounding dates in date_histogram aggregation ([#9727](https://github.com/opensearch-project/OpenSearch/pull/9727)) - [Remote Store] Add support for Remote Translog Store stats in `_remotestore/stats/` API ([#9263](https://github.com/opensearch-project/OpenSearch/pull/9263)) - Add support for query profiler with concurrent aggregation ([#9248](https://github.com/opensearch-project/OpenSearch/pull/9248)) - Cleanup Unreferenced file on segment merge failure ([#9503](https://github.com/opensearch-project/OpenSearch/pull/9503)) diff --git a/benchmarks/src/main/java/org/opensearch/common/ArrayRoundingBenchmark.java b/benchmarks/src/main/java/org/opensearch/common/ArrayRoundingBenchmark.java new file mode 100644 index 0000000000000..64c0a9e1d7aa6 --- /dev/null +++ b/benchmarks/src/main/java/org/opensearch/common/ArrayRoundingBenchmark.java @@ -0,0 +1,147 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import java.util.Random; +import java.util.function.Supplier; + +@Fork(value = 3) +@Warmup(iterations = 3, time = 1) +@Measurement(iterations = 1, time = 1) +@BenchmarkMode(Mode.Throughput) +public class ArrayRoundingBenchmark { + + @Benchmark + public void round(Blackhole bh, Options opts) { + Rounding.Prepared rounding = opts.supplier.get(); + for (long key : opts.queries) { + bh.consume(rounding.round(key)); + } + } + + @State(Scope.Benchmark) + public static class Options { + @Param({ + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "10", + "12", + "14", + "16", + "18", + "20", + "22", + "24", + "26", + "29", + "32", + "37", + "41", + "45", + "49", + "54", + "60", + "64", + "74", + "83", + "90", + "98", + "108", + "118", + "128", + "144", + "159", + "171", + "187", + "204", + "229", + "256" }) + public Integer size; + + @Param({ "binary", "linear" }) + public String type; + + @Param({ "uniform", "skewed_edge", "skewed_center" }) + public String distribution; + + public long[] queries; + public Supplier supplier; + + @Setup + public void setup() { + Random random = new Random(size); + long[] values = new long[size]; + for (int i = 1; i < values.length; i++) { + values[i] = values[i - 1] + 100; + } + + long range = values[values.length - 1] - values[0] + 100; + long mean, stddev; + queries = new long[1000000]; + + switch (distribution) { + case "uniform": // all values equally likely. + for (int i = 0; i < queries.length; i++) { + queries[i] = values[0] + (nextPositiveLong(random) % range); + } + break; + case "skewed_edge": // distribution centered at p90 with ± 5% stddev. + mean = values[0] + (long) (range * 0.9); + stddev = (long) (range * 0.05); + for (int i = 0; i < queries.length; i++) { + queries[i] = Math.max(values[0], mean + (long) (random.nextGaussian() * stddev)); + } + break; + case "skewed_center": // distribution centered at p50 with ± 5% stddev. + mean = values[0] + (long) (range * 0.5); + stddev = (long) (range * 0.05); + for (int i = 0; i < queries.length; i++) { + queries[i] = Math.max(values[0], mean + (long) (random.nextGaussian() * stddev)); + } + break; + default: + throw new IllegalArgumentException("invalid distribution: " + distribution); + } + + switch (type) { + case "binary": + supplier = () -> new Rounding.BinarySearchArrayRounding(values, size, null); + break; + case "linear": + supplier = () -> new Rounding.BidirectionalLinearSearchArrayRounding(values, size, null); + break; + default: + throw new IllegalArgumentException("invalid type: " + type); + } + } + + private static long nextPositiveLong(Random random) { + return random.nextLong() & Long.MAX_VALUE; + } + } +} diff --git a/server/src/main/java/org/opensearch/common/Rounding.java b/server/src/main/java/org/opensearch/common/Rounding.java index 65ffdafc423fd..667eb4529fe38 100644 --- a/server/src/main/java/org/opensearch/common/Rounding.java +++ b/server/src/main/java/org/opensearch/common/Rounding.java @@ -37,6 +37,7 @@ import org.opensearch.OpenSearchException; import org.opensearch.common.LocalTimeOffset.Gap; import org.opensearch.common.LocalTimeOffset.Overlap; +import org.opensearch.common.annotation.InternalApi; import org.opensearch.common.time.DateUtils; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.io.stream.StreamInput; @@ -412,6 +413,21 @@ public Rounding build() { } private abstract class PreparedRounding implements Prepared { + /** + * The maximum limit up to which array-based prepared rounding is used. + * 128 is a power of two that isn't huge. We might be able to do + * better if the limit was based on the actual type of prepared + * rounding but this'll do for now. + */ + private static final int DEFAULT_ARRAY_ROUNDING_MAX_THRESHOLD = 128; + + /** + * The maximum limit up to which linear search is used, otherwise binary search is used. + * This is because linear search is much faster on small arrays. + * Benchmark results: PR #9727 + */ + private static final int LINEAR_SEARCH_ARRAY_ROUNDING_MAX_THRESHOLD = 64; + /** * Attempt to build a {@link Prepared} implementation that relies on pre-calcuated * "round down" points. If there would be more than {@code max} points then return @@ -435,7 +451,9 @@ protected Prepared maybeUseArray(long minUtcMillis, long maxUtcMillis, int max) values = ArrayUtil.grow(values, i + 1); values[i++] = rounded; } - return new ArrayRounding(values, i, this); + return i <= LINEAR_SEARCH_ARRAY_ROUNDING_MAX_THRESHOLD + ? new BidirectionalLinearSearchArrayRounding(values, i, this) + : new BinarySearchArrayRounding(values, i, this); } } @@ -521,12 +539,11 @@ private LocalDateTime truncateLocalDateTime(LocalDateTime localDateTime) { @Override public Prepared prepare(long minUtcMillis, long maxUtcMillis) { - /* - * 128 is a power of two that isn't huge. We might be able to do - * better if the limit was based on the actual type of prepared - * rounding but this'll do for now. - */ - return prepareOffsetOrJavaTimeRounding(minUtcMillis, maxUtcMillis).maybeUseArray(minUtcMillis, maxUtcMillis, 128); + return prepareOffsetOrJavaTimeRounding(minUtcMillis, maxUtcMillis).maybeUseArray( + minUtcMillis, + maxUtcMillis, + PreparedRounding.DEFAULT_ARRAY_ROUNDING_MAX_THRESHOLD + ); } private TimeUnitPreparedRounding prepareOffsetOrJavaTimeRounding(long minUtcMillis, long maxUtcMillis) { @@ -1330,14 +1347,19 @@ public static Rounding read(StreamInput in) throws IOException { /** * Implementation of {@link Prepared} using pre-calculated "round down" points. * + *

+ * It uses binary search to find the greatest round-down point less than or equal to the given timestamp. + * * @opensearch.internal */ - private static class ArrayRounding implements Prepared { + @InternalApi + static class BinarySearchArrayRounding implements Prepared { private final long[] values; private final int max; private final Prepared delegate; - private ArrayRounding(long[] values, int max, Prepared delegate) { + BinarySearchArrayRounding(long[] values, int max, Prepared delegate) { + assert max > 0 : "at least one round-down point must be present"; this.values = values; this.max = max; this.delegate = delegate; @@ -1365,4 +1387,64 @@ public double roundingSize(long utcMillis, DateTimeUnit timeUnit) { return delegate.roundingSize(utcMillis, timeUnit); } } + + /** + * Implementation of {@link Prepared} using pre-calculated "round down" points. + * + *

+ * It uses linear search to find the greatest round-down point less than or equal to the given timestamp. + * For small inputs (≤ 64 elements), this can be much faster than binary search as it avoids the penalty of + * branch mispredictions and pipeline stalls, and accesses memory sequentially. + * + *

+ * It uses "meet in the middle" linear search to avoid the worst case scenario when the desired element is present + * at either side of the array. This is helpful for time-series data where velocity increases over time, so more + * documents are likely to find a greater timestamp which is likely to be present on the right end of the array. + * + * @opensearch.internal + */ + @InternalApi + static class BidirectionalLinearSearchArrayRounding implements Prepared { + private final long[] ascending; + private final long[] descending; + private final Prepared delegate; + + BidirectionalLinearSearchArrayRounding(long[] values, int max, Prepared delegate) { + assert max > 0 : "at least one round-down point must be present"; + this.delegate = delegate; + int len = (max + 1) >>> 1; // rounded-up to handle odd number of values + ascending = new long[len]; + descending = new long[len]; + + for (int i = 0; i < len; i++) { + ascending[i] = values[i]; + descending[i] = values[max - i - 1]; + } + } + + @Override + public long round(long utcMillis) { + int i = 0; + for (; i < ascending.length; i++) { + if (descending[i] <= utcMillis) { + return descending[i]; + } + if (ascending[i] > utcMillis) { + assert i > 0 : "utcMillis must be after " + ascending[0]; + return ascending[i - 1]; + } + } + return ascending[i - 1]; + } + + @Override + public long nextRoundingValue(long utcMillis) { + return delegate.nextRoundingValue(utcMillis); + } + + @Override + public double roundingSize(long utcMillis, DateTimeUnit timeUnit) { + return delegate.roundingSize(utcMillis, timeUnit); + } + } } diff --git a/server/src/test/java/org/opensearch/common/RoundingTests.java b/server/src/test/java/org/opensearch/common/RoundingTests.java index e0c44e3516e7b..0ebfe02dc7641 100644 --- a/server/src/test/java/org/opensearch/common/RoundingTests.java +++ b/server/src/test/java/org/opensearch/common/RoundingTests.java @@ -1143,6 +1143,28 @@ public void testNonMillisecondsBasedUnitCalendarRoundingSize() { assertThat(prepared.roundingSize(thirdQuarter, Rounding.DateTimeUnit.HOUR_OF_DAY), closeTo(2208.0, 0.000001)); } + public void testArrayRoundingImplementations() { + int length = randomIntBetween(1, 256); + long[] values = new long[length]; + for (int i = 1; i < values.length; i++) { + values[i] = values[i - 1] + (randomNonNegativeLong() % 100); + } + + Rounding.Prepared binarySearchImpl = new Rounding.BinarySearchArrayRounding(values, length, null); + Rounding.Prepared linearSearchImpl = new Rounding.BidirectionalLinearSearchArrayRounding(values, length, null); + + for (int i = 0; i < 100000; i++) { + long key = values[0] + (randomNonNegativeLong() % (100 + values[length - 1] - values[0])); + assertEquals(binarySearchImpl.round(key), linearSearchImpl.round(key)); + } + + AssertionError exception = expectThrows(AssertionError.class, () -> { binarySearchImpl.round(values[0] - 1); }); + assertEquals("utcMillis must be after " + values[0], exception.getMessage()); + + exception = expectThrows(AssertionError.class, () -> { linearSearchImpl.round(values[0] - 1); }); + assertEquals("utcMillis must be after " + values[0], exception.getMessage()); + } + private void assertInterval(long rounded, long nextRoundingValue, Rounding rounding, int minutes, ZoneId tz) { assertInterval(rounded, dateBetween(rounded, nextRoundingValue), nextRoundingValue, rounding, tz); long millisPerMinute = 60_000;