Skip to content

Commit

Permalink
Merge pull request duckdb#13093 from hawkfish/window-distinct
Browse files Browse the repository at this point in the history
Feature duckdb#1272: Windowed DISTINCT Sink
  • Loading branch information
Mytherin authored Jul 23, 2024
2 parents b076575 + 42c3d34 commit 40e2fcd
Showing 1 changed file with 103 additions and 121 deletions.
224 changes: 103 additions & 121 deletions src/execution/window_segment_tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ WindowAggregatorState::WindowAggregatorState() : allocator(Allocator::DefaultAll
class WindowAggregatorGlobalState : public WindowAggregatorState {
public:
WindowAggregatorGlobalState(const WindowAggregator &aggregator_p, idx_t group_count)
: aggregator(aggregator_p), winputs(inputs) {
: aggregator(aggregator_p), winputs(inputs), locals(0) {

if (!aggregator.arg_types.empty()) {
winputs.Initialize(Allocator::DefaultAllocator(), aggregator.arg_types, group_count);
Expand All @@ -46,6 +46,9 @@ class WindowAggregatorGlobalState : public WindowAggregatorState {
//! Lock for single threading
mutable mutex lock;

//! Count of local tasks
mutable std::atomic<idx_t> locals;

//! Number of finalised states
idx_t finalized = 0;
};
Expand Down Expand Up @@ -177,8 +180,6 @@ class WindowConstantAggregatorGlobalState : public WindowAggregatorGlobalState {

//! Partition starts
vector<idx_t> partition_offsets;
//! Count of local tasks
mutable std::atomic<idx_t> locals;
//! Reused result state container for the window functions
WindowAggregateStates statef;
//! Aggregate results
Expand Down Expand Up @@ -212,7 +213,7 @@ class WindowConstantAggregatorLocalState : public WindowAggregatorState {
WindowConstantAggregatorGlobalState::WindowConstantAggregatorGlobalState(const WindowConstantAggregator &aggregator,
idx_t group_count,
const ValidityMask &partition_mask)
: WindowAggregatorGlobalState(aggregator, STANDARD_VECTOR_SIZE), locals(0), statef(aggregator.aggr) {
: WindowAggregatorGlobalState(aggregator, STANDARD_VECTOR_SIZE), statef(aggregator.aggr) {

// Locate the partition boundaries
if (partition_mask.AllValid()) {
Expand Down Expand Up @@ -1394,18 +1395,17 @@ class WindowDistinctAggregatorGlobalState : public WindowAggregatorGlobalState {
WindowDistinctAggregatorGlobalState(const WindowDistinctAggregator &aggregator, idx_t group_count);
~WindowDistinctAggregatorGlobalState() override;

void Sink(DataChunk &arg_chunk, idx_t input_idx, optional_ptr<SelectionVector> filter_sel, idx_t filtered);
void Finalize(const FrameStats &stats);

// Single threaded sorting for now
ClientContext &context;
GlobalSortStatePtr global_sort;
LocalSortState local_sort;
idx_t memory_per_thread;

//! The sorted payload data types (partition index)
vector<LogicalType> payload_types;
DataChunk sort_chunk;
DataChunk payload_chunk;
//! The aggregate arguments + partition index
vector<LogicalType> sort_types;

//! The merge sort tree for the aggregate.
unique_ptr<DistinctSortTree> merge_sort_tree;
Expand All @@ -1423,7 +1423,28 @@ WindowDistinctAggregatorGlobalState::WindowDistinctAggregatorGlobalState(const W
idx_t group_count)
: WindowAggregatorGlobalState(aggregator_p, group_count), context(aggregator_p.context) {
payload_types.emplace_back(LogicalType::UBIGINT);
payload_chunk.Initialize(Allocator::DefaultAllocator(), payload_types);

// 1: functionComputePrevIdcs(𝑖𝑛)
// 2: sorted ← []
// We sort the aggregate arguments and use the partition index as a tie-breaker.
// TODO: Use a hash table?
sort_types = aggregator_p.arg_types;
for (const auto &type : payload_types) {
sort_types.emplace_back(type);
}

vector<BoundOrderByNode> orders;
for (const auto &type : sort_types) {
auto expr = make_uniq<BoundConstantExpression>(Value(type));
orders.emplace_back(BoundOrderByNode(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, std::move(expr)));
}

RowLayout payload_layout;
payload_layout.Initialize(payload_types);

global_sort = make_uniq<GlobalSortState>(BufferManager::GetBufferManager(context), orders, payload_layout);

memory_per_thread = PhysicalOperator::GetMaxThreadMemory(context);
}

WindowDistinctAggregatorGlobalState::~WindowDistinctAggregatorGlobalState() {
Expand All @@ -1449,6 +1470,52 @@ WindowDistinctAggregatorGlobalState::~WindowDistinctAggregatorGlobalState() {
}
}

class WindowDistinctAggregatorLocalState : public WindowAggregatorState {
public:
explicit WindowDistinctAggregatorLocalState(const WindowDistinctAggregatorGlobalState &aggregator);

void Sink(DataChunk &arg_chunk, idx_t input_idx, optional_ptr<SelectionVector> filter_sel, idx_t filtered);
void Evaluate(const WindowDistinctAggregatorGlobalState &gdstate, const DataChunk &bounds, Vector &result,
idx_t count, idx_t row_idx);

LocalSortState local_sort;

protected:
//! Flush the accumulated intermediate states into the result states
void FlushStates();

//! The aggregator we are working with
const WindowDistinctAggregatorGlobalState &gastate;
DataChunk sort_chunk;
DataChunk payload_chunk;
//! Reused result state container for the window functions
WindowAggregateStates statef;
//! A vector of pointers to "state", used for buffering intermediate aggregates
Vector statep;
//! Reused state pointers for combining tree elements
Vector statel;
//! Count of buffered values
idx_t flush_count;
//! The frame boundaries, used for the window functions
SubFrames frames;
};

WindowDistinctAggregatorLocalState::WindowDistinctAggregatorLocalState(
const WindowDistinctAggregatorGlobalState &gastate)
: gastate(gastate), statef(gastate.aggregator.aggr), statep(LogicalType::POINTER), statel(LogicalType::POINTER),
flush_count(0) {
InitSubFrames(frames, gastate.aggregator.exclude_mode);
payload_chunk.Initialize(Allocator::DefaultAllocator(), gastate.payload_types);

auto &global_sort = gastate.global_sort;
local_sort.Initialize(*global_sort, global_sort->buffer_manager);

sort_chunk.Initialize(Allocator::DefaultAllocator(), gastate.sort_types);
sort_chunk.data.back().Reference(payload_chunk.data[0]);

gastate.locals++;
}

unique_ptr<WindowAggregatorState> WindowDistinctAggregator::GetGlobalState(idx_t group_count,
const ValidityMask &partition_mask) const {
return make_uniq<WindowDistinctAggregatorGlobalState>(*this, group_count);
Expand All @@ -1458,46 +1525,12 @@ void WindowDistinctAggregator::Sink(WindowAggregatorState &gsink, WindowAggregat
idx_t input_idx, optional_ptr<SelectionVector> filter_sel, idx_t filtered) {
WindowAggregator::Sink(gsink, lstate, arg_chunk, input_idx, filter_sel, filtered);

auto &gdstate = gsink.Cast<WindowDistinctAggregatorGlobalState>();
gdstate.Sink(arg_chunk, input_idx, filter_sel, filtered);
auto &ldstate = lstate.Cast<WindowDistinctAggregatorLocalState>();
ldstate.Sink(arg_chunk, input_idx, filter_sel, filtered);
}

void WindowDistinctAggregatorGlobalState::Sink(DataChunk &arg_chunk, idx_t input_idx,
optional_ptr<SelectionVector> filter_sel, idx_t filtered) {
// Single threaded for now
lock_guard<mutex> sink_guard(lock);

// We sort the arguments and use the partition index as a tie-breaker.
// TODO: Use a hash table?
if (!global_sort) {
// 1: functionComputePrevIdcs(𝑖𝑛)
// 2: sorted ← []
vector<LogicalType> sort_types;
for (const auto &col : arg_chunk.data) {
sort_types.emplace_back(col.GetType());
}

for (const auto &type : payload_types) {
sort_types.emplace_back(type);
}

vector<BoundOrderByNode> orders;
for (const auto &type : sort_types) {
auto expr = make_uniq<BoundConstantExpression>(Value(type));
orders.emplace_back(BoundOrderByNode(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, std::move(expr)));
}

RowLayout payload_layout;
payload_layout.Initialize(payload_types);

global_sort = make_uniq<GlobalSortState>(BufferManager::GetBufferManager(context), orders, payload_layout);
local_sort.Initialize(*global_sort, global_sort->buffer_manager);

sort_chunk.Initialize(Allocator::DefaultAllocator(), sort_types);
sort_chunk.data.back().Reference(payload_chunk.data[0]);
memory_per_thread = PhysicalOperator::GetMaxThreadMemory(context);
}

void WindowDistinctAggregatorLocalState::Sink(DataChunk &arg_chunk, idx_t input_idx,
optional_ptr<SelectionVector> filter_sel, idx_t filtered) {
// 3: for i ← 0 to in.size do
// 4: sorted[i] ← (in[i], i)
const auto count = arg_chunk.size();
Expand All @@ -1521,24 +1554,24 @@ void WindowDistinctAggregatorGlobalState::Sink(DataChunk &arg_chunk, idx_t input

local_sort.SinkChunk(sort_chunk, payload_chunk);

if (local_sort.SizeInBytes() > memory_per_thread) {
local_sort.Sort(*global_sort, true);
if (local_sort.SizeInBytes() > gastate.memory_per_thread) {
local_sort.Sort(*gastate.global_sort, true);
}
}

void WindowDistinctAggregator::Finalize(WindowAggregatorState &gsink, WindowAggregatorState &lstate,
const FrameStats &stats) {
auto &gdsink = gsink.Cast<WindowDistinctAggregatorGlobalState>();
auto &ldstate = lstate.Cast<WindowDistinctAggregatorLocalState>();

// Single threaded Finalize for now
// Single threaded Combine for now
lock_guard<mutex> gestate_guard(gdsink.lock);
if (gdsink.finalized) {
return;
}
gdsink.global_sort->AddLocalState(ldstate.local_sort);

gdsink.Finalize(stats);

++gdsink.finalized;
// Last one out turns off the lights!
if (++gdsink.finalized == gdsink.locals) {
gdsink.Finalize(stats);
}
}

class WindowDistinctAggregatorGlobalState::DistinctSortTree : public MergeSortTree<idx_t, idx_t> {
Expand All @@ -1552,7 +1585,6 @@ class WindowDistinctAggregatorGlobalState::DistinctSortTree : public MergeSortTr

void WindowDistinctAggregatorGlobalState::Finalize(const FrameStats &stats) {
// 5: Sort sorted lexicographically increasing
global_sort->AddLocalState(local_sort);
global_sort->PrepareMergePhase();
while (global_sort->sorted_blocks.size() > 1) {
global_sort->InitializeMergeRound();
Expand Down Expand Up @@ -1590,7 +1622,7 @@ void WindowDistinctAggregatorGlobalState::Finalize(const FrameStats &stats) {

SBIterator curr(*global_sort, ExpressionType::COMPARE_LESSTHAN);
SBIterator prev(*global_sort, ExpressionType::COMPARE_LESSTHAN);
auto prefix_layout = global_sort->sort_layout.GetPrefixComparisonLayout(sort_chunk.ColumnCount() - 1);
auto prefix_layout = global_sort->sort_layout.GetPrefixComparisonLayout(aggregator.arg_types.size());

// 8: for i ← 1 to in.size do
for (++curr; curr.GetIndex() < in_size; ++curr, ++prev) {
Expand Down Expand Up @@ -1738,79 +1770,34 @@ WindowDistinctAggregatorGlobalState::DistinctSortTree::DistinctSortTree(ZippedEl
}
}

class WindowDistinctState : public WindowAggregatorState {
public:
explicit WindowDistinctState(const WindowDistinctAggregator &aggregator);

void Evaluate(const WindowDistinctAggregatorGlobalState &gdstate, const DataChunk &bounds, Vector &result,
idx_t count, idx_t row_idx);

protected:
//! Flush the accumulated intermediate states into the result states
void FlushStates();

//! The aggregator we are working with
const WindowDistinctAggregator &aggregator;
//! The size of a single aggregate state
const idx_t state_size;
//! Data pointer that contains a vector of states, used for row aggregation
vector<data_t> state;
//! Reused result state container for the window functions
Vector statef;
//! A vector of pointers to "state", used for buffering intermediate aggregates
Vector statep;
//! Reused state pointers for combining tree elements
Vector statel;
//! Count of buffered values
idx_t flush_count;
//! The frame boundaries, used for the window functions
SubFrames frames;
};

WindowDistinctState::WindowDistinctState(const WindowDistinctAggregator &aggregator)
: aggregator(aggregator), state_size(aggregator.state_size), state((aggregator.state_size * STANDARD_VECTOR_SIZE)),
statef(LogicalType::POINTER), statep(LogicalType::POINTER), statel(LogicalType::POINTER), flush_count(0) {
InitSubFrames(frames, aggregator.exclude_mode);

// Build the finalise vector that just points to the result states
data_ptr_t state_ptr = state.data();
D_ASSERT(statef.GetVectorType() == VectorType::FLAT_VECTOR);
statef.SetVectorType(VectorType::CONSTANT_VECTOR);
statef.Flatten(STANDARD_VECTOR_SIZE);
auto fdata = FlatVector::GetData<data_ptr_t>(statef);
for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; ++i) {
fdata[i] = state_ptr;
state_ptr += state_size;
}
}

void WindowDistinctState::FlushStates() {
void WindowDistinctAggregatorLocalState::FlushStates() {
if (!flush_count) {
return;
}

const auto &aggr = aggregator.aggr;
const auto &aggr = gastate.aggregator.aggr;
AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator);
statel.Verify(flush_count);
aggr.function.combine(statel, statep, aggr_input_data, flush_count);

flush_count = 0;
}

void WindowDistinctState::Evaluate(const WindowDistinctAggregatorGlobalState &gdstate, const DataChunk &bounds,
Vector &result, idx_t count, idx_t row_idx) {
auto fdata = FlatVector::GetData<data_ptr_t>(statef);
void WindowDistinctAggregatorLocalState::Evaluate(const WindowDistinctAggregatorGlobalState &gdstate,
const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) {
auto ldata = FlatVector::GetData<data_ptr_t>(statel);
auto pdata = FlatVector::GetData<data_ptr_t>(statep);

const auto &merge_sort_tree = *gdstate.merge_sort_tree;
const auto running_aggs = gdstate.levels_flat_native.get();
const auto exclude_mode = gdstate.aggregator.exclude_mode;
const auto &aggr = gdstate.aggregator.aggr;
const auto state_size = statef.state_size;

// Build the finalise vector that just points to the result states
statef.Initialize(count);

EvaluateSubFrames(bounds, exclude_mode, count, row_idx, frames, [&](idx_t rid) {
auto agg_state = fdata[rid];
aggr.function.initialize(agg_state);
auto agg_state = statef.GetStatePtr(rid);

// TODO: Extend AggregateLowerBound to handle subframes, just like SelectNth.
const auto lower = frames[0].start;
Expand All @@ -1835,24 +1822,19 @@ void WindowDistinctState::Evaluate(const WindowDistinctAggregatorGlobalState &gd
FlushStates();

// Finalise the result aggregates and write to the result
AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator);
aggr.function.finalize(statef, aggr_input_data, result, count, 0);

// Destruct the result aggregates
if (aggr.function.destructor) {
aggr.function.destructor(statef, aggr_input_data, count);
}
statef.Finalize(result);
statef.Destroy();
}

unique_ptr<WindowAggregatorState> WindowDistinctAggregator::GetLocalState(const WindowAggregatorState &gstate) const {
return make_uniq<WindowDistinctState>(*this);
return make_uniq<WindowDistinctAggregatorLocalState>(gstate.Cast<const WindowDistinctAggregatorGlobalState>());
}

void WindowDistinctAggregator::Evaluate(const WindowAggregatorState &gsink, WindowAggregatorState &lstate,
const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const {

const auto &gdstate = gsink.Cast<WindowDistinctAggregatorGlobalState>();
auto &ldstate = lstate.Cast<WindowDistinctState>();
auto &ldstate = lstate.Cast<WindowDistinctAggregatorLocalState>();
ldstate.Evaluate(gdstate, bounds, result, count, row_idx);
}

Expand Down

0 comments on commit 40e2fcd

Please sign in to comment.