Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TopN Query Aggregator Memory Guardrails #17439

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/configuration/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2147,9 +2147,10 @@ context). If query does have `maxQueuedBytes` in the context, then that value is

### TopN query config

|Property|Description|Default|
|--------|-----------|-------|
|`druid.query.topN.minTopNThreshold`|See [TopN Aliasing](../querying/topnquery.md#aliasing) for details.|1000|
|Property| Description | Default |
|--------|-------------------------------------------------------------------------------|---------|
|`druid.query.topN.minTopNThreshold`| See [TopN Aliasing](../querying/topnquery.md#aliasing) for details. | 1000 |
|`druid.query.topN.maxTopNAggregatorHeapSizeBytes`| The maximum amount of aggregator heap bytes a given segment runner can acrue. | 10MB |

### Search query config

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.topn.TopNQuery;
import org.apache.druid.query.topn.TopNQueryBuilder;
import org.apache.druid.query.topn.TopNQueryConfig;
import org.apache.druid.query.topn.TopNQueryEngine;
import org.apache.druid.query.topn.TopNResultValue;
import org.apache.druid.segment.IncrementalIndexSegment;
Expand Down Expand Up @@ -133,6 +134,7 @@ public void testTopNWithDistinctCountAgg() throws Exception
final Iterable<Result<TopNResultValue>> results =
engine.query(
query,
new TopNQueryConfig(),
new IncrementalIndexSegment(index, SegmentId.dummy(QueryRunnerTestHelper.DATA_SOURCE)),
null
).toList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ public class QueryContexts
public static final String SERIALIZE_DATE_TIME_AS_LONG_INNER_KEY = "serializeDateTimeAsLongInner";
public static final String UNCOVERED_INTERVALS_LIMIT_KEY = "uncoveredIntervalsLimit";
public static final String MIN_TOP_N_THRESHOLD = "minTopNThreshold";
public static final String MAX_TOP_N_AGGREGATOR_HEAP_SIZE_BYTES = "maxTopNAggregatorHeapSizeBytes";
public static final String CATALOG_VALIDATION_ENABLED = "catalogValidationEnabled";

// projection context keys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,11 @@
public abstract class BaseTopNAlgorithm<DimValSelector, DimValAggregateStore, Parameters extends TopNParams>
implements TopNAlgorithm<DimValSelector, Parameters>
{
public static Aggregator[] makeAggregators(Cursor cursor, List<AggregatorFactory> aggregatorSpecs)

public static Aggregator[] makeAggregators(TopNQuery query, Cursor cursor)
{
query.getAggregatorHelper().addAggregatorMemory();
jtuglu-netflix marked this conversation as resolved.
Show resolved Hide resolved
final List<AggregatorFactory> aggregatorSpecs = query.getAggregatorSpecs();
Aggregator[] aggregators = new Aggregator[aggregatorSpecs.size()];
int aggregatorIndex = 0;
for (AggregatorFactory spec : aggregatorSpecs) {
Expand All @@ -52,8 +55,10 @@ public static Aggregator[] makeAggregators(Cursor cursor, List<AggregatorFactory
return aggregators;
}

protected static BufferAggregator[] makeBufferAggregators(Cursor cursor, List<AggregatorFactory> aggregatorSpecs)
protected static BufferAggregator[] makeBufferAggregators(TopNQuery query, Cursor cursor)
{
query.getAggregatorHelper().addAggregatorMemory();
final List<AggregatorFactory> aggregatorSpecs = query.getAggregatorSpecs();
BufferAggregator[] aggregators = new BufferAggregator[aggregatorSpecs.size()];
int aggregatorIndex = 0;
for (AggregatorFactory spec : aggregatorSpecs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ public int[] build()
resultsBuf.clear();

final int numBytesToWorkWith = resultsBuf.remaining();

query.getAggregatorHelper().addAggregatorMemory();
final int[] aggregatorSizes = new int[query.getAggregatorSpecs().size()];
int numBytesPerRecord = 0;

Expand Down Expand Up @@ -329,7 +331,7 @@ protected int[] updateDimValSelector(int[] dimValSelector, int numProcessed, int
@Override
protected BufferAggregator[] makeDimValAggregateStore(PooledTopNParams params)
{
return makeBufferAggregators(params.getCursor(), query.getAggregatorSpecs());
return makeBufferAggregators(query, params.getCursor());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ protected long scanAndAggregate(

Aggregator[] theAggregators = aggregatesStore.computeIfAbsent(
key,
k -> makeAggregators(cursor, query.getAggregatorSpecs())
k -> makeAggregators(query, cursor)
);

for (Aggregator aggregator : theAggregators) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.druid.query.topn;

import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.ResourceLimitExceededException;

import java.util.concurrent.atomic.AtomicLong;

public class TopNAggregatorResourceHelper
{
public static class Config
{
public final long maxAggregatorHeapSize;

public Config(final long maxAggregatorHeapSize)
{
this.maxAggregatorHeapSize = maxAggregatorHeapSize;
}
}

private final Config config;
private final long newAggregatorEstimatedMemorySize;
private final AtomicLong used = new AtomicLong(0);

TopNAggregatorResourceHelper(final long newAggregatorEstimatedMemorySize, final Config config)
{
this.newAggregatorEstimatedMemorySize = newAggregatorEstimatedMemorySize;
this.config = config;
}

public void addAggregatorMemory()
{
final long newTotal = used.addAndGet(newAggregatorEstimatedMemorySize);
if (newTotal > config.maxAggregatorHeapSize) {
throw new ResourceLimitExceededException(StringUtils.format(
"Query ran out of memory. Maximum allowed bytes=[%d], Hit bytes=[%d]",
config.maxAggregatorHeapSize,
newTotal
));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.druid.query.PerSegmentQueryOptimizationContext;
import org.apache.druid.query.Queries;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.Result;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.PostAggregator;
Expand Down Expand Up @@ -60,6 +61,7 @@ public class TopNQuery extends BaseQuery<Result<TopNResultValue>>
private final DimFilter dimFilter;
private final List<AggregatorFactory> aggregatorSpecs;
private final List<PostAggregator> postAggregatorSpecs;
private TopNAggregatorResourceHelper aggregatorHelper;
jtuglu-netflix marked this conversation as resolved.
Show resolved Hide resolved

@JsonCreator
public TopNQuery(
Expand Down Expand Up @@ -97,9 +99,18 @@ public TopNQuery(
: postAggregatorSpecs
);

final long expectedAllocBytes = aggregatorSpecs.stream().mapToLong(AggregatorFactory::getMaxIntermediateSizeWithNulls).sum();
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved
final long maxAggregatorHeapSizeBytes = this.context().getLong(QueryContexts.MAX_TOP_N_AGGREGATOR_HEAP_SIZE_BYTES, TopNQueryConfig.DEFAULT_MAX_AGGREGATOR_HEAP_SIZE_BYTES);
this.aggregatorHelper = new TopNAggregatorResourceHelper(expectedAllocBytes, new TopNAggregatorResourceHelper.Config(maxAggregatorHeapSizeBytes));

topNMetricSpec.verifyPreconditions(this.aggregatorSpecs, this.postAggregatorSpecs);
}

public TopNAggregatorResourceHelper getAggregatorHelper()
jtuglu-netflix marked this conversation as resolved.
Show resolved Hide resolved
{
return aggregatorHelper;
}

@Override
public boolean hasFilters()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
public class TopNQueryConfig
{
public static final int DEFAULT_MIN_TOPN_THRESHOLD = 1000;
public static final long DEFAULT_MAX_AGGREGATOR_HEAP_SIZE_BYTES = 10 * (2 << 20); // 10mb

@JsonProperty
@Min(1)
Expand All @@ -37,4 +38,13 @@ public int getMinTopNThreshold()
{
return minTopNThreshold;
}

@JsonProperty
@Min(0)
private long maxTopNAggregatorHeapSizeBytes = DEFAULT_MAX_AGGREGATOR_HEAP_SIZE_BYTES;

public long getMaxTopNAggregatorHeapSizeBytes()
{
return maxTopNAggregatorHeapSizeBytes;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@

import com.google.common.base.Preconditions;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.collections.NonBlockingPool;
import org.apache.druid.collections.ResourceHolder;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.query.ColumnSelectorPlus;
import org.apache.druid.query.CursorGranularizer;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryMetrics;
import org.apache.druid.query.Result;
import org.apache.druid.query.aggregation.AggregatorFactory;
Expand Down Expand Up @@ -75,6 +77,7 @@ public TopNQueryEngine(NonBlockingPool<ByteBuffer> bufferPool)
*/
public Sequence<Result<TopNResultValue>> query(
TopNQuery query,
TopNQueryConfig config,
final Segment segment,
@Nullable final TopNQueryMetrics queryMetrics
)
Expand All @@ -86,6 +89,13 @@ public Sequence<Result<TopNResultValue>> query(
);
}

if (!query.context().containsKey(QueryContexts.MAX_TOP_N_AGGREGATOR_HEAP_SIZE_BYTES)) {
query = query.withOverriddenContext(ImmutableMap.of(
QueryContexts.MAX_TOP_N_AGGREGATOR_HEAP_SIZE_BYTES,
config.getMaxTopNAggregatorHeapSizeBytes()
));
}

final CursorBuildSpec buildSpec = makeCursorBuildSpec(query, queryMetrics);
final CursorHolder cursorHolder = cursorFactory.makeCursorHolder(buildSpec);
if (cursorHolder.isPreAggregated()) {
Expand Down Expand Up @@ -178,7 +188,14 @@ private TopNMapFn getMapFn(
);

final TopNAlgorithm<?, ?> topNAlgorithm;
if (canUsePooledAlgorithm(selector, query, columnCapabilities, bufferPool, cursorInspector.getDimensionCardinality(), numBytesPerRecord)) {
if (canUsePooledAlgorithm(
selector,
query,
columnCapabilities,
bufferPool,
cursorInspector.getDimensionCardinality(),
numBytesPerRecord
)) {
// pool based algorithm selection, if we can
if (selector.isAggregateAllMetrics()) {
// if sorted by dimension we should aggregate all metrics in a single pass, use the regular pooled algorithm for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,11 @@ public Sequence<Object[]> resultsAsArrays(TopNQuery query, Sequence<Result<TopNR
);
}

public TopNQueryConfig getConfig()
{
return this.config;
}

/**
* This returns a single frame containing the rows of the topN query's results
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ public Sequence<Result<TopNResultValue>> run(
TopNQuery query = (TopNQuery) input.getQuery();
return queryEngine.query(
query,
toolchest.getConfig(),
segment,
(TopNQueryMetrics) input.getQueryMetrics()
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Aggregator[] getValueAggregators(
long key = Double.doubleToLongBits(selector.getDouble());
return aggregatesStore.computeIfAbsent(
key,
k -> BaseTopNAlgorithm.makeAggregators(cursor, query.getAggregatorSpecs())
k -> BaseTopNAlgorithm.makeAggregators(query, cursor)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Aggregator[] getValueAggregators(
int key = Float.floatToIntBits(selector.getFloat());
return aggregatesStore.computeIfAbsent(
key,
k -> BaseTopNAlgorithm.makeAggregators(cursor, query.getAggregatorSpecs())
k -> BaseTopNAlgorithm.makeAggregators(query, cursor)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Aggregator[] getValueAggregators(TopNQuery query, BaseLongColumnValueSelector se
long key = selector.getLong();
return aggregatesStore.computeIfAbsent(
key,
k -> BaseTopNAlgorithm.makeAggregators(cursor, query.getAggregatorSpecs())
k -> BaseTopNAlgorithm.makeAggregators(query, cursor)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public long scanAndAggregate(
while (!cursor.isDone()) {
if (hasNulls && selector.isNull()) {
if (nullValueAggregates == null) {
nullValueAggregates = BaseTopNAlgorithm.makeAggregators(cursor, query.getAggregatorSpecs());
nullValueAggregates = BaseTopNAlgorithm.makeAggregators(query, cursor);
}
for (Aggregator aggregator : nullValueAggregates) {
aggregator.aggregate();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ private long scanAndAggregateWithCardinalityKnown(
final Object key = dimensionValueConverter.apply(selector.lookupName(dimIndex));
aggs = aggregatesStore.computeIfAbsent(
key,
k -> BaseTopNAlgorithm.makeAggregators(cursor, query.getAggregatorSpecs())
k -> BaseTopNAlgorithm.makeAggregators(query, cursor)
);
rowSelector[dimIndex] = aggs;
}
Expand Down Expand Up @@ -199,7 +199,7 @@ private long scanAndAggregateWithCardinalityUnknown(
final Object key = dimensionValueConverter.apply(selector.lookupName(dimIndex));
Aggregator[] aggs = aggregatesStore.computeIfAbsent(
key,
k -> BaseTopNAlgorithm.makeAggregators(cursor, query.getAggregatorSpecs())
k -> BaseTopNAlgorithm.makeAggregators(query, cursor)
);
for (Aggregator aggregator : aggs) {
aggregator.aggregate();
Expand Down
Loading