Skip to content

Commit

Permalink
Fix list-contains binding (#4644)
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin authored Dec 17, 2024
1 parent 45013c1 commit 881912a
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 5 deletions.
2 changes: 1 addition & 1 deletion extension/iceberg/src/connector/iceberg_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace kuzu {
namespace iceberg_extension {

void IcebergConnector::connect(const std::string& /*dbPath*/, const std::string& /*catalogName*/,
main::ClientContext* context) {
const std::string& /*schemaName*/, main::ClientContext* context) {
// Creates an in-memory duckdb instance, then install httpfs and attach postgres.
instance = std::make_unique<duckdb::DuckDB>(nullptr);
connection = std::make_unique<duckdb::Connection>(*instance);
Expand Down
3 changes: 2 additions & 1 deletion extension/iceberg/src/function/iceberg_bindfunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ static std::string generateQueryOptions(const TableFuncBindInput* input,
std::unique_ptr<TableFuncBindData> bindFuncHelper(main::ClientContext* context,
TableFuncBindInput* input, const std::string& functionName) {
auto connector = std::make_shared<IcebergConnector>();
connector->connect("" /* inMemDB */, "" /* defaultCatalogName */, context);
connector->connect("" /* inMemDB */, "" /* defaultCatalogName */, "" /* defaultSchemaName */,
context);

std::string query_options = generateQueryOptions(input, functionName);
std::string query = common::stringFormat("SELECT * FROM {}('{}'{})", functionName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace iceberg_extension {
class IcebergConnector : public duckdb_extension::DuckDBConnector {
public:
void connect(const std::string& dbPath, const std::string& catalogName,
main::ClientContext* context) override;
const std::string& schemaName, main::ClientContext* context) override;
};

} // namespace iceberg_extension
Expand Down
2 changes: 1 addition & 1 deletion src/function/list/list_contains_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct ListContains {

static std::unique_ptr<FunctionBindData> bindFunc(ScalarBindFuncInput input) {
auto scalarFunction = input.definition->ptrCast<ScalarFunction>();
TypeUtils::visit(input.arguments[1]->getDataType().getPhysicalType(),
TypeUtils::visit(ListType::getChildType(input.arguments[0]->getDataType()).getPhysicalType(),
[&scalarFunction]<typename T>(T) {
scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t, T,
uint8_t, ListContains>;
Expand Down
26 changes: 25 additions & 1 deletion test/test_files/function/list.test
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,31 @@ True
False
True

-LOG ListContainsUnmatchType
-STATEMENT MATCH (a:person) where list_contains(cast([0, 7] as int8[]), a.ID) return a.ID
---- 2
0
7

-STATEMENT MATCH (a:person) where list_contains(cast([2, 3] as int16[]), a.ID) return a.ID
---- 2
2
3

-STATEMENT MATCH (a:person) where list_contains(cast([5] as int32[]), a.ID) return a.ID
---- 1
5

-STATEMENT MATCH (a:person) where list_contains(cast([7, 8] as int128[]), a.ID) return a.ID
---- 2
7
8

-STATEMENT MATCH (a:person) where list_contains(cast(['A0EEBC99-9C0B-4EF8-BB6D-6BB9BD380A11', 'a0eebc99-9c0b4ef8-bb6d6bb9-bd380a15'] as uuid[]), a.u) return a.ID
---- 2
0
7

-LOG ListContainsSelect
-STATEMENT MATCH (a:person) WHERE list_contains(a.courseScoresPerTerm, [8]) RETURN a.ID
---- 2
Expand Down Expand Up @@ -2310,4 +2335,3 @@ True

-STATEMENT RETURN list_has_all(null, null)
---- 1

0 comments on commit 881912a

Please sign in to comment.