Skip to content

Commit

Permalink
Fix list unique and distinct (#3310)
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin authored Apr 19, 2024
1 parent 7984feb commit 0df1f63
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 265 deletions.
3 changes: 3 additions & 0 deletions src/common/types/value/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,9 @@ void Value::copyFromColLayout(const uint8_t* value, ValueVector* vector) {
case PhysicalTypeID::STRUCT: {
copyFromColLayoutStruct(*(struct_entry_t*)value, vector);
} break;
case PhysicalTypeID::INTERNAL_ID: {
val.internalIDVal = *((nodeID_t*)value);
} break;
default:
KU_UNREACHABLE;
}
Expand Down
203 changes: 9 additions & 194 deletions src/function/vector_list_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -776,210 +776,25 @@ function_set ListProductFunction::getFunctionSet() {
}

static std::unique_ptr<FunctionBindData> ListDistinctBindFunc(
const binder::expression_vector& arguments, Function* function) {
auto scalarFunction = ku_dynamic_cast<Function*, ScalarFunction*>(function);
switch (ListType::getChildType(&arguments[0]->dataType)->getLogicalTypeID()) {
case LogicalTypeID::SERIAL:
case LogicalTypeID::INT64: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
list_entry_t, ListDistinct<int64_t>>;
} break;
case LogicalTypeID::INT32: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
list_entry_t, ListDistinct<int32_t>>;
} break;
case LogicalTypeID::INT16: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
list_entry_t, ListDistinct<int16_t>>;
} break;
case LogicalTypeID::INT8: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
list_entry_t, ListDistinct<int8_t>>;
} break;
case LogicalTypeID::UINT64: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
list_entry_t, ListDistinct<uint64_t>>;
} break;
case LogicalTypeID::UINT32: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
list_entry_t, ListDistinct<uint32_t>>;
} break;
case LogicalTypeID::UINT16: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
list_entry_t, ListDistinct<uint16_t>>;
} break;
case LogicalTypeID::UINT8: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
list_entry_t, ListDistinct<uint8_t>>;
} break;
case LogicalTypeID::INT128: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
list_entry_t, ListDistinct<int128_t>>;
} break;
case LogicalTypeID::DOUBLE: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
list_entry_t, ListDistinct<double>>;
} break;
case LogicalTypeID::FLOAT: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
list_entry_t, ListDistinct<float>>;
} break;
case LogicalTypeID::BOOL: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
list_entry_t, ListDistinct<uint8_t>>;
} break;
case LogicalTypeID::STRING: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
list_entry_t, ListDistinct<ku_string_t>>;
} break;
case LogicalTypeID::DATE: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
list_entry_t, ListDistinct<date_t>>;
} break;
case LogicalTypeID::TIMESTAMP_MS: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
list_entry_t, ListDistinct<timestamp_ms_t>>;
} break;
case LogicalTypeID::TIMESTAMP_NS: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
list_entry_t, ListDistinct<timestamp_ns_t>>;
} break;
case LogicalTypeID::TIMESTAMP_SEC: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
list_entry_t, ListDistinct<timestamp_sec_t>>;
} break;
case LogicalTypeID::TIMESTAMP_TZ: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
list_entry_t, ListDistinct<timestamp_tz_t>>;
} break;
case LogicalTypeID::TIMESTAMP: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
list_entry_t, ListDistinct<timestamp_t>>;
} break;
case LogicalTypeID::INTERVAL: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
list_entry_t, ListDistinct<interval_t>>;
} break;
case LogicalTypeID::INTERNAL_ID: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
list_entry_t, ListDistinct<internalID_t>>;
} break;
default: {
KU_UNREACHABLE;
}
}
const binder::expression_vector& arguments, Function* /*function*/) {
return std::make_unique<FunctionBindData>(arguments[0]->getDataType().copy());
}

function_set ListDistinctFunction::getFunctionSet() {
function_set result;
result.push_back(
std::make_unique<ScalarFunction>(name, std::vector<LogicalTypeID>{LogicalTypeID::LIST},
LogicalTypeID::LIST, nullptr, nullptr, ListDistinctBindFunc, false /* isVarlength*/));
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST}, LogicalTypeID::LIST,
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, list_entry_t, ListDistinct>,
nullptr, ListDistinctBindFunc, false /* isVarlength*/));
return result;
}

static std::unique_ptr<FunctionBindData> ListUniqueBindFunc(
const binder::expression_vector& arguments, Function* function) {
auto scalarFunction = ku_dynamic_cast<Function*, ScalarFunction*>(function);
switch (ListType::getChildType(&arguments[0]->dataType)->getLogicalTypeID()) {
case LogicalTypeID::SERIAL:
case LogicalTypeID::INT64: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, int64_t, ListUnique<int64_t>>;
} break;
case LogicalTypeID::INT32: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, int64_t, ListUnique<int32_t>>;
} break;
case LogicalTypeID::INT16: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, int64_t, ListUnique<int16_t>>;
} break;
case LogicalTypeID::INT8: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, int64_t, ListUnique<int8_t>>;
} break;
case LogicalTypeID::UINT64: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
int64_t, ListUnique<uint64_t>>;
} break;
case LogicalTypeID::UINT32: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
int64_t, ListUnique<uint32_t>>;
} break;
case LogicalTypeID::UINT16: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
int64_t, ListUnique<uint16_t>>;
} break;
case LogicalTypeID::UINT8: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, int64_t, ListUnique<uint8_t>>;
} break;
case LogicalTypeID::INT128: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
int64_t, ListUnique<int128_t>>;
} break;
case LogicalTypeID::DOUBLE: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, int64_t, ListUnique<double>>;
} break;
case LogicalTypeID::FLOAT: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, int64_t, ListUnique<float>>;
} break;
case LogicalTypeID::BOOL: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, int64_t, ListUnique<uint8_t>>;
} break;
case LogicalTypeID::STRING: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
int64_t, ListUnique<ku_string_t>>;
} break;
case LogicalTypeID::DATE: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, int64_t, ListUnique<date_t>>;
} break;
case LogicalTypeID::TIMESTAMP: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
int64_t, ListUnique<timestamp_t>>;
} break;
case LogicalTypeID::TIMESTAMP_MS: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
int64_t, ListUnique<timestamp_ms_t>>;
} break;
case LogicalTypeID::TIMESTAMP_NS: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
int64_t, ListUnique<timestamp_ns_t>>;
} break;
case LogicalTypeID::TIMESTAMP_SEC: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
int64_t, ListUnique<timestamp_sec_t>>;
} break;
case LogicalTypeID::TIMESTAMP_TZ: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
int64_t, ListUnique<timestamp_tz_t>>;
} break;
case LogicalTypeID::INTERVAL: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
int64_t, ListUnique<interval_t>>;
} break;
case LogicalTypeID::INTERNAL_ID: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
int64_t, ListUnique<internalID_t>>;
} break;
default: {
KU_UNREACHABLE;
}
}
return std::make_unique<FunctionBindData>(LogicalType::INT64());
}

function_set ListUniqueFunction::getFunctionSet() {
function_set result;
result.push_back(
std::make_unique<ScalarFunction>(name, std::vector<LogicalTypeID>{LogicalTypeID::LIST},
LogicalTypeID::INT64, nullptr, nullptr, ListUniqueBindFunc, false /* isVarlength*/));
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST}, LogicalTypeID::INT64,
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, int64_t, ListUnique>,
false /* isVarlength*/));
return result;
}

Expand Down
35 changes: 12 additions & 23 deletions src/include/function/list/functions/list_distinct_function.h
Original file line number Diff line number Diff line change
@@ -1,37 +1,26 @@
#pragma once

#include <set>

#include "common/vector/value_vector.h"
#include "list_unique_function.h"

namespace kuzu {
namespace function {

template<typename T>
struct ListDistinct {
static inline void operation(common::list_entry_t& input, common::list_entry_t& result,
common::ValueVector& inputVector, common::ValueVector& resultVector) {
std::set<T> uniqueValues;
auto inputValues =
reinterpret_cast<T*>(common::ListVector::getListValues(&inputVector, input));
auto inputDataVector = common::ListVector::getDataVector(&inputVector);

for (auto i = 0u; i < input.size; i++) {
if (inputDataVector->isNull(input.offset + i)) {
continue;
}
uniqueValues.insert(inputValues[i]);
}

result = common::ListVector::addList(&resultVector, uniqueValues.size());
auto resultValues = common::ListVector::getListValues(&resultVector, result);
auto numUniqueValues = ListUnique::appendListElementsToValueSet(input, inputVector);
result = common::ListVector::addList(&resultVector, numUniqueValues);
auto resultDataVector = common::ListVector::getDataVector(&resultVector);
auto numBytesPerValue = inputDataVector->getNumBytesPerValue();
for (auto val : uniqueValues) {
resultDataVector->copyFromVectorData(resultValues, inputDataVector,
reinterpret_cast<uint8_t*>(&val));
resultValues += numBytesPerValue;
}
auto resultDataVectorBuffer =
common::ListVector::getListValuesWithOffset(&resultVector, result, 0 /* offset */);
ListUnique::appendListElementsToValueSet(input, inputVector, nullptr,
[&resultDataVector, &resultDataVectorBuffer](common::ValueVector& dataVector,
uint64_t pos) -> void {
resultDataVector->copyFromVectorData(resultDataVectorBuffer, &dataVector,
dataVector.getData() + pos * dataVector.getNumBytesPerValue());
resultDataVectorBuffer += dataVector.getNumBytesPerValue();
});
}
};

Expand Down
51 changes: 38 additions & 13 deletions src/include/function/list/functions/list_unique_function.h
Original file line number Diff line number Diff line change
@@ -1,28 +1,53 @@
#pragma once

#include <set>

#include "common/type_utils.h"
#include "common/types/value/value.h"
#include "common/vector/value_vector.h"

namespace kuzu {
namespace function {

template<typename T>
struct ListUnique {
static inline void operation(common::list_entry_t& input, int64_t& result,
common::ValueVector& inputVector, common::ValueVector& /*resultVector*/) {
std::set<T> uniqueValues;
auto inputValues =
reinterpret_cast<T*>(common::ListVector::getListValues(&inputVector, input));
auto inputDataVector = common::ListVector::getDataVector(&inputVector);
struct ValueHashFunction {
uint64_t operator()(const common::Value& value) const { return (uint64_t)value.computeHash(); }
};

struct ValueEquality {
bool operator()(const common::Value& a, const common::Value& b) const { return a == b; }
};

using ValueSet = std::unordered_set<common::Value, ValueHashFunction, ValueEquality>;

using duplicateValueHandler = std::function<void(const std::string&)>;
using uniqueValueHandler = std::function<void(common::ValueVector& dataVector, uint64_t pos)>;

struct ListUnique {
static uint64_t appendListElementsToValueSet(common::list_entry_t& input,
common::ValueVector& inputVector, duplicateValueHandler duplicateValHandler = nullptr,
uniqueValueHandler uniqueValueHandler = nullptr) {
ValueSet uniqueKeys;
auto dataVector = common::ListVector::getDataVector(&inputVector);
auto val = common::Value::createDefaultValue(dataVector->dataType);
for (auto i = 0u; i < input.size; i++) {
if (inputDataVector->isNull(input.offset + i)) {
if (dataVector->isNull(input.offset + i)) {
continue;
}
uniqueValues.insert(inputValues[i]);
auto entryVal = common::ListVector::getListValuesWithOffset(&inputVector, input, i);
val.copyFromColLayout(entryVal, dataVector);
auto uniqueKey = uniqueKeys.insert(val).second;
if (duplicateValHandler != nullptr && !uniqueKey) {
duplicateValHandler(
common::TypeUtils::entryToString(dataVector->dataType, entryVal, dataVector));
}
if (uniqueValueHandler != nullptr && uniqueKey) {
uniqueValueHandler(*dataVector, input.offset + i);
}
}
result = uniqueValues.size();
return uniqueKeys.size();
}

static void operation(common::list_entry_t& input, int64_t& result,
common::ValueVector& inputVector, common::ValueVector& /*resultVector*/) {
result = appendListElementsToValueSet(input, inputVector);
}
};

Expand Down
27 changes: 5 additions & 22 deletions src/include/function/map/functions/map_creation_function.h
Original file line number Diff line number Diff line change
@@ -1,36 +1,19 @@
#pragma once

#include "unordered_set"

#include "common/exception/runtime.h"
#include "common/type_utils.h"
#include "common/types/value/value.h"
#include "common/vector/value_vector.h"
#include "function/list/functions/list_unique_function.h"

namespace kuzu {
namespace function {

struct ValueHashFunction {
uint64_t operator()(const common::Value& value) const { return (uint64_t)value.computeHash(); }
};

struct ValueEquality {
bool operator()(const common::Value& a, const common::Value& b) const { return a == b; }
};

static void validateKeys(common::list_entry_t& keyEntry, common::ValueVector& keyVector) {
std::unordered_set<common::Value, ValueHashFunction, ValueEquality> uniqueKeys;
auto dataVector = common::ListVector::getDataVector(&keyVector);
auto val = common::Value::createDefaultValue(dataVector->dataType);
for (auto i = 0u; i < keyEntry.size; i++) {
auto entryVal = common::ListVector::getListValuesWithOffset(&keyVector, keyEntry, i);
val.copyFromColLayout(entryVal, dataVector);
auto unique = uniqueKeys.insert(val).second;
if (!unique) {
throw common::RuntimeException{common::stringFormat("Found duplicate key: {} in map.",
common::TypeUtils::entryToString(dataVector->dataType, entryVal, dataVector))};
}
}
ListUnique::appendListElementsToValueSet(keyEntry, keyVector, [](const std::string& key) {
throw common::RuntimeException{
common::stringFormat("Found duplicate key: {} in map.", key)};
});
}

struct MapCreation {
Expand Down
Loading

0 comments on commit 0df1f63

Please sign in to comment.