Skip to content

Commit

Permalink
Merge AggregateHashTable without intermediate scan
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminwinger committed Jan 8, 2025
1 parent 2a5618e commit eab22da
Show file tree
Hide file tree
Showing 16 changed files with 232 additions and 111 deletions.
25 changes: 14 additions & 11 deletions src/include/processor/operator/aggregate/aggregate_hash_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,6 @@ class AggregateHashTable : public BaseHashTable {
const std::vector<common::LogicalType>& distinctAggKeyTypes, uint64_t numEntriesToAllocate,
FactorizedTableSchema tableSchema);

uint8_t* getEntry(uint64_t idx) { return factorizedTable->getTuple(idx); }

FactorizedTable* getFactorizedTable() { return factorizedTable.get(); }

uint64_t getNumEntries() const { return factorizedTable->getNumTuples(); }

void append(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
common::DataChunkState* leadingState, const std::vector<AggregateInput>& aggregateInputs,
Expand All @@ -88,8 +82,8 @@ class AggregateHashTable : public BaseHashTable {
common::ValueVector* aggregateVector);

//! merge aggregate hash table by combining aggregate states under the same key
void merge(const FactorizedTable& other);
void merge(const AggregateHashTable& other) { merge(*other.factorizedTable); }
void merge(FactorizedTable&& other);
void merge(AggregateHashTable&& other) { merge(std::move(*other.factorizedTable)); }

void finalizeAggregateStates();

Expand All @@ -104,10 +98,15 @@ class AggregateHashTable : public BaseHashTable {
const std::vector<common::ValueVector*>& unFlatKeyVectors, uint64_t numMayMatches,
uint64_t numNoMatches);

uint64_t matchFTEntries(const FactorizedTable& srcTable, uint64_t startOffset,
uint64_t numMayMatches, uint64_t numNoMatches);

void initializeFTEntries(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
const std::vector<common::ValueVector*>& dependentKeyVectors,
uint64_t numFTEntriesToInitialize);
void initializeFTEntries(const FactorizedTable& sourceTable, uint64_t sourceStartOffset,
uint64_t numFTEntriesToInitialize);

uint64_t matchUnFlatVecWithFTColumn(common::ValueVector* vector, uint64_t numMayMatches,
uint64_t& numNoMatches, uint32_t colIdx);
Expand All @@ -122,9 +121,11 @@ class AggregateHashTable : public BaseHashTable {
const std::vector<common::ValueVector*>& dependentKeyVectors,
common::DataChunkState* leadingState);

void findHashSlots(const FactorizedTable& data, uint64_t startOffset, uint64_t numTuples);

protected:
void initializeFT(const std::vector<function::AggregateFunction>& aggregateFunctions,
FactorizedTableSchema tableSchema);
FactorizedTableSchema&& tableSchema);

void initializeHashTable(uint64_t numEntriesToAllocate);

Expand All @@ -147,6 +148,8 @@ class AggregateHashTable : public BaseHashTable {
void increaseSlotIdx(uint64_t& slotIdx) const;

void initTmpHashSlotsAndIdxes();
void initTmpHashSlotsAndIdxes(const FactorizedTable& sourceTable, uint64_t startOffset,
uint64_t numTuples);

void increaseHashSlotIdxes(uint64_t numNoMatches);

Expand All @@ -164,7 +167,7 @@ class AggregateHashTable : public BaseHashTable {
const std::vector<common::ValueVector*>& unFlatKeyVectors,
const std::vector<AggregateInput>& aggregateInputs, uint64_t resultSetMultiplicity);

void fillEntryWithInitialNullAggregateState(FactorizedTable& factorizedTable, uint8_t* entry);
void fillEntryWithInitialNullAggregateState(FactorizedTable& table, uint8_t* entry);

//! find an uninitialized hash slot for given hash and fill hash slot with block id and offset
void fillHashSlot(common::hash_t hash, uint8_t* groupByKeysAndAggregateStateBuffer);
Expand Down Expand Up @@ -233,7 +236,7 @@ class AggregateHashTable : public BaseHashTable {
AggregateHashTable(const AggregateHashTable& other)
: AggregateHashTable(*other.memoryManager, common::LogicalType::copy(other.keyTypes),
common::LogicalType::copy(other.payloadTypes), other.aggregateFunctions,
getDistinctAggKeyTypes(other), 0, other.factorizedTable->getTableSchema()->copy()) {}
getDistinctAggKeyTypes(other), 0, other.getTableSchema()->copy()) {}

protected:
uint32_t hashColIdxInFT{};
Expand Down
11 changes: 5 additions & 6 deletions src/include/processor/operator/aggregate/hash_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include <memory>

#include "aggregate_hash_table.h"
#include "common/concurrent_vector.h"
#include "common/copy_constructors.h"
#include "common/in_mem_overflow_buffer.h"
#include "common/mpsc_queue.h"
Expand Down Expand Up @@ -67,8 +66,8 @@ class HashAggregateSharedState final : public BaseAggregateSharedState {

void setLimitNumber(uint64_t num) { limitNumber = num; }

const FactorizedTableSchema& getTableSchema() const {
return *globalPartitions[0].hashTable->getFactorizedTable()->getTableSchema();
const FactorizedTableSchema* getTableSchema() const {
return globalPartitions[0].hashTable->getTableSchema();
}

void setThreadFinishedProducing() { numThreadsFinishedProducing++; }
Expand All @@ -93,11 +92,11 @@ class HashAggregateSharedState final : public BaseAggregateSharedState {
struct TupleBlock {
TupleBlock(storage::MemoryManager* memoryManager, FactorizedTableSchema tableSchama)
: numTuplesReserved{0}, numTuplesWritten{0},
factorizedTable{memoryManager, std::move(tableSchama)} {
table{memoryManager, std::move(tableSchama)} {
// Start at a fixed capacity of one full block (so that concurrent writes are safe).
// If it is not filled, we resize it to the actual capacity before writing it to the
// hashTable
factorizedTable.resize(factorizedTable.getNumTuplesPerBlock());
table.resize(table.getNumTuplesPerBlock());
}
// numTuplesReserved may be greater than the capacity of the factorizedTable
// if threads try to write to it it while a new block is being allocated
Expand All @@ -107,7 +106,7 @@ class HashAggregateSharedState final : public BaseAggregateSharedState {
// Once numTuplesWritten == factorizedTable.getNumTuplesPerBlock() all writes have
// finished
std::atomic<uint64_t> numTuplesWritten;
FactorizedTable factorizedTable;
FactorizedTable table;
};
common::MPSCQueue<TupleBlock*> queuedTuples;
// When queueing tuples, they are always added to the headBlock until the headBlock is full
Expand Down
5 changes: 1 addition & 4 deletions src/include/processor/operator/hash_join/join_hash_table.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "processor/result/base_hash_table.h"
#include "processor/result/factorized_table.h"
#include "storage/buffer_manager/memory_manager.h"

namespace kuzu {
Expand Down Expand Up @@ -39,7 +40,6 @@ class JoinHashTable : public BaseHashTable {
factorizedTable->lookup(vectors, colIdxesToScan, tuplesToRead, startPos, numTuplesToRead);
}
void merge(JoinHashTable& other) { factorizedTable->merge(*other.factorizedTable); }
uint64_t getNumTuples() { return factorizedTable->getNumTuples(); }
uint8_t** getPrevTuple(const uint8_t* tuple) const {
return (uint8_t**)(tuple + prevPtrColOffset);
}
Expand All @@ -49,8 +49,6 @@ class JoinHashTable : public BaseHashTable {
return ((uint8_t**)(hashSlotsBlocks[slotIdx >> numSlotsPerBlockLog2]
->getData()))[slotIdx & slotIdxInBlockMask];
}
FactorizedTable* getFactorizedTable() { return factorizedTable.get(); }
const FactorizedTableSchema* getTableSchema() { return factorizedTable->getTableSchema(); }

private:
uint8_t** findHashSlot(const uint8_t* tuple) const;
Expand All @@ -65,7 +63,6 @@ class JoinHashTable : public BaseHashTable {
private:
static constexpr uint64_t PREV_PTR_COL_IDX = 1;
static constexpr uint64_t HASH_COL_IDX = 2;
const FactorizedTableSchema* tableSchema;
uint64_t prevPtrColOffset;
};

Expand Down
12 changes: 12 additions & 0 deletions src/include/processor/result/base_hash_table.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
#pragma once

#include <functional>

#include "common/copy_constructors.h"
#include "common/types/types.h"
#include "common/vector/value_vector.h"
#include "processor/result/factorized_table.h"
#include "processor/result/factorized_table_schema.h"
#include "storage/buffer_manager/memory_manager.h"

namespace kuzu {
namespace processor {

using compare_function_t = std::function<bool(common::ValueVector*, uint32_t, const uint8_t*)>;
using raw_compare_function_t = std::function<bool(const uint8_t*, const uint8_t*)>;

class BaseHashTable {
public:
Expand All @@ -18,6 +23,12 @@ class BaseHashTable {

DELETE_COPY_DEFAULT_MOVE(BaseHashTable);

const FactorizedTableSchema* getTableSchema() const {
return factorizedTable->getTableSchema();
}
uint64_t getNumEntries() const { return factorizedTable->getNumTuples(); }
const FactorizedTable* getFactorizedTable() const { return factorizedTable.get(); }

protected:
static constexpr uint64_t HASH_BLOCK_SIZE = common::TEMP_PAGE_SIZE;

Expand All @@ -44,6 +55,7 @@ class BaseHashTable {
storage::MemoryManager* memoryManager;
std::unique_ptr<FactorizedTable> factorizedTable;
std::vector<compare_function_t> compareEntryFuncs;
std::vector<raw_compare_function_t> rawCompareEntryFuncs;
std::vector<common::LogicalType> keyTypes;
// Temporary arrays to hold intermediate results for appending.
std::shared_ptr<common::DataChunkState> hashState;
Expand Down
12 changes: 12 additions & 0 deletions src/include/processor/result/factorized_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,25 @@ class KUZU_API FactorizedTable {
bool isOverflowColNull(const uint8_t* nullBuffer, ft_tuple_idx_t tupleIdx,
ft_col_idx_t colIdx) const;
bool isNonOverflowColNull(const uint8_t* nullBuffer, ft_col_idx_t colIdx) const;
bool isNonOverflowColNull(ft_tuple_idx_t tupleIdx, ft_col_idx_t colIdx) const;
void setNonOverflowColNull(uint8_t* nullBuffer, ft_col_idx_t colIdx);
void clear();

storage::MemoryManager* getMemoryManager() { return memoryManager; }

void resize(uint64_t numTuples);

template<typename Func>
void forEach(Func func) {
for (auto& tupleBlock : flatTupleBlockCollection->getBlocks()) {
uint8_t* tuple = tupleBlock->getData();
for (auto i = 0u; i < tupleBlock->numTuples; i++) {
func(tuple);
tuple += getTableSchema()->getNumBytesPerTuple();
}
}
}

private:
void setOverflowColNull(uint8_t* nullBuffer, ft_col_idx_t colIdx, ft_tuple_idx_t tupleIdx);

Expand Down
1 change: 1 addition & 0 deletions src/include/storage/store/table.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "common/mask.h"
#include "storage/predicate/column_predicate.h"
#include "storage/store/column.h"
#include "storage/store/column_chunk_data.h"
#include "storage/store/node_group.h"

namespace kuzu {
Expand Down
Loading

0 comments on commit eab22da

Please sign in to comment.